From 060afd53fce65cafed7231afb3ff649db5596e22 Mon Sep 17 00:00:00 2001 From: josetr <37419832+josetr@users.noreply.github.com> Date: Tue, 5 Jul 2022 01:09:30 +0100 Subject: [PATCH] Add partial `ref` param support --- src/AST/Assembly.cs | 3 +++ src/AST/TypeExtensions.cs | 26 +++++++++++++++++++ .../Generators/CSharp/CSharpSources.cs | 11 ++++++++ tests/CSharp/CSharp.Tests.cs | 17 ++++++++++++ tests/CSharp/CSharp.h | 17 +++++++++++- 5 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 src/AST/Assembly.cs diff --git a/src/AST/Assembly.cs b/src/AST/Assembly.cs new file mode 100644 index 00000000..8c01cdba --- /dev/null +++ b/src/AST/Assembly.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("CppSharp.Generator")] diff --git a/src/AST/TypeExtensions.cs b/src/AST/TypeExtensions.cs index ed740fb4..d929563d 100644 --- a/src/AST/TypeExtensions.cs +++ b/src/AST/TypeExtensions.cs @@ -442,5 +442,31 @@ { return array.Size * array.ElementSize; } + + internal static bool IsReferenceToPtrToClass(this Type type) + { + var @ref = type.Desugar().AsLvReference(); + if (@ref == null) + return false; + + var @ptr = @ref.Pointee.Desugar(false).AsPtr(); + if (@ptr == null) + return false; + + var @class = @ptr.Pointee; + return @class != null && @class.IsClass(); + } + + internal static PointerType AsLvReference(this Type type) + { + var reference = type as PointerType; + return reference?.Modifier == PointerType.TypeModifier.LVReference ? reference : null; + } + + internal static PointerType AsPtr(this Type type) + { + var ptr = type as PointerType; + return ptr?.Modifier == PointerType.TypeModifier.Pointer ? ptr : null; + } } } \ No newline at end of file diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 4453567b..6fa40255 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -3193,6 +3193,17 @@ WriteLines($@" } WriteLine("{0}({1});", functionName, string.Join(", ", names)); + foreach(var param in @params) + { + if (param.Param.IsInOut && param.Param.Type.IsReferenceToPtrToClass()) + { + var qualifiedClass = param.Param.Type.Visit(TypePrinter); + + WriteLine($"if ({param.Name} != {param.Param.Name}.__Instance)"); + WriteLine($"{param.Param.Name} = {qualifiedClass}.__GetOrCreateInstance(__{param.Name}, false);"); + } + } + foreach (TextGenerator cleanup in from p in @params select p.Context.Cleanup) Write(cleanup); diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index a24c582a..9b5b1e53 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -1908,4 +1908,21 @@ public unsafe class CSharpTests Assert.That(CSharpTemplates.FunctionTemplate(6f), Is.EqualTo(6 + 4.1f)); Assert.That(CSharpTemplates.FunctionTemplate(7), Is.EqualTo(7 + 4)); } + + [Test] + public void TestPatialRefSupport() + { + var myclass = new ClassWithIntValue(); + var backup = myclass; + myclass.Value = 7; + + CSharp.CSharp.ModifyCore(ref myclass); + Assert.That(myclass.Value, Is.EqualTo(10)); + Assert.That(myclass, Is.SameAs(myclass)); + + CSharp.CSharp.CreateCore(ref myclass); + Assert.That(myclass.Value, Is.EqualTo(20)); + Assert.That(myclass, Is.Not.SameAs(backup)); + } + } diff --git a/tests/CSharp/CSharp.h b/tests/CSharp/CSharp.h index 4aee98d5..caa13b87 100644 --- a/tests/CSharp/CSharp.h +++ b/tests/CSharp/CSharp.h @@ -1535,4 +1535,19 @@ DLL_API int TestFunctionToInstanceMethodRefStruct(FTIStruct* bb, FTIStruct& defa DLL_API int TestFunctionToInstanceMethodConstStruct(FTIStruct* bb, const FTIStruct defaultValue); DLL_API int TestFunctionToInstanceMethodConstRefStruct(FTIStruct* bb, const FTIStruct& defaultValue); -class ClassWithoutNativeToManaged { }; \ No newline at end of file +class ClassWithoutNativeToManaged { }; + +struct DLL_API ClassWithIntValue { + int value; +}; + +DLL_API inline ClassWithIntValue* ModifyCore(CS_IN_OUT ClassWithIntValue*& pClass) { + pClass->value = 10; + return nullptr; +} + +DLL_API inline ClassWithIntValue* CreateCore(CS_IN_OUT ClassWithIntValue*& pClass) { + pClass = new ClassWithIntValue(); + pClass->value = 20; + return nullptr; +}