Browse Source

Correctly marshal constant arrays in C++/CLI (#1346)

Co-authored-by: Build Agent <admin@sage.com>
pull/1358/head
Ali Alamiri 5 years ago committed by GitHub
parent
commit
29adf57f83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 99
      src/Generator/Generators/CLI/CLIMarshal.cs
  2. 112
      tests/CLI/CLI.Tests.cs
  3. 21
      tests/CLI/CLI.cpp
  4. 41
      tests/CLI/CLI.h

99
src/Generator/Generators/CLI/CLIMarshal.cs

@ -305,7 +305,7 @@ namespace CppSharp.Generators.CLI
Context.Return.Write("({0} == nullptr) ? nullptr : gcnew ", Context.Return.Write("({0} == nullptr) ? nullptr : gcnew ",
instance); instance);
Context.Return.Write("{0}(", QualifiedIdentifier(@class)); Context.Return.Write("::{0}(", QualifiedIdentifier(@class));
Context.Return.Write("(::{0}*)", @class.QualifiedOriginalName); Context.Return.Write("(::{0}*)", @class.QualifiedOriginalName);
Context.Return.Write("{0}{1})", instance, ownNativeInstance ? ", true" : ""); Context.Return.Write("{0}{1})", instance, ownNativeInstance ? ", true" : "");
} }
@ -433,39 +433,42 @@ namespace CppSharp.Generators.CLI
case ArrayType.ArraySize.Constant: case ArrayType.ArraySize.Constant:
if (string.IsNullOrEmpty(Context.ReturnVarName)) if (string.IsNullOrEmpty(Context.ReturnVarName))
{ {
const string pinnedPtr = "__pinnedPtr"; string arrayPtrRet = $"__{Context.ParameterIndex}ArrayPtr";
Context.Before.WriteLine("cli::pin_ptr<{0}> {1} = &{2}[0];", Context.Before.WriteLine($"{array.Type} {arrayPtrRet}[{array.Size}];");
array.Type, pinnedPtr, Context.Parameter.Name);
const string arrayPtr = "__arrayPtr"; Context.ReturnVarName = arrayPtrRet;
Context.Before.WriteLine("{0}* {1} = {2};", array.Type, arrayPtr, pinnedPtr);
Context.Return.Write("({0} (&)[{1}]) {2}", array.Type, array.Size, arrayPtr); Context.Return.Write(arrayPtrRet);
} }
else
bool isPointerToPrimitive = array.Type.IsPointerToPrimitiveType(PrimitiveType.Void);
bool isPrimitive = array.Type.IsPrimitiveType();
var supportBefore = Context.Before;
supportBefore.WriteLine("if ({0} != nullptr)", Context.Parameter.Name);
supportBefore.WriteOpenBraceAndIndent();
supportBefore.WriteLine($"if ({Context.Parameter.Name}->Length != {array.Size})");
supportBefore.WriteOpenBraceAndIndent();
supportBefore.WriteLine($"throw gcnew System::InvalidOperationException(\"Source array size must equal destination array size.\");");
supportBefore.UnindentAndWriteCloseBrace();
string nativeVal = string.Empty;
if (isPointerToPrimitive)
{
nativeVal = ".ToPointer()";
}
else if (!isPrimitive)
{ {
bool isPointerToPrimitive = array.Type.IsPointerToPrimitiveType(PrimitiveType.Void); nativeVal = "->NativePtr";
bool isPrimitive = array.Type.IsPrimitiveType();
var supportBefore = Context.Before;
supportBefore.WriteLine("if ({0} != nullptr)", Context.ArgName);
supportBefore.WriteOpenBraceAndIndent();
string nativeVal = string.Empty;
if (isPointerToPrimitive)
{
nativeVal = ".ToPointer()";
}
else if (!isPrimitive)
{
nativeVal = "->NativePtr";
}
supportBefore.WriteLine("for (int i = 0; i < {0}; i++)", array.Size);
supportBefore.WriteLineIndent("{0}[i] = {1}{2}[i]{3};",
Context.ReturnVarName,
isPointerToPrimitive || isPrimitive ? string.Empty : "*",
Context.ArgName,
nativeVal);
supportBefore.UnindentAndWriteCloseBrace();
} }
supportBefore.WriteLine("for (int i = 0; i < {0}; i++)", array.Size);
supportBefore.WriteLineIndent("{0}[i] = {1}{2}[i]{3};",
Context.ReturnVarName,
isPointerToPrimitive || isPrimitive ? string.Empty : "*",
Context.Parameter.Name,
nativeVal);
supportBefore.UnindentAndWriteCloseBrace();
break; break;
default: default:
Context.Return.Write("null"); Context.Return.Write("null");
@ -778,7 +781,8 @@ namespace CppSharp.Generators.CLI
{ {
ArgName = fieldRef, ArgName = fieldRef,
ParameterIndex = Context.ParameterIndex++, ParameterIndex = Context.ParameterIndex++,
MarshalVarPrefix = Context.MarshalVarPrefix MarshalVarPrefix = Context.MarshalVarPrefix,
ReturnVarName = $"{marshalVar}.{property.Field.OriginalName}"
}; };
var marshal = new CLIMarshalManagedToNativePrinter(marshalCtx); var marshal = new CLIMarshalManagedToNativePrinter(marshalCtx);
@ -789,23 +793,26 @@ namespace CppSharp.Generators.CLI
if (!string.IsNullOrWhiteSpace(marshal.Context.Before)) if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
Context.Before.Write(marshal.Context.Before); Context.Before.Write(marshal.Context.Before);
Type type; if (!string.IsNullOrWhiteSpace(marshal.Context.Return))
Class @class;
var isRef = property.Type.IsPointerTo(out type) &&
!(type.TryGetClass(out @class) && @class.IsValueType) &&
!type.IsPrimitiveType();
if (isRef)
{ {
Context.Before.WriteLine("if ({0} != nullptr)", fieldRef); Type type;
Context.Before.Indent(); Class @class;
} var isRef = property.Type.IsPointerTo(out type) &&
!(type.TryGetClass(out @class) && @class.IsValueType) &&
!type.IsPrimitiveType();
Context.Before.WriteLine("{0}.{1} = {2};", marshalVar, if (isRef)
property.Field.OriginalName, marshal.Context.Return); {
Context.Before.WriteLine("if ({0} != nullptr)", fieldRef);
Context.Before.Indent();
}
if (isRef) Context.Before.WriteLine("{0}.{1} = {2};", marshalVar,
Context.Before.Unindent(); property.Field.OriginalName, marshal.Context.Return);
if (isRef)
Context.Before.Unindent();
}
} }
public override bool VisitFieldDecl(Field field) public override bool VisitFieldDecl(Field field)

112
tests/CLI/CLI.Tests.cs

@ -1,6 +1,8 @@
using CppSharp.Utils; using CppSharp.Utils;
using NUnit.Framework; using NUnit.Framework;
using CLI; using CLI;
using System.Text;
using System;
public class CLITests : GeneratorTestFixture public class CLITests : GeneratorTestFixture
{ {
@ -66,4 +68,114 @@ public class CLITests : GeneratorTestFixture
Assert.AreEqual("VectorPointerGetter", list[0]); Assert.AreEqual("VectorPointerGetter", list[0]);
} }
} }
[Test]
public void TestMultipleConstantArraysParamsTestMethod()
{
byte[] bytes = Encoding.ASCII.GetBytes("TestMulti");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
byte[] bytes2 = Encoding.ASCII.GetBytes("TestMulti2");
sbyte[] sbytes2 = Array.ConvertAll(bytes2, q => Convert.ToSByte(q));
string s = CLI.CLI.MultipleConstantArraysParamsTestMethod(sbytes, sbytes2);
Assert.AreEqual("TestMultiTestMulti2", s);
}
[Test]
public void TestMultipleConstantArraysParamsTestMethodLongerSourceArray()
{
byte[] bytes = Encoding.ASCII.GetBytes("TestMultipleConstantArraysParamsTestMethodLongerSourceArray");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
Assert.Throws<InvalidOperationException>(() => CLI.CLI.MultipleConstantArraysParamsTestMethod(sbytes, new sbyte[] { }));
}
[Test]
public void TestStructWithNestedUnionTestMethod()
{
using (var val = new StructWithNestedUnion())
{
byte[] bytes = Encoding.ASCII.GetBytes("TestUnions");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
UnionNestedInsideStruct unionNestedInsideStruct;
unionNestedInsideStruct.SzText = sbytes;
Assert.AreEqual(sbytes.Length, unionNestedInsideStruct.SzText.Length);
Assert.AreEqual("TestUnions", unionNestedInsideStruct.SzText);
val.NestedUnion = unionNestedInsideStruct;
Assert.AreEqual(10, val.NestedUnion.SzText.Length);
Assert.AreEqual("TestUnions", val.NestedUnion.SzText);
string ret = CLI.CLI.StructWithNestedUnionTestMethod(val);
Assert.AreEqual("TestUnions", ret);
}
}
[Test]
public void TestStructWithNestedUnionLongerSourceArray()
{
using (var val = new StructWithNestedUnion())
{
byte[] bytes = Encoding.ASCII.GetBytes("TestStructWithNestedUnionLongerSourceArray");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
UnionNestedInsideStruct unionNestedInsideStruct;
unionNestedInsideStruct.SzText = sbytes;
Assert.Throws<InvalidOperationException>(() => val.NestedUnion = unionNestedInsideStruct);
}
}
[Test]
public void TestUnionWithNestedStructTestMethod()
{
using (var val = new StructNestedInsideUnion())
{
byte[] bytes = Encoding.ASCII.GetBytes("TestUnions");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
val.SzText = sbytes;
UnionWithNestedStruct unionWithNestedStruct;
unionWithNestedStruct.NestedStruct = val;
Assert.AreEqual(10, unionWithNestedStruct.NestedStruct.SzText.Length);
Assert.AreEqual("TestUnions", unionWithNestedStruct.NestedStruct.SzText);
string ret = CLI.CLI.UnionWithNestedStructTestMethod(unionWithNestedStruct);
Assert.AreEqual("TestUnions", ret);
}
}
[Test]
public void TestUnionWithNestedStructArrayTestMethod()
{
using (var val = new StructNestedInsideUnion())
{
using (var val2 = new StructNestedInsideUnion())
{
byte[] bytes = Encoding.ASCII.GetBytes("TestUnion1");
sbyte[] sbytes = Array.ConvertAll(bytes, q => Convert.ToSByte(q));
val.SzText = sbytes;
byte[] bytes2 = Encoding.ASCII.GetBytes("TestUnion2");
sbyte[] sbytes2 = Array.ConvertAll(bytes2, q => Convert.ToSByte(q));
val2.SzText = sbytes2;
UnionWithNestedStructArray unionWithNestedStructArray;
unionWithNestedStructArray.NestedStructs = new StructNestedInsideUnion[] { val, val2 };
Assert.AreEqual(2, unionWithNestedStructArray.NestedStructs.Length);
string ret = CLI.CLI.UnionWithNestedStructArrayTestMethod(unionWithNestedStructArray);
Assert.AreEqual("TestUnion1TestUnion2", ret);
}
}
}
} }

21
tests/CLI/CLI.cpp

@ -63,4 +63,25 @@ VectorPointerGetter::~VectorPointerGetter()
std::vector<std::string>* VectorPointerGetter::GetVecPtr() std::vector<std::string>* VectorPointerGetter::GetVecPtr()
{ {
return vecPtr; return vecPtr;
}
std::string DLL_API MultipleConstantArraysParamsTestMethod(char arr1[9], char arr2[10])
{
return std::string(arr1, arr1 + 9) + std::string(arr2, arr2 + 10);
}
std::string DLL_API StructWithNestedUnionTestMethod(StructWithNestedUnion val)
{
return std::string(val.nestedUnion.szText, val.nestedUnion.szText + 10);
}
std::string DLL_API UnionWithNestedStructTestMethod(UnionWithNestedStruct val)
{
return std::string(val.nestedStruct.szText, val.nestedStruct.szText + 10);
}
std::string DLL_API UnionWithNestedStructArrayTestMethod(UnionWithNestedStructArray arr)
{
return std::string(arr.nestedStructs[0].szText, arr.nestedStructs[0].szText + 10)
+ std::string(arr.nestedStructs[1].szText, arr.nestedStructs[1].szText + 10);
} }

41
tests/CLI/CLI.h

@ -89,4 +89,43 @@ public:
private: private:
std::vector<std::string>* vecPtr; std::vector<std::string>* vecPtr;
}; };
// Previously passing multiple constant arrays was generating the same variable name for each array inside the method body.
// This is fixed by using the same generation code in CLIMarshal.VisitArrayType for both when there is a return var name specified and
// for when no return var name is specified.
std::string DLL_API MultipleConstantArraysParamsTestMethod(char arr1[9], char arr2[10]);
// Ensures marshalling arrays is handled correctly for value types used within reference types.
union DLL_API UnionNestedInsideStruct
{
char szText[10];
};
struct DLL_API StructWithNestedUnion
{
UnionNestedInsideStruct nestedUnion;
};
std::string DLL_API StructWithNestedUnionTestMethod(StructWithNestedUnion val);
// Ensures marshalling arrays is handled correctly for reference types used within value types.
struct DLL_API StructNestedInsideUnion
{
char szText[10];
};
union DLL_API UnionWithNestedStruct
{
StructNestedInsideUnion nestedStruct;
};
std::string DLL_API UnionWithNestedStructTestMethod(UnionWithNestedStruct val);
// Ensures marshalling arrays is handled corectly for arrays of reference types used within value types.
union DLL_API UnionWithNestedStructArray
{
StructNestedInsideUnion nestedStructs[2];
};
std::string DLL_API UnionWithNestedStructArrayTestMethod(UnionWithNestedStructArray val);
Loading…
Cancel
Save