diff --git a/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs b/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs index f1ce4d74e..b7bde357d 100644 --- a/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs +++ b/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs @@ -128,7 +128,7 @@ namespace ICSharpCode.Decompiler.CSharp new ConvertConstructorCallIntoInitializer(), // must run after DeclareVariables new DecimalConstantTransform(), new IntroduceUsingDeclarations(), - //new IntroduceExtensionMethods(context), // must run after IntroduceUsingDeclarations + new IntroduceExtensionMethods(), // must run after IntroduceUsingDeclarations //new IntroduceQueryExpressions(context), // must run after IntroduceExtensionMethods //new CombineQueryExpressions(context), //new FlattenSwitchBlocks(), diff --git a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs index c9d143422..4887d2bba 100644 --- a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs +++ b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs @@ -18,49 +18,119 @@ using System; using System.Linq; -using ICSharpCode.NRefactory.CSharp; -using Mono.Cecil; +using ICSharpCode.Decompiler.CSharp.Resolver; +using ICSharpCode.Decompiler.CSharp.Syntax; +using ICSharpCode.Decompiler.CSharp.TypeSystem; +using ICSharpCode.Decompiler.Semantics; +using ICSharpCode.Decompiler.TypeSystem; +using ICSharpCode.Decompiler.Util; namespace ICSharpCode.Decompiler.CSharp.Transforms { /// /// Converts extension method calls into infix syntax. /// - public class IntroduceExtensionMethods : IAstTransform + public class IntroduceExtensionMethods : DepthFirstAstVisitor, IAstTransform { - readonly DecompilerContext context; - - public IntroduceExtensionMethods(DecompilerContext context) + TransformContext context; + UsingScope rootUsingScope; + UsingScope usingScope; + IMember currentMember; + CSharpTypeResolveContext resolveContext; + CSharpResolver resolver; + + public void Run(AstNode rootNode, TransformContext context) { this.context = context; + this.usingScope = this.rootUsingScope = rootNode.Annotation(); + rootNode.AcceptVisitor(this); + } + + void SetContext() + { + this.usingScope = rootUsingScope; + foreach (var name in currentMember.Namespace.Split('.')) + usingScope = new UsingScope(usingScope, name); + resolveContext = new CSharpTypeResolveContext(currentMember.ParentAssembly, usingScope.Resolve(context.TypeSystem.Compilation), currentMember.DeclaringTypeDefinition, currentMember); + resolver = new CSharpResolver(resolveContext); + } + + public override void VisitMethodDeclaration(MethodDeclaration methodDeclaration) + { + currentMember = methodDeclaration.GetSymbol() as IMember; + SetContext(); + base.VisitMethodDeclaration(methodDeclaration); + currentMember = null; + } + + public override void VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration) + { + currentMember = constructorDeclaration.GetSymbol() as IMember; + SetContext(); + base.VisitConstructorDeclaration(constructorDeclaration); + currentMember = null; + } + + public override void VisitDestructorDeclaration(DestructorDeclaration destructorDeclaration) + { + currentMember = destructorDeclaration.GetSymbol() as IMember; + SetContext(); + base.VisitDestructorDeclaration(destructorDeclaration); + currentMember = null; } - - public void Run(AstNode compilationUnit) + + public override void VisitPropertyDeclaration(PropertyDeclaration propertyDeclaration) + { + currentMember = propertyDeclaration.GetSymbol() as IMember; + SetContext(); + base.VisitPropertyDeclaration(propertyDeclaration); + currentMember = null; + } + + public override void VisitFieldDeclaration(FieldDeclaration fieldDeclaration) + { + currentMember = fieldDeclaration.GetSymbol() as IMember; + SetContext(); + base.VisitFieldDeclaration(fieldDeclaration); + currentMember = null; + } + + public override void VisitEventDeclaration(EventDeclaration eventDeclaration) + { + currentMember = eventDeclaration.GetSymbol() as IMember; + SetContext(); + base.VisitEventDeclaration(eventDeclaration); + currentMember = null; + } + + public override void VisitCustomEventDeclaration(CustomEventDeclaration eventDeclaration) + { + currentMember = eventDeclaration.GetSymbol() as IMember; + SetContext(); + base.VisitCustomEventDeclaration(eventDeclaration); + currentMember = null; + } + + public override void VisitInvocationExpression(InvocationExpression invocationExpression) { - foreach (InvocationExpression invocation in compilationUnit.Descendants.OfType()) { - MemberReferenceExpression mre = invocation.Target as MemberReferenceExpression; - MethodReference methodReference = invocation.Annotation(); - if (mre != null && mre.Target is TypeReferenceExpression && methodReference != null && invocation.Arguments.Any()) { - MethodDefinition d = methodReference.Resolve(); - if (d != null) { - foreach (var ca in d.CustomAttributes) { - if (ca.AttributeType.Name == "ExtensionAttribute" && ca.AttributeType.Namespace == "System.Runtime.CompilerServices") { - var firstArgument = invocation.Arguments.First(); - if (firstArgument is NullReferenceExpression) - firstArgument = firstArgument.ReplaceWith(expr => expr.CastTo(AstBuilder.ConvertType(d.Parameters.First().ParameterType))); - else - mre.Target = firstArgument.Detach(); - if (invocation.Arguments.Any()) { - // HACK: removing type arguments should be done indepently from whether a method is an extension method, - // just by testing whether the arguments can be inferred - mre.TypeArguments.Clear(); - } - break; - } - } - } - } - } + base.VisitInvocationExpression(invocationExpression); + var mre = invocationExpression.Target as MemberReferenceExpression; + var method = invocationExpression.GetSymbol() as IMethod; + if (method == null || !method.IsExtensionMethod || mre == null || !(mre.Target is TypeReferenceExpression) || !invocationExpression.Arguments.Any()) + return; + var firstArgument = invocationExpression.Arguments.First(); + var target = firstArgument.GetResolveResult(); + var args = method.Parameters.Skip(1).Select(p => new TypeResolveResult(p.Type)).ToArray(); + var rr = resolver.ResolveMemberAccess(target, method.Name, method.TypeArguments) as MethodGroupResolveResult; + if (rr == null) + return; + var or = rr.PerformOverloadResolution(resolveContext.Compilation, args, allowExtensionMethods: true); + if (or == null || or.IsAmbiguous) + return; + if (firstArgument is NullReferenceExpression) + firstArgument = firstArgument.ReplaceWith(expr => new CastExpression(context.TypeSystemAstBuilder.ConvertType(method.Parameters[0].Type), expr.Detach())); + else + mre.Target = firstArgument.Detach(); } } -} +} \ No newline at end of file diff --git a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceUsingDeclarations.cs b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceUsingDeclarations.cs index 4703f360a..1a0d8b8e5 100644 --- a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceUsingDeclarations.cs +++ b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceUsingDeclarations.cs @@ -34,15 +34,15 @@ namespace ICSharpCode.Decompiler.CSharp.Transforms { public bool FullyQualifyAmbiguousTypeNames = true; - public void Run(AstNode compilationUnit, TransformContext context) + public void Run(AstNode rootNode, TransformContext context) { // First determine all the namespaces that need to be imported: var requiredImports = new FindRequiredImports(context); - compilationUnit.AcceptVisitor(requiredImports); + rootNode.AcceptVisitor(requiredImports); var usingScope = new UsingScope(); - var insertionPoint = compilationUnit.Children.LastOrDefault(n => n is PreProcessorDirective p && p.Type == PreProcessorDirectiveType.Define); + var insertionPoint = rootNode.Children.LastOrDefault(n => n is PreProcessorDirective p && p.Type == PreProcessorDirectiveType.Define); // Now add using declarations for those namespaces: foreach (string ns in requiredImports.ImportedNamespaces.OrderByDescending(n => n)) { @@ -58,14 +58,15 @@ namespace ICSharpCode.Decompiler.CSharp.Transforms if (reference != null) usingScope.Usings.Add(reference); } - compilationUnit.InsertChildAfter(insertionPoint, new UsingDeclaration { Import = nsType }, SyntaxTree.MemberRole); + rootNode.InsertChildAfter(insertionPoint, new UsingDeclaration { Import = nsType }, SyntaxTree.MemberRole); } if (!FullyQualifyAmbiguousTypeNames) return; // verify that the SimpleTypes refer to the correct type (no ambiguities) - compilationUnit.AcceptVisitor(new FullyQualifyAmbiguousTypeNamesVisitor(context, usingScope)); + rootNode.AcceptVisitor(new FullyQualifyAmbiguousTypeNamesVisitor(context, usingScope)); + rootNode.AddAnnotation(usingScope); } sealed class FindRequiredImports : DepthFirstAstVisitor diff --git a/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj b/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj index 0d427c545..2ba1c4089 100644 --- a/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj +++ b/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj @@ -242,6 +242,7 @@ +