diff --git a/src/Main/Base/Project/Src/Services/RefactoringService/FindReferencesAndRenameHelper.cs b/src/Main/Base/Project/Src/Services/RefactoringService/FindReferencesAndRenameHelper.cs
index d9cead43a1..270e898285 100644
--- a/src/Main/Base/Project/Src/Services/RefactoringService/FindReferencesAndRenameHelper.cs
+++ b/src/Main/Base/Project/Src/Services/RefactoringService/FindReferencesAndRenameHelper.cs
@@ -5,11 +5,11 @@
// $Revision$
//
+using ICSharpCode.NRefactory.Ast;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
-
using ICSharpCode.Core;
using ICSharpCode.SharpDevelop.DefaultEditor.Gui.Editor;
using ICSharpCode.SharpDevelop.Dom;
@@ -62,7 +62,7 @@ namespace ICSharpCode.SharpDevelop.Refactoring
LanguageProperties language = c.ProjectContent.Language;
string classFileName = c.CompilationUnit.FileName;
string existingClassCode = ParserService.GetParseableFileContent(classFileName);
-
+
// build the new interface...
string newInterfaceCode =
language.RefactoringProvider.GenerateInterfaceForClass(extractInterface.NewInterfaceName,
@@ -83,7 +83,6 @@ namespace ICSharpCode.SharpDevelop.Refactoring
// simply update it
editable.Text = newInterfaceCode;
viewContent.PrimaryFile.SaveToDisk();
-
} else {
// create it
viewContent = FileService.NewFile(newInterfaceFileName, newInterfaceCode);
@@ -101,9 +100,12 @@ namespace ICSharpCode.SharpDevelop.Refactoring
}
}
+ ICompilationUnit newCompilationUnit = ParserService.ParseFile(newInterfaceFileName).MostRecentCompilationUnit;
+ IClass newInterfaceDef = newCompilationUnit.Classes[0];
+
// finally, add the interface to the base types of the class that we're extracting from
if (extractInterface.AddInterfaceToClass) {
- string modifiedClassCode = language.RefactoringProvider.AddBaseTypeToClass(existingClassCode, extractInterface.NewInterfaceName);
+ string modifiedClassCode = language.RefactoringProvider.AddBaseTypeToClass(existingClassCode, c, newInterfaceDef);
if (modifiedClassCode == null) {
return;
}
diff --git a/src/Main/ICSharpCode.SharpDevelop.Dom/Project/Src/Refactoring/NRefactoryRefactoringProvider.cs b/src/Main/ICSharpCode.SharpDevelop.Dom/Project/Src/Refactoring/NRefactoryRefactoringProvider.cs
index 434af99cac..b3c72b0f6d 100644
--- a/src/Main/ICSharpCode.SharpDevelop.Dom/Project/Src/Refactoring/NRefactoryRefactoringProvider.cs
+++ b/src/Main/ICSharpCode.SharpDevelop.Dom/Project/Src/Refactoring/NRefactoryRefactoringProvider.cs
@@ -93,79 +93,117 @@ namespace ICSharpCode.SharpDevelop.Dom.Refactoring
}
}
- private class ExtractInterfaceVisitor : NR.Visitors.AbstractAstVisitor
+ class ExtractInterfaceTransformer : NR.Visitors.AbstractAstTransformer
{
string newInterfaceName;
string sourceClassName;
string sourceNamespace;
- Dictionary membersToInclude;
+ List membersToInclude;
- public ExtractInterfaceVisitor(string newInterfaceName,
- string sourceNamespace,
- string sourceClassName,
- IList chosenMembers) {
+ public ExtractInterfaceTransformer(string newInterfaceName,
+ string sourceNamespace,
+ string sourceClassName,
+ IList chosenMembers) {
this.newInterfaceName = newInterfaceName;
this.sourceNamespace = sourceNamespace;
this.sourceClassName = sourceClassName;
- // store the chosen members in a dictionary for easy lookup
- membersToInclude = new Dictionary();
- foreach(IMember m in chosenMembers) {
- membersToInclude.Add(m.Name, m);
- }
+ membersToInclude = chosenMembers.ToList();
}
+
+ public override object VisitNamespaceDeclaration(NamespaceDeclaration namespaceDeclaration, object data)
+ {
+ TypeDeclaration type = LookupTypeDeclaration(namespaceDeclaration.Children, namespaceDeclaration.Name);
- public override object VisitCompilationUnit(CompilationUnit compilationUnit, object data)
+ RemoveCurrentNode();
+
+ if (type != null && namespaceDeclaration.Parent is CompilationUnit)
+ ((CompilationUnit)namespaceDeclaration.Parent).AddChild(type);
+
+ return base.VisitNamespaceDeclaration(namespaceDeclaration, data);
+ }
+
+ public override object VisitUsingDeclaration(UsingDeclaration usingDeclaration, object data)
{
- // strip out any usings & extract our TypeReference from the NameSpace
- // we walk backwards so that deletions don't affect the iteration
- NamespaceDeclaration ns;
- TypeDeclaration td;
- object child;
- object nsChild;
- for(int i = compilationUnit.Children.Count-1; i>=0; i--) {
- child = compilationUnit.Children[i];
- if (child is UsingDeclaration) {
- // we don't want our usings here...
- compilationUnit.Children.RemoveAt(i);
- }
- else if (child is NamespaceDeclaration) {
- ns = (NamespaceDeclaration)child;
- if (ns.Name != this.sourceNamespace) {
- // we're not interested in this namespace...
- compilationUnit.Children.RemoveAt(i);
- } else {
-
- // this NamespaceDeclaration presumably contains our source class
- // walk its children backwards to that removing them won't break the iteration
- for(int j = ns.Children.Count-1; j>=0; j--) {
- nsChild = ns.Children[j];
- if (nsChild is TypeDeclaration) {
- td = (TypeDeclaration)nsChild;
-
- if (td.Name == this.sourceClassName) {
- // keep it, and substitute it for the current NamespaceDeclaration
- compilationUnit.Children[i] = td;
- } else {
- // it's not the class we're extracting from
- ns.Children.RemoveAt(j);
- }
- } else {
- // it's not even a class... (e.g. using, etc)
- ns.Children.RemoveAt(j);
- }
- }
- }
- } else {
- // we don't actually want to throw an exception here just because we havn't forseen the node type...
- //throw new NotSupportedException("trimming "+compilationUnit.Children[i].ToString()+" is not supported.");
+ RemoveCurrentNode();
+ return base.VisitUsingDeclaration(usingDeclaration, data);
+ }
+
+ TypeDeclaration LookupTypeDeclaration(List nodes, string nameSpace)
+ {
+ TypeDeclaration td = null;
+
+ foreach (INode node in nodes) {
+ if (node is TypeDeclaration) {
+ TypeDeclaration type = node as TypeDeclaration;
+ string name = nameSpace + "." + type.Name;
+ string lookFor = sourceNamespace + "." + sourceClassName;
+ if (lookFor == name)
+ return type;
+ else
+ td = LookupTypeDeclaration(node.Children, name);
}
}
- return base.VisitCompilationUnit(compilationUnit, data);
+
+ return td;
+ }
+
+ static bool MethodEquals(MethodDeclaration md1, MethodDeclaration md2)
+ {
+ if (md2 == null)
+ throw new ArgumentNullException("md2");
+ if (md1 == null)
+ throw new ArgumentNullException("md1");
+
+ // see C# Spec 3, page 65, 3.6
+
+ return md1.Name == md2.Name &&
+ (MethodDeclaration.GetCollectionString(md1.Parameters) == MethodDeclaration.GetCollectionString(md2.Parameters)) &&
+ (MethodDeclaration.GetCollectionString(md1.Templates) == MethodDeclaration.GetCollectionString(md2.Templates));
+ }
+
+ static bool PropertyEquals(PropertyDeclaration pd1, PropertyDeclaration pd2)
+ {
+ if (pd1 == null)
+ throw new ArgumentNullException("pd1");
+ if (pd2 == null)
+ throw new ArgumentNullException("pd2");
+
+ return pd1.Name == pd2.Name;
+ }
+
+ bool ContainsMethod(MethodDeclaration md)
+ {
+ if (md == null)
+ throw new ArgumentNullException("md");
+
+ foreach (IMember mem in this.membersToInclude) {
+ if (mem is IMethod && MethodEquals(CodeGenerator.ConvertMember(mem as IMethod, new ClassFinder(mem)) as MethodDeclaration, md))
+ return true;
+ }
+
+ return false;
+ }
+
+ bool ContainsProperty(PropertyDeclaration pd)
+ {
+ if (pd == null)
+ throw new ArgumentNullException("pd");
+
+ foreach (IMember mem in this.membersToInclude) {
+ if (mem is IProperty && PropertyEquals(CodeGenerator.ConvertMember(mem as IProperty, new ClassFinder(mem)) as PropertyDeclaration, pd))
+ return true;
+ }
+
+ return false;
}
public override object VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data)
{
+ if (typeDeclaration.Name != sourceClassName) {
+ return base.VisitTypeDeclaration(typeDeclaration, data);
+ }
+
// rewrite the type declaration to an interface
typeDeclaration.Attributes.Clear();
typeDeclaration.BaseTypes.Clear();
@@ -184,16 +222,12 @@ namespace ICSharpCode.SharpDevelop.Dom.Refactoring
child = typeDeclaration.Children[i];
if (child is MethodDeclaration) {
method = (MethodDeclaration)child;
- if (membersToInclude.ContainsKey(method.Name)
- && ((method.Modifier & Modifiers.Static) == Modifiers.None)) {
+ if (ContainsMethod(method) && ((method.Modifier & Modifiers.Static) != Modifiers.Static))
keepIt = true;
- }
} else if (child is PropertyDeclaration) {
property = (PropertyDeclaration)child;
- if (membersToInclude.ContainsKey(property.Name)
- && ((property.Modifier & Modifiers.Static) == Modifiers.None)) {
+ if (ContainsProperty(property) && ((property.Modifier & Modifiers.Static) != Modifiers.Static))
keepIt = true;
- }
}
if (!keepIt) {
@@ -207,28 +241,36 @@ namespace ICSharpCode.SharpDevelop.Dom.Refactoring
public override object VisitMethodDeclaration(MethodDeclaration methodDeclaration, object data)
{
- // strip out the public modifier...
- methodDeclaration.Modifier = NR.Ast.Modifiers.None;
-
- // ...and the method body
- methodDeclaration.Body = BlockStatement.Null;
+ if (ContainsMethod(methodDeclaration) && ((methodDeclaration.Modifier & Modifiers.Static) != Modifiers.Static)) {
+ // strip out the public modifier...
+ methodDeclaration.Modifier = NR.Ast.Modifiers.None;
+
+ // ...and the method body
+ methodDeclaration.Body = BlockStatement.Null;
+ } else {
+ RemoveCurrentNode();
+ }
return null;
}
public override object VisitPropertyDeclaration(PropertyDeclaration propertyDeclaration, object data)
{
- // strip out the public modifiers...
- propertyDeclaration.Modifier = NR.Ast.Modifiers.None;
+ if (ContainsProperty(propertyDeclaration) && ((propertyDeclaration.Modifier & Modifiers.Static) != Modifiers.Static)) {
+ // strip out the public modifiers...
+ propertyDeclaration.Modifier = NR.Ast.Modifiers.None;
- // ... and the body of any get block...
- if (propertyDeclaration.HasGetRegion) {
- propertyDeclaration.GetRegion.Block = BlockStatement.Null;
- }
+ // ... and the body of any get block...
+ if (propertyDeclaration.HasGetRegion) {
+ propertyDeclaration.GetRegion.Block = BlockStatement.Null;
+ }
- // ... and the body of any set block...
- if (propertyDeclaration.HasSetRegion) {
- propertyDeclaration.SetRegion.Block = BlockStatement.Null;
+ // ... and the body of any set block...
+ if (propertyDeclaration.HasSetRegion) {
+ propertyDeclaration.SetRegion.Block = BlockStatement.Null;
+ }
+ } else {
+ RemoveCurrentNode();
}
return null;
@@ -251,11 +293,11 @@ namespace ICSharpCode.SharpDevelop.Dom.Refactoring
// use a custom IAstVisitor to strip our class out of this file,
// rewrite it as our desired interface, and strip out every
// member except those we want to keep in our new interface.
- ExtractInterfaceVisitor extractInterfaceVisitor = new ExtractInterfaceVisitor(newInterfaceName,
- sourceNamespace,
- sourceClassName,
- membersToKeep);
- parser.CompilationUnit.AcceptVisitor(extractInterfaceVisitor, null);
+ ExtractInterfaceTransformer extractInterfaceTransformer = new ExtractInterfaceTransformer(newInterfaceName,
+ sourceNamespace,
+ sourceClassName,
+ membersToKeep);
+ parser.CompilationUnit.AcceptVisitor(extractInterfaceTransformer, null);
// now use an output visitor for the appropriate language (based on
// extension of the existing code file) to format the new interface.
@@ -272,6 +314,7 @@ namespace ICSharpCode.SharpDevelop.Dom.Refactoring
// run the output visitor without the specials inserter
parser.CompilationUnit.AcceptVisitor(output, null);
}
+
parser.Dispose();
if (output.Errors.Count == 0) {
@@ -289,13 +332,14 @@ namespace ICSharpCode.SharpDevelop.Dom.Refactoring
return newFileContent;
}
- private class AddTypeToBaseTypesVisitor : NR.Visitors.AbstractAstVisitor
+ class AddTypeToBaseTypesVisitor : NR.Visitors.AbstractAstVisitor
{
- private TypeReference typeReference;
+ IClass target, newBaseType;
- public AddTypeToBaseTypesVisitor(string newTypeName)
+ public AddTypeToBaseTypesVisitor(IClass target, IClass newBaseType)
{
- this.typeReference = new TypeReference(newTypeName);
+ this.target = target;
+ this.newBaseType = newBaseType;
}
public override object VisitTypeDeclaration(TypeDeclaration typeDeclaration, object data)
@@ -303,29 +347,32 @@ namespace ICSharpCode.SharpDevelop.Dom.Refactoring
// test the Type string property explicitly (rather than .BaseTypes.Contains())
// to ensure that a matching type name is enough to prevent adding a second
// reference.
+
+ if (typeDeclaration.Name != target.Name)
+ return base.VisitTypeDeclaration(typeDeclaration, data);
+
bool exists = false;
foreach(TypeReference type in typeDeclaration.BaseTypes) {
- if (type.Type == this.typeReference.Type) {
+ if (type.Type == this.newBaseType.Name) {
exists = true;
break;
}
}
if (!exists) {
- typeDeclaration.BaseTypes.Add(this.typeReference);
+ typeDeclaration.BaseTypes.Add(new TypeReference(newBaseType.Name, newBaseType.TypeParameters.Select(p => new TypeReference(p.Name)).ToList()));
}
return base.VisitTypeDeclaration(typeDeclaration, data);
}
}
- public override string AddBaseTypeToClass(string existingCode, string newInterfaceName)
+ public override string AddBaseTypeToClass(string existingCode, IClass targetClass, IClass newBaseType)
{
- string newCode = existingCode;
NR.IParser parser = ParseFile(null, existingCode);
if (parser == null) {
return null;
}
- AddTypeToBaseTypesVisitor addTypeToBaseTypesVisitor = new AddTypeToBaseTypesVisitor(newInterfaceName);
+ AddTypeToBaseTypesVisitor addTypeToBaseTypesVisitor = new AddTypeToBaseTypesVisitor(targetClass, newBaseType);
parser.CompilationUnit.AcceptVisitor(addTypeToBaseTypesVisitor, null);
diff --git a/src/Main/ICSharpCode.SharpDevelop.Dom/Project/Src/Refactoring/RefactoringProvider.cs b/src/Main/ICSharpCode.SharpDevelop.Dom/Project/Src/Refactoring/RefactoringProvider.cs
index 6ecde72062..6cefdf26c4 100644
--- a/src/Main/ICSharpCode.SharpDevelop.Dom/Project/Src/Refactoring/RefactoringProvider.cs
+++ b/src/Main/ICSharpCode.SharpDevelop.Dom/Project/Src/Refactoring/RefactoringProvider.cs
@@ -5,6 +5,7 @@
// $Revision$
//
+using ICSharpCode.NRefactory.Ast;
using System;
using System.Collections.Generic;
@@ -46,8 +47,9 @@ namespace ICSharpCode.SharpDevelop.Dom.Refactoring
throw new NotSupportedException();
}
- public virtual string AddBaseTypeToClass(string existingCode, string newInterfaceName) {
- throw new NotSupportedException();
+ public virtual string AddBaseTypeToClass(string existingCode, IClass targetClass, IClass newBaseType)
+ {
+ throw new NotImplementedException();
}
#endregion