From 3cf6c7b3f16262ec4ca750b9e8def83eb85e014b Mon Sep 17 00:00:00 2001 From: Dimitar Dobrev Date: Wed, 13 Nov 2013 00:31:05 +0200 Subject: [PATCH] Added a test for the pass that moves functions to a class. Signed-off-by: Dimitar Dobrev --- src/AST/ASTContext.cs | 18 ++++--- src/AST/Namespace.cs | 10 ++-- src/Generator/Driver.cs | 1 + .../Passes/MoveFunctionToClassPass.cs | 53 +++++++++++-------- tests/Basic/Basic.Tests.cs | 6 +++ tests/Basic/Basic.cpp | 5 ++ tests/Basic/Basic.h | 7 +++ 7 files changed, 65 insertions(+), 35 deletions(-) diff --git a/src/AST/ASTContext.cs b/src/AST/ASTContext.cs index 656d6d1b..8ad4c139 100644 --- a/src/AST/ASTContext.cs +++ b/src/AST/ASTContext.cs @@ -51,7 +51,8 @@ namespace CppSharp.AST /// Finds an existing enum in the library modules. public IEnumerable FindEnum(string name) { - return TranslationUnits.Select(module => module.FindEnum(name)).Where(type => type != null); + return TranslationUnits.Select( + module => module.FindEnum(name)).Where(type => type != null); } /// Finds the complete declaration of an enum. @@ -61,10 +62,13 @@ namespace CppSharp.AST } /// Finds an existing struct/class in the library modules. - public IEnumerable FindClass(string name, bool create = false, bool ignoreCase = false) + public IEnumerable FindClass(string name, bool create = false, + bool ignoreCase = false) { return TranslationUnits.Select( - module => module.FindClass(name, ignoreCase)).Where(type => type != null); + module => module.FindClass(name, + ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal)) + .Where(type => type != null); } /// Finds the complete declaration of a class. @@ -77,15 +81,15 @@ namespace CppSharp.AST /// Finds an existing function in the library modules. public IEnumerable FindFunction(string name) { - return TranslationUnits.Select(module => module.FindFunction(name)).Where( - type => type != null); + return TranslationUnits.Select(module => module.FindFunction(name)) + .Where(type => type != null); } /// Finds an existing typedef in the library modules. public IEnumerable FindTypedef(string name) { - return TranslationUnits.Select(module => module.FindTypedef(name)).Where( - type => type != null); + return TranslationUnits.Select(module => module.FindTypedef(name)) + .Where(type => type != null); } /// Finds an existing declaration by name. diff --git a/src/AST/Namespace.cs b/src/AST/Namespace.cs index cf9d2037..4d29d693 100644 --- a/src/AST/Namespace.cs +++ b/src/AST/Namespace.cs @@ -168,7 +168,8 @@ namespace CppSharp.AST return @class; } - public Class FindClass(string name, bool ignoreCase = false) + public Class FindClass(string name, + StringComparison stringComparison = StringComparison.Ordinal) { if (string.IsNullOrEmpty(name)) return null; @@ -177,12 +178,7 @@ namespace CppSharp.AST if (entries.Count <= 1) { - Class @class; - if (ignoreCase) - @class = Classes.Find(e => e.Name.Equals(name, StringComparison.OrdinalIgnoreCase)); - else - @class = Classes.Find(e => e.Name.Equals(name)); - return @class; + return Classes.Find(e => e.Name.Equals(name, stringComparison)); } var className = entries[entries.Count - 1]; diff --git a/src/Generator/Driver.cs b/src/Generator/Driver.cs index 728b828a..5e78bfc9 100644 --- a/src/Generator/Driver.cs +++ b/src/Generator/Driver.cs @@ -201,6 +201,7 @@ namespace CppSharp TranslationUnitPasses.AddPass(new FindSymbolsPass()); TranslationUnitPasses.AddPass(new MoveOperatorToClassPass()); + TranslationUnitPasses.AddPass(new MoveFunctionToClassPass()); TranslationUnitPasses.AddPass(new CheckAmbiguousFunctions()); TranslationUnitPasses.AddPass(new CheckOperatorsOverloadsPass()); TranslationUnitPasses.AddPass(new CheckVirtualOverrideReturnCovariance()); diff --git a/src/Generator/Passes/MoveFunctionToClassPass.cs b/src/Generator/Passes/MoveFunctionToClassPass.cs index ead334a2..e09d6a70 100644 --- a/src/Generator/Passes/MoveFunctionToClassPass.cs +++ b/src/Generator/Passes/MoveFunctionToClassPass.cs @@ -3,36 +3,36 @@ using CppSharp.AST; namespace CppSharp.Passes { + /// + /// Moves a function to a class, if any, named after the function's header. + /// public class MoveFunctionToClassPass : TranslationUnitPass { public override bool VisitFunctionDecl(Function function) { - if (!AlreadyVisited(function) && !function.Ignore && !(function.Namespace is Class) - // HACK: there are bugs with operators generated by Q_DECLARE_OPERATORS_FOR_FLAGS, an incorrect argument type, to say the least - && !function.IsOperator) + if (AlreadyVisited(function) || function.Ignore || function.Namespace is Class) + return base.VisitFunctionDecl(function); + + Class @class = FindClassToMoveFunctionTo(function.Namespace); + if (@class != null) { - TranslationUnit unit = function.Namespace as TranslationUnit; - Class @class; - if (unit != null) - { - @class = Driver.ASTContext.FindCompleteClass( - unit.FileNameWithoutExtension.ToLowerInvariant(), true); - if (@class != null) - { - MoveFunction(function, @class); - return base.VisitFunctionDecl(function); - } - } - @class = Driver.ASTContext.FindClass( - function.Namespace.Name, ignoreCase: true).FirstOrDefault(); - if (@class != null) - { - MoveFunction(function, @class); - } + MoveFunction(function, @class); } return base.VisitFunctionDecl(function); } + private Class FindClassToMoveFunctionTo(INamedDecl @namespace) + { + TranslationUnit unit = @namespace as TranslationUnit; + if (unit == null) + { + return Driver.ASTContext.FindClass( + @namespace.Name, ignoreCase: true).FirstOrDefault(); + } + return Driver.ASTContext.FindCompleteClass( + unit.FileNameWithoutExtension.ToLowerInvariant(), true); + } + private static void MoveFunction(Function function, Class @class) { var method = new Method(function) @@ -41,6 +41,17 @@ namespace CppSharp.Passes IsStatic = true }; + if (method.OperatorKind != CXXOperatorKind.None) + { + var param = function.Parameters[0]; + Class type; + if (!FunctionToInstanceMethodPass.GetClassParameter(param, out type)) + return; + method.Kind = CXXMethodKind.Operator; + method.SynthKind = FunctionSynthKind.NonMemberOperator; + method.OriginalFunction = null; + } + function.ExplicityIgnored = true; @class.Methods.Add(method); diff --git a/tests/Basic/Basic.Tests.cs b/tests/Basic/Basic.Tests.cs index 2b22b1d7..7c51786e 100644 --- a/tests/Basic/Basic.Tests.cs +++ b/tests/Basic/Basic.Tests.cs @@ -123,6 +123,12 @@ public class BasicTests Assert.That(foo.GetANSI(), Is.EqualTo("ANSI")); } + [Test] + public void TestMoveFunctionToClass() + { + Assert.That(basic.test(new basic()), Is.EqualTo(5)); + } + [Test, Ignore] public void TestConversionOperator() { diff --git a/tests/Basic/Basic.cpp b/tests/Basic/Basic.cpp index 056e1dd4..b7df47d6 100644 --- a/tests/Basic/Basic.cpp +++ b/tests/Basic/Basic.cpp @@ -203,3 +203,8 @@ void DefaultParameters::Bar() const void DefaultParameters::Bar() { } + +int test(basic& s) +{ + return 5; +} diff --git a/tests/Basic/Basic.h b/tests/Basic/Basic.h index 47436670..7296ce01 100644 --- a/tests/Basic/Basic.h +++ b/tests/Basic/Basic.h @@ -178,3 +178,10 @@ class Base class Derived : public Base { }; + +class DLL_API basic +{ + +}; + +DLL_API int test(basic& s);