From c3ca3c241b582d164acb71683ecc8a369f57c083 Mon Sep 17 00:00:00 2001 From: josetr <37419832+josetr@users.noreply.github.com> Date: Tue, 20 Oct 2020 19:01:15 +0100 Subject: [PATCH] Fix wrong [MarshalAs(UnmanagedType)] for strings (#1438) --- src/Generator/Types/Std/Stdlib.cs | 40 +++++++++++++++++++++++++++++-- tests/CSharp/CSharp.Tests.cs | 7 ++++++ tests/CSharp/CSharp.cpp | 10 ++++++++ tests/CSharp/CSharp.h | 8 +++++++ 4 files changed, 63 insertions(+), 2 deletions(-) diff --git a/src/Generator/Types/Std/Stdlib.cs b/src/Generator/Types/Std/Stdlib.cs index 0aaffb95..8b0437b8 100644 --- a/src/Generator/Types/Std/Stdlib.cs +++ b/src/Generator/Types/Std/Stdlib.cs @@ -155,18 +155,54 @@ namespace CppSharp.Types.Std return new CustomType(typePrinter.IntPtrType); } - if (Context.Options.Encoding == Encoding.ASCII || - Context.Options.Encoding == Encoding.UTF8) + Type type = Type.Desugar(); + + uint charWidth = 0; + if (type is PointerType pointerType) + charWidth = GetCharPtrWidth(pointerType); + else if (type.GetPointee()?.Desugar() is PointerType pointeePointerType) + charWidth = GetCharPtrWidth(pointeePointerType); + + if (charWidth == 8) + { + if (Context.Options.Encoding == Encoding.ASCII) + return new CustomType("[MarshalAs(UnmanagedType.LPStr)] string"); return new CustomType("[MarshalAs(UnmanagedType.LPUTF8Str)] string"); + } if (Context.Options.Encoding == Encoding.Unicode || Context.Options.Encoding == Encoding.BigEndianUnicode) + { + // NOTE: This will break if charWidth is not 16 bit which may + // happen if someone uses wchar_t on platforms where it's size is 32 bit. + // TODO: Create a custom marshaller to support that scenario. + // Once that's done consider supporting char32_t as well return new CustomType("[MarshalAs(UnmanagedType.LPWStr)] string"); + } + + if (Context.Options.Encoding == Encoding.UTF8 && charWidth == 16) + return new CustomType("[MarshalAs(UnmanagedType.LPWStr)] string"); // fallback; throw new System.NotSupportedException( $"{Context.Options.Encoding.EncodingName} is not supported yet."); } + public uint GetCharPtrWidth(PointerType pointer) + { + var pointee = pointer?.Pointee?.Desugar(); + if (pointee != null) + { + if (pointee.IsPrimitiveType(PrimitiveType.Char)) + return Context.TargetInfo.CharWidth; + if (pointee.IsPrimitiveType(PrimitiveType.WideChar)) + return Context.TargetInfo.WCharWidth; + if (pointee.IsPrimitiveType(PrimitiveType.Char16)) + return Context.TargetInfo.Char16Width; + } + + return 0; + } + public override void CSharpMarshalToNative(CSharpMarshalContext ctx) { string param = ctx.Parameter.Name; diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index 30af9e33..a60da6c3 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -1479,4 +1479,11 @@ public unsafe class CSharpTests : GeneratorTestFixture { StringAssert.EndsWith(nameof(CSharp.TestAnonymousMemberNameCollision._0.__0), "__0"); } + + [Test] + public void TestStringMarshall() + { + Assert.IsTrue(StringMarshall.CSharpString8(StringMarshall.CSharpString)); + Assert.IsTrue(StringMarshall.CSharpString16(StringMarshall.CSharpString)); + } } diff --git a/tests/CSharp/CSharp.cpp b/tests/CSharp/CSharp.cpp index e8b5b8f4..bbc63e53 100644 --- a/tests/CSharp/CSharp.cpp +++ b/tests/CSharp/CSharp.cpp @@ -1777,3 +1777,13 @@ boolean_t takeTypemapTypedefParam(boolean_t b) { return b; } + +bool StringMarshall::CSharpString8(const char* in) +{ + return in[0] == 'C' && in[1] == '#'; +} + +bool StringMarshall::CSharpString16(const char16_t* in) +{ + return in[0] == 'C' && in[1] == '#'; +} diff --git a/tests/CSharp/CSharp.h b/tests/CSharp/CSharp.h index 8b545857..31466c8d 100644 --- a/tests/CSharp/CSharp.h +++ b/tests/CSharp/CSharp.h @@ -1481,3 +1481,11 @@ struct TestVariableWithoutType static constexpr auto variable = create(n...); }; + +class DLL_API StringMarshall +{ +public: + static constexpr const char* CSharpString = "C#"; + static bool CSharpString8(const char* in); + static bool CSharpString16(const char16_t* in); +};