From 153a4d095be879997dd821830ba035c6e96ca819 Mon Sep 17 00:00:00 2001 From: Dimitar Dobrev Date: Fri, 28 Jun 2019 00:46:18 +0300 Subject: [PATCH] Guard against null for objects passed by value Fixes #1228 Signed-off-by: Dimitar Dobrev --- src/Generator/Generators/CLI/CLIMarshal.cs | 7 ++ .../Generators/CSharp/CSharpMarshal.cs | 79 +++++++++++-------- src/Generator/Passes/CheckAbiParameters.cs | 4 +- tests/CSharp/CSharp.Tests.cs | 2 +- tests/Common/Common.Tests.cs | 6 ++ 5 files changed, 62 insertions(+), 36 deletions(-) diff --git a/src/Generator/Generators/CLI/CLIMarshal.cs b/src/Generator/Generators/CLI/CLIMarshal.cs index 3d21a5de..ae391492 100644 --- a/src/Generator/Generators/CLI/CLIMarshal.cs +++ b/src/Generator/Generators/CLI/CLIMarshal.cs @@ -679,6 +679,13 @@ namespace CppSharp.Generators.CLI if (Context.Parameter.Type.IsReference()) VarPrefix.Write("&"); + else + { + Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, nullptr))"); + Context.Before.WriteLineIndent( + $@"throw gcnew ::System::ArgumentNullException(""{ + Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");"); + } } if (method != null diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index 0ae0086d..dd54a28b 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -701,46 +701,61 @@ namespace CppSharp.Generators.CSharp @interface.IsInterface) paramInstance = $"{param}.__PointerTo{@interface.OriginalClass.Name}"; else - paramInstance = $@"{param}.{Helpers.InstanceIdentifier}"; - if (type.IsAddress()) + paramInstance = $"{param}.{Helpers.InstanceIdentifier}"; + + if (!type.IsAddress()) + { + Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))"); + Context.Before.WriteLineIndent( + $@"throw new global::System.ArgumentNullException(""{ + Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");"); + var realClass = @class.OriginalClass ?? @class; + var qualifiedIdentifier = typePrinter.PrintNative(realClass); + Context.Return.Write($"*({qualifiedIdentifier}*) {paramInstance}"); + return; + } + + Class decl; + if (type.TryGetClass(out decl) && decl.IsValueType) + { + Context.Return.Write(paramInstance); + return; + } + + if (type.IsPointer()) { - Class decl; - if (type.TryGetClass(out decl) && decl.IsValueType) + if (Context.Parameter.IsIndirect) + { + Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))"); + Context.Before.WriteLineIndent( + $@"throw new global::System.ArgumentNullException(""{ + Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");"); Context.Return.Write(paramInstance); + } else { - if (type.IsPointer()) - { - Context.Return.Write("{0}{1}", - method != null && method.OperatorKind == CXXOperatorKind.EqualEqual - ? string.Empty - : $"ReferenceEquals({param}, null) ? global::System.IntPtr.Zero : ", - paramInstance); - } - else - { - if (method == null || - // redundant for comparison operators, they are handled in a special way - (method.OperatorKind != CXXOperatorKind.EqualEqual && - method.OperatorKind != CXXOperatorKind.ExclaimEqual)) - { - Context.Before.WriteLine("if (ReferenceEquals({0}, null))", param); - Context.Before.WriteLineIndent( - "throw new global::System.ArgumentNullException(\"{0}\", " + - "\"Cannot be null because it is a C++ reference (&).\");", - param); - } - Context.Return.Write(paramInstance); - } + Context.Return.Write("{0}{1}", + method != null && method.OperatorKind == CXXOperatorKind.EqualEqual + ? string.Empty + : $"ReferenceEquals({param}, null) ? global::System.IntPtr.Zero : ", + paramInstance); } return; } - var realClass = @class.OriginalClass ?? @class; - var qualifiedIdentifier = typePrinter.PrintNative(realClass); - Context.Return.Write( - "ReferenceEquals({0}, null) ? new {1}() : *({1}*) {2}", - param, qualifiedIdentifier, paramInstance); + if (method == null || + // redundant for comparison operators, they are handled in a special way + (method.OperatorKind != CXXOperatorKind.EqualEqual && + method.OperatorKind != CXXOperatorKind.ExclaimEqual)) + { + Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))"); + Context.Before.WriteLineIndent( + $@"throw new global::System.ArgumentNullException(""{ + Context.Parameter.Name}"", ""Cannot be null because it is a C++ reference (&)."");", + param); + } + + Context.Return.Write(paramInstance); } private void MarshalValueClass() diff --git a/src/Generator/Passes/CheckAbiParameters.cs b/src/Generator/Passes/CheckAbiParameters.cs index eb454283..396a4b01 100644 --- a/src/Generator/Passes/CheckAbiParameters.cs +++ b/src/Generator/Passes/CheckAbiParameters.cs @@ -81,9 +81,7 @@ namespace CppSharp.Passes }); } - foreach (var param in from p in function.Parameters - where p.IsIndirect && !p.Type.Desugar().IsAddress() - select p) + foreach (var param in function.Parameters.Where(p => p.IsIndirect)) { param.QualifiedType = new QualifiedType(new PointerType(param.QualifiedType)); } diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index 0149ecf6..5f6b88a8 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -248,7 +248,7 @@ public unsafe class CSharpTests : GeneratorTestFixture methodsWithDefaultValues.DefaultEmptyEnum(); methodsWithDefaultValues.DefaultRefTypeBeforeOthers(); methodsWithDefaultValues.DefaultRefTypeAfterOthers(); - methodsWithDefaultValues.DefaultRefTypeBeforeAndAfterOthers(0, null); + methodsWithDefaultValues.DefaultRefTypeBeforeAndAfterOthers(); methodsWithDefaultValues.DefaultIntAssignedAnEnum(); methodsWithDefaultValues.defaultRefAssignedValue(); methodsWithDefaultValues.DefaultRefAssignedValue(); diff --git a/tests/Common/Common.Tests.cs b/tests/Common/Common.Tests.cs index 7d351726..65b9c455 100644 --- a/tests/Common/Common.Tests.cs +++ b/tests/Common/Common.Tests.cs @@ -772,6 +772,12 @@ public class CommonTests : GeneratorTestFixture } } + [Test] + public void TestPassingNullToValue() + { + Assert.Catch(() => new Bar((Foo) null)); + } + [Test] public void TestNonTrivialDtorInvocation() {