Browse Source

Refactor LocalFunctionDecompiler to allow mutually recursive local functions to be decompiled correctly.

pull/1586/head
Siegfried Pammer 6 years ago
parent
commit
a109b77858
  1. 28
      ICSharpCode.Decompiler.Tests/TestCases/Pretty/LocalFunctions.cs
  2. 5
      ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs
  3. 2
      ICSharpCode.Decompiler/IL/Transforms/DelegateConstruction.cs
  4. 220
      ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs
  5. 8
      ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs

28
ICSharpCode.Decompiler.Tests/TestCases/Pretty/LocalFunctions.cs

@ -217,6 +217,34 @@ namespace LocalFunctions
return FibHelper(n - 1) + FibHelper(n - 2); return FibHelper(n - 1) + FibHelper(n - 2);
} }
} }
public int MutuallyRecursiveLocalFunctions()
{
return B(4) + C(3);
int A(int i)
{
if (i > 0) {
return A(i - 1) + 2 * B(i - 1) + 3 * C(i - 1);
}
return 1;
}
int B(int i)
{
if (i > 0) {
return 3 * A(i - 1) + B(i - 1);
}
return 1;
}
int C(int i)
{
if (i > 0) {
return 2 * A(i - 1) + C(i - 1);
}
return 1;
}
}
public static int NestedLocalFunctions(int i) public static int NestedLocalFunctions(int i)
{ {

5
ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs

@ -207,11 +207,6 @@ namespace ICSharpCode.Decompiler.CSharp
} }
} }
internal bool IsLocalFunction(IMethod method)
{
return settings.LocalFunctions && IL.Transforms.LocalFunctionDecompiler.IsLocalFunctionMethod(method);
}
internal ILFunction ResolveLocalFunction(IMethod method) internal ILFunction ResolveLocalFunction(IMethod method)
{ {
Debug.Assert(method.IsLocalFunction); Debug.Assert(method.IsLocalFunction);

2
ICSharpCode.Decompiler/IL/Transforms/DelegateConstruction.cs

@ -126,7 +126,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms
return null; return null;
if (targetMethod.MetadataToken.IsNil) if (targetMethod.MetadataToken.IsNil)
return null; return null;
if (LocalFunctionDecompiler.IsLocalFunctionMethod(targetMethod)) if (LocalFunctionDecompiler.IsLocalFunctionMethod(targetMethod, context))
return null; return null;
target = value.Arguments[0]; target = value.Arguments[0];
var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken); var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken);

220
ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs

@ -40,6 +40,12 @@ namespace ICSharpCode.Decompiler.IL.Transforms
{ {
ILTransformContext context; ILTransformContext context;
struct LocalFunctionInfo
{
public List<CallInstruction> UseSites;
public ILFunction Definition;
}
/// <summary> /// <summary>
/// The transform works like this: /// The transform works like this:
/// ///
@ -51,67 +57,103 @@ namespace ICSharpCode.Decompiler.IL.Transforms
/// After all use-sites are collected we construct the ILAst of the local function and add it to the parent function. /// After all use-sites are collected we construct the ILAst of the local function and add it to the parent function.
/// Then all use-sites of the local-function are transformed to a call to the <c>LocalFunctionMethod</c> or a ldftn of the <c>LocalFunctionMethod</c>. /// Then all use-sites of the local-function are transformed to a call to the <c>LocalFunctionMethod</c> or a ldftn of the <c>LocalFunctionMethod</c>.
/// In a next step we handle all nested local functions. /// In a next step we handle all nested local functions.
/// After all local functions are transformed, we move all local functions that do not capture any variables to the top-level function. /// After all local functions are transformed, we move all local functions that capture any variables to their respective declaration scope.
/// </summary> /// </summary>
public void Run(ILFunction function, ILTransformContext context) public void Run(ILFunction function, ILTransformContext context)
{ {
if (!context.Settings.LocalFunctions) if (!context.Settings.LocalFunctions)
return; return;
this.context = context; this.context = context;
var localFunctions = new Dictionary<IMethod, List<CallInstruction>>(); var localFunctions = new Dictionary<MethodDefinitionHandle, LocalFunctionInfo>();
var cancellationToken = context.CancellationToken; var cancellationToken = context.CancellationToken;
// Find use-sites // Find all local functions declared inside this method, including nested local functions or local functions declared in lambdas.
foreach (var inst in function.Descendants) { FindUseSites(function, context, localFunctions);
foreach (var (method, info) in localFunctions) {
cancellationToken.ThrowIfCancellationRequested(); cancellationToken.ThrowIfCancellationRequested();
if (inst is CallInstruction call && IsLocalFunctionMethod(call.Method) && !call.Method.IsLocalFunction) { var firstUseSite = info.UseSites[0];
if (!localFunctions.TryGetValue(call.Method, out var info)) { context.StepStartGroup($"Transform " + info.Definition.Name, info.Definition);
info = new List<CallInstruction>() { call }; try {
localFunctions.Add(call.Method, info); var localFunction = info.Definition;
} else { if (!localFunction.Method.IsStatic) {
info.Add(call); var target = firstUseSite.Arguments[0];
context.Step($"Replace 'this' with {target}", localFunction);
var thisVar = localFunction.Variables.SingleOrDefault(VariableKindExtensions.IsThis);
localFunction.AcceptVisitor(new DelegateConstruction.ReplaceDelegateTargetVisitor(target, thisVar));
} }
} else if (inst is LdFtn ldftn && !ldftn.Method.IsLocalFunction && ldftn.Parent is NewObj newObj && IsLocalFunctionMethod(ldftn.Method) && DelegateConstruction.IsDelegateConstruction(newObj)) {
context.StepStartGroup($"LocalFunctionDecompiler {ldftn.StartILOffset}", ldftn); foreach (var useSite in info.UseSites) {
if (!localFunctions.TryGetValue(ldftn.Method, out var info)) { context.Step("Transform use site at " + useSite.StartILOffset, useSite);
info = new List<CallInstruction>() { newObj }; if (useSite.OpCode == OpCode.NewObj) {
localFunctions.Add(ldftn.Method, info); TransformToLocalFunctionReference(localFunction, useSite);
} else { } else {
info.Add(newObj); DetermineCaptureAndDeclarationScope(localFunction, useSite);
} TransformToLocalFunctionInvocation(localFunction.ReducedMethod, useSite);
context.StepEndGroup();
} }
} }
foreach (var (method, useSites) in localFunctions) { if (localFunction.DeclarationScope == null) {
context.StepStartGroup($"LocalFunctionDecompiler {useSites[0].StartILOffset}", useSites[0]); localFunction.DeclarationScope = (BlockContainer)function.Body;
try { } else if (localFunction.DeclarationScope != function.Body && localFunction.DeclarationScope.Parent is ILFunction declaringFunction) {
TransformLocalFunction(function, method, useSites); function.LocalFunctions.Remove(localFunction);
declaringFunction.LocalFunctions.Add(localFunction);
}
} finally { } finally {
context.StepEndGroup(); context.StepEndGroup();
} }
} }
}
foreach (var f in function.LocalFunctions) { void FindUseSites(ILFunction function, ILTransformContext context, Dictionary<MethodDefinitionHandle, LocalFunctionInfo> localFunctions)
// handle nested functions {
var nestedContext = new ILTransformContext(context, f); foreach (var inst in function.Body.Descendants) {
nestedContext.StepStartGroup("LocalFunctionDecompiler (nested functions)", f); context.CancellationToken.ThrowIfCancellationRequested();
new LocalFunctionDecompiler().Run(f, nestedContext); if (inst is CallInstruction call && !call.Method.IsLocalFunction && IsLocalFunctionMethod(call.Method, context)) {
nestedContext.StepEndGroup(); HandleUseSite(call.Method, call);
} else if (inst is LdFtn ldftn && !ldftn.Method.IsLocalFunction && ldftn.Parent is NewObj newObj && IsLocalFunctionMethod(ldftn.Method, context) && DelegateConstruction.IsDelegateConstruction(newObj)) {
HandleUseSite(ldftn.Method, newObj);
}
} }
if (function.Kind == ILFunctionKind.TopLevelFunction) { void HandleUseSite(IMethod targetMethod, CallInstruction inst)
var movableFunctions = TreeTraversal.PostOrder(function, f => f.LocalFunctions) {
.Where(f => f.Kind == ILFunctionKind.LocalFunction && f.DeclarationScope == null) if (!localFunctions.TryGetValue((MethodDefinitionHandle)targetMethod.MetadataToken, out var info)) {
.ToArray(); context.StepStartGroup($"Read local function '{targetMethod.Name}'", inst);
foreach (var f in movableFunctions) { info = new LocalFunctionInfo() {
var parent = (ILFunction)f.Parent; UseSites = new List<CallInstruction>() { inst },
f.DeclarationScope = (BlockContainer)function.Body; Definition = ReadLocalFunctionDefinition(context.Function, targetMethod)
parent.LocalFunctions.Remove(f); };
function.LocalFunctions.Add(f); localFunctions.Add((MethodDefinitionHandle)targetMethod.MetadataToken, info);
FindUseSites(info.Definition, context, localFunctions);
context.StepEndGroup();
} else {
info.UseSites.Add(inst);
} }
} }
} }
ILFunction ReadLocalFunctionDefinition(ILFunction rootFunction, IMethod targetMethod)
{
var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken);
if (!methodDefinition.HasBody())
return null;
var genericContext = DelegateConstruction.GenericContextFromTypeArguments(targetMethod.Substitution);
if (genericContext == null)
return null;
var ilReader = context.CreateILReader();
var body = context.PEFile.Reader.GetMethodBody(methodDefinition.RelativeVirtualAddress);
var function = ilReader.ReadIL((MethodDefinitionHandle)targetMethod.MetadataToken, body, genericContext.Value, ILFunctionKind.LocalFunction, context.CancellationToken);
// Embed the local function into the parent function's ILAst, so that "Show steps" can show
// how the local function body is being transformed.
rootFunction.LocalFunctions.Add(function);
function.DeclarationScope = (BlockContainer)rootFunction.Body;
function.CheckInvariant(ILPhase.Normal);
var nestedContext = new ILTransformContext(context, function);
function.RunTransforms(CSharpDecompiler.GetILTransforms().TakeWhile(t => !(t is LocalFunctionDecompiler)), nestedContext);
function.DeclarationScope = null;
function.ReducedMethod = ReduceToLocalFunction(targetMethod);
return function;
}
static T FindCommonAncestorInstruction<T>(ILInstruction a, ILInstruction b) static T FindCommonAncestorInstruction<T>(ILInstruction a, ILInstruction b)
where T : ILInstruction where T : ILInstruction
{ {
@ -137,71 +179,13 @@ namespace ICSharpCode.Decompiler.IL.Transforms
internal static ILInstruction GetStatement(ILInstruction inst) internal static ILInstruction GetStatement(ILInstruction inst)
{ {
while (inst.Parent != null) { while (inst.Parent != null) {
if (inst.Parent is Block) if (inst.Parent is Block b && b.Kind == BlockKind.ControlFlow)
return inst; return inst;
inst = inst.Parent; inst = inst.Parent;
} }
return inst; return inst;
} }
private ILFunction TransformLocalFunction(ILFunction parentFunction, IMethod targetMethod, List<CallInstruction> useSites)
{
var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken);
if (!methodDefinition.HasBody())
return null;
var genericContext = DelegateConstruction.GenericContextFromTypeArguments(targetMethod.Substitution);
if (genericContext == null)
return null;
var function = parentFunction.Ancestors.OfType<ILFunction>().SelectMany(f => f.LocalFunctions).FirstOrDefault(f => f.Method == targetMethod);
if (function == null) {
var ilReader = context.CreateILReader();
var body = context.PEFile.Reader.GetMethodBody(methodDefinition.RelativeVirtualAddress);
function = ilReader.ReadIL((MethodDefinitionHandle)targetMethod.MetadataToken, body, genericContext.Value, ILFunctionKind.LocalFunction, context.CancellationToken);
// Embed the local function into the parent function's ILAst, so that "Show steps" can show
// how the local function body is being transformed.
parentFunction.LocalFunctions.Add(function);
function.DeclarationScope = (BlockContainer)parentFunction.Body;
function.CheckInvariant(ILPhase.Normal);
var nestedContext = new ILTransformContext(context, function);
function.RunTransforms(CSharpDecompiler.GetILTransforms().TakeWhile(t => !(t is LocalFunctionDecompiler)), nestedContext);
if (IsNonLocalTarget(targetMethod, useSites, out var target)) {
Debug.Assert(target != null);
nestedContext.Step("LocalFunctionDecompiler (ReplaceDelegateTargetVisitor)", function);
var thisVar = function.Variables.SingleOrDefault(v => v.Index == -1 && v.Kind == VariableKind.Parameter);
function.AcceptVisitor(new DelegateConstruction.ReplaceDelegateTargetVisitor(target, thisVar));
}
function.DeclarationScope = null;
function.ReducedMethod = ReduceToLocalFunction(targetMethod);
foreach (var innerUseSite in function.Descendants.OfType<CallInstruction>()) {
if (innerUseSite.Method != function.Method)
continue;
if (innerUseSite.OpCode == OpCode.NewObj) {
TransformToLocalFunctionReference(function, innerUseSite);
} else {
TransformToLocalFunctionInvocation(function.ReducedMethod, innerUseSite);
}
}
}
foreach (var useSite in useSites) {
if (useSite.OpCode == OpCode.NewObj) {
TransformToLocalFunctionReference(function, useSite);
} else {
DetermineCaptureAndDeclarationScope(function, useSite);
TransformToLocalFunctionInvocation(function.ReducedMethod, useSite);
}
}
if (function.DeclarationScope != null
&& parentFunction.LocalFunctions.Contains(function)
&& function.DeclarationScope.Parent is ILFunction betterParentFunction) {
parentFunction.LocalFunctions.Remove(function);
betterParentFunction.LocalFunctions.Add(function);
}
return function;
}
LocalFunctionMethod ReduceToLocalFunction(IMethod method) LocalFunctionMethod ReduceToLocalFunction(IMethod method)
{ {
int parametersToRemove = 0; int parametersToRemove = 0;
@ -272,47 +256,28 @@ namespace ICSharpCode.Decompiler.IL.Transforms
} }
} }
bool IsNonLocalTarget(IMethod targetMethod, List<CallInstruction> useSites, out ILInstruction target) internal static bool IsLocalFunctionReference(NewObj inst, ILTransformContext context)
{
target = null;
if (targetMethod.IsStatic)
return false;
ValidateUseSites(useSites);
target = useSites.Select(call => call.Arguments.First()).First();
return !target.MatchLdThis();
}
[Conditional("DEBUG")]
static void ValidateUseSites(List<CallInstruction> useSites)
{
ILInstruction targetInstruction = null;
foreach (var site in useSites) {
if (targetInstruction == null)
targetInstruction = site.Arguments.First();
else
Debug.Assert(targetInstruction.Match(site.Arguments[0]).Success);
}
}
internal static bool IsLocalFunctionReference(NewObj inst)
{ {
if (inst == null || inst.Arguments.Count != 2 || inst.Method.DeclaringType.Kind != TypeKind.Delegate) if (inst == null || inst.Arguments.Count != 2 || inst.Method.DeclaringType.Kind != TypeKind.Delegate)
return false; return false;
var opCode = inst.Arguments[1].OpCode; var opCode = inst.Arguments[1].OpCode;
return (opCode == OpCode.LdFtn || opCode == OpCode.LdVirtFtn) return (opCode == OpCode.LdFtn || opCode == OpCode.LdVirtFtn)
&& IsLocalFunctionMethod(((IInstructionWithMethodOperand)inst.Arguments[1]).Method); && IsLocalFunctionMethod(((IInstructionWithMethodOperand)inst.Arguments[1]).Method, context);
} }
public static bool IsLocalFunctionMethod(IMethod method) public static bool IsLocalFunctionMethod(IMethod method, ILTransformContext context)
{ {
if (method.MetadataToken.IsNil) if (method.MetadataToken.IsNil)
return false; return false;
return IsLocalFunctionMethod(method.ParentModule.PEFile, (MethodDefinitionHandle)method.MetadataToken); return IsLocalFunctionMethod(method.ParentModule.PEFile, (MethodDefinitionHandle)method.MetadataToken, context);
} }
public static bool IsLocalFunctionMethod(PEFile module, MethodDefinitionHandle methodHandle) public static bool IsLocalFunctionMethod(PEFile module, MethodDefinitionHandle methodHandle, ILTransformContext context = null)
{ {
if (context != null && context.PEFile != module)
return false;
var metadata = module.Metadata; var metadata = module.Metadata;
var method = metadata.GetMethodDefinition(methodHandle); var method = metadata.GetMethodDefinition(methodHandle);
var declaringType = method.GetDeclaringType(); var declaringType = method.GetDeclaringType();
@ -326,12 +291,15 @@ namespace ICSharpCode.Decompiler.IL.Transforms
return true; return true;
} }
public static bool IsLocalFunctionDisplayClass(PEFile module, TypeDefinitionHandle typeHandle) public static bool IsLocalFunctionDisplayClass(PEFile module, TypeDefinitionHandle typeHandle, ILTransformContext context = null)
{ {
if (context != null && context.PEFile != module)
return false;
var metadata = module.Metadata; var metadata = module.Metadata;
var type = metadata.GetTypeDefinition(typeHandle); var type = metadata.GetTypeDefinition(typeHandle);
if ((type.Attributes & TypeAttributes.NestedPrivate) == 0) if ((type.Attributes & TypeAttributes.VisibilityMask) != TypeAttributes.NestedPrivate)
return false; return false;
if (!type.HasGeneratedName(metadata)) if (!type.HasGeneratedName(metadata))
return false; return false;
@ -340,7 +308,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms
var declaringType = metadata.GetTypeDefinition(declaringTypeHandle); var declaringType = metadata.GetTypeDefinition(declaringTypeHandle);
foreach (var method in declaringType.GetMethods()) { foreach (var method in declaringType.GetMethods()) {
if (!IsLocalFunctionMethod(module, method)) if (!IsLocalFunctionMethod(module, method, context))
continue; continue;
var md = metadata.GetMethodDefinition(method); var md = metadata.GetMethodDefinition(method);
if (md.DecodeSignature(new FindTypeDecoder(typeHandle), default).ParameterTypes.Any()) if (md.DecodeSignature(new FindTypeDecoder(typeHandle), default).ParameterTypes.Any())

8
ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs

@ -39,6 +39,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms
public ITypeDefinition Definition; public ITypeDefinition Definition;
public Dictionary<IField, DisplayClassVariable> Variables; public Dictionary<IField, DisplayClassVariable> Variables;
public BlockContainer CaptureScope; public BlockContainer CaptureScope;
public ILFunction DeclaringFunction;
} }
struct DisplayClassVariable struct DisplayClassVariable
@ -105,7 +106,8 @@ namespace ICSharpCode.Decompiler.IL.Transforms
Variable = v, Variable = v,
Definition = closureType, Definition = closureType,
Variables = new Dictionary<IField, DisplayClassVariable>(), Variables = new Dictionary<IField, DisplayClassVariable>(),
CaptureScope = (isMono && IsMonoNestedCaptureScope(closureType)) || localFunctionClosureParameter ? null : v.CaptureScope CaptureScope = (isMono && IsMonoNestedCaptureScope(closureType)) || localFunctionClosureParameter ? null : v.CaptureScope,
DeclaringFunction = localFunctionClosureParameter ? f.DeclarationScope.Ancestors.OfType<ILFunction>().First() : f
}); });
} else { } else {
if (displayClass.CaptureScope == null && !localFunctionClosureParameter) if (displayClass.CaptureScope == null && !localFunctionClosureParameter)
@ -296,10 +298,10 @@ namespace ICSharpCode.Decompiler.IL.Transforms
} else { } else {
Debug.Assert(displayClass.Definition == field.DeclaringTypeDefinition); Debug.Assert(displayClass.Definition == field.DeclaringTypeDefinition);
// Introduce a fresh variable for the display class field. // Introduce a fresh variable for the display class field.
v = currentFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name);
if (displayClass.IsMono && displayClass.CaptureScope == null && !IsOuterClosureReference(field)) { if (displayClass.IsMono && displayClass.CaptureScope == null && !IsOuterClosureReference(field)) {
displayClass.CaptureScope = BlockContainer.FindClosestContainer(inst); displayClass.CaptureScope = BlockContainer.FindClosestContainer(inst);
} }
v = displayClass.DeclaringFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name);
v.CaptureScope = displayClass.CaptureScope; v.CaptureScope = displayClass.CaptureScope;
inst.ReplaceWith(new StLoc(v, inst.Value).WithILRange(inst)); inst.ReplaceWith(new StLoc(v, inst.Value).WithILRange(inst));
value = new LdLoc(v); value = new LdLoc(v);
@ -346,7 +348,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms
if (!displayClass.Variables.TryGetValue(field, out DisplayClassVariable info)) { if (!displayClass.Variables.TryGetValue(field, out DisplayClassVariable info)) {
// Introduce a fresh variable for the display class field. // Introduce a fresh variable for the display class field.
Debug.Assert(displayClass.Definition == field.DeclaringTypeDefinition); Debug.Assert(displayClass.Definition == field.DeclaringTypeDefinition);
var v = currentFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name); var v = displayClass.DeclaringFunction.RegisterVariable(VariableKind.Local, field.Type, field.Name);
v.CaptureScope = displayClass.CaptureScope; v.CaptureScope = displayClass.CaptureScope;
inst.ReplaceWith(new LdLoca(v).WithILRange(inst)); inst.ReplaceWith(new LdLoca(v).WithILRange(inst));
displayClass.Variables.Add(field, new DisplayClassVariable { Value = new LdLoc(v), Variable = v }); displayClass.Variables.Add(field, new DisplayClassVariable { Value = new LdLoc(v), Variable = v });

Loading…
Cancel
Save