diff --git a/ICSharpCode.Decompiler/Ast/AstBuilder.cs b/ICSharpCode.Decompiler/Ast/AstBuilder.cs index 3b2b71a7c..8310340f6 100644 --- a/ICSharpCode.Decompiler/Ast/AstBuilder.cs +++ b/ICSharpCode.Decompiler/Ast/AstBuilder.cs @@ -20,7 +20,7 @@ namespace Decompiler for (int i = 0; i < 4; i++) { if (Options.ReduceAstJumps) { //astCompileUnit.AcceptVisitor(new Transforms.Ast.RemoveGotos(), null); - astCompileUnit.AcceptVisitor(new Transforms.Ast.RemoveDeadLabels(), null); + //astCompileUnit.AcceptVisitor(new Transforms.Ast.RemoveDeadLabels(), null); } if (Options.ReduceAstLoops) { //astCompileUnit.AcceptVisitor(new Transforms.Ast.RestoreLoop(), null); diff --git a/ICSharpCode.Decompiler/Ast/AstMetodBodyBuilder.cs b/ICSharpCode.Decompiler/Ast/AstMetodBodyBuilder.cs index 1f26b9aa1..eff57689b 100644 --- a/ICSharpCode.Decompiler/Ast/AstMetodBodyBuilder.cs +++ b/ICSharpCode.Decompiler/Ast/AstMetodBodyBuilder.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; using Ast = ICSharpCode.NRefactory.CSharp; using ICSharpCode.NRefactory.CSharp; using Cecil = Mono.Cecil; @@ -48,7 +49,7 @@ namespace Decompiler List body = new ILAstBuilder().Build(methodDef, true); ILAstOptimizer bodyGraph = new ILAstOptimizer(); - bodyGraph.Optimize(body); + bodyGraph.Optimize(ref body); List intNames = new List(new string[] {"i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t"}); Dictionary typeNames = new Dictionary(); @@ -385,7 +386,7 @@ namespace Decompiler throw new NotImplementedException(); #endregion #region Branching - case Code.Br: return branchCommand; + case Code.Br: return new Ast.GotoStatement(((ILLabel)byteCode.Operand).Name); case Code.Brfalse: return new Ast.IfElseStatement(new Ast.UnaryOperatorExpression(UnaryOperatorType.Not, arg1), branchCommand); case Code.Brtrue: return new Ast.IfElseStatement(arg1, branchCommand); case Code.Beq: return new Ast.IfElseStatement(new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Equality, arg2), branchCommand); diff --git a/ICSharpCode.Decompiler/Ast/Transforms/RemoveDeadLabels.cs b/ICSharpCode.Decompiler/Ast/Transforms/RemoveDeadLabels.cs deleted file mode 100644 index b8f65771a..000000000 --- a/ICSharpCode.Decompiler/Ast/Transforms/RemoveDeadLabels.cs +++ /dev/null @@ -1,56 +0,0 @@ -using System; -using System.Collections.Generic; -using ICSharpCode.NRefactory.CSharp; - -namespace Decompiler.Transforms.Ast -{ - public class RemoveDeadLabels : DepthFirstAstVisitor - { - List usedLabels = new List(); - bool collectingUsedLabels; - - public override object VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data) - { - collectingUsedLabels = true; - base.VisitConstructorDeclaration(constructorDeclaration, data); - collectingUsedLabels = false; - base.VisitConstructorDeclaration(constructorDeclaration, data); - return null; - } - - public override object VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data) - { - collectingUsedLabels = true; - base.VisitMethodDeclaration(methodDeclaration, data); - collectingUsedLabels = false; - base.VisitMethodDeclaration(methodDeclaration, data); - return null; - } - - public override object VisitAccessor(Accessor accessor, object data) - { - collectingUsedLabels = true; - base.VisitAccessor(accessor, data); - collectingUsedLabels = false; - return base.VisitAccessor(accessor, data); - } - - public override object VisitGotoStatement(GotoStatement gotoStatement, object data) - { - if (collectingUsedLabels) { - usedLabels.Add(gotoStatement.Label); - } - return null; - } - - public override object VisitLabelStatement(LabelStatement labelStatement, object data) - { - if (!collectingUsedLabels) { - if (!usedLabels.Contains(labelStatement.Label)) { - labelStatement.Remove(); - } - } - return null; - } - } -} diff --git a/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj b/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj index 2c5ead47b..eaa5373ac 100644 --- a/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj +++ b/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj @@ -56,7 +56,6 @@ - diff --git a/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs b/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs index dcc8a0a37..485b6b528 100644 --- a/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs +++ b/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs @@ -10,7 +10,21 @@ namespace Decompiler.ControlFlow { public class ILAstOptimizer { - public void Optimize(List ast) + public void Optimize(ref List ast) + { + OptimizeRecursive(ref ast); + + // Provide a container for the algorithms below + ILBlock astBlock = new ILBlock(ast); + + FlattenNestedMovableBlocks(astBlock); + SimpleGotoRemoval(astBlock); + RemoveDeadLabels(astBlock); + + ast = astBlock.Body; + } + + void OptimizeRecursive(ref List ast) { ILLabel entryLabel; List tryCatchBlocks = ast.OfType().ToList(); @@ -31,15 +45,26 @@ namespace Decompiler.ControlFlow // Recursively optimze try-cath blocks foreach(ILTryCatchBlock tryCatchBlock in tryCatchBlocks) { - Optimize(tryCatchBlock.TryBlock.Body); + Optimize(ref tryCatchBlock.TryBlock.Body); foreach(ILTryCatchBlock.CatchBlock catchBlock in tryCatchBlock.CatchBlocks) { - Optimize(catchBlock.Body); + Optimize(ref catchBlock.Body); } - Optimize(tryCatchBlock.FinallyBlock.Body); + Optimize(ref tryCatchBlock.FinallyBlock.Body); } + + // Sort the nodes in the original order + ast = ast.OrderBy(n => n.GetSelfAndChildrenRecursive().First().OriginalOrder).ToList(); + + ast.Insert(0, new ILExpression(OpCodes.Br, entryLabel)); + } + + + class ILMoveAbleBlock: ILBlock + { + public int OriginalOrder; } - int nextLabelIndex = 0; + int nextBlockIndex = 0; /// /// Group input into a set of blocks that can be later arbitraliby schufled. @@ -50,9 +75,9 @@ namespace Decompiler.ControlFlow { List blocks = new List(); - ILBlock block = new ILBlock(); + ILMoveAbleBlock block = new ILMoveAbleBlock() { OriginalOrder = (nextBlockIndex++) }; blocks.Add(block); - entryLabel = new ILLabel() { Name = "Block_" + (nextLabelIndex++) }; + entryLabel = new ILLabel() { Name = "Block_" + block.OriginalOrder }; block.Body.Add(entryLabel); if (ast.Count == 0) @@ -71,14 +96,14 @@ namespace Decompiler.ControlFlow (currNode is ILExpression) && ((ILExpression)currNode).OpCode.IsBranch()) { ILBlock lastBlock = block; - block = new ILBlock(); + block = new ILMoveAbleBlock() { OriginalOrder = (nextBlockIndex++) }; blocks.Add(block); // Explicit branch from one block to other // (unless the last expression was unconditional branch) if (!(lastNode is ILExpression) || ((ILExpression)lastNode).OpCode.CanFallThough()) { - ILLabel blockLabel = new ILLabel() { Name = "Block_" + (nextLabelIndex++) }; - lastBlock.Body.Add(new ILExpression(OpCodes.Br_S, blockLabel)); + ILLabel blockLabel = new ILLabel() { Name = "Block_" + block.OriginalOrder }; + lastBlock.Body.Add(new ILExpression(OpCodes.Br, blockLabel)); block.Body.Add(blockLabel); } } @@ -110,7 +135,7 @@ namespace Decompiler.ControlFlow cfNode.UserData = node; // Find all contained labels - foreach(ILLabel label in node.GetChildrenRecursive()) { + foreach(ILLabel label in node.GetSelfAndChildrenRecursive()) { labelToCfNode[label] = cfNode; } } @@ -126,7 +151,7 @@ namespace Decompiler.ControlFlow ControlFlowNode source = astNodeToCfNode[node]; // Find all branches - foreach(ILExpression child in node.GetChildrenRecursive()) { + foreach(ILExpression child in node.GetSelfAndChildrenRecursive()) { IEnumerable targets = child.GetBranchTargets(); if (targets != null) { foreach(ILLabel target in targets) { @@ -145,7 +170,7 @@ namespace Decompiler.ControlFlow return new ControlFlowGraph(cfNodes.ToArray()); } - static List FindLoops(HashSet body, ControlFlowNode entryPoint) + static List FindLoops(HashSet nodes, ControlFlowNode entryPoint) { List result = new List(); @@ -154,15 +179,15 @@ namespace Decompiler.ControlFlow while(agenda.Count > 0) { ControlFlowNode node = agenda.Dequeue(); - if (body.Contains(node) + if (nodes.Contains(node) && node.DominanceFrontier.Contains(node) && node != entryPoint) { HashSet loopContents = new HashSet(); - FindLoopContents(body, loopContents, node, node); + FindLoopContents(nodes, loopContents, node, node); // Move the content into loop block - body.ExceptWith(loopContents); + nodes.ExceptWith(loopContents); result.Add(new ILLoop() { ContentBlock = new ILBlock(FindLoops(loopContents, node)) }); } @@ -173,23 +198,23 @@ namespace Decompiler.ControlFlow } // Add whatever is left - foreach(var node in body) { + foreach(var node in nodes) { result.Add((ILNode)node.UserData); } return result; } - static void FindLoopContents(HashSet body, HashSet loopContents, ControlFlowNode loopHead, ControlFlowNode addNode) + static void FindLoopContents(HashSet nodes, HashSet loopContents, ControlFlowNode loopHead, ControlFlowNode addNode) { - if (body.Contains(addNode) && loopHead.Dominates(addNode) && loopContents.Add(addNode)) { + if (nodes.Contains(addNode) && loopHead.Dominates(addNode) && loopContents.Add(addNode)) { foreach (var edge in addNode.Incoming) { - FindLoopContents(body, loopContents, loopHead, edge.Source); + FindLoopContents(nodes, loopContents, loopHead, edge.Source); } } } - static HashSet FindDominatedNodes(HashSet body, ControlFlowNode head) + static HashSet FindDominatedNodes(HashSet nodes, ControlFlowNode head) { var exitNodes = head.DominanceFrontier.SelectMany(n => n.Predecessors); HashSet agenda = new HashSet(exitNodes); @@ -199,7 +224,7 @@ namespace Decompiler.ControlFlow ControlFlowNode addNode = agenda.First(); agenda.Remove(addNode); - if (body.Contains(addNode) && head.Dominates(addNode) && result.Add(addNode)) { + if (nodes.Contains(addNode) && head.Dominates(addNode) && result.Add(addNode)) { foreach (var predecessor in addNode.Predecessors) { agenda.Add(predecessor); } @@ -210,7 +235,7 @@ namespace Decompiler.ControlFlow return result; } - static List FindConditions(HashSet body, ControlFlowNode entryNode) + static List FindConditions(HashSet nodes, ControlFlowNode entryNode) { List result = new List(); @@ -219,7 +244,7 @@ namespace Decompiler.ControlFlow while(agenda.Count > 0) { ControlFlowNode node = agenda.Dequeue(); - if (body.Contains(node) && node.Outgoing.Count == 2) { + if (nodes.Contains(node) && node.Outgoing.Count == 2) { ILCondition condition = new ILCondition() { ConditionBlock = new ILBlock((ILNode)node.UserData) }; @@ -227,15 +252,16 @@ namespace Decompiler.ControlFlow frontiers.UnionWith(node.Outgoing[0].Target.DominanceFrontier); frontiers.UnionWith(node.Outgoing[1].Target.DominanceFrontier); if (!frontiers.Contains(node.Outgoing[0].Target)) { - HashSet content1 = FindDominatedNodes(body, node.Outgoing[0].Target); - body.ExceptWith(content1); + HashSet content1 = FindDominatedNodes(nodes, node.Outgoing[0].Target); + nodes.ExceptWith(content1); condition.Block1 = new ILBlock(FindConditions(content1, node.Outgoing[0].Target)); } if (!frontiers.Contains(node.Outgoing[1].Target)) { - HashSet content2 = FindDominatedNodes(body, node.Outgoing[1].Target); - body.ExceptWith(content2); + HashSet content2 = FindDominatedNodes(nodes, node.Outgoing[1].Target); + nodes.ExceptWith(content2); condition.Block2 = new ILBlock(FindConditions(content2, node.Outgoing[1].Target)); } + nodes.Remove(node); result.Add(condition); } @@ -246,7 +272,7 @@ namespace Decompiler.ControlFlow } // Add whatever is left - foreach(var node in body) { + foreach(var node in nodes) { result.Add((ILNode)node.UserData); } @@ -318,5 +344,71 @@ namespace Decompiler.ControlFlow } */ + + /// + /// Flattens all nested movable blocks, except the the top level 'node' argument + /// + void FlattenNestedMovableBlocks(ILNode node) + { + ILBlock block = node as ILBlock; + if (block != null) { + List flatBody = new List(); + foreach (ILNode child in block.Body) { + FlattenNestedMovableBlocks(child); + if (child is ILMoveAbleBlock) { + flatBody.AddRange(((ILMoveAbleBlock)child).Body); + } else { + flatBody.Add(child); + } + } + block.Body = flatBody; + } else if (node is ILExpression) { + // Optimization - no need to check expressions + } else if (node != null) { + // Recursively find all ILBlocks + foreach(ILNode child in node.GetChildren()) { + FlattenNestedMovableBlocks(child); + } + } + } + + void SimpleGotoRemoval(ILBlock ast) + { + var blocks = ast.GetSelfAndChildrenRecursive().ToList(); + foreach(ILBlock block in blocks) { + for (int i = 0; i < block.Body.Count; i++) { + ILExpression expr = block.Body[i] as ILExpression; + // Uncoditional branch + if (expr != null && (expr.OpCode == OpCodes.Br || expr.OpCode == OpCodes.Br_S)) { + // Check that branch is followed by its label (allow multiple labels) + for (int j = i + 1; j < block.Body.Count; j++) { + ILLabel label = block.Body[j] as ILLabel; + if (label == null) + break; // Can not optimize + if (expr.Operand == label) { + block.Body.RemoveAt(i); + break; // Branch removed + } + } + } + } + } + } + + void RemoveDeadLabels(ILBlock ast) + { + HashSet liveLabels = new HashSet(ast.GetSelfAndChildrenRecursive().SelectMany(e => e.GetBranchTargets())); + var blocks = ast.GetSelfAndChildrenRecursive().ToList(); + foreach(ILBlock block in blocks) { + for (int i = 0; i < block.Body.Count;) { + ILLabel label = block.Body[i] as ILLabel; + if (label != null && !liveLabels.Contains(label)) { + block.Body.RemoveAt(i); + } else { + i++; + } + } + } + } } } diff --git a/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs b/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs index af8be061f..db96898e5 100644 --- a/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs +++ b/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs @@ -11,19 +11,24 @@ namespace Decompiler { public abstract class ILNode { - public IEnumerable GetChildrenRecursive() where T: ILNode + public IEnumerable GetSelfAndChildrenRecursive() where T: ILNode { + if (this is T) + yield return (T)this; + Stack> stack = new Stack>(); try { stack.Push(GetChildren().GetEnumerator()); while (stack.Count > 0) { while (stack.Peek().MoveNext()) { ILNode element = stack.Peek().Current; - if (element is T) - yield return (T)element; - IEnumerable children = element.GetChildren(); - if (children != null) { - stack.Push(children.GetEnumerator()); + if (element != null) { + if (element is T) + yield return (T)element; + IEnumerable children = element.GetChildren(); + if (children != null) { + stack.Push(children.GetEnumerator()); + } } } stack.Pop().Dispose(); @@ -37,17 +42,7 @@ namespace Decompiler public virtual IEnumerable GetChildren() { - return null; - } - } - - public class ILLabel: ILNode - { - public string Name; - - public override string ToString() - { - return Name + ":"; + yield break; } } @@ -71,6 +66,16 @@ namespace Decompiler } } + public class ILLabel: ILNode + { + public string Name; + + public override string ToString() + { + return Name + ":"; + } + } + public class ILTryCatchBlock: ILNode { public class CatchBlock: ILBlock @@ -128,7 +133,7 @@ namespace Decompiler } else if (this.Operand is ILLabel[]) { return (ILLabel[])this.Operand; } else { - return null; + return new ILLabel[] { }; } }