diff --git a/src/Generator/Generators/CLI/CLIMarshal.cs b/src/Generator/Generators/CLI/CLIMarshal.cs index 2bfe6580..4ad3d326 100644 --- a/src/Generator/Generators/CLI/CLIMarshal.cs +++ b/src/Generator/Generators/CLI/CLIMarshal.cs @@ -305,7 +305,7 @@ namespace CppSharp.Generators.CLI Context.Return.Write("({0} == nullptr) ? nullptr : gcnew ", instance); - Context.Return.Write("{0}(", QualifiedIdentifier(@class)); + Context.Return.Write("::{0}(", QualifiedIdentifier(@class)); Context.Return.Write("(::{0}*)", @class.QualifiedOriginalName); Context.Return.Write("{0}{1})", instance, ownNativeInstance ? ", true" : ""); } @@ -433,39 +433,42 @@ namespace CppSharp.Generators.CLI case ArrayType.ArraySize.Constant: if (string.IsNullOrEmpty(Context.ReturnVarName)) { - const string pinnedPtr = "__pinnedPtr"; - Context.Before.WriteLine("cli::pin_ptr<{0}> {1} = &{2}[0];", - array.Type, pinnedPtr, Context.Parameter.Name); - const string arrayPtr = "__arrayPtr"; - Context.Before.WriteLine("{0}* {1} = {2};", array.Type, arrayPtr, pinnedPtr); - Context.Return.Write("({0} (&)[{1}]) {2}", array.Type, array.Size, arrayPtr); + string arrayPtrRet = $"__{Context.ParameterIndex}ArrayPtr"; + Context.Before.WriteLine($"{array.Type} {arrayPtrRet}[{array.Size}];"); + + Context.ReturnVarName = arrayPtrRet; + + Context.Return.Write(arrayPtrRet); } - else + + bool isPointerToPrimitive = array.Type.IsPointerToPrimitiveType(PrimitiveType.Void); + bool isPrimitive = array.Type.IsPrimitiveType(); + var supportBefore = Context.Before; + supportBefore.WriteLine("if ({0} != nullptr)", Context.Parameter.Name); + supportBefore.WriteOpenBraceAndIndent(); + + supportBefore.WriteLine($"if ({Context.Parameter.Name}->Length != {array.Size})"); + supportBefore.WriteOpenBraceAndIndent(); + supportBefore.WriteLine($"throw gcnew System::InvalidOperationException(\"Source array size must equal destination array size.\");"); + supportBefore.UnindentAndWriteCloseBrace(); + + string nativeVal = string.Empty; + if (isPointerToPrimitive) + { + nativeVal = ".ToPointer()"; + } + else if (!isPrimitive) { - bool isPointerToPrimitive = array.Type.IsPointerToPrimitiveType(PrimitiveType.Void); - bool isPrimitive = array.Type.IsPrimitiveType(); - var supportBefore = Context.Before; - supportBefore.WriteLine("if ({0} != nullptr)", Context.ArgName); - supportBefore.WriteOpenBraceAndIndent(); - - string nativeVal = string.Empty; - if (isPointerToPrimitive) - { - nativeVal = ".ToPointer()"; - } - else if (!isPrimitive) - { - nativeVal = "->NativePtr"; - } - - supportBefore.WriteLine("for (int i = 0; i < {0}; i++)", array.Size); - supportBefore.WriteLineIndent("{0}[i] = {1}{2}[i]{3};", - Context.ReturnVarName, - isPointerToPrimitive || isPrimitive ? string.Empty : "*", - Context.ArgName, - nativeVal); - supportBefore.UnindentAndWriteCloseBrace(); + nativeVal = "->NativePtr"; } + + supportBefore.WriteLine("for (int i = 0; i < {0}; i++)", array.Size); + supportBefore.WriteLineIndent("{0}[i] = {1}{2}[i]{3};", + Context.ReturnVarName, + isPointerToPrimitive || isPrimitive ? string.Empty : "*", + Context.Parameter.Name, + nativeVal); + supportBefore.UnindentAndWriteCloseBrace(); break; default: Context.Return.Write("null"); @@ -778,7 +781,8 @@ namespace CppSharp.Generators.CLI { ArgName = fieldRef, ParameterIndex = Context.ParameterIndex++, - MarshalVarPrefix = Context.MarshalVarPrefix + MarshalVarPrefix = Context.MarshalVarPrefix, + ReturnVarName = $"{marshalVar}.{property.Field.OriginalName}" }; var marshal = new CLIMarshalManagedToNativePrinter(marshalCtx); @@ -789,23 +793,26 @@ namespace CppSharp.Generators.CLI if (!string.IsNullOrWhiteSpace(marshal.Context.Before)) Context.Before.Write(marshal.Context.Before); - Type type; - Class @class; - var isRef = property.Type.IsPointerTo(out type) && - !(type.TryGetClass(out @class) && @class.IsValueType) && - !type.IsPrimitiveType(); - - if (isRef) + if (!string.IsNullOrWhiteSpace(marshal.Context.Return)) { - Context.Before.WriteLine("if ({0} != nullptr)", fieldRef); - Context.Before.Indent(); - } + Type type; + Class @class; + var isRef = property.Type.IsPointerTo(out type) && + !(type.TryGetClass(out @class) && @class.IsValueType) && + !type.IsPrimitiveType(); - Context.Before.WriteLine("{0}.{1} = {2};", marshalVar, - property.Field.OriginalName, marshal.Context.Return); + if (isRef) + { + Context.Before.WriteLine("if ({0} != nullptr)", fieldRef); + Context.Before.Indent(); + } - if (isRef) - Context.Before.Unindent(); + Context.Before.WriteLine("{0}.{1} = {2};", marshalVar, + property.Field.OriginalName, marshal.Context.Return); + + if (isRef) + Context.Before.Unindent(); + } } public override bool VisitFieldDecl(Field field) diff --git a/tests/CLI/CLI.Tests.cs b/tests/CLI/CLI.Tests.cs index bf3d632d..b41956f3 100644 --- a/tests/CLI/CLI.Tests.cs +++ b/tests/CLI/CLI.Tests.cs @@ -1,6 +1,8 @@ using CppSharp.Utils; using NUnit.Framework; using CLI; +using System.Text; +using System; public class CLITests : GeneratorTestFixture { @@ -66,4 +68,114 @@ public class CLITests : GeneratorTestFixture Assert.AreEqual("VectorPointerGetter", list[0]); } } + + [Test] + public void TestMultipleConstantArraysParamsTestMethod() + { + byte[] bytes = Encoding.ASCII.GetBytes("TestMulti"); + sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q)); + + byte[] bytes2 = Encoding.ASCII.GetBytes("TestMulti2"); + sbyte[] sbytes2 = Array.ConvertAll(bytes2, q => Convert.ToSByte(q)); + + string s = CLI.CLI.MultipleConstantArraysParamsTestMethod(sbytes, sbytes2); + Assert.AreEqual("TestMultiTestMulti2", s); + } + + [Test] + public void TestMultipleConstantArraysParamsTestMethodLongerSourceArray() + { + byte[] bytes = Encoding.ASCII.GetBytes("TestMultipleConstantArraysParamsTestMethodLongerSourceArray"); + sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q)); + + Assert.Throws(() => CLI.CLI.MultipleConstantArraysParamsTestMethod(sbytes, new sbyte[] { })); + } + + [Test] + public void TestStructWithNestedUnionTestMethod() + { + using (var val = new StructWithNestedUnion()) + { + byte[] bytes = Encoding.ASCII.GetBytes("TestUnions"); + sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q)); + + UnionNestedInsideStruct unionNestedInsideStruct; + unionNestedInsideStruct.SzText = sbytes; + + Assert.AreEqual(sbytes.Length, unionNestedInsideStruct.SzText.Length); + Assert.AreEqual("TestUnions", unionNestedInsideStruct.SzText); + + val.NestedUnion = unionNestedInsideStruct; + + Assert.AreEqual(10, val.NestedUnion.SzText.Length); + Assert.AreEqual("TestUnions", val.NestedUnion.SzText); + + string ret = CLI.CLI.StructWithNestedUnionTestMethod(val); + + Assert.AreEqual("TestUnions", ret); + } + } + + [Test] + public void TestStructWithNestedUnionLongerSourceArray() + { + using (var val = new StructWithNestedUnion()) + { + byte[] bytes = Encoding.ASCII.GetBytes("TestStructWithNestedUnionLongerSourceArray"); + sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q)); + + UnionNestedInsideStruct unionNestedInsideStruct; + unionNestedInsideStruct.SzText = sbytes; + + Assert.Throws(() => val.NestedUnion = unionNestedInsideStruct); + } + } + + [Test] + public void TestUnionWithNestedStructTestMethod() + { + using (var val = new StructNestedInsideUnion()) + { + byte[] bytes = Encoding.ASCII.GetBytes("TestUnions"); + sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q)); + val.SzText = sbytes; + + UnionWithNestedStruct unionWithNestedStruct; + unionWithNestedStruct.NestedStruct = val; + + Assert.AreEqual(10, unionWithNestedStruct.NestedStruct.SzText.Length); + Assert.AreEqual("TestUnions", unionWithNestedStruct.NestedStruct.SzText); + + string ret = CLI.CLI.UnionWithNestedStructTestMethod(unionWithNestedStruct); + + Assert.AreEqual("TestUnions", ret); + } + } + + [Test] + public void TestUnionWithNestedStructArrayTestMethod() + { + using (var val = new StructNestedInsideUnion()) + { + using (var val2 = new StructNestedInsideUnion()) + { + byte[] bytes = Encoding.ASCII.GetBytes("TestUnion1"); + sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q)); + val.SzText = sbytes; + + byte[] bytes2 = Encoding.ASCII.GetBytes("TestUnion2"); + sbyte[] sbytes2 = Array.ConvertAll(bytes2, q => Convert.ToSByte(q)); + val2.SzText = sbytes2; + + UnionWithNestedStructArray unionWithNestedStructArray; + unionWithNestedStructArray.NestedStructs = new StructNestedInsideUnion[] { val, val2 }; + + Assert.AreEqual(2, unionWithNestedStructArray.NestedStructs.Length); + + string ret = CLI.CLI.UnionWithNestedStructArrayTestMethod(unionWithNestedStructArray); + + Assert.AreEqual("TestUnion1TestUnion2", ret); + } + } + } } \ No newline at end of file diff --git a/tests/CLI/CLI.cpp b/tests/CLI/CLI.cpp index d5626465..e006f2b8 100644 --- a/tests/CLI/CLI.cpp +++ b/tests/CLI/CLI.cpp @@ -63,4 +63,25 @@ VectorPointerGetter::~VectorPointerGetter() std::vector* VectorPointerGetter::GetVecPtr() { return vecPtr; +} + +std::string DLL_API MultipleConstantArraysParamsTestMethod(char arr1[9], char arr2[10]) +{ + return std::string(arr1, arr1 + 9) + std::string(arr2, arr2 + 10); +} + +std::string DLL_API StructWithNestedUnionTestMethod(StructWithNestedUnion val) +{ + return std::string(val.nestedUnion.szText, val.nestedUnion.szText + 10); +} + +std::string DLL_API UnionWithNestedStructTestMethod(UnionWithNestedStruct val) +{ + return std::string(val.nestedStruct.szText, val.nestedStruct.szText + 10); +} + +std::string DLL_API UnionWithNestedStructArrayTestMethod(UnionWithNestedStructArray arr) +{ + return std::string(arr.nestedStructs[0].szText, arr.nestedStructs[0].szText + 10) + + std::string(arr.nestedStructs[1].szText, arr.nestedStructs[1].szText + 10); } \ No newline at end of file diff --git a/tests/CLI/CLI.h b/tests/CLI/CLI.h index 9cd5ed31..0b70ac0a 100644 --- a/tests/CLI/CLI.h +++ b/tests/CLI/CLI.h @@ -89,4 +89,43 @@ public: private: std::vector* vecPtr; -}; \ No newline at end of file +}; + +// Previously passing multiple constant arrays was generating the same variable name for each array inside the method body. +// This is fixed by using the same generation code in CLIMarshal.VisitArrayType for both when there is a return var name specified and +// for when no return var name is specified. +std::string DLL_API MultipleConstantArraysParamsTestMethod(char arr1[9], char arr2[10]); + +// Ensures marshalling arrays is handled correctly for value types used within reference types. +union DLL_API UnionNestedInsideStruct +{ + char szText[10]; +}; + +struct DLL_API StructWithNestedUnion +{ + UnionNestedInsideStruct nestedUnion; +}; + +std::string DLL_API StructWithNestedUnionTestMethod(StructWithNestedUnion val); + +// Ensures marshalling arrays is handled correctly for reference types used within value types. +struct DLL_API StructNestedInsideUnion +{ + char szText[10]; +}; + +union DLL_API UnionWithNestedStruct +{ + StructNestedInsideUnion nestedStruct; +}; + +std::string DLL_API UnionWithNestedStructTestMethod(UnionWithNestedStruct val); + +// Ensures marshalling arrays is handled corectly for arrays of reference types used within value types. +union DLL_API UnionWithNestedStructArray +{ + StructNestedInsideUnion nestedStructs[2]; +}; + +std::string DLL_API UnionWithNestedStructArrayTestMethod(UnionWithNestedStructArray val); \ No newline at end of file