diff --git a/src/AST/ASTContext.cs b/src/AST/ASTContext.cs index 25ab6b74..42302939 100644 --- a/src/AST/ASTContext.cs +++ b/src/AST/ASTContext.cs @@ -92,8 +92,7 @@ namespace CppSharp.AST /// Finds an existing function in the library modules. public IEnumerable FindFunction(string name) { - return TranslationUnits.Select(module => module.FindFunction(name)) - .Where(type => type != null); + return TranslationUnits.SelectMany(module => module.FindFunction(name)); } /// Finds an existing typedef in the library modules. diff --git a/src/AST/Namespace.cs b/src/AST/Namespace.cs index 99d19f50..14afdf8f 100644 --- a/src/AST/Namespace.cs +++ b/src/AST/Namespace.cs @@ -160,33 +160,33 @@ namespace CppSharp.AST return Enums.FirstOrDefault(f => f.OriginalPtr == ptr); } - public Function FindFunction(string name, bool createDecl = false) + public IEnumerable FindFunction(string name, bool createDecl = false) { if (string.IsNullOrEmpty(name)) - return null; + return Enumerable.Empty(); var entries = name.Split(new string[] { "::" }, StringSplitOptions.RemoveEmptyEntries).ToList(); if (entries.Count <= 1) { - var function = Functions.FirstOrDefault(e => e.Name.Equals(name)); + var functions = Functions.Where(e => e.Name.Equals(name)); - if (function == null && createDecl) + if (!functions.Any() && createDecl) { - function = new Function() { Name = name, Namespace = this }; + var function = new Function() { Name = name, Namespace = this }; Declarations.Add(function); } - return function; + return functions; } - var funcName = entries[entries.Count - 1]; + var funcName = entries[^1]; var namespaces = entries.Take(entries.Count - 1); var @namespace = FindNamespace(namespaces); if (@namespace == null) - return null; + return Enumerable.Empty(); return @namespace.FindFunction(funcName, createDecl); } @@ -201,14 +201,12 @@ namespace CppSharp.AST Class CreateClass(string name, bool isComplete) { - var @class = new Class + return new Class { Name = name, Namespace = this, IsIncomplete = !isComplete }; - - return @class; } public Class FindClass(string name, @@ -316,7 +314,7 @@ namespace CppSharp.AST public T FindType(string name) where T : Declaration { var type = FindEnum(name) - ?? FindFunction(name) + ?? FindFunction(name).FirstOrDefault() ?? (Declaration)FindClass(name) ?? FindTypedef(name); diff --git a/src/Generator/Library.cs b/src/Generator/Library.cs index bcbbf37c..dace412d 100644 --- a/src/Generator/Library.cs +++ b/src/Generator/Library.cs @@ -533,7 +533,7 @@ namespace CppSharp public static IEnumerable FindFunction(this ASTContext context, string name) { return context.TranslationUnits - .Select(module => module.FindFunction(name)) + .SelectMany(module => module.FindFunction(name)) .Where(function => function != null); }