From adda97cf8477dcac42e809a8768f09715db60c1c Mon Sep 17 00:00:00 2001 From: Siegfried Pammer Date: Sat, 3 Mar 2018 16:50:10 +0100 Subject: [PATCH] Add AwaitInFinallyTransform --- .gitmodules | 2 +- .../TestCases/Correctness/Async.cs | 27 ++- .../IL/ControlFlow/AsyncAwaitDecompiler.cs | 1 + .../IL/ControlFlow/AwaitInCatchTransform.cs | 174 +++++++++++++++++- .../IL/Instructions/PatternMatching.cs | 21 +++ 5 files changed, 220 insertions(+), 5 deletions(-) diff --git a/.gitmodules b/.gitmodules index bedfc09b3..a9e291c02 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,4 +4,4 @@ [submodule "ILSpy-tests"] path = ILSpy-tests url = https://github.com/icsharpcode/ILSpy-tests - ignore = all + \ No newline at end of file diff --git a/ICSharpCode.Decompiler.Tests/TestCases/Correctness/Async.cs b/ICSharpCode.Decompiler.Tests/TestCases/Correctness/Async.cs index d94009270..320ef94bf 100644 --- a/ICSharpCode.Decompiler.Tests/TestCases/Correctness/Async.cs +++ b/ICSharpCode.Decompiler.Tests/TestCases/Correctness/Async.cs @@ -41,7 +41,13 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Correctness await TaskMethodWithoutAwaitButWithExceptionHandling(); #if CS60 await AwaitCatch(Task.FromResult(1)); - await AwaitFinally(Task.FromResult(2)); + await AwaitMultipleCatchBlocks(Task.FromResult(1)); + await AwaitMultipleCatchBlocks2(Task.FromResult(1)); + try { + await AwaitFinally(Task.FromResult(2)); + } catch (Exception ex) { + Console.WriteLine(ex + " caught!"); + } #endif await NestedAwait(Task.FromResult(Task.FromResult(5))); await AwaitWithStack(Task.FromResult(3)); @@ -49,6 +55,7 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Correctness #if CS60 await AwaitInCatch(Task.FromResult(1), Task.FromResult(2)); await AwaitInFinally(Task.FromResult(2), Task.FromResult(4)); + await AwaitInCatchAndFinally(Task.FromResult(3), Task.FromResult(6), Task.FromResult(9)); #endif } @@ -211,6 +218,24 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Correctness } Console.WriteLine("End Method"); } + + public async Task AwaitInCatchAndFinally(Task task1, Task task2, Task task3) + { + try { + Console.WriteLine("Start try"); + await task1; + Console.WriteLine("End try"); + } catch (Exception ex) { + Console.WriteLine("Start catch"); + await task2; + Console.WriteLine("End catch"); + } finally { + Console.WriteLine("Start finally"); + await task3; + Console.WriteLine("End finally"); + } + Console.WriteLine("End Method"); + } #endif } } \ No newline at end of file diff --git a/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs b/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs index 97ca6be17..a1728a2b7 100644 --- a/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs +++ b/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs @@ -128,6 +128,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow function.RunTransforms(CSharpDecompiler.EarlyILTransforms(), context); AwaitInCatchTransform.Run(function, context); + AwaitInFinallyTransform.Run(function, context); } private void CleanUpBodyOfMoveNext(ILFunction function) diff --git a/ICSharpCode.Decompiler/IL/ControlFlow/AwaitInCatchTransform.cs b/ICSharpCode.Decompiler/IL/ControlFlow/AwaitInCatchTransform.cs index 9d14a6549..1ca08ba93 100644 --- a/ICSharpCode.Decompiler/IL/ControlFlow/AwaitInCatchTransform.cs +++ b/ICSharpCode.Decompiler/IL/ControlFlow/AwaitInCatchTransform.cs @@ -18,9 +18,12 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Text; +using ICSharpCode.Decompiler.FlowAnalysis; using ICSharpCode.Decompiler.IL.Transforms; +using ICSharpCode.Decompiler.TypeSystem; namespace ICSharpCode.Decompiler.IL.ControlFlow { @@ -35,7 +38,7 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow if (!(tryCatch.Parent?.Parent is BlockContainer container)) continue; // Detect all handlers that contain an await expression - AnalyzeHandlers(tryCatch.Handlers, out var transformableCatchBlocks); + AnalyzeHandlers(tryCatch.Handlers, out var catchHandlerIdentifier, out var transformableCatchBlocks); var cfg = new ControlFlowGraph(container, context.CancellationToken); if (transformableCatchBlocks.Count > 0) changedContainers.Add(container); @@ -99,10 +102,10 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow /// /// Analyzes all catch handlers and returns every handler that follows the await catch handler pattern. /// - static bool AnalyzeHandlers(InstructionCollection handlers, out List<(int Id, TryCatchHandler Handler, Block RealCatchBlockEntryPoint, Block NextBlock, IfInstruction JumpTableEntry, StLoc ObjectVariableStore)> transformableCatchBlocks) + static bool AnalyzeHandlers(InstructionCollection handlers, out ILVariable catchHandlerIdentifier, out List<(int Id, TryCatchHandler Handler, Block RealCatchBlockEntryPoint, Block NextBlock, IfInstruction JumpTableEntry, StLoc ObjectVariableStore)> transformableCatchBlocks) { transformableCatchBlocks = new List<(int Id, TryCatchHandler Handler, Block RealCatchBlockEntryPoint, Block NextBlock, IfInstruction JumpTableEntry, StLoc ObjectVariableStore)>(); - ILVariable catchHandlerIdentifier = null; + catchHandlerIdentifier = null; foreach (var handler in handlers) { if (!MatchAwaitCatchHandler((BlockContainer)handler.Body, out int id, out var identifierVariable, out var realEntryPoint, out var nextBlock, out var jumpTableEntry, out var objectVariableStore)) continue; @@ -184,4 +187,169 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow return false; } } + + class AwaitInFinallyTransform + { + public static void Run(ILFunction function, ILTransformContext context) + { + HashSet changedContainers = new HashSet(); + + // analyze all try-catch statements in the function + foreach (var tryCatch in function.Descendants.OfType().ToArray()) { + if (!(tryCatch.Parent?.Parent is BlockContainer container)) + continue; + // await in finally uses a single catch block with catch-type object + if (tryCatch.Handlers.Count != 1 || !(tryCatch.Handlers[0].Body is BlockContainer catchBlockContainer) || !tryCatch.Handlers[0].Variable.Type.IsKnownType(KnownTypeCode.Object)) + continue; + // and consists of an assignment to a temporary that is used outside the catch block + // and a jump to the finally block + var block = catchBlockContainer.EntryPoint; + if (block.Instructions.Count < 2 || !block.Instructions[0].MatchStLoc(out var globalCopyVar, out var value) || !value.MatchLdLoc(tryCatch.Handlers[0].Variable)) + continue; + if (block.Instructions.Count == 3) { + if (!block.Instructions[1].MatchStLoc(out var globalCopyVarTemp, out value) || !value.MatchLdLoc(globalCopyVar)) + continue; + globalCopyVar = globalCopyVarTemp; + } + if (!block.Instructions[block.Instructions.Count - 1].MatchBranch(out var entryPointOfFinally)) + continue; + // globalCopyVar should only be used once, at the end of the finally-block + if (globalCopyVar.LoadCount != 1 || globalCopyVar.StoreCount > 2) + continue; + var tempStore = globalCopyVar.LoadInstructions[0].Parent as StLoc; + if (tempStore == null || !MatchExceptionCaptureBlock(tempStore, out var exitOfFinally, out var afterFinally, out var blocksToRemove)) + continue; + if (afterFinally.Instructions.Count < 2) + continue; + int offset = 0; + if (afterFinally.Instructions[0].MatchLdLoc(out var identifierVariable)) { + if (identifierVariable.LoadCount != 1 || identifierVariable.StoreCount != 1) + continue; + offset = 1; + } + if (!afterFinally.Instructions[offset].MatchStLoc(out var globalCopyVarSplitted, out var ldnull) || !ldnull.MatchLdNull()) + continue; + if (globalCopyVarSplitted.StoreCount != 1 || globalCopyVarSplitted.LoadCount != 0) + continue; + context.Step("Inline finally block with await", tryCatch.Handlers[0]); + var cfg = new ControlFlowGraph(container, context.CancellationToken); + var exitOfFinallyNode = cfg.GetNode(exitOfFinally); + var entryPointOfFinallyNode = cfg.GetNode(entryPointOfFinally); + var nodes = new Stack(new[] { entryPointOfFinallyNode }); + var blocksInFinally = new HashSet(); + var invalidExits = new List(); + while (nodes.Count > 0) { + var currentNode = nodes.Pop(); + if (currentNode != exitOfFinallyNode) { + foreach (var successor in currentNode.Successors) + nodes.Push(successor); + if (entryPointOfFinallyNode.Dominates(currentNode)) + blocksInFinally.Add((Block)currentNode.UserData); + else + invalidExits.Add(currentNode); + } + } + foreach (var blockToRemove in blocksToRemove) { + blockToRemove.Remove(); + } + var finallyContainer = new BlockContainer(); + entryPointOfFinally.Remove(); + if (offset == 1) + afterFinally.Instructions.RemoveAt(0); + changedContainers.Add(container); + finallyContainer.Blocks.Add(entryPointOfFinally); + exitOfFinally.Instructions.RemoveRange(tempStore.ChildIndex, 3); + exitOfFinally.Instructions.Add(new Leave(finallyContainer)); + foreach (var branchToFinally in container.Descendants.OfType()) { + if (branchToFinally.TargetBlock == entryPointOfFinally) + branchToFinally.ReplaceWith(new Branch(afterFinally)); + } + foreach (var newBlock in blocksInFinally) { + newBlock.Remove(); + finallyContainer.Blocks.Add(newBlock); + } + tryCatch.ReplaceWith(new TryFinally(tryCatch.TryBlock, finallyContainer)); + } + + // clean up all modified containers + foreach (var container in changedContainers) + container.SortBlocks(deleteUnreachableBlocks: true); + } + + /// + /// Block finallyHead (incoming: 2) { + /// [body of finally] + /// stloc V_4(ldloc V_1) + /// if (comp(ldloc V_4 == ldnull)) br afterFinally + /// br typeCheckBlock + /// } + /// + /// Block typeCheckBlock (incoming: 1) { + /// stloc S_110(isinst System.Exception(ldloc V_4)) + /// if (comp(ldloc S_110 != ldnull)) br captureBlock + /// br throwBlock + /// } + /// + /// Block throwBlock (incoming: 1) { + /// throw(ldloc V_4) + /// } + /// + /// Block captureBlock (incoming: 1) { + /// callvirt Throw(call Capture(ldloc S_110)) + /// br afterFinally + /// } + /// + /// Block afterFinally (incoming: 2) { + /// stloc V_1(ldnull) + /// [after finally] + /// } + /// + static bool MatchExceptionCaptureBlock(StLoc tempStore, out Block endOfFinally, out Block afterFinally, out List blocksToRemove) + { + afterFinally = null; + endOfFinally = (Block)tempStore.Parent; + blocksToRemove = new List(); + int count = endOfFinally.Instructions.Count; + if (tempStore.ChildIndex != count - 3) + return false; + if (!(endOfFinally.Instructions[count - 2] is IfInstruction ifInst)) + return false; + if (!endOfFinally.Instructions.Last().MatchBranch(out var typeCheckBlock)) + return false; + if (!ifInst.TrueInst.MatchBranch(out afterFinally)) + return false; + // match typeCheckBlock + if (typeCheckBlock.Instructions.Count != 3) + return false; + if (!typeCheckBlock.Instructions[0].MatchStLoc(out var castStore, out var cast) + || !cast.MatchIsInst(out var arg, out var type) || !type.IsKnownType(KnownTypeCode.Exception) || !arg.MatchLdLoc(tempStore.Variable)) + return false; + if (!typeCheckBlock.Instructions[1].MatchIfInstruction(out var cond, out var jumpToCaptureBlock)) + return false; + if (!cond.MatchCompNotEqualsNull(out arg) || !arg.MatchLdLoc(castStore)) + return false; + if (!typeCheckBlock.Instructions[2].MatchBranch(out var throwBlock)) + return false; + if (!jumpToCaptureBlock.MatchBranch(out var captureBlock)) + return false; + // match throwBlock + if (throwBlock.Instructions.Count != 1 || !throwBlock.Instructions[0].MatchThrow(out arg) || !arg.MatchLdLoc(tempStore.Variable)) + return false; + // match captureBlock + if (captureBlock.Instructions.Count != 2) + return false; + if (!captureBlock.Instructions[1].MatchBranch(afterFinally)) + return false; + if (!(captureBlock.Instructions[0] is CallVirt callVirt) || callVirt.Method.FullName != "System.Runtime.ExceptionServices.ExceptionDispatchInfo.Throw" || callVirt.Arguments.Count != 1) + return false; + if (!(callVirt.Arguments[0] is Call call) || call.Method.FullName != "System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture" || call.Arguments.Count != 1) + return false; + if (!call.Arguments[0].MatchLdLoc(castStore)) + return false; + blocksToRemove.Add(typeCheckBlock); + blocksToRemove.Add(throwBlock); + blocksToRemove.Add(captureBlock); + return true; + } + } } diff --git a/ICSharpCode.Decompiler/IL/Instructions/PatternMatching.cs b/ICSharpCode.Decompiler/IL/Instructions/PatternMatching.cs index 9403210be..b4bd321eb 100644 --- a/ICSharpCode.Decompiler/IL/Instructions/PatternMatching.cs +++ b/ICSharpCode.Decompiler/IL/Instructions/PatternMatching.cs @@ -363,6 +363,27 @@ namespace ICSharpCode.Decompiler.IL } } + /// + /// Matches 'comp(arg != ldnull)' + /// + public bool MatchCompNotEqualsNull(out ILInstruction arg) + { + if (!MatchCompNotEquals(out var left, out var right)) { + arg = null; + return false; + } + if (right.MatchLdNull()) { + arg = left; + return true; + } else if (left.MatchLdNull()) { + arg = right; + return true; + } else { + arg = null; + return false; + } + } + /// /// Matches comp(left != right) or logic.not(comp(left == right)). ///