Browse Source

Support parameters in IAsyncEnumerator methods

pull/1730/head
Daniel Grunwald 6 years ago
parent
commit
ef699c096b
  1. 79
      ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs
  2. 5
      ICSharpCode.Decompiler/IL/ControlFlow/YieldReturnDecompiler.cs

79
ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs

@ -373,20 +373,50 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
// HACK: the normal async/await logic expects 'initialState' to be the 'in progress' state // HACK: the normal async/await logic expects 'initialState' to be the 'in progress' state
initialState = -1; initialState = -1;
try { try {
AnalyzeEnumeratorCtor(((NewObj)newObj).Method, context, out builderField, out stateField); AnalyzeEnumeratorCtor(((NewObj)newObj).Method, context, out builderField, out builderType, out stateField);
} catch (SymbolicAnalysisFailedException) { } catch (SymbolicAnalysisFailedException) {
return false; return false;
} }
builderType = builderField.Type.GetDefinition();
if (builderType == null)
return false;
return true; return true;
} else { } else {
return false; return false;
} }
} else {
// stloc v(newobj<CountUpSlowly> d__0..ctor(ldc.i4 - 2))
// stfld <>4__this(ldloc v, ldloc this)
// stfld <>3__otherParam(ldloc v, ldloc otherParam)
// leave IL_0000(ldloc v)
int pos = 0;
if (!body.Instructions[pos].MatchStLoc(out var v, out var newObj))
return false;
if (!MatchEnumeratorCreationNewObj(newObj, context, out initialState, out stateMachineType))
return false;
pos++;
while (MatchStFld(body.Instructions[pos], v, out var field, out var value)) {
if (!value.MatchLdLoc(out var p))
return false;
fieldToParameterMap[field] = p;
pos++;
}
if (!body.Instructions[pos].MatchReturn(out var returnValue))
return false;
if (!returnValue.MatchLdLoc(v))
return false;
// HACK: the normal async/await logic expects 'initialState' to be the 'in progress' state
initialState = -1;
try {
AnalyzeEnumeratorCtor(((NewObj)newObj).Method, context, out builderField, out builderType, out stateField);
if (methodType == AsyncMethodType.AsyncEnumerable) {
ResolveIEnumerableIEnumeratorFieldMapping();
}
} catch (SymbolicAnalysisFailedException) {
return false;
}
return true;
} }
// TODO: enumerator creation with parameters
return false;
} }
static bool MatchEnumeratorCreationNewObj(ILInstruction inst, ILTransformContext context, static bool MatchEnumeratorCreationNewObj(ILInstruction inst, ILTransformContext context,
@ -424,7 +454,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
return false; return false;
} }
static void AnalyzeEnumeratorCtor(IMethod ctor, ILTransformContext context, out IField builderField, out IField stateField) static void AnalyzeEnumeratorCtor(IMethod ctor, ILTransformContext context, out IField builderField, out ITypeDefinition builderType, out IField stateField)
{ {
builderField = null; builderField = null;
stateField = null; stateField = null;
@ -439,21 +469,33 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
// stfld <>t__builder(ldloc this, call Create()) // stfld <>t__builder(ldloc this, call Create())
// leave IL_0000 (nop) // leave IL_0000 (nop)
// } // }
if (body.Instructions.ElementAtOrDefault(1).MatchStFld(out var target, out var field, out var value) foreach (var inst in body.Instructions) {
if (inst.MatchStFld(out var target, out var field, out var value)
&& target.MatchLdThis() && target.MatchLdThis()
&& value.MatchLdLoc(out var arg) && value.MatchLdLoc(out var arg)
&& arg.Kind == VariableKind.Parameter && arg.Index == 0) { && arg.Kind == VariableKind.Parameter && arg.Index == 0) {
stateField = (IField)field.MemberDefinition; stateField = (IField)field.MemberDefinition;
} else { }
throw new SymbolicAnalysisFailedException("Could not find stateField"); if (inst.MatchStFld(out target, out field, out value)
} && target.MatchLdThis()
if (body.Instructions.ElementAtOrDefault(2).MatchStFld(out target, out field, out value) && value is Call call && call.Method.Name == "Create") {
&& target.MatchLdThis() builderField = (IField)field.MemberDefinition;
&& value is Call call && call.Method.Name == "Create") { }
builderField = (IField)field.MemberDefinition;
} else {
throw new SymbolicAnalysisFailedException("Could not find builderField");
} }
if (stateField == null || builderField == null)
throw new SymbolicAnalysisFailedException();
builderType = builderField.Type.GetDefinition();
if (builderType == null)
throw new SymbolicAnalysisFailedException();
}
private void ResolveIEnumerableIEnumeratorFieldMapping()
{
var getAsyncEnumerator = stateMachineType.GetMethods(m => m.Name.EndsWith(".GetAsyncEnumerator", StringComparison.Ordinal)).FirstOrDefault();
if (getAsyncEnumerator == null)
throw new SymbolicAnalysisFailedException();
YieldReturnDecompiler.ResolveIEnumerableIEnumeratorFieldMapping((MethodDefinitionHandle)getAsyncEnumerator.MetadataToken, context, fieldToParameterMap);
} }
#endregion #endregion
@ -689,6 +731,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
function.AsyncReturnType = underlyingReturnType; function.AsyncReturnType = underlyingReturnType;
function.MoveNextMethod = moveNextFunction.Method; function.MoveNextMethod = moveNextFunction.Method;
function.CodeSize = moveNextFunction.CodeSize; function.CodeSize = moveNextFunction.CodeSize;
function.IsIterator = IsAsyncEnumerator;
moveNextFunction.Variables.Clear(); moveNextFunction.Variables.Clear();
moveNextFunction.ReleaseRef(); moveNextFunction.ReleaseRef();
foreach (var branch in function.Descendants.OfType<Branch>()) { foreach (var branch in function.Descendants.OfType<Branch>()) {

5
ICSharpCode.Decompiler/IL/ControlFlow/YieldReturnDecompiler.cs

@ -453,6 +453,11 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
MethodDefinitionHandle getEnumeratorMethod = metadata.GetTypeDefinition(enumeratorType).GetMethods().FirstOrDefault( MethodDefinitionHandle getEnumeratorMethod = metadata.GetTypeDefinition(enumeratorType).GetMethods().FirstOrDefault(
m => metadata.GetString(metadata.GetMethodDefinition(m).Name).StartsWith("System.Collections.Generic.IEnumerable", StringComparison.Ordinal) m => metadata.GetString(metadata.GetMethodDefinition(m).Name).StartsWith("System.Collections.Generic.IEnumerable", StringComparison.Ordinal)
&& metadata.GetString(metadata.GetMethodDefinition(m).Name).EndsWith(".GetEnumerator", StringComparison.Ordinal)); && metadata.GetString(metadata.GetMethodDefinition(m).Name).EndsWith(".GetEnumerator", StringComparison.Ordinal));
}
internal static void ResolveIEnumerableIEnumeratorFieldMapping(MethodDefinitionHandle getEnumeratorMethod, ILTransformContext context,
Dictionary<IField, ILVariable> fieldToParameterMap)
{
if (getEnumeratorMethod.IsNil) if (getEnumeratorMethod.IsNil)
return; // no mappings (maybe it's just an IEnumerator implementation?) return; // no mappings (maybe it's just an IEnumerator implementation?)
var function = CreateILAst(getEnumeratorMethod, context); var function = CreateILAst(getEnumeratorMethod, context);

Loading…
Cancel
Save