diff --git a/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs b/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs
index f1ce4d74e..b7bde357d 100644
--- a/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs
+++ b/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs
@@ -128,7 +128,7 @@ namespace ICSharpCode.Decompiler.CSharp
new ConvertConstructorCallIntoInitializer(), // must run after DeclareVariables
new DecimalConstantTransform(),
new IntroduceUsingDeclarations(),
- //new IntroduceExtensionMethods(context), // must run after IntroduceUsingDeclarations
+ new IntroduceExtensionMethods(), // must run after IntroduceUsingDeclarations
//new IntroduceQueryExpressions(context), // must run after IntroduceExtensionMethods
//new CombineQueryExpressions(context),
//new FlattenSwitchBlocks(),
diff --git a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs
index c9d143422..4887d2bba 100644
--- a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs
+++ b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs
@@ -18,49 +18,119 @@
using System;
using System.Linq;
-using ICSharpCode.NRefactory.CSharp;
-using Mono.Cecil;
+using ICSharpCode.Decompiler.CSharp.Resolver;
+using ICSharpCode.Decompiler.CSharp.Syntax;
+using ICSharpCode.Decompiler.CSharp.TypeSystem;
+using ICSharpCode.Decompiler.Semantics;
+using ICSharpCode.Decompiler.TypeSystem;
+using ICSharpCode.Decompiler.Util;
namespace ICSharpCode.Decompiler.CSharp.Transforms
{
///
/// Converts extension method calls into infix syntax.
///
- public class IntroduceExtensionMethods : IAstTransform
+ public class IntroduceExtensionMethods : DepthFirstAstVisitor, IAstTransform
{
- readonly DecompilerContext context;
-
- public IntroduceExtensionMethods(DecompilerContext context)
+ TransformContext context;
+ UsingScope rootUsingScope;
+ UsingScope usingScope;
+ IMember currentMember;
+ CSharpTypeResolveContext resolveContext;
+ CSharpResolver resolver;
+
+ public void Run(AstNode rootNode, TransformContext context)
{
this.context = context;
+ this.usingScope = this.rootUsingScope = rootNode.Annotation();
+ rootNode.AcceptVisitor(this);
+ }
+
+ void SetContext()
+ {
+ this.usingScope = rootUsingScope;
+ foreach (var name in currentMember.Namespace.Split('.'))
+ usingScope = new UsingScope(usingScope, name);
+ resolveContext = new CSharpTypeResolveContext(currentMember.ParentAssembly, usingScope.Resolve(context.TypeSystem.Compilation), currentMember.DeclaringTypeDefinition, currentMember);
+ resolver = new CSharpResolver(resolveContext);
+ }
+
+ public override void VisitMethodDeclaration(MethodDeclaration methodDeclaration)
+ {
+ currentMember = methodDeclaration.GetSymbol() as IMember;
+ SetContext();
+ base.VisitMethodDeclaration(methodDeclaration);
+ currentMember = null;
+ }
+
+ public override void VisitConstructorDeclaration(ConstructorDeclaration constructorDeclaration)
+ {
+ currentMember = constructorDeclaration.GetSymbol() as IMember;
+ SetContext();
+ base.VisitConstructorDeclaration(constructorDeclaration);
+ currentMember = null;
+ }
+
+ public override void VisitDestructorDeclaration(DestructorDeclaration destructorDeclaration)
+ {
+ currentMember = destructorDeclaration.GetSymbol() as IMember;
+ SetContext();
+ base.VisitDestructorDeclaration(destructorDeclaration);
+ currentMember = null;
}
-
- public void Run(AstNode compilationUnit)
+
+ public override void VisitPropertyDeclaration(PropertyDeclaration propertyDeclaration)
+ {
+ currentMember = propertyDeclaration.GetSymbol() as IMember;
+ SetContext();
+ base.VisitPropertyDeclaration(propertyDeclaration);
+ currentMember = null;
+ }
+
+ public override void VisitFieldDeclaration(FieldDeclaration fieldDeclaration)
+ {
+ currentMember = fieldDeclaration.GetSymbol() as IMember;
+ SetContext();
+ base.VisitFieldDeclaration(fieldDeclaration);
+ currentMember = null;
+ }
+
+ public override void VisitEventDeclaration(EventDeclaration eventDeclaration)
+ {
+ currentMember = eventDeclaration.GetSymbol() as IMember;
+ SetContext();
+ base.VisitEventDeclaration(eventDeclaration);
+ currentMember = null;
+ }
+
+ public override void VisitCustomEventDeclaration(CustomEventDeclaration eventDeclaration)
+ {
+ currentMember = eventDeclaration.GetSymbol() as IMember;
+ SetContext();
+ base.VisitCustomEventDeclaration(eventDeclaration);
+ currentMember = null;
+ }
+
+ public override void VisitInvocationExpression(InvocationExpression invocationExpression)
{
- foreach (InvocationExpression invocation in compilationUnit.Descendants.OfType()) {
- MemberReferenceExpression mre = invocation.Target as MemberReferenceExpression;
- MethodReference methodReference = invocation.Annotation();
- if (mre != null && mre.Target is TypeReferenceExpression && methodReference != null && invocation.Arguments.Any()) {
- MethodDefinition d = methodReference.Resolve();
- if (d != null) {
- foreach (var ca in d.CustomAttributes) {
- if (ca.AttributeType.Name == "ExtensionAttribute" && ca.AttributeType.Namespace == "System.Runtime.CompilerServices") {
- var firstArgument = invocation.Arguments.First();
- if (firstArgument is NullReferenceExpression)
- firstArgument = firstArgument.ReplaceWith(expr => expr.CastTo(AstBuilder.ConvertType(d.Parameters.First().ParameterType)));
- else
- mre.Target = firstArgument.Detach();
- if (invocation.Arguments.Any()) {
- // HACK: removing type arguments should be done indepently from whether a method is an extension method,
- // just by testing whether the arguments can be inferred
- mre.TypeArguments.Clear();
- }
- break;
- }
- }
- }
- }
- }
+ base.VisitInvocationExpression(invocationExpression);
+ var mre = invocationExpression.Target as MemberReferenceExpression;
+ var method = invocationExpression.GetSymbol() as IMethod;
+ if (method == null || !method.IsExtensionMethod || mre == null || !(mre.Target is TypeReferenceExpression) || !invocationExpression.Arguments.Any())
+ return;
+ var firstArgument = invocationExpression.Arguments.First();
+ var target = firstArgument.GetResolveResult();
+ var args = method.Parameters.Skip(1).Select(p => new TypeResolveResult(p.Type)).ToArray();
+ var rr = resolver.ResolveMemberAccess(target, method.Name, method.TypeArguments) as MethodGroupResolveResult;
+ if (rr == null)
+ return;
+ var or = rr.PerformOverloadResolution(resolveContext.Compilation, args, allowExtensionMethods: true);
+ if (or == null || or.IsAmbiguous)
+ return;
+ if (firstArgument is NullReferenceExpression)
+ firstArgument = firstArgument.ReplaceWith(expr => new CastExpression(context.TypeSystemAstBuilder.ConvertType(method.Parameters[0].Type), expr.Detach()));
+ else
+ mre.Target = firstArgument.Detach();
}
}
-}
+}
\ No newline at end of file
diff --git a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceUsingDeclarations.cs b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceUsingDeclarations.cs
index 4703f360a..1a0d8b8e5 100644
--- a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceUsingDeclarations.cs
+++ b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceUsingDeclarations.cs
@@ -34,15 +34,15 @@ namespace ICSharpCode.Decompiler.CSharp.Transforms
{
public bool FullyQualifyAmbiguousTypeNames = true;
- public void Run(AstNode compilationUnit, TransformContext context)
+ public void Run(AstNode rootNode, TransformContext context)
{
// First determine all the namespaces that need to be imported:
var requiredImports = new FindRequiredImports(context);
- compilationUnit.AcceptVisitor(requiredImports);
+ rootNode.AcceptVisitor(requiredImports);
var usingScope = new UsingScope();
- var insertionPoint = compilationUnit.Children.LastOrDefault(n => n is PreProcessorDirective p && p.Type == PreProcessorDirectiveType.Define);
+ var insertionPoint = rootNode.Children.LastOrDefault(n => n is PreProcessorDirective p && p.Type == PreProcessorDirectiveType.Define);
// Now add using declarations for those namespaces:
foreach (string ns in requiredImports.ImportedNamespaces.OrderByDescending(n => n)) {
@@ -58,14 +58,15 @@ namespace ICSharpCode.Decompiler.CSharp.Transforms
if (reference != null)
usingScope.Usings.Add(reference);
}
- compilationUnit.InsertChildAfter(insertionPoint, new UsingDeclaration { Import = nsType }, SyntaxTree.MemberRole);
+ rootNode.InsertChildAfter(insertionPoint, new UsingDeclaration { Import = nsType }, SyntaxTree.MemberRole);
}
if (!FullyQualifyAmbiguousTypeNames)
return;
// verify that the SimpleTypes refer to the correct type (no ambiguities)
- compilationUnit.AcceptVisitor(new FullyQualifyAmbiguousTypeNamesVisitor(context, usingScope));
+ rootNode.AcceptVisitor(new FullyQualifyAmbiguousTypeNamesVisitor(context, usingScope));
+ rootNode.AddAnnotation(usingScope);
}
sealed class FindRequiredImports : DepthFirstAstVisitor
diff --git a/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj b/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj
index 0d427c545..2ba1c4089 100644
--- a/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj
+++ b/ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj
@@ -242,6 +242,7 @@
+