diff --git a/src/AST/Class.cs b/src/AST/Class.cs index 442fe937..758cc3eb 100644 --- a/src/AST/Class.cs +++ b/src/AST/Class.cs @@ -149,9 +149,7 @@ namespace CppSharp.AST public bool HasNonIgnoredBase => HasBaseClass && !IsValueType - && BaseClass != null - && !BaseClass.IsValueType - && BaseClass.IsGenerated; + && BaseClass is { IsValueType: false, IsGenerated: true }; public bool NeedsBase => HasNonIgnoredBase && IsGenerated; diff --git a/src/Generator.Tests/Passes/TestPasses.cs b/src/Generator.Tests/Passes/TestPasses.cs index 45387cdd..2a72a628 100644 --- a/src/Generator.Tests/Passes/TestPasses.cs +++ b/src/Generator.Tests/Passes/TestPasses.cs @@ -66,6 +66,32 @@ namespace CppSharp.Generator.Tests.Passes Assert.IsTrue(ucharClassEnum.BuiltinType.Type == PrimitiveType.UChar); } + [Test] + public void TestCheckStaticClassPass() + { + var staticClass = AstContext.Class("TestCheckStaticClass"); + var staticStruct = AstContext.Class("TestCheckStaticStruct"); + var staticClassDeletedCtor = AstContext.Class("TestCheckStaticClassDeleted"); + var nonStaticClass = AstContext.Class("TestCheckNonStaticClass"); + var nonStaticEmptyClass = AstContext.Class("TestCommentsPass"); + + Assert.IsFalse(staticClass.IsStatic); + Assert.IsFalse(staticStruct.IsStatic); + Assert.IsFalse(staticClassDeletedCtor.IsStatic); + Assert.IsFalse(nonStaticClass.IsStatic); + Assert.IsFalse(nonStaticEmptyClass.IsStatic); + + passBuilder.AddPass(new CheckStaticClassPass()); + passBuilder.RunPasses(pass => pass.VisitASTContext(AstContext)); + + Assert.IsTrue(staticClass.IsStatic, "`TestCheckStaticClass` should be static"); + Assert.IsTrue(staticStruct.IsStatic, "`TestCheckStaticStruct` should be static"); + Assert.IsTrue(staticClassDeletedCtor.IsStatic, "`TestCheckStaticClassDeleted` should be static"); + + Assert.IsFalse(nonStaticClass.IsStatic, "`TestCheckNonStaticClass` should NOT be static, since it has a private data field with default ctor"); + Assert.IsFalse(nonStaticEmptyClass.IsStatic, "`TestCommentsPass` should NOT be static, since it doesn't have any static declarations"); + } + [Test] public void TestFunctionToInstancePass() { diff --git a/src/Generator/Driver.cs b/src/Generator/Driver.cs index b8e15ecd..f9be9845 100644 --- a/src/Generator/Driver.cs +++ b/src/Generator/Driver.cs @@ -230,7 +230,7 @@ namespace CppSharp passes.AddPass(new FindSymbolsPass()); passes.AddPass(new CheckMacroPass()); - passes.AddPass(new CheckStaticClass()); + passes.AddPass(new CheckStaticClassPass()); if (Options.IsCLIGenerator || Options.IsCSharpGenerator || Options.IsCppGenerator) { diff --git a/src/Generator/Passes/CheckStaticClass.cs b/src/Generator/Passes/CheckStaticClassPass.cs similarity index 52% rename from src/Generator/Passes/CheckStaticClass.cs rename to src/Generator/Passes/CheckStaticClassPass.cs index 7b4e2844..79f896c8 100644 --- a/src/Generator/Passes/CheckStaticClass.cs +++ b/src/Generator/Passes/CheckStaticClassPass.cs @@ -7,9 +7,9 @@ namespace CppSharp.Passes /// /// Checks for classes that should be bound as static classes. /// - public class CheckStaticClass : TranslationUnitPass + public class CheckStaticClassPass : TranslationUnitPass { - public CheckStaticClass() + public CheckStaticClassPass() => VisitOptions.ResetFlags(VisitFlags.ClassMethods); public override bool VisitDeclaration(Declaration decl) @@ -32,8 +32,7 @@ namespace CppSharp.Passes if (decl.Access != AccessSpecifier.Protected) return false; - var @class = decl.Namespace as Class; - return @class != null && @class.IsStatic; + return decl.Namespace is Class { IsStatic: true }; } static void SetDeclarationAccessToPrivate(Declaration decl) @@ -42,7 +41,7 @@ namespace CppSharp.Passes // as an internal in the final C# wrapper. decl.Access = AccessSpecifier.Private; - // We need to explicity set the generation else the + // We need to explicitly set the generation else the // now private declaration will get filtered out later. decl.GenerationKind = GenerationKind.Generate; } @@ -51,53 +50,90 @@ namespace CppSharp.Passes { var returnType = function.ReturnType.Type.Desugar(); - var tag = returnType as TagType; - if (tag == null) + if (returnType is not TagType tag) returnType.IsPointerTo(out tag); - if (tag == null) + var decl = tag?.Declaration; + if (decl is not Class) return false; var @class = (Class)function.Namespace; - var decl = tag.Declaration; + return @class.QualifiedOriginalName == decl.QualifiedOriginalName; + } - if (!(decl is Class)) - return false; + static bool AcceptsClassInstance(Function function) + { + return function.Parameters.Any(param => + { + var paramType = param.Type.Desugar(); - return @class.QualifiedOriginalName == decl.QualifiedOriginalName; + if (paramType is not TagType tag) + paramType.IsPointerTo(out tag); + + var decl = tag?.Declaration; + if (decl is not Class) + return false; + + var @class = (Class)function.Namespace; + return @class.QualifiedOriginalName == decl.QualifiedOriginalName; + } + ); } public override bool VisitClassDecl(Class @class) { - // If the class has any non-private constructors then it cannot - // be bound as a static class and we bail out early. - if (@class.Constructors.Any(m => - !(m.IsCopyConstructor || m.IsMoveConstructor) - && m.Access != AccessSpecifier.Private)) + // If the class is to be used as an opaque type, then it cannot be + // bound as static. + if (@class.IsOpaque) + return false; + + if (@class.IsDependent) + return false; + + // Polymorphic classes are currently not supported. + // TODO: We could support this if the base class is also static, since it's composition then. + if (@class.IsPolymorphic) + return false; + + // Make sure we have at least one accessible static method or field + if (!@class.Methods.Any(m => m.Kind == CXXMethodKind.Normal && m.Access != AccessSpecifier.Private && m.IsStatic) + && @class.Variables.All(v => v.Access == AccessSpecifier.Private)) return false; // Check for any non-static fields or methods, in which case we // assume the class is not meant to be static. // Note: Static fields are represented as variables in the AST. - if (@class.Fields.Any() || - @class.Methods.Any(m => m.Kind == CXXMethodKind.Normal - && !m.IsStatic)) + if (@class.Fields.Count != 0) return false; - // Check for any static function that return a pointer to the class. - // If one exists, we assume it's a factory function and the class is - // not meant to be static. It's a simple heuristic but it should be - // good enough for the time being. - if (@class.Functions.Any(m => !m.IsOperator && ReturnsClassInstance(m)) || - @class.Methods.Any(m => !m.IsOperator && ReturnsClassInstance(m))) + if (@class.Methods.Any(m => m.Kind == CXXMethodKind.Normal && !m.IsStatic)) return false; - // If the class is to be used as an opaque type, then it cannot be - // bound as static. - if (@class.IsOpaque) + if (@class.Constructors.Any(m => + { + // Implicit constructors are not user-defined, so assume this was unintentional. + if (m.IsImplicit) + return false; + + // Ignore deleted constructors. + if (m.IsDeleted) + return false; + + // If the class has a copy or move constructor, it cannot be static. + if (m.IsCopyConstructor || m.IsMoveConstructor) + return true; + + // If the class has any (user defined) non-private constructors then it cannot be static + return m.Access != AccessSpecifier.Private; + })) + { return false; + } - if (@class.IsDependent) + // Check for any static function that accepts/returns a pointer to the class. + // If one exists, we assume it's not meant to be static + if (@class.Functions.Any(m => !m.IsOperator && (ReturnsClassInstance(m) || AcceptsClassInstance(m))) || + @class.Methods.Any(m => !m.IsOperator && !m.IsConstructor && (ReturnsClassInstance(m) || AcceptsClassInstance(m)))) return false; // TODO: We should take C++ friends into account here, they might allow diff --git a/tests/dotnet/Native/Passes.h b/tests/dotnet/Native/Passes.h index e9bdd32f..ee3c8c7a 100644 --- a/tests/dotnet/Native/Passes.h +++ b/tests/dotnet/Native/Passes.h @@ -130,6 +130,60 @@ struct TestCheckAmbiguousFunctionsPass int Method(int x) const; }; +class TestCheckStaticClass +{ +public: + static int Method(); + static int Method(int x); + + constexpr static float ConstExprStatic = 3.0f; + inline static float InlineStatic = 1.0f; + +private: + inline static float PrivateInlineStatic = 1.0f; +}; + +struct TestCheckStaticStruct +{ + static int Method(); + static int Method(int x); + + constexpr static float ConstExprStatic = 3.0f; + inline static float InlineStatic = 1.0f; +}; + +class TestCheckStaticClassDeleted +{ +public: + TestCheckStaticClassDeleted() = delete; + + static int Method(); + static int Method(int x); + + constexpr static float ConstExprStatic = 3.0f; + inline static float InlineStatic = 1.0f; + +private: + inline static float PrivateInlineStatic = 1.0f; +}; + +class TestCheckNonStaticClass +{ +public: + TestCheckNonStaticClass() = default; + + static int Method(); + static int Method(int x); + + constexpr static float ConstExprStatic = 3.0f; + inline static float InlineStatic = 1.0f; + +private: + inline static float PrivateInlineStatic = 1.0f; + + float NonStatic = 1.0f; +}; + #define CS_INTERNAL struct TestMethodAsInternal {