diff --git a/src/AST/Type.cs b/src/AST/Type.cs index f5533020..f41b3ca1 100644 --- a/src/AST/Type.cs +++ b/src/AST/Type.cs @@ -152,6 +152,21 @@ namespace CppSharp.AST return this; } + public Type SkipPointerRefs() + { + var type = this as PointerType; + + if (type != null) + { + var pointee = type.Pointee; + + if (type.IsReference()) + return pointee.Desugar().SkipPointerRefs(); + } + + return this; + } + public abstract T Visit(ITypeVisitor visitor, TypeQualifiers quals = new TypeQualifiers()); diff --git a/src/Generator/Generators/CLI/CLIMarshal.cs b/src/Generator/Generators/CLI/CLIMarshal.cs index 89d17feb..1ddebc2e 100644 --- a/src/Generator/Generators/CLI/CLIMarshal.cs +++ b/src/Generator/Generators/CLI/CLIMarshal.cs @@ -575,7 +575,7 @@ namespace CppSharp.Generators.CLI return; } - if (!Context.Parameter.Type.IsPointer()) + if (!Context.Parameter.Type.SkipPointerRefs().IsPointer()) { Context.Return.Write("*"); diff --git a/tests/Basic/Basic.Tests.cs b/tests/Basic/Basic.Tests.cs index e5a613a0..25c19391 100644 --- a/tests/Basic/Basic.Tests.cs +++ b/tests/Basic/Basic.Tests.cs @@ -19,6 +19,7 @@ public class BasicTests : GeneratorTestFixture var foo = new Foo { A = 4, B = 7 }; Assert.That(hello.AddFoo(foo), Is.EqualTo(11)); Assert.That(hello.AddFooPtr(foo), Is.EqualTo(11)); + Assert.That(hello.AddFooPtr(foo), Is.EqualTo(11)); Assert.That(hello.AddFooRef(foo), Is.EqualTo(11)); var bar = new Bar { A = 4, B = 7 }; diff --git a/tests/Basic/Basic.cpp b/tests/Basic/Basic.cpp index 7cd88cbe..17f6850a 100644 --- a/tests/Basic/Basic.cpp +++ b/tests/Basic/Basic.cpp @@ -107,6 +107,11 @@ int Hello::AddFooPtr(Foo* foo) return AddFoo(*foo); } +int Hello::AddFooPtrRef(Foo*& foo) +{ + return AddFoo(*foo); +} + int Hello::AddFoo2(Foo2 foo) { return (int)(foo.A + foo.B + foo.C); diff --git a/tests/Basic/Basic.h b/tests/Basic/Basic.h index 10ee3529..1ea8778d 100644 --- a/tests/Basic/Basic.h +++ b/tests/Basic/Basic.h @@ -103,6 +103,7 @@ public: int AddFoo(Foo); int AddFooRef(Foo&); int AddFooPtr(Foo*); + int AddFooPtrRef(Foo*&); Foo RetFoo(int a, float b); int AddFoo2(Foo2);