Browse Source

Refactor LocalFunctionDecompiler to allow mutually recursive local functions to be decompiled correctly.

pull/1586/head
Siegfried Pammer 6 years ago
parent
commit
a109b77858
  1. 28
      ICSharpCode.Decompiler.Tests/TestCases/Pretty/LocalFunctions.cs
  2. 5
      ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs
  3. 2
      ICSharpCode.Decompiler/IL/Transforms/DelegateConstruction.cs
  4. 220
      ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs
  5. 8
      ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs

28
ICSharpCode.Decompiler.Tests/TestCases/Pretty/LocalFunctions.cs

@ -217,6 +217,34 @@ namespace LocalFunctions @@ -217,6 +217,34 @@ namespace LocalFunctions
return FibHelper(n - 1) + FibHelper(n - 2);
}
}
public int MutuallyRecursiveLocalFunctions()
{
return B(4) + C(3);
int A(int i)
{
if (i > 0) {
return A(i - 1) + 2 * B(i - 1) + 3 * C(i - 1);
}
return 1;
}
int B(int i)
{
if (i > 0) {
return 3 * A(i - 1) + B(i - 1);
}
return 1;
}
int C(int i)
{
if (i > 0) {
return 2 * A(i - 1) + C(i - 1);
}
return 1;
}
}
public static int NestedLocalFunctions(int i)
{

5
ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs

@ -207,11 +207,6 @@ namespace ICSharpCode.Decompiler.CSharp @@ -207,11 +207,6 @@ namespace ICSharpCode.Decompiler.CSharp
}
}
internal bool IsLocalFunction(IMethod method)
{
return settings.LocalFunctions && IL.Transforms.LocalFunctionDecompiler.IsLocalFunctionMethod(method);
}
internal ILFunction ResolveLocalFunction(IMethod method)
{
Debug.Assert(method.IsLocalFunction);

2
ICSharpCode.Decompiler/IL/Transforms/DelegateConstruction.cs

@ -126,7 +126,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -126,7 +126,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms
return null;
if (targetMethod.MetadataToken.IsNil)
return null;
if (LocalFunctionDecompiler.IsLocalFunctionMethod(targetMethod))
if (LocalFunctionDecompiler.IsLocalFunctionMethod(targetMethod, context))
return null;
target = value.Arguments[0];
var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken);

220
ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs

@ -40,6 +40,12 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -40,6 +40,12 @@ namespace ICSharpCode.Decompiler.IL.Transforms
{
ILTransformContext context;
struct LocalFunctionInfo
{
public List<CallInstruction> UseSites;
public ILFunction Definition;
}
/// <summary>
/// The transform works like this:
///
@ -51,67 +57,103 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -51,67 +57,103 @@ namespace ICSharpCode.Decompiler.IL.Transforms
/// After all use-sites are collected we construct the ILAst of the local function and add it to the parent function.
/// Then all use-sites of the local-function are transformed to a call to the <c>LocalFunctionMethod</c> or a ldftn of the <c>LocalFunctionMethod</c>.
/// In a next step we handle all nested local functions.
/// After all local functions are transformed, we move all local functions that do not capture any variables to the top-level function.
/// After all local functions are transformed, we move all local functions that capture any variables to their respective declaration scope.
/// </summary>
public void Run(ILFunction function, ILTransformContext context)
{
if (!context.Settings.LocalFunctions)
return;
this.context = context;
var localFunctions = new Dictionary<IMethod, List<CallInstruction>>();
var localFunctions = new Dictionary<MethodDefinitionHandle, LocalFunctionInfo>();
var cancellationToken = context.CancellationToken;
// Find use-sites
foreach (var inst in function.Descendants) {
// Find all local functions declared inside this method, including nested local functions or local functions declared in lambdas.
FindUseSites(function, context, localFunctions);
foreach (var (method, info) in localFunctions) {
cancellationToken.ThrowIfCancellationRequested();
if (inst is CallInstruction call && IsLocalFunctionMethod(call.Method) && !call.Method.IsLocalFunction) {
if (!localFunctions.TryGetValue(call.Method, out var info)) {
info = new List<CallInstruction>() { call };
localFunctions.Add(call.Method, info);
} else {
info.Add(call);
var firstUseSite = info.UseSites[0];
context.StepStartGroup($"Transform " + info.Definition.Name, info.Definition);
try {
var localFunction = info.Definition;
if (!localFunction.Method.IsStatic) {
var target = firstUseSite.Arguments[0];
context.Step($"Replace 'this' with {target}", localFunction);
var thisVar = localFunction.Variables.SingleOrDefault(VariableKindExtensions.IsThis);
localFunction.AcceptVisitor(new DelegateConstruction.ReplaceDelegateTargetVisitor(target, thisVar));
}
} else if (inst is LdFtn ldftn && !ldftn.Method.IsLocalFunction && ldftn.Parent is NewObj newObj && IsLocalFunctionMethod(ldftn.Method) && DelegateConstruction.IsDelegateConstruction(newObj)) {
context.StepStartGroup($"LocalFunctionDecompiler {ldftn.StartILOffset}", ldftn);
if (!localFunctions.TryGetValue(ldftn.Method, out var info)) {
info = new List<CallInstruction>() { newObj };
localFunctions.Add(ldftn.Method, info);
foreach (var useSite in info.UseSites) {
context.Step("Transform use site at " + useSite.StartILOffset, useSite);
if (useSite.OpCode == OpCode.NewObj) {
TransformToLocalFunctionReference(localFunction, useSite);
} else {
info.Add(newObj);
}
context.StepEndGroup();
DetermineCaptureAndDeclarationScope(localFunction, useSite);
TransformToLocalFunctionInvocation(localFunction.ReducedMethod, useSite);
}
}
foreach (var (method, useSites) in localFunctions) {
context.StepStartGroup($"LocalFunctionDecompiler {useSites[0].StartILOffset}", useSites[0]);
try {
TransformLocalFunction(function, method, useSites);
if (localFunction.DeclarationScope == null) {
localFunction.DeclarationScope = (BlockContainer)function.Body;
} else if (localFunction.DeclarationScope != function.Body && localFunction.DeclarationScope.Parent is ILFunction declaringFunction) {
function.LocalFunctions.Remove(localFunction);
declaringFunction.LocalFunctions.Add(localFunction);
}
} finally {
context.StepEndGroup();
}
}
}
foreach (var f in function.LocalFunctions) {
// handle nested functions
var nestedContext = new ILTransformContext(context, f);
nestedContext.StepStartGroup("LocalFunctionDecompiler (nested functions)", f);
new LocalFunctionDecompiler().Run(f, nestedContext);
nestedContext.StepEndGroup();
void FindUseSites(ILFunction function, ILTransformContext context, Dictionary<MethodDefinitionHandle, LocalFunctionInfo> localFunctions)
{
foreach (var inst in function.Body.Descendants) {
context.CancellationToken.ThrowIfCancellationRequested();
if (inst is CallInstruction call && !call.Method.IsLocalFunction && IsLocalFunctionMethod(call.Method, context)) {
HandleUseSite(call.Method, call);
} else if (inst is LdFtn ldftn && !ldftn.Method.IsLocalFunction && ldftn.Parent is NewObj newObj && IsLocalFunctionMethod(ldftn.Method, context) && DelegateConstruction.IsDelegateConstruction(newObj)) {
HandleUseSite(ldftn.Method, newObj);
}
}
if (function.Kind == ILFunctionKind.TopLevelFunction) {
var movableFunctions = TreeTraversal.PostOrder(function, f => f.LocalFunctions)
.Where(f => f.Kind == ILFunctionKind.LocalFunction && f.DeclarationScope == null)
.ToArray();
foreach (var f in movableFunctions) {
var parent = (ILFunction)f.Parent;
f.DeclarationScope = (BlockContainer)function.Body;
parent.LocalFunctions.Remove(f);
function.LocalFunctions.Add(f);
void HandleUseSite(IMethod targetMethod, CallInstruction inst)
{
if (!localFunctions.TryGetValue((MethodDefinitionHandle)targetMethod.MetadataToken, out var info)) {
context.StepStartGroup($"Read local function '{targetMethod.Name}'", inst);
info = new LocalFunctionInfo() {
UseSites = new List<CallInstruction>() { inst },
Definition = ReadLocalFunctionDefinition(context.Function, targetMethod)
};
localFunctions.Add((MethodDefinitionHandle)targetMethod.MetadataToken, info);
FindUseSites(info.Definition, context, localFunctions);
context.StepEndGroup();
} else {
info.UseSites.Add(inst);
}
}
}
ILFunction ReadLocalFunctionDefinition(ILFunction rootFunction, IMethod targetMethod)
{
var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken);
if (!methodDefinition.HasBody())
return null;
var genericContext = DelegateConstruction.GenericContextFromTypeArguments(targetMethod.Substitution);
if (genericContext == null)
return null;
var ilReader = context.CreateILReader();
var body = context.PEFile.Reader.GetMethodBody(methodDefinition.RelativeVirtualAddress);
var function = ilReader.ReadIL((MethodDefinitionHandle)targetMethod.MetadataToken, body, genericContext.Value, ILFunctionKind.LocalFunction, context.CancellationToken);
// Embed the local function into the parent function's ILAst, so that "Show steps" can show
// how the local function body is being transformed.
rootFunction.LocalFunctions.Add(function);
function.DeclarationScope = (BlockContainer)rootFunction.Body;
function.CheckInvariant(ILPhase.Normal);
var nestedContext = new ILTransformContext(context, function);
function.RunTransforms(CSharpDecompiler.GetILTransforms().TakeWhile(t => !(t is LocalFunctionDecompiler)), nestedContext);
function.DeclarationScope = null;
function.ReducedMethod = ReduceToLocalFunction(targetMethod);
return function;
}
static T FindCommonAncestorInstruction<T>(ILInstruction a, ILInstruction b)
where T : ILInstruction
{
@ -137,71 +179,13 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -137,71 +179,13 @@ namespace ICSharpCode.Decompiler.IL.Transforms
internal static ILInstruction GetStatement(ILInstruction inst)
{
while (inst.Parent != null) {
if (inst.Parent is Block)
if (inst.Parent is Block b && b.Kind == BlockKind.ControlFlow)
return inst;
inst = inst.Parent;
}
return inst;
}
private ILFunction TransformLocalFunction(ILFunction parentFunction, IMethod targetMethod, List<CallInstruction> useSites)
{
var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken);
if (!methodDefinition.HasBody())
return null;
var genericContext = DelegateConstruction.GenericContextFromTypeArguments(targetMethod.Substitution);
if (genericContext == null)
return null;
var function = parentFunction.Ancestors.OfType<ILFunction>().SelectMany(f => f.LocalFunctions).FirstOrDefault(f => f.Method == targetMethod);
if (function == null) {
var ilReader = context.CreateILReader();
var body = context.PEFile.Reader.GetMethodBody(methodDefinition.RelativeVirtualAddress);
function = ilReader.ReadIL((MethodDefinitionHandle)targetMethod.MetadataToken, body, genericContext.Value, ILFunctionKind.LocalFunction, context.CancellationToken);
// Embed the local function into the parent function's ILAst, so that "Show steps" can show
// how the local function body is being transformed.
parentFunction.LocalFunctions.Add(function);
function.DeclarationScope = (BlockContainer)parentFunction.Body;
function.CheckInvariant(ILPhase.Normal);
var nestedContext = new ILTransformContext(context, function);
function.RunTransforms(CSharpDecompiler.GetILTransforms().TakeWhile(t => !(t is LocalFunctionDecompiler)), nestedContext);
if (IsNonLocalTarget(targetMethod, useSites, out var target)) {
Debug.Assert(target != null);
nestedContext.Step("LocalFunctionDecompiler (ReplaceDelegateTargetVisitor)", function);
var thisVar = function.Variables.SingleOrDefault(v => v.Index == -1 && v.Kind == VariableKind.Parameter);
function.AcceptVisitor(new DelegateConstruction.ReplaceDelegateTargetVisitor(target, thisVar));
}
function.DeclarationScope = null;
function.ReducedMethod = ReduceToLocalFunction(targetMethod);
foreach (var innerUseSite in function.Descendants.OfType<CallInstruction>()) {
if (innerUseSite.Method != function.Method)
continue;
if (innerUseSite.OpCode == OpCode.NewObj) {
TransformToLocalFunctionReference(function, innerUseSite);
} else {
TransformToLocalFunctionInvocation(function.ReducedMethod, innerUseSite);
}
}
}
foreach (var useSite in useSites) {
if (useSite.OpCode == OpCode.NewObj) {
TransformToLocalFunctionReference(function, useSite);
} else {
DetermineCaptureAndDeclarationScope(function, useSite);
TransformToLocalFunctionInvocation(function.ReducedMethod, useSite);
}
}
if (function.DeclarationScope != null
&& parentFunction.LocalFunctions.Contains(function)
&& function.DeclarationScope.Parent is ILFunction betterParentFunction) {
parentFunction.LocalFunctions.Remove(function);
betterParentFunction.LocalFunctions.Add(function);
}
return function;
}
LocalFunctionMethod ReduceToLocalFunction(IMethod method)
{
int parametersToRemove = 0;
@ -272,47 +256,28 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -272,47 +256,28 @@ namespace ICSharpCode.Decompiler.IL.Transforms
}
}
bool IsNonLocalTarget(IMethod targetMethod, List<CallInstruction> useSites, out ILInstruction target)
{
target = null;
if (targetMethod.IsStatic)
return false;
ValidateUseSites(useSites);
target = useSites.Select(call => call.Arguments.First()).First();
return !target.MatchLdThis();
}
[Conditional("DEBUG")]
static void ValidateUseSites(List<CallInstruction> useSites)
{
ILInstruction targetInstruction = null;
foreach (var site in useSites) {
if (targetInstruction == null)
targetInstruction = site.Arguments.First();
else
Debug.Assert(targetInstruction.Match(site.Arguments[0]).Success);
}
}
internal static bool IsLocalFunctionReference(NewObj inst)
internal static bool IsLocalFunctionReference(NewObj inst, ILTransformContext context)
{
if (inst == null || inst.Arguments.Count != 2 || inst.Method.DeclaringType.Kind != TypeKind.Delegate)
return false;
var opCode = inst.Arguments[1].OpCode;
return (opCode == OpCode.LdFtn || opCode == OpCode.LdVirtFtn)
&& IsLocalFunctionMethod(((IInstructionWithMethodOperand)inst.Arguments[1]).Method);
&& IsLocalFunctionMethod(((IInstructionWithMethodOperand)inst.Arguments[1]).Method, context);
}
public static bool IsLocalFunctionMethod(IMethod method)
public static bool IsLocalFunctionMethod(IMethod method, ILTransformContext context)
{
if (method.MetadataToken.IsNil)
return false;
return IsLocalFunctionMethod(method.ParentModule.PEFile, (MethodDefinitionHandle)method.MetadataToken);
return IsLocalFunctionMethod(method.ParentModule.PEFile, (MethodDefinitionHandle)method.MetadataToken, context);
}
public static bool IsLocalFunctionMethod(PEFile module, MethodDefinitionHandle methodHandle)
public static bool IsLocalFunctionMethod(PEFile module, MethodDefinitionHandle methodHandle, ILTransformContext context = null)
{
if (context != null && context.PEFile != module)
return false;
var metadata = module.Metadata;
var method = metadata.GetMethodDefinition(methodHandle);
var declaringType = method.GetDeclaringType();
@ -326,12 +291,15 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -326,12 +291,15 @@ namespace ICSharpCode.Decompiler.IL.Transforms
return true;
}
public static bool IsLocalFunctionDisplayClass(PEFile module, TypeDefinitionHandle typeHandle)
public static bool IsLocalFunctionDisplayClass(PEFile module, TypeDefinitionHandle typeHandle, ILTransformContext context = null)
{
if (context != null && context.PEFile != module)
return false;
var metadata = module.Metadata;
var type = metadata.GetTypeDefinition(typeHandle);
if ((type.Attributes & TypeAttributes.NestedPrivate) == 0)
if ((type.Attributes & TypeAttributes.VisibilityMask) != TypeAttributes.NestedPrivate)
return false;
if (!type.HasGeneratedName(metadata))
return false;
@ -340,7 +308,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -340,7 +308,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms
var declaringType = metadata.GetTypeDefinition(declaringTypeHandle);
foreach (var method in declaringType.GetMethods()) {
if (!IsLocalFunctionMethod(module, method))
if (!IsLocalFunctionMethod(module, method, context))
continue;
var md = metadata.GetMethodDefinition(method);
if (md.DecodeSignature(new FindTypeDecoder(typeHandle), default).ParameterTypes.Any())

8
ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs

@ -39,6 +39,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -39,6 +39,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms
public ITypeDefinition Definition;
public Dictionary<IField, DisplayClassVariable> Variables;
public BlockContainer CaptureScope;
public ILFunction DeclaringFunction;
}
struct DisplayClassVariable
@ -105,7 +106,8 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -105,7 +106,8 @@ namespace ICSharpCode.Decompiler.IL.Transforms
Variable = v,
Definition = closureType,
Variables = new Dictionary<IField, DisplayClassVariable>(),
CaptureScope = (isMono && IsMonoNestedCaptureScope(closureType)) || localFunctionClosureParameter ? null : v.CaptureScope
CaptureScope = (isMono && IsMonoNestedCaptureScope(closureType)) || localFunctionClosureParameter ? null : v.CaptureScope,
DeclaringFunction = localFunctionClosureParameter ? f.DeclarationScope.Ancestors.OfType<ILFunction>().First() : f
});
} else {
if (displayClass.CaptureScope == null && !localFunctionClosureParameter)
@ -296,10 +298,10 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -296,10 +298,10 @@ namespace ICSharpCode.Decompiler.IL.Transforms
} else {
Debug.Assert(displayClass.Definition == field.DeclaringTypeDefinition);
// Introduce a fresh variable for the display class field.
v = currentFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name);
if (displayClass.IsMono && displayClass.CaptureScope == null && !IsOuterClosureReference(field)) {
displayClass.CaptureScope = BlockContainer.FindClosestContainer(inst);
}
v = displayClass.DeclaringFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name);
v.CaptureScope = displayClass.CaptureScope;
inst.ReplaceWith(new StLoc(v, inst.Value).WithILRange(inst));
value = new LdLoc(v);
@ -346,7 +348,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -346,7 +348,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms
if (!displayClass.Variables.TryGetValue(field, out DisplayClassVariable info)) {
// Introduce a fresh variable for the display class field.
Debug.Assert(displayClass.Definition == field.DeclaringTypeDefinition);
var v = currentFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name);
var v = displayClass.DeclaringFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name);
v.CaptureScope = displayClass.CaptureScope;
inst.ReplaceWith(new LdLoca(v).WithILRange(inst));
displayClass.Variables.Add(field, new DisplayClassVariable { Value = new LdLoc(v), Variable = v });

Loading…
Cancel
Save