From a6426913487ca241767967705b880eb467cada71 Mon Sep 17 00:00:00 2001 From: Fabio Anderegg Date: Thu, 6 Oct 2022 00:39:58 +0100 Subject: [PATCH] Call destructor on copied arguments when calling C# method from C++ (MS ABIs only) (#1685) * on MS abi call destructor on copy-by-value arguments after call to c# function * add tests for destructor call on call by value from c++ to c# * copy-by-value destructor call using Dispose() instead of Internal.dtor to handle destructors in base class --- .../Generators/CSharp/CSharpSources.cs | 33 ++++++++++- src/Generator/Generators/CodeGenerator.cs | 2 +- tests/CSharp/CSharp.Tests.cs | 52 +++++++++++++++++ tests/CSharp/CSharp.cpp | 56 +++++++++++++++++++ tests/CSharp/CSharp.h | 26 +++++++++ 5 files changed, 166 insertions(+), 3 deletions(-) diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 320b9e32..fc4aa9c6 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -1918,7 +1918,7 @@ internal static bool {Helpers.TryGetNativeToManagedMappingIdentifier}(IntPtr nat var hasReturn = !isVoid && !isSetter; if (hasReturn) - Write(isPrimitive && !isSetter ? "return " : $"var {Helpers.ReturnIdentifier} = "); + Write($"var {Helpers.ReturnIdentifier} = "); Write($"{Helpers.TargetIdentifier}."); string marshalsCode = string.Join(", ", marshals); @@ -1933,8 +1933,37 @@ internal static bool {Helpers.TryGetNativeToManagedMappingIdentifier}(IntPtr nat Write($" = {marshalsCode}"); } WriteLine(";"); - if (isPrimitive && !isSetter) + + // on Microsoft ABIs, the destructor on copy-by-value parameters is + // called by the called function, not the caller, so we are generating + // code to do that for classes that have a non-trivial destructor. + if (Context.ParserOptions.IsMicrosoftAbi) + { + for (int i = 0; i < method.Parameters.Count; i++) + { + var param = method.Parameters[i]; + if (param.Ignore) + continue; + + if (param.Kind == ParameterKind.IndirectReturnType) + continue; + + var paramType = param.Type.GetFinalPointee(); + + if (param.IsIndirect && + paramType.TryGetClass(out Class paramClass) && !(paramClass is ClassTemplateSpecialization) && + paramClass.HasNonTrivialDestructor) + { + WriteLine($"{Generator.GeneratedIdentifier("result")}{i}.Dispose(false, true);"); + } + } + } + + if (hasReturn && isPrimitive && !isSetter) + { + WriteLine($"return { Helpers.ReturnIdentifier};"); return; + } if (hasReturn) { diff --git a/src/Generator/Generators/CodeGenerator.cs b/src/Generator/Generators/CodeGenerator.cs index f6d7d1d9..39eb7e49 100644 --- a/src/Generator/Generators/CodeGenerator.cs +++ b/src/Generator/Generators/CodeGenerator.cs @@ -1292,7 +1292,7 @@ namespace CppSharp.Generators public static readonly string InstanceField = Generator.GeneratedIdentifier("instance"); public static readonly string InstanceIdentifier = Generator.GeneratedIdentifier("Instance"); public static readonly string PrimaryBaseOffsetIdentifier = Generator.GeneratedIdentifier("PrimaryBaseOffset"); - public static readonly string ReturnIdentifier = Generator.GeneratedIdentifier("ret"); + public static readonly string ReturnIdentifier = "___ret"; public static readonly string DummyIdentifier = Generator.GeneratedIdentifier("dummy"); public static readonly string TargetIdentifier = Generator.GeneratedIdentifier("target"); public static readonly string SlotIdentifier = Generator.GeneratedIdentifier("slot"); diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index 9b5b1e53..761e38f1 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -1925,4 +1925,56 @@ public unsafe class CSharpTests Assert.That(myclass, Is.Not.SameAs(backup)); } + + class CallByValueInterfaceImpl : CallByValueInterface + { + public override void CallByValue(RuleOfThreeTester value) + { + } + public override void CallByReference(RuleOfThreeTester value) + { + } + public override void CallByPointer(RuleOfThreeTester value) + { + } + } + + [Test] + public void TestCallByValueCppToCSharpValue() + { + RuleOfThreeTester.Reset(); + CallByValueInterface @interface = new CallByValueInterfaceImpl(); + CSharp.CSharp.CallCallByValueInterfaceValue(@interface); + + Assert.That(RuleOfThreeTester.ConstructorCalls, Is.EqualTo(1)); + Assert.That(RuleOfThreeTester.DestructorCalls, Is.EqualTo(2)); + Assert.That(RuleOfThreeTester.CopyConstructorCalls, Is.EqualTo(1)); + Assert.That(RuleOfThreeTester.CopyAssignmentCalls, Is.EqualTo(0)); + } + + [Test] + public void TestCallByValueCppToCSharpReference() + { + RuleOfThreeTester.Reset(); + CallByValueInterface @interface = new CallByValueInterfaceImpl(); + CSharp.CSharp.CallCallByValueInterfaceReference(@interface); + + Assert.That(RuleOfThreeTester.ConstructorCalls, Is.EqualTo(1)); + Assert.That(RuleOfThreeTester.DestructorCalls, Is.EqualTo(1)); + Assert.That(RuleOfThreeTester.CopyConstructorCalls, Is.EqualTo(0)); + Assert.That(RuleOfThreeTester.CopyAssignmentCalls, Is.EqualTo(0)); + } + + [Test] + public void TestCallByValueCppToCSharpPointer() + { + RuleOfThreeTester.Reset(); + CallByValueInterface @interface = new CallByValueInterfaceImpl(); + CSharp.CSharp.CallCallByValueInterfacePointer(@interface); + + Assert.That(RuleOfThreeTester.ConstructorCalls, Is.EqualTo(1)); + Assert.That(RuleOfThreeTester.DestructorCalls, Is.EqualTo(1)); + Assert.That(RuleOfThreeTester.CopyConstructorCalls, Is.EqualTo(0)); + Assert.That(RuleOfThreeTester.CopyAssignmentCalls, Is.EqualTo(0)); + } } diff --git a/tests/CSharp/CSharp.cpp b/tests/CSharp/CSharp.cpp index 372a0e82..fa8dbd5d 100644 --- a/tests/CSharp/CSharp.cpp +++ b/tests/CSharp/CSharp.cpp @@ -1691,3 +1691,59 @@ DLL_API int TestFunctionToInstanceMethodStruct(FTIStruct* bb, FTIStruct defaultV DLL_API int TestFunctionToInstanceMethodRefStruct(FTIStruct* bb, FTIStruct& defaultValue) { return defaultValue.a; } DLL_API int TestFunctionToInstanceMethodConstStruct(FTIStruct* bb, const FTIStruct defaultValue) { return defaultValue.a; } DLL_API int TestFunctionToInstanceMethodConstRefStruct(FTIStruct* bb, const FTIStruct& defaultValue) { return defaultValue.a; } + +int RuleOfThreeTester::constructorCalls = 0; +int RuleOfThreeTester::destructorCalls = 0; +int RuleOfThreeTester::copyConstructorCalls = 0; +int RuleOfThreeTester::copyAssignmentCalls = 0; + +void RuleOfThreeTester::reset() +{ + constructorCalls = 0; + destructorCalls = 0; + copyConstructorCalls = 0; + copyAssignmentCalls = 0; +} + +RuleOfThreeTester::RuleOfThreeTester() +{ + a = 0; + constructorCalls++; +} + +RuleOfThreeTester::RuleOfThreeTester(const RuleOfThreeTester& other) +{ + a = other.a; + copyConstructorCalls++; +} + +RuleOfThreeTester::~RuleOfThreeTester() +{ + destructorCalls++; +} + +RuleOfThreeTester& RuleOfThreeTester::operator=(const RuleOfThreeTester& other) +{ + a = other.a; + copyAssignmentCalls++; + return *this; +} + +// test if generated code correctly calls constructors and destructors when going from C++ to C# +void CallCallByValueInterfaceValue(CallByValueInterface* interface) +{ + RuleOfThreeTester value; + interface->CallByValue(value); +} + +void CallCallByValueInterfaceReference(CallByValueInterface* interface) +{ + RuleOfThreeTester value; + interface->CallByReference(value); +} + +void CallCallByValueInterfacePointer(CallByValueInterface* interface) +{ + RuleOfThreeTester value; + interface->CallByPointer(&value); +} diff --git a/tests/CSharp/CSharp.h b/tests/CSharp/CSharp.h index caa13b87..5bc46f9e 100644 --- a/tests/CSharp/CSharp.h +++ b/tests/CSharp/CSharp.h @@ -1551,3 +1551,29 @@ DLL_API inline ClassWithIntValue* CreateCore(CS_IN_OUT ClassWithIntValue*& pClas pClass->value = 20; return nullptr; } + + +struct DLL_API RuleOfThreeTester { + int a; + static int constructorCalls; + static int destructorCalls; + static int copyConstructorCalls; + static int copyAssignmentCalls; + + static void reset(); + + RuleOfThreeTester(); + ~RuleOfThreeTester(); + RuleOfThreeTester(const RuleOfThreeTester& other); + RuleOfThreeTester& operator=(const RuleOfThreeTester& other); +}; + +struct DLL_API CallByValueInterface { + virtual void CallByValue(RuleOfThreeTester value) = 0; + virtual void CallByReference(RuleOfThreeTester& value) = 0; + virtual void CallByPointer(RuleOfThreeTester* value) = 0; +}; + +void DLL_API CallCallByValueInterfaceValue(CallByValueInterface*); +void DLL_API CallCallByValueInterfaceReference(CallByValueInterface*); +void DLL_API CallCallByValueInterfacePointer(CallByValueInterface*);