From 2730133f5f5e155c87cc7e70a40debb3f2d21d25 Mon Sep 17 00:00:00 2001 From: Daniel Grunwald Date: Sat, 25 Jun 2016 16:00:24 +0200 Subject: [PATCH] Combine switch sections that branch to same label. --- .../ControlFlow/ControlFlowSimplification.cs | 86 +++++++++++++++---- ICSharpCode.Decompiler/IL/ILReader.cs | 2 +- .../IL/Instructions/SwitchInstruction.cs | 12 +++ .../Util/CollectionExtensions.cs | 28 ++++++ ICSharpCode.Decompiler/Util/Interval.cs | 82 +++++++++++++++++- 5 files changed, 189 insertions(+), 21 deletions(-) diff --git a/ICSharpCode.Decompiler/IL/ControlFlow/ControlFlowSimplification.cs b/ICSharpCode.Decompiler/IL/ControlFlow/ControlFlowSimplification.cs index d17151b01..04824258f 100644 --- a/ICSharpCode.Decompiler/IL/ControlFlow/ControlFlowSimplification.cs +++ b/ICSharpCode.Decompiler/IL/ControlFlow/ControlFlowSimplification.cs @@ -41,29 +41,76 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow foreach (var block in function.Descendants.OfType()) { // Remove 'nop' instructions block.Instructions.RemoveAll(inst => inst.OpCode == OpCode.Nop); - // Ensure return blocks are inlined: - if (block.Instructions.Count == 2 && block.Instructions[1].OpCode == OpCode.Return) { - Return ret = (Return)block.Instructions[1]; - ILVariable v; - ILInstruction inst; - if (ret.ReturnValue != null && ret.ReturnValue.MatchLdLoc(out v) - && v.IsSingleDefinition && v.LoadCount == 1 && block.Instructions[0].MatchStLoc(v, out inst)) - { - inst.AddILRange(ret.ReturnValue.ILRange); - inst.AddILRange(block.Instructions[0].ILRange); - ret.ReturnValue = inst; - block.Instructions.RemoveAt(0); - } + + InlineReturnBlock(block); + // due to our of of basic blocks at this point, + // switch instructions can only appear as second-to-last insturction + SimplifySwitchInstruction(block.Instructions.ElementAtOrDefault(block.Instructions.Count - 2) as SwitchInstruction); + } + SimplifyBranchChains(function); + CleanUpEmptyBlocks(function); + } + + void InlineReturnBlock(Block block) + { + // In debug mode, the C#-compiler generates 'return blocks' that + // unnecessarily store the return value to a local and then load it again: + // v = + // ret(v) + // (where 'v' has no other uses) + // Simplify these to a simple `ret()` so that they match the release build version. + // + if (block.Instructions.Count == 2 && block.Instructions[1].OpCode == OpCode.Return) { + Return ret = (Return)block.Instructions[1]; + ILVariable v; + ILInstruction inst; + if (ret.ReturnValue != null && ret.ReturnValue.MatchLdLoc(out v) + && v.IsSingleDefinition && v.LoadCount == 1 && block.Instructions[0].MatchStLoc(v, out inst)) + { + inst.AddILRange(ret.ReturnValue.ILRange); + inst.AddILRange(block.Instructions[0].ILRange); + ret.ReturnValue = inst; + block.Instructions.RemoveAt(0); } } + } + + void SimplifySwitchInstruction(SwitchInstruction sw) + { + if (sw == null) + return; + + // ControlFlowSimplification runs early (before any other control flow transforms). + // Any switch instructions will only have branch instructions in the sections. + + // dict from branch target to switch section + var dict = new Dictionary(); + sw.Sections.RemoveAll( + section => { + Block target; + if (section.Body.MatchBranch(out target)) { + SwitchSection primarySection; + if (dict.TryGetValue(target, out primarySection)) { + primarySection.Labels = primarySection.Labels.UnionWith(section.Labels); + return true; // remove this section + } else { + dict.Add(target, section); + } + } + return false; + }); + } + + void SimplifyBranchChains(ILFunction function) + { HashSet visitedBlocks = new HashSet(); foreach (var branch in function.Descendants.OfType()) { - // Resolve indirect branches + // Resolve chained branches to the final target: var targetBlock = branch.TargetBlock; visitedBlocks.Clear(); while (targetBlock.Instructions.Count == 1 && targetBlock.Instructions[0].OpCode == OpCode.Branch) { if (!visitedBlocks.Add(targetBlock)) { - // prevent infinite loop when indirect branches point in infinite loop + // prevent infinite loop when branch chain is cyclic break; } var nextBranch = (Branch)targetBlock.Instructions[0]; @@ -75,14 +122,18 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow if (IsReturnBlock(targetBlock)) { // Replace branches to 'return blocks' with the return instruction branch.ReplaceWith(targetBlock.Instructions[0].Clone()); - } else if (branch.TargetBlock.Instructions.Count == 1 && branch.TargetBlock.Instructions[0].OpCode == OpCode.Leave) { + } else if (targetBlock.Instructions.Count == 1 && targetBlock.Instructions[0].OpCode == OpCode.Leave) { // Replace branches to 'leave' instruction with the leave instruction - Leave leave = (Leave)branch.TargetBlock.Instructions[0]; + Leave leave = (Leave)targetBlock.Instructions[0]; branch.ReplaceWith(new Leave(leave.TargetContainer) { ILRange = branch.ILRange }); } if (targetBlock.IncomingEdgeCount == 0) targetBlock.Instructions.Clear(); // mark the block for deletion } + } + + void CleanUpEmptyBlocks(ILFunction function) + { foreach (var container in function.Descendants.OfType()) { foreach (var block in container.Blocks) { if (block.Instructions.Count == 0) @@ -123,5 +174,6 @@ namespace ICSharpCode.Decompiler.IL.ControlFlow targetBlock.Instructions.Clear(); // mark targetBlock for deletion return true; } + } } diff --git a/ICSharpCode.Decompiler/IL/ILReader.cs b/ICSharpCode.Decompiler/IL/ILReader.cs index a6981f90b..ab1a8592d 100644 --- a/ICSharpCode.Decompiler/IL/ILReader.cs +++ b/ICSharpCode.Decompiler/IL/ILReader.cs @@ -1042,7 +1042,7 @@ namespace ICSharpCode.Decompiler.IL Debug.Assert(right.ResultType.IsIntegerType()); return new Comp(kind, un ? Sign.Unsigned : Sign.Signed, left, right); } else { - // object reference or managed reference comparison + // integer equality, object reference or managed reference comparison return new Comp(kind, Sign.None, left, right); } } diff --git a/ICSharpCode.Decompiler/IL/Instructions/SwitchInstruction.cs b/ICSharpCode.Decompiler/IL/Instructions/SwitchInstruction.cs index d3d604961..5335c2c25 100644 --- a/ICSharpCode.Decompiler/IL/Instructions/SwitchInstruction.cs +++ b/ICSharpCode.Decompiler/IL/Instructions/SwitchInstruction.cs @@ -17,6 +17,8 @@ // DEALINGS IN THE SOFTWARE. using System; +using System.Collections.Immutable; +using System.Diagnostics; using System.Linq; namespace ICSharpCode.Decompiler.IL @@ -111,6 +113,16 @@ namespace ICSharpCode.Decompiler.IL clone.Sections.AddRange(this.Sections.Select(h => (SwitchSection)h.Clone())); return clone; } + + internal override void CheckInvariant(ILPhase phase) + { + base.CheckInvariant(phase); + LongSet sets = new LongSet(ImmutableArray.Empty); + foreach (var section in Sections) { + Debug.Assert(!section.Labels.Intersects(sets)); + sets = sets.UnionWith(section.Labels); + } + } } partial class SwitchSection diff --git a/ICSharpCode.Decompiler/Util/CollectionExtensions.cs b/ICSharpCode.Decompiler/Util/CollectionExtensions.cs index 3a25f965f..322b0b053 100644 --- a/ICSharpCode.Decompiler/Util/CollectionExtensions.cs +++ b/ICSharpCode.Decompiler/Util/CollectionExtensions.cs @@ -46,5 +46,33 @@ namespace ICSharpCode.Decompiler } return result; } + + /// + /// The merge step of merge sort. + /// + public static IEnumerable Merge(this IEnumerable input1, IEnumerable input2, Comparison comparison) + { + var enumA = input1.GetEnumerator(); + var enumB = input2.GetEnumerator(); + bool moreA = enumA.MoveNext(); + bool moreB = enumB.MoveNext(); + while (moreA && moreB) { + if (comparison(enumA.Current, enumB.Current) <= 0) { + yield return enumA.Current; + moreA = enumA.MoveNext(); + } else { + yield return enumB.Current; + moreB = enumB.MoveNext(); + } + } + while (moreA) { + yield return enumA.Current; + moreA = enumA.MoveNext(); + } + while (moreB) { + yield return enumB.Current; + moreB = enumB.MoveNext(); + } + } } } diff --git a/ICSharpCode.Decompiler/Util/Interval.cs b/ICSharpCode.Decompiler/Util/Interval.cs index 142b4f42d..0d65f0bab 100644 --- a/ICSharpCode.Decompiler/Util/Interval.cs +++ b/ICSharpCode.Decompiler/Util/Interval.cs @@ -1,4 +1,5 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; @@ -94,7 +95,7 @@ namespace ICSharpCode.Decompiler if (End == int.MinValue) return string.Format("[{0}..int.MaxValue]", Start); else - return string.Format("[{0}..{1})", Start, End); + return string.Format("[{0}..{1})", Start, End); } #region Equals and GetHashCode implementation @@ -111,7 +112,7 @@ namespace ICSharpCode.Decompiler public override int GetHashCode() { return Start ^ End ^ (End << 7); - } + } public static bool operator ==(Interval lhs, Interval rhs) { @@ -230,7 +231,7 @@ namespace ICSharpCode.Decompiler if (End == long.MinValue) return string.Format("[{0}..long.MaxValue]", Start); else - return string.Format("[{0}..{1})", Start, End); + return string.Format("[{0}..{1})", Start, End); } #region Equals and GetHashCode implementation @@ -283,6 +284,81 @@ namespace ICSharpCode.Decompiler get { return Intervals.IsDefaultOrEmpty; } } + IEnumerable DoIntersectWith(LongSet other) + { + var enumA = this.Intervals.GetEnumerator(); + var enumB = other.Intervals.GetEnumerator(); + bool moreA = enumA.MoveNext(); + bool moreB = enumB.MoveNext(); + while (moreA && moreB) { + LongInterval a = enumA.Current; + LongInterval b = enumB.Current; + LongInterval intersection = a.Intersect(b); + if (!intersection.IsEmpty) { + yield return intersection; + } + if (a.InclusiveEnd < b.InclusiveEnd) { + moreA = enumA.MoveNext(); + } else { + moreB = enumB.MoveNext(); + } + } + } + + public bool Intersects(LongSet other) + { + return DoIntersectWith(other).Any(); + } + + public LongSet IntersectWith(LongSet other) + { + return new LongSet(DoIntersectWith(other).ToImmutableArray()); + } + + IEnumerable DoUnionWith(LongSet other) + { + long start = long.MinValue; + long end = long.MinValue; + bool empty = true; + foreach (var element in this.Intervals.Merge(other.Intervals, (a, b) => a.Start.CompareTo(b.Start))) { + Debug.Assert(start <= element.Start); + + if (!empty && element.Start < end - 1) { + // element overlaps or touches [start, end), so combine the intervals: + if (element.End == long.MinValue) { + // special case: element goes all the way up to long.MaxValue inclusive + yield return new LongInterval(start, element.End); + break; + } else { + end = Math.Max(end, element.End); + } + } else { + // flush existing interval: + if (!empty) { + yield return new LongInterval(start, end); + } else { + empty = false; + } + start = element.Start; + end = element.End; + } + if (end == long.MinValue) { + // special case: element goes all the way up to long.MaxValue inclusive + // all further intervals in the input must be contained in [start, end), + // so ignore them (and avoid trouble due to the overflow in `end`). + break; + } + } + if (!empty) { + yield return new LongInterval(start, end); + } + } + + public LongSet UnionWith(LongSet other) + { + return new LongSet(DoUnionWith(other).ToImmutableArray()); + } + public bool Contains(long val) { int index = upper_bound(val);