diff --git a/src/AST/Class.cs b/src/AST/Class.cs index f4812594..15061f45 100644 --- a/src/AST/Class.cs +++ b/src/AST/Class.cs @@ -149,8 +149,8 @@ namespace CppSharp.AST { foreach (var @base in Bases) { - if (@base.IsClass && @base.Class.IsDeclared) - return @base.Class; + if (@base.IsClass && @base.Class.IsDeclared) + return @base.Class; } return null; @@ -199,11 +199,16 @@ namespace CppSharp.AST } } + public override IEnumerable FindOperator(CXXOperatorKind kind) + { + return Methods.Where(m => m.OperatorKind == kind); + } + public override IEnumerable GetOverloads(Function function) { - if (function.IsOperator) - return Methods.Where(fn => fn.OperatorKind == function.OperatorKind); - + if (function.IsOperator) + return Methods.Where(fn => fn.OperatorKind == function.OperatorKind); + var methods = Methods.Where(m => m.Name == function.Name); if (methods.ToList().Count != 0) return methods; diff --git a/src/AST/Namespace.cs b/src/AST/Namespace.cs index 6adaab95..e4927284 100644 --- a/src/AST/Namespace.cs +++ b/src/AST/Namespace.cs @@ -25,16 +25,16 @@ namespace CppSharp.AST public Dictionary Anonymous; // True if the context is inside an extern "C" context. - public bool IsExternCContext; - - public override string LogicalName - { - get { return IsAnonymous ? "" : base.Name; } - } - - public override string LogicalOriginalName - { - get { return IsAnonymous ? "" : base.OriginalName; } + public bool IsExternCContext; + + public override string LogicalName + { + get { return IsAnonymous ? "" : base.Name; } + } + + public override string LogicalOriginalName + { + get { return IsAnonymous ? "" : base.OriginalName; } } protected DeclarationContext() @@ -71,25 +71,25 @@ namespace CppSharp.AST public Declaration FindAnonymous(ulong key) { return Anonymous.ContainsKey(key) ? Anonymous[key] : null; - } - - public DeclarationContext FindDeclaration(IEnumerable declarations) - { - DeclarationContext currentDeclaration = this; - - foreach (var declaration in declarations) - { - var subDeclaration = currentDeclaration.Namespaces - .Concat(currentDeclaration.Classes) - .FirstOrDefault(e => e.Name.Equals(declaration)); - - if (subDeclaration == null) - return null; - - currentDeclaration = subDeclaration; - } - - return currentDeclaration as DeclarationContext; + } + + public DeclarationContext FindDeclaration(IEnumerable declarations) + { + DeclarationContext currentDeclaration = this; + + foreach (var declaration in declarations) + { + var subDeclaration = currentDeclaration.Namespaces + .Concat(currentDeclaration.Classes) + .FirstOrDefault(e => e.Name.Equals(declaration)); + + if (subDeclaration == null) + return null; + + currentDeclaration = subDeclaration; + } + + return currentDeclaration as DeclarationContext; } public Namespace FindNamespace(string name) @@ -165,34 +165,34 @@ namespace CppSharp.AST } public Function FindFunction(string name, bool createDecl = false) - { - if (string.IsNullOrEmpty(name)) - return null; - - var entries = name.Split(new string[] { "::" }, - StringSplitOptions.RemoveEmptyEntries).ToList(); - - if (entries.Count <= 1) - { - var function = Functions.Find(e => e.Name.Equals(name)); - - if (function == null && createDecl) - { - function = new Function() { Name = name, Namespace = this }; - Functions.Add(function); - } - - return function; - } - - var funcName = entries[entries.Count - 1]; - var namespaces = entries.Take(entries.Count - 1); - - var @namespace = FindNamespace(namespaces); - if (@namespace == null) - return null; - - return @namespace.FindFunction(funcName, createDecl); + { + if (string.IsNullOrEmpty(name)) + return null; + + var entries = name.Split(new string[] { "::" }, + StringSplitOptions.RemoveEmptyEntries).ToList(); + + if (entries.Count <= 1) + { + var function = Functions.Find(e => e.Name.Equals(name)); + + if (function == null && createDecl) + { + function = new Function() { Name = name, Namespace = this }; + Functions.Add(function); + } + + return function; + } + + var funcName = entries[entries.Count - 1]; + var namespaces = entries.Take(entries.Count - 1); + + var @namespace = FindNamespace(namespaces); + if (@namespace == null) + return null; + + return @namespace.FindFunction(funcName, createDecl); } Class CreateClass(string name, bool isComplete) @@ -229,7 +229,7 @@ namespace CppSharp.AST DeclarationContext declContext = FindDeclaration(namespaces); if (declContext == null) - { + { declContext = FindClass(entries[0]); if (declContext == null) return null; @@ -320,13 +320,13 @@ namespace CppSharp.AST return Enums.Find(e => e.ItemsByName.ContainsKey(name)); } - public IEnumerable FindOperator(CXXOperatorKind kind) + public virtual IEnumerable FindOperator(CXXOperatorKind kind) { return Functions.Where(fn => fn.OperatorKind == kind); } - public virtual IEnumerable GetOverloads(Function function) - { + public virtual IEnumerable GetOverloads(Function function) + { if (function.IsOperator) return FindOperator(function.OperatorKind); return Functions.Where(fn => fn.Name == function.Name); @@ -335,17 +335,17 @@ namespace CppSharp.AST public bool HasDeclarations { get - { + { Predicate pred = (t => t.IsGenerated); return Enums.Exists(pred) || HasFunctions || Typedefs.Exists(pred) - || Classes.Any() || Namespaces.Exists(n => n.HasDeclarations); + || Classes.Any() || Namespaces.Exists(n => n.HasDeclarations); } } public bool HasFunctions { get - { + { Predicate pred = (t => t.IsGenerated); return Functions.Exists(pred) || Namespaces.Exists(n => n.HasFunctions); } @@ -357,18 +357,18 @@ namespace CppSharp.AST /// /// Represents a C++ namespace. /// - public class Namespace : DeclarationContext - { - public override string LogicalName - { - get { return IsInline ? string.Empty : base.Name; } - } - - public override string LogicalOriginalName - { - get { return IsInline ? string.Empty : base.OriginalName; } - } - + public class Namespace : DeclarationContext + { + public override string LogicalName + { + get { return IsInline ? string.Empty : base.Name; } + } + + public override string LogicalOriginalName + { + get { return IsInline ? string.Empty : base.OriginalName; } + } + public bool IsInline; public override T Visit(IDeclVisitor visitor) diff --git a/src/Generator.Tests/AST/TestAST.cs b/src/Generator.Tests/AST/TestAST.cs index 872a7081..ceb50d29 100644 --- a/src/Generator.Tests/AST/TestAST.cs +++ b/src/Generator.Tests/AST/TestAST.cs @@ -57,5 +57,14 @@ namespace CppSharp.Generator.Tests.AST } Assert.IsTrue(func.Parameters[2].HasDefaultValue, "Parameter.HasDefaultValue"); } + + [Test] + public void TestASTHelperMethods() + { + var @class = AstContext.FindClass("Math::Complex").FirstOrDefault(); + Assert.IsNotNull(@class, "Couldn't find Math::Complex class."); + var plusOperator = @class.FindOperator(CXXOperatorKind.Plus).FirstOrDefault(); + Assert.IsNotNull(plusOperator, "Couldn't find operator+ in Math::Complex class."); + } } } diff --git a/tests/Native/AST.h b/tests/Native/AST.h index 3e917755..ac8e6f1d 100644 --- a/tests/Native/AST.h +++ b/tests/Native/AST.h @@ -1,2 +1,18 @@ // Tests assignment of AST.Parameter properties void TestParameterProperties(bool a, const short& b, int* c = nullptr) {}; + +// Tests various AST helper methods (like FindClass, FindOperator etc.) +namespace Math +{ + struct Complex { + Complex(double r, double i) : re(r), im(i) {} + Complex operator+(Complex &other); + private: + double re, im; + }; + + // Operator overloaded using a member function + Complex Complex::operator+(Complex &other) { + return Complex(re + other.re, im + other.im); + } +} \ No newline at end of file