From ef699c096bd27bd809b19977041b08cd3cbb2ff9 Mon Sep 17 00:00:00 2001 From: Daniel Grunwald Date: Sun, 29 Sep 2019 12:35:40 +0200 Subject: [PATCH] Support parameters in IAsyncEnumerator methods --- .../IL/ControlFlow/AsyncAwaitDecompiler.cs | 79 ++++++++++++++----- .../IL/ControlFlow/YieldReturnDecompiler.cs | 5 ++ 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs b/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs index 3a38f028e..b82525591 100644 --- a/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs +++ b/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs @@ -373,20 +373,50 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow // HACK: the normal async/await logic expects 'initialState' to be the 'in progress' state initialState = -1; try { - AnalyzeEnumeratorCtor(((NewObj)newObj).Method, context, out builderField, out stateField); + AnalyzeEnumeratorCtor(((NewObj)newObj).Method, context, out builderField, out builderType, out stateField); } catch (SymbolicAnalysisFailedException) { return false; } - builderType = builderField.Type.GetDefinition(); - if (builderType == null) - return false; return true; } else { return false; } + } else { + + // stloc v(newobj d__0..ctor(ldc.i4 - 2)) + // stfld <>4__this(ldloc v, ldloc this) + // stfld <>3__otherParam(ldloc v, ldloc otherParam) + // leave IL_0000(ldloc v) + int pos = 0; + if (!body.Instructions[pos].MatchStLoc(out var v, out var newObj)) + return false; + if (!MatchEnumeratorCreationNewObj(newObj, context, out initialState, out stateMachineType)) + return false; + pos++; + + while (MatchStFld(body.Instructions[pos], v, out var field, out var value)) { + if (!value.MatchLdLoc(out var p)) + return false; + fieldToParameterMap[field] = p; + pos++; + } + if (!body.Instructions[pos].MatchReturn(out var returnValue)) + return false; + if (!returnValue.MatchLdLoc(v)) + return false; + + // HACK: the normal async/await logic expects 'initialState' to be the 'in progress' state + initialState = -1; + try { + AnalyzeEnumeratorCtor(((NewObj)newObj).Method, context, out builderField, out builderType, out stateField); + if (methodType == AsyncMethodType.AsyncEnumerable) { + ResolveIEnumerableIEnumeratorFieldMapping(); + } + } catch (SymbolicAnalysisFailedException) { + return false; + } + return true; } - // TODO: enumerator creation with parameters - return false; } static bool MatchEnumeratorCreationNewObj(ILInstruction inst, ILTransformContext context, @@ -424,7 +454,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow return false; } - static void AnalyzeEnumeratorCtor(IMethod ctor, ILTransformContext context, out IField builderField, out IField stateField) + static void AnalyzeEnumeratorCtor(IMethod ctor, ILTransformContext context, out IField builderField, out ITypeDefinition builderType, out IField stateField) { builderField = null; stateField = null; @@ -439,21 +469,33 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow // stfld <>t__builder(ldloc this, call Create()) // leave IL_0000 (nop) // } - if (body.Instructions.ElementAtOrDefault(1).MatchStFld(out var target, out var field, out var value) + foreach (var inst in body.Instructions) { + if (inst.MatchStFld(out var target, out var field, out var value) && target.MatchLdThis() && value.MatchLdLoc(out var arg) && arg.Kind == VariableKind.Parameter && arg.Index == 0) { - stateField = (IField)field.MemberDefinition; - } else { - throw new SymbolicAnalysisFailedException("Could not find stateField"); - } - if (body.Instructions.ElementAtOrDefault(2).MatchStFld(out target, out field, out value) - && target.MatchLdThis() - && value is Call call && call.Method.Name == "Create") { - builderField = (IField)field.MemberDefinition; - } else { - throw new SymbolicAnalysisFailedException("Could not find builderField"); + stateField = (IField)field.MemberDefinition; + } + if (inst.MatchStFld(out target, out field, out value) + && target.MatchLdThis() + && value is Call call && call.Method.Name == "Create") { + builderField = (IField)field.MemberDefinition; + } } + if (stateField == null || builderField == null) + throw new SymbolicAnalysisFailedException(); + + builderType = builderField.Type.GetDefinition(); + if (builderType == null) + throw new SymbolicAnalysisFailedException(); + } + + private void ResolveIEnumerableIEnumeratorFieldMapping() + { + var getAsyncEnumerator = stateMachineType.GetMethods(m => m.Name.EndsWith(".GetAsyncEnumerator", StringComparison.Ordinal)).FirstOrDefault(); + if (getAsyncEnumerator == null) + throw new SymbolicAnalysisFailedException(); + YieldReturnDecompiler.ResolveIEnumerableIEnumeratorFieldMapping((MethodDefinitionHandle)getAsyncEnumerator.MetadataToken, context, fieldToParameterMap); } #endregion @@ -689,6 +731,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow function.AsyncReturnType = underlyingReturnType; function.MoveNextMethod = moveNextFunction.Method; function.CodeSize = moveNextFunction.CodeSize; + function.IsIterator = IsAsyncEnumerator; moveNextFunction.Variables.Clear(); moveNextFunction.ReleaseRef(); foreach (var branch in function.Descendants.OfType()) { diff --git a/ICSharpCode.Decompiler/IL/ControlFlow/YieldReturnDecompiler.cs b/ICSharpCode.Decompiler/IL/ControlFlow/YieldReturnDecompiler.cs index a4e90674b..dae5445a7 100644 --- a/ICSharpCode.Decompiler/IL/ControlFlow/YieldReturnDecompiler.cs +++ b/ICSharpCode.Decompiler/IL/ControlFlow/YieldReturnDecompiler.cs @@ -453,6 +453,11 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow MethodDefinitionHandle getEnumeratorMethod = metadata.GetTypeDefinition(enumeratorType).GetMethods().FirstOrDefault( m => metadata.GetString(metadata.GetMethodDefinition(m).Name).StartsWith("System.Collections.Generic.IEnumerable", StringComparison.Ordinal) && metadata.GetString(metadata.GetMethodDefinition(m).Name).EndsWith(".GetEnumerator", StringComparison.Ordinal)); + } + + internal static void ResolveIEnumerableIEnumeratorFieldMapping(MethodDefinitionHandle getEnumeratorMethod, ILTransformContext context, + Dictionary fieldToParameterMap) + { if (getEnumeratorMethod.IsNil) return; // no mappings (maybe it's just an IEnumerator implementation?) var function = CreateILAst(getEnumeratorMethod, context);