diff --git a/src/AST/Class.cs b/src/AST/Class.cs index b42042e8..14de8630 100644 --- a/src/AST/Class.cs +++ b/src/AST/Class.cs @@ -213,8 +213,8 @@ namespace CppSharp.AST if (function.IsOperator) return Methods.Where(fn => fn.OperatorKind == function.OperatorKind); - var methods = Methods.Where(m => m.Name == function.Name); - if (methods.ToList().Count != 0) + var methods = Methods.Where(m => m.Name == function.Name).ToList(); + if (methods.Count != 0) return methods; return base.GetOverloads(function); diff --git a/src/AST/ClassExtensions.cs b/src/AST/ClassExtensions.cs index c9ca3382..bddeb49b 100644 --- a/src/AST/ClassExtensions.cs +++ b/src/AST/ClassExtensions.cs @@ -52,7 +52,6 @@ namespace CppSharp.AST where method.Name == @override.Name && method.ReturnType == @override.ReturnType && - method.Parameters.Count == @override.Parameters.Count && method.Parameters.SequenceEqual(@override.Parameters, new ParameterTypeComparer()) select method).FirstOrDefault() @@ -69,7 +68,6 @@ namespace CppSharp.AST from property in @base.Class.Properties where property.Name == @override.Name && - property.Parameters.Count == @override.Parameters.Count && property.Parameters.SequenceEqual(@override.Parameters, new ParameterTypeComparer()) select property).FirstOrDefault() diff --git a/src/Generator.Tests/Passes/TestPasses.cs b/src/Generator.Tests/Passes/TestPasses.cs index d185a32c..6a6de8c0 100644 --- a/src/Generator.Tests/Passes/TestPasses.cs +++ b/src/Generator.Tests/Passes/TestPasses.cs @@ -1,7 +1,7 @@ -using System.Linq; -using CppSharp; +using CppSharp.AST; +using CppSharp.Generators.CSharp; +using System.Linq; using CppSharp.Passes; -using CppSharp.AST; using NUnit.Framework; namespace CppSharp.Generator.Tests.Passes @@ -70,6 +70,8 @@ namespace CppSharp.Generator.Tests.Passes [Test] public void TestCaseRenamePass() { + Type.TypePrinterDelegate += type => type.Visit(new CSharpTypePrinter(Driver)).Type; + var c = AstContext.Class("TestRename"); var method = c.Method("lowerCaseMethod"); diff --git a/src/Generator/Passes/GenerateAbstractImplementationsPass.cs b/src/Generator/Passes/GenerateAbstractImplementationsPass.cs index 9469b58a..25e69f74 100644 --- a/src/Generator/Passes/GenerateAbstractImplementationsPass.cs +++ b/src/Generator/Passes/GenerateAbstractImplementationsPass.cs @@ -105,7 +105,6 @@ namespace CppSharp.Passes var @abstract = abstractMethods[i]; if (overriddenMethods.Find(m => m.Name == @abstract.Name && m.ReturnType == @abstract.ReturnType && - m.Parameters.Count == @abstract.Parameters.Count && m.Parameters.SequenceEqual(@abstract.Parameters, paramTypeCmp)) != null) { abstractMethods.RemoveAt(i); diff --git a/src/Generator/Passes/RenamePass.cs b/src/Generator/Passes/RenamePass.cs index 0f1ed1b1..541ee46d 100644 --- a/src/Generator/Passes/RenamePass.cs +++ b/src/Generator/Passes/RenamePass.cs @@ -13,6 +13,19 @@ namespace CppSharp.Passes /// public abstract class RenamePass : TranslationUnitPass { + public class ParameterMappedTypeComparer : IEqualityComparer + { + public bool Equals(Parameter x, Parameter y) + { + return x.QualifiedType.ToString() == y.QualifiedType.ToString(); + } + + public int GetHashCode(Parameter obj) + { + return obj.Type.GetHashCode(); + } + } + public RenameTargets Targets = RenameTargets.Any; protected RenamePass() @@ -90,7 +103,14 @@ namespace CppSharp.Passes declarations.AddRange(decl.Namespace.Classes.Where(c => !c.IsIncomplete)); declarations.AddRange(decl.Namespace.Enums); declarations.AddRange(decl.Namespace.Events); - declarations.AddRange(decl.Namespace.Functions); + var function = decl as Function; + if (function != null) + { + // account for overloads + declarations.AddRange(GetFunctionsWithTheSameParams(function)); + } + else + declarations.AddRange(decl.Namespace.Functions); declarations.AddRange(decl.Namespace.Variables); declarations.AddRange(from typedefDecl in decl.Namespace.Typedefs let pointerType = typedefDecl.Type.Desugar() as PointerType @@ -108,6 +128,18 @@ namespace CppSharp.Passes return ((Class) method.Namespace).GetPropertyByName(newName) != null; } + private static IEnumerable GetFunctionsWithTheSameParams(Function function) + { + var method = function as Method; + if (method != null) + { + return ((Class) method.Namespace).Methods.Where( + m => m.Parameters.SequenceEqual(function.Parameters, new ParameterMappedTypeComparer())); + } + return function.Namespace.Functions.Where( + f => f.Parameters.SequenceEqual(function.Parameters, new ParameterMappedTypeComparer())); + } + public override bool VisitEnumItem(Enumeration.Item item) { if (!Targets.HasFlag(RenameTargets.EnumItem)) diff --git a/tests/CSharpTemp/CSharpTemp.h b/tests/CSharpTemp/CSharpTemp.h index e8f7aa15..457c87c7 100644 --- a/tests/CSharpTemp/CSharpTemp.h +++ b/tests/CSharpTemp/CSharpTemp.h @@ -158,3 +158,10 @@ public: int A; float B; }; + +class DLL_API TestRenaming +{ +public: + void name(); + void Name(); +};