diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index fe2e2006..ec779186 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -455,6 +455,12 @@ namespace CppSharp.Generators.CSharp var dict = $@"global::System.Collections.Concurrent.ConcurrentDictionary"; WriteLine("internal static readonly {0} NativeToManagedMap = new {0}();", dict); + + // Add booleans to track who owns unmanaged memory for string fields + foreach (var field in @class.Layout.Fields.Where(f => f.QualifiedType.Type.IsConstCharString())) + { + WriteLine($"private bool __{field.Name}_OwnsNativeMemory = false;"); + } } PopBlock(NewLineKind.BeforeNextBlock); } @@ -871,7 +877,7 @@ namespace CppSharp.Generators.CSharp PopBlock(NewLineKind.BeforeNextBlock); } - #endregion +#endregion private void GeneratePropertySetter(T decl, Class @class, bool isAbstract = false, Property property = null) @@ -1411,13 +1417,17 @@ namespace CppSharp.Generators.CSharp if (templateSubstitution != null && returnType.Type.IsDependent) Write($"({templateSubstitution.ReplacedParameter.Parameter.Name}) (object) "); if ((final.IsPrimitiveType() && !final.IsPrimitiveType(PrimitiveType.Void) && - (!final.IsPrimitiveType(PrimitiveType.Char) && - !final.IsPrimitiveType(PrimitiveType.WideChar) || + ((!final.IsPrimitiveType(PrimitiveType.Char) && + !final.IsPrimitiveType(PrimitiveType.WideChar) && + !final.IsPrimitiveType(PrimitiveType.Char16) && + !final.IsPrimitiveType(PrimitiveType.Char32)) || (!Context.Options.MarshalCharAsManagedChar && !((PointerType) field.Type).QualifiedPointee.Qualifiers.IsConst)) && templateSubstitution == null) || (!((PointerType) field.Type).QualifiedPointee.Qualifiers.IsConst && - final.IsPrimitiveType(PrimitiveType.WideChar))) + (final.IsPrimitiveType(PrimitiveType.WideChar) || + final.IsPrimitiveType(PrimitiveType.Char16) || + final.IsPrimitiveType(PrimitiveType.Char32)))) Write($"({field.Type.GetPointee().Desugar()}*) "); } WriteLine($"{@return};"); @@ -1626,7 +1636,7 @@ namespace CppSharp.Generators.CSharp PopBlock(NewLineKind.BeforeNextBlock); } - #region Virtual Tables +#region Virtual Tables public List GetUniqueVTableMethodEntries(Class @class) { @@ -2033,9 +2043,9 @@ namespace CppSharp.Generators.CSharp return @class.IsGenerated && @class.IsDynamic && GetUniqueVTableMethodEntries(@class).Count > 0; } - #endregion +#endregion - #region Events +#region Events public override bool VisitEvent(Event @event) { @@ -2143,9 +2153,9 @@ namespace CppSharp.Generators.CSharp UnindentAndWriteCloseBrace(); } - #endregion +#endregion - #region Constructors +#region Constructors public void GenerateClassConstructors(Class @class) { @@ -2266,6 +2276,21 @@ namespace CppSharp.Generators.CSharp } } + // If we have any fields holding references to unmanaged memory allocated here, free the + // referenced memory. Don't rely on testing if the field's IntPtr is IntPtr.Zero since + // unmanaged memory isn't always initialized and/or a reference may be owned by the + // native side. + // + // TODO: We should delegate to the dispose methods of references we hold to other + // generated type instances since those instances could also hold references to + // unmanaged memory. + foreach (var field in @class.Layout.Fields.Where(f => f.QualifiedType.Type.IsConstCharString())) + { + var ptr = $"(({Helpers.InternalStruct}*){Helpers.InstanceIdentifier})->{field.Name}"; + WriteLine($"if (__{field.Name}_OwnsNativeMemory)"); + WriteLineIndent($"Marshal.FreeHGlobal({ptr});"); + } + WriteLine("if ({0})", Helpers.OwnsNativeInstanceIdentifier); WriteLineIndent("Marshal.FreeHGlobal({0});", Helpers.InstanceIdentifier); @@ -2482,9 +2507,9 @@ internal static{(@new ? " new" : string.Empty)} {printedClass} __GetInstance({Ty WriteLineIndent(": this()"); } - #endregion +#endregion - #region Methods / Functions +#region Methods / Functions public void GenerateFunction(Function function, string parentName) { @@ -2898,6 +2923,21 @@ internal static{(@new ? " new" : string.Empty)} {printedClass} __GetInstance({Ty var classInternal = TypePrinter.PrintNative(@class); WriteLine($@"*(({classInternal}*) {Helpers.InstanceIdentifier}) = *(({ classInternal}*) {method.Parameters[0].Name}.{Helpers.InstanceIdentifier});"); + + // Copy any string references owned by the source to the new instance so we + // don't have to ref count them. + foreach (var field in @class.Fields.Where(f => f.QualifiedType.Type.IsConstCharString())) + { + var prop = @class.Properties.Where(p => p.Field == field).FirstOrDefault(); + // If there is no property or no setter then this instance can never own the native + // memory. Worry about the case where there's only a setter (write-only) when we + // understand the use case and how it can occur. + if (prop != null && prop.HasGetter && prop.HasSetter) + { + WriteLine($"if ({method.Parameters[0].Name}.__{field.OriginalName}_OwnsNativeMemory)"); + WriteLineIndent($@"this.{prop.Name} = {method.Parameters[0].Name}.{prop.Name};"); + } + } } } else @@ -3230,7 +3270,7 @@ internal static{(@new ? " new" : string.Empty)} {printedClass} __GetInstance({Ty return TypePrinter.VisitParameters(@params, true).Type; } - #endregion +#endregion public override bool VisitTypedefNameDecl(TypedefNameDecl typedef) { diff --git a/src/Generator/Types/Std/Stdlib.CSharp.cs b/src/Generator/Types/Std/Stdlib.CSharp.cs index 32ad5b3c..b31936d4 100644 --- a/src/Generator/Types/Std/Stdlib.CSharp.cs +++ b/src/Generator/Types/Std/Stdlib.CSharp.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.Linq; +using System.Runtime.InteropServices; using System.Text; using CppSharp.AST; using CppSharp.AST.Extensions; @@ -96,15 +97,18 @@ namespace CppSharp.Types.Std return new CustomType(typePrinter.IntPtrType); } - var (enconding, _) = GetEncoding(); + var (encoding, _) = GetEncoding(); - if (enconding == Encoding.ASCII) - return new CustomType("[MarshalAs(UnmanagedType.LPStr)] string"); - else if (enconding == Encoding.UTF8) + if (encoding == Encoding.ASCII || encoding == Encoding.Default) + // This is not really right. ASCII is 7-bit only - the 8th bit is stripped; ANSI has + // multi-byte support via a code page. MarshalAs(UnmanagedType.LPStr) marshals as ANSI. + // Perhaps we need a CppSharp.Runtime.ASCIIMarshaller? + return new CustomType("[MarshalAs(UnmanagedType.LPStr)] string"); + else if (encoding == Encoding.UTF8) return new CustomType("[MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CppSharp.Runtime.UTF8Marshaller))] string"); - else if (enconding == Encoding.Unicode || enconding == Encoding.BigEndianUnicode) + else if (encoding == Encoding.Unicode || encoding == Encoding.BigEndianUnicode) return new CustomType("[MarshalAs(UnmanagedType.LPWStr)] string"); - else if (enconding == Encoding.UTF32) + else if (encoding == Encoding.UTF32) return new CustomType("[MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CppSharp.Runtime.UTF32Marshaller))] string"); throw new System.NotSupportedException( @@ -129,14 +133,63 @@ namespace CppSharp.Types.Std if (substitution != null) param = $"({substitution.Replacement}) (object) {param}"; - string bytes = $"__bytes{ctx.ParameterIndex}"; - string bytePtr = $"__bytePtr{ctx.ParameterIndex}"; - ctx.Before.WriteLine($@"byte[] {bytes} = global::System.Text.Encoding.{ - GetEncoding().Name}.GetBytes({param});"); - ctx.Before.WriteLine($"fixed (byte* {bytePtr} = {bytes})"); - ctx.HasCodeBlock = true; - ctx.Before.WriteOpenBraceAndIndent(); - ctx.Return.Write($"new global::System.IntPtr({bytePtr})"); + // Allow setting native field to null via setter property. + if (ctx.MarshalKind == MarshalKind.NativeField) + { + // Free memory if we're holding a pointer to unmanaged memory that we (think we) + // allocated. We can't simply compare with IntPtr.Zero since the reference could be + // owned by the native side. + + // TODO: Surely, we can do better than stripping out the name of the field using + // string manipulation on the ReturnVarName, but I don't see it yet. Seems like it + // would be really helpful to have ctx hold a Decl property representing the + // "appropriate" Decl when we get here. When MarshalKind == NativeField, Decl would + // be set to the Field we're operating on. + var fieldName = ctx.ReturnVarName.Substring(ctx.ReturnVarName.LastIndexOf("->") + 2); + + ctx.Before.WriteLine($"if (__{fieldName}_OwnsNativeMemory)"); + ctx.Before.WriteLineIndent($"Marshal.FreeHGlobal({ctx.ReturnVarName});"); + ctx.Before.WriteLine($"__{fieldName}_OwnsNativeMemory = true;"); + ctx.Before.WriteLine($"if ({param} == null)"); + ctx.Before.WriteOpenBraceAndIndent(); + ctx.Before.WriteLine($"{ctx.ReturnVarName} = global::System.IntPtr.Zero;"); + ctx.Before.WriteLine("return;"); + ctx.Before.UnindentAndWriteCloseBrace(); + } + + var bytes = $"__bytes{ctx.ParameterIndex}"; + var bytePtr = $"__bytePtr{ctx.ParameterIndex}"; + var encodingName = GetEncoding().Name; + + switch (encodingName) + { + case nameof(Encoding.Unicode): + ctx.Before.WriteLine($@"var {bytePtr} = Marshal.StringToHGlobalUni({param});"); + break; + case nameof(Encoding.Default): + ctx.Before.WriteLine($@"var {bytePtr} = Marshal.StringToHGlobalAnsi({param});"); + break; + default: + { + var encodingBytesPerChar = GetCharWidth() / 8; + var writeNulMethod = encodingBytesPerChar switch + { + 1 => nameof(Marshal.WriteByte), + 2 => nameof(Marshal.WriteInt16), + 4 => nameof(Marshal.WriteInt32), + _ => throw new System.NotImplementedException( + $"Encoding bytes per char: {encodingBytesPerChar} is not implemented.") + }; + + ctx.Before.WriteLine($@"var {bytes} = global::System.Text.Encoding.{encodingName}.GetBytes({param});"); + ctx.Before.WriteLine($@"var {bytePtr} = Marshal.AllocHGlobal({bytes}.Length + {encodingBytesPerChar});"); + ctx.Before.WriteLine($"Marshal.Copy({bytes}, 0, {bytePtr}, {bytes}.Length);"); + ctx.Before.WriteLine($"Marshal.{writeNulMethod}({bytePtr} + {bytes}.Length, 0);"); + } + break; + } + + ctx.Return.Write($"{bytePtr}"); } public override void CSharpMarshalToManaged(CSharpMarshalContext ctx) @@ -168,6 +221,8 @@ namespace CppSharp.Types.Std switch (GetCharWidth()) { case 8: + if (Context.Options.Encoding == Encoding.Default) // aka ANSI with system default code page + return (Context.Options.Encoding, nameof(Encoding.Default)); if (Context.Options.Encoding == Encoding.ASCII) return (Context.Options.Encoding, nameof(Encoding.ASCII)); if (Context.Options.Encoding == Encoding.UTF8) diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index f881a99d..25d5adc7 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Runtime.InteropServices; +using System.Text; using CSharp; using NUnit.Framework; @@ -868,6 +869,153 @@ public unsafe class CSharpTests } } + [Test] + public void TestStringMemManagement() + { + const int instanceCount = 100; + const string otherString = @"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."; + + var batch = new TestString[instanceCount]; + for (var i = 0; i < instanceCount; i++) + { + batch[i] = new TestString { UnicodeConst = otherString }; + if (batch[i].UnicodeConst != otherString) + { + throw new Exception($"iteration {i}"); + } + } + + GC.Collect(); + + for (var i = 0; i < instanceCount; i++) + { + if (batch[i].UnicodeConst != otherString) + { + throw new Exception($"iteration {i}"); + } + Assert.That(batch[i].UnicodeConst, Is.EqualTo(otherString)); + } + + Array.ForEach(batch, ts => ts.Dispose()); + } + + static bool OwnsNativeMemory(T instance, string fieldName) + { + return (bool)instance.GetType() + .GetField(fieldName, BindingFlags.Instance | BindingFlags.NonPublic) + .GetValue(instance); + } + + [Test] + public void TestManagedOwnsChar32String() + { + const string constructorString = "ქართული ენა"; + const string str = "ßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿĀāĂ㥹ĆćĈĉĊċČčĎďĐđĒēĔĕĖėĘęĚěĜĝĞğĠġĢģĤĥĦħĨĩĪīĬĭĮįİıIJijĴĵ"; + + using (var ts = new TestChar32String()) + { + Assert.That(ts.ThirtyTwoBitConst, Is.EqualTo(constructorString)); + Assert.That(OwnsNativeMemory(ts, "__thirtyTwoBitConst_OwnsNativeMemory"), Is.EqualTo(false)); + + ts.ThirtyTwoBitConst = str; + Assert.That(ts.RetrieveString, Is.EqualTo(str)); + Assert.That(OwnsNativeMemory(ts, "__thirtyTwoBitConst_OwnsNativeMemory"), Is.EqualTo(true)); + } + } + + [Test] + public void TestNativeOwnsChar32String() + { + const string constructorString = "ქართული ენა"; + const string str = "ҪҫҬҭҮүҰұҲҳҴҵҶҷҸҹҺһҼҽҾҿӀӁӂӃӄӅӆӇӈӉӊӋӌӍӎӏӐӑӒӓӔӕӖӗӘәӚӛӜӝӞӟӠӡӢӣӤӥӦӧӨөӪӫӬӭӮӯӰӱӲӳӴӵӶӷӸӹӺӻӼӽ"; + const string otherStr = "Test String"; + + using (var ts = new TestChar32String()) + { + Assert.That(ts.ThirtyTwoBitConst, Is.EqualTo(constructorString)); + Assert.That(OwnsNativeMemory(ts, "__thirtyTwoBitConst_OwnsNativeMemory"), Is.EqualTo(false)); + ts.UpdateString(str); + Assert.That(ts.ThirtyTwoBitConst, Is.EqualTo(str)); + Assert.That(OwnsNativeMemory(ts, "__thirtyTwoBitConst_OwnsNativeMemory"), Is.EqualTo(false)); + + var x = (uint *)ts.ThirtyTwoBitNonConst; + for (int i = 0; i < otherStr.Length; i++) + { + Assert.That(*x++, Is.EqualTo(otherStr[i])); + } + Assert.That(*x, Is.EqualTo(0)); + } + } + + [Test] + public void TestManagedOwnsChar16String() + { + const string constructorString = "ქართული ენა"; + const string str = "ßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿĀāĂ㥹ĆćĈĉĊċČčĎďĐđĒēĔĕĖėĘęĚěĜĝĞğĠġĢģĤĥĦħĨĩĪīĬĭĮįİıIJijĴĵ"; + + using (var ts = new TestChar16String()) + { + Assert.That(ts.SixteenBitConst, Is.EqualTo(constructorString)); + Assert.That(OwnsNativeMemory(ts, "__sixteenBitConst_OwnsNativeMemory"), Is.EqualTo(false)); + + ts.SixteenBitConst = str; + Assert.That(ts.RetrieveString, Is.EqualTo(str)); + Assert.That(OwnsNativeMemory(ts, "__sixteenBitConst_OwnsNativeMemory"), Is.EqualTo(true)); + } + } + + [Test] + public void TestNativeOwnsChar16String() + { + const string constructorString = "ქართული ენა"; + const string str = "ѐёђѓєѕіїјљњћќѝўџѠѡѢѣѤѥѦѧѨѩѪѫѬѭѮѯѰѱѲѳѴѵѶѷѸѹѺѻѼѽѾѿҀҁҊҋҌҍҎҏҐґҒғҔҕҖҗҘҙҚқҜҝҞҟҠҡҢңҤҥҦҧҨҩ"; + const string otherStr = "Test String"; + + using (var ts = new TestChar16String()) + { + Assert.That(ts.SixteenBitConst, Is.EqualTo(constructorString)); + Assert.That(OwnsNativeMemory(ts, "__sixteenBitConst_OwnsNativeMemory"), Is.EqualTo(false)); + + ts.UpdateString(str); + Assert.That(ts.SixteenBitConst, Is.EqualTo(str)); + Assert.That(OwnsNativeMemory(ts, "__sixteenBitConst_OwnsNativeMemory"), Is.EqualTo(false)); + + var x = ts.SixteenBitNonConst; + for (int i = 0; i < otherStr.Length; i++) + { + Assert.That(*x++, Is.EqualTo(otherStr[i])); + } + Assert.That(*x, Is.EqualTo(0)); + } + } + + [Test] + public void TestStringRefWithCopyConstructor() + { + const string otherString = @"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."; + var ts1 = new TestString { UnicodeConst = otherString }; + var ts2 = new TestString(ts1); + + // verify that the copy has its own reference to UnicodeConst. + var ownsNativeMemory = (bool)ts2.GetType() + .GetField("__unicodeConst_OwnsNativeMemory", BindingFlags.Instance | BindingFlags.NonPublic) + .GetValue(ts2); + Assert.That(true, Is.EqualTo(ownsNativeMemory)); + + var offset = Marshal.OffsetOf("unicodeConst"); + var ts1PtrRef = IntPtr.Add(ts1.__Instance, (int)offset); + var ts2PtrRef = IntPtr.Add(ts2.__Instance, (int)offset); + var ts1Ptr = *(IntPtr*)ts1PtrRef; + var ts2Ptr = *(IntPtr*)ts2PtrRef; + Assert.That(ts1Ptr != ts2Ptr); + + // should be able to dispose in any order. + Assert.That(otherString, Is.EqualTo(ts1.UnicodeConst)); + ts1.Dispose(); + Assert.That(otherString, Is.EqualTo(ts2.UnicodeConst)); + ts2.Dispose(); + } + [Test] public void TestEnumProperty() { diff --git a/tests/CSharp/CSharp.cpp b/tests/CSharp/CSharp.cpp index 5f88ec65..68f9f4c9 100644 --- a/tests/CSharp/CSharp.cpp +++ b/tests/CSharp/CSharp.cpp @@ -1452,13 +1452,45 @@ TestString::TestString() : unicodeConst(L"ქართული ენა"), uni { } -void decltypeFunctionPointer() {} +TestString::~TestString() +{ +} -void usesDecltypeFunctionPointer(funcPtr func) {} +TestChar32String::TestChar32String() : + thirtyTwoBitConst(U"ქართული ენა") +{ + static std::u32string nonConst = U"Test String"; + thirtyTwoBitNonConst = &nonConst[0]; +} -TestString::~TestString() +TestChar32String::~TestChar32String() {} +void TestChar32String::UpdateString(const char32_t* s) +{ + static std::u32string nativeOwnedMemory = s; + thirtyTwoBitConst = nativeOwnedMemory.data(); +} + +const char32_t* TestChar32String::RetrieveString() { return thirtyTwoBitConst; } + +TestChar16String::TestChar16String() : + sixteenBitConst(u"ქართული ენა") { + static std::u16string nonConst = u"Test String"; + sixteenBitNonConst = &nonConst[0]; +} + +TestChar16String::~TestChar16String() {} + +void TestChar16String::UpdateString(const char16_t* s) +{ + static std::u16string nativeOwnedMemory = s; + sixteenBitConst = nativeOwnedMemory.data(); } +const char16_t* TestChar16String::RetrieveString() { return sixteenBitConst; } + +void decltypeFunctionPointer() {} + +void usesDecltypeFunctionPointer(funcPtr func) {} PrimaryBaseWithAbstractWithDefaultArg::PrimaryBaseWithAbstractWithDefaultArg() { diff --git a/tests/CSharp/CSharp.h b/tests/CSharp/CSharp.h index 789fc6e8..5c501628 100644 --- a/tests/CSharp/CSharp.h +++ b/tests/CSharp/CSharp.h @@ -1146,6 +1146,31 @@ public: wchar_t* unicode; }; +class DLL_API TestChar32String +{ +public: + TestChar32String(); + ~TestChar32String(); + const char32_t* thirtyTwoBitConst; + char32_t* thirtyTwoBitNonConst; + + void UpdateString(const char32_t* s); + const char32_t* RetrieveString(); +}; + +class DLL_API TestChar16String +{ +public: + TestChar16String(); + ~TestChar16String(); + const char16_t* sixteenBitConst; + char16_t* sixteenBitNonConst; + + void UpdateString(const char16_t* s); + const char16_t* RetrieveString(); +}; + + DLL_API void decltypeFunctionPointer(); using funcPtr = decltype(&decltypeFunctionPointer);