From d8b855bfe6b908ab6b5656acd86918b8887b42f8 Mon Sep 17 00:00:00 2001 From: Elias Holzer Date: Tue, 15 Apr 2014 21:23:49 +0200 Subject: [PATCH] Fixed handling of primitive pointer types. --- src/AST/TypeExtensions.cs | 28 +++++++++++++++++++ src/Generator/Generators/CLI/CLIMarshal.cs | 8 +++--- .../Generators/CLI/CLITypePrinter.cs | 24 ++++++++++------ .../Generators/CSharp/CSharpTypePrinter.cs | 22 ++++++++++----- tests/Basic/Basic.Tests.cs | 12 +++++++- tests/Basic/Basic.cpp | 5 ++++ tests/Basic/Basic.h | 6 +++- 7 files changed, 84 insertions(+), 21 deletions(-) diff --git a/src/AST/TypeExtensions.cs b/src/AST/TypeExtensions.cs index 9165713f..837c47f7 100644 --- a/src/AST/TypeExtensions.cs +++ b/src/AST/TypeExtensions.cs @@ -154,6 +154,34 @@ } return t; + } + + public static Type GetFinalPointee(this PointerType pointer) + { + var pointee = pointer.Pointee; + while (pointee.IsPointer()) + { + var p = pointee as PointerType; + if (p != null) + pointee = p.Pointee; + else + return GetFinalPointee(pointee as MemberPointerType); + } + return pointee; + } + + public static Type GetFinalPointee(this MemberPointerType pointer) + { + var pointee = pointer.Pointee; + while (pointee.IsPointer()) + { + var p = pointee as MemberPointerType; + if (p != null) + pointee = p.Pointee; + else + return GetFinalPointee(pointee as PointerType); + } + return pointee; } } } \ No newline at end of file diff --git a/src/Generator/Generators/CLI/CLIMarshal.cs b/src/Generator/Generators/CLI/CLIMarshal.cs index 45eda0ca..a1f5d75f 100644 --- a/src/Generator/Generators/CLI/CLIMarshal.cs +++ b/src/Generator/Generators/CLI/CLIMarshal.cs @@ -427,10 +427,10 @@ namespace CppSharp.Generators.CLI if (Context.Function == null) Context.Return.Write("&"); return pointee.Visit(this, quals); - } - - PrimitiveType primitive; - if (pointee.IsPrimitiveType(out primitive)) + } + + var finalPointee = pointer.GetFinalPointee(); + if (finalPointee.IsPrimitiveType()) { var cppTypePrinter = new CppTypePrinter(Context.Driver.TypeDatabase); var cppTypeName = pointer.Visit(cppTypePrinter, quals); diff --git a/src/Generator/Generators/CLI/CLITypePrinter.cs b/src/Generator/Generators/CLI/CLITypePrinter.cs index 9875cfbe..c97210b1 100644 --- a/src/Generator/Generators/CLI/CLITypePrinter.cs +++ b/src/Generator/Generators/CLI/CLITypePrinter.cs @@ -140,16 +140,24 @@ namespace CppSharp.Generators.CLI if (pointee.IsPrimitiveType(PrimitiveType.Char) && quals.IsConst) { return "System::String^"; - } - - PrimitiveType primitive; - if (pointee.Desugar().IsPrimitiveType(out primitive)) + } + + // From http://msdn.microsoft.com/en-us/library/y31yhkeb.aspx + // Any of the following types may be a pointer type: + // * sbyte, byte, short, ushort, int, uint, long, ulong, char, float, double, decimal, or bool. + // * Any enum type. + // * Any pointer type. + // * Any user-defined struct type that contains fields of unmanaged types only. + var finalPointee = pointer.GetFinalPointee(); + if (finalPointee.IsPrimitiveType()) { - var param = Context.Parameter; - if (param != null && (param.IsOut || param.IsInOut)) - return VisitPrimitiveType(primitive); + // Skip one indirection if passed by reference + var param = Context.Parameter; + if (param != null && (param.IsOut || param.IsInOut) + && pointee == finalPointee) + return pointee.Visit(this, quals); - return VisitPrimitiveType(primitive, quals) + "*"; + return pointee.Visit(this, quals) + "*"; } return pointee.Visit(this, quals); diff --git a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs index b0c098fb..dfdef8b0 100644 --- a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs +++ b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs @@ -208,21 +208,29 @@ namespace CppSharp.Generators.CSharp if (IsConstCharString(pointer)) return isManagedContext ? "string" : "global::System.IntPtr"; - PrimitiveType primitive; - var desugared = pointee.Desugar(); - if (desugared.IsPrimitiveType(out primitive)) + // From http://msdn.microsoft.com/en-us/library/y31yhkeb.aspx + // Any of the following types may be a pointer type: + // * sbyte, byte, short, ushort, int, uint, long, ulong, char, float, double, decimal, or bool. + // * Any enum type. + // * Any pointer type. + // * Any user-defined struct type that contains fields of unmanaged types only. + var finalPointee = pointer.GetFinalPointee(); + if (finalPointee.IsPrimitiveType()) { - if (isManagedContext && Context.Parameter != null && - (Context.Parameter.IsOut || Context.Parameter.IsInOut)) - return VisitPrimitiveType(primitive, quals); + // Skip one indirection if passed by reference + var param = Context.Parameter; + if (isManagedContext && param != null && (param.IsOut || param.IsInOut) + && pointee == finalPointee) + return pointee.Visit(this, quals); if (ContextKind == CSharpTypePrinterContextKind.GenericDelegate) return "global::System.IntPtr"; - return VisitPrimitiveType(primitive, quals) + "*"; + return pointee.Visit(this, quals) + "*"; } Class @class; + var desugared = pointee.Desugar(); if ((desugared.IsDependent || desugared.IsTagDecl(out @class)) && ContextKind == CSharpTypePrinterContextKind.Native) { diff --git a/tests/Basic/Basic.Tests.cs b/tests/Basic/Basic.Tests.cs index 19095bb0..4f410326 100644 --- a/tests/Basic/Basic.Tests.cs +++ b/tests/Basic/Basic.Tests.cs @@ -22,7 +22,17 @@ public class BasicTests : GeneratorTestFixture Assert.That(hello.AddFoo(foo), Is.EqualTo(11)); Assert.That(hello.AddFooPtr(foo), Is.EqualTo(11)); Assert.That(hello.AddFooPtr(foo), Is.EqualTo(11)); - Assert.That(hello.AddFooRef(foo), Is.EqualTo(11)); + Assert.That(hello.AddFooRef(foo), Is.EqualTo(11)); + unsafe + { + var pointer = foo.SomePointer; + var pointerPointer = foo.SomePointerPointer; + for (int i = 0; i < 4; i++) + { + Assert.AreEqual(i, pointer[i]); + Assert.AreEqual(i, (*pointerPointer)[i]); + } + } var bar = new Bar { A = 4, B = 7 }; Assert.That(hello.AddBar(bar), Is.EqualTo(11)); diff --git a/tests/Basic/Basic.cpp b/tests/Basic/Basic.cpp index 54237caf..15fd7369 100644 --- a/tests/Basic/Basic.cpp +++ b/tests/Basic/Basic.cpp @@ -2,6 +2,11 @@ Foo::Foo() { + auto p = new int[4]; + for (int i = 0; i < 4; i++) + p[i] = i; + SomePointer = p; + SomePointerPointer = &SomePointer; } const char* Foo::GetANSI() diff --git a/tests/Basic/Basic.h b/tests/Basic/Basic.h index b47d7471..23fd08c8 100644 --- a/tests/Basic/Basic.h +++ b/tests/Basic/Basic.h @@ -18,7 +18,11 @@ public: // TODO: VC++ does not support char16 // char16 chr16; - float nested_array[2][2]; + // Not properly handled yet - ignore + float nested_array[2][2]; + // Primitive pointer types + int* SomePointer; + int** SomePointerPointer; }; struct DLL_API Bar