diff --git a/ICSharpCode.Decompiler/Ast/AstBuilder.cs b/ICSharpCode.Decompiler/Ast/AstBuilder.cs index 896b79d05..0f684e80b 100644 --- a/ICSharpCode.Decompiler/Ast/AstBuilder.cs +++ b/ICSharpCode.Decompiler/Ast/AstBuilder.cs @@ -19,7 +19,15 @@ namespace ICSharpCode.Decompiler.Ast using Ast = ICSharpCode.NRefactory.CSharp; using ClassType = ICSharpCode.NRefactory.TypeSystem.ClassType; using VarianceModifier = ICSharpCode.NRefactory.TypeSystem.VarianceModifier; - + + [Flags] + public enum ConvertTypeOptions + { + None = 0, + IncludeNamespace = 1, + IncludeTypeParameterDefinitions = 2 + } + public class AstBuilder { DecompilerContext context = new DecompilerContext(); @@ -89,21 +97,16 @@ namespace ICSharpCode.Decompiler.Ast public void AddAssembly(AssemblyDefinition assemblyDefinition, bool onlyAssemblyLevel = false) { - astCompileUnit.AddChild( - new UsingDeclaration { - Import = new SimpleType("System") - }, CompilationUnit.MemberRole); - ConvertCustomAttributes(astCompileUnit, assemblyDefinition, AttributeTarget.Assembly); ConvertCustomAttributes(astCompileUnit, assemblyDefinition.MainModule, AttributeTarget.Module); - + if (!onlyAssemblyLevel) { - foreach (TypeDefinition typeDef in assemblyDefinition.MainModule.Types) - { - // Skip nested types - they will be added by the parent type - if (typeDef.DeclaringType != null) continue; + foreach (TypeDefinition typeDef in assemblyDefinition.MainModule.Types) { // Skip the class if (typeDef.Name == "") continue; + // Skip any hidden types + if (AstBuilder.MemberIsHidden(typeDef, context.Settings)) + continue; AddType(typeDef); } @@ -264,7 +267,7 @@ namespace ICSharpCode.Decompiler.Ast transform.Run(astCompileUnit); } - string CleanName(string name) + internal static string CleanName(string name) { int pos = name.LastIndexOf('`'); if (pos >= 0) @@ -283,13 +286,13 @@ namespace ICSharpCode.Decompiler.Ast /// a type system type reference. /// Attributes associated with the Cecil type reference. /// This is used to support the 'dynamic' type. - public static AstType ConvertType(TypeReference type, ICustomAttributeProvider typeAttributes = null) + public static AstType ConvertType(TypeReference type, ICustomAttributeProvider typeAttributes = null, ConvertTypeOptions options = ConvertTypeOptions.None) { int typeIndex = 0; - return ConvertType(type, typeAttributes, ref typeIndex); + return ConvertType(type, typeAttributes, ref typeIndex, options); } - static AstType ConvertType(TypeReference type, ICustomAttributeProvider typeAttributes, ref int typeIndex) + static AstType ConvertType(TypeReference type, ICustomAttributeProvider typeAttributes, ref int typeIndex, ConvertTypeOptions options) { while (type is OptionalModifierType || type is RequiredModifierType) { type = ((TypeSpecification)type).ElementType; @@ -301,39 +304,44 @@ namespace ICSharpCode.Decompiler.Ast if (type is Mono.Cecil.ByReferenceType) { typeIndex++; // by reference type cannot be represented in C#; so we'll represent it as a pointer instead - return ConvertType((type as Mono.Cecil.ByReferenceType).ElementType, typeAttributes, ref typeIndex) + return ConvertType((type as Mono.Cecil.ByReferenceType).ElementType, typeAttributes, ref typeIndex, options) .MakePointerType(); } else if (type is Mono.Cecil.PointerType) { typeIndex++; - return ConvertType((type as Mono.Cecil.PointerType).ElementType, typeAttributes, ref typeIndex) + return ConvertType((type as Mono.Cecil.PointerType).ElementType, typeAttributes, ref typeIndex, options) .MakePointerType(); } else if (type is Mono.Cecil.ArrayType) { typeIndex++; - return ConvertType((type as Mono.Cecil.ArrayType).ElementType, typeAttributes, ref typeIndex) + return ConvertType((type as Mono.Cecil.ArrayType).ElementType, typeAttributes, ref typeIndex, options) .MakeArrayType((type as Mono.Cecil.ArrayType).Rank); } else if (type is GenericInstanceType) { GenericInstanceType gType = (GenericInstanceType)type; if (gType.ElementType.Namespace == "System" && gType.ElementType.Name == "Nullable`1" && gType.GenericArguments.Count == 1) { typeIndex++; return new ComposedType { - BaseType = ConvertType(gType.GenericArguments[0], typeAttributes, ref typeIndex), + BaseType = ConvertType(gType.GenericArguments[0], typeAttributes, ref typeIndex, options), HasNullableSpecifier = true }; } - AstType baseType = ConvertType(gType.ElementType, typeAttributes, ref typeIndex); + AstType baseType = ConvertType(gType.ElementType, typeAttributes, ref typeIndex, options & ~ConvertTypeOptions.IncludeTypeParameterDefinitions); List typeArguments = new List(); foreach (var typeArgument in gType.GenericArguments) { typeIndex++; - typeArguments.Add(ConvertType(typeArgument, typeAttributes, ref typeIndex)); + typeArguments.Add(ConvertType(typeArgument, typeAttributes, ref typeIndex, options)); } ApplyTypeArgumentsTo(baseType, typeArguments); return baseType; } else if (type is GenericParameter) { return new SimpleType(type.Name); } else if (type.IsNested) { - AstType typeRef = ConvertType(type.DeclaringType, typeAttributes, ref typeIndex); + AstType typeRef = ConvertType(type.DeclaringType, typeAttributes, ref typeIndex, options & ~ConvertTypeOptions.IncludeTypeParameterDefinitions); string namepart = ICSharpCode.NRefactory.TypeSystem.ReflectionHelper.SplitTypeParameterCountFromReflectionName(type.Name); - return new MemberType { Target = typeRef, MemberName = namepart }.WithAnnotation(type); + MemberType memberType = new MemberType { Target = typeRef, MemberName = namepart }; + memberType.AddAnnotation(type); + if ((options & ConvertTypeOptions.IncludeTypeParameterDefinitions) == ConvertTypeOptions.IncludeTypeParameterDefinitions) { + AddTypeParameterDefininitionsTo(type, memberType); + } + return memberType; } else { string ns = type.Namespace ?? string.Empty; string name = type.Name; @@ -382,18 +390,35 @@ namespace ICSharpCode.Decompiler.Ast name = ICSharpCode.NRefactory.TypeSystem.ReflectionHelper.SplitTypeParameterCountFromReflectionName(name); - // TODO: Until we can simplify type with 'using', use just the name without namesapce - return new SimpleType(name).WithAnnotation(type); + AstType astType; + if ((options & ConvertTypeOptions.IncludeNamespace) == ConvertTypeOptions.IncludeNamespace && ns.Length > 0) { + string[] parts = ns.Split('.'); + AstType nsType = new SimpleType(parts[0]); + for (int i = 1; i < parts.Length; i++) { + nsType = new MemberType { Target = nsType, MemberName = parts[i] }; + } + astType = new MemberType { Target = nsType, MemberName = name }; + } else { + astType = new SimpleType(name); + } + astType.AddAnnotation(type); -// if (ns.Length == 0) -// return new SimpleType(name).WithAnnotation(type); -// string[] parts = ns.Split('.'); -// AstType nsType = new SimpleType(parts[0]); -// for (int i = 1; i < parts.Length; i++) { -// nsType = new MemberType { Target = nsType, MemberName = parts[i] }; -// } -// return new MemberType { Target = nsType, MemberName = name }.WithAnnotation(type); + if ((options & ConvertTypeOptions.IncludeTypeParameterDefinitions) == ConvertTypeOptions.IncludeTypeParameterDefinitions) { + AddTypeParameterDefininitionsTo(type, astType); + } + return astType; + } + } + } + + static void AddTypeParameterDefininitionsTo(TypeReference type, AstType astType) + { + if (type.HasGenericParameters) { + List typeArguments = new List(); + foreach (GenericParameter gp in type.GenericParameters) { + typeArguments.Add(new SimpleType(gp.Name)); } + ApplyTypeArgumentsTo(astType, typeArguments); } } diff --git a/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs b/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs index b07ccb74c..7328779c2 100644 --- a/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs +++ b/ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs @@ -198,8 +198,6 @@ namespace ICSharpCode.Decompiler.Ast yield return fixedStatement; } else if (node is ILBlock) { yield return TransformBlock((ILBlock)node); - } else if (node is ILComment) { - yield return new CommentStatement(((ILComment)node).Text).WithAnnotation(((ILComment)node).ILRanges); } else { throw new Exception("Unknown node type"); } diff --git a/ICSharpCode.Decompiler/Ast/CecilTypeResolveContext.cs b/ICSharpCode.Decompiler/Ast/CecilTypeResolveContext.cs new file mode 100644 index 000000000..0a5348bf4 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/CecilTypeResolveContext.cs @@ -0,0 +1,148 @@ +// Copyright (c) AlphaSierraPapa for the SharpDevelop Team (for details please see \doc\copyright.txt) +// This code is distributed under MIT X11 license (for details please see \doc\license.txt) + +using System; +using System.Collections.Generic; +using System.Linq; +using ICSharpCode.NRefactory.TypeSystem; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast +{ + /// + /// ITypeResolveContext implementation that lazily loads types from Cecil. + /// + public class CecilTypeResolveContext : ISynchronizedTypeResolveContext, IProjectContent + { + readonly ModuleDefinition module; + readonly string[] namespaces; + readonly CecilLoader loader; + Dictionary dict = new Dictionary(); + int countUntilNextCleanup = 4; + + public CecilTypeResolveContext(ModuleDefinition module) + { + this.loader = new CecilLoader(); + this.loader.IncludeInternalMembers = true; + this.module = module; + this.namespaces = module.Types.Select(t => t.Namespace).Distinct().ToArray(); + + List assemblyAttributes = new List(); + foreach (var attr in module.Assembly.CustomAttributes) { + assemblyAttributes.Add(loader.ReadAttribute(attr)); + } + this.AssemblyAttributes = assemblyAttributes.AsReadOnly(); + } + + ITypeDefinition GetClass(TypeDefinition cecilType) + { + lock (dict) { + WeakReference wr; + ITypeDefinition type; + if (dict.TryGetValue(cecilType, out wr)) { + type = (ITypeDefinition)wr.Target; + } else { + wr = null; + type = null; + } + if (type == null) { + type = loader.LoadType(cecilType, this); + } + if (wr == null) { + if (--countUntilNextCleanup <= 0) + CleanupDict(); + wr = new WeakReference(type); + dict.Add(cecilType, wr); + } else { + wr.Target = type; + } + return type; + } + } + + void CleanupDict() + { + List deletedKeys = new List(); + foreach (var pair in dict) { + if (!pair.Value.IsAlive) { + deletedKeys.Add(pair.Key); + } + } + foreach (var key in deletedKeys) { + dict.Remove(key); + } + countUntilNextCleanup = dict.Count + 4; + } + + public IList AssemblyAttributes { get; private set; } + + public ITypeDefinition GetClass(string nameSpace, string name, int typeParameterCount, StringComparer nameComparer) + { + if (typeParameterCount > 0) + name = name + "`" + typeParameterCount.ToString(); + if (nameComparer == StringComparer.Ordinal) { + TypeDefinition cecilType = module.GetType(nameSpace, name); + if (cecilType != null) + return GetClass(cecilType); + else + return null; + } + foreach (TypeDefinition cecilType in module.Types) { + if (nameComparer.Equals(name, cecilType.Name) + && nameComparer.Equals(nameSpace, cecilType.Namespace) + && cecilType.GenericParameters.Count == typeParameterCount) + { + return GetClass(cecilType); + } + } + return null; + } + + public IEnumerable GetClasses() + { + foreach (TypeDefinition cecilType in module.Types) { + yield return GetClass(cecilType); + } + } + + public IEnumerable GetClasses(string nameSpace, StringComparer nameComparer) + { + foreach (TypeDefinition cecilType in module.Types) { + if (nameComparer.Equals(nameSpace, cecilType.Namespace)) + yield return GetClass(cecilType); + } + } + + public IEnumerable GetNamespaces() + { + return namespaces; + } + + public string GetNamespace(string nameSpace, StringComparer nameComparer) + { + foreach (string ns in namespaces) { + if (nameComparer.Equals(ns, nameSpace)) + return ns; + } + return null; + } + + ICSharpCode.NRefactory.Utils.CacheManager ITypeResolveContext.CacheManager { + get { + // We don't support caching + return null; + } + } + + ISynchronizedTypeResolveContext ITypeResolveContext.Synchronize() + { + // This class is logically immutable + return this; + } + + void IDisposable.Dispose() + { + // exit from Synchronize() block + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/ContextTrackingVisitor.cs b/ICSharpCode.Decompiler/Ast/Transforms/ContextTrackingVisitor.cs index a9b3f3fc6..236b95711 100644 --- a/ICSharpCode.Decompiler/Ast/Transforms/ContextTrackingVisitor.cs +++ b/ICSharpCode.Decompiler/Ast/Transforms/ContextTrackingVisitor.cs @@ -11,7 +11,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms /// /// Base class for AST visitors that need the current type/method context info. /// - public abstract class ContextTrackingVisitor : DepthFirstAstVisitor, IAstTransform + public abstract class ContextTrackingVisitor : DepthFirstAstVisitor, IAstTransform { protected readonly DecompilerContext context; @@ -22,7 +22,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms this.context = context; } - public override object VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data) + public override TResult VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data) { TypeDefinition oldType = context.CurrentType; try { @@ -33,7 +33,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } } - public override object VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data) + public override TResult VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data) { Debug.Assert(context.CurrentMethod == null); try { @@ -44,7 +44,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } } - public override object VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data) + public override TResult VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration, object data) { Debug.Assert(context.CurrentMethod == null); try { @@ -55,7 +55,29 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } } - public override object VisitAccessor(Accessor accessor, object data) + public override TResult VisitDestructorDeclaration(DestructorDeclaration destructorDeclaration, object data) + { + Debug.Assert(context.CurrentMethod == null); + try { + context.CurrentMethod = destructorDeclaration.Annotation(); + return base.VisitDestructorDeclaration(destructorDeclaration, data); + } finally { + context.CurrentMethod = null; + } + } + + public override TResult VisitOperatorDeclaration(OperatorDeclaration operatorDeclaration, object data) + { + Debug.Assert(context.CurrentMethod == null); + try { + context.CurrentMethod = operatorDeclaration.Annotation(); + return base.VisitOperatorDeclaration(operatorDeclaration, data); + } finally { + context.CurrentMethod = null; + } + } + + public override TResult VisitAccessor(Accessor accessor, object data) { Debug.Assert(context.CurrentMethod == null); try { diff --git a/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs b/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs index c44200bf3..48a2ddf6c 100644 --- a/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs +++ b/ICSharpCode.Decompiler/Ast/Transforms/DelegateConstruction.cs @@ -18,7 +18,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms /// For anonymous methods, creates an AnonymousMethodExpression. /// Also gets rid of any "Display Classes" left over after inlining an anonymous method. /// - public class DelegateConstruction : ContextTrackingVisitor + public class DelegateConstruction : ContextTrackingVisitor { internal sealed class Annotation { @@ -182,7 +182,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms public override object VisitBlockStatement(BlockStatement blockStatement, object data) { base.VisitBlockStatement(blockStatement, data); - foreach (VariableDeclarationStatement stmt in blockStatement.Statements.OfType()) { + foreach (VariableDeclarationStatement stmt in blockStatement.Statements.OfType().ToArray()) { if (stmt.Variables.Count() != 1) continue; var variable = stmt.Variables.Single(); diff --git a/ICSharpCode.Decompiler/Ast/Transforms/IntroduceUsingDeclarations.cs b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceUsingDeclarations.cs new file mode 100644 index 000000000..0b0a40564 --- /dev/null +++ b/ICSharpCode.Decompiler/Ast/Transforms/IntroduceUsingDeclarations.cs @@ -0,0 +1,87 @@ +// Copyright (c) AlphaSierraPapa for the SharpDevelop Team (for details please see \doc\copyright.txt) +// This code is distributed under MIT X11 license (for details please see \doc\license.txt) + +using System; +using System.Collections.Generic; +using System.Linq; +using ICSharpCode.NRefactory.CSharp; +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.Ast.Transforms +{ + /// + /// Introduces using declarations. + /// + public class IntroduceUsingDeclarations : DepthFirstAstVisitor, IAstTransform + { + DecompilerContext context; + + public IntroduceUsingDeclarations(DecompilerContext context) + { + this.context = context; + currentNamespace = context.CurrentType != null ? context.CurrentType.Namespace : string.Empty; + } + + public void Run(AstNode compilationUnit) + { + // Don't show using when decompiling a single method or nested types: + if (context.CurrentMethod != null || (context.CurrentType != null && context.CurrentType.IsNested)) + return; + + // First determine all the namespaces that need to be imported: + compilationUnit.AcceptVisitor(this, null); + + importedNamespaces.Add("System"); // always import System, even when not necessary + + // Now add using declarations for those namespaces: + foreach (string ns in importedNamespaces.OrderByDescending(n => n)) { + // we go backwards (OrderByDescending) through the list of namespaces because we insert them backwards + // (always inserting at the start of the list) + string[] parts = ns.Split('.'); + AstType nsType = new SimpleType(parts[0]); + for (int i = 1; i < parts.Length; i++) { + nsType = new MemberType { Target = nsType, MemberName = parts[i] }; + } + compilationUnit.InsertChildAfter(null, new UsingDeclaration { Import = nsType }, CompilationUnit.MemberRole); + } + + // TODO: verify that the SimpleTypes refer to the correct type (no ambiguities) + } + + readonly HashSet importedNamespaces = new HashSet(); + string currentNamespace; + + bool IsParentOfCurrentNamespace(string ns) + { + if (ns.Length == 0) + return true; + if (currentNamespace.StartsWith(ns, StringComparison.Ordinal)) { + if (currentNamespace.Length == ns.Length) + return true; + if (currentNamespace[ns.Length] == '.') + return true; + } + return false; + } + + public override object VisitSimpleType(SimpleType simpleType, object data) + { + TypeReference tr = simpleType.Annotation(); + if (tr != null && !IsParentOfCurrentNamespace(tr.Namespace)) { + importedNamespaces.Add(tr.Namespace); + } + return base.VisitSimpleType(simpleType, data); // also visit type arguments + } + + public override object VisitNamespaceDeclaration(NamespaceDeclaration namespaceDeclaration, object data) + { + string oldNamespace = currentNamespace; + foreach (Identifier ident in namespaceDeclaration.Identifiers) { + currentNamespace = NamespaceDeclaration.BuildQualifiedName(currentNamespace, ident.Name); + } + base.VisitNamespaceDeclaration(namespaceDeclaration, data); + currentNamespace = oldNamespace; + return null; + } + } +} diff --git a/ICSharpCode.Decompiler/Ast/Transforms/PatternStatementTransform.cs b/ICSharpCode.Decompiler/Ast/Transforms/PatternStatementTransform.cs index 862856915..e2d6536d3 100644 --- a/ICSharpCode.Decompiler/Ast/Transforms/PatternStatementTransform.cs +++ b/ICSharpCode.Decompiler/Ast/Transforms/PatternStatementTransform.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using ICSharpCode.NRefactory.CSharp; using ICSharpCode.NRefactory.CSharp.PatternMatching; @@ -13,37 +14,106 @@ namespace ICSharpCode.Decompiler.Ast.Transforms /// /// Finds the expanded form of using statements using pattern matching and replaces it with a UsingStatement. /// - public class PatternStatementTransform : IAstTransform + public sealed class PatternStatementTransform : ContextTrackingVisitor, IAstTransform { - DecompilerContext context; + public PatternStatementTransform(DecompilerContext context) : base(context) + { + } - public PatternStatementTransform(DecompilerContext context) + #region Visitor Overrides + protected override AstNode VisitChildren(AstNode node, object data) { - if (context == null) - throw new ArgumentNullException("context"); - this.context = context; + // Go through the children, and keep visiting a node as long as it changes. + // Because some transforms delete/replace nodes before and after the node being transformed, we rely + // on the transform's return value to know where we need to keep iterating. + for (AstNode child = node.FirstChild; child != null; child = child.NextSibling) { + AstNode oldChild; + do { + oldChild = child; + child = child.AcceptVisitor(this, data); + Debug.Assert(child != null && child.Parent == node); + } while (child != oldChild); + } + return node; } - public void Run(AstNode compilationUnit) + public override AstNode VisitVariableDeclarationStatement(VariableDeclarationStatement variableDeclarationStatement, object data) { - if (context.Settings.UsingStatement) - TransformUsings(compilationUnit); - if (context.Settings.ForEachStatement) - TransformForeach(compilationUnit); - TransformFor(compilationUnit); - TransformDoWhile(compilationUnit); - if (context.Settings.LockStatement) - TransformLock(compilationUnit); - if (context.Settings.SwitchStatementOnString) - TransformSwitchOnString(compilationUnit); - if (context.Settings.AutomaticProperties) - TransformAutomaticProperties(compilationUnit); - if (context.Settings.AutomaticEvents) - TransformAutomaticEvents(compilationUnit); - - TransformTryCatchFinally(compilationUnit); + AstNode result; + if (context.Settings.UsingStatement) { + result = TransformUsings(variableDeclarationStatement); + if (result != null) + return result; + } + result = TransformFor(variableDeclarationStatement); + if (result != null) + return result; + if (context.Settings.LockStatement) { + result = TransformLock(variableDeclarationStatement); + if (result != null) + return result; + } + return base.VisitVariableDeclarationStatement(variableDeclarationStatement, data); + } + + public override AstNode VisitUsingStatement(UsingStatement usingStatement, object data) + { + if (context.Settings.ForEachStatement) { + AstNode result = TransformForeach(usingStatement); + if (result != null) + return result; + } + return base.VisitUsingStatement(usingStatement, data); } + public override AstNode VisitWhileStatement(WhileStatement whileStatement, object data) + { + return TransformDoWhile(whileStatement) ?? base.VisitWhileStatement(whileStatement, data); + } + + public override AstNode VisitIfElseStatement(IfElseStatement ifElseStatement, object data) + { + if (context.Settings.SwitchStatementOnString) { + AstNode result = TransformSwitchOnString(ifElseStatement); + if (result != null) + return result; + } + return base.VisitIfElseStatement(ifElseStatement, data); + } + + public override AstNode VisitPropertyDeclaration(PropertyDeclaration propertyDeclaration, object data) + { + if (context.Settings.AutomaticProperties) { + AstNode result = TransformAutomaticProperties(propertyDeclaration); + if (result != null) + return result; + } + return base.VisitPropertyDeclaration(propertyDeclaration, data); + } + + public override AstNode VisitCustomEventDeclaration(CustomEventDeclaration eventDeclaration, object data) + { + // first apply transforms to the accessor bodies + base.VisitCustomEventDeclaration(eventDeclaration, data); + if (context.Settings.AutomaticEvents) { + AstNode result = TransformAutomaticEvents(eventDeclaration); + if (result != null) + return result; + } + return eventDeclaration; + } + + public override AstNode VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data) + { + return TransformDestructor(methodDeclaration) ?? base.VisitMethodDeclaration(methodDeclaration, data); + } + + public override AstNode VisitTryCatchStatement(TryCatchStatement tryCatchStatement, object data) + { + return TransformTryCatchFinally(tryCatchStatement) ?? base.VisitTryCatchStatement(tryCatchStatement, data); + } + #endregion + /// /// $type $variable = $initializer; /// @@ -93,31 +163,30 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } }; - public void TransformUsings(AstNode compilationUnit) + public UsingStatement TransformUsings(VariableDeclarationStatement node) { - foreach (AstNode node in compilationUnit.Descendants.OfType().ToArray()) { - Match m1 = variableDeclPattern.Match(node); - if (m1 == null) continue; - AstNode tryCatch = node.NextSibling; - while (simpleVariableDefinition.Match(tryCatch) != null) - tryCatch = tryCatch.NextSibling; - Match m2 = usingTryCatchPattern.Match(tryCatch); - if (m2 == null) continue; - if (m1.Get("variable").Single().Name == m2.Get("ident").Single().Identifier) { - if (m2.Has("valueType")) { - // if there's no if(x!=null), then it must be a value type - TypeReference tr = m1.Get("type").Single().Annotation(); - if (tr == null || !tr.IsValueType) - continue; - } - BlockStatement body = m2.Get("body").Single(); - tryCatch.ReplaceWith( - new UsingStatement { - ResourceAcquisition = node.Detach(), - EmbeddedStatement = body.Detach() - }); + Match m1 = variableDeclPattern.Match(node); + if (m1 == null) return null; + AstNode tryCatch = node.NextSibling; + while (simpleVariableDefinition.Match(tryCatch) != null) + tryCatch = tryCatch.NextSibling; + Match m2 = usingTryCatchPattern.Match(tryCatch); + if (m2 == null) return null; + if (m1.Get("variable").Single().Name == m2.Get("ident").Single().Identifier) { + if (m2.Has("valueType")) { + // if there's no if(x!=null), then it must be a value type + TypeReference tr = m1.Get("type").Single().Annotation(); + if (tr == null || !tr.IsValueType) + return null; } + BlockStatement body = m2.Get("body").Single(); + UsingStatement usingStatement = new UsingStatement(); + usingStatement.ResourceAcquisition = node.Detach(); + usingStatement.EmbeddedStatement = body.Detach(); + tryCatch.ReplaceWith(usingStatement); + return usingStatement; } + return null; } #endregion @@ -184,29 +253,28 @@ namespace ICSharpCode.Decompiler.Ast.Transforms }.ToStatement() }; - public void TransformForeach(AstNode compilationUnit) + public ForeachStatement TransformForeach(UsingStatement node) { - foreach (AstNode node in compilationUnit.Descendants.OfType().ToArray()) { - Match m = foreachPattern.Match(node); - if (m == null) - continue; - VariableInitializer enumeratorVar = m.Get("enumeratorVariable").Single(); - VariableInitializer itemVar = m.Get("itemVariable").Single(); - if (m.Has("itemVariableInsideLoop") && itemVar.Annotation() != null) { - // cannot move captured variables out of loops - continue; - } - BlockStatement newBody = new BlockStatement(); - foreach (Statement stmt in m.Get("statement")) - newBody.Add(stmt.Detach()); - node.ReplaceWith( - new ForeachStatement { - VariableType = m.Get("itemType").Single().Detach(), - VariableName = itemVar.Name, - InExpression = m.Get("collection").Single().Detach(), - EmbeddedStatement = newBody - }); + Match m = foreachPattern.Match(node); + if (m == null) + return null; + VariableInitializer enumeratorVar = m.Get("enumeratorVariable").Single(); + VariableInitializer itemVar = m.Get("itemVariable").Single(); + if (m.Has("itemVariableInsideLoop") && itemVar.Annotation() != null) { + // cannot move captured variables out of loops + return null; } + BlockStatement newBody = new BlockStatement(); + foreach (Statement stmt in m.Get("statement")) + newBody.Add(stmt.Detach()); + ForeachStatement foreachStatement = new ForeachStatement { + VariableType = m.Get("itemType").Single().Detach(), + VariableName = itemVar.Name, + InExpression = m.Get("collection").Single().Detach(), + EmbeddedStatement = newBody + }; + node.ReplaceWith(foreachStatement); + return foreachStatement; } #endregion @@ -231,32 +299,30 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } }}; - public void TransformFor(AstNode compilationUnit) + public ForStatement TransformFor(VariableDeclarationStatement node) { - foreach (AstNode node in compilationUnit.Descendants.OfType().ToArray()) { - Match m1 = variableDeclPattern.Match(node); - if (m1 == null) continue; - AstNode next = node.NextSibling; - while (simpleVariableDefinition.Match(next) != null) - next = next.NextSibling; - Match m2 = forPattern.Match(next); - if (m2 == null) continue; - // ensure the variable in the for pattern is the same as in the declaration - if (m1.Get("variable").Single().Name != m2.Get("ident").Single().Identifier) - continue; - WhileStatement loop = (WhileStatement)next; - node.Remove(); - BlockStatement newBody = new BlockStatement(); - foreach (Statement stmt in m2.Get("statement")) - newBody.Add(stmt.Detach()); - loop.ReplaceWith( - new ForStatement { - Initializers = { (VariableDeclarationStatement)node }, - Condition = loop.Condition.Detach(), - Iterators = { m2.Get("increment").Single().Detach() }, - EmbeddedStatement = newBody - }); - } + Match m1 = variableDeclPattern.Match(node); + if (m1 == null) return null; + AstNode next = node.NextSibling; + while (simpleVariableDefinition.Match(next) != null) + next = next.NextSibling; + Match m2 = forPattern.Match(next); + if (m2 == null) return null; + // ensure the variable in the for pattern is the same as in the declaration + if (m1.Get("variable").Single().Name != m2.Get("ident").Single().Identifier) + return null; + WhileStatement loop = (WhileStatement)next; + node.Remove(); + BlockStatement newBody = new BlockStatement(); + foreach (Statement stmt in m2.Get("statement")) + newBody.Add(stmt.Detach()); + ForStatement forStatement = new ForStatement(); + forStatement.Initializers.Add(node); + forStatement.Condition = loop.Condition.Detach(); + forStatement.Iterators.Add(m2.Get("increment").Single().Detach()); + forStatement.EmbeddedStatement = newBody; + loop.ReplaceWith(forStatement); + return forStatement; } #endregion @@ -273,37 +339,37 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } }}; - public void TransformDoWhile(AstNode compilationUnit) + public DoWhileStatement TransformDoWhile(WhileStatement whileLoop) { - foreach (WhileStatement whileLoop in compilationUnit.Descendants.OfType().ToArray()) { - Match m = doWhilePattern.Match(whileLoop); - if (m != null) { - DoWhileStatement doLoop = new DoWhileStatement(); - doLoop.Condition = new UnaryOperatorExpression(UnaryOperatorType.Not, m.Get("condition").Single().Detach()); - doLoop.Condition.AcceptVisitor(new PushNegation(), null); - BlockStatement block = (BlockStatement)whileLoop.EmbeddedStatement; - block.Statements.Last().Remove(); // remove if statement - doLoop.EmbeddedStatement = block.Detach(); - whileLoop.ReplaceWith(doLoop); - - // we may have to extract variable definitions out of the loop if they were used in the condition: - foreach (var varDecl in block.Statements.OfType()) { - VariableInitializer v = varDecl.Variables.Single(); - if (doLoop.Condition.DescendantsAndSelf.OfType().Any(i => i.Identifier == v.Name)) { - AssignmentExpression assign = new AssignmentExpression(new IdentifierExpression(v.Name), v.Initializer.Detach()); - // move annotations from v to assign: - assign.CopyAnnotationsFrom(v); - v.RemoveAnnotations(); - // remove varDecl with assignment; and move annotations from varDecl to the ExpressionStatement: - varDecl.ReplaceWith(new ExpressionStatement(assign).CopyAnnotationsFrom(varDecl)); - varDecl.RemoveAnnotations(); - - // insert the varDecl above the do-while loop: - doLoop.Parent.InsertChildBefore(doLoop, varDecl, BlockStatement.StatementRole); - } + Match m = doWhilePattern.Match(whileLoop); + if (m != null) { + DoWhileStatement doLoop = new DoWhileStatement(); + doLoop.Condition = new UnaryOperatorExpression(UnaryOperatorType.Not, m.Get("condition").Single().Detach()); + doLoop.Condition.AcceptVisitor(new PushNegation(), null); + BlockStatement block = (BlockStatement)whileLoop.EmbeddedStatement; + block.Statements.Last().Remove(); // remove if statement + doLoop.EmbeddedStatement = block.Detach(); + whileLoop.ReplaceWith(doLoop); + + // we may have to extract variable definitions out of the loop if they were used in the condition: + foreach (var varDecl in block.Statements.OfType()) { + VariableInitializer v = varDecl.Variables.Single(); + if (doLoop.Condition.DescendantsAndSelf.OfType().Any(i => i.Identifier == v.Name)) { + AssignmentExpression assign = new AssignmentExpression(new IdentifierExpression(v.Name), v.Initializer.Detach()); + // move annotations from v to assign: + assign.CopyAnnotationsFrom(v); + v.RemoveAnnotations(); + // remove varDecl with assignment; and move annotations from varDecl to the ExpressionStatement: + varDecl.ReplaceWith(new ExpressionStatement(assign).CopyAnnotationsFrom(varDecl)); + varDecl.RemoveAnnotations(); + + // insert the varDecl above the do-while loop: + doLoop.Parent.InsertChildBefore(doLoop, varDecl, BlockStatement.StatementRole); } } + return doLoop; } + return null; } #endregion @@ -338,49 +404,49 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } }}; - public void TransformLock(AstNode compilationUnit) + public LockStatement TransformLock(VariableDeclarationStatement node) { - foreach (AstNode node in compilationUnit.Descendants.OfType().ToArray()) { - Match m1 = lockFlagInitPattern.Match(node); - if (m1 == null) continue; - AstNode tryCatch = node.NextSibling; - while (simpleVariableDefinition.Match(tryCatch) != null) - tryCatch = tryCatch.NextSibling; - Match m2 = lockTryCatchPattern.Match(tryCatch); - if (m2 == null) continue; - if (m1.Get("variable").Single().Name == m2.Get("flag").Single().Identifier) { - Expression enter = m2.Get("enter").Single(); - IdentifierExpression exit = m2.Get("exit").Single(); - if (exit.Match(enter) == null) { - // If exit and enter are not the same, then enter must be "exit = ..." - AssignmentExpression assign = enter as AssignmentExpression; - if (assign == null) - continue; - if (exit.Match(assign.Left) == null) - continue; - enter = assign.Right; - // Remove 'exit' variable: - bool ok = false; - for (AstNode tmp = node.NextSibling; tmp != tryCatch; tmp = tmp.NextSibling) { - VariableDeclarationStatement v = (VariableDeclarationStatement)tmp; - if (v.Variables.Single().Name == exit.Identifier) { - ok = true; - v.Remove(); - break; - } + Match m1 = lockFlagInitPattern.Match(node); + if (m1 == null) return null; + AstNode tryCatch = node.NextSibling; + while (simpleVariableDefinition.Match(tryCatch) != null) + tryCatch = tryCatch.NextSibling; + Match m2 = lockTryCatchPattern.Match(tryCatch); + if (m2 == null) return null; + if (m1.Get("variable").Single().Name == m2.Get("flag").Single().Identifier) { + Expression enter = m2.Get("enter").Single(); + IdentifierExpression exit = m2.Get("exit").Single(); + if (exit.Match(enter) == null) { + // If exit and enter are not the same, then enter must be "exit = ..." + AssignmentExpression assign = enter as AssignmentExpression; + if (assign == null) + return null; + if (exit.Match(assign.Left) == null) + return null; + enter = assign.Right; + // Remove 'exit' variable: + bool ok = false; + for (AstNode tmp = node.NextSibling; tmp != tryCatch; tmp = tmp.NextSibling) { + VariableDeclarationStatement v = (VariableDeclarationStatement)tmp; + if (v.Variables.Single().Name == exit.Identifier) { + ok = true; + v.Remove(); + break; } - if (!ok) - continue; } - // transform the code into a lock statement: - LockStatement l = new LockStatement(); - l.Expression = enter.Detach(); - l.EmbeddedStatement = ((TryCatchStatement)tryCatch).TryBlock.Detach(); - ((BlockStatement)l.EmbeddedStatement).Statements.First().Remove(); // Remove 'Enter()' call - tryCatch.ReplaceWith(l); - node.Remove(); // remove flag variable + if (!ok) + return null; } + // transform the code into a lock statement: + LockStatement l = new LockStatement(); + l.Expression = enter.Detach(); + l.EmbeddedStatement = ((TryCatchStatement)tryCatch).TryBlock.Detach(); + ((BlockStatement)l.EmbeddedStatement).Statements.First().Remove(); // Remove 'Enter()' call + tryCatch.ReplaceWith(l); + node.Remove(); // remove flag variable + return l; } + return null; } #endregion @@ -427,60 +493,59 @@ namespace ICSharpCode.Decompiler.Ast.Transforms FalseStatement = new OptionalNode("nullStmt", new BlockStatement { Statements = { new Repeat(new AnyNode()) } }) }; - public void TransformSwitchOnString(AstNode compilationUnit) + public SwitchStatement TransformSwitchOnString(IfElseStatement node) { - foreach (AstNode node in compilationUnit.Descendants.OfType().ToArray()) { - Match m = switchOnStringPattern.Match(node); - if (m == null) - continue; - if (m.Has("nonNullDefaultStmt") && !m.Has("nullStmt")) - continue; - // switchVar must be the same as switchExpr; or switchExpr must be an assignment and switchVar the left side of that assignment - if (m.Get("switchVar").Single().Match(m.Get("switchExpr").Single()) == null) { - AssignmentExpression assign = m.Get("switchExpr").Single() as AssignmentExpression; - if (m.Get("switchVar").Single().Match(assign.Left) == null) + Match m = switchOnStringPattern.Match(node); + if (m == null) + return null; + if (m.Has("nonNullDefaultStmt") && !m.Has("nullStmt")) + return null; + // switchVar must be the same as switchExpr; or switchExpr must be an assignment and switchVar the left side of that assignment + if (m.Get("switchVar").Single().Match(m.Get("switchExpr").Single()) == null) { + AssignmentExpression assign = m.Get("switchExpr").Single() as AssignmentExpression; + if (m.Get("switchVar").Single().Match(assign.Left) == null) + return null; + } + FieldReference cachedDictField = m.Get("cachedDict").Single().Annotation(); + if (cachedDictField == null || !cachedDictField.DeclaringType.Name.StartsWith("", StringComparison.Ordinal)) + return null; + List dictCreation = m.Get("dictCreation").Single().Statements.ToList(); + List> dict = BuildDictionary(dictCreation); + SwitchStatement sw = m.Get("switch").Single(); + sw.Expression = m.Get("switchExpr").Single().Detach(); + foreach (SwitchSection section in sw.SwitchSections) { + List labels = section.CaseLabels.ToList(); + section.CaseLabels.Clear(); + foreach (CaseLabel label in labels) { + PrimitiveExpression expr = label.Expression as PrimitiveExpression; + if (expr == null || !(expr.Value is int)) continue; - } - FieldReference cachedDictField = m.Get("cachedDict").Single().Annotation(); - if (cachedDictField == null || !cachedDictField.DeclaringType.Name.StartsWith("", StringComparison.Ordinal)) - continue; - List dictCreation = m.Get("dictCreation").Single().Statements.ToList(); - List> dict = BuildDictionary(dictCreation); - SwitchStatement sw = m.Get("switch").Single(); - sw.Expression = m.Get("switchExpr").Single().Detach(); - foreach (SwitchSection section in sw.SwitchSections) { - List labels = section.CaseLabels.ToList(); - section.CaseLabels.Clear(); - foreach (CaseLabel label in labels) { - PrimitiveExpression expr = label.Expression as PrimitiveExpression; - if (expr == null || !(expr.Value is int)) - continue; - int val = (int)expr.Value; - foreach (var pair in dict) { - if (pair.Value == val) - section.CaseLabels.Add(new CaseLabel { Expression = new PrimitiveExpression(pair.Key) }); - } + int val = (int)expr.Value; + foreach (var pair in dict) { + if (pair.Value == val) + section.CaseLabels.Add(new CaseLabel { Expression = new PrimitiveExpression(pair.Key) }); } } - if (m.Has("nullStmt")) { - SwitchSection section = new SwitchSection(); - section.CaseLabels.Add(new CaseLabel { Expression = new NullReferenceExpression() }); - BlockStatement block = m.Get("nullStmt").Single(); - block.Statements.Add(new BreakStatement()); - section.Statements.Add(block.Detach()); + } + if (m.Has("nullStmt")) { + SwitchSection section = new SwitchSection(); + section.CaseLabels.Add(new CaseLabel { Expression = new NullReferenceExpression() }); + BlockStatement block = m.Get("nullStmt").Single(); + block.Statements.Add(new BreakStatement()); + section.Statements.Add(block.Detach()); + sw.SwitchSections.Add(section); + if (m.Has("nonNullDefaultStmt")) { + section = new SwitchSection(); + section.CaseLabels.Add(new CaseLabel()); + block = new BlockStatement(); + block.Statements.AddRange(m.Get("nonNullDefaultStmt").Select(s => s.Detach())); + block.Add(new BreakStatement()); + section.Statements.Add(block); sw.SwitchSections.Add(section); - if (m.Has("nonNullDefaultStmt")) { - section = new SwitchSection(); - section.CaseLabels.Add(new CaseLabel()); - block = new BlockStatement(); - block.Statements.AddRange(m.Get("nonNullDefaultStmt").Select(s => s.Detach())); - block.Add(new BreakStatement()); - section.Statements.Add(block); - sw.SwitchSections.Add(section); - } } - node.ReplaceWith(sw); } + node.ReplaceWith(sw); + return sw; } List> BuildDictionary(List dictCreation) @@ -526,25 +591,25 @@ namespace ICSharpCode.Decompiler.Ast.Transforms } }}}; - void TransformAutomaticProperties(AstNode compilationUnit) + PropertyDeclaration TransformAutomaticProperties(PropertyDeclaration property) { - foreach (var property in compilationUnit.Descendants.OfType()) { - PropertyDefinition cecilProperty = property.Annotation(); - if (cecilProperty == null || cecilProperty.GetMethod == null || cecilProperty.SetMethod == null) - continue; - if (!(cecilProperty.GetMethod.IsCompilerGenerated() && cecilProperty.SetMethod.IsCompilerGenerated())) - continue; - Match m = automaticPropertyPattern.Match(property); - if (m != null) { - FieldDefinition field = m.Get("fieldReference").Single().Annotation().ResolveWithinSameModule(); - if (field.IsCompilerGenerated()) { - RemoveCompilerGeneratedAttribute(property.Getter.Attributes); - RemoveCompilerGeneratedAttribute(property.Setter.Attributes); - property.Getter.Body = null; - property.Setter.Body = null; - } + PropertyDefinition cecilProperty = property.Annotation(); + if (cecilProperty == null || cecilProperty.GetMethod == null || cecilProperty.SetMethod == null) + return null; + if (!(cecilProperty.GetMethod.IsCompilerGenerated() && cecilProperty.SetMethod.IsCompilerGenerated())) + return null; + Match m = automaticPropertyPattern.Match(property); + if (m != null) { + FieldDefinition field = m.Get("fieldReference").Single().Annotation().ResolveWithinSameModule(); + if (field.IsCompilerGenerated()) { + RemoveCompilerGeneratedAttribute(property.Getter.Attributes); + RemoveCompilerGeneratedAttribute(property.Setter.Attributes); + property.Getter.Body = null; + property.Setter.Body = null; } } + // Since the event instance is not changed, we can continue in the visitor as usual, so return null + return null; } void RemoveCompilerGeneratedAttribute(AstNodeCollection attributeSections) @@ -624,33 +689,64 @@ namespace ICSharpCode.Decompiler.Ast.Transforms return combineMethod.DeclaringType.FullName == "System.Delegate"; } - void TransformAutomaticEvents(AstNode compilationUnit) + EventDeclaration TransformAutomaticEvents(CustomEventDeclaration ev) { - foreach (var ev in compilationUnit.Descendants.OfType().ToArray()) { - Match m1 = automaticEventPatternV4.Match(ev.AddAccessor); - if (!CheckAutomaticEventV4Match(m1, ev, true)) - continue; - Match m2 = automaticEventPatternV4.Match(ev.RemoveAccessor); - if (!CheckAutomaticEventV4Match(m2, ev, false)) - continue; - EventDeclaration ed = new EventDeclaration(); - ev.Attributes.MoveTo(ed.Attributes); - ed.ReturnType = ev.ReturnType.Detach(); - ed.Modifiers = ev.Modifiers; - ed.Variables.Add(new VariableInitializer(ev.Name)); - ed.CopyAnnotationsFrom(ev); - - EventDefinition eventDef = ev.Annotation(); - if (eventDef != null) { - FieldDefinition field = eventDef.DeclaringType.Fields.FirstOrDefault(f => f.Name == ev.Name); - if (field != null) { - ed.AddAnnotation(field); - AstBuilder.ConvertAttributes(ed, field, AttributeTarget.Field); + Match m1 = automaticEventPatternV4.Match(ev.AddAccessor); + if (!CheckAutomaticEventV4Match(m1, ev, true)) + return null; + Match m2 = automaticEventPatternV4.Match(ev.RemoveAccessor); + if (!CheckAutomaticEventV4Match(m2, ev, false)) + return null; + EventDeclaration ed = new EventDeclaration(); + ev.Attributes.MoveTo(ed.Attributes); + ed.ReturnType = ev.ReturnType.Detach(); + ed.Modifiers = ev.Modifiers; + ed.Variables.Add(new VariableInitializer(ev.Name)); + ed.CopyAnnotationsFrom(ev); + + EventDefinition eventDef = ev.Annotation(); + if (eventDef != null) { + FieldDefinition field = eventDef.DeclaringType.Fields.FirstOrDefault(f => f.Name == ev.Name); + if (field != null) { + ed.AddAnnotation(field); + AstBuilder.ConvertAttributes(ed, field, AttributeTarget.Field); + } + } + + ev.ReplaceWith(ed); + return ed; + } + #endregion + + #region Destructor + static readonly MethodDeclaration destructorPattern = new MethodDeclaration { + Attributes = { new Repeat(new AnyNode()) }, + Modifiers = Modifiers.Any, + ReturnType = new PrimitiveType("void"), + Name = "Finalize", + Body = new BlockStatement { + new TryCatchStatement { + TryBlock = new AnyNode("body"), + FinallyBlock = new BlockStatement { + new BaseReferenceExpression().Invoke("Finalize") } } - - ev.ReplaceWith(ed); } + }; + + DestructorDeclaration TransformDestructor(MethodDeclaration methodDef) + { + Match m = destructorPattern.Match(methodDef); + if (m != null) { + DestructorDeclaration dd = new DestructorDeclaration(); + methodDef.Attributes.MoveTo(dd.Attributes); + dd.Modifiers = methodDef.Modifiers & ~(Modifiers.Protected | Modifiers.Override); + dd.Body = m.Get("body").Single().Detach(); + dd.Name = AstBuilder.CleanName(context.CurrentType.Name); + methodDef.ReplaceWith(dd); + return dd; + } + return null; } #endregion @@ -669,15 +765,15 @@ namespace ICSharpCode.Decompiler.Ast.Transforms /// Simplify nested 'try { try {} catch {} } finally {}'. /// This transformation must run after the using/lock tranformations. /// - void TransformTryCatchFinally(AstNode compilationUnit) + TryCatchStatement TransformTryCatchFinally(TryCatchStatement tryFinally) { - foreach (var tryFinally in compilationUnit.Descendants.OfType()) { - if (tryCatchFinallyPattern.Match(tryFinally) != null) { - TryCatchStatement tryCatch = (TryCatchStatement)tryFinally.TryBlock.Statements.Single(); - tryFinally.TryBlock = tryCatch.TryBlock.Detach(); - tryCatch.CatchClauses.MoveTo(tryFinally.CatchClauses); - } + if (tryCatchFinallyPattern.Match(tryFinally) != null) { + TryCatchStatement tryCatch = (TryCatchStatement)tryFinally.TryBlock.Statements.Single(); + tryFinally.TryBlock = tryCatch.TryBlock.Detach(); + tryCatch.CatchClauses.MoveTo(tryFinally.CatchClauses); } + // Since the tryFinally instance is not changed, we can continue in the visitor as usual, so return null + return null; } #endregion diff --git a/ICSharpCode.Decompiler/Ast/Transforms/TransformationPipeline.cs b/ICSharpCode.Decompiler/Ast/Transforms/TransformationPipeline.cs index 88a44523f..eff00c6ed 100644 --- a/ICSharpCode.Decompiler/Ast/Transforms/TransformationPipeline.cs +++ b/ICSharpCode.Decompiler/Ast/Transforms/TransformationPipeline.cs @@ -23,6 +23,7 @@ namespace ICSharpCode.Decompiler.Ast.Transforms new ConvertConstructorCallIntoInitializer(), new ReplaceMethodCallsWithOperators(), new IntroduceUnsafeModifier(), + new IntroduceUsingDeclarations(context) }; } diff --git a/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj b/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj index aa4456f8a..a5bb22877 100644 --- a/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj +++ b/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj @@ -51,6 +51,7 @@ + @@ -62,6 +63,7 @@ + @@ -100,6 +102,7 @@ + diff --git a/ICSharpCode.Decompiler/ILAst/GotoRemoval.cs b/ICSharpCode.Decompiler/ILAst/GotoRemoval.cs index 6fb4acf8e..9a31e8595 100644 --- a/ICSharpCode.Decompiler/ILAst/GotoRemoval.cs +++ b/ICSharpCode.Decompiler/ILAst/GotoRemoval.cs @@ -65,8 +65,9 @@ namespace ICSharpCode.Decompiler.ILAst int count = ilCase.Body.Count; if (count >= 2) { - if (!ilCase.Body[count - 2].CanFallThough() && - ilCase.Body[count - 1].Match(ILCode.LoopOrSwitchBreak)) { + if (ilCase.Body[count - 2].IsUnconditionalControlFlow() && + ilCase.Body[count - 1].Match(ILCode.LoopOrSwitchBreak)) + { ilCase.Body.RemoveAt(count - 1); } } diff --git a/ICSharpCode.Decompiler/ILAst/ILAstBuilder.cs b/ICSharpCode.Decompiler/ILAst/ILAstBuilder.cs index 397faf050..b75f3b069 100644 --- a/ICSharpCode.Decompiler/ILAst/ILAstBuilder.cs +++ b/ICSharpCode.Decompiler/ILAst/ILAstBuilder.cs @@ -316,7 +316,7 @@ namespace ICSharpCode.Decompiler.ILAst // Find all successors List branchTargets = new List(); - if (byteCode.Code.CanFallThough()) { + if (!byteCode.Code.IsUnconditionalControlFlow()) { branchTargets.Add(byteCode.Next); } if (byteCode.Operand is Instruction[]) { @@ -695,10 +695,7 @@ namespace ICSharpCode.Decompiler.ILAst ILRange ilRange = new ILRange() { From = byteCode.Offset, To = byteCode.EndOffset }; if (byteCode.StackBefore == null) { - ast.Add(new ILComment() { - Text = "Unreachable code: " + byteCode.Code.GetName(), - ILRanges = new List(new[] { ilRange }) - }); + // Unreachable code continue; } diff --git a/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs b/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs index eb19e00a8..034d93722 100644 --- a/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs +++ b/ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs @@ -21,6 +21,7 @@ namespace ICSharpCode.Decompiler.ILAst SimplifyShortCircuit, SimplifyTernaryOperator, SimplifyNullCoalescing, + JoinBasicBlocks, TransformDecimalCtorToConstant, SimplifyLdObjAndStObj, TransformArrayInitializers, @@ -86,55 +87,45 @@ namespace ICSharpCode.Decompiler.ILAst // Types are needed for the ternary operator optimization TypeAnalysis.Run(context, method); - AnalyseLabels(method); foreach(ILBlock block in method.GetSelfAndChildrenRecursive()) { bool modified; do { modified = false; if (abortBeforeStep == ILAstOptimizationStep.SimplifyShortCircuit) return; - modified |= block.RunOptimization(SimplifyShortCircuit); + modified |= block.RunOptimization(new SimpleControlFlow(context, method).SimplifyShortCircuit); if (abortBeforeStep == ILAstOptimizationStep.SimplifyTernaryOperator) return; - modified |= block.RunOptimization(SimplifyTernaryOperator); + modified |= block.RunOptimization(new SimpleControlFlow(context, method).SimplifyTernaryOperator); if (abortBeforeStep == ILAstOptimizationStep.SimplifyNullCoalescing) return; - modified |= block.RunOptimization(SimplifyNullCoalescing); + modified |= block.RunOptimization(new SimpleControlFlow(context, method).SimplifyNullCoalescing); + + if (abortBeforeStep == ILAstOptimizationStep.JoinBasicBlocks) return; + modified |= block.RunOptimization(new SimpleControlFlow(context, method).JoinBasicBlocks); + + if (abortBeforeStep == ILAstOptimizationStep.TransformDecimalCtorToConstant) return; + modified |= block.RunOptimization(TransformDecimalCtorToConstant); + + if (abortBeforeStep == ILAstOptimizationStep.SimplifyLdObjAndStObj) return; + modified |= block.RunOptimization(SimplifyLdObjAndStObj); + + if (abortBeforeStep == ILAstOptimizationStep.TransformArrayInitializers) return; + modified |= block.RunOptimization(Initializers.TransformArrayInitializers); + + if (abortBeforeStep == ILAstOptimizationStep.TransformCollectionInitializers) return; + modified |= block.RunOptimization(Initializers.TransformCollectionInitializers); + + if (abortBeforeStep == ILAstOptimizationStep.MakeAssignmentExpression) return; + modified |= block.RunOptimization(MakeAssignmentExpression); + + if (abortBeforeStep == ILAstOptimizationStep.InlineVariables2) return; + modified |= new ILInlining(method).InlineAllInBlock(block); + new ILInlining(method).CopyPropagation(); } while(modified); } - ILInlining inlining2 = new ILInlining(method); - inlining2.InlineAllVariables(); - inlining2.CopyPropagation(); - - foreach(ILBlock block in method.GetSelfAndChildrenRecursive()) { - - // Intentionaly outside the while(modifed) loop, - // I will put it there later after more testing - - bool modified = false; - - if (abortBeforeStep == ILAstOptimizationStep.TransformDecimalCtorToConstant) return; - modified |= block.RunOptimization(TransformDecimalCtorToConstant); - - if (abortBeforeStep == ILAstOptimizationStep.SimplifyLdObjAndStObj) return; - modified |= block.RunOptimization(SimplifyLdObjAndStObj); - - if (abortBeforeStep == ILAstOptimizationStep.TransformArrayInitializers) return; - modified |= block.RunOptimization(Initializers.TransformArrayInitializers); - modified |= block.RunOptimization(Initializers.TransformArrayInitializers); - - if (abortBeforeStep == ILAstOptimizationStep.TransformCollectionInitializers) return; - modified |= block.RunOptimization(Initializers.TransformCollectionInitializers); - - if (abortBeforeStep == ILAstOptimizationStep.MakeAssignmentExpression) return; - modified |= block.RunOptimization(MakeAssignmentExpression); - - if (abortBeforeStep == ILAstOptimizationStep.InlineVariables2) return; - modified |= new ILInlining(method).InlineAllInBlock(block); - } - if (abortBeforeStep == ILAstOptimizationStep.FindLoops) return; foreach(ILBlock block in method.GetSelfAndChildrenRecursive()) { new LoopsAndConditions(context).FindLoops(block); @@ -175,9 +166,17 @@ namespace ICSharpCode.Decompiler.ILAst if (abortBeforeStep == ILAstOptimizationStep.IntroduceFixedStatements) return; foreach(ILBlock block in method.GetSelfAndChildrenRecursive()) { - for (int i = 0; i < block.Body.Count; i++) { + for (int i = block.Body.Count - 1; i >= 0; i--) { + // TODO: Move before loops + if (i < block.Body.Count) + IntroduceFixedStatements(block.Body, i); + } + } + foreach(ILBlock block in method.GetSelfAndChildrenRecursive()) { + for (int i = block.Body.Count - 1; i >= 0; i--) { // TODO: Move before loops - IntroduceFixedStatements(block.Body, i); + if (i < block.Body.Count) + IntroduceFixedStatements(block.Body, i); } } @@ -286,14 +285,14 @@ namespace ICSharpCode.Decompiler.ILAst { List basicBlocks = new List(); - ILBasicBlock basicBlock = new ILBasicBlock() { - EntryLabel = block.Body.FirstOrDefault() as ILLabel ?? new ILLabel() { Name = "Block_" + (nextLabelIndex++) } - }; + ILLabel entryLabel = block.Body.FirstOrDefault() as ILLabel ?? new ILLabel() { Name = "Block_" + (nextLabelIndex++) }; + ILBasicBlock basicBlock = new ILBasicBlock(); basicBlocks.Add(basicBlock); - block.EntryGoto = new ILExpression(ILCode.Br, basicBlock.EntryLabel); + basicBlock.Body.Add(entryLabel); + block.EntryGoto = new ILExpression(ILCode.Br, entryLabel); if (block.Body.Count > 0) { - if (block.Body[0] != basicBlock.EntryLabel) + if (block.Body[0] != entryLabel) basicBlock.Body.Add(block.Body[0]); for (int i = 1; i < block.Body.Count; i++) { @@ -302,31 +301,28 @@ namespace ICSharpCode.Decompiler.ILAst // Start a new basic block if necessary if (currNode is ILLabel || - lastNode is ILTryCatchBlock || - currNode is ILTryCatchBlock || - (lastNode is ILExpression && ((ILExpression)lastNode).IsBranch())) + currNode is ILTryCatchBlock || // Counts as label + lastNode.IsConditionalControlFlow() || + lastNode.IsUnconditionalControlFlow()) { // Try to reuse the label ILLabel label = currNode is ILLabel ? ((ILLabel)currNode) : new ILLabel() { Name = "Block_" + (nextLabelIndex++) }; // Terminate the last block - if (lastNode.CanFallThough()) { + if (!lastNode.IsUnconditionalControlFlow()) { // Explicit branch from one block to other - basicBlock.FallthoughGoto = new ILExpression(ILCode.Br, label); - } else if (lastNode.Match(ILCode.Br)) { - // Reuse the existing goto as FallthoughGoto - basicBlock.FallthoughGoto = (ILExpression)lastNode; - basicBlock.Body.RemoveAt(basicBlock.Body.Count - 1); + basicBlock.Body.Add(new ILExpression(ILCode.Br, label)); } - // Start the new block + // Start the new block basicBlock = new ILBasicBlock(); basicBlocks.Add(basicBlock); - basicBlock.EntryLabel = label; - } - - // Add the node to the basic block - if (currNode != basicBlock.EntryLabel) { + basicBlock.Body.Add(label); + + // Add the node to the basic block + if (currNode != label) + basicBlock.Body.Add(currNode); + } else { basicBlock.Body.Add(currNode); } } @@ -336,208 +332,6 @@ namespace ICSharpCode.Decompiler.ILAst return; } - Dictionary labelGlobalRefCount; - Dictionary labelToBasicBlock; - - void AnalyseLabels(ILBlock method) - { - labelGlobalRefCount = new Dictionary(); - foreach(ILLabel target in method.GetSelfAndChildrenRecursive(e => e.IsBranch()).SelectMany(e => e.GetBranchTargets())) { - if (!labelGlobalRefCount.ContainsKey(target)) - labelGlobalRefCount[target] = 0; - labelGlobalRefCount[target]++; - } - - labelToBasicBlock = new Dictionary(); - foreach(ILBasicBlock bb in method.GetSelfAndChildrenRecursive()) { - foreach(ILLabel label in bb.GetChildren().OfType()) { - labelToBasicBlock[label] = bb; - } - } - } - - bool SimplifyTernaryOperator(List body, ILBasicBlock head, int pos) - { - Debug.Assert(body.Contains(head)); - - ILExpression condExpr; - ILLabel trueLabel; - ILLabel falseLabel; - ILVariable trueLocVar = null; - ILExpression trueExpr; - ILLabel trueFall; - ILVariable falseLocVar = null; - ILExpression falseExpr; - ILLabel falseFall; - object unused; - - if (head.MatchLast(ILCode.Brtrue, out trueLabel, out condExpr, out falseLabel) && - labelGlobalRefCount[trueLabel] == 1 && - labelGlobalRefCount[falseLabel] == 1 && - ((labelToBasicBlock[trueLabel].MatchSingle(ILCode.Stloc, out trueLocVar, out trueExpr, out trueFall) && - labelToBasicBlock[falseLabel].MatchSingle(ILCode.Stloc, out falseLocVar, out falseExpr, out falseFall)) || - (labelToBasicBlock[trueLabel].MatchSingle(ILCode.Ret, out unused, out trueExpr, out trueFall) && - labelToBasicBlock[falseLabel].MatchSingle(ILCode.Ret, out unused, out falseExpr, out falseFall))) && - trueLocVar == falseLocVar && - trueFall == falseFall && - body.Contains(labelToBasicBlock[trueLabel]) && - body.Contains(labelToBasicBlock[falseLabel]) - ) - { - ILCode opCode = trueLocVar != null ? ILCode.Stloc : ILCode.Ret; - TypeReference retType = trueLocVar != null ? trueLocVar.Type : this.context.CurrentMethod.ReturnType; - int leftBoolVal; - int rightBoolVal; - ILExpression newExpr; - // a ? true:false is equivalent to a - // a ? false:true is equivalent to !a - // a ? true : b is equivalent to a || b - // a ? b : true is equivalent to !a || b - // a ? b : false is equivalent to a && b - // a ? false : b is equivalent to !a && b - if (retType == typeSystem.Boolean && - trueExpr.Match(ILCode.Ldc_I4, out leftBoolVal) && - falseExpr.Match(ILCode.Ldc_I4, out rightBoolVal) && - ((leftBoolVal != 0 && rightBoolVal == 0) || (leftBoolVal == 0 && rightBoolVal != 0)) - ) - { - // It can be expressed as trivilal expression - if (leftBoolVal != 0) { - newExpr = condExpr; - } else { - newExpr = new ILExpression(ILCode.LogicNot, null, condExpr); - } - } else if (retType == typeSystem.Boolean && trueExpr.Match(ILCode.Ldc_I4, out leftBoolVal)) { - // It can be expressed as logical expression - if (leftBoolVal != 0) { - newExpr = new ILExpression(ILCode.LogicOr, null, condExpr, falseExpr); - } else { - newExpr = new ILExpression(ILCode.LogicAnd, null, new ILExpression(ILCode.LogicNot, null, condExpr), falseExpr); - } - } else if (retType == typeSystem.Boolean && falseExpr.Match(ILCode.Ldc_I4, out rightBoolVal)) { - // It can be expressed as logical expression - if (rightBoolVal != 0) { - newExpr = new ILExpression(ILCode.LogicOr, null, new ILExpression(ILCode.LogicNot, null, condExpr), trueExpr); - } else { - newExpr = new ILExpression(ILCode.LogicAnd, null, condExpr, trueExpr); - } - } else { - // Ternary operator tends to create long complicated return statements - if (opCode == ILCode.Ret) - return false; - - // Create ternary expression - newExpr = new ILExpression(ILCode.TernaryOp, null, condExpr, trueExpr, falseExpr); - } - head.Body[head.Body.Count - 1] = new ILExpression(opCode, trueLocVar, newExpr); - head.FallthoughGoto = trueFall != null ? new ILExpression(ILCode.Br, trueFall) : null; - - // Remove the old basic blocks - foreach(ILLabel deleteLabel in new [] { trueLabel, falseLabel }) { - body.RemoveOrThrow(labelToBasicBlock[deleteLabel]); - labelGlobalRefCount.RemoveOrThrow(deleteLabel); - labelToBasicBlock.RemoveOrThrow(deleteLabel); - } - - return true; - } - return false; - } - - bool SimplifyNullCoalescing(List body, ILBasicBlock head, int pos) - { - // ... - // v = ldloc(leftVar) - // brtrue(endBBLabel, ldloc(leftVar)) - // br(rightBBLabel) - // - // rightBBLabel: - // v = rightExpr - // br(endBBLabel) - // ... - // => - // ... - // v = NullCoalescing(ldloc(leftVar), rightExpr) - // br(endBBLabel) - - ILVariable v, v2; - ILExpression leftExpr, leftExpr2; - ILVariable leftVar; - ILLabel endBBLabel, endBBLabel2; - ILLabel rightBBLabel; - ILBasicBlock rightBB; - ILExpression rightExpr; - if (head.Body.Count >= 2 && - head.Body[head.Body.Count - 2].Match(ILCode.Stloc, out v, out leftExpr) && - leftExpr.Match(ILCode.Ldloc, out leftVar) && - head.MatchLast(ILCode.Brtrue, out endBBLabel, out leftExpr2, out rightBBLabel) && - leftExpr2.Match(ILCode.Ldloc, leftVar) && - labelToBasicBlock.TryGetValue(rightBBLabel, out rightBB) && - rightBB.MatchSingle(ILCode.Stloc, out v2, out rightExpr, out endBBLabel2) && - v == v2 && - endBBLabel == endBBLabel2 && - labelGlobalRefCount.GetOrDefault(rightBBLabel) == 1 && - body.Contains(rightBB) - ) - { - head.Body[head.Body.Count - 2] = new ILExpression(ILCode.Stloc, v, new ILExpression(ILCode.NullCoalescing, null, leftExpr, rightExpr)); - head.Body.RemoveAt(head.Body.Count - 1); - head.FallthoughGoto = new ILExpression(ILCode.Br, endBBLabel); - - body.RemoveOrThrow(labelToBasicBlock[rightBBLabel]); - labelGlobalRefCount.RemoveOrThrow(rightBBLabel); - labelToBasicBlock.RemoveOrThrow(rightBBLabel); - return true; - } - return false; - } - - bool SimplifyShortCircuit(List body, ILBasicBlock head, int pos) - { - Debug.Assert(body.Contains(head)); - - ILExpression condExpr; - ILLabel trueLabel; - ILLabel falseLabel; - if(head.MatchLast(ILCode.Brtrue, out trueLabel, out condExpr, out falseLabel)) { - for (int pass = 0; pass < 2; pass++) { - - // On the second pass, swap labels and negate expression of the first branch - // It is slightly ugly, but much better then copy-pasting this whole block - ILLabel nextLabel = (pass == 0) ? trueLabel : falseLabel; - ILLabel otherLablel = (pass == 0) ? falseLabel : trueLabel; - bool negate = (pass == 1); - - ILBasicBlock nextBasicBlock = labelToBasicBlock[nextLabel]; - ILExpression nextCondExpr; - ILLabel nextTrueLablel; - ILLabel nextFalseLabel; - if (body.Contains(nextBasicBlock) && - nextBasicBlock != head && - labelGlobalRefCount[nextBasicBlock.EntryLabel] == 1 && - nextBasicBlock.MatchSingle(ILCode.Brtrue, out nextTrueLablel, out nextCondExpr, out nextFalseLabel) && - (otherLablel == nextFalseLabel || otherLablel == nextTrueLablel)) - { - // Create short cicuit branch - if (otherLablel == nextFalseLabel) { - head.Body[head.Body.Count - 1] = new ILExpression(ILCode.Brtrue, nextTrueLablel, new ILExpression(ILCode.LogicAnd, null, negate ? new ILExpression(ILCode.LogicNot, null, condExpr) : condExpr, nextCondExpr)); - } else { - head.Body[head.Body.Count - 1] = new ILExpression(ILCode.Brtrue, nextTrueLablel, new ILExpression(ILCode.LogicOr, null, negate ? condExpr : new ILExpression(ILCode.LogicNot, null, condExpr), nextCondExpr)); - } - head.FallthoughGoto = new ILExpression(ILCode.Br, nextFalseLabel); - - // Remove the inlined branch from scope - labelGlobalRefCount.RemoveOrThrow(nextBasicBlock.EntryLabel); - labelToBasicBlock.RemoveOrThrow(nextBasicBlock.EntryLabel); - body.RemoveOrThrow(nextBasicBlock); - - return true; - } - } - } - return false; - } - void DuplicateReturnStatements(ILBlock method) { Dictionary nextSibling = new Dictionary(); @@ -598,8 +392,13 @@ namespace ICSharpCode.Decompiler.ILAst List flatBody = new List(); foreach (ILNode child in block.GetChildren()) { FlattenBasicBlocks(child); - if (child is ILBasicBlock) { - flatBody.AddRange(child.GetChildren()); + ILBasicBlock childAsBB = child as ILBasicBlock; + if (childAsBB != null) { + if (!(childAsBB.Body.FirstOrDefault() is ILLabel)) + throw new Exception("Basic block has to start with a label. \n" + childAsBB.ToString()); + if (childAsBB.Body.LastOrDefault() is ILExpression && !childAsBB.Body.LastOrDefault().IsUnconditionalControlFlow()) + throw new Exception("Basci block has to end with unconditional control flow. \n" + childAsBB.ToString()); + flatBody.AddRange(childAsBB.GetChildren()); } else { flatBody.Add(child); } @@ -627,8 +426,8 @@ namespace ICSharpCode.Decompiler.ILAst for (int i = 0; i < block.Body.Count; i++) { ILCondition cond = block.Body[i] as ILCondition; if (cond != null) { - bool trueExits = cond.TrueBlock.Body.Count > 0 && !cond.TrueBlock.Body.Last().CanFallThough(); - bool falseExits = cond.FalseBlock.Body.Count > 0 && !cond.FalseBlock.Body.Last().CanFallThough(); + bool trueExits = cond.TrueBlock.Body.LastOrDefault().IsUnconditionalControlFlow(); + bool falseExits = cond.FalseBlock.Body.LastOrDefault().IsUnconditionalControlFlow(); if (trueExits) { // Move the false block after the condition @@ -677,12 +476,9 @@ namespace ICSharpCode.Decompiler.ILAst { bool modified = false; List body = block.Body; - for (int i = 0; i < body.Count;) { - if (optimization(body, (ILBasicBlock)body[i], i)) { + for (int i = body.Count - 1; i >= 0; i--) { + if (i < body.Count && optimization(body, (ILBasicBlock)body[i], i)) { modified = true; - i = Math.Max(0, i - 1); // Go back one step - } else { - i++; } } return modified; @@ -692,26 +488,26 @@ namespace ICSharpCode.Decompiler.ILAst { bool modified = false; foreach (ILBasicBlock bb in block.Body) { - for (int j = 0; j < bb.Body.Count;) { - ILExpression expr = bb.Body[j] as ILExpression; - if (expr != null && optimization(bb.Body, expr, j)) { + for (int i = bb.Body.Count - 1; i >= 0; i--) { + ILExpression expr = bb.Body.ElementAtOrDefault(i) as ILExpression; + if (expr != null && optimization(bb.Body, expr, i)) { modified = true; - j = Math.Max(0, j - 1); // Go back one step - } else { - j++; } } } return modified; } - public static bool CanFallThough(this ILNode node) + public static bool IsConditionalControlFlow(this ILNode node) { ILExpression expr = node as ILExpression; - if (expr != null) { - return expr.Code.CanFallThough(); - } - return true; + return expr != null && expr.Code.IsConditionalControlFlow(); + } + + public static bool IsUnconditionalControlFlow(this ILNode node) + { + ILExpression expr = node as ILExpression; + return expr != null && expr.Code.IsUnconditionalControlFlow(); } /// @@ -768,12 +564,22 @@ namespace ICSharpCode.Decompiler.ILAst return !mr.Name.StartsWith("get_", StringComparison.Ordinal); case ILCode.Newobj: case ILCode.Newarr: + case ILCode.Stloc: return true; default: return false; } } + public static void RemoveTail(this List body, params ILCode[] codes) + { + for (int i = 0; i < codes.Length; i++) { + if (((ILExpression)body[body.Count - codes.Length + i]).Code != codes[i]) + throw new Exception("Tailing code does not match expected."); + } + body.RemoveRange(body.Count - codes.Length, codes.Length); + } + public static V GetOrDefault(this Dictionary dict, K key) { V ret; diff --git a/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs b/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs index c28dbc8df..3a7ae376f 100644 --- a/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs +++ b/ICSharpCode.Decompiler/ILAst/ILAstTypes.cs @@ -85,19 +85,12 @@ namespace ICSharpCode.Decompiler.ILAst public class ILBasicBlock: ILNode { - public ILLabel EntryLabel; + /// Body has to start with a label and end with unconditional control flow public List Body = new List(); - public ILExpression FallthoughGoto; public override IEnumerable GetChildren() { - if (this.EntryLabel != null) - yield return this.EntryLabel; - foreach (ILNode child in this.Body) { - yield return child; - } - if (this.FallthoughGoto != null) - yield return this.FallthoughGoto; + return this.Body; } public override void WriteTo(ITextOutput output) @@ -119,17 +112,6 @@ namespace ICSharpCode.Decompiler.ILAst } } - public class ILComment: ILNode - { - public string Text; - public List ILRanges { get; set; } - - public override void WriteTo(ITextOutput output) - { - output.WriteLine("// " + this.Text); - } - } - public class ILTryCatchBlock: ILNode { public class CatchBlock: ILBlock @@ -290,6 +272,9 @@ namespace ICSharpCode.Decompiler.ILAst public ILExpression(ILCode code, object operand, List args) { + if (operand is ILExpression) + throw new ArgumentException("operand"); + this.Code = code; this.Operand = operand; this.Arguments = new List(args); @@ -298,6 +283,9 @@ namespace ICSharpCode.Decompiler.ILAst public ILExpression(ILCode code, object operand, params ILExpression[] args) { + if (operand is ILExpression) + throw new ArgumentException("operand"); + this.Code = code; this.Operand = operand; this.Arguments = new List(args); diff --git a/ICSharpCode.Decompiler/ILAst/ILCodes.cs b/ICSharpCode.Decompiler/ILAst/ILCodes.cs index ea0569582..7ab0ec9db 100644 --- a/ICSharpCode.Decompiler/ILAst/ILCodes.cs +++ b/ICSharpCode.Decompiler/ILAst/ILCodes.cs @@ -280,7 +280,41 @@ namespace ICSharpCode.Decompiler.ILAst return code.ToString().ToLowerInvariant().TrimStart('_').Replace('_','.'); } - public static bool CanFallThough(this ILCode code) + public static bool IsConditionalControlFlow(this ILCode code) + { + switch(code) { + case ILCode.__Brfalse_S: + case ILCode.__Brtrue_S: + case ILCode.__Beq_S: + case ILCode.__Bge_S: + case ILCode.__Bgt_S: + case ILCode.__Ble_S: + case ILCode.__Blt_S: + case ILCode.__Bne_Un_S: + case ILCode.__Bge_Un_S: + case ILCode.__Bgt_Un_S: + case ILCode.__Ble_Un_S: + case ILCode.__Blt_Un_S: + case ILCode.__Brfalse: + case ILCode.Brtrue: + case ILCode.__Beq: + case ILCode.__Bge: + case ILCode.__Bgt: + case ILCode.__Ble: + case ILCode.__Blt: + case ILCode.__Bne_Un: + case ILCode.__Bge_Un: + case ILCode.__Bgt_Un: + case ILCode.__Ble_Un: + case ILCode.__Blt_Un: + case ILCode.Switch: + return true; + default: + return false; + } + } + + public static bool IsUnconditionalControlFlow(this ILCode code) { switch(code) { case ILCode.Br: @@ -295,9 +329,9 @@ namespace ICSharpCode.Decompiler.ILAst case ILCode.LoopContinue: case ILCode.LoopOrSwitchBreak: case ILCode.YieldBreak: - return false; - default: return true; + default: + return false; } } diff --git a/ICSharpCode.Decompiler/ILAst/ILInlining.cs b/ICSharpCode.Decompiler/ILAst/ILInlining.cs index 79cf33365..647f035a7 100644 --- a/ICSharpCode.Decompiler/ILAst/ILInlining.cs +++ b/ICSharpCode.Decompiler/ILAst/ILInlining.cs @@ -81,7 +81,7 @@ namespace ICSharpCode.Decompiler.ILAst { bool modified = false; List body = bb.Body; - for(int i = 0; i < body.Count - 1;) { + for(int i = 0; i < body.Count;) { ILVariable locVar; ILExpression expr; if (body[i].Match(ILCode.Stloc, out locVar, out expr) && InlineOneIfPossible(bb.Body, i, aggressive: false)) { diff --git a/ICSharpCode.Decompiler/ILAst/InitializerPeepholeTransforms.cs b/ICSharpCode.Decompiler/ILAst/InitializerPeepholeTransforms.cs index e66cf577b..29d869526 100644 --- a/ICSharpCode.Decompiler/ILAst/InitializerPeepholeTransforms.cs +++ b/ICSharpCode.Decompiler/ILAst/InitializerPeepholeTransforms.cs @@ -162,7 +162,7 @@ namespace ICSharpCode.Decompiler.ILAst if (nextExpr.Match(ILCode.Callvirt, out addMethod, out args) && addMethod.Name == "Add" && addMethod.HasThis && - args.Count == 2 && + args.Count >= 2 && args[0].Match(ILCode.Ldloc, out v2) && v == v2) { diff --git a/ICSharpCode.Decompiler/ILAst/LoopsAndConditions.cs b/ICSharpCode.Decompiler/ILAst/LoopsAndConditions.cs index 79516ceac..964b1f84e 100644 --- a/ICSharpCode.Decompiler/ILAst/LoopsAndConditions.cs +++ b/ICSharpCode.Decompiler/ILAst/LoopsAndConditions.cs @@ -88,7 +88,7 @@ namespace ICSharpCode.Decompiler.ILAst ControlFlowNode destination; // Labels which are out of out scope will not be int the collection // Insert self edge only if we are sure we are a loop - if (labelToCfNode.TryGetValue(target, out destination) && (destination != source || target == node.EntryLabel)) { + if (labelToCfNode.TryGetValue(target, out destination) && (destination != source || target == node.Body.FirstOrDefault())) { ControlFlowEdge edge = new ControlFlowEdge(source, destination, JumpType.Normal); source.Outgoing.Add(edge); destination.Incoming.Add(edge); @@ -123,7 +123,8 @@ namespace ICSharpCode.Decompiler.ILAst ILExpression condExpr; ILLabel trueLabel; ILLabel falseLabel; - if(basicBlock.MatchSingle(ILCode.Brtrue, out trueLabel, out condExpr, out falseLabel)) + // It has to be just brtrue - any preceding code would introduce goto + if(basicBlock.MatchSingleAndBr(ILCode.Brtrue, out trueLabel, out condExpr, out falseLabel)) { ControlFlowNode trueTarget; labelToCfNode.TryGetValue(trueLabel, out trueTarget); @@ -157,7 +158,7 @@ namespace ICSharpCode.Decompiler.ILAst } // Use loop to implement the brtrue - basicBlock.Body.RemoveAt(basicBlock.Body.Count - 1); + basicBlock.Body.RemoveTail(ILCode.Brtrue, ILCode.Br); basicBlock.Body.Add(new ILWhileLoop() { Condition = condExpr, BodyBlock = new ILBlock() { @@ -165,7 +166,7 @@ namespace ICSharpCode.Decompiler.ILAst Body = FindLoops(loopContents, node, false) } }); - basicBlock.FallthoughGoto = new ILExpression(ILCode.Br, falseLabel); + basicBlock.Body.Add(new ILExpression(ILCode.Br, falseLabel)); result.Add(basicBlock); scope.ExceptWith(loopContents); @@ -175,16 +176,15 @@ namespace ICSharpCode.Decompiler.ILAst // Fallback method: while(true) if (scope.Contains(node)) { result.Add(new ILBasicBlock() { - EntryLabel = new ILLabel() { Name = "Loop_" + (nextLabelIndex++) }, Body = new List() { + new ILLabel() { Name = "Loop_" + (nextLabelIndex++) }, new ILWhileLoop() { BodyBlock = new ILBlock() { - EntryGoto = new ILExpression(ILCode.Br, basicBlock.EntryLabel), + EntryGoto = new ILExpression(ILCode.Br, (ILLabel)basicBlock.Body.First()), Body = FindLoops(loopContents, node, true) } }, }, - FallthoughGoto = null }); scope.ExceptWith(loopContents); @@ -233,12 +233,13 @@ namespace ICSharpCode.Decompiler.ILAst ILLabel[] caseLabels; ILExpression switchArg; ILLabel fallLabel; - if (block.MatchLast(ILCode.Switch, out caseLabels, out switchArg, out fallLabel)) { + if (block.MatchLastAndBr(ILCode.Switch, out caseLabels, out switchArg, out fallLabel)) { // Replace the switch code with ILSwitch ILSwitch ilSwitch = new ILSwitch() { Condition = switchArg }; - block.Body.RemoveAt(block.Body.Count - 1); + block.Body.RemoveTail(ILCode.Switch, ILCode.Br); block.Body.Add(ilSwitch); + block.Body.Add(new ILExpression(ILCode.Br, fallLabel)); result.Add(block); // Remove the item so that it is not picked up as content @@ -285,7 +286,12 @@ namespace ICSharpCode.Decompiler.ILAst scope.ExceptWith(content); caseBlock.Body.AddRange(FindConditions(content, condTarget)); // Add explicit break which should not be used by default, but the goto removal might decide to use it - caseBlock.Body.Add(new ILBasicBlock() { Body = { new ILExpression(ILCode.LoopOrSwitchBreak, null) } }); + caseBlock.Body.Add(new ILBasicBlock() { + Body = { + new ILLabel() { Name = "SwitchBreak_" + (nextLabelIndex++) }, + new ILExpression(ILCode.LoopOrSwitchBreak, null) + } + }); } } caseBlock.Values.Add(i + addValue); @@ -297,12 +303,17 @@ namespace ICSharpCode.Decompiler.ILAst if (content.Any()) { var caseBlock = new ILSwitch.CaseBlock() { EntryGoto = new ILExpression(ILCode.Br, fallLabel) }; ilSwitch.CaseBlocks.Add(caseBlock); - block.FallthoughGoto = null; + block.Body.RemoveTail(ILCode.Br); scope.ExceptWith(content); caseBlock.Body.AddRange(FindConditions(content, fallTarget)); // Add explicit break which should not be used by default, but the goto removal might decide to use it - caseBlock.Body.Add(new ILBasicBlock() { Body = { new ILExpression(ILCode.LoopOrSwitchBreak, null) } }); + caseBlock.Body.Add(new ILBasicBlock() { + Body = { + new ILLabel() { Name = "SwitchBreak_" + (nextLabelIndex++) }, + new ILExpression(ILCode.LoopOrSwitchBreak, null) + } + }); } } } @@ -311,7 +322,7 @@ namespace ICSharpCode.Decompiler.ILAst ILExpression condExpr; ILLabel trueLabel; ILLabel falseLabel; - if(block.MatchLast(ILCode.Brtrue, out trueLabel, out condExpr, out falseLabel)) { + if(block.MatchLastAndBr(ILCode.Brtrue, out trueLabel, out condExpr, out falseLabel)) { // Swap bodies since that seems to be the usual C# order ILLabel temp = trueLabel; @@ -325,9 +336,8 @@ namespace ICSharpCode.Decompiler.ILAst TrueBlock = new ILBlock() { EntryGoto = new ILExpression(ILCode.Br, trueLabel) }, FalseBlock = new ILBlock() { EntryGoto = new ILExpression(ILCode.Br, falseLabel) } }; - block.Body.RemoveAt(block.Body.Count - 1); + block.Body.RemoveTail(ILCode.Brtrue, ILCode.Br); block.Body.Add(ilCond); - block.FallthoughGoto = null; result.Add(block); // Remove the item immediately so that it is not picked up as content diff --git a/ICSharpCode.Decompiler/ILAst/PatternMatching.cs b/ICSharpCode.Decompiler/ILAst/PatternMatching.cs index b8d1f3651..3921121e3 100644 --- a/ICSharpCode.Decompiler/ILAst/PatternMatching.cs +++ b/ICSharpCode.Decompiler/ILAst/PatternMatching.cs @@ -98,29 +98,44 @@ namespace ICSharpCode.Decompiler.ILAst return false; } - public static bool MatchSingle(this ILBasicBlock bb, ILCode code, out T operand, out ILExpression arg, out ILLabel fallLabel) + public static bool MatchSingle(this ILBasicBlock bb, ILCode code, out T operand, out ILExpression arg) { - if (bb.Body.Count == 1) { - if (bb.Body[0].Match(code, out operand, out arg)) { - fallLabel = bb.FallthoughGoto != null ? (ILLabel)bb.FallthoughGoto.Operand : null; - return true; - } + if (bb.Body.Count == 2 && + bb.Body[0] is ILLabel && + bb.Body[1].Match(code, out operand, out arg)) + { + return true; + } + operand = default(T); + arg = null; + return false; + } + + public static bool MatchSingleAndBr(this ILBasicBlock bb, ILCode code, out T operand, out ILExpression arg, out ILLabel brLabel) + { + if (bb.Body.Count == 3 && + bb.Body[0] is ILLabel && + bb.Body[1].Match(code, out operand, out arg) && + bb.Body[2].Match(ILCode.Br, out brLabel)) + { + return true; } operand = default(T); arg = null; - fallLabel = null; + brLabel = null; return false; } - public static bool MatchLast(this ILBasicBlock bb, ILCode code, out T operand, out ILExpression arg, out ILLabel fallLabel) + public static bool MatchLastAndBr(this ILBasicBlock bb, ILCode code, out T operand, out ILExpression arg, out ILLabel brLabel) { - if (bb.Body.LastOrDefault().Match(code, out operand, out arg)) { - fallLabel = bb.FallthoughGoto != null ? (ILLabel)bb.FallthoughGoto.Operand : null; + if (bb.Body.ElementAtOrDefault(bb.Body.Count - 2).Match(code, out operand, out arg) && + bb.Body.LastOrDefault().Match(ILCode.Br, out brLabel)) + { return true; } operand = default(T); arg = null; - fallLabel = null; + brLabel = null; return false; } diff --git a/ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs b/ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs index 891357ca1..da9908e45 100644 --- a/ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs +++ b/ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs @@ -55,10 +55,11 @@ namespace ICSharpCode.Decompiler.ILAst static bool SimplifyLdObjAndStObj(List body, ILExpression expr, int pos) { + bool modified = false; if (expr.Code == ILCode.Initobj) { expr.Code = ILCode.Stobj; expr.Arguments.Add(new ILExpression(ILCode.DefaultValue, expr.Operand)); - return true; + modified = true; } ILExpression arg, arg2; TypeReference type; @@ -84,9 +85,9 @@ namespace ICSharpCode.Decompiler.ILAst arg.Arguments.Add(arg2); arg.ILRanges.AddRange(expr.ILRanges); body[pos] = arg; - return true; + modified = true; } - return false; + return modified; } #region CachedDelegateInitialization diff --git a/ICSharpCode.Decompiler/ILAst/SimpleControlFlow.cs b/ICSharpCode.Decompiler/ILAst/SimpleControlFlow.cs new file mode 100644 index 000000000..d1cbe720c --- /dev/null +++ b/ICSharpCode.Decompiler/ILAst/SimpleControlFlow.cs @@ -0,0 +1,259 @@ +// Copyright (c) AlphaSierraPapa for the SharpDevelop Team (for details please see \doc\copyright.txt) +// This code is distributed under MIT X11 license (for details please see \doc\license.txt) + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +using Mono.Cecil; + +namespace ICSharpCode.Decompiler.ILAst +{ + public class SimpleControlFlow + { + Dictionary labelGlobalRefCount = new Dictionary(); + Dictionary labelToBasicBlock = new Dictionary(); + + DecompilerContext context; + TypeSystem typeSystem; + + public SimpleControlFlow(DecompilerContext context, ILBlock method) + { + this.context = context; + this.typeSystem = context.CurrentMethod.Module.TypeSystem; + + foreach(ILLabel target in method.GetSelfAndChildrenRecursive(e => e.IsBranch()).SelectMany(e => e.GetBranchTargets())) { + labelGlobalRefCount[target] = labelGlobalRefCount.GetOrDefault(target) + 1; + } + foreach(ILBasicBlock bb in method.GetSelfAndChildrenRecursive()) { + foreach(ILLabel label in bb.GetChildren().OfType()) { + labelToBasicBlock[label] = bb; + } + } + } + + public bool SimplifyTernaryOperator(List body, ILBasicBlock head, int pos) + { + Debug.Assert(body.Contains(head)); + + ILExpression condExpr; + ILLabel trueLabel; + ILLabel falseLabel; + ILVariable trueLocVar = null; + ILExpression trueExpr; + ILLabel trueFall; + ILVariable falseLocVar = null; + ILExpression falseExpr; + ILLabel falseFall; + object unused; + + if (head.MatchLastAndBr(ILCode.Brtrue, out trueLabel, out condExpr, out falseLabel) && + labelGlobalRefCount[trueLabel] == 1 && + labelGlobalRefCount[falseLabel] == 1 && + ((labelToBasicBlock[trueLabel].MatchSingleAndBr(ILCode.Stloc, out trueLocVar, out trueExpr, out trueFall) && + labelToBasicBlock[falseLabel].MatchSingleAndBr(ILCode.Stloc, out falseLocVar, out falseExpr, out falseFall) && + trueLocVar == falseLocVar && trueFall == falseFall) || + (labelToBasicBlock[trueLabel].MatchSingle(ILCode.Ret, out unused, out trueExpr) && + labelToBasicBlock[falseLabel].MatchSingle(ILCode.Ret, out unused, out falseExpr))) && + body.Contains(labelToBasicBlock[trueLabel]) && + body.Contains(labelToBasicBlock[falseLabel]) + ) + { + bool isStloc = trueLocVar != null; + ILCode opCode = isStloc ? ILCode.Stloc : ILCode.Ret; + TypeReference retType = isStloc ? trueLocVar.Type : this.context.CurrentMethod.ReturnType; + int leftBoolVal; + int rightBoolVal; + ILExpression newExpr; + // a ? true:false is equivalent to a + // a ? false:true is equivalent to !a + // a ? true : b is equivalent to a || b + // a ? b : true is equivalent to !a || b + // a ? b : false is equivalent to a && b + // a ? false : b is equivalent to !a && b + if (retType == typeSystem.Boolean && + trueExpr.Match(ILCode.Ldc_I4, out leftBoolVal) && + falseExpr.Match(ILCode.Ldc_I4, out rightBoolVal) && + ((leftBoolVal != 0 && rightBoolVal == 0) || (leftBoolVal == 0 && rightBoolVal != 0)) + ) + { + // It can be expressed as trivilal expression + if (leftBoolVal != 0) { + newExpr = condExpr; + } else { + newExpr = new ILExpression(ILCode.LogicNot, null, condExpr); + } + } else if (retType == typeSystem.Boolean && trueExpr.Match(ILCode.Ldc_I4, out leftBoolVal)) { + // It can be expressed as logical expression + if (leftBoolVal != 0) { + newExpr = MakeLeftAssociativeShortCircuit(ILCode.LogicOr, condExpr, falseExpr); + } else { + newExpr = MakeLeftAssociativeShortCircuit(ILCode.LogicAnd, new ILExpression(ILCode.LogicNot, null, condExpr), falseExpr); + } + } else if (retType == typeSystem.Boolean && falseExpr.Match(ILCode.Ldc_I4, out rightBoolVal)) { + // It can be expressed as logical expression + if (rightBoolVal != 0) { + newExpr = MakeLeftAssociativeShortCircuit(ILCode.LogicOr, new ILExpression(ILCode.LogicNot, null, condExpr), trueExpr); + } else { + newExpr = MakeLeftAssociativeShortCircuit(ILCode.LogicAnd, condExpr, trueExpr); + } + } else { + // Ternary operator tends to create long complicated return statements + if (opCode == ILCode.Ret) + return false; + + // Only simplify generated variables + if (opCode == ILCode.Stloc && !trueLocVar.IsGenerated) + return false; + + // Create ternary expression + newExpr = new ILExpression(ILCode.TernaryOp, null, condExpr, trueExpr, falseExpr); + } + + head.Body.RemoveTail(ILCode.Brtrue, ILCode.Br); + head.Body.Add(new ILExpression(opCode, trueLocVar, newExpr)); + if (isStloc) + head.Body.Add(new ILExpression(ILCode.Br, trueFall)); + + // Remove the old basic blocks + body.RemoveOrThrow(labelToBasicBlock[trueLabel]); + body.RemoveOrThrow(labelToBasicBlock[falseLabel]); + + return true; + } + return false; + } + + public bool SimplifyNullCoalescing(List body, ILBasicBlock head, int pos) + { + // ... + // v = ldloc(leftVar) + // brtrue(endBBLabel, ldloc(leftVar)) + // br(rightBBLabel) + // + // rightBBLabel: + // v = rightExpr + // br(endBBLabel) + // ... + // => + // ... + // v = NullCoalescing(ldloc(leftVar), rightExpr) + // br(endBBLabel) + + ILVariable v, v2; + ILExpression leftExpr, leftExpr2; + ILVariable leftVar; + ILLabel endBBLabel, endBBLabel2; + ILLabel rightBBLabel; + ILBasicBlock rightBB; + ILExpression rightExpr; + if (head.Body.Count >= 3 && + head.Body[head.Body.Count - 3].Match(ILCode.Stloc, out v, out leftExpr) && + leftExpr.Match(ILCode.Ldloc, out leftVar) && + head.MatchLastAndBr(ILCode.Brtrue, out endBBLabel, out leftExpr2, out rightBBLabel) && + leftExpr2.Match(ILCode.Ldloc, leftVar) && + labelToBasicBlock.TryGetValue(rightBBLabel, out rightBB) && + rightBB.MatchSingleAndBr(ILCode.Stloc, out v2, out rightExpr, out endBBLabel2) && + v == v2 && + endBBLabel == endBBLabel2 && + labelGlobalRefCount.GetOrDefault(rightBBLabel) == 1 && + body.Contains(rightBB) + ) + { + head.Body.RemoveTail(ILCode.Stloc, ILCode.Brtrue, ILCode.Br); + head.Body.Add(new ILExpression(ILCode.Stloc, v, new ILExpression(ILCode.NullCoalescing, null, leftExpr, rightExpr))); + head.Body.Add(new ILExpression(ILCode.Br, endBBLabel)); + + body.RemoveOrThrow(labelToBasicBlock[rightBBLabel]); + return true; + } + return false; + } + + public bool SimplifyShortCircuit(List body, ILBasicBlock head, int pos) + { + Debug.Assert(body.Contains(head)); + + ILExpression condExpr; + ILLabel trueLabel; + ILLabel falseLabel; + if(head.MatchLastAndBr(ILCode.Brtrue, out trueLabel, out condExpr, out falseLabel)) { + for (int pass = 0; pass < 2; pass++) { + + // On the second pass, swap labels and negate expression of the first branch + // It is slightly ugly, but much better then copy-pasting this whole block + ILLabel nextLabel = (pass == 0) ? trueLabel : falseLabel; + ILLabel otherLablel = (pass == 0) ? falseLabel : trueLabel; + bool negate = (pass == 1); + + ILBasicBlock nextBasicBlock = labelToBasicBlock[nextLabel]; + ILExpression nextCondExpr; + ILLabel nextTrueLablel; + ILLabel nextFalseLabel; + if (body.Contains(nextBasicBlock) && + nextBasicBlock != head && + labelGlobalRefCount[(ILLabel)nextBasicBlock.Body.First()] == 1 && + nextBasicBlock.MatchSingleAndBr(ILCode.Brtrue, out nextTrueLablel, out nextCondExpr, out nextFalseLabel) && + (otherLablel == nextFalseLabel || otherLablel == nextTrueLablel)) + { + // Create short cicuit branch + ILExpression logicExpr; + if (otherLablel == nextFalseLabel) { + logicExpr = MakeLeftAssociativeShortCircuit(ILCode.LogicAnd, negate ? new ILExpression(ILCode.LogicNot, null, condExpr) : condExpr, nextCondExpr); + } else { + logicExpr = MakeLeftAssociativeShortCircuit(ILCode.LogicOr, negate ? condExpr : new ILExpression(ILCode.LogicNot, null, condExpr), nextCondExpr); + } + head.Body.RemoveTail(ILCode.Brtrue, ILCode.Br); + head.Body.Add(new ILExpression(ILCode.Brtrue, nextTrueLablel, logicExpr)); + head.Body.Add(new ILExpression(ILCode.Br, nextFalseLabel)); + + // Remove the inlined branch from scope + body.RemoveOrThrow(nextBasicBlock); + + return true; + } + } + } + return false; + } + + ILExpression MakeLeftAssociativeShortCircuit(ILCode code, ILExpression left, ILExpression right) + { + // Assuming that the inputs are already left associative + if (right.Match(code)) { + // Find the leftmost logical expression + ILExpression current = right; + while(current.Arguments[0].Match(code)) + current = current.Arguments[0]; + current.Arguments[0] = new ILExpression(code, null, left, current.Arguments[0]); + return right; + } else { + return new ILExpression(code, null, left, right); + } + } + + public bool JoinBasicBlocks(List body, ILBasicBlock head, int pos) + { + ILLabel nextLabel; + ILBasicBlock nextBB; + if (!head.Body.ElementAtOrDefault(head.Body.Count - 2).IsConditionalControlFlow() && + head.Body.Last().Match(ILCode.Br, out nextLabel) && + labelGlobalRefCount[nextLabel] == 1 && + labelToBasicBlock.TryGetValue(nextLabel, out nextBB) && + body.Contains(nextBB) && + nextBB.Body.First() == nextLabel && + !nextBB.Body.OfType().Any() + ) + { + head.Body.RemoveTail(ILCode.Br); + nextBB.Body.RemoveAt(0); // Remove label + head.Body.AddRange(nextBB.Body); + + body.RemoveOrThrow(nextBB); + return true; + } + return false; + } + } +} diff --git a/ICSharpCode.Decompiler/Tests/UnsafeCode.cs b/ICSharpCode.Decompiler/Tests/UnsafeCode.cs index 03b8b6ab5..83c292400 100644 --- a/ICSharpCode.Decompiler/Tests/UnsafeCode.cs +++ b/ICSharpCode.Decompiler/Tests/UnsafeCode.cs @@ -74,4 +74,9 @@ public class UnsafeCode } return PointerReferenceExpression((double*)a); } + + unsafe ~UnsafeCode() + { + PassPointerAsRefParameter(NullPointer); + } } diff --git a/ILSpy/CSharpLanguage.cs b/ILSpy/CSharpLanguage.cs index 11bc1b3dc..2140f498e 100644 --- a/ILSpy/CSharpLanguage.cs +++ b/ILSpy/CSharpLanguage.cs @@ -82,6 +82,7 @@ namespace ICSharpCode.ILSpy public override void DecompileMethod(MethodDefinition method, ITextOutput output, DecompilationOptions options) { + WriteCommentLine(output, TypeToString(method.DeclaringType, includeNamespace: true)); AstBuilder codeDomBuilder = CreateAstBuilder(options, method.DeclaringType); codeDomBuilder.AddMethod(method); codeDomBuilder.GenerateCode(output, transformAbortCondition); @@ -89,6 +90,7 @@ namespace ICSharpCode.ILSpy public override void DecompileProperty(PropertyDefinition property, ITextOutput output, DecompilationOptions options) { + WriteCommentLine(output, TypeToString(property.DeclaringType, includeNamespace: true)); AstBuilder codeDomBuilder = CreateAstBuilder(options, property.DeclaringType); codeDomBuilder.AddProperty(property); codeDomBuilder.GenerateCode(output, transformAbortCondition); @@ -96,6 +98,7 @@ namespace ICSharpCode.ILSpy public override void DecompileField(FieldDefinition field, ITextOutput output, DecompilationOptions options) { + WriteCommentLine(output, TypeToString(field.DeclaringType, includeNamespace: true)); AstBuilder codeDomBuilder = CreateAstBuilder(options, field.DeclaringType); codeDomBuilder.AddField(field); codeDomBuilder.GenerateCode(output, transformAbortCondition); @@ -103,6 +106,7 @@ namespace ICSharpCode.ILSpy public override void DecompileEvent(EventDefinition ev, ITextOutput output, DecompilationOptions options) { + WriteCommentLine(output, TypeToString(ev.DeclaringType, includeNamespace: true)); AstBuilder codeDomBuilder = CreateAstBuilder(options, ev.DeclaringType); codeDomBuilder.AddEvent(ev); codeDomBuilder.GenerateCode(output, transformAbortCondition); @@ -117,26 +121,15 @@ namespace ICSharpCode.ILSpy public override void DecompileAssembly(AssemblyDefinition assembly, string fileName, ITextOutput output, DecompilationOptions options) { - if (options.FullDecompilation) { - if (options.SaveAsProjectDirectory != null) { - HashSet directories = new HashSet(StringComparer.OrdinalIgnoreCase); - var files = WriteCodeFilesInProject(assembly, options, directories).ToList(); - files.AddRange(WriteResourceFilesInProject(assembly, fileName, options, directories)); - WriteProjectFile(new TextOutputWriter(output), files, assembly.MainModule); - } else { - foreach (TypeDefinition type in assembly.MainModule.Types) { - if (AstBuilder.MemberIsHidden(type, options.DecompilerSettings)) - continue; - AstBuilder codeDomBuilder = CreateAstBuilder(options, type); - codeDomBuilder.AddType(type); - codeDomBuilder.GenerateCode(output, transformAbortCondition); - output.WriteLine(); - } - } + if (options.FullDecompilation && options.SaveAsProjectDirectory != null) { + HashSet directories = new HashSet(StringComparer.OrdinalIgnoreCase); + var files = WriteCodeFilesInProject(assembly, options, directories).ToList(); + files.AddRange(WriteResourceFilesInProject(assembly, fileName, options, directories)); + WriteProjectFile(new TextOutputWriter(output), files, assembly.MainModule); } else { base.DecompileAssembly(assembly, fileName, output, options); AstBuilder codeDomBuilder = CreateAstBuilder(options, currentType: null); - codeDomBuilder.AddAssembly(assembly, onlyAssemblyLevel: true); + codeDomBuilder.AddAssembly(assembly, onlyAssemblyLevel: !options.FullDecompilation); codeDomBuilder.GenerateCode(output, transformAbortCondition); } } @@ -392,14 +385,12 @@ namespace ICSharpCode.ILSpy }); } - public override string TypeToString(TypeReference type, bool includeNamespace, ICustomAttributeProvider typeAttributes) + public override string TypeToString(TypeReference type, bool includeNamespace, ICustomAttributeProvider typeAttributes = null) { - AstType astType = AstBuilder.ConvertType(type, typeAttributes); - if (!includeNamespace) { - var tre = new TypeReferenceExpression { Type = astType }; - tre.AcceptVisitor(new RemoveNamespaceFromType(), null); - astType = tre.Type; - } + ConvertTypeOptions options = ConvertTypeOptions.IncludeTypeParameterDefinitions; + if (includeNamespace) + options |= ConvertTypeOptions.IncludeNamespace; + AstType astType = AstBuilder.ConvertType(type, typeAttributes, options); StringWriter w = new StringWriter(); if (type.IsByReference) { @@ -417,21 +408,6 @@ namespace ICSharpCode.ILSpy return w.ToString(); } - sealed class RemoveNamespaceFromType : DepthFirstAstVisitor - { - public override object VisitMemberType(MemberType memberType, object data) - { - base.VisitMemberType(memberType, data); - SimpleType st = memberType.Target as SimpleType; - if (st != null && !st.TypeArguments.Any()) { - SimpleType newSt = new SimpleType(memberType.MemberName); - memberType.TypeArguments.MoveTo(newSt.TypeArguments); - memberType.ReplaceWith(newSt); - } - return null; - } - } - public override bool ShowMember(MemberReference member) { return showAllMembers || !AstBuilder.MemberIsHidden(member, new DecompilationOptions().DecompilerSettings); diff --git a/ILSpy/DecompilationOptions.cs b/ILSpy/DecompilationOptions.cs index 162fbe29c..bada56620 100644 --- a/ILSpy/DecompilationOptions.cs +++ b/ILSpy/DecompilationOptions.cs @@ -51,7 +51,15 @@ namespace ICSharpCode.ILSpy /// Gets the settings for the decompiler. /// public DecompilerSettings DecompilerSettings { get; set; } - + + /// + /// Gets/sets an optional state of a decompiler text view. + /// + /// + /// This state is used to restore test view's state when decompilation is started by Go Back/Forward action. + /// + public ICSharpCode.ILSpy.TextView.DecompilerTextViewState TextViewState { get; set; } + public DecompilationOptions() { this.DecompilerSettings = DecompilerSettingsPanel.CurrentDecompilerSettings; diff --git a/ILSpy/ILSpy.csproj b/ILSpy/ILSpy.csproj index 538dbdc44..c8dc9dc09 100644 --- a/ILSpy/ILSpy.csproj +++ b/ILSpy/ILSpy.csproj @@ -156,6 +156,7 @@ + @@ -259,11 +260,6 @@ ICSharpCode.TreeView - - - - - - + \ No newline at end of file diff --git a/ILSpy/MainWindow.xaml.cs b/ILSpy/MainWindow.xaml.cs index 97114d72e..b88809274 100644 --- a/ILSpy/MainWindow.xaml.cs +++ b/ILSpy/MainWindow.xaml.cs @@ -47,7 +47,8 @@ namespace ICSharpCode.ILSpy /// partial class MainWindow : Window { - NavigationHistory history = new NavigationHistory(); + NavigationHistory, DecompilerTextViewState>> history = + new NavigationHistory, DecompilerTextViewState>>(); ILSpySettings spySettings; internal SessionSettings sessionSettings; AssemblyListManager assemblyListManager; @@ -259,7 +260,7 @@ namespace ICSharpCode.ILSpy { if (e.OldItems != null) foreach (LoadedAssembly asm in e.OldItems) - history.RemoveAll(n => n.AncestorsAndSelf().OfType().Any(a => a.LoadedAssembly == asm)); + history.RemoveAll(n => n.Item1.Any(nd => nd.AncestorsAndSelf().OfType().Any(a => a.LoadedAssembly == asm))); } void LoadInitialAssemblies() @@ -289,7 +290,7 @@ namespace ICSharpCode.ILSpy { RefreshTreeViewFilter(); if (e.PropertyName == "Language") { - TreeView_SelectionChanged(null, null); + DecompileSelectedNodes(); } } @@ -316,7 +317,7 @@ namespace ICSharpCode.ILSpy if (obj != null) { SharpTreeNode oldNode = treeView.SelectedItem as SharpTreeNode; if (oldNode != null && recordNavigationInHistory) - history.Record(oldNode); + history.Record(Tuple.Create(treeView.SelectedItems.OfType().ToList(), decompilerTextView.GetState())); // Set both the selection and focus to ensure that keyboard navigation works as expected. treeView.FocusNode(obj); treeView.SelectedItem = obj; @@ -437,12 +438,22 @@ namespace ICSharpCode.ILSpy #region Decompile (TreeView_SelectionChanged) void TreeView_SelectionChanged(object sender, SelectionChangedEventArgs e) { + DecompileSelectedNodes(); + } + + private bool ignoreDecompilationRequests; + + private void DecompileSelectedNodes(DecompilerTextViewState state = null) + { + if (ignoreDecompilationRequests) + return; + if (treeView.SelectedItems.Count == 1) { ILSpyTreeNode node = treeView.SelectedItem as ILSpyTreeNode; if (node != null && node.View(decompilerTextView)) return; } - decompilerTextView.Decompile(this.CurrentLanguage, this.SelectedNodes, new DecompilationOptions()); + decompilerTextView.Decompile(this.CurrentLanguage, this.SelectedNodes, new DecompilationOptions() { TextViewState = state }); } void SaveCommandExecuted(object sender, ExecutedRoutedEventArgs e) @@ -458,7 +469,7 @@ namespace ICSharpCode.ILSpy public void RefreshDecompiledView() { - TreeView_SelectionChanged(null, null); + DecompileSelectedNodes(); } public DecompilerTextView TextView { @@ -489,7 +500,7 @@ namespace ICSharpCode.ILSpy { if (history.CanNavigateBack) { e.Handled = true; - SelectNode(history.GoBack(treeView.SelectedItem as SharpTreeNode), false); + NavigateHistory(false); } } @@ -503,9 +514,27 @@ namespace ICSharpCode.ILSpy { if (history.CanNavigateForward) { e.Handled = true; - SelectNode(history.GoForward(treeView.SelectedItem as SharpTreeNode), false); + NavigateHistory(true); } } + + void NavigateHistory(bool forward) + { + var currentSelection = treeView.SelectedItems.OfType().ToList(); + var state = decompilerTextView.GetState(); + var combinedState = Tuple.Create(currentSelection, state); + var newState = forward ? history.GoForward(combinedState) : history.GoBack(combinedState); + + this.ignoreDecompilationRequests = true; + treeView.SelectedItems.Clear(); + foreach (var node in newState.Item1) + { + treeView.SelectedItems.Add(node); + } + ignoreDecompilationRequests = false; + DecompileSelectedNodes(newState.Item2); + } + #endregion #region Analyzer diff --git a/ILSpy/NavigationHistory.cs b/ILSpy/NavigationHistory.cs index f15cfa732..26118321d 100644 --- a/ILSpy/NavigationHistory.cs +++ b/ILSpy/NavigationHistory.cs @@ -10,10 +10,10 @@ namespace ICSharpCode.ILSpy /// /// Stores the navigation history. /// - sealed class NavigationHistory + sealed class NavigationHistory { - List back = new List(); - List forward = new List(); + List back = new List(); + List forward = new List(); public bool CanNavigateBack { get { return back.Count > 0; } @@ -23,27 +23,27 @@ namespace ICSharpCode.ILSpy get { return forward.Count > 0; } } - public SharpTreeNode GoBack(SharpTreeNode oldNode) + public T GoBack(T oldNode) { if (oldNode != null) forward.Add(oldNode); - SharpTreeNode node = back[back.Count - 1]; + T node = back[back.Count - 1]; back.RemoveAt(back.Count - 1); return node; } - public SharpTreeNode GoForward(SharpTreeNode oldNode) + public T GoForward(T oldNode) { if (oldNode != null) back.Add(oldNode); - SharpTreeNode node = forward[forward.Count - 1]; + T node = forward[forward.Count - 1]; forward.RemoveAt(forward.Count - 1); return node; } - public void RemoveAll(Predicate predicate) + public void RemoveAll(Predicate predicate) { back.RemoveAll(predicate); forward.RemoveAll(predicate); @@ -55,7 +55,7 @@ namespace ICSharpCode.ILSpy forward.Clear(); } - public void Record(SharpTreeNode node) + public void Record(T node) { forward.Clear(); back.Add(node); diff --git a/ILSpy/TextView/DecompilerTextView.cs b/ILSpy/TextView/DecompilerTextView.cs index b9c357401..38c21a940 100644 --- a/ILSpy/TextView/DecompilerTextView.cs +++ b/ILSpy/TextView/DecompilerTextView.cs @@ -298,9 +298,9 @@ namespace ICSharpCode.ILSpy.TextView /// /// Shows the given output in the text view. /// - void ShowOutput(AvalonEditTextOutput textOutput, IHighlightingDefinition highlighting = null) + void ShowOutput(AvalonEditTextOutput textOutput, IHighlightingDefinition highlighting = null, DecompilerTextViewState state = null) { - Debug.WriteLine("Showing {0} characters of output", textOutput.TextLength); + Debug.WriteLine("Showing {0} characters of output", textOutput.TextLength); Stopwatch w = Stopwatch.StartNew(); textEditor.ScrollToHome(); @@ -318,6 +318,11 @@ namespace ICSharpCode.ILSpy.TextView textEditor.Document = textOutput.GetDocument(); Debug.WriteLine(" Assigning document: {0}", w.Elapsed); w.Restart(); if (textOutput.Foldings.Count > 0) { + if (state != null) { + state.RestoreFoldings(textOutput.Foldings); + textEditor.ScrollToVerticalOffset(state.VerticalOffset); + textEditor.ScrollToHorizontalOffset(state.HorizontalOffset); + } foldingManager = FoldingManager.Install(textEditor.TextArea); foldingManager.UpdateFoldings(textOutput.Foldings.OrderBy(f => f.StartOffset), -1); Debug.WriteLine(" Updating folding: {0}", w.Elapsed); w.Restart(); @@ -383,7 +388,7 @@ namespace ICSharpCode.ILSpy.TextView delegate (Task task) { // handling the result try { AvalonEditTextOutput textOutput = task.Result; - ShowOutput(textOutput, context.Language.SyntaxHighlighting); + ShowOutput(textOutput, context.Language.SyntaxHighlighting, context.Options.TextViewState); } catch (AggregateException aggregateException) { textEditor.SyntaxHighlighting = null; Debug.WriteLine("Decompiler crashed: " + aggregateException.ToString()); @@ -656,6 +661,16 @@ namespace ICSharpCode.ILSpy.TextView return text; } #endregion + + public DecompilerTextViewState GetState() + { + var state = new DecompilerTextViewState(); + if (foldingManager != null) + state.SaveFoldingsState(foldingManager.AllFoldings); + state.VerticalOffset = textEditor.VerticalOffset; + state.HorizontalOffset = textEditor.HorizontalOffset; + return state; + } #region Unfold public void UnfoldAndScroll(int lineNumber) @@ -679,4 +694,26 @@ namespace ICSharpCode.ILSpy.TextView } #endregion } + + public class DecompilerTextViewState + { + private List> ExpandedFoldings; + private int FoldingsChecksum; + public double VerticalOffset; + public double HorizontalOffset; + + public void SaveFoldingsState(IEnumerable foldings) + { + ExpandedFoldings = foldings.Where(f => !f.IsFolded).Select(f => Tuple.Create(f.StartOffset, f.EndOffset)).ToList(); + FoldingsChecksum = foldings.Select(f => f.StartOffset * 3 - f.EndOffset).Aggregate((a, b) => a + b); + } + + internal void RestoreFoldings(List list) + { + var checksum = list.Select(f => f.StartOffset * 3 - f.EndOffset).Aggregate((a, b) => a + b); + if (FoldingsChecksum == checksum) + foreach (var folding in list) + folding.DefaultClosed = !ExpandedFoldings.Any(f => f.Item1 == folding.StartOffset && f.Item2 == folding.EndOffset); + } + } } diff --git a/ILSpy/TreeNodes/Analyzer/AnalyzedMethodTreeNode.cs b/ILSpy/TreeNodes/Analyzer/AnalyzedMethodTreeNode.cs index f577e4a85..41a7b7074 100644 --- a/ILSpy/TreeNodes/Analyzer/AnalyzedMethodTreeNode.cs +++ b/ILSpy/TreeNodes/Analyzer/AnalyzedMethodTreeNode.cs @@ -52,6 +52,8 @@ namespace ICSharpCode.ILSpy.TreeNodes.Analyzer if (analyzedMethod.HasBody) this.Children.Add(new AnalyzedMethodUsesNode(analyzedMethod)); this.Children.Add(new AnalyzedMethodUsedByTreeNode(analyzedMethod)); + if (analyzedMethod.IsVirtual && !analyzedMethod.IsFinal && !analyzedMethod.DeclaringType.IsInterface) // interfaces are temporarly disabled + this.Children.Add(new AnalyzerMethodOverridesTreeNode(analyzedMethod)); } } } diff --git a/ILSpy/TreeNodes/Analyzer/AnalyzerMethodOverridesTreeNode.cs b/ILSpy/TreeNodes/Analyzer/AnalyzerMethodOverridesTreeNode.cs new file mode 100644 index 000000000..709323e02 --- /dev/null +++ b/ILSpy/TreeNodes/Analyzer/AnalyzerMethodOverridesTreeNode.cs @@ -0,0 +1,203 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using ICSharpCode.NRefactory.Utils; +using ICSharpCode.TreeView; +using Mono.Cecil; + +namespace ICSharpCode.ILSpy.TreeNodes.Analyzer +{ + /// + /// Searches for overrides of the analyzed method. + /// + class AnalyzerMethodOverridesTreeNode : AnalyzerTreeNode + { + readonly MethodDefinition analyzedMethod; + readonly ThreadingSupport threading; + + /// + /// Controls whether overrides of already overriden method should be included. + /// + readonly bool onlyDirectOverrides = false; + + public AnalyzerMethodOverridesTreeNode(MethodDefinition analyzedMethod) + { + if (analyzedMethod == null) + throw new ArgumentNullException("analyzedMethod"); + + this.analyzedMethod = analyzedMethod; + this.threading = new ThreadingSupport(); + this.LazyLoading = true; + } + + public override object Text + { + get { return "Overrided By"; } + } + + public override object Icon + { + get { return Images.Search; } + } + + protected override void LoadChildren() + { + threading.LoadChildren(this, FetchChildren); + } + + protected override void OnCollapsing() + { + if (threading.IsRunning) + { + this.LazyLoading = true; + threading.Cancel(); + this.Children.Clear(); + } + } + + IEnumerable FetchChildren(CancellationToken ct) + { + return FindReferences(MainWindow.Instance.AssemblyList.GetAssemblies(), ct); + } + + IEnumerable FindReferences(LoadedAssembly[] assemblies, CancellationToken ct) + { + // use parallelism only on the assembly level (avoid locks within Cecil) + return assemblies.AsParallel().WithCancellation(ct).SelectMany((LoadedAssembly asm) => FindReferences(asm, ct)); + } + + IEnumerable FindReferences(LoadedAssembly asm, CancellationToken ct) + { + string asmName = asm.AssemblyDefinition.Name.Name; + string name = analyzedMethod.Name; + string declTypeName = analyzedMethod.DeclaringType.FullName; + foreach (TypeDefinition type in TreeTraversal.PreOrder(asm.AssemblyDefinition.MainModule.Types, t => t.NestedTypes)) + { + ct.ThrowIfCancellationRequested(); + + if (!IsDerived(type, analyzedMethod.DeclaringType)) + continue; + + foreach (MethodDefinition method in type.Methods) + { + ct.ThrowIfCancellationRequested(); + + if (HasCompatibleSpecification(method) && !method.IsNewSlot && DoesOverrideCorrectMethod(method)) + { + yield return new AnalyzedMethodTreeNode(method); + } + } + } + } + + /// + /// Tests whether the method could override analyzed method by comparing its name, return type and parameters. + /// + /// The method to test. + /// true if the method has the same specyfication as analyzed method, otherwise false. + private bool HasCompatibleSpecification(MethodDefinition method) + { + return method.Name == analyzedMethod.Name + && method.IsVirtual + && AreSameType(method.ReturnType, analyzedMethod.ReturnType) + && HaveTheSameParameters(method); + } + + /// + /// Checks whether between given and analyzed method are overrides with new (newSlot) modifier. + /// + /// The method to test. + /// true if the method overrides analyzed method, false if it overrides some other method that hides analyzed method. + private bool DoesOverrideCorrectMethod(MethodDefinition method) + { + var type = method.DeclaringType.BaseType.Resolve(); + while (type != analyzedMethod.DeclaringType) + { + var parentOverride = type.Methods.Where(m => HasCompatibleSpecification(m)).SingleOrDefault(); + if (parentOverride != null) + { + if (parentOverride.IsNewSlot) + return false; + else + return !onlyDirectOverrides; + } + type = type.BaseType.Resolve(); + } + return true; + } + + /// + /// Checks whether one type derives (directly or indirectly) from base type. + /// + /// The possible derived type. + /// The base type. + /// true if derives from , overwise false. + private static bool IsDerived(TypeDefinition derivedType, TypeDefinition baseType) + { + while (derivedType != null && derivedType.BaseType != null) + { + if (AreSameType(derivedType.BaseType, baseType)) + return true; + derivedType = derivedType.BaseType.Resolve(); + } + return false; + } + + /// + /// Checks whether both instances references the same type. + /// + /// The first type reference. + /// The second type reference. + /// true if both instances references the same type, overwise false. + private static bool AreSameType(TypeReference ref1, TypeReference ref2) + { + if (ref1 == ref2) + return true; + + if (ref1.Name != ref2.Name || ref1.FullName != ref2.FullName) + return false; + + return ref1.Resolve() == ref2.Resolve(); + } + + /// + /// Checkes whether the given method and the analyzed one has identical lists of parameters. + /// + /// The method to test. + /// true if both methods has the same parameters, otherwise false. + private bool HaveTheSameParameters(MethodDefinition method) + { + if (analyzedMethod.HasParameters) + { + return CompareParameterLists(analyzedMethod.Parameters, method.Parameters); + } + else + { + return !method.HasParameters; + } + } + + /// + /// Compares the list of method's parameters. + /// + /// The first list to compare. + /// The second list to copare. + /// true if both list have parameters of the same types at the same positions. + private static bool CompareParameterLists(Mono.Collections.Generic.Collection coll1, Mono.Collections.Generic.Collection coll2) + { + if (coll1.Count != coll2.Count) + return false; + + for (int index = 0; index < coll1.Count; index++) + { + var param1 = coll1[index]; + var param2 = coll2[index]; + if (param1.Attributes != param2.Attributes || !AreSameType(param1.ParameterType, param2.ParameterType)) + return false; + } + return true; + } + } +} diff --git a/ILSpy/TreeNodes/BaseTypesTreeNode.cs b/ILSpy/TreeNodes/BaseTypesTreeNode.cs index ff3f82c43..488807f05 100644 --- a/ILSpy/TreeNodes/BaseTypesTreeNode.cs +++ b/ILSpy/TreeNodes/BaseTypesTreeNode.cs @@ -94,7 +94,7 @@ namespace ICSharpCode.ILSpy.TreeNodes } public override object Text { - get { return tr.FullName; } + get { return this.Language.TypeToString(tr, true); } } public override object Icon { diff --git a/ILSpy/TreeNodes/DerivedTypesTreeNode.cs b/ILSpy/TreeNodes/DerivedTypesTreeNode.cs index 58a91e90c..35dc5f694 100644 --- a/ILSpy/TreeNodes/DerivedTypesTreeNode.cs +++ b/ILSpy/TreeNodes/DerivedTypesTreeNode.cs @@ -68,7 +68,18 @@ namespace ICSharpCode.ILSpy.TreeNodes static bool IsSameType(TypeReference typeRef, TypeDefinition type) { - return typeRef.FullName == type.FullName; + if (typeRef.FullName == type.FullName) + return true; + if (typeRef.Name != type.Name || type.Namespace != typeRef.Namespace) + return false; + if (typeRef.IsNested || type.IsNested) + if (!typeRef.IsNested || !type.IsNested || !IsSameType(typeRef.DeclaringType, type.DeclaringType)) + return false; + var gTypeRef = typeRef as GenericInstanceType; + if (gTypeRef != null || type.HasGenericParameters) + if (gTypeRef == null || !type.HasGenericParameters || gTypeRef.GenericArguments.Count != type.GenericParameters.Count) + return false; + return true; } public override void Decompile(Language language, ITextOutput output, DecompilationOptions options) @@ -98,7 +109,7 @@ namespace ICSharpCode.ILSpy.TreeNodes } public override object Text { - get { return def.FullName; } + get { return this.Language.TypeToString(def, true); } } public override object Icon { diff --git a/ILSpy/TreeNodes/TypeTreeNode.cs b/ILSpy/TreeNodes/TypeTreeNode.cs index 295add9c4..865b9c4ff 100644 --- a/ILSpy/TreeNodes/TypeTreeNode.cs +++ b/ILSpy/TreeNodes/TypeTreeNode.cs @@ -61,7 +61,7 @@ namespace ICSharpCode.ILSpy.TreeNodes } public override object Text { - get { return HighlightSearchMatch(type.Name); } + get { return HighlightSearchMatch(this.Language.TypeToString(type, includeNamespace: false)); } } public bool IsPublicAPI { diff --git a/NRefactory/ICSharpCode.NRefactory.Tests/CSharp/Analysis/DefiniteAssignmentTests.cs b/NRefactory/ICSharpCode.NRefactory.Tests/CSharp/Analysis/DefiniteAssignmentTests.cs index 247c1fb39..2e94a1c58 100644 --- a/NRefactory/ICSharpCode.NRefactory.Tests/CSharp/Analysis/DefiniteAssignmentTests.cs +++ b/NRefactory/ICSharpCode.NRefactory.Tests/CSharp/Analysis/DefiniteAssignmentTests.cs @@ -3,6 +3,7 @@ using System; using System.Linq; +using ICSharpCode.NRefactory.TypeSystem; using NUnit.Framework; namespace ICSharpCode.NRefactory.CSharp.Analysis @@ -39,7 +40,7 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis Statement stmt5 = tryCatchStatement.FinallyBlock.Statements.Single(); LabelStatement label = (LabelStatement)block.Statements.ElementAt(1); - DefiniteAssignmentAnalysis da = new DefiniteAssignmentAnalysis(block); + DefiniteAssignmentAnalysis da = new DefiniteAssignmentAnalysis(block, CecilLoaderTests.Mscorlib); da.Analyze("i"); Assert.AreEqual(0, da.UnassignedVariableUses.Count); Assert.AreEqual(DefiniteAssignmentStatus.PotentiallyAssigned, da.GetStatusBefore(tryCatchStatement)); @@ -89,7 +90,7 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis TrueStatement = new BlockStatement(), FalseStatement = new BlockStatement() }; - DefiniteAssignmentAnalysis da = new DefiniteAssignmentAnalysis(ifStmt); + DefiniteAssignmentAnalysis da = new DefiniteAssignmentAnalysis(ifStmt, CecilLoaderTests.Mscorlib); da.Analyze("i"); Assert.AreEqual(0, da.UnassignedVariableUses.Count); Assert.AreEqual(DefiniteAssignmentStatus.PotentiallyAssigned, da.GetStatusBefore(ifStmt)); @@ -120,7 +121,7 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis TrueStatement = new BlockStatement(), FalseStatement = new BlockStatement() }; - DefiniteAssignmentAnalysis da = new DefiniteAssignmentAnalysis(ifStmt); + DefiniteAssignmentAnalysis da = new DefiniteAssignmentAnalysis(ifStmt, CecilLoaderTests.Mscorlib); da.Analyze("i"); Assert.AreEqual(0, da.UnassignedVariableUses.Count); Assert.AreEqual(DefiniteAssignmentStatus.PotentiallyAssigned, da.GetStatusBefore(ifStmt)); @@ -128,5 +129,24 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis Assert.AreEqual(DefiniteAssignmentStatus.DefinitelyAssigned, da.GetStatusBefore(ifStmt.FalseStatement)); Assert.AreEqual(DefiniteAssignmentStatus.PotentiallyAssigned, da.GetStatusAfter(ifStmt)); } + + [Test] + public void WhileTrue() + { + WhileStatement loop = new WhileStatement { + Condition = new PrimitiveExpression(true), + EmbeddedStatement = new BlockStatement { + new AssignmentExpression(new IdentifierExpression("i"), new PrimitiveExpression(0)), + new BreakStatement() + } + }; + DefiniteAssignmentAnalysis da = new DefiniteAssignmentAnalysis(loop, CecilLoaderTests.Mscorlib); + da.Analyze("i"); + Assert.AreEqual(0, da.UnassignedVariableUses.Count); + Assert.AreEqual(DefiniteAssignmentStatus.PotentiallyAssigned, da.GetStatusBefore(loop)); + Assert.AreEqual(DefiniteAssignmentStatus.PotentiallyAssigned, da.GetStatusBefore(loop.EmbeddedStatement)); + Assert.AreEqual(DefiniteAssignmentStatus.CodeUnreachable, da.GetStatusAfter(loop.EmbeddedStatement)); + Assert.AreEqual(DefiniteAssignmentStatus.DefinitelyAssigned, da.GetStatusAfter(loop)); + } } } diff --git a/NRefactory/ICSharpCode.NRefactory.Tests/TypeSystem/TypeSystemTests.TestCase.cs b/NRefactory/ICSharpCode.NRefactory.Tests/TypeSystem/TypeSystemTests.TestCase.cs index cfce62245..5092401ed 100644 --- a/NRefactory/ICSharpCode.NRefactory.Tests/TypeSystem/TypeSystemTests.TestCase.cs +++ b/NRefactory/ICSharpCode.NRefactory.Tests/TypeSystem/TypeSystemTests.TestCase.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Runtime.InteropServices; [assembly: ICSharpCode.NRefactory.TypeSystem.TestCase.TypeTestAttribute( 42, typeof(System.Action<>), typeof(IDictionary>))] @@ -65,4 +66,31 @@ namespace ICSharpCode.NRefactory.TypeSystem.TestCase { public MyStructWithCtor(int a) {} } + + [Serializable] + public class NonCustomAttributes + { + [NonSerialized] + public readonly int NonSerializedField; + + [DllImport("unmanaged.dll", CharSet = CharSet.Unicode)] + [return: MarshalAs(UnmanagedType.Bool)] + public static extern bool DllMethod([In, Out] ref int p); + } + + [StructLayout(LayoutKind.Explicit, CharSet = CharSet.Unicode, Pack = 8)] + public struct ExplicitFieldLayoutStruct + { + [FieldOffset(0)] + public int Field0; + + [FieldOffset(100)] + public int Field100; + } + + public class ParameterTests + { + public void MethodWithOutParameter(out int x) { x = 0; } + public void MethodWithParamsArray(params object[] x) {} + } } diff --git a/NRefactory/ICSharpCode.NRefactory.Tests/TypeSystem/TypeSystemTests.cs b/NRefactory/ICSharpCode.NRefactory.Tests/TypeSystem/TypeSystemTests.cs index ba43e776d..2959f00ed 100644 --- a/NRefactory/ICSharpCode.NRefactory.Tests/TypeSystem/TypeSystemTests.cs +++ b/NRefactory/ICSharpCode.NRefactory.Tests/TypeSystem/TypeSystemTests.cs @@ -3,6 +3,7 @@ using System; using System.Linq; +using System.Runtime.InteropServices; using ICSharpCode.NRefactory.TypeSystem.Implementation; using ICSharpCode.NRefactory.TypeSystem.TestCase; using NUnit.Framework; @@ -261,5 +262,105 @@ namespace ICSharpCode.NRefactory.TypeSystem Assert.AreEqual(2, ctors.Count()); Assert.IsFalse(ctors.Any(c => c.IsStatic)); } + + [Test] + public void SerializableAttribute() + { + IAttribute attr = ctx.GetClass(typeof(NonCustomAttributes)).Attributes.Single(); + Assert.AreEqual("System.SerializableAttribute", attr.AttributeType.Resolve(ctx).FullName); + } + + [Test] + public void NonSerializedAttribute() + { + IField field = ctx.GetClass(typeof(NonCustomAttributes)).Fields.Single(f => f.Name == "NonSerializedField"); + Assert.AreEqual("System.NonSerializedAttribute", field.Attributes.Single().AttributeType.Resolve(ctx).FullName); + } + + [Test] + public void ExplicitStructLayoutAttribute() + { + IAttribute attr = ctx.GetClass(typeof(ExplicitFieldLayoutStruct)).Attributes.Single(); + Assert.AreEqual("System.Runtime.InteropServices.StructLayoutAttribute", attr.AttributeType.Resolve(ctx).FullName); + IConstantValue arg1 = attr.PositionalArguments.Single(); + Assert.AreEqual("System.Runtime.InteropServices.LayoutKind", arg1.GetValueType(ctx).FullName); + Assert.AreEqual((int)LayoutKind.Explicit, arg1.GetValue(ctx)); + + var arg2 = attr.NamedArguments[0]; + Assert.AreEqual("CharSet", arg2.Key); + Assert.AreEqual("System.Runtime.InteropServices.CharSet", arg2.Value.GetValueType(ctx).FullName); + Assert.AreEqual((int)CharSet.Unicode, arg2.Value.GetValue(ctx)); + + var arg3 = attr.NamedArguments[1]; + Assert.AreEqual("Pack", arg3.Key); + Assert.AreEqual("System.Int32", arg3.Value.GetValueType(ctx).FullName); + Assert.AreEqual(8, arg3.Value.GetValue(ctx)); + } + + [Test] + public void FieldOffsetAttribute() + { + IField field = ctx.GetClass(typeof(ExplicitFieldLayoutStruct)).Fields.Single(f => f.Name == "Field0"); + Assert.AreEqual("System.Runtime.InteropServices.FieldOffsetAttribute", field.Attributes.Single().AttributeType.Resolve(ctx).FullName); + IConstantValue arg = field.Attributes.Single().PositionalArguments.Single(); + Assert.AreEqual("System.Int32", arg.GetValueType(ctx).FullName); + Assert.AreEqual(0, arg.GetValue(ctx)); + + field = ctx.GetClass(typeof(ExplicitFieldLayoutStruct)).Fields.Single(f => f.Name == "Field100"); + Assert.AreEqual("System.Runtime.InteropServices.FieldOffsetAttribute", field.Attributes.Single().AttributeType.Resolve(ctx).FullName); + arg = field.Attributes.Single().PositionalArguments.Single(); + Assert.AreEqual("System.Int32", arg.GetValueType(ctx).FullName); + Assert.AreEqual(100, arg.GetValue(ctx)); + } + + [Test] + public void DllImportAttribute() + { + IMethod method = ctx.GetClass(typeof(NonCustomAttributes)).Methods.Single(m => m.Name == "DllMethod"); + IAttribute dllImport = method.Attributes.Single(); + Assert.AreEqual("System.Runtime.InteropServices.DllImportAttribute", dllImport.AttributeType.Resolve(ctx).FullName); + Assert.AreEqual("unmanaged.dll", dllImport.PositionalArguments[0].GetValue(ctx)); + Assert.AreEqual((int)CharSet.Unicode, dllImport.NamedArguments.Single().Value.GetValue(ctx)); + } + + [Test] + public void InOutParametersOnRefMethod() + { + IParameter p = ctx.GetClass(typeof(NonCustomAttributes)).Methods.Single(m => m.Name == "DllMethod").Parameters.Single(); + Assert.IsTrue(p.IsRef); + Assert.IsFalse(p.IsOut); + Assert.AreEqual(2, p.Attributes.Count); + Assert.AreEqual("System.Runtime.InteropServices.InAttribute", p.Attributes[0].AttributeType.Resolve(ctx).FullName); + Assert.AreEqual("System.Runtime.InteropServices.OutAttribute", p.Attributes[1].AttributeType.Resolve(ctx).FullName); + } + + [Test] + public void MarshalAsAttributeOnMethod() + { + IMethod method = ctx.GetClass(typeof(NonCustomAttributes)).Methods.Single(m => m.Name == "DllMethod"); + IAttribute marshalAs = method.ReturnTypeAttributes.Single(); + Assert.AreEqual((int)UnmanagedType.Bool, marshalAs.PositionalArguments.Single().GetValue(ctx)); + } + + [Test] + public void MethodWithOutParameter() + { + IParameter p = ctx.GetClass(typeof(ParameterTests)).Methods.Single(m => m.Name == "MethodWithOutParameter").Parameters.Single(); + Assert.IsFalse(p.IsRef); + Assert.IsTrue(p.IsOut); + Assert.AreEqual(0, p.Attributes.Count); + Assert.IsTrue(p.Type is ByReferenceTypeReference); + } + + [Test] + public void MethodWithParamsArray() + { + IParameter p = ctx.GetClass(typeof(ParameterTests)).Methods.Single(m => m.Name == "MethodWithParamsArray").Parameters.Single(); + Assert.IsFalse(p.IsRef); + Assert.IsFalse(p.IsOut); + Assert.IsTrue(p.IsParams); + Assert.AreEqual(0, p.Attributes.Count); + Assert.IsTrue(p.Type is ArrayTypeReference); + } } } diff --git a/NRefactory/ICSharpCode.NRefactory/CSharp/Analysis/ControlFlow.cs b/NRefactory/ICSharpCode.NRefactory/CSharp/Analysis/ControlFlow.cs index 2f1e2b47e..9a2fa141d 100644 --- a/NRefactory/ICSharpCode.NRefactory/CSharp/Analysis/ControlFlow.cs +++ b/NRefactory/ICSharpCode.NRefactory/CSharp/Analysis/ControlFlow.cs @@ -5,7 +5,10 @@ using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Threading; + using ICSharpCode.NRefactory.CSharp.Resolver; +using ICSharpCode.NRefactory.TypeSystem; namespace ICSharpCode.NRefactory.CSharp.Analysis { @@ -135,12 +138,24 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis } Statement rootStatement; + ResolveVisitor resolveVisitor; List nodes; Dictionary labels; List gotoStatements; - public IList BuildControlFlowGraph(Statement statement) + public IList BuildControlFlowGraph(Statement statement, ITypeResolveContext context, CancellationToken cancellationToken = default(CancellationToken)) { + return BuildControlFlowGraph(statement, new ResolveVisitor( + new CSharpResolver(context, cancellationToken), null, ConstantModeResolveVisitorNavigator.Skip)); + } + + public IList BuildControlFlowGraph(Statement statement, ResolveVisitor resolveVisitor) + { + if (statement == null) + throw new ArgumentNullException("statement"); + if (resolveVisitor == null) + throw new ArgumentNullException("resolveVisitor"); + NodeCreationVisitor nodeCreationVisitor = new NodeCreationVisitor(); nodeCreationVisitor.builder = this; try { @@ -148,6 +163,7 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis this.labels = new Dictionary(); this.gotoStatements = new List(); this.rootStatement = statement; + this.resolveVisitor = resolveVisitor; ControlFlowNode entryPoint = CreateStartNode(statement); statement.AcceptVisitor(nodeCreationVisitor, entryPoint); @@ -167,6 +183,7 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis this.labels = null; this.gotoStatements = null; this.rootStatement = null; + this.resolveVisitor = null; } } @@ -206,7 +223,7 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis ControlFlowNode CreateSpecialNode(Statement statement, ControlFlowNodeType type) { - ControlFlowNode node = CreateNode(statement, null, type); + ControlFlowNode node = CreateNode(null, statement, type); nodes.Add(node); return node; } @@ -238,7 +255,7 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis /// The constant value of the expression; or null if the expression is not a constant. ConstantResolveResult EvaluateConstant(Expression expr) { - return null; // TODO: implement this using the C# resolver + return resolveVisitor.Resolve(expr) as ConstantResolveResult; } /// @@ -256,7 +273,11 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis bool AreEqualConstants(ConstantResolveResult c1, ConstantResolveResult c2) { - return false; // TODO: implement this using the resolver's operator== + if (c1 == null || c2 == null) + return false; + CSharpResolver r = new CSharpResolver(resolveVisitor.TypeResolveContext, resolveVisitor.CancellationToken); + ResolveResult c = r.ResolveBinaryOperator(BinaryOperatorType.Equality, c1, c2); + return c.IsCompileTimeConstant && (c.ConstantValue as bool?) == true; } #endregion diff --git a/NRefactory/ICSharpCode.NRefactory/CSharp/Analysis/DefiniteAssignmentAnalysis.cs b/NRefactory/ICSharpCode.NRefactory/CSharp/Analysis/DefiniteAssignmentAnalysis.cs index 534e42325..e37702239 100644 --- a/NRefactory/ICSharpCode.NRefactory/CSharp/Analysis/DefiniteAssignmentAnalysis.cs +++ b/NRefactory/ICSharpCode.NRefactory/CSharp/Analysis/DefiniteAssignmentAnalysis.cs @@ -5,7 +5,9 @@ using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Threading; using ICSharpCode.NRefactory.CSharp.Resolver; +using ICSharpCode.NRefactory.TypeSystem; using ICSharpCode.NRefactory.Utils; namespace ICSharpCode.NRefactory.CSharp.Analysis @@ -46,6 +48,8 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis readonly List allNodes = new List(); readonly Dictionary beginNodeDict = new Dictionary(); readonly Dictionary endNodeDict = new Dictionary(); + readonly ResolveVisitor resolveVisitor; + readonly CancellationToken cancellationToken; Dictionary nodeStatus = new Dictionary(); Dictionary edgeStatus = new Dictionary(); @@ -54,19 +58,30 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis Queue nodesWithModifiedInput = new Queue(); - public DefiniteAssignmentAnalysis(Statement rootStatement) + public DefiniteAssignmentAnalysis(Statement rootStatement, ITypeResolveContext context, CancellationToken cancellationToken = default(CancellationToken)) + : this(rootStatement, new ResolveVisitor(new CSharpResolver(context, cancellationToken), null, ConstantModeResolveVisitorNavigator.Skip)) { + } + + public DefiniteAssignmentAnalysis(Statement rootStatement, ResolveVisitor resolveVisitor) + { + if (rootStatement == null) + throw new ArgumentNullException("rootStatement"); + if (resolveVisitor == null) + throw new ArgumentNullException("resolveVisitor"); + this.resolveVisitor = resolveVisitor; + this.cancellationToken = resolveVisitor.CancellationToken; visitor.analysis = this; ControlFlowGraphBuilder b = new ControlFlowGraphBuilder(); - allNodes.AddRange(b.BuildControlFlowGraph(rootStatement)); + allNodes.AddRange(b.BuildControlFlowGraph(rootStatement, resolveVisitor)); foreach (AstNode descendant in rootStatement.Descendants) { // Anonymous methods have separate control flow graphs, but we also need to analyze those. AnonymousMethodExpression ame = descendant as AnonymousMethodExpression; if (ame != null) - allNodes.AddRange(b.BuildControlFlowGraph(ame.Body)); + allNodes.AddRange(b.BuildControlFlowGraph(ame.Body, resolveVisitor)); LambdaExpression lambda = descendant as LambdaExpression; if (lambda != null && lambda.Body is Statement) - allNodes.AddRange(b.BuildControlFlowGraph((Statement)lambda.Body)); + allNodes.AddRange(b.BuildControlFlowGraph((Statement)lambda.Body, resolveVisitor)); } // Verify that we created nodes for all statements: Debug.Assert(!rootStatement.DescendantsAndSelf.OfType().Except(allNodes.Select(n => n.NextStatement)).Any()); @@ -289,7 +304,7 @@ namespace ICSharpCode.NRefactory.CSharp.Analysis /// The constant value of the expression; or null if the expression is not a constant. ConstantResolveResult EvaluateConstant(Expression expr) { - return null; // TODO: implement this using the C# resolver + return resolveVisitor.Resolve(expr) as ConstantResolveResult; } /// diff --git a/NRefactory/ICSharpCode.NRefactory/CSharp/Ast/AstNode.cs b/NRefactory/ICSharpCode.NRefactory/CSharp/Ast/AstNode.cs index 8264bdc9d..bd5fbcc58 100644 --- a/NRefactory/ICSharpCode.NRefactory/CSharp/Ast/AstNode.cs +++ b/NRefactory/ICSharpCode.NRefactory/CSharp/Ast/AstNode.cs @@ -129,6 +129,7 @@ namespace ICSharpCode.NRefactory.CSharp get { AstNode next; for (AstNode cur = firstChild; cur != null; cur = next) { + Debug.Assert(cur.parent == this); // Remember next before yielding cur. // This allows removing/replacing nodes while iterating through the list. next = cur.nextSibling; diff --git a/NRefactory/ICSharpCode.NRefactory/CSharp/Ast/AstNodeCollection.cs b/NRefactory/ICSharpCode.NRefactory/CSharp/Ast/AstNodeCollection.cs index 07e062672..a8c11cb99 100644 --- a/NRefactory/ICSharpCode.NRefactory/CSharp/Ast/AstNodeCollection.cs +++ b/NRefactory/ICSharpCode.NRefactory/CSharp/Ast/AstNodeCollection.cs @@ -117,6 +117,7 @@ namespace ICSharpCode.NRefactory.CSharp { AstNode next; for (AstNode cur = node.FirstChild; cur != null; cur = next) { + Debug.Assert(cur.Parent == node); // Remember next before yielding cur. // This allows removing/replacing nodes while iterating through the list. next = cur.NextSibling; diff --git a/NRefactory/ICSharpCode.NRefactory/CSharp/Parser/CSharpParser.cs b/NRefactory/ICSharpCode.NRefactory/CSharp/Parser/CSharpParser.cs index 68ac7ee2f..a056c8ba1 100644 --- a/NRefactory/ICSharpCode.NRefactory/CSharp/Parser/CSharpParser.cs +++ b/NRefactory/ICSharpCode.NRefactory/CSharp/Parser/CSharpParser.cs @@ -374,9 +374,9 @@ namespace ICSharpCode.NRefactory.CSharp variable.AddChild (new Identifier (em.Name, Convert (em.Location)), AstNode.Roles.Identifier); if (em.Initializer != null) { - var initializer = (VariableInitializer)em.Initializer.Accept (this); + var initializer = (Expression)em.Initializer.Accept (this); if (initializer != null) - variable.AddChild (initializer, AstNode.Roles.Variable); + variable.AddChild (initializer, EnumMemberDeclaration.InitializerRole); } newField.AddChild (variable, AstNode.Roles.Variable); diff --git a/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/CSharpResolver.cs b/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/CSharpResolver.cs index 904bbae99..fd43e18c4 100644 --- a/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/CSharpResolver.cs +++ b/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/CSharpResolver.cs @@ -33,7 +33,6 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver this.context = context; } - #if !DOTNET35 public CSharpResolver(ITypeResolveContext context, CancellationToken cancellationToken) { if (context == null) @@ -41,7 +40,6 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver this.context = context; this.cancellationToken = cancellationToken; } - #endif #endregion #region Properties diff --git a/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/IResolveVisitorNavigator.cs b/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/IResolveVisitorNavigator.cs index 0ae7b02a2..be956b22f 100644 --- a/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/IResolveVisitorNavigator.cs +++ b/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/IResolveVisitorNavigator.cs @@ -37,4 +37,16 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver /// ResolveAll } + + sealed class ConstantModeResolveVisitorNavigator : IResolveVisitorNavigator + { + ResolveVisitorNavigationMode mode; + + public static readonly IResolveVisitorNavigator Skip = new ConstantModeResolveVisitorNavigator { mode = ResolveVisitorNavigationMode.Skip }; + + ResolveVisitorNavigationMode IResolveVisitorNavigator.Scan(AstNode node) + { + return mode; + } + } } diff --git a/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/ResolveVisitor.cs b/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/ResolveVisitor.cs index 060bdfa7f..2c4002a41 100644 --- a/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/ResolveVisitor.cs +++ b/NRefactory/ICSharpCode.NRefactory/CSharp/Resolver/ResolveVisitor.cs @@ -74,6 +74,20 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver } #endregion + /// + /// Gets the TypeResolveContext used by this ResolveVisitor. + /// + public ITypeResolveContext TypeResolveContext { + get { return resolver.Context; } + } + + /// + /// Gets the CancellationToken used by this ResolveVisitor. + /// + public CancellationToken CancellationToken { + get { return resolver.cancellationToken; } + } + #region Scan / Resolve bool resolverEnabled { get { return mode != ResolveVisitorNavigationMode.Scan; } @@ -118,6 +132,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver mode = ResolveVisitorNavigationMode.Resolve; ResolveResult result; if (!cache.TryGetValue(node, out result)) { + resolver.cancellationToken.ThrowIfCancellationRequested(); result = cache[node] = node.AcceptVisitor(this, null) ?? errorResult; } if (wasScan) diff --git a/NRefactory/ICSharpCode.NRefactory/TypeSystem/CecilLoader.cs b/NRefactory/ICSharpCode.NRefactory/TypeSystem/CecilLoader.cs index 7564ed995..74cbdb91c 100644 --- a/NRefactory/ICSharpCode.NRefactory/TypeSystem/CecilLoader.cs +++ b/NRefactory/ICSharpCode.NRefactory/TypeSystem/CecilLoader.cs @@ -46,6 +46,10 @@ namespace ICSharpCode.NRefactory.TypeSystem #endregion #region Load From AssemblyDefinition + /// + /// Loads the assembly definition into a project content. + /// + /// IProjectContent that represents the assembly public IProjectContent LoadAssembly(AssemblyDefinition assemblyDefinition) { if (assemblyDefinition == null) @@ -144,7 +148,7 @@ namespace ICSharpCode.NRefactory.TypeSystem public void Dispose() { - // Disposibng the synchronization context has no effect + // Disposing the synchronization context has no effect } string IDocumentationProvider.GetDocumentation(IEntity entity) @@ -297,14 +301,13 @@ namespace ICSharpCode.NRefactory.TypeSystem } } - const string DynamicAttributeFullName = "System.Runtime.CompilerServices.DynamicAttribute"; - static bool HasDynamicAttribute(ICustomAttributeProvider attributeProvider, int typeIndex) { if (attributeProvider == null || !attributeProvider.HasCustomAttributes) return false; foreach (CustomAttribute a in attributeProvider.CustomAttributes) { - if (a.Constructor.DeclaringType.FullName == DynamicAttributeFullName) { + TypeReference type = a.AttributeType; + if (type.Name == "DynamicAttribute" && type.Namespace == "System.Runtime.CompilerServices") { if (a.ConstructorArguments.Count == 1) { CustomAttributeArgument[] values = a.ConstructorArguments[0].Value as CustomAttributeArgument[]; if (values != null && typeIndex < values.Length && values[typeIndex].Value is bool) @@ -325,21 +328,150 @@ namespace ICSharpCode.NRefactory.TypeSystem } } + static readonly IAttribute inAttribute = new DefaultAttribute(typeof(InAttribute).ToTypeReference(), null); + static readonly IAttribute outAttribute = new DefaultAttribute(typeof(OutAttribute).ToTypeReference(), null); + void AddAttributes(ParameterDefinition parameter, DefaultParameter targetParameter) { + if (!targetParameter.IsOut) { + if (parameter.IsIn) + targetParameter.Attributes.Add(inAttribute); + if (parameter.IsOut) + targetParameter.Attributes.Add(outAttribute); + } if (parameter.HasCustomAttributes) { AddCustomAttributes(parameter.CustomAttributes, targetParameter.Attributes); } } - void AddAttributes(MethodDefinition accessorMethod, DefaultAccessor targetAccessor) + static readonly ITypeReference dllImportAttributeTypeRef = typeof(DllImportAttribute).ToTypeReference(); + static readonly SimpleConstantValue trueValue = new SimpleConstantValue(KnownTypeReference.Boolean, true); + static readonly SimpleConstantValue falseValue = new SimpleConstantValue(KnownTypeReference.Boolean, true); + static readonly ITypeReference callingConventionTypeRef = typeof(CallingConvention).ToTypeReference(); + static readonly IAttribute preserveSigAttribute = new DefaultAttribute(typeof(PreserveSigAttribute).ToTypeReference(), null); + static readonly ITypeReference methodImplAttributeTypeRef = typeof(MethodImplAttribute).ToTypeReference(); + static readonly ITypeReference methodImplOptionsTypeRef = typeof(MethodImplOptions).ToTypeReference(); + + bool HasAnyAttributes(MethodDefinition methodDefinition) { - if (accessorMethod.HasCustomAttributes) { - AddCustomAttributes(accessorMethod.CustomAttributes, targetAccessor.Attributes); + if (methodDefinition.HasPInvokeInfo) + return true; + if ((methodDefinition.ImplAttributes & ~MethodImplAttributes.CodeTypeMask) != 0) + return true; + if (methodDefinition.MethodReturnType.HasFieldMarshal) + return true; + return methodDefinition.HasCustomAttributes || methodDefinition.MethodReturnType.HasCustomAttributes; + } + + void AddAttributes(MethodDefinition methodDefinition, IList attributes, IList returnTypeAttributes) + { + MethodImplAttributes implAttributes = methodDefinition.ImplAttributes & ~MethodImplAttributes.CodeTypeMask; + + #region DllImportAttribute + if (methodDefinition.HasPInvokeInfo) { + PInvokeInfo info = methodDefinition.PInvokeInfo; + DefaultAttribute dllImport = new DefaultAttribute(dllImportAttributeTypeRef, new[] { KnownTypeReference.String }); + dllImport.PositionalArguments.Add(new SimpleConstantValue(KnownTypeReference.String, info.Module.Name)); + + if (info.IsBestFitDisabled) + AddNamedArgument(dllImport, "BestFitMapping", falseValue); + if (info.IsBestFitEnabled) + AddNamedArgument(dllImport, "BestFitMapping", trueValue); + + CallingConvention callingConvention; + switch (info.Attributes & PInvokeAttributes.CallConvMask) { + case PInvokeAttributes.CallConvCdecl: + callingConvention = CallingConvention.Cdecl; + break; + case PInvokeAttributes.CallConvFastcall: + callingConvention = CallingConvention.FastCall; + break; + case PInvokeAttributes.CallConvStdCall: + callingConvention = CallingConvention.StdCall; + break; + case PInvokeAttributes.CallConvThiscall: + callingConvention = CallingConvention.ThisCall; + break; + case PInvokeAttributes.CallConvWinapi: + callingConvention = CallingConvention.Winapi; + break; + default: + throw new NotSupportedException("unknown calling convention"); + } + if (callingConvention != CallingConvention.Winapi) + AddNamedArgument(dllImport, "CallingConvention", new SimpleConstantValue(callingConventionTypeRef, (int)callingConvention)); + + CharSet charSet = CharSet.None; + switch (info.Attributes & PInvokeAttributes.CharSetMask) { + case PInvokeAttributes.CharSetAnsi: + charSet = CharSet.Ansi; + break; + case PInvokeAttributes.CharSetAuto: + charSet = CharSet.Auto; + break; + case PInvokeAttributes.CharSetUnicode: + charSet = CharSet.Unicode; + break; + } + if (charSet != CharSet.None) + dllImport.NamedArguments.Add(new KeyValuePair( + "CharSet", new SimpleConstantValue(charSetTypeRef, (int)charSet))); + + if (!string.IsNullOrEmpty(info.EntryPoint) && info.EntryPoint != methodDefinition.Name) + AddNamedArgument(dllImport, "EntryPoint", new SimpleConstantValue(KnownTypeReference.String, info.EntryPoint)); + + if (info.IsNoMangle) + AddNamedArgument(dllImport, "ExactSpelling", trueValue); + + if ((implAttributes & MethodImplAttributes.PreserveSig) == MethodImplAttributes.PreserveSig) + implAttributes &= ~MethodImplAttributes.PreserveSig; + else + AddNamedArgument(dllImport, "PreserveSig", falseValue); + + if (info.SupportsLastError) + AddNamedArgument(dllImport, "SetLastError", trueValue); + + if (info.IsThrowOnUnmappableCharDisabled) + AddNamedArgument(dllImport, "ThrowOnUnmappableChar", falseValue); + if (info.IsThrowOnUnmappableCharEnabled) + AddNamedArgument(dllImport, "ThrowOnUnmappableChar", trueValue); + + attributes.Add(dllImport); + } + #endregion + + #region PreserveSigAttribute + if (implAttributes == MethodImplAttributes.PreserveSig) { + attributes.Add(preserveSigAttribute); + implAttributes = 0; + } + #endregion + + #region MethodImplAttribute + if (implAttributes != 0) { + DefaultAttribute methodImpl = new DefaultAttribute(methodImplAttributeTypeRef, new[] { methodImplOptionsTypeRef }); + methodImpl.PositionalArguments.Add(new SimpleConstantValue(methodImplOptionsTypeRef, (int)implAttributes)); + attributes.Add(methodImpl); + } + #endregion + + if (methodDefinition.HasCustomAttributes) { + AddCustomAttributes(methodDefinition.CustomAttributes, attributes); + } + if (methodDefinition.MethodReturnType.HasMarshalInfo) { + returnTypeAttributes.Add(ConvertMarshalInfo(methodDefinition.MethodReturnType.MarshalInfo)); } + if (methodDefinition.MethodReturnType.HasCustomAttributes) { + AddCustomAttributes(methodDefinition.MethodReturnType.CustomAttributes, returnTypeAttributes); + } + } + + static void AddNamedArgument(DefaultAttribute attribute, string name, IConstantValue value) + { + attribute.NamedArguments.Add(new KeyValuePair(name, value)); } - static readonly DefaultAttribute serializableAttribute = new DefaultAttribute(typeof(SerializableAttribute).ToTypeReference()); + static readonly DefaultAttribute serializableAttribute = new DefaultAttribute(typeof(SerializableAttribute).ToTypeReference(), null); static readonly ITypeReference structLayoutAttributeTypeRef = typeof(StructLayoutAttribute).ToTypeReference(); static readonly ITypeReference layoutKindTypeRef = typeof(LayoutKind).ToTypeReference(); static readonly ITypeReference charSetTypeRef = typeof(CharSet).ToTypeReference(); @@ -373,8 +505,9 @@ namespace ICSharpCode.NRefactory.TypeSystem charSet = CharSet.Unicode; break; } - if (layoutKind != LayoutKind.Auto || charSet != CharSet.Ansi || typeDefinition.PackingSize > 0 || typeDefinition.ClassSize > 0) { - DefaultAttribute structLayout = new DefaultAttribute(structLayoutAttributeTypeRef); + LayoutKind defaultLayoutKind = (typeDefinition.IsValueType && !typeDefinition.IsEnum) ? LayoutKind.Sequential: LayoutKind.Auto; + if (layoutKind != defaultLayoutKind || charSet != CharSet.Ansi || typeDefinition.PackingSize > 0 || typeDefinition.ClassSize > 0) { + DefaultAttribute structLayout = new DefaultAttribute(structLayoutAttributeTypeRef, new[] { layoutKindTypeRef }); structLayout.PositionalArguments.Add(new SimpleConstantValue(layoutKindTypeRef, (int)layoutKind)); if (charSet != CharSet.Ansi) { structLayout.NamedArguments.Add(new KeyValuePair( @@ -400,12 +533,58 @@ namespace ICSharpCode.NRefactory.TypeSystem } } + static readonly ITypeReference fieldOffsetAttributeTypeRef = typeof(FieldOffsetAttribute).ToTypeReference(); + static readonly DefaultAttribute nonSerializedAttribute = new DefaultAttribute(typeof(NonSerializedAttribute).ToTypeReference(), null); + + void AddAttributes(FieldDefinition fieldDefinition, IEntity targetEntity) + { + #region FieldOffsetAttribute + if (fieldDefinition.HasLayoutInfo) { + DefaultAttribute fieldOffset = new DefaultAttribute(fieldOffsetAttributeTypeRef, new[] { KnownTypeReference.Int32 }); + fieldOffset.PositionalArguments.Add(new SimpleConstantValue(KnownTypeReference.Int32, fieldDefinition.Offset)); + targetEntity.Attributes.Add(fieldOffset); + } + #endregion + + #region NonSerializedAttribute + if (fieldDefinition.IsNotSerialized) { + targetEntity.Attributes.Add(nonSerializedAttribute); + } + #endregion + + if (fieldDefinition.HasMarshalInfo) { + targetEntity.Attributes.Add(ConvertMarshalInfo(fieldDefinition.MarshalInfo)); + } + + if (fieldDefinition.HasCustomAttributes) { + AddCustomAttributes(fieldDefinition.CustomAttributes, targetEntity.Attributes); + } + } + + #region MarshalAsAttribute (ConvertMarshalInfo) + static readonly ITypeReference marshalAsAttributeTypeRef = typeof(MarshalAsAttribute).ToTypeReference(); + static readonly ITypeReference unmanagedTypeTypeRef = typeof(UnmanagedType).ToTypeReference(); + + static IAttribute ConvertMarshalInfo(MarshalInfo marshalInfo) + { + DefaultAttribute attr = new DefaultAttribute(marshalAsAttributeTypeRef, new[] { unmanagedTypeTypeRef }); + attr.PositionalArguments.Add(new SimpleConstantValue(unmanagedTypeTypeRef, (int)marshalInfo.NativeType)); + // TODO: handle classes derived from MarshalInfo + return attr; + } + #endregion + void AddCustomAttributes(Mono.Collections.Generic.Collection attributes, IList targetCollection) { foreach (var cecilAttribute in attributes) { - if (cecilAttribute.AttributeType.FullName != DynamicAttributeFullName) { - targetCollection.Add(ReadAttribute(cecilAttribute)); + TypeReference type = cecilAttribute.AttributeType; + if (type.Namespace == "System.Runtime.CompilerServices") { + if (type.Name == "DynamicAttribute" || type.Name == "ExtensionAttribute") + continue; + } else if (type.Name == "ParamArrayAttribute" && type.Namespace == "System") { + continue; } + targetCollection.Add(ReadAttribute(cecilAttribute)); } } @@ -413,7 +592,15 @@ namespace ICSharpCode.NRefactory.TypeSystem { if (attribute == null) throw new ArgumentNullException("attribute"); - DefaultAttribute a = new DefaultAttribute(ReadTypeReference(attribute.AttributeType)); + MethodReference ctor = attribute.Constructor; + ITypeReference[] ctorParameters = null; + if (ctor.HasParameters) { + ctorParameters = new ITypeReference[ctor.Parameters.Count]; + for (int i = 0; i < ctorParameters.Length; i++) { + ctorParameters[i] = ReadTypeReference(ctor.Parameters[i].ParameterType); + } + } + DefaultAttribute a = new DefaultAttribute(ReadTypeReference(attribute.AttributeType), ctorParameters); try { if (attribute.HasConstructorArguments) { foreach (var arg in attribute.ConstructorArguments) { @@ -442,8 +629,13 @@ namespace ICSharpCode.NRefactory.TypeSystem #region Read Constant Value public IConstantValue ReadConstantValue(CustomAttributeArgument arg) { - ITypeReference type = ReadTypeReference(arg.Type); object value = arg.Value; + if (value is CustomAttributeArgument) { + // Cecil uses this representation for boxed values + arg = (CustomAttributeArgument)value; + value = arg.Value; + } + ITypeReference type = ReadTypeReference(arg.Type); CustomAttributeArgument[] array = value as CustomAttributeArgument[]; if (array != null) { // TODO: write unit test for this @@ -500,9 +692,9 @@ namespace ICSharpCode.NRefactory.TypeSystem InitNestedTypes(loader); // nested types can be initialized only after generic parameters were created - if (typeDefinition.HasCustomAttributes) { - loader.AddAttributes(typeDefinition, this); - } + loader.AddAttributes(typeDefinition, this); + + this.HasExtensionMethods = HasExtensionAttribute(typeDefinition); // set base classes if (typeDefinition.IsEnum) { @@ -699,7 +891,8 @@ namespace ICSharpCode.NRefactory.TypeSystem else m.ReturnType = ReadTypeReference(method.ReturnType, typeAttributes: method.MethodReturnType, entity: m); - AddAttributes(method, m); + if (HasAnyAttributes(method)) + AddAttributes(method, m.Attributes, m.ReturnTypeAttributes); TranslateModifiers(method, m); if (method.HasParameters) { @@ -708,18 +901,26 @@ namespace ICSharpCode.NRefactory.TypeSystem } } - // mark as extension method is the attribute is set - if (method.IsStatic && method.HasCustomAttributes) { - foreach (var attr in method.CustomAttributes) { - if (attr.AttributeType.FullName == typeof(ExtensionAttribute).FullName) - m.IsExtensionMethod = true; - } + // mark as extension method if the attribute is set + if (method.IsStatic && HasExtensionAttribute(method)) { + m.IsExtensionMethod = true; } FinishReadMember(m); return m; } + static bool HasExtensionAttribute(ICustomAttributeProvider provider) + { + if (provider.HasCustomAttributes) { + foreach (var attr in provider.CustomAttributes) { + if (attr.AttributeType.Name == "ExtensionAttribute" && attr.AttributeType.Namespace == "System.Runtime.CompilerServices") + return true; + } + } + return false; + } + bool IsVisible(MethodAttributes att) { att &= MethodAttributes.MemberAccessMask; @@ -782,14 +983,13 @@ namespace ICSharpCode.NRefactory.TypeSystem var type = ReadTypeReference(parameter.ParameterType, typeAttributes: parameter, entity: parentMember); DefaultParameter p = new DefaultParameter(type, parameter.Name); - AddAttributes(parameter, p); - if (parameter.ParameterType is Mono.Cecil.ByReferenceType) { - if (parameter.IsOut) + if (!parameter.IsIn && parameter.IsOut) p.IsOut = true; else p.IsRef = true; } + AddAttributes(parameter, p); if (parameter.IsOptional) { p.DefaultValue = ReadConstantValue(new CustomAttributeArgument(parameter.ParameterType, parameter.Constant)); @@ -917,10 +1117,10 @@ namespace ICSharpCode.NRefactory.TypeSystem { if (accessorMethod != null && IsVisible(accessorMethod.Attributes)) { Accessibility accessibility = GetAccessibility(accessorMethod.Attributes); - if (accessorMethod.HasCustomAttributes) { + if (HasAnyAttributes(accessorMethod)) { DefaultAccessor a = new DefaultAccessor(); a.Accessibility = accessibility; - AddAttributes(accessorMethod, a); + AddAttributes(accessorMethod, a.Attributes, a.ReturnTypeAttributes); return a; } else { return DefaultAccessor.GetFromAccessibility(accessibility); diff --git a/NRefactory/ICSharpCode.NRefactory/TypeSystem/IAccessor.cs b/NRefactory/ICSharpCode.NRefactory/TypeSystem/IAccessor.cs index ed9278b25..9c33b586e 100644 --- a/NRefactory/ICSharpCode.NRefactory/TypeSystem/IAccessor.cs +++ b/NRefactory/ICSharpCode.NRefactory/TypeSystem/IAccessor.cs @@ -21,6 +21,11 @@ namespace ICSharpCode.NRefactory.TypeSystem /// IList Attributes { get; } + /// + /// Gets the attributes defined on the return type of the accessor. (e.g. [return: MarshalAs(...)]) + /// + IList ReturnTypeAttributes { get; } + /// /// Gets the accessibility of this accessor. /// diff --git a/NRefactory/ICSharpCode.NRefactory/TypeSystem/IAttribute.cs b/NRefactory/ICSharpCode.NRefactory/TypeSystem/IAttribute.cs index b9bb2dd97..42f913f45 100644 --- a/NRefactory/ICSharpCode.NRefactory/TypeSystem/IAttribute.cs +++ b/NRefactory/ICSharpCode.NRefactory/TypeSystem/IAttribute.cs @@ -35,6 +35,12 @@ namespace ICSharpCode.NRefactory.TypeSystem /// Gets the named arguments passed to the attribute. /// IList> NamedArguments { get; } + + /// + /// Resolves the constructor method used for this attribute invocation. + /// Returns null if the constructor cannot be found. + /// + IMethod ResolveConstructor(ITypeResolveContext context); } #if WITH_CONTRACTS diff --git a/NRefactory/ICSharpCode.NRefactory/TypeSystem/IMethod.cs b/NRefactory/ICSharpCode.NRefactory/TypeSystem/IMethod.cs index e66ab0782..7f6241b3d 100644 --- a/NRefactory/ICSharpCode.NRefactory/TypeSystem/IMethod.cs +++ b/NRefactory/ICSharpCode.NRefactory/TypeSystem/IMethod.cs @@ -16,7 +16,7 @@ namespace ICSharpCode.NRefactory.TypeSystem public interface IMethod : IParameterizedMember { /// - /// Gets the attributes associated with the return type. + /// Gets the attributes associated with the return type. (e.g. [return: MarshalAs(...)]) /// IList ReturnTypeAttributes { get; } diff --git a/NRefactory/ICSharpCode.NRefactory/TypeSystem/ITypeDefinition.cs b/NRefactory/ICSharpCode.NRefactory/TypeSystem/ITypeDefinition.cs index 905a85fdc..ef6568eda 100644 --- a/NRefactory/ICSharpCode.NRefactory/TypeSystem/ITypeDefinition.cs +++ b/NRefactory/ICSharpCode.NRefactory/TypeSystem/ITypeDefinition.cs @@ -44,6 +44,12 @@ namespace ICSharpCode.NRefactory.TypeSystem /// Gets all members declared in this class. This is the union of Fields,Properties,Methods and Events. /// IEnumerable Members { get; } + + /// + /// Gets whether this type contains extension methods. + /// + /// This property is used to speed up the search for extension methods. + bool HasExtensionMethods { get; } } #if WITH_CONTRACTS diff --git a/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultAccessor.cs b/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultAccessor.cs index 9a8580f0a..917b1f473 100644 --- a/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultAccessor.cs +++ b/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultAccessor.cs @@ -44,6 +44,7 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation Accessibility accessibility; DomRegion region; IList attributes; + IList returnTypeAttributes; protected override void FreezeInternal() { @@ -75,20 +76,32 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation } } + public IList ReturnTypeAttributes { + get { + if (returnTypeAttributes == null) + returnTypeAttributes = new List(); + return returnTypeAttributes; + } + } + void ISupportsInterning.PrepareForInterning(IInterningProvider provider) { attributes = provider.InternList(attributes); + returnTypeAttributes = provider.InternList(returnTypeAttributes); } int ISupportsInterning.GetHashCodeForInterning() { - return (attributes != null ? attributes.GetHashCode() : 0) ^ region.GetHashCode() ^ (int)accessibility; + return (attributes != null ? attributes.GetHashCode() : 0) + ^ (returnTypeAttributes != null ? returnTypeAttributes.GetHashCode() : 0) + ^ region.GetHashCode() ^ (int)accessibility; } bool ISupportsInterning.EqualsForInterning(ISupportsInterning other) { DefaultAccessor a = other as DefaultAccessor; - return a != null && (attributes == a.attributes && accessibility == a.accessibility && region == a.region); + return a != null && (attributes == a.attributes && returnTypeAttributes == a.returnTypeAttributes + && accessibility == a.accessibility && region == a.region); } } } diff --git a/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultAttribute.cs b/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultAttribute.cs index cb085ca18..8607ffca8 100644 --- a/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultAttribute.cs +++ b/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultAttribute.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Text; @@ -13,8 +14,9 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation /// public sealed class DefaultAttribute : AbstractFreezable, IAttribute, ISupportsInterning { - DomRegion region; ITypeReference attributeType; + readonly ITypeReference[] constructorParameterTypes; + DomRegion region; IList positionalArguments; IList> namedArguments; @@ -34,11 +36,20 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation base.FreezeInternal(); } - public DefaultAttribute(ITypeReference attributeType) + public DefaultAttribute(ITypeReference attributeType, IEnumerable constructorParameterTypes) { if (attributeType == null) throw new ArgumentNullException("attributeType"); this.attributeType = attributeType; + this.constructorParameterTypes = constructorParameterTypes != null ? constructorParameterTypes.ToArray() : null; + } + + public ITypeReference AttributeType { + get { return attributeType; } + } + + public ReadOnlyCollection ConstructorParameterTypes { + get { return Array.AsReadOnly(constructorParameterTypes); } } public DomRegion Region { @@ -49,14 +60,6 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation } } - public ITypeReference AttributeType { - get { return attributeType; } - set { - CheckBeforeMutation(); - attributeType = value; - } - } - public IList PositionalArguments { get { if (positionalArguments == null) @@ -73,6 +76,38 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation } } + public IMethod ResolveConstructor(ITypeResolveContext context) + { + IType[] parameterTypes = null; + if (constructorParameterTypes != null && constructorParameterTypes.Length > 0) { + parameterTypes = new IType[constructorParameterTypes.Length]; + for (int i = 0; i < parameterTypes.Length; i++) { + parameterTypes[i] = constructorParameterTypes[i].Resolve(context); + } + } + IMethod bestMatch = null; + foreach (IMethod ctor in attributeType.Resolve(context).GetConstructors(context)) { + if (ctor.IsStatic) + continue; + if (parameterTypes == null) { + if (ctor.Parameters.Count == 0) + return ctor; + } else if (ctor.Parameters.Count == parameterTypes.Length) { + bestMatch = ctor; + bool ok = true; + for (int i = 0; i < parameterTypes.Length; i++) { + if (ctor.Parameters[i].Type != parameterTypes[i]) { + ok = false; + break; + } + } + if (ok) + return ctor; + } + } + return bestMatch; + } + public override string ToString() { StringBuilder b = new StringBuilder(); @@ -100,6 +135,11 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation void ISupportsInterning.PrepareForInterning(IInterningProvider provider) { attributeType = provider.Intern(attributeType); + if (constructorParameterTypes != null) { + for (int i = 0; i < constructorParameterTypes.Length; i++) { + constructorParameterTypes[i] = provider.Intern(constructorParameterTypes[i]); + } + } positionalArguments = provider.InternList(positionalArguments); } diff --git a/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultTypeDefinition.cs b/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultTypeDefinition.cs index 881d38585..c707780c4 100644 --- a/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultTypeDefinition.cs +++ b/NRefactory/ICSharpCode.NRefactory/TypeSystem/Implementation/DefaultTypeDefinition.cs @@ -40,6 +40,7 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation const ushort FlagShadowing = 0x0004; const ushort FlagSynthetic = 0x0008; const ushort FlagAddDefaultConstructorIfRequired = 0x0010; + const ushort FlagHasExtensionMethods = 0x0020; protected override void FreezeInternal() { @@ -312,6 +313,14 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation } } + public bool HasExtensionMethods { + get { return flags[FlagHasExtensionMethods]; } + set { + CheckBeforeMutation(); + flags[FlagHasExtensionMethods] = value; + } + } + public IProjectContent ProjectContent { get { return projectContent; } }