Browse Source

Guard against null for objects passed by value

Fixes #1228

Signed-off-by: Dimitar Dobrev <dpldobrev@protonmail.com>
remove-private-fields
Dimitar Dobrev 6 years ago
parent
commit
153a4d095b
  1. 7
      src/Generator/Generators/CLI/CLIMarshal.cs
  2. 53
      src/Generator/Generators/CSharp/CSharpMarshal.cs
  3. 4
      src/Generator/Passes/CheckAbiParameters.cs
  4. 2
      tests/CSharp/CSharp.Tests.cs
  5. 6
      tests/Common/Common.Tests.cs

7
src/Generator/Generators/CLI/CLIMarshal.cs

@ -679,6 +679,13 @@ namespace CppSharp.Generators.CLI
if (Context.Parameter.Type.IsReference()) if (Context.Parameter.Type.IsReference())
VarPrefix.Write("&"); VarPrefix.Write("&");
else
{
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, nullptr))");
Context.Before.WriteLineIndent(
$@"throw gcnew ::System::ArgumentNullException(""{
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
}
} }
if (method != null if (method != null

53
src/Generator/Generators/CSharp/CSharpMarshal.cs

@ -701,15 +701,38 @@ namespace CppSharp.Generators.CSharp
@interface.IsInterface) @interface.IsInterface)
paramInstance = $"{param}.__PointerTo{@interface.OriginalClass.Name}"; paramInstance = $"{param}.__PointerTo{@interface.OriginalClass.Name}";
else else
paramInstance = $@"{param}.{Helpers.InstanceIdentifier}"; paramInstance = $"{param}.{Helpers.InstanceIdentifier}";
if (type.IsAddress())
if (!type.IsAddress())
{ {
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
Context.Before.WriteLineIndent(
$@"throw new global::System.ArgumentNullException(""{
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
var realClass = @class.OriginalClass ?? @class;
var qualifiedIdentifier = typePrinter.PrintNative(realClass);
Context.Return.Write($"*({qualifiedIdentifier}*) {paramInstance}");
return;
}
Class decl; Class decl;
if (type.TryGetClass(out decl) && decl.IsValueType) if (type.TryGetClass(out decl) && decl.IsValueType)
Context.Return.Write(paramInstance);
else
{ {
Context.Return.Write(paramInstance);
return;
}
if (type.IsPointer()) if (type.IsPointer())
{
if (Context.Parameter.IsIndirect)
{
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
Context.Before.WriteLineIndent(
$@"throw new global::System.ArgumentNullException(""{
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
Context.Return.Write(paramInstance);
}
else
{ {
Context.Return.Write("{0}{1}", Context.Return.Write("{0}{1}",
method != null && method.OperatorKind == CXXOperatorKind.EqualEqual method != null && method.OperatorKind == CXXOperatorKind.EqualEqual
@ -717,30 +740,22 @@ namespace CppSharp.Generators.CSharp
: $"ReferenceEquals({param}, null) ? global::System.IntPtr.Zero : ", : $"ReferenceEquals({param}, null) ? global::System.IntPtr.Zero : ",
paramInstance); paramInstance);
} }
else return;
{ }
if (method == null || if (method == null ||
// redundant for comparison operators, they are handled in a special way // redundant for comparison operators, they are handled in a special way
(method.OperatorKind != CXXOperatorKind.EqualEqual && (method.OperatorKind != CXXOperatorKind.EqualEqual &&
method.OperatorKind != CXXOperatorKind.ExclaimEqual)) method.OperatorKind != CXXOperatorKind.ExclaimEqual))
{ {
Context.Before.WriteLine("if (ReferenceEquals({0}, null))", param); Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
Context.Before.WriteLineIndent( Context.Before.WriteLineIndent(
"throw new global::System.ArgumentNullException(\"{0}\", " + $@"throw new global::System.ArgumentNullException(""{
"\"Cannot be null because it is a C++ reference (&).\");", Context.Parameter.Name}"", ""Cannot be null because it is a C++ reference (&)."");",
param); param);
} }
Context.Return.Write(paramInstance);
}
}
return;
}
var realClass = @class.OriginalClass ?? @class; Context.Return.Write(paramInstance);
var qualifiedIdentifier = typePrinter.PrintNative(realClass);
Context.Return.Write(
"ReferenceEquals({0}, null) ? new {1}() : *({1}*) {2}",
param, qualifiedIdentifier, paramInstance);
} }
private void MarshalValueClass() private void MarshalValueClass()

4
src/Generator/Passes/CheckAbiParameters.cs

@ -81,9 +81,7 @@ namespace CppSharp.Passes
}); });
} }
foreach (var param in from p in function.Parameters foreach (var param in function.Parameters.Where(p => p.IsIndirect))
where p.IsIndirect && !p.Type.Desugar().IsAddress()
select p)
{ {
param.QualifiedType = new QualifiedType(new PointerType(param.QualifiedType)); param.QualifiedType = new QualifiedType(new PointerType(param.QualifiedType));
} }

2
tests/CSharp/CSharp.Tests.cs

@ -248,7 +248,7 @@ public unsafe class CSharpTests : GeneratorTestFixture
methodsWithDefaultValues.DefaultEmptyEnum(); methodsWithDefaultValues.DefaultEmptyEnum();
methodsWithDefaultValues.DefaultRefTypeBeforeOthers(); methodsWithDefaultValues.DefaultRefTypeBeforeOthers();
methodsWithDefaultValues.DefaultRefTypeAfterOthers(); methodsWithDefaultValues.DefaultRefTypeAfterOthers();
methodsWithDefaultValues.DefaultRefTypeBeforeAndAfterOthers(0, null); methodsWithDefaultValues.DefaultRefTypeBeforeAndAfterOthers();
methodsWithDefaultValues.DefaultIntAssignedAnEnum(); methodsWithDefaultValues.DefaultIntAssignedAnEnum();
methodsWithDefaultValues.defaultRefAssignedValue(); methodsWithDefaultValues.defaultRefAssignedValue();
methodsWithDefaultValues.DefaultRefAssignedValue(); methodsWithDefaultValues.DefaultRefAssignedValue();

6
tests/Common/Common.Tests.cs

@ -772,6 +772,12 @@ public class CommonTests : GeneratorTestFixture
} }
} }
[Test]
public void TestPassingNullToValue()
{
Assert.Catch<ArgumentNullException>(() => new Bar((Foo) null));
}
[Test] [Test]
public void TestNonTrivialDtorInvocation() public void TestNonTrivialDtorInvocation()
{ {

Loading…
Cancel
Save