From 66e21340727bcbc1b5b48e6ed79d932f394b1e49 Mon Sep 17 00:00:00 2001 From: Daniel Grunwald Date: Mon, 21 Feb 2011 23:36:20 +0100 Subject: [PATCH] TypeAnalysis: Make type analysis store both the inferred type and the expected type in each ILExpression. AstMethodBodyBuilder: Use that type information to insert conversions between int/bool/enum where required (previously this was working only for branch conditions/comparison instructions). --- .../Ast/AstMethodBodyBuilder.cs | 116 +++++++++--------- ICSharpCode.Decompiler/ILAst/ILAstTypes.cs | 10 ++ ICSharpCode.Decompiler/ILAst/TypeAnalysis.cs | 48 +++++++- 3 files changed, 110 insertions(+), 64 deletions(-) diff --git a/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs b/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs index eeb8bfb36..e15d85e7d 100644 --- a/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs +++ b/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs @@ -176,12 +176,6 @@ namespace Decompiler return args; } - AstNode TransformExpression(ILExpression expr) - { - List args = TransformExpressionArguments(expr); - return TransformByteCode(expr, args); - } - Ast.Expression MakeBranchCondition(ILExpression expr) { switch(expr.Code) { @@ -204,22 +198,11 @@ namespace Decompiler List args = TransformExpressionArguments(expr); Ast.Expression arg1 = args.Count >= 1 ? args[0] : null; Ast.Expression arg2 = args.Count >= 2 ? args[1] : null; - TypeReference arg1Type = args.Count >= 1 ? expr.Arguments[0].InferredType : null; switch((Code)expr.Code) { case Code.Brfalse: - if (arg1Type == typeSystem.Boolean) - return new Ast.UnaryOperatorExpression(UnaryOperatorType.Not, arg1); - else if (TypeAnalysis.IsIntegerOrEnum(typeSystem, arg1Type)) - return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Equality, new PrimitiveExpression(0)); - else - return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Equality, new NullReferenceExpression()); + return new Ast.UnaryOperatorExpression(UnaryOperatorType.Not, arg1); case Code.Brtrue: - if (arg1Type == typeSystem.Boolean) - return arg1; - else if (TypeAnalysis.IsIntegerOrEnum(typeSystem, arg1Type)) - return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.InEquality, new PrimitiveExpression(0)); - else - return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.InEquality, new NullReferenceExpression()); + return arg1; case Code.Beq: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Equality, arg2); case Code.Bge: @@ -245,21 +228,6 @@ namespace Decompiler } } - AstNode TransformByteCode(ILExpression byteCode, List args) - { - try { - AstNode ret = TransformByteCode_Internal(byteCode, args); - // ret.UserData["Type"] = byteCode.Type; - return ret; - } catch (NotImplementedException) { - // Output the operand of the unknown IL code as well - if (byteCode.Operand != null) { - args.Insert(0, new IdentifierExpression(FormatByteCodeOperand(byteCode.Operand))); - } - return new IdentifierExpression(byteCode.Code.GetName()).Invoke(args); - } - } - static string FormatByteCodeOperand(object operand) { if (operand == null) { @@ -285,7 +253,18 @@ namespace Decompiler } } - AstNode TransformByteCode_Internal(ILExpression byteCode, List args) + AstNode TransformExpression(ILExpression expr) + { + List args = TransformExpressionArguments(expr); + AstNode node = TransformByteCode(expr, args); + Expression astExpr = node as Expression; + if (astExpr != null) + return Convert(astExpr, expr.InferredType, expr.ExpectedType); + else + return node; + } + + AstNode TransformByteCode(ILExpression byteCode, List args) { ILCode opCode = byteCode.Code; object operand = byteCode.Operand; @@ -381,9 +360,11 @@ namespace Decompiler case Code.Bne_Un: return new Ast.IfElseStatement(new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.InEquality, arg2), branchCommand); #endregion #region Comparison - case Code.Ceq: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Equality, ConvertIntToBool(arg2)); + case Code.Ceq: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Equality, arg2); case Code.Cgt: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThan, arg2); - case Code.Cgt_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThan, arg2); + case Code.Cgt_Un: + // TODO: can also mean Inequality, when used with object references + return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThan, arg2); case Code.Clt: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThan, arg2); case Code.Clt_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThan, arg2); #endregion @@ -417,8 +398,7 @@ namespace Decompiler case Code.Conv_Ovf_I1_Un: return arg1.CastTo(typeof(SByte)); case Code.Conv_Ovf_I2_Un: return arg1.CastTo(typeof(Int16)); case Code.Conv_Ovf_I4_Un: return arg1.CastTo(typeof(Int32)); - case Code.Conv_Ovf_I8_Un: return arg1.CastTo(typeof(Int64)); - case Code.Conv_Ovf_U_Un: return arg1.CastTo(typeof(uint)); + case Code.Conv_Ovf_I8_Un: return arg1.CastTo(typeof(Int64)); case Code.Conv_Ovf_U_Un: return arg1.CastTo(typeof(uint)); case Code.Conv_Ovf_U1_Un: return arg1.CastTo(typeof(Byte)); case Code.Conv_Ovf_U2_Un: return arg1.CastTo(typeof(UInt16)); case Code.Conv_Ovf_U4_Un: return arg1.CastTo(typeof(UInt32)); @@ -504,20 +484,7 @@ namespace Decompiler return MakeRef(new Ast.IdentifierExpression(((ParameterDefinition)operand).Name).WithAnnotation(operand)); } case Code.Ldc_I4: - if (byteCode.InferredType == typeSystem.Boolean && (int)operand == 0) - return new Ast.PrimitiveExpression(false); - else if (byteCode.InferredType == typeSystem.Boolean && (int)operand == 1) - return new Ast.PrimitiveExpression(true); - if (byteCode.InferredType != null) { // cannot rely on IsValueType, it's not set for typerefs (but is set for typespecs) - TypeDefinition enumDefinition = byteCode.InferredType.Resolve(); - if (enumDefinition != null && enumDefinition.IsEnum) { - foreach (FieldDefinition field in enumDefinition.Fields) { - if (field.IsStatic && object.Equals(CSharpPrimitiveCast.Cast(TypeCode.Int32, field.Constant, false), operand)) - return AstBuilder.ConvertType(enumDefinition).Member(field.Name).WithAnnotation(field); - } - } - } - return new Ast.PrimitiveExpression(operand); + return PrimitiveExpression((int)operand, byteCode.InferredType); case Code.Ldc_I8: case Code.Ldc_R4: case Code.Ldc_R8: @@ -580,7 +547,6 @@ namespace Decompiler case Code.Refanyval: return InlineAssembly(byteCode, args); case Code.Ret: { if (methodDef.ReturnType.FullName != "System.Void") { - arg1 = Convert(arg1, methodDef.ReturnType); return new Ast.ReturnStatement { Expression = arg1 }; } else { return new Ast.ReturnStatement(); @@ -719,19 +685,51 @@ namespace Decompiler return new DirectionExpression { Expression = expr, FieldDirection = FieldDirection.Ref }; } - static Ast.Expression Convert(Ast.Expression expr, Cecil.TypeReference reqType) + Ast.Expression Convert(Ast.Expression expr, Cecil.TypeReference actualType, Cecil.TypeReference reqType) { - if (reqType == null) { + if (reqType == null || actualType == reqType) { return expr; } else { + if (reqType == typeSystem.Boolean) { + if (TypeAnalysis.IsIntegerOrEnum(typeSystem, actualType)) { + return new BinaryOperatorExpression(expr, BinaryOperatorType.InEquality, PrimitiveExpression(0, actualType)); + } else { + return new BinaryOperatorExpression(expr, BinaryOperatorType.InEquality, new NullReferenceExpression()); + } + } + if (actualType == typeSystem.Boolean && TypeAnalysis.IsIntegerOrEnum(typeSystem, reqType)) { + return new ConditionalExpression { + Condition = expr, + TrueExpression = PrimitiveExpression(1, reqType), + FalseExpression = PrimitiveExpression(0, reqType) + }; + } return expr; } } - static Ast.Expression ConvertIntToBool(Ast.Expression astInt) + Expression PrimitiveExpression(long val, TypeReference type) { - return astInt; - // return new Ast.ParenthesizedExpression(new Ast.BinaryOperatorExpression(astInt, BinaryOperatorType.InEquality, new Ast.PrimitiveExpression(0, "0"))); + if (type == typeSystem.Boolean && val == 0) + return new Ast.PrimitiveExpression(false); + else if (type == typeSystem.Boolean && val == 1) + return new Ast.PrimitiveExpression(true); + if (type != null) { // cannot rely on type.IsValueType, it's not set for typerefs (but is set for typespecs) + TypeDefinition enumDefinition = type.Resolve(); + if (enumDefinition != null && enumDefinition.IsEnum) { + foreach (FieldDefinition field in enumDefinition.Fields) { + if (field.IsStatic && object.Equals(CSharpPrimitiveCast.Cast(TypeCode.Int64, field.Constant, false), val)) + return AstBuilder.ConvertType(enumDefinition).Member(field.Name).WithAnnotation(field); + else if (!field.IsStatic && field.IsRuntimeSpecialName) + type = field.FieldType; // use primitive type of the enum + } + } + } + TypeCode code = TypeAnalysis.GetTypeCode(typeSystem, type); + if (code == TypeCode.Object) + return new Ast.PrimitiveExpression((int)val); + else + return new Ast.PrimitiveExpression(CSharpPrimitiveCast.Cast(code, val, false)); } } } diff --git a/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs b/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs index 667084ba8..82c05d985 100644 --- a/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs +++ b/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs @@ -203,6 +203,7 @@ namespace Decompiler // Mapping to the original instructions (useful for debugging) public List ILRanges { get; set; } + public TypeReference ExpectedType { get; set; } public TypeReference InferredType { get; set; } public ILExpression(ILCode code, object operand, params ILExpression[] args) @@ -296,6 +297,15 @@ namespace Decompiler if (this.InferredType != null) { output.Write(':'); this.InferredType.WriteTo(output, true, true); + if (this.ExpectedType != null && this.ExpectedType.FullName != this.InferredType.FullName) { + output.Write("[exp:"); + this.ExpectedType.WriteTo(output, true, true); + output.Write(']'); + } + } else if (this.ExpectedType != null) { + output.Write("[exp:"); + this.ExpectedType.WriteTo(output, true, true); + output.Write(']'); } output.Write('('); bool first = true; diff --git a/ICSharpCode.Decompiler/ILAst/TypeAnalysis.cs b/ICSharpCode.Decompiler/ILAst/TypeAnalysis.cs index 202ea5a1a..d1ccc1249 100644 --- a/ICSharpCode.Decompiler/ILAst/TypeAnalysis.cs +++ b/ICSharpCode.Decompiler/ILAst/TypeAnalysis.cs @@ -55,7 +55,7 @@ namespace Decompiler } bool anyArgumentIsMissingType = expr.Arguments.Any(a => a.InferredType == null); if (expr.InferredType == null || anyArgumentIsMissingType) - expr.InferredType = InferTypeForExpression(expr, null, forceInferChildren: anyArgumentIsMissingType); + expr.InferredType = InferTypeForExpression(expr, expr.ExpectedType, forceInferChildren: anyArgumentIsMissingType); } foreach (ILNode child in node.GetChildren()) { InferTypes(child); @@ -97,6 +97,7 @@ namespace Decompiler /// The inferred type TypeReference InferTypeForExpression(ILExpression expr, TypeReference expectedType, bool forceInferChildren = false) { + expr.ExpectedType = expectedType; if (forceInferChildren || expr.InferredType == null) expr.InferredType = DoInferTypeForExpression(expr, expectedType, forceInferChildren); return expr.InferredType; @@ -552,13 +553,16 @@ namespace Decompiler TypeReference leftPreferred = DoInferTypeForExpression(left, null); TypeReference rightPreferred = DoInferTypeForExpression(right, null); if (leftPreferred == rightPreferred) { - return left.InferredType = right.InferredType = leftPreferred; + return left.InferredType = right.InferredType = left.ExpectedType = right.ExpectedType = leftPreferred; } else if (rightPreferred == DoInferTypeForExpression(left, rightPreferred)) { - return left.InferredType = right.InferredType = rightPreferred; + return left.InferredType = right.InferredType = left.ExpectedType = right.ExpectedType = rightPreferred; } else if (leftPreferred == DoInferTypeForExpression(right, leftPreferred)) { - return left.InferredType = right.InferredType = leftPreferred; + return left.InferredType = right.InferredType = left.ExpectedType = right.ExpectedType = leftPreferred; } else { - return left.InferredType = right.InferredType = TypeWithMoreInformation(leftPreferred, rightPreferred); + left.ExpectedType = right.ExpectedType = TypeWithMoreInformation(leftPreferred, rightPreferred); + left.InferredType = DoInferTypeForExpression(left, left.ExpectedType); + right.InferredType = DoInferTypeForExpression(left, right.ExpectedType); + return left.ExpectedType; } } @@ -637,5 +641,39 @@ namespace Decompiler return true; return null; } + + public static TypeCode GetTypeCode(TypeSystem typeSystem, TypeReference type) + { + if (type == typeSystem.Boolean) + return TypeCode.Boolean; + else if (type == typeSystem.Byte) + return TypeCode.Byte; + else if (type == typeSystem.Char) + return TypeCode.Char; + else if (type == typeSystem.Double) + return TypeCode.Double; + else if (type == typeSystem.Int16) + return TypeCode.Int16; + else if (type == typeSystem.Int32) + return TypeCode.Int32; + else if (type == typeSystem.Int64) + return TypeCode.Int64; + else if (type == typeSystem.Single) + return TypeCode.Single; + else if (type == typeSystem.Double) + return TypeCode.Double; + else if (type == typeSystem.SByte) + return TypeCode.SByte; + else if (type == typeSystem.UInt16) + return TypeCode.UInt16; + else if (type == typeSystem.UInt32) + return TypeCode.UInt32; + else if (type == typeSystem.UInt64) + return TypeCode.UInt64; + else if (type == typeSystem.String) + return TypeCode.String; + else + return TypeCode.Object; + } } }