diff --git a/ICSharpCode.Decompiler/IL/Transforms/TransformExpressionTrees.cs b/ICSharpCode.Decompiler/IL/Transforms/TransformExpressionTrees.cs index 7335d006e..8cfc236c5 100644 --- a/ICSharpCode.Decompiler/IL/Transforms/TransformExpressionTrees.cs +++ b/ICSharpCode.Decompiler/IL/Transforms/TransformExpressionTrees.cs @@ -31,16 +31,30 @@ namespace ICSharpCode.Decompiler.IL.Transforms { static bool MightBeExpressionTree(ILInstruction inst, ILInstruction stmt) { - if (!(inst is CallInstruction call && call.Method.FullName == "System.Linq.Expressions.Expression.Lambda")) + if (!(inst is CallInstruction call + && call.Method.FullName == "System.Linq.Expressions.Expression.Lambda" + && call.Arguments.Count == 2)) return false; - if (!ILInlining.CanUninline(call, stmt) || call.Arguments.Count != 2) + if (call.Parent is CallInstruction parentCall && parentCall.Method.FullName == "System.Linq.Expressions.Expression.Quote") return false; - if (!((call.Arguments[1] is CallInstruction emptyCall && emptyCall.Method.FullName == "System.Array.Empty" && emptyCall.Arguments.Count == 0) - || (call.Arguments[1] is Block block && block.Kind == BlockKind.ArrayInitializer))) + if (!(IsEmptyParameterList(call.Arguments[1]) || (call.Arguments[1] is Block block && block.Kind == BlockKind.ArrayInitializer))) return false; + //if (!ILInlining.CanUninline(call, stmt)) + // return false; return true; } + static bool IsEmptyParameterList(ILInstruction inst) + { + if (inst is CallInstruction emptyCall && emptyCall.Method.FullName == "System.Array.Empty" && emptyCall.Arguments.Count == 0) + return true; + if (inst.MatchNewArr(out var type) && type.FullName == "System.Linq.Expressions.ParameterExpression") + return true; + if (inst.MatchNewArr(out type) && type.FullName == "System.Linq.Expressions.Expression") + return true; + return false; + } + bool MatchParameterVariableAssignment(ILInstruction expr, out ILVariable parameterReferenceVar, out IType type, out string name) { // stloc(v, call(Expression::Parameter, call(Type::GetTypeFromHandle, ldtoken(...)), ldstr(...))) @@ -97,6 +111,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms if (MightBeExpressionTree(instruction, statement)) { var lambda = ConvertLambda((CallInstruction)instruction); if (lambda != null) { + context.Step("Convert Expression Tree", instruction); instruction.ReplaceWith(lambda); return true; } @@ -159,10 +174,8 @@ namespace ICSharpCode.Decompiler.IL.Transforms i++; } return true; - case CallInstruction emptyCall: - return emptyCall.Method.FullName == "System.Array.Empty" && emptyCall.Arguments.Count == 0; default: - return false; + return IsEmptyParameterList(initializer); } } @@ -211,8 +224,8 @@ namespace ICSharpCode.Decompiler.IL.Transforms return ConvertComparison(invocation, ComparisonKind.GreaterThan); case "GreaterThanOrEqual": return ConvertComparison(invocation, ComparisonKind.GreaterThanOrEqual); - /*case "Invoke": - return ConvertInvoke(invocation);*/ + case "Invoke": + return ConvertInvoke(invocation); case "Lambda": return ConvertLambda(invocation); case "LeftShift": @@ -221,10 +234,10 @@ namespace ICSharpCode.Decompiler.IL.Transforms return ConvertComparison(invocation, ComparisonKind.LessThan); case "LessThanOrEqual": return ConvertComparison(invocation, ComparisonKind.LessThanOrEqual); - /*case "ListInit": + case "ListInit": return ConvertListInit(invocation); case "MemberInit": - return ConvertMemberInit(invocation);*/ + return ConvertMemberInit(invocation); case "Modulo": return ConvertBinaryNumericOperator(invocation, BinaryNumericOperator.Rem); case "Multiply": @@ -237,10 +250,10 @@ namespace ICSharpCode.Decompiler.IL.Transforms return ConvertUnaryNumericOperator(invocation, BinaryNumericOperator.Sub, true); case "New": return ConvertNewObject(invocation); - /*case "NewArrayBounds": + case "NewArrayBounds": return ConvertNewArrayBounds(invocation); case "NewArrayInit": - return ConvertNewArrayInit(invocation);*/ + return ConvertNewArrayInit(invocation); case "Not": return ConvertNotOperator(invocation); case "NotEqual": @@ -253,11 +266,11 @@ namespace ICSharpCode.Decompiler.IL.Transforms return ConvertLogicOperator(invocation, false); case "Property": return ConvertProperty(invocation); - /*case "Quote": + case "Quote": if (invocation.Arguments.Count == 1) - return Convert(invocation.Arguments.Single()); + return ConvertInstruction(invocation.Arguments.Single()); else - return null;*/ + return null; case "RightShift": return ConvertBinaryNumericOperator(invocation, BinaryNumericOperator.ShiftRight); case "Subtract": @@ -271,7 +284,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms } return null; default: - return ConvertValue(instruction); + return ConvertValue(instruction, instruction.Parent); } } @@ -282,13 +295,18 @@ namespace ICSharpCode.Decompiler.IL.Transforms var array = ConvertInstruction(invocation.Arguments[0]); if (array == null) return null; - var type = ((ArrayType)array.InferType()).ElementType; + var arrayType = InferType(array); + if (!(arrayType is ArrayType type)) + return null; if (!MatchArgumentList(invocation.Arguments[1], out var arguments)) arguments = new[] { invocation.Arguments[1] }; - arguments = arguments.Select(ConvertInstruction).ToArray(); - if (arguments.Any(p => p == null)) - return null; - return new LdObj(new LdElema(type, array, arguments.ToArray()), type); + for (int i = 0; i < arguments.Count; i++) { + var converted = ConvertInstruction(arguments[i]); + if (converted == null) + return null; + arguments[i] = converted; + } + return new LdObj(new LdElema(type.ElementType, array, arguments.ToArray()), type.ElementType); } ILInstruction ConvertArrayLength(CallInstruction invocation) @@ -305,10 +323,10 @@ namespace ICSharpCode.Decompiler.IL.Transforms { if (invocation.Arguments.Count < 2) return null; - var left = ConvertInstruction(invocation.Arguments.ElementAt(0)); + var left = ConvertInstruction(invocation.Arguments[0]); if (left == null) return null; - var right = ConvertInstruction(invocation.Arguments.ElementAt(1)); + var right = ConvertInstruction(invocation.Arguments[1]); if (right == null) return null; IMember method; @@ -322,7 +340,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms Arguments = { left, right } }; case 4: - //if (!trueOrFalse.IsMatch(invocation.Arguments.ElementAt(2))) + //if (!trueOrFalse.IsMatch(invocation.Arguments[2])) // return null; if (!MatchGetMethodFromHandle(invocation.Arguments[3], out method)) return null; @@ -334,27 +352,6 @@ namespace ICSharpCode.Decompiler.IL.Transforms } } - ILInstruction ConvertNotOperator(CallInstruction invocation) - { - if (invocation.Arguments.Count < 1) - return null; - var argument = ConvertInstruction(invocation.Arguments.ElementAt(0)); - if (argument == null) - return null; - switch (invocation.Arguments.Count) { - case 1: - return argument.InferType().IsKnownType(KnownTypeCode.Boolean) ? Comp.LogicNot(argument) : (ILInstruction)new BitNot(argument); - case 2: - if (!MatchGetMethodFromHandle(invocation.Arguments[1], out var method)) - return null; - return new Call((IMethod)method) { - Arguments = { argument } - }; - default: - return null; - } - } - ILInstruction ConvertCall(CallInstruction invocation) { if (invocation.Arguments.Count < 2) @@ -377,11 +374,21 @@ namespace ICSharpCode.Decompiler.IL.Transforms arguments = arguments.Select(ConvertInstruction).ToArray(); if (arguments.Any(p => p == null)) return null; + IMethod method = (IMethod)member; + if (method.FullName == "System.Reflection.MethodInfo.CreateDelegate" && method.Parameters.Count == 2) { + if (!MatchGetMethodFromHandle(arguments[0], out var targetMethod)) + return null; + if (!MatchGetTypeFromHandle(arguments[1], out var delegateType)) + return null; + return new NewObj(delegateType.GetConstructors().Single()) { + Arguments = { arguments[2], new LdFtn((IMethod)targetMethod) } + }; + } CallInstruction call; - if (member.IsAbstract || member.IsVirtual || member.IsOverridable) { - call = new CallVirt((IMethod)member); + if (method.IsAbstract || method.IsVirtual || method.IsOverridable) { + call = new CallVirt(method); } else { - call = new Call((IMethod)member); + call = new Call(method); } call.Arguments.AddRange(arguments); return call; @@ -391,19 +398,12 @@ namespace ICSharpCode.Decompiler.IL.Transforms { if (invocation.Arguments.Count < 2) return null; - if (!MatchTypeOfCall(invocation.Arguments[1], out var targetType)) + if (!MatchGetTypeFromHandle(invocation.Arguments[1], out var targetType)) return null; var expr = ConvertInstruction(invocation.Arguments[0]); if (expr == null) return null; - var sourceType = expr.InferType(); - if (sourceType.Equals(SpecialType.UnknownType)) - return null; - if (sourceType.IsReferenceType != true) { - if (targetType.IsKnownType(KnownTypeCode.Object)) - return new Box(expr, sourceType); - } - return null; + return new ExpressionTreeCast(targetType, expr, isChecked); } ILInstruction ConvertCoalesce(CallInstruction invocation) @@ -457,10 +457,31 @@ namespace ICSharpCode.Decompiler.IL.Transforms { if (!MatchConstantCall(invocation, out var value, out var type)) return null; - if (value.MatchBox(out var arg, out _)) { + if (value.MatchBox(out var arg, out var boxType)) { + if (boxType.Kind == TypeKind.Enum) + return new ExpressionTreeCast(boxType, ConvertValue(arg, invocation), false); value = arg; } - return ConvertValue(value); + return ConvertValue(value, invocation); + } + + ILInstruction ConvertElementInit(CallInstruction invocation) + { + if (invocation.Arguments.Count != 2) + return null; + if (!MatchGetMethodFromHandle(invocation.Arguments[0], out var member)) + return null; + if (!MatchArgumentList(invocation.Arguments[1], out var arguments)) + return null; + CallInstruction call = new Call((IMethod)member); + for (int i = 0; i < arguments.Count; i++) { + ILInstruction arg = ConvertInstruction(arguments[i]); + if (arg == null) + return null; + arguments[i] = arg; + } + call.Arguments.AddRange(arguments); + return call; } ILInstruction ConvertField(CallInstruction invocation) @@ -482,14 +503,80 @@ namespace ICSharpCode.Decompiler.IL.Transforms } } + ILInstruction ConvertInvoke(CallInstruction invocation) + { + if (invocation.Arguments.Count != 2) + return null; + var target = ConvertInstruction(invocation.Arguments[0]); + if (target == null) + return null; + var invokeMethod = InferType(target).GetDelegateInvokeMethod(); + if (invokeMethod == null) + return null; + if (!MatchArgumentList(invocation.Arguments[1], out var arguments)) + return null; + for (int i = 0; i < arguments.Count; i++) { + var arg = ConvertInstruction(arguments[i]); + if (arg == null) + return null; + arguments[i] = arg; + } + var call = new Call(invokeMethod); + call.Arguments.Add(target); + call.Arguments.AddRange(arguments); + return call; + } + + ILInstruction ConvertListInit(CallInstruction invocation) + { + if (invocation.Arguments.Count < 2) + return null; + var newObj = ConvertInstruction(invocation.Arguments[0]) as NewObj; + if (newObj == null) + return null; + IList arguments = null; + ILFunction function; + if (!MatchGetMethodFromHandle(invocation.Arguments[1], out var member)) { + if (!MatchArgumentList(invocation.Arguments[1], out arguments)) + return null; + function = ((LdLoc)((Block)invocation.Arguments[1]).FinalInstruction).Variable.Function; + } else { + if (invocation.Arguments.Count != 3 || !MatchArgumentList(invocation.Arguments[2], out arguments)) + return null; + function = ((LdLoc)((Block)invocation.Arguments[2]).FinalInstruction).Variable.Function; + } + if (arguments == null || arguments.Count == 0) + return null; + var initializer = function.RegisterVariable(VariableKind.InitializerTarget, newObj.Method.DeclaringType); + for (int i = 0; i < arguments.Count; i++) { + ILInstruction arg; + if (arguments[i] is CallInstruction elementInit && elementInit.Method.FullName == "System.Linq.Expressions.Expression.ElementInit") { + arg = ConvertElementInit(elementInit); + if (arg == null) + return null; + ((CallInstruction)arg).Arguments.Insert(0, new LdLoc(initializer)); + } else { + arg = ConvertInstruction(arguments[i]); + if (arg == null) + return null; + } + arguments[i] = arg; + } + var initializerBlock = new Block(BlockKind.CollectionInitializer); + initializerBlock.FinalInstruction = new LdLoc(initializer); + initializerBlock.Instructions.Add(new StLoc(initializer, newObj)); + initializerBlock.Instructions.AddRange(arguments); + return initializerBlock; + } + ILInstruction ConvertLogicOperator(CallInstruction invocation, bool and) { if (invocation.Arguments.Count < 2) return null; - var left = ConvertInstruction(invocation.Arguments.ElementAt(0)); + var left = ConvertInstruction(invocation.Arguments[0]); if (left == null) return null; - var right = ConvertInstruction(invocation.Arguments.ElementAt(1)); + var right = ConvertInstruction(invocation.Arguments[1]); if (right == null) return null; IMember method; @@ -503,7 +590,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms Arguments = { left, right } }; case 4: - //if (!trueOrFalse.IsMatch(invocation.Arguments.ElementAt(2))) + //if (!trueOrFalse.IsMatch(invocation.Arguments[2])) // return null; if (!MatchGetMethodFromHandle(invocation.Arguments[3], out method)) return null; @@ -515,12 +602,65 @@ namespace ICSharpCode.Decompiler.IL.Transforms } } + ILInstruction ConvertMemberInit(CallInstruction invocation) + { + return null; + } + + ILInstruction ConvertNewArrayBounds(CallInstruction invocation) + { + if (invocation.Arguments.Count != 2) + return null; + if (!MatchGetTypeFromHandle(invocation.Arguments[0], out var type)) + return null; + if (!MatchArgumentList(invocation.Arguments[1], out var arguments)) + return null; + if (arguments.Count == 0) + return null; + var indices = new ILInstruction[arguments.Count]; + for (int i = 0; i < arguments.Count; i++) { + var index = ConvertInstruction(arguments[i]); + if (index == null) + return null; + indices[i] = index; + } + return new NewArr(type, indices); + } + + ILInstruction ConvertNewArrayInit(CallInstruction invocation) + { + if (invocation.Arguments.Count != 2) + return null; + if (!MatchGetTypeFromHandle(invocation.Arguments[0], out var type)) + return null; + if (!MatchArgumentList(invocation.Arguments[1], out var arguments)) + return null; + if (arguments.Count == 0) + return null; + var block = (Block)invocation.Arguments[1]; + var function = ((LdLoc)block.FinalInstruction).Variable.Function; + var variable = function.RegisterVariable(VariableKind.InitializerTarget, new ArrayType(context.BlockContext.TypeSystem.Compilation, type)); + Block initializer = new Block(BlockKind.ArrayInitializer); + int i = 0; + initializer.Instructions.Add(new StLoc(variable, new NewArr(type, new LdcI4(arguments.Count)))); + foreach (var item in arguments) { + var value = ConvertInstruction(item); + if (value == null) + return null; + initializer.Instructions.Add(new StObj(new LdElema(type, new LdLoc(variable), new LdcI4(i)), value, type)); + } + initializer.FinalInstruction = new LdLoc(variable); + return initializer; + } + ILInstruction ConvertNewObject(CallInstruction invocation) { IMember member; + IList arguments; + NewObj newObj; switch (invocation.Arguments.Count) { case 1: - if (MatchTypeOfCall(invocation.Arguments[0], out var type)) { + if (MatchGetTypeFromHandle(invocation.Arguments[0], out var type)) { var ctors = type.GetConstructors().ToArray(); if (ctors.Length != 1 || ctors[0].Parameters.Count > 0) return null; @@ -533,18 +673,50 @@ namespace ICSharpCode.Decompiler.IL.Transforms case 2: if (!MatchGetConstructorFromHandle(invocation.Arguments[0], out member)) return null; - if (!MatchArgumentList(invocation.Arguments[1], out var arguments)) + if (!MatchArgumentList(invocation.Arguments[1], out arguments)) return null; var args = arguments.SelectArray(ConvertInstruction); if (args.Any(a => a == null)) return null; - var newObj = new NewObj((IMethod)member); + newObj = new NewObj((IMethod)member); newObj.Arguments.AddRange(args); return newObj; + case 3: + if (!MatchGetConstructorFromHandle(invocation.Arguments[0], out member)) + return null; + if (!MatchArgumentList(invocation.Arguments[1], out arguments)) + return null; + var args2 = arguments.SelectArray(ConvertInstruction); + if (args2.Any(a => a == null)) + return null; + newObj = new NewObj((IMethod)member); + newObj.Arguments.AddRange(args2); + return newObj; } return null; } + ILInstruction ConvertNotOperator(CallInstruction invocation) + { + if (invocation.Arguments.Count < 1) + return null; + var argument = ConvertInstruction(invocation.Arguments[0]); + if (argument == null) + return null; + switch (invocation.Arguments.Count) { + case 1: + return InferType(argument).IsKnownType(KnownTypeCode.Boolean) ? Comp.LogicNot(argument) : (ILInstruction)new BitNot(argument); + case 2: + if (!MatchGetMethodFromHandle(invocation.Arguments[1], out var method)) + return null; + return new Call((IMethod)method) { + Arguments = { argument } + }; + default: + return null; + } + } + ILInstruction ConvertProperty(CallInstruction invocation) { if (invocation.Arguments.Count < 2) @@ -585,7 +757,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms if (invocation.Arguments.Count != 2) return null; var converted = ConvertInstruction(invocation.Arguments[0]); - if (!MatchTypeOfCall(invocation.Arguments[1], out var type)) + if (!MatchGetTypeFromHandle(invocation.Arguments[1], out var type)) return null; if (converted != null) return new IsInst(converted, type); @@ -597,7 +769,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms if (invocation.Arguments.Count != 2) return null; var converted = ConvertInstruction(invocation.Arguments[0]); - if (!MatchTypeOfCall(invocation.Arguments[1], out var type)) + if (!MatchGetTypeFromHandle(invocation.Arguments[1], out var type)) return null; if (converted != null) return new Comp(ComparisonKind.Inequality, Sign.None, new IsInst(converted, type), new LdNull()); @@ -608,7 +780,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms { if (invocation.Arguments.Count < 1) return null; - var argument = ConvertInstruction(invocation.Arguments.ElementAt(0)); + var argument = ConvertInstruction(invocation.Arguments[0]); if (argument == null) return null; switch (invocation.Arguments.Count) { @@ -624,18 +796,26 @@ namespace ICSharpCode.Decompiler.IL.Transforms return null; } - ILInstruction ConvertValue(ILInstruction value) + ILInstruction ConvertValue(ILInstruction value, ILInstruction context) { switch (value) { case LdLoc ldloc: if (IsExpressionTreeParameter(ldloc.Variable)) { if (!parameterMapping.TryGetValue(ldloc.Variable, out var v)) return null; + if (context is CallInstruction parentCall + && parentCall.Method.FullName == "System.Linq.Expressions.Expression.Call" + && v.StackType.IsIntegerType()) + return new LdLoca(v); return new LdLoc(v); } else { - return ldloc; + if (ldloc.Variable.Kind != VariableKind.StackSlot) + return new LdLoc(ldloc.Variable); + return null; } default: + if (SemanticHelper.IsPure(value.Flags)) + return value.Clone(); return value; } } @@ -649,14 +829,17 @@ namespace ICSharpCode.Decompiler.IL.Transforms { value = null; type = null; - if (inst is CallInstruction call && call.Arguments.Count == 2 && call.Method.FullName == "System.Linq.Expressions.Expression.Constant") { + if (inst is CallInstruction call && call.Method.FullName == "System.Linq.Expressions.Expression.Constant") { value = call.Arguments[0]; - return MatchTypeOfCall(call.Arguments[1], out type); + if (call.Arguments.Count == 2) + return MatchGetTypeFromHandle(call.Arguments[1], out type); + type = InferType(value); + return true; } return false; } - bool MatchTypeOfCall(ILInstruction inst, out IType type) + bool MatchGetTypeFromHandle(ILInstruction inst, out IType type) { type = null; return inst is CallInstruction getTypeCall @@ -739,7 +922,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms { arguments = null; if (!(inst is Block block && block.Kind == BlockKind.ArrayInitializer)) { - if (inst is CallInstruction emptyCall && emptyCall.Method.FullName == "System.Array.Empty" && emptyCall.Arguments.Count == 0) { + if (IsEmptyParameterList(inst)) { arguments = new List(); return true; } @@ -755,5 +938,14 @@ namespace ICSharpCode.Decompiler.IL.Transforms } return true; } + + IType InferType(ILInstruction inst) + { + if (inst is Block b && b.Kind == BlockKind.ArrayInitializer) + return b.FinalInstruction.InferType(); + if (inst is ExpressionTreeCast cast) + return cast.Type; + return inst.InferType(); + } } }