From 5e6fecebf5046a01b62824886768d6357c7a6b5b Mon Sep 17 00:00:00 2001 From: SilverFox Date: Thu, 14 Nov 2019 23:28:24 +0800 Subject: [PATCH] Rework support for generic local function, and fix tests `LocalFunctions.Generic.Test_CaptureT` and `LocalFunctions.Generic.TestGenericArgs` --- ICSharpCode.Decompiler/CSharp/CallBuilder.cs | 35 ++--- .../CSharp/ExpressionBuilder.cs | 4 +- .../CSharp/StatementBuilder.cs | 5 +- .../IL/Transforms/LocalFunctionDecompiler.cs | 147 +++++++++++++++--- .../Implementation/LocalFunctionMethod.cs | 4 + 5 files changed, 143 insertions(+), 52 deletions(-) diff --git a/ICSharpCode.Decompiler/CSharp/CallBuilder.cs b/ICSharpCode.Decompiler/CSharp/CallBuilder.cs index ea43abd63..0d6fde5da 100644 --- a/ICSharpCode.Decompiler/CSharp/CallBuilder.cs +++ b/ICSharpCode.Decompiler/CSharp/CallBuilder.cs @@ -194,19 +194,10 @@ namespace ICSharpCode.Decompiler.CSharp TranslatedExpression target; if (callOpCode == OpCode.NewObj) { target = default(TranslatedExpression); // no target - } else if (method.IsLocalFunction && localFunction != null) { + } else if (localFunction != null) { var ide = new IdentifierExpression(localFunction.Name); if (method.TypeArguments.Count > 0) { - var parentMethod = ((ILFunction)localFunction.Parent).Method; - int skipCount = parentMethod.DeclaringType.TypeParameterCount + parentMethod.TypeParameters.Count - localFunction.Method.DeclaringType.TypeParameterCount; -#if DEBUG - Debug.Assert(skipCount >= 0); - if (skipCount > 0) { - var currentMethod = expressionBuilder.currentFunction.Method; - currentMethod = currentMethod.ReducedFrom ?? currentMethod; - Debug.Assert(currentMethod.DeclaringType.TypeParameters.Concat(currentMethod.TypeParameters).Take(skipCount).SequenceEqual(method.DeclaringType.TypeArguments.Concat(method.TypeArguments).Take(skipCount))); - } -#endif + int skipCount = localFunction.ReducedMethod.NumberOfCompilerGeneratedGenerics; ide.TypeArguments.AddRange(method.TypeArguments.Skip(skipCount).Select(expressionBuilder.ConvertType)); } target = ide.WithoutILInstruction() @@ -238,7 +229,7 @@ namespace ICSharpCode.Decompiler.CSharp var argumentList = BuildArgumentList(expectedTargetDetails, target.ResolveResult, method, firstParamIndex, callArguments, argumentToParameterMap); - if (method.IsLocalFunction) { + if (localFunction != null) { return new InvocationExpression(target, argumentList.GetArgumentExpressions()) .WithRR(new CSharpInvocationResolveResult(target.ResolveResult, method, argumentList.GetArgumentResolveResults().ToList(), isExpandedForm: argumentList.IsExpandedForm)); @@ -1368,11 +1359,12 @@ namespace ICSharpCode.Decompiler.CSharp memberDeclaringType: method.DeclaringType); requireTarget = expressionBuilder.HidesVariableWithName(method.Name) || (method.IsStatic ? !expressionBuilder.IsCurrentOrContainingType(method.DeclaringTypeDefinition) : !(target.Expression is ThisReferenceExpression)); - + step = requireTarget ? 1 : 0; var savedTarget = target; - for (step = requireTarget ? 1 : 0; step < 7; step++) { + for (; step < 7; step++) { ResolveResult targetResolveResult; - if (!method.IsLocalFunction && (step & 1) != 0) { + //TODO: why there is an check for IsLocalFunction here, it should be unreachable in old code + if (localFunction == null && (step & 1) != 0) { targetResolveResult = savedTarget.ResolveResult; target = savedTarget; } else { @@ -1395,7 +1387,7 @@ namespace ICSharpCode.Decompiler.CSharp break; } } - requireTarget = !method.IsLocalFunction && (step & 1) != 0; + requireTarget = localFunction == null && (step & 1) != 0; ExpressionWithResolveResult targetExpression; Debug.Assert(result != null); if (requireTarget) { @@ -1409,16 +1401,7 @@ namespace ICSharpCode.Decompiler.CSharp if ((step & 2) != 0) { int skipCount = 0; if (localFunction != null && method.TypeArguments.Count > 0) { - var parentMethod = ((ILFunction)localFunction.Parent).Method; - skipCount = parentMethod.DeclaringType.TypeParameterCount + parentMethod.TypeParameters.Count - localFunction.Method.DeclaringType.TypeParameterCount; -#if DEBUG - Debug.Assert(skipCount >= 0); - if (skipCount > 0) { - var currentMethod = expressionBuilder.currentFunction.Method; - currentMethod = currentMethod.ReducedFrom ?? currentMethod; - Debug.Assert(currentMethod.DeclaringType.TypeParameters.Concat(currentMethod.TypeParameters).Take(skipCount).SequenceEqual(method.DeclaringType.TypeArguments.Concat(method.TypeArguments).Take(skipCount))); - } -#endif + skipCount = localFunction.ReducedMethod.NumberOfCompilerGeneratedGenerics; } ide.TypeArguments.AddRange(method.TypeArguments.Skip(skipCount).Select(expressionBuilder.ConvertType)); } diff --git a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs index d6f012783..cee0b3531 100644 --- a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs +++ b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs @@ -217,9 +217,9 @@ namespace ICSharpCode.Decompiler.CSharp internal ILFunction ResolveLocalFunction(IMethod method) { Debug.Assert(method.IsLocalFunction); - method = (IMethod)method.ReducedFrom.MemberDefinition; + method = (IMethod)((IMethod)method.MemberDefinition).ReducedFrom.MemberDefinition; foreach (var parent in currentFunction.Ancestors.OfType()) { - var definition = parent.LocalFunctions.FirstOrDefault(f => f.Method.Equals(method)); + var definition = parent.LocalFunctions.FirstOrDefault(f => f.Method.MemberDefinition.Equals(method)); if (definition != null) { return definition; } diff --git a/ICSharpCode.Decompiler/CSharp/StatementBuilder.cs b/ICSharpCode.Decompiler/CSharp/StatementBuilder.cs index ea0076215..1c58f6e92 100644 --- a/ICSharpCode.Decompiler/CSharp/StatementBuilder.cs +++ b/ICSharpCode.Decompiler/CSharp/StatementBuilder.cs @@ -988,10 +988,9 @@ namespace ICSharpCode.Decompiler.CSharp stmt.ReturnType = exprBuilder.ConvertType(function.Method.ReturnType); stmt.Body = nestedBuilder.ConvertAsBlock(function.Body); if (function.Method.TypeParameters.Count > 0) { - var parentMethod = ((ILFunction)function.Parent).Method; - int skipCount = parentMethod.DeclaringType.TypeParameterCount + parentMethod.TypeParameters.Count - function.Method.DeclaringType.TypeParameterCount; var astBuilder = exprBuilder.astBuilder; if (astBuilder.ShowTypeParameters) { + int skipCount = function.ReducedMethod.NumberOfCompilerGeneratedGenerics; stmt.TypeParameters.AddRange(function.Method.TypeParameters.Skip(skipCount).Select(t => astBuilder.ConvertTypeParameter(t))); if (astBuilder.ShowTypeParameterConstraints) { stmt.Constraints.AddRange(function.Method.TypeParameters.Skip(skipCount).Select(t => astBuilder.ConvertTypeParameterConstraint(t)).Where(c => c != null)); @@ -1001,7 +1000,7 @@ namespace ICSharpCode.Decompiler.CSharp if (function.IsAsync) { stmt.Modifiers |= Modifiers.Async; } - if (settings.StaticLocalFunctions && function.Method.IsStatic && function.ReducedMethod.NumberOfCompilerGeneratedParameters == 0) { + if (settings.StaticLocalFunctions && function.ReducedMethod.IsStaticLocalFunction) { stmt.Modifiers |= Modifiers.Static; } stmt.AddAnnotation(new MemberResolveResult(null, function.ReducedMethod)); diff --git a/ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs b/ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs index ee62f568b..f71ccf4ce 100644 --- a/ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs +++ b/ICSharpCode.Decompiler/IL/Transforms/LocalFunctionDecompiler.cs @@ -80,25 +80,18 @@ namespace ICSharpCode.Decompiler.IL.Transforms continue; } - var firstUseSite = info.UseSites[0]; context.StepStartGroup($"Transform " + info.Definition.Name, info.Definition); try { var localFunction = info.Definition; if (!localFunction.Method.IsStatic) { - var target = firstUseSite.Arguments[0]; + var target = info.UseSites[0].Arguments[0]; context.Step($"Replace 'this' with {target}", localFunction); var thisVar = localFunction.Variables.SingleOrDefault(VariableKindExtensions.IsThis); localFunction.AcceptVisitor(new DelegateConstruction.ReplaceDelegateTargetVisitor(target, thisVar)); } foreach (var useSite in info.UseSites) { - context.Step($"Transform use site at IL_{useSite.StartILOffset:x4}", useSite); DetermineCaptureAndDeclarationScope(localFunction, useSite); - if (useSite.OpCode == OpCode.NewObj) { - TransformToLocalFunctionReference(localFunction, useSite); - } else { - TransformToLocalFunctionInvocation(localFunction.ReducedMethod, useSite); - } if (function.Method.IsConstructor && localFunction.DeclarationScope == null) { localFunction.DeclarationScope = BlockContainer.FindClosestContainer(useSite); @@ -111,12 +104,69 @@ namespace ICSharpCode.Decompiler.IL.Transforms function.LocalFunctions.Remove(localFunction); declaringFunction.LocalFunctions.Add(localFunction); } + + if (TryValidateSkipCount(info, out int skipCount) && skipCount != localFunction.ReducedMethod.NumberOfCompilerGeneratedGenerics) { + Debug.Assert(false); + function.Warnings.Add($"Could not decode local function '{info.Method}'"); + if (localFunction.DeclarationScope != function.Body && localFunction.DeclarationScope.Parent is ILFunction declaringFunction) { + declaringFunction.LocalFunctions.Remove(localFunction); + } + continue; + } + + foreach (var useSite in info.UseSites) { + context.Step($"Transform use site at IL_{useSite.StartILOffset:x4}", useSite); + if (useSite.OpCode == OpCode.NewObj) { + TransformToLocalFunctionReference(localFunction, useSite); + } else { + TransformToLocalFunctionInvocation(localFunction.ReducedMethod, useSite); + } + } } finally { context.StepEndGroup(); } } } + bool TryValidateSkipCount(LocalFunctionInfo info, out int skipCount) + { + skipCount = 0; + var localFunction = info.Definition; + if (localFunction.Method.TypeParameters.Count == 0) + return true; + var parentMethod = ((ILFunction)localFunction.Parent).Method; + var method = localFunction.Method; + skipCount = parentMethod.DeclaringType.TypeParameterCount - method.DeclaringType.TypeParameterCount; + + if (skipCount > 0) + return false; + skipCount += parentMethod.TypeParameters.Count; + Debug.Assert(skipCount >= 0 && skipCount <= method.TypeArguments.Count); + if (skipCount < 0 || skipCount > method.TypeArguments.Count) + return false; + + if (skipCount > 0) { +#if DEBUG + foreach (var useSite in info.UseSites) { + var callerMethod = useSite.Ancestors.OfType().First().Method; + callerMethod = callerMethod.ReducedFrom ?? callerMethod; + IMethod method0; + if (useSite.OpCode == OpCode.NewObj) { + method0 = ((LdFtn)useSite.Arguments[1]).Method; + } else { + method0 = useSite.Method; + } + var totalSkipCount = skipCount + method0.DeclaringType.TypeParameterCount; + var methodSkippedArgs = method0.DeclaringType.TypeArguments.Concat(method0.TypeArguments).Take(totalSkipCount); + Debug.Assert(methodSkippedArgs.SequenceEqual(callerMethod.DeclaringType.TypeArguments.Concat(callerMethod.TypeArguments).Take(totalSkipCount))); + Debug.Assert(methodSkippedArgs.All(p => p.Kind == TypeKind.TypeParameter)); + Debug.Assert(methodSkippedArgs.Select(p => p.Name).SequenceEqual(method0.DeclaringType.TypeParameters.Concat(method0.TypeParameters).Take(totalSkipCount).Select(p => p.Name))); + } +#endif + } + return true; + } + void FindUseSites(ILFunction function, ILTransformContext context, Dictionary localFunctions) { foreach (var inst in function.Body.Descendants) { @@ -131,13 +181,14 @@ namespace ICSharpCode.Decompiler.IL.Transforms void HandleUseSite(IMethod targetMethod, CallInstruction inst) { if (!localFunctions.TryGetValue((MethodDefinitionHandle)targetMethod.MetadataToken, out var info)) { - targetMethod = (IMethod)targetMethod.MemberDefinition; context.StepStartGroup($"Read local function '{targetMethod.Name}'", inst); info = new LocalFunctionInfo() { UseSites = new List() { inst }, - Method = targetMethod, - Definition = ReadLocalFunctionDefinition(context.Function, targetMethod) + Method = (IMethod)targetMethod.MemberDefinition, }; + var rootFunction = context.Function; + int skipCount = GetSkipCount(rootFunction, targetMethod); + info.Definition = ReadLocalFunctionDefinition(rootFunction, targetMethod, skipCount); localFunctions.Add((MethodDefinitionHandle)targetMethod.MetadataToken, info); if (info.Definition != null) { FindUseSites(info.Definition, context, localFunctions); @@ -149,14 +200,17 @@ namespace ICSharpCode.Decompiler.IL.Transforms } } - ILFunction ReadLocalFunctionDefinition(ILFunction rootFunction, IMethod targetMethod) + ILFunction ReadLocalFunctionDefinition(ILFunction rootFunction, IMethod targetMethod, int skipCount) { var methodDefinition = context.PEFile.Metadata.GetMethodDefinition((MethodDefinitionHandle)targetMethod.MetadataToken); if (!methodDefinition.HasBody()) return null; var ilReader = context.CreateILReader(); var body = context.PEFile.Reader.GetMethodBody(methodDefinition.RelativeVirtualAddress); - var function = ilReader.ReadIL((MethodDefinitionHandle)targetMethod.MetadataToken, body, default, ILFunctionKind.LocalFunction, context.CancellationToken); + var genericContext = GenericContextFromTypeArguments(targetMethod, skipCount); + if (genericContext == null) + return null; + var function = ilReader.ReadIL((MethodDefinitionHandle)targetMethod.MetadataToken, body, genericContext.GetValueOrDefault(), ILFunctionKind.LocalFunction, context.CancellationToken); // Embed the local function into the parent function's ILAst, so that "Show steps" can show // how the local function body is being transformed. rootFunction.LocalFunctions.Add(function); @@ -165,10 +219,64 @@ namespace ICSharpCode.Decompiler.IL.Transforms var nestedContext = new ILTransformContext(context, function); function.RunTransforms(CSharpDecompiler.GetILTransforms().TakeWhile(t => !(t is LocalFunctionDecompiler)), nestedContext); function.DeclarationScope = null; - function.ReducedMethod = ReduceToLocalFunction(targetMethod); + function.ReducedMethod = ReduceToLocalFunction(function.Method); + function.ReducedMethod.NumberOfCompilerGeneratedGenerics = skipCount; return function; } + int GetSkipCount(ILFunction rootFunction, IMethod targetMethod) + { + targetMethod = (IMethod)targetMethod.MemberDefinition; + var skipCount = rootFunction.Method.DeclaringType.TypeParameters.Count + rootFunction.Method.TypeParameters.Count - targetMethod.DeclaringType.TypeParameters.Count; + if (skipCount < 0) { + skipCount = 0; + } + if (targetMethod.TypeParameters.Count > 0) { + var lastParams = targetMethod.Parameters.Where(p => IsClosureParameter(p, this.resolveContext)).SelectMany(p => UnwrapByRef(p.Type).TypeArguments) + .Select(pt => (int?)targetMethod.TypeArguments.IndexOf(pt)).DefaultIfEmpty().Max(); + if (lastParams != null && lastParams.GetValueOrDefault() + 1 > skipCount) + skipCount = lastParams.GetValueOrDefault() + 1; + } + return skipCount; + } + + static TypeSystem.GenericContext? GenericContextFromTypeArguments(IMethod targetMethod, int skipCount) + { + if (skipCount < 0 || skipCount > targetMethod.TypeParameters.Count) { + Debug.Assert(false); + return null; + } + int total = targetMethod.DeclaringType.TypeParameters.Count + skipCount; + if (total == 0) + return default(TypeSystem.GenericContext); + + var classTypeParameters = new List(targetMethod.DeclaringType.TypeParameters); + var methodTypeParameters = new List(targetMethod.TypeParameters); + var a = targetMethod.DeclaringType.TypeArguments.Concat(targetMethod.TypeArguments).Take(total); + int idx = 0; + foreach (var curA in a) { + int curIdx; + List curParameters; + IReadOnlyList curArgs; + if (idx < classTypeParameters.Count) { + curIdx = idx; + curParameters = classTypeParameters; + curArgs = targetMethod.DeclaringType.TypeArguments; + } else { + curIdx = idx - classTypeParameters.Count; + curParameters = methodTypeParameters; + curArgs = targetMethod.TypeArguments; + } + if (curArgs[curIdx].Kind != TypeKind.TypeParameter) + break; + curParameters[curIdx] = (ITypeParameter)curA; + idx++; + } + Debug.Assert(idx == total); + + return new TypeSystem.GenericContext(classTypeParameters, methodTypeParameters); + } + static T FindCommonAncestorInstruction(ILInstruction a, ILInstruction b) where T : ILInstruction { @@ -219,19 +327,16 @@ namespace ICSharpCode.Decompiler.IL.Transforms { useSite.Arguments[0].ReplaceWith(new LdNull().WithILRange(useSite.Arguments[0])); var fnptr = (IInstructionWithMethodOperand)useSite.Arguments[1]; - LocalFunctionMethod reducedMethod = function.ReducedMethod; - if (reducedMethod.TypeParameters.Count > 0) - reducedMethod = new LocalFunctionMethod(fnptr.Method, reducedMethod.NumberOfCompilerGeneratedParameters); - var replacement = new LdFtn(reducedMethod).WithILRange((ILInstruction)fnptr); + var specializeMethod = function.ReducedMethod.Specialize(fnptr.Method.Substitution); + var replacement = new LdFtn(specializeMethod).WithILRange((ILInstruction)fnptr); useSite.Arguments[1].ReplaceWith(replacement); } void TransformToLocalFunctionInvocation(LocalFunctionMethod reducedMethod, CallInstruction useSite) { - if (reducedMethod.TypeParameters.Count > 0) - reducedMethod = new LocalFunctionMethod(useSite.Method, reducedMethod.NumberOfCompilerGeneratedParameters); + var specializeMethod = reducedMethod.Specialize(useSite.Method.Substitution); bool wasInstanceCall = !useSite.Method.IsStatic; - var replacement = new Call(reducedMethod); + var replacement = new Call(specializeMethod); int firstArgumentIndex = wasInstanceCall ? 1 : 0; int argumentCount = useSite.Arguments.Count; int reducedArgumentCount = argumentCount - (reducedMethod.NumberOfCompilerGeneratedParameters + firstArgumentIndex); diff --git a/ICSharpCode.Decompiler/TypeSystem/Implementation/LocalFunctionMethod.cs b/ICSharpCode.Decompiler/TypeSystem/Implementation/LocalFunctionMethod.cs index cb53d51a6..6475bae1a 100644 --- a/ICSharpCode.Decompiler/TypeSystem/Implementation/LocalFunctionMethod.cs +++ b/ICSharpCode.Decompiler/TypeSystem/Implementation/LocalFunctionMethod.cs @@ -67,6 +67,10 @@ namespace ICSharpCode.Decompiler.TypeSystem.Implementation internal int NumberOfCompilerGeneratedParameters { get; } + internal int NumberOfCompilerGeneratedGenerics { get; set; } + + internal bool IsStaticLocalFunction => NumberOfCompilerGeneratedParameters == 0 && baseMethod.IsStatic; + public IMember MemberDefinition => this; public IType ReturnType => baseMethod.ReturnType;