diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index 7e99887f..35a08eb3 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -269,22 +269,35 @@ namespace CppSharp.Generators.CSharp "Internal" : ""); if (returnType.IsAddress()) - Context.Return.Write("({0} == IntPtr.Zero) ? {1} : {2}.{3}({0})", instance, - @class.IsRefType ? "null" : string.Format("new {0}()", type), - type, Helpers.CreateInstanceIdentifier); + { + var ret = Generator.GeneratedIdentifier("result") + Context.ParameterIndex; + Context.SupportBefore.WriteLine("{0} {1};", type, ret); + Context.SupportBefore.WriteLine("if ({0} == IntPtr.Zero) {1} = {2};", instance, ret, + @class.IsRefType ? "null" : string.Format("new {0}()", type)); + var dtor = @class.Destructors.FirstOrDefault(); + var map = @class.IsRefType && dtor != null && dtor.IsVirtual; + if (map) + { + Context.SupportBefore.WriteLine( + "else if (CppSharp.Runtime.Helpers.NativeToManagedMap.ContainsKey({0}))", instance); + Context.SupportBefore.WriteLineIndent("{0} = ({1}) CppSharp.Runtime.Helpers.NativeToManagedMap[{2}];", + ret, type, instance); + Context.SupportBefore.WriteLine("else CppSharp.Runtime.Helpers.NativeToManagedMap[{3}] = {0} = {1}.{2}({3});", ret, type, + Helpers.CreateInstanceIdentifier, instance); + } + else + { + Context.SupportBefore.WriteLine("else {0} = {1}.{2}({3});", ret, type, + Helpers.CreateInstanceIdentifier, instance); + } + Context.Return.Write(ret); + } else Context.Return.Write("{0}.{1}({2})", type, Helpers.CreateInstanceIdentifier, instance); return true; } - private static bool FindTypeMap(ITypeMapDatabase typeMapDatabase, - Class @class, out TypeMap typeMap) - { - return typeMapDatabase.FindTypeMap(@class, out typeMap) || - (@class.HasBase && FindTypeMap(typeMapDatabase, @class.Bases[0].Class, out typeMap)); - } - public override bool VisitEnumDecl(Enumeration @enum) { Context.Return.Write("{0}", Context.ReturnVarName); diff --git a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs index 3dc5c89c..43d3ffca 100644 --- a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs +++ b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs @@ -44,6 +44,7 @@ namespace CppSharp.Generators.CSharp public static readonly string InstanceField = Generator.GeneratedIdentifier("instance"); public static readonly string InstanceIdentifier = Generator.GeneratedIdentifier("Instance"); public static readonly string ReturnIdentifier = Generator.GeneratedIdentifier("ret"); + public static readonly string DummyIdentifier = Generator.GeneratedIdentifier("dummy"); public static readonly string OwnsNativeInstanceIdentifier = Generator.GeneratedIdentifier("ownsNativeInstance"); @@ -1423,7 +1424,8 @@ namespace CppSharp.Generators.CSharp var ctx = new CSharpMarshalContext(Driver) { ReturnType = param.QualifiedType, - ReturnVarName = param.Name + ReturnVarName = param.Name, + ParameterIndex = i }; var marshal = new CSharpMarshalNativeToManagedPrinter(ctx); @@ -1758,20 +1760,24 @@ namespace CppSharp.Generators.CSharp WriteLine("void Dispose(bool disposing)"); WriteStartBraceIndent(); - if (ShouldGenerateClassNativeField(@class)) + var dtor = @class.Destructors.FirstOrDefault(); + if (@class.IsRefType && dtor != null && dtor.IsVirtual) { - var dtor = @class.Methods.FirstOrDefault(method => method.IsDestructor); - if (dtor != null) + WriteLine("object {0};", Helpers.DummyIdentifier); + WriteLine("CppSharp.Runtime.Helpers.NativeToManagedMap.TryRemove({0}, out {1});", + Helpers.InstanceIdentifier, Helpers.DummyIdentifier); + } + + if (ShouldGenerateClassNativeField(@class) && dtor != null) + { + if (dtor.Access != AccessSpecifier.Private && @class.HasNonTrivialDestructor && !dtor.IsPure) { - if (dtor.Access != AccessSpecifier.Private && @class.HasNonTrivialDestructor && !dtor.IsPure) + NativeLibrary library; + if (!Options.CheckSymbols || + Driver.Symbols.FindLibraryBySymbol(dtor.Mangled, out library)) { - NativeLibrary library; - if (!Options.CheckSymbols || - Driver.Symbols.FindLibraryBySymbol(dtor.Mangled, out library)) - { - WriteLine("Internal.{0}({1});", GetFunctionNativeIdentifier(dtor), - Helpers.InstanceIdentifier); - } + WriteLine("Internal.{0}({1});", GetFunctionNativeIdentifier(dtor), + Helpers.InstanceIdentifier); } } } @@ -2389,33 +2395,6 @@ namespace CppSharp.Generators.CSharp if (needsReturn) { - TypePrinter.PushContext(CSharpTypePrinterContextKind.Native); - var retTypeName = retType.CSharpType(TypePrinter).Type; - TypePrinter.PopContext(); - - var isIntPtr = retTypeName.Contains("IntPtr"); - - Type pointee; - if (retType.Type.IsPointerTo(out pointee) && isIntPtr) - { - pointee = pointee.Desugar(); - string @null; - Class @class; - if (pointee.TryGetClass(out @class) && @class.IsValueType) - { - @null = string.Format("new {0}()", pointee); - } - else - { - @null = (pointee.IsPrimitiveType() || - pointee.IsPointer()) && - !CSharpTypePrinter.IsConstCharString(retType) ? - "IntPtr.Zero" : "null"; - } - WriteLine("if ({0} == global::System.IntPtr.Zero) return {1};", - Generator.GeneratedIdentifier("ret"), @null); - } - var ctx = new CSharpMarshalContext(Driver) { ArgName = Helpers.ReturnIdentifier, diff --git a/src/Runtime/Helpers.cs b/src/Runtime/Helpers.cs index f0563bb8..5f5b0e29 100644 --- a/src/Runtime/Helpers.cs +++ b/src/Runtime/Helpers.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Runtime.InteropServices; namespace CppSharp.Runtime @@ -11,5 +12,12 @@ namespace CppSharp.Runtime [DllImport("libc", EntryPoint = "memcpy")] #endif public static extern IntPtr memcpy(IntPtr dest, IntPtr src, UIntPtr count); + + public static ConcurrentDictionary NativeToManagedMap + { + get { return nativeToManagedMap; } + } + + private static readonly ConcurrentDictionary nativeToManagedMap = new ConcurrentDictionary(); } } diff --git a/tests/CSharpTemp/CSharpTemp.Tests.cs b/tests/CSharpTemp/CSharpTemp.Tests.cs index d9a03bd7..f14c4185 100644 --- a/tests/CSharpTemp/CSharpTemp.Tests.cs +++ b/tests/CSharpTemp/CSharpTemp.Tests.cs @@ -1,5 +1,6 @@ using System; using System.Reflection; +using CppSharp.Runtime; using CSharpTemp; using CppSharp.Utils; using NUnit.Framework; @@ -206,4 +207,22 @@ public class CSharpTempTests : GeneratorTestFixture { QMap.Iterator test_iter; } + + [Test] + public void TestNativeToManagedMap() + { + IntPtr native1; + IntPtr native2; + using (var testNativeToManagedMap = new TestNativeToManagedMap()) + { + var hasVirtualDtor2 = testNativeToManagedMap.HasVirtualDtor2; + native2 = hasVirtualDtor2.__Instance; + native1 = hasVirtualDtor2.HasVirtualDtor1.__Instance; + Assert.IsTrue(Helpers.NativeToManagedMap.ContainsKey(native2)); + Assert.IsTrue(Helpers.NativeToManagedMap.ContainsKey(native1)); + Assert.AreSame(hasVirtualDtor2, testNativeToManagedMap.HasVirtualDtor2); + } + Assert.IsFalse(Helpers.NativeToManagedMap.ContainsKey(native2)); + Assert.IsFalse(Helpers.NativeToManagedMap.ContainsKey(native1)); + } } \ No newline at end of file diff --git a/tests/CSharpTemp/CSharpTemp.cpp b/tests/CSharpTemp/CSharpTemp.cpp index de875331..81877507 100644 --- a/tests/CSharpTemp/CSharpTemp.cpp +++ b/tests/CSharpTemp/CSharpTemp.cpp @@ -391,3 +391,41 @@ Foo StructWithPrivateFields::getComplexPrivateField() { return complexPrivateField; } + +HasVirtualDtor1::~HasVirtualDtor1() +{ +} + +HasVirtualDtor2::HasVirtualDtor2() +{ + hasVirtualDtor1 = new HasVirtualDtor1(); +} + +HasVirtualDtor2::~HasVirtualDtor2() +{ + delete hasVirtualDtor1; +} + +HasVirtualDtor1* HasVirtualDtor2::getHasVirtualDtor1() +{ + return hasVirtualDtor1; +} + +void HasVirtualDtor2::virtualFunction(const HasVirtualDtor1& param1, const HasVirtualDtor1& param2) +{ +} + +TestNativeToManagedMap::TestNativeToManagedMap() +{ + hasVirtualDtor2 = new HasVirtualDtor2(); +} + +TestNativeToManagedMap::~TestNativeToManagedMap() +{ + delete hasVirtualDtor2; +} + +HasVirtualDtor2* TestNativeToManagedMap::getHasVirtualDtor2() +{ + return hasVirtualDtor2; +} diff --git a/tests/CSharpTemp/CSharpTemp.h b/tests/CSharpTemp/CSharpTemp.h index 0d5a2409..9e5d4dbb 100644 --- a/tests/CSharpTemp/CSharpTemp.h +++ b/tests/CSharpTemp/CSharpTemp.h @@ -421,3 +421,30 @@ void TestPointers::TestTripleCharPointers(const char*** names) { } + +class DLL_API HasVirtualDtor1 +{ +public: + virtual ~HasVirtualDtor1(); +}; + +class DLL_API HasVirtualDtor2 +{ +public: + HasVirtualDtor2(); + virtual ~HasVirtualDtor2(); + HasVirtualDtor1* getHasVirtualDtor1(); + virtual void virtualFunction(const HasVirtualDtor1& param1, const HasVirtualDtor1& param2); +private: + HasVirtualDtor1* hasVirtualDtor1; +}; + +class DLL_API TestNativeToManagedMap +{ +public: + TestNativeToManagedMap(); + virtual ~TestNativeToManagedMap(); + HasVirtualDtor2* getHasVirtualDtor2(); +private: + HasVirtualDtor2* hasVirtualDtor2; +};