diff --git a/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs b/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs index 39c7f6783..c0fcabab3 100644 --- a/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs +++ b/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs @@ -94,13 +94,20 @@ namespace ICSharpCode.Decompiler.Ast.Transforms return base.VisitObjectCreateExpression(objectCreateExpression, data); } + internal static bool IsAnonymousMethod(DecompilerContext context, MethodDefinition method) + { + if (method == null || !method.Name.StartsWith("<", StringComparison.Ordinal)) + return false; + if (!(method.IsCompilerGenerated() || IsPotentialClosure(context, method.DeclaringType))) + return false; + return true; + } + bool HandleAnonymousMethod(ObjectCreateExpression objectCreateExpression, Expression target, MethodReference methodRef) { // Anonymous methods are defined in the same assembly, so there's no need to Resolve(). MethodDefinition method = methodRef as MethodDefinition; - if (method == null || !method.Name.StartsWith("<", StringComparison.Ordinal)) - return false; - if (!(method.IsCompilerGenerated() || IsPotentialClosure(method.DeclaringType))) + if (!IsAnonymousMethod(context, method)) return false; // Decompile the anonymous method: @@ -143,7 +150,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms return true; } - bool IsPotentialClosure(TypeDefinition potentialDisplayClass) + static bool IsPotentialClosure(DecompilerContext context, TypeDefinition potentialDisplayClass) { if (potentialDisplayClass == null || !potentialDisplayClass.IsCompilerGenerated()) return false; @@ -164,7 +171,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms continue; var variable = stmt.Variables.Single(); TypeDefinition type = stmt.Type.Annotation(); - if (!IsPotentialClosure(type)) + if (!IsPotentialClosure(context, type)) continue; ObjectCreateExpression oce = variable.Initializer as ObjectCreateExpression; if (oce == null || oce.Type.Annotation() != type || oce.Arguments.Any() || !oce.Initializer.IsNull) diff --git a/ICSharpCode.Decompiler/CecilExtensions.cs b/ICSharpCode.Decompiler/CecilExtensions.cs index a97ed2427..1ef980d1f 100644 --- a/ICSharpCode.Decompiler/CecilExtensions.cs +++ b/ICSharpCode.Decompiler/CecilExtensions.cs @@ -169,5 +169,14 @@ namespace ICSharpCode.Decompiler } return false; } + + public static bool IsCompilerGeneratedOrIsInCompilerGeneratedClass(this IMemberDefinition member) + { + if (member == null) + return false; + if (member.IsCompilerGenerated()) + return true; + return IsCompilerGeneratedOrIsInCompilerGeneratedClass(member.DeclaringType); + } } } diff --git a/ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs b/ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs index c2ee9a99e..daf0ecc48 100644 --- a/ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs +++ b/ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs @@ -13,12 +13,20 @@ namespace ICSharpCode.Decompiler.ILAst /// /// Handles peephole transformations on the ILAst. /// - public static class PeepholeTransforms + public class PeepholeTransforms { + DecompilerContext context; + ILBlock method; + public static void Run(DecompilerContext context, ILBlock method) { + PeepholeTransforms transforms = new PeepholeTransforms(); + transforms.context = context; + transforms.method = method; + PeepholeTransform[] blockTransforms = { - ArrayInitializers.Transform(method) + ArrayInitializers.Transform(method), + transforms.CachedDelegateInitialization }; Func[] exprTransforms = { EliminateDups, @@ -67,6 +75,7 @@ namespace ICSharpCode.Decompiler.ILAst return expr; } + #region HandleDecimalConstants static ILExpression HandleDecimalConstants(ILExpression expr) { if (expr.Code == ILCode.Newobj) { @@ -105,5 +114,69 @@ namespace ICSharpCode.Decompiler.ILAst else return null; } + #endregion + + #region CachedDelegateInitialization + void CachedDelegateInitialization(ILBlock block, ref int i) + { + // if (logicnot(brtrue(ldsfld(field)))) { + // stsfld(field, newobj(Action::.ctor, ldnull(), ldftn(method))) + // } else { + // } + // ...(..., ldsfld(field), ...) + + ILCondition c = block.Body[i] as ILCondition; + if (c == null || c.Condition == null && c.TrueBlock == null || c.FalseBlock == null) + return; + if (!(c.TrueBlock.Body.Count == 1 && c.FalseBlock.Body.Count == 0)) + return; + ILExpression condition = UnpackBrFalse(c.Condition); + if (condition == null || condition.Code != ILCode.Ldsfld) + return; + FieldDefinition field = condition.Operand as FieldDefinition; // field is defined in current assembly + if (field == null || !field.IsCompilerGeneratedOrIsInCompilerGeneratedClass()) + return; + ILExpression stsfld = c.TrueBlock.Body[0] as ILExpression; + if (!(stsfld != null && stsfld.Code == ILCode.Stsfld && stsfld.Operand == field)) + return; + ILExpression newObj = stsfld.Arguments[0]; + if (!(newObj.Code == ILCode.Newobj && newObj.Arguments.Count == 2)) + return; + if (newObj.Arguments[0].Code != ILCode.Ldnull) + return; + if (newObj.Arguments[1].Code != ILCode.Ldftn) + return; + MethodDefinition anonymousMethod = newObj.Arguments[1].Operand as MethodDefinition; // method is defined in current assembly + if (!Ast.Transforms.DelegateConstruction.IsAnonymousMethod(context, anonymousMethod)) + return; + + ILExpression expr = block.Body.ElementAtOrDefault(i + 1) as ILExpression; + if (expr != null && expr.GetSelfAndChildrenRecursive().Count(e => e.Code == ILCode.Ldsfld && e.Operand == field) == 1) { + foreach (ILExpression parent in expr.GetSelfAndChildrenRecursive()) { + for (int j = 0; j < parent.Arguments.Count; j++) { + if (parent.Arguments[j].Code == ILCode.Ldsfld && parent.Arguments[j].Operand == field) { + parent.Arguments[j] = newObj; + block.Body.RemoveAt(i); + i -= ILInlining.InlineInto(block, i, method); + return; + } + } + } + } + } + + /// + /// Returns 'result' in brfalse(result) or logicnot(brtrue(result)). + /// + static ILExpression UnpackBrFalse(ILExpression condition) + { + if (condition.Code == ILCode.Brfalse) { + return condition.Arguments.Single(); + } else if (condition.Code == ILCode.LogicNot && condition.Arguments.Single().Code == ILCode.Brtrue) { + return condition.Arguments.Single().Arguments.Single(); + } + return null; + } + #endregion } }