diff --git a/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs b/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs index b82525591..7dfc1048a 100644 --- a/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs +++ b/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs @@ -80,6 +80,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow int initialState; Dictionary fieldToParameterMap = new Dictionary(); Dictionary cachedFieldToParameterMap = new Dictionary(); + IField disposeModeField; // 'disposeMode' field (IAsyncEnumerable/IAsyncEnumerator only) // These fields are set by AnalyzeMoveNext(): ILFunction moveNextFunction; @@ -119,6 +120,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow try { AnalyzeMoveNext(); ValidateCatchBlock(); + AnalyzeDisposeAsync(); } catch (SymbolicAnalysisFailedException) { return; } @@ -492,7 +494,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow private void ResolveIEnumerableIEnumeratorFieldMapping() { - var getAsyncEnumerator = stateMachineType.GetMethods(m => m.Name.EndsWith(".GetAsyncEnumerator", StringComparison.Ordinal)).FirstOrDefault(); + var getAsyncEnumerator = stateMachineType.Methods.FirstOrDefault(m => m.Name.EndsWith(".GetAsyncEnumerator", StringComparison.Ordinal)); if (getAsyncEnumerator == null) throw new SymbolicAnalysisFailedException(); YieldReturnDecompiler.ResolveIEnumerableIEnumeratorFieldMapping((MethodDefinitionHandle)getAsyncEnumerator.MetadataToken, context, fieldToParameterMap); @@ -721,6 +723,33 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow newState = 0; return false; } + + /// + /// Analyse the DisposeAsync() method in order to find the disposeModeField. + /// + private void AnalyzeDisposeAsync() + { + disposeModeField = null; + if (!IsAsyncEnumerator) { + return; + } + var disposeAsync = stateMachineType.Methods.FirstOrDefault(m => m.Name.EndsWith(".DisposeAsync", StringComparison.Ordinal)); + if (disposeAsync == null) + throw new SymbolicAnalysisFailedException("Could not find DisposeAsync()"); + var disposeAsyncHandle = (MethodDefinitionHandle)disposeAsync.MetadataToken; + var function = YieldReturnDecompiler.CreateILAst(disposeAsyncHandle, context); + foreach (var store in function.Descendants) { + if (!store.MatchStFld(out var target, out var field, out var value)) + continue; + if (!target.MatchLdThis()) + continue; + if (!value.MatchLdcI4(1)) + throw new SymbolicAnalysisFailedException(); + if (disposeModeField != null) + throw new SymbolicAnalysisFailedException("Multiple stores to disposeMode in DisposeAsync()"); + disposeModeField = (IField)field.MemberDefinition; + } + } #endregion #region InlineBodyOfMoveNext @@ -828,6 +857,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow block.Instructions.Add(new InvalidBranch("Could not detect 'yield return'")); } } + SimplifyIfDisposeMode(block); } // Skip the state dispatcher and directly jump to the initial state var entryPoint = stateToBlockMap.GetOrDefault(initialState); @@ -842,6 +872,20 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow } } + private bool SimplifyIfDisposeMode(Block block) + { + // if (logic.not(ldfld disposeMode(ldloc this))) br falseInst + // br trueInst + if (!block.MatchIfAtEndOfBlock(out var condition, out _, out var falseInst)) + return false; + if (!condition.MatchLdFld(out var target, out var field)) + return false; + if (!(target.MatchLdThis() && field.MemberDefinition == disposeModeField)) + return false; + block.Instructions[block.Instructions.Count - 2] = falseInst; + block.Instructions.RemoveAt(block.Instructions.Count - 1); + return true; + } bool AnalyzeAwaitBlock(Block block, out ILVariable awaiter, out IField awaiterField, out int state, out int yieldOffset) { diff --git a/ICSharpCode.Decompiler/IL/Instructions/Block.cs b/ICSharpCode.Decompiler/IL/Instructions/Block.cs index 20b9da814..5bb1a2dae 100644 --- a/ICSharpCode.Decompiler/IL/Instructions/Block.cs +++ b/ICSharpCode.Decompiler/IL/Instructions/Block.cs @@ -311,6 +311,27 @@ namespace ICSharpCode.Decompiler.IL return false; return this.FinalInstruction.MatchLdLoc(tmp); } + + public bool MatchIfAtEndOfBlock(out ILInstruction condition, out ILInstruction trueInst, out ILInstruction falseInst) + { + condition = null; + trueInst = null; + falseInst = null; + if (Instructions.Count < 2) + return false; + if (Instructions[Instructions.Count - 2].MatchIfInstruction(out condition, out trueInst)) { + // Swap trueInst<>falseInst for every logic.not in the condition. + falseInst = Instructions.Last(); + while (condition.MatchLogicNot(out var arg)) { + condition = arg; + ILInstruction tmp = trueInst; + trueInst = falseInst; + falseInst = tmp; + } + return true; + } + return false; + } } public enum BlockKind