diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index f157372f..37661e93 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -13,7 +13,8 @@ namespace CppSharp.Generators.CSharp Unknown, NativeField, GenericDelegate, - DefaultExpression + DefaultExpression, + VTableReturnValue } public class CSharpMarshalContext : MarshalContext @@ -160,7 +161,8 @@ namespace CppSharp.Generators.CSharp var pointee = pointer.Pointee.Desugar(); bool marshalPointeeAsString = CSharpTypePrinter.IsConstCharString(pointee) && isRefParam; - if (CSharpTypePrinter.IsConstCharString(pointer) || marshalPointeeAsString) + if ((CSharpTypePrinter.IsConstCharString(pointer) && !MarshalsParameter) || + marshalPointeeAsString) { Context.Return.Write(MarshalStringToManaged(Context.ReturnVarName, pointer.GetFinalPointee().Desugar() as BuiltinType)); @@ -505,9 +507,7 @@ namespace CppSharp.Generators.CSharp var isRefParam = param != null && (param.IsInOut || param.IsOut); var pointee = pointer.Pointee.Desugar(); - bool marshalPointeeAsString = CSharpTypePrinter.IsConstCharString(pointee) && isRefParam; - - if (CSharpTypePrinter.IsConstCharString(pointer) || marshalPointeeAsString) + if (CSharpTypePrinter.IsConstCharString(pointee) && isRefParam) { if (param.IsOut) { @@ -555,9 +555,11 @@ namespace CppSharp.Generators.CSharp return true; } + var marshalAsString = CSharpTypePrinter.IsConstCharString(pointer); var finalPointee = pointer.GetFinalPointee(); PrimitiveType primitive; - if (finalPointee.IsPrimitiveType(out primitive) || finalPointee.IsEnumType()) + if (finalPointee.IsPrimitiveType(out primitive) || finalPointee.IsEnumType() || + marshalAsString) { // From MSDN: "note that a ref or out parameter is classified as a moveable // variable". This means we must create a local variable to hold the result @@ -576,13 +578,19 @@ namespace CppSharp.Generators.CSharp } else { - if (Context.Driver.Options.MarshalCharAsManagedChar && primitive == PrimitiveType.Char) + if (!marshalAsString && + Context.Driver.Options.MarshalCharAsManagedChar && + primitive == PrimitiveType.Char) { - var typePrinter = new CSharpTypePrinter(Context.Driver); typePrinter.PushContext(CSharpTypePrinterContextKind.Native); Context.Return.Write(string.Format("({0}) ", pointer.Visit(typePrinter))); + typePrinter.PopContext(); } - Context.Return.Write(Context.Parameter.Name); + if (marshalAsString && (Context.Kind == CSharpMarshalKind.NativeField || + Context.Kind == CSharpMarshalKind.VTableReturnValue)) + Context.Return.Write(MarshalStringToUnmanaged(Context.Parameter.Name)); + else + Context.Return.Write(Context.Parameter.Name); } return true; diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 41466111..5de6991a 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -670,8 +670,7 @@ namespace CppSharp.Generators.CSharp { TypePrinter.PushContext(CSharpTypePrinterContextKind.Native); - var retParam = new Parameter { QualifiedType = function.ReturnType }; - retType = retParam.CSharpType(TypePrinter); + retType = function.ReturnType.CSharpType(TypePrinter); var @params = function.GatherInternalParams(Driver.Options.IsItaniumLikeAbi).Select(p => string.Format("{0} {1}", p.CSharpType(TypePrinter), p.Name)).ToList(); @@ -962,7 +961,8 @@ namespace CppSharp.Generators.CSharp if (marshal.Context.Return.StringBuilder.Length > 0) { WriteLine("{0} = {1}{2};", ctx.ReturnVarName, - field.Type.IsPointer() && field.Type.GetFinalPointee().IsPrimitiveType() ? + field.Type.IsPointer() && field.Type.GetFinalPointee().IsPrimitiveType() && + !CSharpTypePrinter.IsConstCharString(field.Type) ? string.Format("({0}) ", CSharpTypePrinter.IntPtrType) : string.Empty, marshal.Context.Return); @@ -1659,7 +1659,8 @@ namespace CppSharp.Generators.CSharp { ArgName = Helpers.ReturnIdentifier, Parameter = param, - Function = method + Function = method, + Kind = CSharpMarshalKind.VTableReturnValue }; var marshal = new CSharpMarshalManagedToNativePrinter(ctx); diff --git a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs index 86500e41..672cf4dc 100644 --- a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs +++ b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs @@ -6,6 +6,7 @@ using CppSharp.Types; using Type = CppSharp.AST.Type; using ParserTargetInfo = CppSharp.Parser.ParserTargetInfo; using System.Linq; +using System.Text; namespace CppSharp.Generators.CSharp { @@ -265,7 +266,19 @@ namespace CppSharp.Generators.CSharp var isManagedContext = ContextKind == CSharpTypePrinterContextKind.Managed; if (allowStrings && IsConstCharString(pointer)) - return isManagedContext ? "string" : IntPtrType; + { + if (isManagedContext || MarshalKind == CSharpMarshalKind.GenericDelegate) + return "string"; + if (Context.Parameter == null || Context.Parameter.Name == Helpers.ReturnIdentifier) + return IntPtrType; + if (driver.Options.Encoding == Encoding.ASCII) + return string.Format("[MarshalAs(UnmanagedType.LPStr)] string"); + if (driver.Options.Encoding == Encoding.Unicode || + driver.Options.Encoding == Encoding.BigEndianUnicode) + return string.Format("[MarshalAs(UnmanagedType.LPWStr)] string"); + throw new NotSupportedException(string.Format("{0} is not supported yet.", + driver.Options.Encoding.EncodingName)); + } var desugared = pointee.Desugar(); @@ -288,6 +301,9 @@ namespace CppSharp.Generators.CSharp if (pointee.IsPrimitiveType(PrimitiveType.Void)) return IntPtrType; + if (IsConstCharString(pointee) && isRefParam) + return IntPtrType + "*"; + // Do not allow strings inside primitive arrays case, else we'll get invalid types // like string* for const char **. allowStrings = isRefParam; @@ -300,9 +316,6 @@ namespace CppSharp.Generators.CSharp Enumeration @enum; if (desugared.TryGetEnum(out @enum)) { - if (MarshalKind == CSharpMarshalKind.GenericDelegate && isManagedContext) - return IntPtrType; - // Skip one indirection if passed by reference var param = Context.Parameter; if (isManagedContext && param != null && (param.IsOut || param.IsInOut) diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index def5de6c..0d187f99 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -555,6 +555,29 @@ public unsafe class CSharpTests : GeneratorTestFixture } } + [Test] + public void TestOverrideVirtualWithString() + { + using (var overrideVirtualWithString = new OverrideVirtualWithString()) + { + Assert.That(overrideVirtualWithString.CallsVirtualToReturnString("test"), Is.EqualTo("test_test")); + Assert.IsFalse(overrideVirtualWithString.CallsVirtualToReturnBool(true)); + } + } + + private class OverrideVirtualWithString : HasVirtualTakesReturnsProblematicTypes + { + public override string VirtualTakesAndReturnsString(string c) + { + return "test_test"; + } + + public override bool VirtualTakesAndReturnsBool(bool b) + { + return !base.VirtualTakesAndReturnsBool(b); + } + } + private class GetEnumFromNativePointer : UsesPointerToEnumInParamOfVirtual { public override Flags HasPointerToEnumInParam(Flags pointerToEnum) diff --git a/tests/CSharp/CSharp.cpp b/tests/CSharp/CSharp.cpp index 2319a81a..fc686a34 100644 --- a/tests/CSharp/CSharp.cpp +++ b/tests/CSharp/CSharp.cpp @@ -1036,3 +1036,31 @@ bool VirtualDtorAddedInDerived::dtorCalled = false; void NamespaceB::B::Function(CS_OUT NamespaceA::A &a) { } + +HasVirtualTakesReturnsProblematicTypes::HasVirtualTakesReturnsProblematicTypes() +{ +} + +HasVirtualTakesReturnsProblematicTypes::~HasVirtualTakesReturnsProblematicTypes() +{ +} + +const char* HasVirtualTakesReturnsProblematicTypes::virtualTakesAndReturnsString(const char* c) +{ + return c; +} + +const char* HasVirtualTakesReturnsProblematicTypes::callsVirtualToReturnString(const char* c) +{ + return virtualTakesAndReturnsString(c); +} + +bool HasVirtualTakesReturnsProblematicTypes::virtualTakesAndReturnsBool(bool b) +{ + return b; +} + +bool HasVirtualTakesReturnsProblematicTypes::callsVirtualToReturnBool(bool b) +{ + return virtualTakesAndReturnsBool(b); +} diff --git a/tests/CSharp/CSharp.h b/tests/CSharp/CSharp.h index e933f71c..4f30b830 100644 --- a/tests/CSharp/CSharp.h +++ b/tests/CSharp/CSharp.h @@ -969,3 +969,14 @@ private: }; class ForwardInOtherUnitButSameModule; + +class DLL_API HasVirtualTakesReturnsProblematicTypes +{ +public: + HasVirtualTakesReturnsProblematicTypes(); + ~HasVirtualTakesReturnsProblematicTypes(); + virtual const char* virtualTakesAndReturnsString(const char* c); + const char* callsVirtualToReturnString(const char* c); + virtual bool virtualTakesAndReturnsBool(bool b); + bool callsVirtualToReturnBool(bool b); +};