From 9fb7d244ed8b5d06c13850d4f7e041bfe6a9b497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Srbeck=C3=BD?= Date: Mon, 14 Feb 2011 02:07:05 +0000 Subject: [PATCH] Find conditions --- .../Ast/AstMetodBodyBuilder.cs | 64 ++++---- .../ILAst/ILAstOptimizer.cs | 140 ++++++++++++------ ICSharpCode.Decompiler/ILAst/ILAstTypes.cs | 14 +- 3 files changed, 125 insertions(+), 93 deletions(-) diff --git a/ICSharpCode.Decompiler/Ast/AstMetodBodyBuilder.cs b/ICSharpCode.Decompiler/Ast/AstMetodBodyBuilder.cs index eff57689b..61d448e5b 100644 --- a/ICSharpCode.Decompiler/Ast/AstMetodBodyBuilder.cs +++ b/ICSharpCode.Decompiler/Ast/AstMetodBodyBuilder.cs @@ -150,15 +150,12 @@ namespace Decompiler */ } else if (node is ILCondition) { ILCondition conditionalNode = (ILCondition)node; - yield return TransformBlock(conditionalNode.ConditionBlock); - - Ast.IfElseStatement ifElseStmt = new Ast.IfElseStatement { - Condition = new PrimitiveExpression(true), - TrueStatement = TransformBlock(conditionalNode.Block1), - FalseStatement = TransformBlock(conditionalNode.Block2) + // Swap bodies + yield return new Ast.IfElseStatement { + Condition = new UnaryOperatorExpression(UnaryOperatorType.Not, MakeBranchCondition(conditionalNode.Condition)), + TrueStatement = TransformBlock(conditionalNode.FalseBlock), + FalseStatement = TransformBlock(conditionalNode.TrueBlock) }; - - yield return ifElseStmt; } else if (node is ILTryCatchBlock) { ILTryCatchBlock tryCachNode = ((ILTryCatchBlock)node); List catchClauses = new List(); @@ -197,35 +194,27 @@ namespace Decompiler return TransformByteCode(methodDef, expr, args); } - /* - - Ast.Expression MakeBranchCondition(Branch branch) - { - return MakeBranchCondition_Internal(branch); - } - - Ast.Expression MakeBranchCondition_Internal(Branch branch) + Ast.Expression MakeBranchCondition(ILExpression expr) { - if (branch is SimpleBranch) { - List args = TransformExpressionArguments((ILExpression)((SimpleBranch)branch).BasicBlock.Body[0]); - Ast.Expression arg1 = args.Count >= 1 ? args[0] : null; - Ast.Expression arg2 = args.Count >= 2 ? args[1] : null; - switch(((ILExpression)((SimpleBranch)branch).BasicBlock.Body[0]).OpCode.Code) { - case Code.Brfalse: return new Ast.UnaryOperatorExpression(UnaryOperatorType.Not, arg1); - case Code.Brtrue: return arg1; - case Code.Beq: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Equality, arg2); - case Code.Bge: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThanOrEqual, arg2); - case Code.Bge_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThanOrEqual, arg2); - case Code.Bgt: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThan, arg2); - case Code.Bgt_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThan, arg2); - case Code.Ble: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThanOrEqual, arg2); - case Code.Ble_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThanOrEqual, arg2); - case Code.Blt: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThan, arg2); - case Code.Blt_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThan, arg2); - case Code.Bne_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.InEquality, arg2); - case Code.Leave: return new Ast.PrimitiveExpression(true); - default: throw new Exception("Bad opcode"); - } + List args = TransformExpressionArguments(expr); + Ast.Expression arg1 = args.Count >= 1 ? args[0] : null; + Ast.Expression arg2 = args.Count >= 2 ? args[1] : null; + switch(expr.OpCode.Code) { + case Code.Brfalse: return new Ast.UnaryOperatorExpression(UnaryOperatorType.Not, arg1); + case Code.Brtrue: return arg1; + case Code.Beq: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Equality, arg2); + case Code.Bge: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThanOrEqual, arg2); + case Code.Bge_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThanOrEqual, arg2); + case Code.Bgt: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThan, arg2); + case Code.Bgt_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.GreaterThan, arg2); + case Code.Ble: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThanOrEqual, arg2); + case Code.Ble_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThanOrEqual, arg2); + case Code.Blt: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThan, arg2); + case Code.Blt_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.LessThan, arg2); + case Code.Bne_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.InEquality, arg2); + default: throw new Exception("Bad opcode"); + } + /* } else if (branch is ShortCircuitBranch) { ShortCircuitBranch scBranch = (ShortCircuitBranch)branch; switch(scBranch.Operator) { @@ -259,10 +248,9 @@ namespace Decompiler } else { throw new Exception("Bad type"); } + */ } - */ - static object TransformByteCode(MethodDefinition methodDef, ILExpression byteCode, List args) { try { diff --git a/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs b/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs index 485b6b528..0ad42e54b 100644 --- a/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs +++ b/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs @@ -10,6 +10,8 @@ namespace Decompiler.ControlFlow { public class ILAstOptimizer { + Dictionary labelToCfNode = new Dictionary(); + public void Optimize(ref List ast) { OptimizeRecursive(ref ast); @@ -17,6 +19,7 @@ namespace Decompiler.ControlFlow // Provide a container for the algorithms below ILBlock astBlock = new ILBlock(ast); + OrderNodes(astBlock); FlattenNestedMovableBlocks(astBlock); SimpleGotoRemoval(astBlock); RemoveDeadLabels(astBlock); @@ -52,14 +55,11 @@ namespace Decompiler.ControlFlow Optimize(ref tryCatchBlock.FinallyBlock.Body); } - // Sort the nodes in the original order - ast = ast.OrderBy(n => n.GetSelfAndChildrenRecursive().First().OriginalOrder).ToList(); - ast.Insert(0, new ILExpression(OpCodes.Br, entryLabel)); } - class ILMoveAbleBlock: ILBlock + class ILMoveableBlock: ILBlock { public int OriginalOrder; } @@ -75,7 +75,7 @@ namespace Decompiler.ControlFlow { List blocks = new List(); - ILMoveAbleBlock block = new ILMoveAbleBlock() { OriginalOrder = (nextBlockIndex++) }; + ILMoveableBlock block = new ILMoveableBlock() { OriginalOrder = (nextBlockIndex++) }; blocks.Add(block); entryLabel = new ILLabel() { Name = "Block_" + block.OriginalOrder }; block.Body.Add(entryLabel); @@ -96,7 +96,7 @@ namespace Decompiler.ControlFlow (currNode is ILExpression) && ((ILExpression)currNode).OpCode.IsBranch()) { ILBlock lastBlock = block; - block = new ILMoveAbleBlock() { OriginalOrder = (nextBlockIndex++) }; + block = new ILMoveableBlock() { OriginalOrder = (nextBlockIndex++) }; blocks.Add(block); // Explicit branch from one block to other @@ -126,7 +126,7 @@ namespace Decompiler.ControlFlow cfNodes.Add(exceptionalExit); // Create graph nodes - Dictionary labelToCfNode = new Dictionary(); + labelToCfNode = new Dictionary(); Dictionary astNodeToCfNode = new Dictionary(); foreach(ILNode node in nodes) { ControlFlowNode cfNode = new ControlFlowNode(index++, -1, ControlFlowNodeType.Normal); @@ -214,28 +214,7 @@ namespace Decompiler.ControlFlow } } - static HashSet FindDominatedNodes(HashSet nodes, ControlFlowNode head) - { - var exitNodes = head.DominanceFrontier.SelectMany(n => n.Predecessors); - HashSet agenda = new HashSet(exitNodes); - HashSet result = new HashSet(); - - while(agenda.Count > 0) { - ControlFlowNode addNode = agenda.First(); - agenda.Remove(addNode); - - if (nodes.Contains(addNode) && head.Dominates(addNode) && result.Add(addNode)) { - foreach (var predecessor in addNode.Predecessors) { - agenda.Add(predecessor); - } - } - } - result.Add(head); - - return result; - } - - static List FindConditions(HashSet nodes, ControlFlowNode entryNode) + List FindConditions(HashSet nodes, ControlFlowNode entryNode) { List result = new List(); @@ -244,25 +223,58 @@ namespace Decompiler.ControlFlow while(agenda.Count > 0) { ControlFlowNode node = agenda.Dequeue(); - if (nodes.Contains(node) && node.Outgoing.Count == 2) { - ILCondition condition = new ILCondition() { - ConditionBlock = new ILBlock((ILNode)node.UserData) - }; - HashSet frontiers = new HashSet(); - frontiers.UnionWith(node.Outgoing[0].Target.DominanceFrontier); - frontiers.UnionWith(node.Outgoing[1].Target.DominanceFrontier); - if (!frontiers.Contains(node.Outgoing[0].Target)) { - HashSet content1 = FindDominatedNodes(nodes, node.Outgoing[0].Target); - nodes.ExceptWith(content1); - condition.Block1 = new ILBlock(FindConditions(content1, node.Outgoing[0].Target)); - } - if (!frontiers.Contains(node.Outgoing[1].Target)) { - HashSet content2 = FindDominatedNodes(nodes, node.Outgoing[1].Target); - nodes.ExceptWith(content2); - condition.Block2 = new ILBlock(FindConditions(content2, node.Outgoing[1].Target)); + ILMoveableBlock block = node.UserData as ILMoveableBlock; + + // Find a block that represents a simple condition + if (nodes.Contains(node) && block != null && block.Body.Count == 3) { + + ILLabel label = block.Body[0] as ILLabel; + ILExpression condBranch = block.Body[1] as ILExpression; + ILExpression statBranch = block.Body[2] as ILExpression; + + if (label != null && + condBranch != null && condBranch.Operand is ILLabel && condBranch.Arguments.Count > 0 && + statBranch != null && statBranch.Operand is ILLabel && statBranch.Arguments.Count == 0) + { + ControlFlowNode condTarget; + ControlFlowNode statTarget; + if (labelToCfNode.TryGetValue((ILLabel)condBranch.Operand, out condTarget) && + labelToCfNode.TryGetValue((ILLabel)statBranch.Operand, out statTarget)) + { + ILCondition condition = new ILCondition() { + Condition = condBranch, + TrueTarget = (ILLabel)condBranch.Operand, + FalseTarget = (ILLabel)statBranch.Operand + }; + + // TODO: Use the labels to ensre correctness + // TODO: Ensure that the labels are considered live in dead label removal + + // Replace the two branches with a conditional structure + block.Body.Remove(condBranch); + block.Body.Remove(statBranch); + block.Body.Add(condition); + result.Add(block); + + // Pull in the conditional code + HashSet frontiers = new HashSet(); + frontiers.UnionWith(condTarget.DominanceFrontier); + frontiers.UnionWith(statTarget.DominanceFrontier); + + if (!frontiers.Contains(condTarget)) { + HashSet content = FindDominatedNodes(nodes, condTarget); + nodes.ExceptWith(content); + condition.TrueBlock = new ILBlock(FindConditions(content, condTarget)); + } + if (!frontiers.Contains(statTarget)) { + HashSet content = FindDominatedNodes(nodes, statTarget); + nodes.ExceptWith(content); + condition.FalseBlock = new ILBlock(FindConditions(content, statTarget)); + } + + nodes.Remove(node); + } } - nodes.Remove(node); - result.Add(condition); } // Using the dominator tree should ensure we find the the widest loop first @@ -279,6 +291,27 @@ namespace Decompiler.ControlFlow return result; } + static HashSet FindDominatedNodes(HashSet nodes, ControlFlowNode head) + { + var exitNodes = head.DominanceFrontier.SelectMany(n => n.Predecessors); + HashSet agenda = new HashSet(exitNodes); + HashSet result = new HashSet(); + + while(agenda.Count > 0) { + ControlFlowNode addNode = agenda.First(); + agenda.Remove(addNode); + + if (nodes.Contains(addNode) && head.Dominates(addNode) && result.Add(addNode)) { + foreach (var predecessor in addNode.Predecessors) { + agenda.Add(predecessor); + } + } + } + result.Add(head); + + return result; + } + /* public enum ShortCircuitOperator @@ -345,6 +378,15 @@ namespace Decompiler.ControlFlow */ + void OrderNodes(ILBlock ast) + { + var blocks = ast.GetSelfAndChildrenRecursive().ToList(); + ILMoveableBlock first = new ILMoveableBlock() { OriginalOrder = -1 }; + foreach(ILBlock block in blocks) { + block.Body = block.Body.OrderBy(n => (n.GetSelfAndChildrenRecursive().FirstOrDefault() ?? first).OriginalOrder).ToList(); + } + } + /// /// Flattens all nested movable blocks, except the the top level 'node' argument /// @@ -355,8 +397,8 @@ namespace Decompiler.ControlFlow List flatBody = new List(); foreach (ILNode child in block.Body) { FlattenNestedMovableBlocks(child); - if (child is ILMoveAbleBlock) { - flatBody.AddRange(((ILMoveAbleBlock)child).Body); + if (child is ILMoveableBlock) { + flatBody.AddRange(((ILMoveableBlock)child).Body); } else { flatBody.Add(child); } diff --git a/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs b/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs index db96898e5..584390299 100644 --- a/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs +++ b/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs @@ -174,15 +174,17 @@ namespace Decompiler public class ILCondition: ILNode { - public ILBlock ConditionBlock; - public ILBlock Block1; - public ILBlock Block2; + public ILExpression Condition; + public ILBlock TrueBlock; // Branch was taken + public ILLabel TrueTarget; // Entry label + public ILBlock FalseBlock; // Fall-though + public ILLabel FalseTarget; // Entry label public override IEnumerable GetChildren() { - yield return ConditionBlock; - yield return Block1; - yield return Block2; + yield return Condition; + yield return TrueBlock; + yield return FalseBlock; } } }