Browse Source

Support parameters in IAsyncEnumerator methods

pull/1730/head
Daniel Grunwald 6 years ago
parent
commit
ef699c096b
  1. 67
      ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs
  2. 5
      ICSharpCode.Decompiler/IL/ControlFlow/YieldReturnDecompiler.cs

67
ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs

@ -373,21 +373,51 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow @@ -373,21 +373,51 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
// HACK: the normal async/await logic expects 'initialState' to be the 'in progress' state
initialState = -1;
try {
AnalyzeEnumeratorCtor(((NewObj)newObj).Method, context, out builderField, out stateField);
AnalyzeEnumeratorCtor(((NewObj)newObj).Method, context, out builderField, out builderType, out stateField);
} catch (SymbolicAnalysisFailedException) {
return false;
}
builderType = builderField.Type.GetDefinition();
if (builderType == null)
return false;
return true;
} else {
return false;
}
} else {
// stloc v(newobj<CountUpSlowly> d__0..ctor(ldc.i4 - 2))
// stfld <>4__this(ldloc v, ldloc this)
// stfld <>3__otherParam(ldloc v, ldloc otherParam)
// leave IL_0000(ldloc v)
int pos = 0;
if (!body.Instructions[pos].MatchStLoc(out var v, out var newObj))
return false;
if (!MatchEnumeratorCreationNewObj(newObj, context, out initialState, out stateMachineType))
return false;
pos++;
while (MatchStFld(body.Instructions[pos], v, out var field, out var value)) {
if (!value.MatchLdLoc(out var p))
return false;
fieldToParameterMap[field] = p;
pos++;
}
if (!body.Instructions[pos].MatchReturn(out var returnValue))
return false;
if (!returnValue.MatchLdLoc(v))
return false;
// HACK: the normal async/await logic expects 'initialState' to be the 'in progress' state
initialState = -1;
try {
AnalyzeEnumeratorCtor(((NewObj)newObj).Method, context, out builderField, out builderType, out stateField);
if (methodType == AsyncMethodType.AsyncEnumerable) {
ResolveIEnumerableIEnumeratorFieldMapping();
}
// TODO: enumerator creation with parameters
} catch (SymbolicAnalysisFailedException) {
return false;
}
return true;
}
}
static bool MatchEnumeratorCreationNewObj(ILInstruction inst, ILTransformContext context,
out int initialState, out ITypeDefinition stateMachineType)
@ -424,7 +454,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow @@ -424,7 +454,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
return false;
}
static void AnalyzeEnumeratorCtor(IMethod ctor, ILTransformContext context, out IField builderField, out IField stateField)
static void AnalyzeEnumeratorCtor(IMethod ctor, ILTransformContext context, out IField builderField, out ITypeDefinition builderType, out IField stateField)
{
builderField = null;
stateField = null;
@ -439,22 +469,34 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow @@ -439,22 +469,34 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
// stfld <>t__builder(ldloc this, call Create())
// leave IL_0000 (nop)
// }
if (body.Instructions.ElementAtOrDefault(1).MatchStFld(out var target, out var field, out var value)
foreach (var inst in body.Instructions) {
if (inst.MatchStFld(out var target, out var field, out var value)
&& target.MatchLdThis()
&& value.MatchLdLoc(out var arg)
&& arg.Kind == VariableKind.Parameter && arg.Index == 0) {
stateField = (IField)field.MemberDefinition;
} else {
throw new SymbolicAnalysisFailedException("Could not find stateField");
}
if (body.Instructions.ElementAtOrDefault(2).MatchStFld(out target, out field, out value)
if (inst.MatchStFld(out target, out field, out value)
&& target.MatchLdThis()
&& value is Call call && call.Method.Name == "Create") {
builderField = (IField)field.MemberDefinition;
} else {
throw new SymbolicAnalysisFailedException("Could not find builderField");
}
}
if (stateField == null || builderField == null)
throw new SymbolicAnalysisFailedException();
builderType = builderField.Type.GetDefinition();
if (builderType == null)
throw new SymbolicAnalysisFailedException();
}
private void ResolveIEnumerableIEnumeratorFieldMapping()
{
var getAsyncEnumerator = stateMachineType.GetMethods(m => m.Name.EndsWith(".GetAsyncEnumerator", StringComparison.Ordinal)).FirstOrDefault();
if (getAsyncEnumerator == null)
throw new SymbolicAnalysisFailedException();
YieldReturnDecompiler.ResolveIEnumerableIEnumeratorFieldMapping((MethodDefinitionHandle)getAsyncEnumerator.MetadataToken, context, fieldToParameterMap);
}
#endregion
#region AnalyzeMoveNext
@ -689,6 +731,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow @@ -689,6 +731,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
function.AsyncReturnType = underlyingReturnType;
function.MoveNextMethod = moveNextFunction.Method;
function.CodeSize = moveNextFunction.CodeSize;
function.IsIterator = IsAsyncEnumerator;
moveNextFunction.Variables.Clear();
moveNextFunction.ReleaseRef();
foreach (var branch in function.Descendants.OfType<Branch>()) {

5
ICSharpCode.Decompiler/IL/ControlFlow/YieldReturnDecompiler.cs

@ -453,6 +453,11 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow @@ -453,6 +453,11 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow
MethodDefinitionHandle getEnumeratorMethod = metadata.GetTypeDefinition(enumeratorType).GetMethods().FirstOrDefault(
m => metadata.GetString(metadata.GetMethodDefinition(m).Name).StartsWith("System.Collections.Generic.IEnumerable", StringComparison.Ordinal)
&& metadata.GetString(metadata.GetMethodDefinition(m).Name).EndsWith(".GetEnumerator", StringComparison.Ordinal));
}
internal static void ResolveIEnumerableIEnumeratorFieldMapping(MethodDefinitionHandle getEnumeratorMethod, ILTransformContext context,
Dictionary<IField, ILVariable> fieldToParameterMap)
{
if (getEnumeratorMethod.IsNil)
return; // no mappings (maybe it's just an IEnumerator implementation?)
var function = CreateILAst(getEnumeratorMethod, context);

Loading…
Cancel
Save