diff --git a/src/AST/TypeExtensions.cs b/src/AST/TypeExtensions.cs index 32333d6e..af22e298 100644 --- a/src/AST/TypeExtensions.cs +++ b/src/AST/TypeExtensions.cs @@ -439,6 +439,22 @@ return false; } + public static bool IsTemplate(this Type type) + { + if (type is TemplateParameterType or TemplateParameterSubstitutionType) + return true; + + var ptr = type; + while (ptr is PointerType pType) + { + ptr = pType.Pointee; + if (ptr is TemplateParameterType or TemplateParameterSubstitutionType) + return true; + } + + return false; + } + public static Module GetModule(this Type type) { Declaration declaration; diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index 77106453..dea8aa5d 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -601,7 +601,7 @@ namespace CppSharp.Generators.CSharp if (Context.Context.Options.MarshalCharAsManagedChar && primitive == PrimitiveType.Char) { - Context.Return.Write($"({typePrinter.PrintNative(pointer)})"); + Context.Return.StringBuilder.Insert(0, $"({typePrinter.PrintNative(pointer)}) "); if (isConst) Context.Return.Write("&"); Context.Return.Write(param.Name); @@ -643,8 +643,13 @@ namespace CppSharp.Generators.CSharp } else { - Context.Before.WriteLine("var {0} = {1}.{2};", - arg, Context.Parameter.Name, Helpers.InstanceIdentifier); + Context.Before.Write($"var {arg} = "); + if (pointer.Pointee.IsTemplate()) + Context.Before.Write($"(({Context.Parameter.Type}) (object) {Context.Parameter.Name})"); + else + Context.Before.WriteLine(Context.Parameter.Name); + Context.Before.WriteLine($".{Helpers.InstanceIdentifier};"); + Context.Return.Write($"new {typePrinter.IntPtrType}(&{arg})"); } @@ -805,7 +810,12 @@ namespace CppSharp.Generators.CSharp private void MarshalValueClass() { - Context.Return.Write("{0}.{1}", Context.Parameter.Name, Helpers.InstanceIdentifier); + if (Context.Parameter.Type.IsTemplate()) + Context.Return.Write($"(({Context.Parameter.Type}) (object) {Context.Parameter.Name})"); + else + Context.Return.Write(Context.Parameter.Name); + + Context.Return.Write($".{Helpers.InstanceIdentifier}"); } public override bool VisitFieldDecl(Field field) diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 1540de3f..7c0ae294 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -85,6 +85,7 @@ namespace CppSharp.Generators.CSharp GenerateUsings(); WriteLine("#pragma warning disable CS0109 // Member does not hide an inherited member; new keyword is not required"); + WriteLine("#pragma warning disable CS9084 // Struct member returns 'this' or other instance members by reference"); NewLine(); if (!string.IsNullOrEmpty(Module.OutputNamespace)) diff --git a/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs b/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs index 07ae0d79..a027390b 100644 --- a/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs +++ b/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs @@ -52,6 +52,7 @@ namespace CppSharp.Passes if (!methodsWithDependentPointers.Any()) return false; + var hasMethods = false; var classExtensions = new Class { Name = $"{@class.Name}Extensions", IsStatic = true }; foreach (var specialization in @class.Specializations.Where(s => s.IsGenerated)) foreach (var method in methodsWithDependentPointers.Where( @@ -59,9 +60,11 @@ namespace CppSharp.Passes { var specializedMethod = specialization.Methods.FirstOrDefault( m => m.InstantiatedFrom == method); - if (specializedMethod == null) + if (specializedMethod == null || specializedMethod.IsOperator) continue; + hasMethods = true; + Method extensionMethod = GetExtensionMethodForDependentPointer(specializedMethod); classExtensions.Methods.Add(extensionMethod); extensionMethod.Namespace = classExtensions; @@ -75,6 +78,10 @@ namespace CppSharp.Passes extensionMethod.GenerationKind = GenerationKind.Generate; } } + + if (!hasMethods) + return false; + classExtensions.Namespace = @class.Namespace; classExtensions.OriginalClass = @class; extensions.Add(classExtensions); diff --git a/src/Generator/Types/Std/Stdlib.CSharp.cs b/src/Generator/Types/Std/Stdlib.CSharp.cs index 83df7b1c..c197a665 100644 --- a/src/Generator/Types/Std/Stdlib.CSharp.cs +++ b/src/Generator/Types/Std/Stdlib.CSharp.cs @@ -329,7 +329,10 @@ namespace CppSharp.Types.Std ctx.Return.Write($@"{qualifiedBasicString}Extensions.{ Helpers.InternalStruct}.{assign.Name}(new { typePrinter.IntPtrType}(&{ - ctx.ReturnVarName}), {ctx.Parameter.Name})"); + ctx.ReturnVarName}), "); + if (ctx.Parameter.Type.IsTemplate()) + ctx.Return.Write("(string) (object) "); + ctx.Return.Write($"{ctx.Parameter.Name})"); ctx.ReturnVarName = string.Empty; } else @@ -337,8 +340,13 @@ namespace CppSharp.Types.Std var varBasicString = $"__basicString{ctx.ParameterIndex}"; ctx.Before.WriteLine($@"var {varBasicString} = new { basicString.Visit(typePrinter)}();"); - ctx.Before.WriteLine($@"{qualifiedBasicString}Extensions.{ - assign.Name}({varBasicString}, {ctx.Parameter.Name});"); + + ctx.Before.Write($@"{qualifiedBasicString}Extensions.{ + assign.Name}({varBasicString}, "); + if (ctx.Parameter.Type.IsTemplate()) + ctx.Before.Write("(string) (object) "); + ctx.Before.WriteLine($"{ctx.Parameter.Name});"); + ctx.Return.Write($"{varBasicString}.{Helpers.InstanceIdentifier}"); ctx.Cleanup.WriteLine($@"{varBasicString}.Dispose({ (!Type.IsAddress() || ctx.Parameter?.IsIndirect == true ? "disposing: true, callNativeDtor:false" : string.Empty)});"); diff --git a/tests/dotnet/CSharp/CSharp.Tests.cs b/tests/dotnet/CSharp/CSharp.Tests.cs index ecf73afe..98e080b0 100644 --- a/tests/dotnet/CSharp/CSharp.Tests.cs +++ b/tests/dotnet/CSharp/CSharp.Tests.cs @@ -2003,4 +2003,25 @@ public unsafe class CSharpTests Assert.AreEqual(2, unionTestA.A); Assert.AreEqual(2, unionTestB.B); } + + [TestCase("hi")] + [TestCase(2u)] + public void TestOptional<T>(T value) + { + Assert.That(new CSharp.Optional<T>() != new CSharp.Optional<T>(value)); + Assert.That(new CSharp.Optional<T>() != value); + Assert.That(new CSharp.Optional<T>() == new CSharp.Optional<T>()); + Assert.That(new CSharp.Optional<T>(value) == new CSharp.Optional<T>(value)); + Assert.That(new CSharp.Optional<T>(value) == value); + } + + [Test] + public void TestOptionalIntPtr() + { + Assert.That(new CSharp.Optional<IntPtr>() != new CSharp.Optional<IntPtr>(IntPtr.MaxValue)); + Assert.That(new CSharp.Optional<IntPtr>() != IntPtr.MaxValue); + Assert.That(new CSharp.Optional<IntPtr>() == new CSharp.Optional<IntPtr>()); + Assert.That(new CSharp.Optional<IntPtr>(IntPtr.MaxValue) == new CSharp.Optional<IntPtr>(IntPtr.MaxValue)); + Assert.That(new CSharp.Optional<IntPtr>(IntPtr.MaxValue) == IntPtr.MaxValue); + } } diff --git a/tests/dotnet/CSharp/CSharp.h b/tests/dotnet/CSharp/CSharp.h index fcc638f5..78a8ea7d 100644 --- a/tests/dotnet/CSharp/CSharp.h +++ b/tests/dotnet/CSharp/CSharp.h @@ -733,22 +733,22 @@ class DLL_API TestParamToInterfacePassBaseOne class DLL_API TestParamToInterfacePassBaseTwo { - int m; + int m; public: - int getM(); - void setM(int n); - const TestParamToInterfacePassBaseTwo& operator++(); - TestParamToInterfacePassBaseTwo(); - TestParamToInterfacePassBaseTwo(int n); + int getM(); + void setM(int n); + const TestParamToInterfacePassBaseTwo& operator++(); + TestParamToInterfacePassBaseTwo(); + TestParamToInterfacePassBaseTwo(int n); }; class DLL_API TestParamToInterfacePass : public TestParamToInterfacePassBaseOne, public TestParamToInterfacePassBaseTwo { public: - TestParamToInterfacePassBaseTwo addM(TestParamToInterfacePassBaseTwo b); - TestParamToInterfacePassBaseTwo operator+(TestParamToInterfacePassBaseTwo b); - TestParamToInterfacePass(TestParamToInterfacePassBaseTwo b); - TestParamToInterfacePass(); + TestParamToInterfacePassBaseTwo addM(TestParamToInterfacePassBaseTwo b); + TestParamToInterfacePassBaseTwo operator+(TestParamToInterfacePassBaseTwo b); + TestParamToInterfacePass(TestParamToInterfacePassBaseTwo b); + TestParamToInterfacePass(); }; class DLL_API HasProtectedVirtual @@ -973,18 +973,18 @@ class DLL_API ClassWithVirtualBase : public virtual Foo namespace NamespaceA { - CS_VALUE_TYPE class DLL_API A - { - }; + CS_VALUE_TYPE class DLL_API A + { + }; } namespace NamespaceB { - class DLL_API B - { - public: - void Function(CS_OUT NamespaceA::A &a); - }; + class DLL_API B + { + public: + void Function(CS_OUT NamespaceA::A &a); + }; } class DLL_API HasPrivateVirtualProperty @@ -1607,6 +1607,37 @@ DLL_API extern PointerTester* PointerToClass; union DLL_API UnionTester { float a; int b; + inline bool operator ==(const UnionTester& other) const { + return b == other.b; + } }; int DLL_API ValueTypeOutParameter(CS_OUT UnionTester* testerA, CS_OUT UnionTester* testerB); + +template <class T> +class Optional { +public: + T m_value; + bool m_hasValue; + + Optional() { + m_hasValue = false; + } + + Optional(T value) { + m_value = std::move(value); + m_hasValue = true; + } + + inline bool operator ==(const Optional<T>& rhs) const { + return (m_hasValue == rhs.m_hasValue && (!m_hasValue || m_value == rhs.m_value)); + } + + inline bool operator ==(const T& rhs) const { + return (m_hasValue && m_value == rhs); + } +}; + +// We just need a method that uses various instantiations of Optional. +inline void DLL_API InstantiateOptionalTemplate(Optional<unsigned int>, Optional<std::string>, + Optional<TestComparison>, Optional<char*>, Optional<UnionTester>) { }