diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index 50c61c7d..9064ac37 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -146,13 +146,25 @@ namespace CppSharp.Generators.CSharp var pointee = pointer.Pointee.Desugar(); var finalPointee = pointer.GetFinalPointee().Desugar(); - var type = Context.ReturnType.Type.Desugar( + var returnType = Context.ReturnType.Type.Desugar( resolveTemplateSubstitution: false); - PrimitiveType primitive; - if ((pointee.IsConstCharString() && (isRefParam || type.IsReference())) || - (!finalPointee.IsPrimitiveType(out primitive) && + if ((pointee.IsConstCharString() && (isRefParam || returnType.IsReference())) || + (!finalPointee.IsPrimitiveType(out PrimitiveType primitive) && !finalPointee.IsEnumType())) + { + if (Context.MarshalKind != MarshalKind.NativeField && + pointee.IsPointerTo(out Type type) && + type.Desugar().TryGetClass(out Class c)) + { + string ret = Generator.GeneratedIdentifier(Context.ReturnVarName); + Context.Before.WriteLine($@"{typePrinter.IntPtrType} {ret} = { + Context.ReturnVarName} == {typePrinter.IntPtrType}.Zero ? { + typePrinter.IntPtrType}.Zero : new { + typePrinter.IntPtrType}(*(void**) {Context.ReturnVarName});"); + Context.ReturnVarName = ret; + } return pointer.QualifiedPointee.Visit(this); + } if (isRefParam) { @@ -167,7 +179,7 @@ namespace CppSharp.Generators.CSharp if (Context.Function != null && Context.Function.OperatorKind == CXXOperatorKind.Subscript) { - if (type.IsPrimitiveType(primitive)) + if (returnType.IsPrimitiveType(primitive)) { Context.Return.Write("*"); } @@ -517,7 +529,7 @@ namespace CppSharp.Generators.CSharp if (param.IsOut) { MarshalString(pointee); - Context.Return.Write("IntPtr.Zero"); + Context.Return.Write($"{typePrinter.IntPtrType}.Zero"); Context.ArgumentPrefix.Write("&"); return true; } @@ -604,7 +616,7 @@ namespace CppSharp.Generators.CSharp arg, Context.Parameter.Name, Helpers.InstanceIdentifier); } - Context.Return.Write($"new global::System.IntPtr(&{arg})"); + Context.Return.Write($"new {typePrinter.IntPtrType}(&{arg})"); return true; } @@ -614,7 +626,7 @@ namespace CppSharp.Generators.CSharp pointer.QualifiedPointee.Visit(this); Context.Before.WriteLine($"var {arg} = {Context.Return};"); Context.Return.StringBuilder.Clear(); - Context.Return.Write($"new global::System.IntPtr(&{arg})"); + Context.Return.Write($"new {typePrinter.IntPtrType}(&{arg})"); return true; } diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 879c66b3..eb93c6fc 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -1741,15 +1741,16 @@ namespace CppSharp.Generators.CSharp } } - bool isVoid = method.OriginalReturnType.Type.Desugar().IsPrimitiveType( - PrimitiveType.Void); + Type returnType = method.OriginalReturnType.Type.Desugar(); + bool isPrimitive = returnType.IsPrimitiveType(); + bool isVoid = returnType.IsPrimitiveType(PrimitiveType.Void); var property = ((Class) method.Namespace).Properties.Find( p => p.GetMethod == method || p.SetMethod == method); bool isSetter = property != null && property.SetMethod == method; var hasReturn = !isVoid && !isSetter; if (hasReturn) - Write($"var {Helpers.ReturnIdentifier} = "); + Write(isPrimitive && !isSetter ? "return " : $"var {Helpers.ReturnIdentifier} = "); Write($"{Helpers.TargetIdentifier}."); string marshalsCode = string.Join(", ", marshals); @@ -1764,6 +1765,8 @@ namespace CppSharp.Generators.CSharp Write($" = {marshalsCode}"); } WriteLine(";"); + if (isPrimitive && !isSetter) + return; if (hasReturn) { diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index b2eb2338..c45c99e6 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -1285,8 +1285,12 @@ public unsafe class CSharpTests : GeneratorTestFixture { using (Foo foo = new Foo { A = 25 }) { - Foo returnedFoo = CSharp.CSharp.TakeRefToPointerToObject(foo); + Foo returnedFoo = CSharp.CSharp.TakeReturnReferenceToPointer(foo); Assert.That(returnedFoo.A, Is.EqualTo(foo.A)); + using (Qux qux = new Qux()) + { + Assert.That(qux.TakeReferenceToPointer(foo), Is.EqualTo(foo.A)); + } } } diff --git a/tests/CSharp/CSharp.cpp b/tests/CSharp/CSharp.cpp index e71fa516..35611676 100644 --- a/tests/CSharp/CSharp.cpp +++ b/tests/CSharp/CSharp.cpp @@ -225,6 +225,11 @@ void Qux::makeClassDynamic() { } +int Qux::takeReferenceToPointer(Foo*& ret) +{ + return ret->A; +} + Bar::Bar(Qux qux) { } @@ -1623,7 +1628,7 @@ const void*& rValueReferenceToPointer(void*&& v) return (const void*&) v; } -const Foo* takeRefToPointerToObject(const Foo*& foo) +const Foo*& takeReturnReferenceToPointer(const Foo*& foo) { return foo; } diff --git a/tests/CSharp/CSharp.h b/tests/CSharp/CSharp.h index 25d2368f..5e697e47 100644 --- a/tests/CSharp/CSharp.h +++ b/tests/CSharp/CSharp.h @@ -84,6 +84,7 @@ public: Qux* getInterface(); void setInterface(Qux* qux); virtual void makeClassDynamic(); + virtual int takeReferenceToPointer(Foo*& ret); }; class DLL_API Bar : public Qux @@ -1325,7 +1326,7 @@ DLL_API char* returnCharPointer(); DLL_API char* takeConstCharRef(const char& c); DLL_API const char*& takeConstCharStarRef(const char*& c); DLL_API const void*& rValueReferenceToPointer(void*&& v); -DLL_API const Foo* takeRefToPointerToObject(const Foo*& foo); +DLL_API const Foo*& takeReturnReferenceToPointer(const Foo*& foo); struct { struct {