Browse Source

Optimize PatternStatementTransform.

pull/100/head
Daniel Grunwald 15 years ago
parent
commit
039483ddbd
  1. 32
      ICSharpCode.Decompiler/Ast/Transforms/ContextTrackingVisitor.cs
  2. 2
      ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs
  3. 243
      ICSharpCode.Decompiler/Ast/Transforms/PatternStatementTransform.cs

32
ICSharpCode.Decompiler/Ast/Transforms/ContextTrackingVisitor.cs

@ -11,7 +11,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -11,7 +11,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
/// <summary>
/// Base class for AST visitors that need the current type/method context info.
/// </summary>
public abstract class ContextTrackingVisitor : DepthFirstAstVisitor<object, object>, IAstTransform
public abstract class ContextTrackingVisitor<TResult> : DepthFirstAstVisitor<object, TResult>, IAstTransform
{
protected readonly DecompilerContext context;
@ -22,7 +22,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -22,7 +22,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
this.context = context;
}
public override object VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data)
public override TResult VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data)
{
TypeDefinition oldType = context.CurrentType;
try {
@ -33,7 +33,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -33,7 +33,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
}
public override object VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data)
public override TResult VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data)
{
Debug.Assert(context.CurrentMethod == null);
try {
@ -44,7 +44,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -44,7 +44,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
}
public override object VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data)
public override TResult VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data)
{
Debug.Assert(context.CurrentMethod == null);
try {
@ -55,7 +55,29 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -55,7 +55,29 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
}
public override object VisitAccessor(Accessor accessor, object data)
public override TResult VisitDestructorDeclaration(DestructorDeclaration destructorDeclaration, object data)
{
Debug.Assert(context.CurrentMethod == null);
try {
context.CurrentMethod = destructorDeclaration.Annotation<MethodDefinition>();
return base.VisitDestructorDeclaration(destructorDeclaration, data);
} finally {
context.CurrentMethod = null;
}
}
public override TResult VisitOperatorDeclaration(OperatorDeclaration operatorDeclaration, object data)
{
Debug.Assert(context.CurrentMethod == null);
try {
context.CurrentMethod = operatorDeclaration.Annotation<MethodDefinition>();
return base.VisitOperatorDeclaration(operatorDeclaration, data);
} finally {
context.CurrentMethod = null;
}
}
public override TResult VisitAccessor(Accessor accessor, object data)
{
Debug.Assert(context.CurrentMethod == null);
try {

2
ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs

@ -18,7 +18,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -18,7 +18,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
/// For anonymous methods, creates an AnonymousMethodExpression.
/// Also gets rid of any "Display Classes" left over after inlining an anonymous method.
/// </summary>
public class DelegateConstruction : ContextTrackingVisitor
public class DelegateConstruction : ContextTrackingVisitor<object>
{
internal sealed class Annotation
{

243
ICSharpCode.Decompiler/Ast/Transforms/PatternStatementTransform.cs

@ -3,6 +3,7 @@ @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using ICSharpCode.NRefactory.CSharp;
using ICSharpCode.NRefactory.CSharp.PatternMatching;
@ -13,37 +14,105 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -13,37 +14,105 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
/// <summary>
/// Finds the expanded form of using statements using pattern matching and replaces it with a UsingStatement.
/// </summary>
public class PatternStatementTransform : IAstTransform
public sealed class PatternStatementTransform : ContextTrackingVisitor<AstNode>, IAstTransform
{
DecompilerContext context;
public PatternStatementTransform(DecompilerContext context) : base(context)
{
}
#region Visitor Overrides
protected override AstNode VisitChildren(AstNode node, object data)
{
// Go through the children, and keep visiting a node as long as it changes.
// Because some transforms delete/replace nodes before and after the node being transformed, we rely
// on the transform's return value to know where we need to keep iterating.
for (AstNode child = node.FirstChild; child != null; child = child.NextSibling) {
AstNode oldChild;
do {
oldChild = child;
child = child.AcceptVisitor(this, data);
Debug.Assert(child != null && child.Parent == node);
} while (child != oldChild);
}
return node;
}
public override AstNode VisitVariableDeclarationStatement(VariableDeclarationStatement variableDeclarationStatement, object data)
{
AstNode result;
if (context.Settings.UsingStatement) {
result = TransformUsings(variableDeclarationStatement);
if (result != null)
return result;
}
result = TransformFor(variableDeclarationStatement);
if (result != null)
return result;
if (context.Settings.LockStatement) {
result = TransformLock(variableDeclarationStatement);
if (result != null)
return result;
}
return base.VisitVariableDeclarationStatement(variableDeclarationStatement, data);
}
public override AstNode VisitUsingStatement(UsingStatement usingStatement, object data)
{
if (context.Settings.ForEachStatement) {
AstNode result = TransformForeach(usingStatement);
if (result != null)
return result;
}
return base.VisitUsingStatement(usingStatement, data);
}
public override AstNode VisitWhileStatement(WhileStatement whileStatement, object data)
{
return TransformDoWhile(whileStatement) ?? base.VisitWhileStatement(whileStatement, data);
}
public override AstNode VisitIfElseStatement(IfElseStatement ifElseStatement, object data)
{
if (context.Settings.SwitchStatementOnString) {
AstNode result = TransformSwitchOnString(ifElseStatement);
if (result != null)
return result;
}
return base.VisitIfElseStatement(ifElseStatement, data);
}
public override AstNode VisitPropertyDeclaration(PropertyDeclaration propertyDeclaration, object data)
{
if (context.Settings.AutomaticProperties) {
AstNode result = TransformAutomaticProperties(propertyDeclaration);
if (result != null)
return result;
}
return base.VisitPropertyDeclaration(propertyDeclaration, data);
}
public override AstNode VisitCustomEventDeclaration(CustomEventDeclaration eventDeclaration, object data)
{
// first apply transforms to the accessor bodies
base.VisitCustomEventDeclaration(eventDeclaration, data);
if (context.Settings.AutomaticEvents) {
AstNode result = TransformAutomaticEvents(eventDeclaration);
if (result != null)
return result;
}
return eventDeclaration;
}
public PatternStatementTransform(DecompilerContext context)
public override AstNode VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data)
{
if (context == null)
throw new ArgumentNullException("context");
this.context = context;
return TransformDestructor(methodDeclaration) ?? base.VisitMethodDeclaration(methodDeclaration, data);
}
public void Run(AstNode compilationUnit)
public override AstNode VisitTryCatchStatement(TryCatchStatement tryCatchStatement, object data)
{
if (context.Settings.UsingStatement)
TransformUsings(compilationUnit);
if (context.Settings.ForEachStatement)
TransformForeach(compilationUnit);
TransformFor(compilationUnit);
TransformDoWhile(compilationUnit);
if (context.Settings.LockStatement)
TransformLock(compilationUnit);
if (context.Settings.SwitchStatementOnString)
TransformSwitchOnString(compilationUnit);
if (context.Settings.AutomaticProperties)
TransformAutomaticProperties(compilationUnit);
if (context.Settings.AutomaticEvents)
TransformAutomaticEvents(compilationUnit);
TransformDestructor(compilationUnit);
TransformTryCatchFinally(compilationUnit);
return TransformTryCatchFinally(tryCatchStatement) ?? base.VisitTryCatchStatement(tryCatchStatement, data);
}
#endregion
/// <summary>
/// $type $variable = $initializer;
@ -94,31 +163,30 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -94,31 +163,30 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
};
public void TransformUsings(AstNode compilationUnit)
public UsingStatement TransformUsings(VariableDeclarationStatement node)
{
foreach (AstNode node in compilationUnit.Descendants.OfType<VariableDeclarationStatement>().ToArray()) {
Match m1 = variableDeclPattern.Match(node);
if (m1 == null) continue;
if (m1 == null) return null;
AstNode tryCatch = node.NextSibling;
while (simpleVariableDefinition.Match(tryCatch) != null)
tryCatch = tryCatch.NextSibling;
Match m2 = usingTryCatchPattern.Match(tryCatch);
if (m2 == null) continue;
if (m2 == null) return null;
if (m1.Get<VariableInitializer>("variable").Single().Name == m2.Get<IdentifierExpression>("ident").Single().Identifier) {
if (m2.Has("valueType")) {
// if there's no if(x!=null), then it must be a value type
TypeReference tr = m1.Get<AstType>("type").Single().Annotation<TypeReference>();
if (tr == null || !tr.IsValueType)
continue;
return null;
}
BlockStatement body = m2.Get<BlockStatement>("body").Single();
tryCatch.ReplaceWith(
new UsingStatement {
ResourceAcquisition = node.Detach(),
EmbeddedStatement = body.Detach()
});
}
UsingStatement usingStatement = new UsingStatement();
usingStatement.ResourceAcquisition = node.Detach();
usingStatement.EmbeddedStatement = body.Detach();
tryCatch.ReplaceWith(usingStatement);
return usingStatement;
}
return null;
}
#endregion
@ -185,29 +253,28 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -185,29 +253,28 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}.ToStatement()
};
public void TransformForeach(AstNode compilationUnit)
public ForeachStatement TransformForeach(UsingStatement node)
{
foreach (AstNode node in compilationUnit.Descendants.OfType<UsingStatement>().ToArray()) {
Match m = foreachPattern.Match(node);
if (m == null)
continue;
return null;
VariableInitializer enumeratorVar = m.Get<VariableInitializer>("enumeratorVariable").Single();
VariableInitializer itemVar = m.Get<VariableInitializer>("itemVariable").Single();
if (m.Has("itemVariableInsideLoop") && itemVar.Annotation<DelegateConstruction.CapturedVariableAnnotation>() != null) {
// cannot move captured variables out of loops
continue;
return null;
}
BlockStatement newBody = new BlockStatement();
foreach (Statement stmt in m.Get<Statement>("statement"))
newBody.Add(stmt.Detach());
node.ReplaceWith(
new ForeachStatement {
ForeachStatement foreachStatement = new ForeachStatement {
VariableType = m.Get<AstType>("itemType").Single().Detach(),
VariableName = itemVar.Name,
InExpression = m.Get<Expression>("collection").Single().Detach(),
EmbeddedStatement = newBody
});
}
};
node.ReplaceWith(foreachStatement);
return foreachStatement;
}
#endregion
@ -232,32 +299,30 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -232,32 +299,30 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
}};
public void TransformFor(AstNode compilationUnit)
public ForStatement TransformFor(VariableDeclarationStatement node)
{
foreach (AstNode node in compilationUnit.Descendants.OfType<VariableDeclarationStatement>().ToArray()) {
Match m1 = variableDeclPattern.Match(node);
if (m1 == null) continue;
if (m1 == null) return null;
AstNode next = node.NextSibling;
while (simpleVariableDefinition.Match(next) != null)
next = next.NextSibling;
Match m2 = forPattern.Match(next);
if (m2 == null) continue;
if (m2 == null) return null;
// ensure the variable in the for pattern is the same as in the declaration
if (m1.Get<VariableInitializer>("variable").Single().Name != m2.Get<IdentifierExpression>("ident").Single().Identifier)
continue;
return null;
WhileStatement loop = (WhileStatement)next;
node.Remove();
BlockStatement newBody = new BlockStatement();
foreach (Statement stmt in m2.Get<Statement>("statement"))
newBody.Add(stmt.Detach());
loop.ReplaceWith(
new ForStatement {
Initializers = { (VariableDeclarationStatement)node },
Condition = loop.Condition.Detach(),
Iterators = { m2.Get<Statement>("increment").Single().Detach() },
EmbeddedStatement = newBody
});
}
ForStatement forStatement = new ForStatement();
forStatement.Initializers.Add(node);
forStatement.Condition = loop.Condition.Detach();
forStatement.Iterators.Add(m2.Get<Statement>("increment").Single().Detach());
forStatement.EmbeddedStatement = newBody;
loop.ReplaceWith(forStatement);
return forStatement;
}
#endregion
@ -274,9 +339,8 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -274,9 +339,8 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
}};
public void TransformDoWhile(AstNode compilationUnit)
public DoWhileStatement TransformDoWhile(WhileStatement whileLoop)
{
foreach (WhileStatement whileLoop in compilationUnit.Descendants.OfType<WhileStatement>().ToArray()) {
Match m = doWhilePattern.Match(whileLoop);
if (m != null) {
DoWhileStatement doLoop = new DoWhileStatement();
@ -303,8 +367,9 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -303,8 +367,9 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
doLoop.Parent.InsertChildBefore(doLoop, varDecl, BlockStatement.StatementRole);
}
}
return doLoop;
}
}
return null;
}
#endregion
@ -339,16 +404,15 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -339,16 +404,15 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
}};
public void TransformLock(AstNode compilationUnit)
public LockStatement TransformLock(VariableDeclarationStatement node)
{
foreach (AstNode node in compilationUnit.Descendants.OfType<VariableDeclarationStatement>().ToArray()) {
Match m1 = lockFlagInitPattern.Match(node);
if (m1 == null) continue;
if (m1 == null) return null;
AstNode tryCatch = node.NextSibling;
while (simpleVariableDefinition.Match(tryCatch) != null)
tryCatch = tryCatch.NextSibling;
Match m2 = lockTryCatchPattern.Match(tryCatch);
if (m2 == null) continue;
if (m2 == null) return null;
if (m1.Get<VariableInitializer>("variable").Single().Name == m2.Get<IdentifierExpression>("flag").Single().Identifier) {
Expression enter = m2.Get<Expression>("enter").Single();
IdentifierExpression exit = m2.Get<IdentifierExpression>("exit").Single();
@ -356,9 +420,9 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -356,9 +420,9 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
// If exit and enter are not the same, then enter must be "exit = ..."
AssignmentExpression assign = enter as AssignmentExpression;
if (assign == null)
continue;
return null;
if (exit.Match(assign.Left) == null)
continue;
return null;
enter = assign.Right;
// Remove 'exit' variable:
bool ok = false;
@ -371,7 +435,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -371,7 +435,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
}
if (!ok)
continue;
return null;
}
// transform the code into a lock statement:
LockStatement l = new LockStatement();
@ -380,8 +444,9 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -380,8 +444,9 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
((BlockStatement)l.EmbeddedStatement).Statements.First().Remove(); // Remove 'Enter()' call
tryCatch.ReplaceWith(l);
node.Remove(); // remove flag variable
return l;
}
}
return null;
}
#endregion
@ -428,23 +493,22 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -428,23 +493,22 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
FalseStatement = new OptionalNode("nullStmt", new BlockStatement { Statements = { new Repeat(new AnyNode()) } })
};
public void TransformSwitchOnString(AstNode compilationUnit)
public SwitchStatement TransformSwitchOnString(IfElseStatement node)
{
foreach (AstNode node in compilationUnit.Descendants.OfType<IfElseStatement>().ToArray()) {
Match m = switchOnStringPattern.Match(node);
if (m == null)
continue;
return null;
if (m.Has("nonNullDefaultStmt") && !m.Has("nullStmt"))
continue;
return null;
// switchVar must be the same as switchExpr; or switchExpr must be an assignment and switchVar the left side of that assignment
if (m.Get("switchVar").Single().Match(m.Get("switchExpr").Single()) == null) {
AssignmentExpression assign = m.Get("switchExpr").Single() as AssignmentExpression;
if (m.Get("switchVar").Single().Match(assign.Left) == null)
continue;
return null;
}
FieldReference cachedDictField = m.Get("cachedDict").Single().Annotation<FieldReference>();
if (cachedDictField == null || !cachedDictField.DeclaringType.Name.StartsWith("<PrivateImplementationDetails>", StringComparison.Ordinal))
continue;
return null;
List<Statement> dictCreation = m.Get<BlockStatement>("dictCreation").Single().Statements.ToList();
List<KeyValuePair<string, int>> dict = BuildDictionary(dictCreation);
SwitchStatement sw = m.Get<SwitchStatement>("switch").Single();
@ -481,7 +545,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -481,7 +545,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
}
node.ReplaceWith(sw);
}
return sw;
}
List<KeyValuePair<string, int>> BuildDictionary(List<Statement> dictCreation)
@ -527,14 +591,13 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -527,14 +591,13 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
}}};
void TransformAutomaticProperties(AstNode compilationUnit)
PropertyDeclaration TransformAutomaticProperties(PropertyDeclaration property)
{
foreach (var property in compilationUnit.Descendants.OfType<PropertyDeclaration>()) {
PropertyDefinition cecilProperty = property.Annotation<PropertyDefinition>();
if (cecilProperty == null || cecilProperty.GetMethod == null || cecilProperty.SetMethod == null)
continue;
return null;
if (!(cecilProperty.GetMethod.IsCompilerGenerated() && cecilProperty.SetMethod.IsCompilerGenerated()))
continue;
return null;
Match m = automaticPropertyPattern.Match(property);
if (m != null) {
FieldDefinition field = m.Get("fieldReference").Single().Annotation<FieldReference>().ResolveWithinSameModule();
@ -545,7 +608,8 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -545,7 +608,8 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
property.Setter.Body = null;
}
}
}
// Since the event instance is not changed, we can continue in the visitor as usual, so return null
return null;
}
void RemoveCompilerGeneratedAttribute(AstNodeCollection<AttributeSection> attributeSections)
@ -625,15 +689,14 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -625,15 +689,14 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
return combineMethod.DeclaringType.FullName == "System.Delegate";
}
void TransformAutomaticEvents(AstNode compilationUnit)
EventDeclaration TransformAutomaticEvents(CustomEventDeclaration ev)
{
foreach (var ev in compilationUnit.Descendants.OfType<CustomEventDeclaration>().ToArray()) {
Match m1 = automaticEventPatternV4.Match(ev.AddAccessor);
if (!CheckAutomaticEventV4Match(m1, ev, true))
continue;
return null;
Match m2 = automaticEventPatternV4.Match(ev.RemoveAccessor);
if (!CheckAutomaticEventV4Match(m2, ev, false))
continue;
return null;
EventDeclaration ed = new EventDeclaration();
ev.Attributes.MoveTo(ed.Attributes);
ed.ReturnType = ev.ReturnType.Detach();
@ -651,7 +714,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -651,7 +714,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
ev.ReplaceWith(ed);
}
return ed;
}
#endregion
@ -671,9 +734,8 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -671,9 +734,8 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
};
void TransformDestructor(AstNode compilationUnit)
DestructorDeclaration TransformDestructor(MethodDeclaration methodDef)
{
foreach (var methodDef in compilationUnit.Descendants.OfType<MethodDeclaration>()) {
Match m = destructorPattern.Match(methodDef);
if (m != null) {
DestructorDeclaration dd = new DestructorDeclaration();
@ -682,8 +744,9 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -682,8 +744,9 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
dd.Body = m.Get<BlockStatement>("body").Single().Detach();
dd.Name = AstBuilder.CleanName(context.CurrentType.Name);
methodDef.ReplaceWith(dd);
return dd;
}
}
return null;
}
#endregion
@ -702,15 +765,15 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -702,15 +765,15 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
/// Simplify nested 'try { try {} catch {} } finally {}'.
/// This transformation must run after the using/lock tranformations.
/// </summary>
void TransformTryCatchFinally(AstNode compilationUnit)
TryCatchStatement TransformTryCatchFinally(TryCatchStatement tryFinally)
{
foreach (var tryFinally in compilationUnit.Descendants.OfType<TryCatchStatement>()) {
if (tryCatchFinallyPattern.Match(tryFinally) != null) {
TryCatchStatement tryCatch = (TryCatchStatement)tryFinally.TryBlock.Statements.Single();
tryFinally.TryBlock = tryCatch.TryBlock.Detach();
tryCatch.CatchClauses.MoveTo(tryFinally.CatchClauses);
}
}
// Since the tryFinally instance is not changed, we can continue in the visitor as usual, so return null
return null;
}
#endregion

Loading…
Cancel
Save