From b6ba673a2d798283de459c532cd52bd5533d9962 Mon Sep 17 00:00:00 2001 From: Daniel Grunwald Date: Sat, 8 Oct 2011 22:40:12 +0200 Subject: [PATCH] Implemented Expression Tree decompilation. Closes #175. --- .../Ast/AstMethodBodyBuilder.cs | 8 +- .../Ast/Transforms/ExpressionTreeConverter.cs | 234 ++++++++++++++++-- 2 files changed, 214 insertions(+), 28 deletions(-) diff --git a/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs b/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs index 7f4740b59..52d9cc579 100644 --- a/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs +++ b/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs @@ -658,18 +658,24 @@ namespace ICSharpCode.Decompiler.Ast if (operand is Cecil.TypeReference) { return AstBuilder.CreateTypeOfExpression((TypeReference)operand).Member("TypeHandle"); } else { - var referencedEntity = new IdentifierExpression(FormatByteCodeOperand(byteCode.Operand)).WithAnnotation(byteCode.Operand); + Expression referencedEntity; string loadName; string handleName; if (operand is Cecil.FieldReference) { loadName = "fieldof"; handleName = "FieldHandle"; + FieldReference fr = (FieldReference)operand; + referencedEntity = AstBuilder.ConvertType(fr.DeclaringType).Member(fr.Name).WithAnnotation(fr); } else if (operand is Cecil.MethodReference) { loadName = "methodof"; handleName = "MethodHandle"; + MethodReference mr = (MethodReference)operand; + var methodParameters = mr.Parameters.Select(p => new TypeReferenceExpression(AstBuilder.ConvertType(p.ParameterType))); + referencedEntity = AstBuilder.ConvertType(mr.DeclaringType).Invoke(mr.Name, methodParameters).WithAnnotation(mr); } else { loadName = "ldtoken"; handleName = "Handle"; + referencedEntity = new IdentifierExpression(FormatByteCodeOperand(byteCode.Operand)); } return new IdentifierExpression(loadName).Invoke(referencedEntity).WithAnnotation(new LdTokenAnnotation()).Member(handleName); } diff --git a/ICSharpCode.Decompiler/Ast/Transforms/ExpressionTreeConverter.cs b/ICSharpCode.Decompiler/Ast/Transforms/ExpressionTreeConverter.cs index e294c8193..3bc059637 100644 --- a/ICSharpCode.Decompiler/Ast/Transforms/ExpressionTreeConverter.cs +++ b/ICSharpCode.Decompiler/Ast/Transforms/ExpressionTreeConverter.cs @@ -81,7 +81,6 @@ namespace ICSharpCode.Decompiler.Ast.Transforms case "AndAssign": return ConvertAssignmentOperator(invocation, AssignmentOperatorType.BitwiseAnd); case "ArrayAccess": - return NotImplemented(invocation); case "ArrayIndex": return ConvertArrayIndex(invocation); case "ArrayLength": @@ -93,7 +92,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms case "Coalesce": return ConvertBinaryOperator(invocation, BinaryOperatorType.NullCoalescing); case "Condition": - return NotImplemented(invocation); + return ConvertCondition(invocation); case "Constant": if (invocation.Arguments.Count >= 1) return invocation.Arguments.First().Clone(); @@ -103,8 +102,6 @@ namespace ICSharpCode.Decompiler.Ast.Transforms return ConvertCast(invocation, false); case "ConvertChecked": return ConvertCast(invocation, true); - case "Default": - return NotImplemented(invocation); case "Divide": return ConvertBinaryOperator(invocation, BinaryOperatorType.Divide); case "DivideAssign": @@ -133,6 +130,10 @@ namespace ICSharpCode.Decompiler.Ast.Transforms return ConvertBinaryOperator(invocation, BinaryOperatorType.LessThan); case "LessThanOrEqual": return ConvertBinaryOperator(invocation, BinaryOperatorType.LessThanOrEqual); + case "ListInit": + return ConvertListInit(invocation); + case "MemberInit": + return ConvertMemberInit(invocation); case "Modulo": return ConvertBinaryOperator(invocation, BinaryOperatorType.Modulus); case "ModuloAssign": @@ -151,6 +152,8 @@ namespace ICSharpCode.Decompiler.Ast.Transforms return ConvertUnaryOperator(invocation, UnaryOperatorType.Minus, true); case "New": return ConvertNewObject(invocation); + case "NewArrayBounds": + return ConvertNewArrayBounds(invocation); case "NewArrayInit": return ConvertNewArrayInit(invocation); case "Not": @@ -168,7 +171,10 @@ namespace ICSharpCode.Decompiler.Ast.Transforms case "Property": return ConvertProperty(invocation); case "Quote": - return NotImplemented(invocation); + if (invocation.Arguments.Count == 1) + return Convert(invocation.Arguments.Single()); + else + return NotSupported(invocation); case "RightShift": return ConvertBinaryOperator(invocation, BinaryOperatorType.ShiftRight); case "RightShiftAssign": @@ -208,11 +214,6 @@ namespace ICSharpCode.Decompiler.Ast.Transforms Debug.WriteLine("Expression Tree Conversion Failed: '" + expr + "' is not supported"); return null; } - - Expression NotImplemented(Expression expr) - { - return new IdentifierExpression("NotImplemented").Invoke(expr.Clone()); - } #endregion #region Convert Lambda @@ -323,10 +324,15 @@ namespace ICSharpCode.Decompiler.Ast.Transforms return null; } - string name = mr.Name; + return convertedTarget.Member(GetPropertyName(mr)).WithAnnotation(mr); + } + + string GetPropertyName(MethodReference accessor) + { + string name = accessor.Name; if (name.StartsWith("get_", StringComparison.Ordinal) || name.StartsWith("set_", StringComparison.Ordinal)) name = name.Substring(4); - return convertedTarget.Member(name).WithAnnotation(mr); + return name; } #endregion @@ -460,7 +466,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms #region Convert Assignment Operator Expression ConvertAssignmentOperator(InvocationExpression invocation, AssignmentOperatorType op, bool? isChecked = null) { - return NotImplemented(invocation); + return NotSupported(invocation); } #endregion @@ -493,6 +499,22 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } #endregion + #region Convert Condition Operator + Expression ConvertCondition(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 3) + return NotSupported(invocation); + + Expression condition = Convert(invocation.Arguments.ElementAt(0)); + Expression trueExpr = Convert(invocation.Arguments.ElementAt(1)); + Expression falseExpr = Convert(invocation.Arguments.ElementAt(2)); + if (condition != null && trueExpr != null && falseExpr != null) + return new ConditionalExpression(condition, trueExpr, falseExpr); + else + return null; + } + #endregion + #region Convert New Object static readonly Expression newObjectCtorPattern = new TypePattern(typeof(MethodBase)).ToType().Invoke ( @@ -557,36 +579,162 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } #endregion + #region Convert ListInit + static readonly Pattern elementInitArrayPattern = ArrayInitializationPattern( + typeof(System.Linq.Expressions.ElementInit), + new TypePattern(typeof(System.Linq.Expressions.Expression)).ToType().Invoke("ElementInit", new AnyNode("methodInfos"), new AnyNode("addArgumentsArrays")) + ); + + Expression ConvertListInit(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + ObjectCreateExpression oce = Convert(invocation.Arguments.ElementAt(0)) as ObjectCreateExpression; + if (oce == null) + return null; + Expression elementsArray = invocation.Arguments.ElementAt(1); + ArrayInitializerExpression initializer = ConvertElementInit(elementsArray); + if (initializer != null) { + oce.Initializer = initializer; + return oce; + } else { + return null; + } + } + + ArrayInitializerExpression ConvertElementInit(Expression elementsArray) + { + IList elements = ConvertExpressionsArray(elementsArray); + if (elements != null) { + return new ArrayInitializerExpression(elements); + } + Match m = elementInitArrayPattern.Match(elementsArray); + if (!m.Success) + return null; + ArrayInitializerExpression result = new ArrayInitializerExpression(); + foreach (var elementInit in m.Get("addArgumentsArrays")) { + IList arguments = ConvertExpressionsArray(elementInit); + if (arguments == null) + return null; + result.Elements.Add(new ArrayInitializerExpression(arguments)); + } + return result; + } + #endregion + + #region Convert MemberInit + Expression ConvertMemberInit(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + ObjectCreateExpression oce = Convert(invocation.Arguments.ElementAt(0)) as ObjectCreateExpression; + if (oce == null) + return null; + Expression elementsArray = invocation.Arguments.ElementAt(1); + ArrayInitializerExpression bindings = ConvertMemberBindings(elementsArray); + if (bindings == null) + return null; + oce.Initializer = bindings; + return oce; + } + + static readonly Pattern memberBindingArrayPattern = ArrayInitializationPattern(typeof(System.Linq.Expressions.MemberBinding), new AnyNode("binding")); + static readonly INode expressionTypeReference = new TypeReferenceExpression(new TypePattern(typeof(System.Linq.Expressions.Expression))); + + ArrayInitializerExpression ConvertMemberBindings(Expression elementsArray) + { + Match m = memberBindingArrayPattern.Match(elementsArray); + if (!m.Success) + return null; + ArrayInitializerExpression result = new ArrayInitializerExpression(); + foreach (var binding in m.Get("binding")) { + InvocationExpression bindingInvocation = binding as InvocationExpression; + if (bindingInvocation == null || bindingInvocation.Arguments.Count != 2) + return null; + MemberReferenceExpression bindingMRE = bindingInvocation.Target as MemberReferenceExpression; + if (bindingMRE == null || !expressionTypeReference.IsMatch(bindingMRE.Target)) + return null; + + Expression bindingTarget = bindingInvocation.Arguments.ElementAt(0); + Expression bindingValue = bindingInvocation.Arguments.ElementAt(1); + + string memberName; + Match m2 = getMethodFromHandlePattern.Match(bindingTarget); + if (m2.Success) { + MethodReference setter = m2.Get("method").Single().Annotation(); + if (setter == null) + return null; + memberName = GetPropertyName(setter); + } else { + return null; + } + + Expression convertedValue; + switch (bindingMRE.MemberName) { + case "Bind": + convertedValue = Convert(bindingValue); + break; + case "MemberBind": + convertedValue = ConvertMemberBindings(bindingValue); + break; + case "ListBind": + convertedValue = ConvertElementInit(bindingValue); + break; + default: + return null; + } + if (convertedValue == null) + return null; + result.Elements.Add(new NamedExpression(memberName, convertedValue)); + } + return result; + } + #endregion + #region Convert Cast Expression ConvertCast(InvocationExpression invocation, bool isChecked) { - if (invocation.Arguments.Count != 2) + if (invocation.Arguments.Count < 2) return null; Expression converted = Convert(invocation.Arguments.ElementAt(0)); AstType type = ConvertTypeReference(invocation.Arguments.ElementAt(1)); if (converted != null && type != null) { CastExpression cast = converted.CastTo(type); cast.AddAnnotation(isChecked ? AddCheckedBlocks.CheckedAnnotation : AddCheckedBlocks.UncheckedAnnotation); - return cast; + switch (invocation.Arguments.Count) { + case 2: + return cast; + case 3: + Match m = getMethodFromHandlePattern.Match(invocation.Arguments.ElementAt(2)); + if (m.Success) + return cast.WithAnnotation(m.Get("method").Single().Annotation()); + else + return null; + } } return null; } #endregion #region ConvertExpressionsArray - static readonly Pattern expressionArrayPattern = new Choice { - new ArrayCreateExpression { - Type = new TypePattern(typeof(System.Linq.Expressions.Expression)), - Arguments = { new PrimitiveExpression(0) } - }, - new ArrayCreateExpression { - Type = new TypePattern(typeof(System.Linq.Expressions.Expression)), - AdditionalArraySpecifiers = { new ArraySpecifier() }, - Initializer = new ArrayInitializerExpression { - Elements = { new Repeat(new AnyNode("elements")) } + static Pattern ArrayInitializationPattern(Type arrayElementType, INode elementPattern) + { + return new Choice { + new ArrayCreateExpression { + Type = new TypePattern(arrayElementType), + Arguments = { new PrimitiveExpression(0) } + }, + new ArrayCreateExpression { + Type = new TypePattern(arrayElementType), + AdditionalArraySpecifiers = { new ArraySpecifier() }, + Initializer = new ArrayInitializerExpression { + Elements = { new Repeat(elementPattern) } + } } - } - }; + }; + } + + static readonly Pattern expressionArrayPattern = ArrayInitializationPattern(typeof(System.Linq.Expressions.Expression), new AnyNode("elements")); IList ConvertExpressionsArray(Expression arrayExpression) { @@ -682,6 +830,9 @@ namespace ICSharpCode.Decompiler.Ast.Transforms AstType elementType = ConvertTypeReference(invocation.Arguments.ElementAt(0)); IList elements = ConvertExpressionsArray(invocation.Arguments.ElementAt(1)); if (elementType != null && elements != null) { + if (ContainsAnonymousType(elementType)) { + elementType = null; + } return new ArrayCreateExpression { Type = elementType, AdditionalArraySpecifiers = { new ArraySpecifier() }, @@ -690,6 +841,35 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } return null; } + + Expression ConvertNewArrayBounds(InvocationExpression invocation) + { + if (invocation.Arguments.Count != 2) + return NotSupported(invocation); + + AstType elementType = ConvertTypeReference(invocation.Arguments.ElementAt(0)); + IList arguments = ConvertExpressionsArray(invocation.Arguments.ElementAt(1)); + if (elementType != null && arguments != null) { + if (ContainsAnonymousType(elementType)) { + elementType = null; + } + ArrayCreateExpression ace = new ArrayCreateExpression(); + ace.Type = elementType; + ace.Arguments.AddRange(arguments); + return ace; + } + return null; + } + + bool ContainsAnonymousType(AstType type) + { + foreach (AstType t in type.DescendantsAndSelf.OfType()) { + TypeReference tr = t.Annotation(); + if (tr != null && tr.IsAnonymousType()) + return true; + } + return false; + } #endregion } }