Browse Source

Guard against null for objects passed by value

Fixes #1228

Signed-off-by: Dimitar Dobrev <dpldobrev@protonmail.com>
remove-private-fields
Dimitar Dobrev 6 years ago
parent
commit
153a4d095b
  1. 7
      src/Generator/Generators/CLI/CLIMarshal.cs
  2. 79
      src/Generator/Generators/CSharp/CSharpMarshal.cs
  3. 4
      src/Generator/Passes/CheckAbiParameters.cs
  4. 2
      tests/CSharp/CSharp.Tests.cs
  5. 6
      tests/Common/Common.Tests.cs

7
src/Generator/Generators/CLI/CLIMarshal.cs

@ -679,6 +679,13 @@ namespace CppSharp.Generators.CLI @@ -679,6 +679,13 @@ namespace CppSharp.Generators.CLI
if (Context.Parameter.Type.IsReference())
VarPrefix.Write("&");
else
{
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, nullptr))");
Context.Before.WriteLineIndent(
$@"throw gcnew ::System::ArgumentNullException(""{
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
}
}
if (method != null

79
src/Generator/Generators/CSharp/CSharpMarshal.cs

@ -701,46 +701,61 @@ namespace CppSharp.Generators.CSharp @@ -701,46 +701,61 @@ namespace CppSharp.Generators.CSharp
@interface.IsInterface)
paramInstance = $"{param}.__PointerTo{@interface.OriginalClass.Name}";
else
paramInstance = $@"{param}.{Helpers.InstanceIdentifier}";
if (type.IsAddress())
paramInstance = $"{param}.{Helpers.InstanceIdentifier}";
if (!type.IsAddress())
{
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
Context.Before.WriteLineIndent(
$@"throw new global::System.ArgumentNullException(""{
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
var realClass = @class.OriginalClass ?? @class;
var qualifiedIdentifier = typePrinter.PrintNative(realClass);
Context.Return.Write($"*({qualifiedIdentifier}*) {paramInstance}");
return;
}
Class decl;
if (type.TryGetClass(out decl) && decl.IsValueType)
{
Context.Return.Write(paramInstance);
return;
}
if (type.IsPointer())
{
Class decl;
if (type.TryGetClass(out decl) && decl.IsValueType)
if (Context.Parameter.IsIndirect)
{
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
Context.Before.WriteLineIndent(
$@"throw new global::System.ArgumentNullException(""{
Context.Parameter.Name}"", ""Cannot be null because it is passed by value."");");
Context.Return.Write(paramInstance);
}
else
{
if (type.IsPointer())
{
Context.Return.Write("{0}{1}",
method != null && method.OperatorKind == CXXOperatorKind.EqualEqual
? string.Empty
: $"ReferenceEquals({param}, null) ? global::System.IntPtr.Zero : ",
paramInstance);
}
else
{
if (method == null ||
// redundant for comparison operators, they are handled in a special way
(method.OperatorKind != CXXOperatorKind.EqualEqual &&
method.OperatorKind != CXXOperatorKind.ExclaimEqual))
{
Context.Before.WriteLine("if (ReferenceEquals({0}, null))", param);
Context.Before.WriteLineIndent(
"throw new global::System.ArgumentNullException(\"{0}\", " +
"\"Cannot be null because it is a C++ reference (&).\");",
param);
}
Context.Return.Write(paramInstance);
}
Context.Return.Write("{0}{1}",
method != null && method.OperatorKind == CXXOperatorKind.EqualEqual
? string.Empty
: $"ReferenceEquals({param}, null) ? global::System.IntPtr.Zero : ",
paramInstance);
}
return;
}
var realClass = @class.OriginalClass ?? @class;
var qualifiedIdentifier = typePrinter.PrintNative(realClass);
Context.Return.Write(
"ReferenceEquals({0}, null) ? new {1}() : *({1}*) {2}",
param, qualifiedIdentifier, paramInstance);
if (method == null ||
// redundant for comparison operators, they are handled in a special way
(method.OperatorKind != CXXOperatorKind.EqualEqual &&
method.OperatorKind != CXXOperatorKind.ExclaimEqual))
{
Context.Before.WriteLine($"if (ReferenceEquals({Context.Parameter.Name}, null))");
Context.Before.WriteLineIndent(
$@"throw new global::System.ArgumentNullException(""{
Context.Parameter.Name}"", ""Cannot be null because it is a C++ reference (&)."");",
param);
}
Context.Return.Write(paramInstance);
}
private void MarshalValueClass()

4
src/Generator/Passes/CheckAbiParameters.cs

@ -81,9 +81,7 @@ namespace CppSharp.Passes @@ -81,9 +81,7 @@ namespace CppSharp.Passes
});
}
foreach (var param in from p in function.Parameters
where p.IsIndirect && !p.Type.Desugar().IsAddress()
select p)
foreach (var param in function.Parameters.Where(p => p.IsIndirect))
{
param.QualifiedType = new QualifiedType(new PointerType(param.QualifiedType));
}

2
tests/CSharp/CSharp.Tests.cs

@ -248,7 +248,7 @@ public unsafe class CSharpTests : GeneratorTestFixture @@ -248,7 +248,7 @@ public unsafe class CSharpTests : GeneratorTestFixture
methodsWithDefaultValues.DefaultEmptyEnum();
methodsWithDefaultValues.DefaultRefTypeBeforeOthers();
methodsWithDefaultValues.DefaultRefTypeAfterOthers();
methodsWithDefaultValues.DefaultRefTypeBeforeAndAfterOthers(0, null);
methodsWithDefaultValues.DefaultRefTypeBeforeAndAfterOthers();
methodsWithDefaultValues.DefaultIntAssignedAnEnum();
methodsWithDefaultValues.defaultRefAssignedValue();
methodsWithDefaultValues.DefaultRefAssignedValue();

6
tests/Common/Common.Tests.cs

@ -772,6 +772,12 @@ public class CommonTests : GeneratorTestFixture @@ -772,6 +772,12 @@ public class CommonTests : GeneratorTestFixture
}
}
[Test]
public void TestPassingNullToValue()
{
Assert.Catch<ArgumentNullException>(() => new Bar((Foo) null));
}
[Test]
public void TestNonTrivialDtorInvocation()
{

Loading…
Cancel
Save