Browse Source

Fix DecimalConstantTransform.

pull/170/merge
Daniel Grunwald 15 years ago
parent
commit
30fe30c236
  1. 137
      ICSharpCode.Decompiler/Ast/Transforms/DecimalConstantTransform.cs
  2. 2
      ICSharpCode.Decompiler/Ast/Transforms/TransformationPipeline.cs

137
ICSharpCode.Decompiler/Ast/Transforms/DecimalConstantTransform.cs

@ -17,145 +17,42 @@ @@ -17,145 +17,42 @@
// DEALINGS IN THE SOFTWARE.
using System;
using System.Collections.Generic;
using System.Linq;
using ICSharpCode.NRefactory.CSharp;
using ICSharpCode.NRefactory.PatternMatching;
using Mono.Cecil;
namespace ICSharpCode.Decompiler.Ast.Transforms
{
/// <summary>
/// Description of DecimalConstantTransform.
/// Transforms decimal constant fields.
/// </summary>
public class DecimalConstantTransform : DepthFirstAstVisitor<object, object>, IAstTransform
{
static readonly ICSharpCode.NRefactory.CSharp.Attribute decimalConstantAttribute = new ICSharpCode.NRefactory.CSharp.Attribute {
Type = new SimpleType("DecimalConstant"),
Arguments = { new Repeat(new AnyNode()) }
};
static readonly FieldDeclaration decimalConstantPattern = new FieldDeclaration {
Attributes = {
new AttributeSection {
Attributes = {
decimalConstantAttribute
}
}
},
Modifiers = Modifiers.Any,
Variables = { new Repeat(new AnyNode()) },
ReturnType = new PrimitiveType("decimal")
};
Expression ConstructExpression(string typeName, string member)
{
return new AssignmentExpression(
new MemberReferenceExpression(
new TypeReferenceExpression(new SimpleType(typeName)),
member
),
AssignmentOperatorType.Assign,
new AnyNode()
);
}
class ClassInfo
{
public ClassInfo()
{
Fields = new List<string>();
}
public List<string> Fields { get; private set; }
public TypeDeclaration Declaration { get; set; }
}
Stack<ClassInfo> replaceableFields;
public DecimalConstantTransform()
{
}
void IAstTransform.Run(AstNode compilationUnit)
{
this.replaceableFields = new Stack<ClassInfo>();
compilationUnit.AcceptVisitor(this, null);
}
static readonly PrimitiveType decimalType = new PrimitiveType("decimal");
public override object VisitFieldDeclaration(FieldDeclaration fieldDeclaration, object data)
{
base.VisitFieldDeclaration(fieldDeclaration, data);
var match = decimalConstantPattern.Match(fieldDeclaration);
if (match.Success) {
Modifiers pattern = Modifiers.Static | Modifiers.Readonly;
if ((fieldDeclaration.Modifiers & pattern) == pattern) {
replaceableFields.Peek().Fields.AddRange(fieldDeclaration.Variables.Select(v => v.Name));
fieldDeclaration.ReplaceWith(ReplaceFieldWithConstant);
}
}
return null;
}
AstNode ReplaceFieldWithConstant(AstNode node)
{
var old = node as FieldDeclaration;
var fd = new FieldDeclaration {
Modifiers = old.Modifiers & ~(Modifiers.Readonly | Modifiers.Static) | Modifiers.Const,
ReturnType = new PrimitiveType("decimal")
};
var foundAttr = old.Attributes.SelectMany(section => section.Attributes)
.First(a => decimalConstantAttribute.IsMatch(a));
foundAttr.Remove();
foreach (var attr in old.Attributes.Where(section => section.Attributes.Count == 0))
attr.Remove();
old.Attributes.MoveTo(fd.Attributes);
old.Variables.MoveTo(fd.Variables);
fd.Variables.Single().Initializer = new PrimitiveExpression(CreateDecimalValue(foundAttr));
return fd;
}
object CreateDecimalValue(ICSharpCode.NRefactory.CSharp.Attribute foundAttr)
{
byte scale = (byte)((PrimitiveExpression)foundAttr.Arguments.ElementAt(0)).Value;
byte sign = (byte)((PrimitiveExpression)foundAttr.Arguments.ElementAt(1)).Value;
int high = (int)(uint)((PrimitiveExpression)foundAttr.Arguments.ElementAt(2)).Value;
int mid = (int)(uint)((PrimitiveExpression)foundAttr.Arguments.ElementAt(3)).Value;
int low = (int)(uint)((PrimitiveExpression)foundAttr.Arguments.ElementAt(4)).Value;
return new Decimal(low, mid, high, sign == 1, scale);
}
public override object VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data)
{
if ((constructorDeclaration.Modifiers & Modifiers.Static) == Modifiers.Static && replaceableFields.Count > 0) {
var current = replaceableFields.Peek();
foreach (var fieldName in current.Fields) {
var pattern = ConstructExpression(current.Declaration.Name, fieldName);
foreach (var expr in constructorDeclaration.Body
.OfType<ExpressionStatement>()) {
if (pattern.IsMatch(expr.Expression))
expr.Remove();
const Modifiers staticReadOnly = Modifiers.Static | Modifiers.Readonly;
if ((fieldDeclaration.Modifiers & staticReadOnly) == staticReadOnly && decimalType.IsMatch(fieldDeclaration.ReturnType)) {
foreach (var attributeSection in fieldDeclaration.Attributes) {
foreach (var attribute in attributeSection.Attributes) {
TypeReference tr = attribute.Type.Annotation<TypeReference>();
if (tr != null && tr.Name == "DecimalConstantAttribute" && tr.Namespace == "System.Runtime.CompilerServices") {
attribute.Remove();
if (attributeSection.Attributes.Count == 0)
attributeSection.Remove();
fieldDeclaration.Modifiers = (fieldDeclaration.Modifiers & ~staticReadOnly) | Modifiers.Const;
return null;
}
}
}
}
return null;
}
public override object VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data)
public void Run(AstNode compilationUnit)
{
replaceableFields.Push(new ClassInfo() { Declaration = typeDeclaration });
base.VisitTypeDeclaration(typeDeclaration, data);
replaceableFields.Pop();
return null;
compilationUnit.AcceptVisitor(this, null);
}
}
}

2
ICSharpCode.Decompiler/Ast/Transforms/TransformationPipeline.cs

@ -32,7 +32,6 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -32,7 +32,6 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
public static IAstTransform[] CreatePipeline(DecompilerContext context)
{
return new IAstTransform[] {
new DecimalConstantTransform(),
new PushNegation(),
new DelegateConstruction(context),
new PatternStatementTransform(context),
@ -41,6 +40,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -41,6 +40,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
new AddCheckedBlocks(),
new DeclareVariables(context), // should run after most transforms that modify statements
new ConvertConstructorCallIntoInitializer(), // must run after DeclareVariables
new DecimalConstantTransform(),
new IntroduceUsingDeclarations(context),
new IntroduceExtensionMethods(context), // must run after IntroduceUsingDeclarations
new IntroduceQueryExpressions(context), // must run after IntroduceExtensionMethods

Loading…
Cancel
Save