diff --git a/src/AST/ClassExtensions.cs b/src/AST/ClassExtensions.cs index bc02d13c..2355cc0b 100644 --- a/src/AST/ClassExtensions.cs +++ b/src/AST/ClassExtensions.cs @@ -43,6 +43,17 @@ namespace CppSharp.AST } } + public static Class GetRootBase(this Class @class) + { + while (true) + { + if (@class.BaseClass == null) + return @class; + + @class = @class.BaseClass; + } + } + public static Method GetRootBaseMethod(this Class c, Method @override, bool onlyFirstBase = false) { return (from @base in c.Bases diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index 35a08eb3..a29ae23a 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -259,41 +259,19 @@ namespace CppSharp.Generators.CSharp public override bool VisitClassDecl(Class @class) { - var instance = Context.ReturnVarName; - - @class = @class.OriginalClass ?? @class; + var originalClass = @class.OriginalClass ?? @class; Type returnType = Context.ReturnType.Type.Desugar(); - var type = QualifiedIdentifier(@class) + - (Context.Driver.Options.GenerateAbstractImpls && @class.IsAbstract ? + // if the class is an abstract impl, use the original for the object map + var qualifiedClass = QualifiedIdentifier(originalClass); + var type = qualifiedClass + + (Context.Driver.Options.GenerateAbstractImpls && originalClass.IsAbstract ? "Internal" : ""); if (returnType.IsAddress()) - { - var ret = Generator.GeneratedIdentifier("result") + Context.ParameterIndex; - Context.SupportBefore.WriteLine("{0} {1};", type, ret); - Context.SupportBefore.WriteLine("if ({0} == IntPtr.Zero) {1} = {2};", instance, ret, - @class.IsRefType ? "null" : string.Format("new {0}()", type)); - var dtor = @class.Destructors.FirstOrDefault(); - var map = @class.IsRefType && dtor != null && dtor.IsVirtual; - if (map) - { - Context.SupportBefore.WriteLine( - "else if (CppSharp.Runtime.Helpers.NativeToManagedMap.ContainsKey({0}))", instance); - Context.SupportBefore.WriteLineIndent("{0} = ({1}) CppSharp.Runtime.Helpers.NativeToManagedMap[{2}];", - ret, type, instance); - Context.SupportBefore.WriteLine("else CppSharp.Runtime.Helpers.NativeToManagedMap[{3}] = {0} = {1}.{2}({3});", ret, type, - Helpers.CreateInstanceIdentifier, instance); - } - else - { - Context.SupportBefore.WriteLine("else {0} = {1}.{2}({3});", ret, type, - Helpers.CreateInstanceIdentifier, instance); - } - Context.Return.Write(ret); - } + Context.Return.Write(HandleReturnedPointer(@class, type, qualifiedClass)); else - Context.Return.Write("{0}.{1}({2})", type, Helpers.CreateInstanceIdentifier, instance); + Context.Return.Write("{0}.{1}({2})", type, Helpers.CreateInstanceIdentifier, Context.ReturnVarName); return true; } @@ -330,6 +308,40 @@ namespace CppSharp.Generators.CSharp Context.Return.Write("_{0}", parameter.Name); return true; } + + private string HandleReturnedPointer(Class @class, string type, string qualifiedClass) + { + var originalClass = @class.OriginalClass ?? @class; + var ret = Generator.GeneratedIdentifier("result") + Context.ParameterIndex; + Context.SupportBefore.WriteLine("{0} {1};", @class.Name, ret); + Context.SupportBefore.WriteLine("if ({0} == IntPtr.Zero) {1} = {2};", Context.ReturnVarName, ret, + originalClass.IsRefType ? "null" : string.Format("new {0}()", type)); + if (originalClass.IsRefType) + { + Context.SupportBefore.WriteLine( + "else if ({0}.NativeToManagedMap.ContainsKey({1}))", qualifiedClass, Context.ReturnVarName); + Context.SupportBefore.WriteLineIndent("{0} = ({1}) {2}.NativeToManagedMap[{3}];", + ret, @class.Name, qualifiedClass, Context.ReturnVarName); + var dtor = originalClass.Destructors.FirstOrDefault(); + if (dtor != null && dtor.IsVirtual) + { + Context.SupportBefore.WriteLine("else {0}.NativeToManagedMap[{1}] = {2} = {3}.{4}({1});", + qualifiedClass, Context.ReturnVarName, ret, type, + Helpers.CreateInstanceIdentifier, Context.ReturnVarName); + } + else + { + Context.SupportBefore.WriteLine("else {0} = {1}.{2}({3});", ret, type, + Helpers.CreateInstanceIdentifier, Context.ReturnVarName); + } + } + else + { + Context.SupportBefore.WriteLine("else {0} = {1}.{2}({3});", ret, type, + Helpers.CreateInstanceIdentifier, Context.ReturnVarName); + } + return ret; + } } public class CSharpMarshalManagedToNativePrinter : CSharpMarshalPrinter diff --git a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs index 43d3ffca..dc8353b4 100644 --- a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs +++ b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs @@ -399,6 +399,12 @@ namespace CppSharp.Generators.CSharp { WriteLine("public {0} {1} {{ get; protected set; }}", "global::System.IntPtr", Helpers.InstanceIdentifier); + + // use interfaces if any - derived types with a secondary base this class must be compatible with the map + var @interface = @class.Namespace.Classes.Find(c => c.OriginalClass == @class); + WriteLine( + "public static readonly System.Collections.Concurrent.ConcurrentDictionary NativeToManagedMap = new System.Collections.Concurrent.ConcurrentDictionary();", + @interface != null ? @interface.Name : @class.Name); } PopBlock(NewLineKind.BeforeNextBlock); } @@ -1760,14 +1766,21 @@ namespace CppSharp.Generators.CSharp WriteLine("void Dispose(bool disposing)"); WriteStartBraceIndent(); - var dtor = @class.Destructors.FirstOrDefault(); - if (@class.IsRefType && dtor != null && dtor.IsVirtual) + if (@class.IsRefType) { - WriteLine("object {0};", Helpers.DummyIdentifier); - WriteLine("CppSharp.Runtime.Helpers.NativeToManagedMap.TryRemove({0}, out {1});", + var @base = @class.GetRootBase(); + var className = @base.IsAbstractImpl ? @base.BaseClass.Name : @base.Name; + + // Use interfaces if any - derived types with a this class as a seconary base, must be compatible with the map + var @interface = @base.Namespace.Classes.Find(c => c.OriginalClass == @base); + + // The local var must be of the exact type in the object map because of TryRemove + WriteLine("{0} {1};", @interface != null ? @interface.Name : className, Helpers.DummyIdentifier); + WriteLine("NativeToManagedMap.TryRemove({0}, out {1});", Helpers.InstanceIdentifier, Helpers.DummyIdentifier); } + var dtor = @class.Destructors.FirstOrDefault(); if (ShouldGenerateClassNativeField(@class) && dtor != null) { if (dtor.Access != AccessSpecifier.Private && @class.HasNonTrivialDestructor && !dtor.IsPure) @@ -1806,32 +1819,26 @@ namespace CppSharp.Generators.CSharp PopBlock(NewLineKind.BeforeNextBlock); } - string className = @class.Name; - string safeIdentifier = className; - if (@class.IsAbstractImpl) - { - className = className.Substring(0, - safeIdentifier.LastIndexOf("Internal", StringComparison.Ordinal)); - } + string className = @class.IsAbstractImpl ? @class.BaseClass.Name : @class.Name; if (!@class.IsAbstract) { PushBlock(CSharpBlockKind.Method); WriteLine("public static {0}{1} {2}(global::System.IntPtr native)", @class.HasNonIgnoredBase && !@class.BaseClass.IsAbstract ? "new " : string.Empty, - safeIdentifier, Helpers.CreateInstanceIdentifier); + @class.Name, Helpers.CreateInstanceIdentifier); WriteStartBraceIndent(); - WriteLine("return new {0}(({1}.Internal*) native);", safeIdentifier, className); + WriteLine("return new {0}(({1}.Internal*) native);", @class.Name, className); WriteCloseBraceIndent(); PopBlock(NewLineKind.BeforeNextBlock); - GenerateNativeConstructorByValue(@class, className, safeIdentifier); + GenerateNativeConstructorByValue(@class, className, @class.Name); } PushBlock(CSharpBlockKind.Method); WriteLine("{0} {1}({2}.Internal* native, bool isInternalImpl = false){3}", @class.IsRefType ? "protected" : "private", - safeIdentifier, className, @class.IsValueType ? " : this()" : string.Empty); + @class.Name, className, @class.IsValueType ? " : this()" : string.Empty); var hasBaseClass = @class.HasBaseClass && @class.BaseClass.IsRefType; if (hasBaseClass) @@ -1903,6 +1910,7 @@ namespace CppSharp.Generators.CSharp if (@class.IsRefType) { WriteLine("{0} = true;", Helpers.OwnsNativeInstanceIdentifier); + WriteLine("NativeToManagedMap[{0}] = this;", Helpers.InstanceIdentifier); } else { @@ -2217,6 +2225,7 @@ namespace CppSharp.Generators.CSharp WriteLine("{0} = Marshal.AllocHGlobal({1});", Helpers.InstanceIdentifier, @class.Layout.Size); WriteLine("{0} = true;", Helpers.OwnsNativeInstanceIdentifier); + WriteLine("NativeToManagedMap[{0}] = this;", Helpers.InstanceIdentifier); if (method.IsCopyConstructor) { @@ -2399,7 +2408,8 @@ namespace CppSharp.Generators.CSharp { ArgName = Helpers.ReturnIdentifier, ReturnVarName = Helpers.ReturnIdentifier, - ReturnType = retType + ReturnType = retType, + Parameter = operatorParam }; var marshal = new CSharpMarshalNativeToManagedPrinter(ctx); diff --git a/src/Generator/Passes/ParamTypeToInterfacePass.cs b/src/Generator/Passes/ParamTypeToInterfacePass.cs index b83275e4..34f1cea2 100644 --- a/src/Generator/Passes/ParamTypeToInterfacePass.cs +++ b/src/Generator/Passes/ParamTypeToInterfacePass.cs @@ -20,7 +20,7 @@ namespace CppSharp.Passes private static void ChangeToInterfaceType(QualifiedType type) { - var tagType = type.Type.SkipPointerRefs() as TagType; + var tagType = type.Type.GetFinalPointee() as TagType; if (tagType != null) { var @class = tagType.Declaration as Class; diff --git a/src/Runtime/Helpers.cs b/src/Runtime/Helpers.cs index 5f5b0e29..7a3bd791 100644 --- a/src/Runtime/Helpers.cs +++ b/src/Runtime/Helpers.cs @@ -12,12 +12,5 @@ namespace CppSharp.Runtime [DllImport("libc", EntryPoint = "memcpy")] #endif public static extern IntPtr memcpy(IntPtr dest, IntPtr src, UIntPtr count); - - public static ConcurrentDictionary NativeToManagedMap - { - get { return nativeToManagedMap; } - } - - private static readonly ConcurrentDictionary nativeToManagedMap = new ConcurrentDictionary(); } } diff --git a/tests/CSharpTemp/CSharpTemp.Tests.cs b/tests/CSharpTemp/CSharpTemp.Tests.cs index f14c4185..4ee90888 100644 --- a/tests/CSharpTemp/CSharpTemp.Tests.cs +++ b/tests/CSharpTemp/CSharpTemp.Tests.cs @@ -209,7 +209,7 @@ public class CSharpTempTests : GeneratorTestFixture } [Test] - public void TestNativeToManagedMap() + public void TestNativeToManagedMapWithForeignObjects() { IntPtr native1; IntPtr native2; @@ -218,11 +218,25 @@ public class CSharpTempTests : GeneratorTestFixture var hasVirtualDtor2 = testNativeToManagedMap.HasVirtualDtor2; native2 = hasVirtualDtor2.__Instance; native1 = hasVirtualDtor2.HasVirtualDtor1.__Instance; - Assert.IsTrue(Helpers.NativeToManagedMap.ContainsKey(native2)); - Assert.IsTrue(Helpers.NativeToManagedMap.ContainsKey(native1)); + Assert.IsTrue(HasVirtualDtor2.NativeToManagedMap.ContainsKey(native2)); + Assert.IsTrue(HasVirtualDtor1.NativeToManagedMap.ContainsKey(native1)); Assert.AreSame(hasVirtualDtor2, testNativeToManagedMap.HasVirtualDtor2); } - Assert.IsFalse(Helpers.NativeToManagedMap.ContainsKey(native2)); - Assert.IsFalse(Helpers.NativeToManagedMap.ContainsKey(native1)); + Assert.IsFalse(HasVirtualDtor2.NativeToManagedMap.ContainsKey(native2)); + Assert.IsFalse(HasVirtualDtor1.NativeToManagedMap.ContainsKey(native1)); + } + + [Test] + public void TestNativeToManagedMapWithOwnObjects() + { + using (var testNativeToManagedMap = new TestNativeToManagedMap()) + { + var bar = new Bar(); + testNativeToManagedMap.PropertyWithNoVirtualDtor = bar; + Assert.AreSame(bar, testNativeToManagedMap.PropertyWithNoVirtualDtor); + Assert.IsTrue(Bar.NativeToManagedMap.ContainsKey(bar.__Instance)); + bar.Dispose(); + Assert.IsFalse(Bar.NativeToManagedMap.ContainsKey(bar.__Instance)); + } } } \ No newline at end of file diff --git a/tests/CSharpTemp/CSharpTemp.cpp b/tests/CSharpTemp/CSharpTemp.cpp index 81877507..9f7ebda3 100644 --- a/tests/CSharpTemp/CSharpTemp.cpp +++ b/tests/CSharpTemp/CSharpTemp.cpp @@ -84,6 +84,15 @@ void Qux::obsolete() } +Qux* Qux::getInterface() +{ + return this; +} + +void Qux::setInterface(Qux *qux) +{ +} + int Bar::method() { return 2; @@ -429,3 +438,13 @@ HasVirtualDtor2* TestNativeToManagedMap::getHasVirtualDtor2() { return hasVirtualDtor2; } + +Bar* TestNativeToManagedMap::propertyWithNoVirtualDtor() const +{ + return bar; +} + +void TestNativeToManagedMap::setPropertyWithNoVirtualDtor(Bar* bar) +{ + this->bar = bar; +} diff --git a/tests/CSharpTemp/CSharpTemp.h b/tests/CSharpTemp/CSharpTemp.h index 9e5d4dbb..207f00de 100644 --- a/tests/CSharpTemp/CSharpTemp.h +++ b/tests/CSharpTemp/CSharpTemp.h @@ -37,6 +37,8 @@ public: int farAwayFunc() const; int array[3]; void obsolete(); + Qux* getInterface(); + void setInterface(Qux* qux); }; class DLL_API Bar : public Qux @@ -445,6 +447,9 @@ public: TestNativeToManagedMap(); virtual ~TestNativeToManagedMap(); HasVirtualDtor2* getHasVirtualDtor2(); + Bar* propertyWithNoVirtualDtor() const; + void setPropertyWithNoVirtualDtor(Bar* bar); private: HasVirtualDtor2* hasVirtualDtor2; + Bar* bar; };