diff --git a/ICSharpCode.NRefactory.Tests/Analysis/SymbolCollectorTests.cs b/ICSharpCode.NRefactory.Tests/Analysis/SymbolCollectorTests.cs index 7675c03ec0..9ced7ed7ab 100644 --- a/ICSharpCode.NRefactory.Tests/Analysis/SymbolCollectorTests.cs +++ b/ICSharpCode.NRefactory.Tests/Analysis/SymbolCollectorTests.cs @@ -29,6 +29,10 @@ using ICSharpCode.NRefactory.CSharp; using ICSharpCode.NRefactory.TypeSystem; using NUnit.Framework; using ICSharpCode.NRefactory.CSharp.CodeCompletion; +using System.Text; +using System.Collections.Generic; +using ICSharpCode.NRefactory.Editor; +using ICSharpCode.NRefactory.CSharp.Resolver; namespace ICSharpCode.NRefactory.Analysis { @@ -36,6 +40,164 @@ namespace ICSharpCode.NRefactory.Analysis public class SymbolCollectorTests { + void CollectMembers(string code, string memberName, bool includeOverloads = true) + { + StringBuilder sb = new StringBuilder(); + List offsets = new List(); + foreach (var ch in code) { + if (ch == '$') { + offsets.Add(sb.Length); + continue; + } + sb.Append(ch); + } + var syntaxTree = SyntaxTree.Parse(sb.ToString (), "test.cs"); + var unresolvedFile = syntaxTree.ToTypeSystem(); + var compilation = TypeSystemHelper.CreateCompilation(unresolvedFile); + + var symbol = FindReferencesTest.GetSymbol(compilation, memberName); + + var result = SymbolCollector.GetRelatedSymbols (new TypeGraph (compilation.Assemblies), + symbol, includeOverloads); + if (offsets.Count != result.Count()) { + foreach (var a in result) + Console.WriteLine(a); + } + Assert.AreEqual(offsets.Count, result.Count()); + var doc = new ReadOnlyDocument(sb.ToString ()); + result + .Select(r => doc.GetOffset ((r as IEntity).Region.Begin)) + .SequenceEqual(offsets); + } + + [Test] + public void TestSingleInterfaceImpl () + { + var code = @" +interface IA +{ + void $Method(); +} + +class A : IA +{ + public virtual void $Method() { }; +} + +class B : A +{ + public override void Method() { }; +} + +class C : IA +{ + public void $Method() { }; +}"; + CollectMembers(code, "IA.Method"); + } + + + [Test] + public void TestMultiInterfacesImpl1 () + { + var code = @" +interface IA +{ + void $Method(); +} +interface IB +{ + void $Method(); +} +class A : IA, IB +{ + public void $Method() { } +} +class B : IA +{ + public void $Method() { } +} +class C : IB +{ + public void $Method() { } +}"; + CollectMembers(code, "A.Method"); + } + + + [Test] + public void TestOverloads () + { + var code = @" +class A +{ + public void $Method () { } + public void $Method (int i) { } + public void $Method (string i) { } +} +"; + CollectMembers(code, "A.Method"); + } + + [Test] + public void TestConstructor () + { + var code = @" +class $A +{ + public $A() { } + public $A(int i) { } +} +"; + CollectMembers(code, "A"); + } + + + [Test] + public void TestDestructor () + { + var code = @" +class $A +{ + $~A() { } +} +"; + CollectMembers(code, "A"); + } + + [Test] + public void TestStaticConstructor () + { + var code = @" +class $A +{ + static $A() { } + public $A(int i) { } +} +"; + CollectMembers(code, "A"); + } + + [Test] + public void TestShadowedMember () + { + var code = @" +class A +{ + public int $Prop + { get; set; } +} +class B : A +{ + public int Prop + { get; set; } +} +"; + CollectMembers(code, "A.Prop"); + } + + + } } diff --git a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs index 08bc38a260..a62eda24bf 100644 --- a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs +++ b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs @@ -389,7 +389,7 @@ namespace Foo #region Rename - ISymbol GetSymbol (string reflectionName) + internal static ISymbol GetSymbol (ICompilation compilation, string reflectionName) { Stack typeStack = new Stack(compilation.MainAssembly.TopLevelTypeDefinitions); while (typeStack.Count > 0) { @@ -408,7 +408,7 @@ namespace Foo IList Rename(string fullyQualifiedName, string newName, bool includeOverloads) { - var sym = GetSymbol(fullyQualifiedName); + var sym = GetSymbol(compilation, fullyQualifiedName); Assert.NotNull(sym); var graph = new TypeGraph(compilation.Assemblies); diff --git a/ICSharpCode.NRefactory/Analysis/SymbolCollector.cs b/ICSharpCode.NRefactory/Analysis/SymbolCollector.cs index 186f20488b..ebf2b353e4 100644 --- a/ICSharpCode.NRefactory/Analysis/SymbolCollector.cs +++ b/ICSharpCode.NRefactory/Analysis/SymbolCollector.cs @@ -1,4 +1,4 @@ -// Copyright (c) 2013 AlphaSierraPapa for the SharpDevelop Team +// Copyright (c) 2013 AlphaSierraPapa for the SharpDevelop Team // // Permission is hereby granted, free of charge, to any person obtaining a copy of this // software and associated documentation files (the "Software"), to deal in the Software @@ -19,14 +19,114 @@ using System; using System.Collections.Generic; using ICSharpCode.NRefactory.TypeSystem; +using System.Linq; namespace ICSharpCode.NRefactory.Analysis { + /// + /// The symbol collector collects related symbols that form a group of symbols that should be renamed + /// when a name of one symbol changes. For example if a type definition name should be changed + /// the constructors and destructor names should change as well. + /// public class SymbolCollector { + static IEnumerable CollectTypeRelatedMembers (ITypeDefinition type) + { + yield return type; + foreach (var c in type.GetDefinition ().GetMembers (m => !m.IsSynthetic && (m.SymbolKind == SymbolKind.Constructor || m.SymbolKind == SymbolKind.Destructor), GetMemberOptions.IgnoreInheritedMembers)) { + yield return c; + } + } + + static IEnumerable CollectOverloads (TypeGraph g, IMethod method) + { + return method.DeclaringType + .GetMethods (m => m.Name == method.Name) + .Where (m => m != method); + } + + static IMember SearchMember (ITypeDefinition derivedType, IMember method) + { + foreach (var m in derivedType.Members) { + if (m.ImplementedInterfaceMembers.Contains (method)) + return m; + } + return null; + } + + static IEnumerable MakeUnique (List symbols) + { + HashSet taken = new HashSet (); + foreach (var sym in symbols) { + if (taken.Contains (sym)) + continue; + taken.Add (sym); + yield return sym; + } + } + + /// + /// Gets the related symbols. + /// + /// The related symbols. + /// The type graph. + /// The symbol to search + /// If set to true overloads are included in the rename. public static IEnumerable GetRelatedSymbols(TypeGraph g, ISymbol m, bool includeOverloads) { - yield return m; + switch (m.SymbolKind) { + case SymbolKind.TypeDefinition: + return CollectTypeRelatedMembers ((ITypeDefinition)m); + + case SymbolKind.Field: + case SymbolKind.Operator: + case SymbolKind.Variable: + case SymbolKind.Parameter: + case SymbolKind.TypeParameter: + return new ISymbol[] { m }; + + case SymbolKind.Constructor: + case SymbolKind.Destructor: + return GetRelatedSymbols (g, ((IMethod)m).DeclaringTypeDefinition, includeOverloads); + + case SymbolKind.Indexer: + case SymbolKind.Event: + case SymbolKind.Property: + return new ISymbol[] { m }; + + case SymbolKind.Method: + var method = (IMethod)m; + List symbols = new List (); + if (method.ImplementedInterfaceMembers.Count > 0) { + foreach (var m2 in method.ImplementedInterfaceMembers) { + symbols.AddRange (GetRelatedSymbols (g, m2, includeOverloads)); + } + } else { + symbols.Add (method); + } + + if (method.DeclaringType.Kind == TypeKind.Interface) { + foreach (var derivedType in g.GetNode (method.DeclaringTypeDefinition).DerivedTypes) { + var member = SearchMember (derivedType.TypeDefinition, method); + if (member != null) + symbols.Add (member); + } + } + + + if (includeOverloads) { + foreach (var m3 in CollectOverloads (g, method)) { + symbols.AddRange (GetRelatedSymbols (g, m3, false)); + } + } + return MakeUnique (symbols); + + case SymbolKind.Namespace: + // TODO? + return new ISymbol[] { m }; + default: + throw new ArgumentOutOfRangeException ("symbol:"+m.SymbolKind); + } } } }