diff --git a/src/AST/ASTContext.cs b/src/AST/ASTContext.cs index a6e895bb..8ad4c139 100644 --- a/src/AST/ASTContext.cs +++ b/src/AST/ASTContext.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; namespace CppSharp.AST { @@ -50,61 +51,45 @@ namespace CppSharp.AST /// Finds an existing enum in the library modules. public IEnumerable FindEnum(string name) { - foreach (var module in TranslationUnits) - { - var type = module.FindEnum(name); - if (type != null) yield return type; - } + return TranslationUnits.Select( + module => module.FindEnum(name)).Where(type => type != null); } /// Finds the complete declaration of an enum. public Enumeration FindCompleteEnum(string name) { - foreach (var @enum in FindEnum(name)) - if (!@enum.IsIncomplete) - return @enum; - - return null; + return FindEnum(name).FirstOrDefault(@enum => !@enum.IsIncomplete); } /// Finds an existing struct/class in the library modules. - public IEnumerable FindClass(string name, bool create = false) + public IEnumerable FindClass(string name, bool create = false, + bool ignoreCase = false) { - foreach (var module in TranslationUnits) - { - var type = module.FindClass(name); - if (type != null) yield return type; - } + return TranslationUnits.Select( + module => module.FindClass(name, + ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal)) + .Where(type => type != null); } /// Finds the complete declaration of a class. - public Class FindCompleteClass(string name) + public Class FindCompleteClass(string name, bool ignoreCase = false) { - foreach (var @class in FindClass(name)) - if (!@class.IsIncomplete) - return @class; - - return null; + return FindClass(name, ignoreCase: ignoreCase).FirstOrDefault( + @class => !@class.IsIncomplete); } /// Finds an existing function in the library modules. public IEnumerable FindFunction(string name) { - foreach (var module in TranslationUnits) - { - var type = module.FindFunction(name); - if (type != null) yield return type; - } + return TranslationUnits.Select(module => module.FindFunction(name)) + .Where(type => type != null); } /// Finds an existing typedef in the library modules. public IEnumerable FindTypedef(string name) { - foreach (var module in TranslationUnits) - { - var type = module.FindTypedef(name); - if (type != null) yield return type; - } + return TranslationUnits.Select(module => module.FindTypedef(name)) + .Where(type => type != null); } /// Finds an existing declaration by name. diff --git a/src/AST/Namespace.cs b/src/AST/Namespace.cs index 64c77d1d..4d29d693 100644 --- a/src/AST/Namespace.cs +++ b/src/AST/Namespace.cs @@ -168,17 +168,17 @@ namespace CppSharp.AST return @class; } - public Class FindClass(string name) + public Class FindClass(string name, + StringComparison stringComparison = StringComparison.Ordinal) { if (string.IsNullOrEmpty(name)) return null; - var entries = name.Split(new string[] { "::" }, + var entries = name.Split(new[] { "::" }, StringSplitOptions.RemoveEmptyEntries).ToList(); if (entries.Count <= 1) { - var @class = Classes.Find(e => e.Name.Equals(name)); - return @class; + return Classes.Find(e => e.Name.Equals(name, stringComparison)); } var className = entries[entries.Count - 1]; diff --git a/src/Generator/Driver.cs b/src/Generator/Driver.cs index 728b828a..5e78bfc9 100644 --- a/src/Generator/Driver.cs +++ b/src/Generator/Driver.cs @@ -201,6 +201,7 @@ namespace CppSharp TranslationUnitPasses.AddPass(new FindSymbolsPass()); TranslationUnitPasses.AddPass(new MoveOperatorToClassPass()); + TranslationUnitPasses.AddPass(new MoveFunctionToClassPass()); TranslationUnitPasses.AddPass(new CheckAmbiguousFunctions()); TranslationUnitPasses.AddPass(new CheckOperatorsOverloadsPass()); TranslationUnitPasses.AddPass(new CheckVirtualOverrideReturnCovariance()); diff --git a/src/Generator/Passes/MoveFunctionToClassPass.cs b/src/Generator/Passes/MoveFunctionToClassPass.cs new file mode 100644 index 00000000..e09d6a70 --- /dev/null +++ b/src/Generator/Passes/MoveFunctionToClassPass.cs @@ -0,0 +1,60 @@ +using System.Linq; +using CppSharp.AST; + +namespace CppSharp.Passes +{ + /// + /// Moves a function to a class, if any, named after the function's header. + /// + public class MoveFunctionToClassPass : TranslationUnitPass + { + public override bool VisitFunctionDecl(Function function) + { + if (AlreadyVisited(function) || function.Ignore || function.Namespace is Class) + return base.VisitFunctionDecl(function); + + Class @class = FindClassToMoveFunctionTo(function.Namespace); + if (@class != null) + { + MoveFunction(function, @class); + } + return base.VisitFunctionDecl(function); + } + + private Class FindClassToMoveFunctionTo(INamedDecl @namespace) + { + TranslationUnit unit = @namespace as TranslationUnit; + if (unit == null) + { + return Driver.ASTContext.FindClass( + @namespace.Name, ignoreCase: true).FirstOrDefault(); + } + return Driver.ASTContext.FindCompleteClass( + unit.FileNameWithoutExtension.ToLowerInvariant(), true); + } + + private static void MoveFunction(Function function, Class @class) + { + var method = new Method(function) + { + Namespace = @class, + IsStatic = true + }; + + if (method.OperatorKind != CXXOperatorKind.None) + { + var param = function.Parameters[0]; + Class type; + if (!FunctionToInstanceMethodPass.GetClassParameter(param, out type)) + return; + method.Kind = CXXMethodKind.Operator; + method.SynthKind = FunctionSynthKind.NonMemberOperator; + method.OriginalFunction = null; + } + + function.ExplicityIgnored = true; + + @class.Methods.Add(method); + } + } +} diff --git a/tests/Basic/Basic.Tests.cs b/tests/Basic/Basic.Tests.cs index 2b22b1d7..7c51786e 100644 --- a/tests/Basic/Basic.Tests.cs +++ b/tests/Basic/Basic.Tests.cs @@ -123,6 +123,12 @@ public class BasicTests Assert.That(foo.GetANSI(), Is.EqualTo("ANSI")); } + [Test] + public void TestMoveFunctionToClass() + { + Assert.That(basic.test(new basic()), Is.EqualTo(5)); + } + [Test, Ignore] public void TestConversionOperator() { diff --git a/tests/Basic/Basic.cpp b/tests/Basic/Basic.cpp index 056e1dd4..b7df47d6 100644 --- a/tests/Basic/Basic.cpp +++ b/tests/Basic/Basic.cpp @@ -203,3 +203,8 @@ void DefaultParameters::Bar() const void DefaultParameters::Bar() { } + +int test(basic& s) +{ + return 5; +} diff --git a/tests/Basic/Basic.h b/tests/Basic/Basic.h index 47436670..7296ce01 100644 --- a/tests/Basic/Basic.h +++ b/tests/Basic/Basic.h @@ -178,3 +178,10 @@ class Base class Derived : public Base { }; + +class DLL_API basic +{ + +}; + +DLL_API int test(basic& s);