diff --git a/ICSharpCode.Decompiler.Tests/TestCases/Pretty/LocalFunctions.cs b/ICSharpCode.Decompiler.Tests/TestCases/Pretty/LocalFunctions.cs index 03c37273a..f66883c7b 100644 --- a/ICSharpCode.Decompiler.Tests/TestCases/Pretty/LocalFunctions.cs +++ b/ICSharpCode.Decompiler.Tests/TestCases/Pretty/LocalFunctions.cs @@ -217,6 +217,34 @@ namespace LocalFunctions return FibHelper(n - 1) + FibHelper(n - 2); } } + public int MutuallyRecursiveLocalFunctions() + { + return B(4) + C(3); + + int A(int i) + { + if (i > 0) { + return A(i - 1) + 2 * B(i - 1) + 3 * C(i - 1); + } + return 1; + } + + int B(int i) + { + if (i > 0) { + return 3 * A(i - 1) + B(i - 1); + } + return 1; + } + + int C(int i) + { + if (i > 0) { + return 2 * A(i - 1) + C(i - 1); + } + return 1; + } + } public static int NestedLocalFunctions(int i) { diff --git a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs index 5fb95a319..87602db3f 100644 --- a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs +++ b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs @@ -207,11 +207,6 @@ namespace ICSharpCode.Decompiler.CSharp } } - internal bool IsLocalFunction(IMethod method) - { - return settings.LocalFunctions && IL.Transforms.LocalFunctionDecompiler.IsLocalFunctionMethod(method); - } - internal ILFunction ResolveLocalFunction(IMethod method) { Debug.Assert(method.IsLocalFunction); diff --git a/ICSharpCode.Decompiler/IL/Transforms/DelegateConstruction.cs b/ICSharpCode.Decompiler/IL/Transforms/DelegateConstruction.cs index b3f617029..7edfeb1ef 100644 --- a/ICSharpCode.Decompiler/IL/Transforms/DelegateConstruction.cs +++ b/ICSharpCode.Decompiler/IL/Transforms/DelegateConstruction.cs @@ -126,7 +126,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms return null; if (targetMethod.MetadataToken.IsNil) return null; - if (LocalFunctionDecompiler.IsLocalFunctionMethod(targetMethod)) + if (LocalFunctionDecompiler.IsLocalFunctionMethod(targetMethod, context)) return null; target = value.Arguments[0]; var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken); diff --git a/ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs b/ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs index 7937cf2a7..dafda4357 100644 --- a/ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs +++ b/ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs @@ -40,6 +40,12 @@ namespace ICSharpCode.Decompiler.IL.Transforms { ILTransformContext context; + struct LocalFunctionInfo + { + public List UseSites; + public ILFunction Definition; + } + /// /// The transform works like this: /// @@ -51,67 +57,103 @@ namespace ICSharpCode.Decompiler.IL.Transforms /// After all use-sites are collected we construct the ILAst of the local function and add it to the parent function. /// Then all use-sites of the local-function are transformed to a call to the LocalFunctionMethod or a ldftn of the LocalFunctionMethod. /// In a next step we handle all nested local functions. - /// After all local functions are transformed, we move all local functions that do not capture any variables to the top-level function. + /// After all local functions are transformed, we move all local functions that capture any variables to their respective declaration scope. /// public void Run(ILFunction function, ILTransformContext context) { if (!context.Settings.LocalFunctions) return; this.context = context; - var localFunctions = new Dictionary>(); + var localFunctions = new Dictionary(); var cancellationToken = context.CancellationToken; - // Find use-sites - foreach (var inst in function.Descendants) { + // Find all local functions declared inside this method, including nested local functions or local functions declared in lambdas. + FindUseSites(function, context, localFunctions); + foreach (var (method, info) in localFunctions) { cancellationToken.ThrowIfCancellationRequested(); - if (inst is CallInstruction call && IsLocalFunctionMethod(call.Method) && !call.Method.IsLocalFunction) { - if (!localFunctions.TryGetValue(call.Method, out var info)) { - info = new List() { call }; - localFunctions.Add(call.Method, info); - } else { - info.Add(call); + var firstUseSite = info.UseSites[0]; + context.StepStartGroup($"Transform " + info.Definition.Name, info.Definition); + try { + var localFunction = info.Definition; + if (!localFunction.Method.IsStatic) { + var target = firstUseSite.Arguments[0]; + context.Step($"Replace 'this' with {target}", localFunction); + var thisVar = localFunction.Variables.SingleOrDefault(VariableKindExtensions.IsThis); + localFunction.AcceptVisitor(new DelegateConstruction.ReplaceDelegateTargetVisitor(target, thisVar)); } - } else if (inst is LdFtn ldftn && !ldftn.Method.IsLocalFunction && ldftn.Parent is NewObj newObj && IsLocalFunctionMethod(ldftn.Method) && DelegateConstruction.IsDelegateConstruction(newObj)) { - context.StepStartGroup($"LocalFunctionDecompiler {ldftn.StartILOffset}", ldftn); - if (!localFunctions.TryGetValue(ldftn.Method, out var info)) { - info = new List() { newObj }; - localFunctions.Add(ldftn.Method, info); - } else { - info.Add(newObj); + + foreach (var useSite in info.UseSites) { + context.Step("Transform use site at " + useSite.StartILOffset, useSite); + if (useSite.OpCode == OpCode.NewObj) { + TransformToLocalFunctionReference(localFunction, useSite); + } else { + DetermineCaptureAndDeclarationScope(localFunction, useSite); + TransformToLocalFunctionInvocation(localFunction.ReducedMethod, useSite); + } } - context.StepEndGroup(); - } - } - foreach (var (method, useSites) in localFunctions) { - context.StepStartGroup($"LocalFunctionDecompiler {useSites[0].StartILOffset}", useSites[0]); - try { - TransformLocalFunction(function, method, useSites); + if (localFunction.DeclarationScope == null) { + localFunction.DeclarationScope = (BlockContainer)function.Body; + } else if (localFunction.DeclarationScope != function.Body && localFunction.DeclarationScope.Parent is ILFunction declaringFunction) { + function.LocalFunctions.Remove(localFunction); + declaringFunction.LocalFunctions.Add(localFunction); + } } finally { context.StepEndGroup(); } } + } - foreach (var f in function.LocalFunctions) { - // handle nested functions - var nestedContext = new ILTransformContext(context, f); - nestedContext.StepStartGroup("LocalFunctionDecompiler (nested functions)", f); - new LocalFunctionDecompiler().Run(f, nestedContext); - nestedContext.StepEndGroup(); + void FindUseSites(ILFunction function, ILTransformContext context, Dictionary localFunctions) + { + foreach (var inst in function.Body.Descendants) { + context.CancellationToken.ThrowIfCancellationRequested(); + if (inst is CallInstruction call && !call.Method.IsLocalFunction && IsLocalFunctionMethod(call.Method, context)) { + HandleUseSite(call.Method, call); + } else if (inst is LdFtn ldftn && !ldftn.Method.IsLocalFunction && ldftn.Parent is NewObj newObj && IsLocalFunctionMethod(ldftn.Method, context) && DelegateConstruction.IsDelegateConstruction(newObj)) { + HandleUseSite(ldftn.Method, newObj); + } } - if (function.Kind == ILFunctionKind.TopLevelFunction) { - var movableFunctions = TreeTraversal.PostOrder(function, f => f.LocalFunctions) - .Where(f => f.Kind == ILFunctionKind.LocalFunction && f.DeclarationScope == null) - .ToArray(); - foreach (var f in movableFunctions) { - var parent = (ILFunction)f.Parent; - f.DeclarationScope = (BlockContainer)function.Body; - parent.LocalFunctions.Remove(f); - function.LocalFunctions.Add(f); + void HandleUseSite(IMethod targetMethod, CallInstruction inst) + { + if (!localFunctions.TryGetValue((MethodDefinitionHandle)targetMethod.MetadataToken, out var info)) { + context.StepStartGroup($"Read local function '{targetMethod.Name}'", inst); + info = new LocalFunctionInfo() { + UseSites = new List() { inst }, + Definition = ReadLocalFunctionDefinition(context.Function, targetMethod) + }; + localFunctions.Add((MethodDefinitionHandle)targetMethod.MetadataToken, info); + FindUseSites(info.Definition, context, localFunctions); + context.StepEndGroup(); + } else { + info.UseSites.Add(inst); } } } + ILFunction ReadLocalFunctionDefinition(ILFunction rootFunction, IMethod targetMethod) + { + var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken); + if (!methodDefinition.HasBody()) + return null; + var genericContext = DelegateConstruction.GenericContextFromTypeArguments(targetMethod.Substitution); + if (genericContext == null) + return null; + var ilReader = context.CreateILReader(); + var body = context.PEFile.Reader.GetMethodBody(methodDefinition.RelativeVirtualAddress); + var function = ilReader.ReadIL((MethodDefinitionHandle)targetMethod.MetadataToken, body, genericContext.Value, ILFunctionKind.LocalFunction, context.CancellationToken); + // Embed the local function into the parent function's ILAst, so that "Show steps" can show + // how the local function body is being transformed. + rootFunction.LocalFunctions.Add(function); + function.DeclarationScope = (BlockContainer)rootFunction.Body; + function.CheckInvariant(ILPhase.Normal); + var nestedContext = new ILTransformContext(context, function); + function.RunTransforms(CSharpDecompiler.GetILTransforms().TakeWhile(t => !(t is LocalFunctionDecompiler)), nestedContext); + function.DeclarationScope = null; + function.ReducedMethod = ReduceToLocalFunction(targetMethod); + return function; + } + static T FindCommonAncestorInstruction(ILInstruction a, ILInstruction b) where T : ILInstruction { @@ -137,71 +179,13 @@ namespace ICSharpCode.Decompiler.IL.Transforms internal static ILInstruction GetStatement(ILInstruction inst) { while (inst.Parent != null) { - if (inst.Parent is Block) + if (inst.Parent is Block b && b.Kind == BlockKind.ControlFlow) return inst; inst = inst.Parent; } return inst; } - private ILFunction TransformLocalFunction(ILFunction parentFunction, IMethod targetMethod, List useSites) - { - var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken); - if (!methodDefinition.HasBody()) - return null; - var genericContext = DelegateConstruction.GenericContextFromTypeArguments(targetMethod.Substitution); - if (genericContext == null) - return null; - var function = parentFunction.Ancestors.OfType().SelectMany(f => f.LocalFunctions).FirstOrDefault(f => f.Method == targetMethod); - if (function == null) { - var ilReader = context.CreateILReader(); - var body = context.PEFile.Reader.GetMethodBody(methodDefinition.RelativeVirtualAddress); - function = ilReader.ReadIL((MethodDefinitionHandle)targetMethod.MetadataToken, body, genericContext.Value, ILFunctionKind.LocalFunction, context.CancellationToken); - // Embed the local function into the parent function's ILAst, so that "Show steps" can show - // how the local function body is being transformed. - parentFunction.LocalFunctions.Add(function); - function.DeclarationScope = (BlockContainer)parentFunction.Body; - function.CheckInvariant(ILPhase.Normal); - var nestedContext = new ILTransformContext(context, function); - function.RunTransforms(CSharpDecompiler.GetILTransforms().TakeWhile(t => !(t is LocalFunctionDecompiler)), nestedContext); - if (IsNonLocalTarget(targetMethod, useSites, out var target)) { - Debug.Assert(target != null); - nestedContext.Step("LocalFunctionDecompiler (ReplaceDelegateTargetVisitor)", function); - var thisVar = function.Variables.SingleOrDefault(v => v.Index == -1 && v.Kind == VariableKind.Parameter); - function.AcceptVisitor(new DelegateConstruction.ReplaceDelegateTargetVisitor(target, thisVar)); - } - function.DeclarationScope = null; - function.ReducedMethod = ReduceToLocalFunction(targetMethod); - - foreach (var innerUseSite in function.Descendants.OfType()) { - if (innerUseSite.Method != function.Method) - continue; - if (innerUseSite.OpCode == OpCode.NewObj) { - TransformToLocalFunctionReference(function, innerUseSite); - } else { - TransformToLocalFunctionInvocation(function.ReducedMethod, innerUseSite); - } - } - } - foreach (var useSite in useSites) { - if (useSite.OpCode == OpCode.NewObj) { - TransformToLocalFunctionReference(function, useSite); - } else { - DetermineCaptureAndDeclarationScope(function, useSite); - TransformToLocalFunctionInvocation(function.ReducedMethod, useSite); - } - } - - if (function.DeclarationScope != null - && parentFunction.LocalFunctions.Contains(function) - && function.DeclarationScope.Parent is ILFunction betterParentFunction) { - parentFunction.LocalFunctions.Remove(function); - betterParentFunction.LocalFunctions.Add(function); - } - - return function; - } - LocalFunctionMethod ReduceToLocalFunction(IMethod method) { int parametersToRemove = 0; @@ -272,47 +256,28 @@ namespace ICSharpCode.Decompiler.IL.Transforms } } - bool IsNonLocalTarget(IMethod targetMethod, List useSites, out ILInstruction target) - { - target = null; - if (targetMethod.IsStatic) - return false; - ValidateUseSites(useSites); - target = useSites.Select(call => call.Arguments.First()).First(); - return !target.MatchLdThis(); - } - - [Conditional("DEBUG")] - static void ValidateUseSites(List useSites) - { - ILInstruction targetInstruction = null; - foreach (var site in useSites) { - if (targetInstruction == null) - targetInstruction = site.Arguments.First(); - else - Debug.Assert(targetInstruction.Match(site.Arguments[0]).Success); - } - } - - internal static bool IsLocalFunctionReference(NewObj inst) + internal static bool IsLocalFunctionReference(NewObj inst, ILTransformContext context) { if (inst == null || inst.Arguments.Count != 2 || inst.Method.DeclaringType.Kind != TypeKind.Delegate) return false; var opCode = inst.Arguments[1].OpCode; return (opCode == OpCode.LdFtn || opCode == OpCode.LdVirtFtn) - && IsLocalFunctionMethod(((IInstructionWithMethodOperand)inst.Arguments[1]).Method); + && IsLocalFunctionMethod(((IInstructionWithMethodOperand)inst.Arguments[1]).Method, context); } - public static bool IsLocalFunctionMethod(IMethod method) + public static bool IsLocalFunctionMethod(IMethod method, ILTransformContext context) { if (method.MetadataToken.IsNil) return false; - return IsLocalFunctionMethod(method.ParentModule.PEFile, (MethodDefinitionHandle)method.MetadataToken); + return IsLocalFunctionMethod(method.ParentModule.PEFile, (MethodDefinitionHandle)method.MetadataToken, context); } - public static bool IsLocalFunctionMethod(PEFile module, MethodDefinitionHandle methodHandle) + public static bool IsLocalFunctionMethod(PEFile module, MethodDefinitionHandle methodHandle, ILTransformContext context = null) { + if (context != null && context.PEFile != module) + return false; + var metadata = module.Metadata; var method = metadata.GetMethodDefinition(methodHandle); var declaringType = method.GetDeclaringType(); @@ -326,12 +291,15 @@ namespace ICSharpCode.Decompiler.IL.Transforms return true; } - public static bool IsLocalFunctionDisplayClass(PEFile module, TypeDefinitionHandle typeHandle) + public static bool IsLocalFunctionDisplayClass(PEFile module, TypeDefinitionHandle typeHandle, ILTransformContext context = null) { + if (context != null && context.PEFile != module) + return false; + var metadata = module.Metadata; var type = metadata.GetTypeDefinition(typeHandle); - if ((type.Attributes & TypeAttributes.NestedPrivate) == 0) + if ((type.Attributes & TypeAttributes.VisibilityMask) != TypeAttributes.NestedPrivate) return false; if (!type.HasGeneratedName(metadata)) return false; @@ -340,7 +308,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms var declaringType = metadata.GetTypeDefinition(declaringTypeHandle); foreach (var method in declaringType.GetMethods()) { - if (!IsLocalFunctionMethod(module, method)) + if (!IsLocalFunctionMethod(module, method, context)) continue; var md = metadata.GetMethodDefinition(method); if (md.DecodeSignature(new FindTypeDecoder(typeHandle), default).ParameterTypes.Any()) diff --git a/ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs b/ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs index 4fb8ddb9e..382abe96a 100644 --- a/ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs +++ b/ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs @@ -39,6 +39,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms public ITypeDefinition Definition; public Dictionary Variables; public BlockContainer CaptureScope; + public ILFunction DeclaringFunction; } struct DisplayClassVariable @@ -105,7 +106,8 @@ namespace ICSharpCode.Decompiler.IL.Transforms Variable = v, Definition = closureType, Variables = new Dictionary(), - CaptureScope = (isMono && IsMonoNestedCaptureScope(closureType)) || localFunctionClosureParameter ? null : v.CaptureScope + CaptureScope = (isMono && IsMonoNestedCaptureScope(closureType)) || localFunctionClosureParameter ? null : v.CaptureScope, + DeclaringFunction = localFunctionClosureParameter ? f.DeclarationScope.Ancestors.OfType().First() : f }); } else { if (displayClass.CaptureScope == null && !localFunctionClosureParameter) @@ -296,10 +298,10 @@ namespace ICSharpCode.Decompiler.IL.Transforms } else { Debug.Assert(displayClass.Definition == field.DeclaringTypeDefinition); // Introduce a fresh variable for the display class field. - v = currentFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name); if (displayClass.IsMono && displayClass.CaptureScope == null && !IsOuterClosureReference(field)) { displayClass.CaptureScope = BlockContainer.FindClosestContainer(inst); } + v = displayClass.DeclaringFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name); v.CaptureScope = displayClass.CaptureScope; inst.ReplaceWith(new StLoc(v, inst.Value).WithILRange(inst)); value = new LdLoc(v); @@ -346,7 +348,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms if (!displayClass.Variables.TryGetValue(field, out DisplayClassVariable info)) { // Introduce a fresh variable for the display class field. Debug.Assert(displayClass.Definition == field.DeclaringTypeDefinition); - var v = currentFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name); + var v = displayClass.DeclaringFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name); v.CaptureScope = displayClass.CaptureScope; inst.ReplaceWith(new LdLoca(v).WithILRange(inst)); displayClass.Variables.Add(field, new DisplayClassVariable { Value = new LdLoc(v), Variable = v });