From 40212685b68a04e3e2feee727d1c1b9ef5247a0b Mon Sep 17 00:00:00 2001 From: Daniel Grunwald Date: Sun, 3 Sep 2017 14:17:49 +0200 Subject: [PATCH] [async]: control flow reconstruction --- .../IL/ControlFlow/AsyncAwaitDecompiler.cs | 130 +++++++++++++++++- .../IL/ControlFlow/StateRangeAnalysis.cs | 10 ++ ICSharpCode.Decompiler/IL/Instructions.cs | 117 ++++++++++++++++ ICSharpCode.Decompiler/IL/Instructions.tt | 5 +- .../IL/Instructions/InstructionCollection.cs | 4 + 5 files changed, 262 insertions(+), 4 deletions(-) diff --git a/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs b/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs index 462018dc0..ff98e0eea 100644 --- a/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs +++ b/ICSharpCode.Decompiler/IL/ControlFlow/AsyncAwaitDecompiler.cs @@ -68,6 +68,8 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow AnalyzeMoveNext(); ValidateCatchBlock(); InlineBodyOfMoveNext(function); + AnalyzeStateMachine(function); + FinalizeInlineMoveNext(function); } catch (SymbolicAnalysisFailedException) { return; } @@ -382,17 +384,139 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow branch.ReplaceWith(new Leave((BlockContainer)function.Body) { ILRange = branch.ILRange }); } } + function.Variables.AddRange(function.Descendants.OfType().Select(inst => inst.Variable).Distinct()); + function.Variables.RemoveDead(); + } + + void FinalizeInlineMoveNext(ILFunction function) + { + context.Step("FinalizeInlineMoveNext()", function); foreach (var leave in function.Descendants.OfType()) { if (leave.TargetContainer == moveNextFunction.Body) { leave.ReplaceWith(new InvalidBranch { - Message = " leave MoveNext - await not detected correctly ", + Message = "leave MoveNext - await not detected correctly", ILRange = leave.ILRange }); } } - function.Variables.AddRange(function.Descendants.OfType().Select(inst => inst.Variable).Distinct()); - function.Variables.RemoveDead(); } #endregion + + /// + /// Analyze the the state machine; and replace 'leave IL_0000' with await+jump to block that gets + /// entered on the next MoveNext() call. + /// + void AnalyzeStateMachine(ILFunction function) + { + context.Step("AnalyzeStateMachine()", function); + foreach (var container in function.Descendants.OfType()) { + // Use a separate state range analysis per container. + var sra = new StateRangeAnalysis(StateRangeAnalysisMode.AsyncMoveNext, stateField, cachedStateVar); + sra.CancellationToken = context.CancellationToken; + sra.AssignStateRanges(container, LongSet.Universe); + + foreach (var block in container.Blocks) { + if (block.Instructions.Last().MatchLeave((BlockContainer)moveNextFunction.Body)) { + // This is likely an 'await' block + if (AnalyzeAwaitBlock(block, out var awaiter, out var awaiterField, out var state)) { + block.Instructions.Add(new Await(new LdLoca(awaiter))); + Block targetBlock = sra.FindBlock(container, state); + if (targetBlock != null) { + block.Instructions.Add(new Branch(targetBlock)); + } else { + block.Instructions.Add(new InvalidBranch("Could not find block for state " + state)); + } + } + } + } + var entryPoint = sra.FindBlock(container, initialState); + if (entryPoint != null) { + container.Blocks.Insert(0, new Block { + Instructions = { + new Branch(entryPoint) + } + }); + } + container.SortBlocks(deleteUnreachableBlocks: true); + } + } + + bool AnalyzeAwaitBlock(Block block, out ILVariable awaiter, out IField awaiterField, out int state) + { + awaiter = null; + awaiterField = null; + state = 0; + context.CancellationToken.ThrowIfCancellationRequested(); + int pos = block.Instructions.Count - 2; + // call AwaitUnsafeOnCompleted(ldflda <>t__builder(ldloc this), ldloca awaiter, ldloc this) + if (!MatchCall(block.Instructions[pos], "AwaitUnsafeOnCompleted", out var callArgs)) + return false; + if (callArgs.Count != 3) + return false; + if (!IsBuilderFieldOnThis(callArgs[0])) + return false; + if (!callArgs[1].MatchLdLoca(out awaiter)) + return false; + if (callArgs[2].MatchLdThis()) { + // OK (if state machine is a struct) + pos--; + } else if (callArgs[2].MatchLdLoca(out var tempVar)) { + // Roslyn, non-optimized uses a class for the state machine. + // stloc tempVar(ldloc this) + // call AwaitUnsafeOnCompleted(ldflda <>t__builder](ldloc this), ldloca awaiter, ldloca tempVar) + if (!(pos > 0 && block.Instructions[pos - 1].MatchStLoc(tempVar, out var tempVal))) + return false; + if (!tempVal.MatchLdThis()) + return false; + pos -= 2; + } else { + return false; + } + // stfld StateMachine.<>awaiter(ldloc this, ldloc awaiter) + if (!block.Instructions[pos].MatchStFld(out var target, out awaiterField, out var value)) + return false; + if (!target.MatchLdThis()) + return false; + if (!value.MatchLdLoc(awaiter)) + return false; + pos--; + // stloc S_10(ldloc this) + // stloc S_11(ldc.i4 0) + // stloc cachedStateVar(ldloc S_11) + // stfld <>1__state(ldloc S_10, ldloc S_11) + if (!block.Instructions[pos].MatchStFld(out target, out var field, out value)) + return false; + if (!StackSlotValue(target).MatchLdThis()) + return false; + if (field.MemberDefinition != stateField) + return false; + if (!StackSlotValue(value).MatchLdcI4(out state)) + return false; + if (pos > 0 && block.Instructions[pos - 1] is StLoc stloc + && stloc.Variable.Kind == VariableKind.Local && stloc.Variable.Index == cachedStateVar.Index + && StackSlotValue(stloc.Value).MatchLdcI4(state)) { + // also delete the assignment to cachedStateVar + pos--; + } + block.Instructions.RemoveRange(pos, block.Instructions.Count - pos); + // delete preceding dead stores: + while (pos > 0 && block.Instructions[pos - 1] is StLoc stloc2 + && stloc2.Variable.IsSingleDefinition && stloc2.Variable.LoadCount == 0 + && stloc2.Variable.Kind == VariableKind.StackSlot) { + pos--; + } + block.Instructions.RemoveRange(pos, block.Instructions.Count - pos); + return true; + } + + static ILInstruction StackSlotValue(ILInstruction inst) + { + if (inst.MatchLdLoc(out var v) && v.Kind == VariableKind.StackSlot && v.IsSingleDefinition) { + if (v.StoreInstructions[0] is StLoc stloc) { + return stloc.Value; + } + } + return inst; + } } } diff --git a/ICSharpCode.Decompiler/IL/ControlFlow/StateRangeAnalysis.cs b/ICSharpCode.Decompiler/IL/ControlFlow/StateRangeAnalysis.cs index 8a316473f..a5d65c845 100644 --- a/ICSharpCode.Decompiler/IL/ControlFlow/StateRangeAnalysis.cs +++ b/ICSharpCode.Decompiler/IL/ControlFlow/StateRangeAnalysis.cs @@ -187,5 +187,15 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow yield return (block, stateSet); } } + + public Block FindBlock(BlockContainer container, int newState) + { + Block targetBlock = null; + foreach (var (block, stateSet) in GetBlockStateSetMapping(container)) { + if (stateSet.Contains(newState)) + targetBlock = block; + } + return targetBlock; + } } } diff --git a/ICSharpCode.Decompiler/IL/Instructions.cs b/ICSharpCode.Decompiler/IL/Instructions.cs index 6a426461b..59e5d6dea 100644 --- a/ICSharpCode.Decompiler/IL/Instructions.cs +++ b/ICSharpCode.Decompiler/IL/Instructions.cs @@ -162,6 +162,8 @@ namespace ICSharpCode.Decompiler.IL RefAnyValue, /// Yield an element from an iterator. YieldReturn, + /// C# await operator. + Await, /// Matches any node AnyNode, } @@ -3995,6 +3997,98 @@ namespace ICSharpCode.Decompiler.IL } } } +namespace ICSharpCode.Decompiler.IL +{ + /// C# await operator. + public sealed partial class Await : ILInstruction + { + public Await(ILInstruction value) : base(OpCode.Await) + { + this.Value = value; + } + public static readonly SlotInfo ValueSlot = new SlotInfo("Value", canInlineInto: true); + ILInstruction value; + public ILInstruction Value { + get { return this.value; } + set { + ValidateChild(value); + SetChildInstruction(ref this.value, value, 0); + } + } + protected sealed override int GetChildCount() + { + return 1; + } + protected sealed override ILInstruction GetChild(int index) + { + switch (index) { + case 0: + return this.value; + default: + throw new IndexOutOfRangeException(); + } + } + protected sealed override void SetChild(int index, ILInstruction value) + { + switch (index) { + case 0: + this.Value = value; + break; + default: + throw new IndexOutOfRangeException(); + } + } + protected sealed override SlotInfo GetChildSlot(int index) + { + switch (index) { + case 0: + return ValueSlot; + default: + throw new IndexOutOfRangeException(); + } + } + public sealed override ILInstruction Clone() + { + var clone = (Await)ShallowClone(); + clone.Value = this.value.Clone(); + return clone; + } + public override StackType ResultType { get { return StackType.Void; } } + protected override InstructionFlags ComputeFlags() + { + return InstructionFlags.SideEffect | value.Flags; + } + public override InstructionFlags DirectFlags { + get { + return InstructionFlags.SideEffect; + } + } + public override void WriteTo(ITextOutput output) + { + output.Write(OpCode); + output.Write('('); + this.value.WriteTo(output); + output.Write(')'); + } + public override void AcceptVisitor(ILVisitor visitor) + { + visitor.VisitAwait(this); + } + public override T AcceptVisitor(ILVisitor visitor) + { + return visitor.VisitAwait(this); + } + public override T AcceptVisitor(ILVisitor visitor, C context) + { + return visitor.VisitAwait(this, context); + } + protected internal override bool PerformMatch(ILInstruction other, ref Patterns.Match match) + { + var o = other as Await; + return o != null && this.value.PerformMatch(o.value, ref match); + } + } +} namespace ICSharpCode.Decompiler.IL.Patterns { /// Matches any node @@ -4313,6 +4407,10 @@ namespace ICSharpCode.Decompiler.IL { Default(inst); } + protected internal virtual void VisitAwait(Await inst) + { + Default(inst); + } } /// @@ -4587,6 +4685,10 @@ namespace ICSharpCode.Decompiler.IL { return Default(inst); } + protected internal virtual T VisitAwait(Await inst) + { + return Default(inst); + } } /// @@ -4861,6 +4963,10 @@ namespace ICSharpCode.Decompiler.IL { return Default(inst, context); } + protected internal virtual T VisitAwait(Await inst, C context) + { + return Default(inst, context); + } } partial class InstructionOutputExtensions @@ -4932,6 +5038,7 @@ namespace ICSharpCode.Decompiler.IL "refanytype", "refanyval", "yield.return", + "await", "AnyNode", }; } @@ -5420,6 +5527,16 @@ namespace ICSharpCode.Decompiler.IL value = default(ILInstruction); return false; } + public bool MatchAwait(out ILInstruction value) + { + var inst = this as Await; + if (inst != null) { + value = inst.Value; + return true; + } + value = default(ILInstruction); + return false; + } } } diff --git a/ICSharpCode.Decompiler/IL/Instructions.tt b/ICSharpCode.Decompiler/IL/Instructions.tt index be5cbf167..8fab367ca 100644 --- a/ICSharpCode.Decompiler/IL/Instructions.tt +++ b/ICSharpCode.Decompiler/IL/Instructions.tt @@ -219,10 +219,13 @@ CustomClassName("RefAnyValue"), Unary, HasTypeOperand, MayThrow, ResultType("Ref")), new OpCode("yield.return", "Yield an element from an iterator.", - MayBranch, // yield return may end up returning if the consumer disposes the iterator without + MayBranch, // yield return may end up returning if the consumer disposes the iterator SideEffect, // consumer can have arbitrary side effects while we're yielding CustomArguments("value"), VoidResult), // note: "yield break" is always represented using a "leave" instruction + new OpCode("await", "C# await operator.", + SideEffect, // other code can run with arbitrary side effects while we're waiting + CustomArguments("value"), ResultType("Void")), // patterns new OpCode("AnyNode", "Matches any node", Pattern, CustomArguments(), CustomConstructor), diff --git a/ICSharpCode.Decompiler/IL/Instructions/InstructionCollection.cs b/ICSharpCode.Decompiler/IL/Instructions/InstructionCollection.cs index 747541260..d0030a9bb 100644 --- a/ICSharpCode.Decompiler/IL/Instructions/InstructionCollection.cs +++ b/ICSharpCode.Decompiler/IL/Instructions/InstructionCollection.cs @@ -18,6 +18,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; namespace ICSharpCode.Decompiler.IL { @@ -86,15 +87,18 @@ namespace ICSharpCode.Decompiler.IL #endif } + [DebuggerStepThrough] public bool MoveNext() { return ++pos < list.Count; } public T Current { + [DebuggerStepThrough] get { return list[pos]; } } + [DebuggerStepThrough] public void Dispose() { #if DEBUG