Browse Source

Correctly marshal constant arrays in C++/CLI (#1346)

Co-authored-by: Build Agent <admin@sage.com>
pull/1358/head
Ali Alamiri 5 years ago committed by GitHub
parent
commit
29adf57f83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 99
      src/Generator/Generators/CLI/CLIMarshal.cs
  2. 112
      tests/CLI/CLI.Tests.cs
  3. 21
      tests/CLI/CLI.cpp
  4. 41
      tests/CLI/CLI.h

99
src/Generator/Generators/CLI/CLIMarshal.cs

@ -305,7 +305,7 @@ namespace CppSharp.Generators.CLI @@ -305,7 +305,7 @@ namespace CppSharp.Generators.CLI
Context.Return.Write("({0} == nullptr) ? nullptr : gcnew ",
instance);
Context.Return.Write("{0}(", QualifiedIdentifier(@class));
Context.Return.Write("::{0}(", QualifiedIdentifier(@class));
Context.Return.Write("(::{0}*)", @class.QualifiedOriginalName);
Context.Return.Write("{0}{1})", instance, ownNativeInstance ? ", true" : "");
}
@ -433,39 +433,42 @@ namespace CppSharp.Generators.CLI @@ -433,39 +433,42 @@ namespace CppSharp.Generators.CLI
case ArrayType.ArraySize.Constant:
if (string.IsNullOrEmpty(Context.ReturnVarName))
{
const string pinnedPtr = "__pinnedPtr";
Context.Before.WriteLine("cli::pin_ptr<{0}> {1} = &{2}[0];",
array.Type, pinnedPtr, Context.Parameter.Name);
const string arrayPtr = "__arrayPtr";
Context.Before.WriteLine("{0}* {1} = {2};", array.Type, arrayPtr, pinnedPtr);
Context.Return.Write("({0} (&)[{1}]) {2}", array.Type, array.Size, arrayPtr);
string arrayPtrRet = $"__{Context.ParameterIndex}ArrayPtr";
Context.Before.WriteLine($"{array.Type} {arrayPtrRet}[{array.Size}];");
Context.ReturnVarName = arrayPtrRet;
Context.Return.Write(arrayPtrRet);
}
else
bool isPointerToPrimitive = array.Type.IsPointerToPrimitiveType(PrimitiveType.Void);
bool isPrimitive = array.Type.IsPrimitiveType();
var supportBefore = Context.Before;
supportBefore.WriteLine("if ({0} != nullptr)", Context.Parameter.Name);
supportBefore.WriteOpenBraceAndIndent();
supportBefore.WriteLine($"if ({Context.Parameter.Name}->Length != {array.Size})");
supportBefore.WriteOpenBraceAndIndent();
supportBefore.WriteLine($"throw gcnew System::InvalidOperationException(\"Source array size must equal destination array size.\");");
supportBefore.UnindentAndWriteCloseBrace();
string nativeVal = string.Empty;
if (isPointerToPrimitive)
{
nativeVal = ".ToPointer()";
}
else if (!isPrimitive)
{
bool isPointerToPrimitive = array.Type.IsPointerToPrimitiveType(PrimitiveType.Void);
bool isPrimitive = array.Type.IsPrimitiveType();
var supportBefore = Context.Before;
supportBefore.WriteLine("if ({0} != nullptr)", Context.ArgName);
supportBefore.WriteOpenBraceAndIndent();
string nativeVal = string.Empty;
if (isPointerToPrimitive)
{
nativeVal = ".ToPointer()";
}
else if (!isPrimitive)
{
nativeVal = "->NativePtr";
}
supportBefore.WriteLine("for (int i = 0; i < {0}; i++)", array.Size);
supportBefore.WriteLineIndent("{0}[i] = {1}{2}[i]{3};",
Context.ReturnVarName,
isPointerToPrimitive || isPrimitive ? string.Empty : "*",
Context.ArgName,
nativeVal);
supportBefore.UnindentAndWriteCloseBrace();
nativeVal = "->NativePtr";
}
supportBefore.WriteLine("for (int i = 0; i < {0}; i++)", array.Size);
supportBefore.WriteLineIndent("{0}[i] = {1}{2}[i]{3};",
Context.ReturnVarName,
isPointerToPrimitive || isPrimitive ? string.Empty : "*",
Context.Parameter.Name,
nativeVal);
supportBefore.UnindentAndWriteCloseBrace();
break;
default:
Context.Return.Write("null");
@ -778,7 +781,8 @@ namespace CppSharp.Generators.CLI @@ -778,7 +781,8 @@ namespace CppSharp.Generators.CLI
{
ArgName = fieldRef,
ParameterIndex = Context.ParameterIndex++,
MarshalVarPrefix = Context.MarshalVarPrefix
MarshalVarPrefix = Context.MarshalVarPrefix,
ReturnVarName = $"{marshalVar}.{property.Field.OriginalName}"
};
var marshal = new CLIMarshalManagedToNativePrinter(marshalCtx);
@ -789,23 +793,26 @@ namespace CppSharp.Generators.CLI @@ -789,23 +793,26 @@ namespace CppSharp.Generators.CLI
if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
Context.Before.Write(marshal.Context.Before);
Type type;
Class @class;
var isRef = property.Type.IsPointerTo(out type) &&
!(type.TryGetClass(out @class) && @class.IsValueType) &&
!type.IsPrimitiveType();
if (isRef)
if (!string.IsNullOrWhiteSpace(marshal.Context.Return))
{
Context.Before.WriteLine("if ({0} != nullptr)", fieldRef);
Context.Before.Indent();
}
Type type;
Class @class;
var isRef = property.Type.IsPointerTo(out type) &&
!(type.TryGetClass(out @class) && @class.IsValueType) &&
!type.IsPrimitiveType();
Context.Before.WriteLine("{0}.{1} = {2};", marshalVar,
property.Field.OriginalName, marshal.Context.Return);
if (isRef)
{
Context.Before.WriteLine("if ({0} != nullptr)", fieldRef);
Context.Before.Indent();
}
if (isRef)
Context.Before.Unindent();
Context.Before.WriteLine("{0}.{1} = {2};", marshalVar,
property.Field.OriginalName, marshal.Context.Return);
if (isRef)
Context.Before.Unindent();
}
}
public override bool VisitFieldDecl(Field field)

112
tests/CLI/CLI.Tests.cs

@ -1,6 +1,8 @@ @@ -1,6 +1,8 @@
using CppSharp.Utils;
using NUnit.Framework;
using CLI;
using System.Text;
using System;
public class CLITests : GeneratorTestFixture
{
@ -66,4 +68,114 @@ public class CLITests : GeneratorTestFixture @@ -66,4 +68,114 @@ public class CLITests : GeneratorTestFixture
Assert.AreEqual("VectorPointerGetter", list[0]);
}
}
[Test]
public void TestMultipleConstantArraysParamsTestMethod()
{
byte[] bytes = Encoding.ASCII.GetBytes("TestMulti");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
byte[] bytes2 = Encoding.ASCII.GetBytes("TestMulti2");
sbyte[] sbytes2 = Array.ConvertAll(bytes2, q => Convert.ToSByte(q));
string s = CLI.CLI.MultipleConstantArraysParamsTestMethod(sbytes, sbytes2);
Assert.AreEqual("TestMultiTestMulti2", s);
}
[Test]
public void TestMultipleConstantArraysParamsTestMethodLongerSourceArray()
{
byte[] bytes = Encoding.ASCII.GetBytes("TestMultipleConstantArraysParamsTestMethodLongerSourceArray");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
Assert.Throws<InvalidOperationException>(() => CLI.CLI.MultipleConstantArraysParamsTestMethod(sbytes, new sbyte[] { }));
}
[Test]
public void TestStructWithNestedUnionTestMethod()
{
using (var val = new StructWithNestedUnion())
{
byte[] bytes = Encoding.ASCII.GetBytes("TestUnions");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
UnionNestedInsideStruct unionNestedInsideStruct;
unionNestedInsideStruct.SzText = sbytes;
Assert.AreEqual(sbytes.Length, unionNestedInsideStruct.SzText.Length);
Assert.AreEqual("TestUnions", unionNestedInsideStruct.SzText);
val.NestedUnion = unionNestedInsideStruct;
Assert.AreEqual(10, val.NestedUnion.SzText.Length);
Assert.AreEqual("TestUnions", val.NestedUnion.SzText);
string ret = CLI.CLI.StructWithNestedUnionTestMethod(val);
Assert.AreEqual("TestUnions", ret);
}
}
[Test]
public void TestStructWithNestedUnionLongerSourceArray()
{
using (var val = new StructWithNestedUnion())
{
byte[] bytes = Encoding.ASCII.GetBytes("TestStructWithNestedUnionLongerSourceArray");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
UnionNestedInsideStruct unionNestedInsideStruct;
unionNestedInsideStruct.SzText = sbytes;
Assert.Throws<InvalidOperationException>(() => val.NestedUnion = unionNestedInsideStruct);
}
}
[Test]
public void TestUnionWithNestedStructTestMethod()
{
using (var val = new StructNestedInsideUnion())
{
byte[] bytes = Encoding.ASCII.GetBytes("TestUnions");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
val.SzText = sbytes;
UnionWithNestedStruct unionWithNestedStruct;
unionWithNestedStruct.NestedStruct = val;
Assert.AreEqual(10, unionWithNestedStruct.NestedStruct.SzText.Length);
Assert.AreEqual("TestUnions", unionWithNestedStruct.NestedStruct.SzText);
string ret = CLI.CLI.UnionWithNestedStructTestMethod(unionWithNestedStruct);
Assert.AreEqual("TestUnions", ret);
}
}
[Test]
public void TestUnionWithNestedStructArrayTestMethod()
{
using (var val = new StructNestedInsideUnion())
{
using (var val2 = new StructNestedInsideUnion())
{
byte[] bytes = Encoding.ASCII.GetBytes("TestUnion1");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
val.SzText = sbytes;
byte[] bytes2 = Encoding.ASCII.GetBytes("TestUnion2");
sbyte[] sbytes2 = Array.ConvertAll(bytes2, q => Convert.ToSByte(q));
val2.SzText = sbytes2;
UnionWithNestedStructArray unionWithNestedStructArray;
unionWithNestedStructArray.NestedStructs = new StructNestedInsideUnion[] { val, val2 };
Assert.AreEqual(2, unionWithNestedStructArray.NestedStructs.Length);
string ret = CLI.CLI.UnionWithNestedStructArrayTestMethod(unionWithNestedStructArray);
Assert.AreEqual("TestUnion1TestUnion2", ret);
}
}
}
}

21
tests/CLI/CLI.cpp

@ -63,4 +63,25 @@ VectorPointerGetter::~VectorPointerGetter() @@ -63,4 +63,25 @@ VectorPointerGetter::~VectorPointerGetter()
std::vector<std::string>* VectorPointerGetter::GetVecPtr()
{
return vecPtr;
}
std::string DLL_API MultipleConstantArraysParamsTestMethod(char arr1[9], char arr2[10])
{
return std::string(arr1, arr1 + 9) + std::string(arr2, arr2 + 10);
}
std::string DLL_API StructWithNestedUnionTestMethod(StructWithNestedUnion val)
{
return std::string(val.nestedUnion.szText, val.nestedUnion.szText + 10);
}
std::string DLL_API UnionWithNestedStructTestMethod(UnionWithNestedStruct val)
{
return std::string(val.nestedStruct.szText, val.nestedStruct.szText + 10);
}
std::string DLL_API UnionWithNestedStructArrayTestMethod(UnionWithNestedStructArray arr)
{
return std::string(arr.nestedStructs[0].szText, arr.nestedStructs[0].szText + 10)
+ std::string(arr.nestedStructs[1].szText, arr.nestedStructs[1].szText + 10);
}

41
tests/CLI/CLI.h

@ -89,4 +89,43 @@ public: @@ -89,4 +89,43 @@ public:
private:
std::vector<std::string>* vecPtr;
};
};
// Previously passing multiple constant arrays was generating the same variable name for each array inside the method body.
// This is fixed by using the same generation code in CLIMarshal.VisitArrayType for both when there is a return var name specified and
// for when no return var name is specified.
std::string DLL_API MultipleConstantArraysParamsTestMethod(char arr1[9], char arr2[10]);
// Ensures marshalling arrays is handled correctly for value types used within reference types.
union DLL_API UnionNestedInsideStruct
{
char szText[10];
};
struct DLL_API StructWithNestedUnion
{
UnionNestedInsideStruct nestedUnion;
};
std::string DLL_API StructWithNestedUnionTestMethod(StructWithNestedUnion val);
// Ensures marshalling arrays is handled correctly for reference types used within value types.
struct DLL_API StructNestedInsideUnion
{
char szText[10];
};
union DLL_API UnionWithNestedStruct
{
StructNestedInsideUnion nestedStruct;
};
std::string DLL_API UnionWithNestedStructTestMethod(UnionWithNestedStruct val);
// Ensures marshalling arrays is handled corectly for arrays of reference types used within value types.
union DLL_API UnionWithNestedStructArray
{
StructNestedInsideUnion nestedStructs[2];
};
std::string DLL_API UnionWithNestedStructArrayTestMethod(UnionWithNestedStructArray val);
Loading…
Cancel
Save