Browse Source

Fixed 4 bugs related to custom comparison operators.

1. Missing Equals to complement operator ==;
2. Endless recursion when having a == and comparing to null;
3. Crash when having a == and comparing a null pointer to sth else;
4. Incompilable code with a type derived from a struct with operators.

Signed-off-by: Dimitar Dobrev <dpldobrev@yahoo.com>
pull/395/head
Dimitar Dobrev 11 years ago committed by Joao Matos
parent
commit
e6223a1c4c
  1. 1
      src/AST/Declaration.cs
  2. 32
      src/Generator/Generators/CLI/CLIHeadersTemplate.cs
  3. 34
      src/Generator/Generators/CLI/CLISourcesTemplate.cs
  4. 6
      src/Generator/Generators/CSharp/CSharpMarshal.cs
  5. 58
      src/Generator/Generators/CSharp/CSharpTextTemplate.cs
  6. 6
      tests/Basic/Basic.Tests.cs
  7. 10
      tests/Basic/Basic.cpp
  8. 4
      tests/Basic/Basic.h

1
src/AST/Declaration.cs

@ -265,7 +265,6 @@ namespace CppSharp.AST @@ -265,7 +265,6 @@ namespace CppSharp.AST
set { if (value) ExplicitlyIgnore(); }
}
[Obsolete("Replace set by ExplicitlyIgnore(). Replace get by GenerationKind == GenerationKind.None.")]
public virtual bool Ignore
{
get { return GenerationKind == GenerationKind.None; }

32
src/Generator/Generators/CLI/CLIHeadersTemplate.cs

@ -263,7 +263,7 @@ namespace CppSharp.Generators.CLI @@ -263,7 +263,7 @@ namespace CppSharp.Generators.CLI
GenerateClassProperties(@class);
GenerateClassEvents(@class);
GenerateClassMethods(@class);
GenerateClassMethods(@class.Methods);
if (Options.GenerateFunctionTemplates)
GenerateClassGenericMethods(@class);
@ -487,16 +487,21 @@ namespace CppSharp.Generators.CLI @@ -487,16 +487,21 @@ namespace CppSharp.Generators.CLI
}
}
public void GenerateClassMethods(Class @class)
public void GenerateClassMethods(List<Method> methods)
{
if (methods.Count == 0)
return;
PushIndent();
var @class = (Class) methods[0].Namespace;
if (@class.IsValueType)
foreach (var @base in @class.Bases.Where(b => b.IsClass && !b.Class.Ignore))
GenerateClassMethods(@base.Class);
GenerateClassMethods(@base.Class.Methods.Where(m => !m.IsOperator).ToList());
var staticMethods = new List<Method>();
foreach (var method in @class.Methods)
foreach (var method in methods)
{
if (ASTUtils.CheckIgnoreMethod(method, Options))
continue;
@ -710,9 +715,28 @@ namespace CppSharp.Generators.CLI @@ -710,9 +715,28 @@ namespace CppSharp.Generators.CLI
WriteLine(";");
if (method.OperatorKind == CXXOperatorKind.EqualEqual)
{
GenerateEquals(method, (Class) method.Namespace);
}
PopBlock(NewLineKind.BeforeNextBlock);
}
private void GenerateEquals(Function method, Class @class)
{
Class leftHandSide;
Class rightHandSide;
if (method.Parameters[0].Type.SkipPointerRefs().TryGetClass(out leftHandSide) &&
leftHandSide.OriginalPtr == @class.OriginalPtr &&
method.Parameters[1].Type.SkipPointerRefs().TryGetClass(out rightHandSide) &&
rightHandSide.OriginalPtr == @class.OriginalPtr)
{
NewLine();
WriteLine("virtual bool Equals(::System::Object^ obj) override;");
}
}
public bool GenerateTypedef(TypedefDecl typedef)
{
if (!typedef.IsGenerated)

34
src/Generator/Generators/CLI/CLISourcesTemplate.cs

@ -212,7 +212,7 @@ namespace CppSharp.Generators.CLI @@ -212,7 +212,7 @@ namespace CppSharp.Generators.CLI
foreach (var @base in @class.Bases.Where(b => b.IsClass && !b.Class.Ignore))
GenerateClassMethods(@base.Class, realOwner);
foreach (var method in @class.Methods)
foreach (var method in @class.Methods.Where(m => @class == realOwner || !m.IsOperator))
{
if (ASTUtils.CheckIgnoreMethod(method, Options))
continue;
@ -775,9 +775,41 @@ namespace CppSharp.Generators.CLI @@ -775,9 +775,41 @@ namespace CppSharp.Generators.CLI
PopBlock();
WriteCloseBraceIndent();
if (method.OperatorKind == CXXOperatorKind.EqualEqual)
{
GenerateEquals(method, @class);
}
PopBlock(NewLineKind.Always);
}
private void GenerateEquals(Function method, Class @class)
{
Class leftHandSide;
Class rightHandSide;
if (method.Parameters[0].Type.SkipPointerRefs().TryGetClass(out leftHandSide) &&
leftHandSide.OriginalPtr == @class.OriginalPtr &&
method.Parameters[1].Type.SkipPointerRefs().TryGetClass(out rightHandSide) &&
rightHandSide.OriginalPtr == @class.OriginalPtr)
{
NewLine();
var qualifiedIdentifier = QualifiedIdentifier(@class);
WriteLine("bool {0}::Equals(::System::Object^ obj)", qualifiedIdentifier);
WriteStartBraceIndent();
if (@class.IsRefType)
{
WriteLine("return this == dynamic_cast<{0}^>(obj);", qualifiedIdentifier);
}
else
{
WriteLine("if (!(obj is {0})) return false;", @class.Name);
WriteLine("return this == ({0}) obj;", @class.Name);
}
WriteCloseBraceIndent();
}
}
private void GenerateValueTypeConstructorCall(Method method, Class @class)
{
var names = new List<string>();

6
src/Generator/Generators/CSharp/CSharpMarshal.cs

@ -568,7 +568,11 @@ namespace CppSharp.Generators.CSharp @@ -568,7 +568,11 @@ namespace CppSharp.Generators.CSharp
if (type.TryGetClass(out decl) && decl.IsValueType)
Context.Return.Write("{0}.{1}", param, Helpers.InstanceIdentifier);
else
Context.Return.Write("{0} == ({2}) null ? global::System.IntPtr.Zero : {0}.{1}", param,
Context.Return.Write("{0}{1}.{2}",
method != null && method.OperatorKind == CXXOperatorKind.EqualEqual
? string.Empty
: string.Format("ReferenceEquals({0}, null) ? global::System.IntPtr.Zero : ", param),
param,
Helpers.InstanceIdentifier, type);
return;
}

58
src/Generator/Generators/CSharp/CSharpTextTemplate.cs

@ -358,7 +358,7 @@ namespace CppSharp.Generators.CSharp @@ -358,7 +358,7 @@ namespace CppSharp.Generators.CSharp
if (@class.IsUnion)
GenerateUnionFields(@class);
GenerateClassMethods(@class);
GenerateClassMethods(@class.Methods);
GenerateClassVariables(@class);
GenerateClassProperties(@class);
@ -1133,14 +1133,19 @@ namespace CppSharp.Generators.CSharp @@ -1133,14 +1133,19 @@ namespace CppSharp.Generators.CSharp
return false;
}
public void GenerateClassMethods(Class @class)
public void GenerateClassMethods(IList<Method> methods)
{
if (methods.Count == 0)
return;
var @class = (Class) methods[0].Namespace;
if (@class.IsValueType)
foreach (var @base in @class.Bases.Where(b => b.IsClass && !b.Class.Ignore))
GenerateClassMethods(@base.Class);
GenerateClassMethods(@base.Class.Methods.Where(m => !m.IsOperator).ToList());
var staticMethods = new List<Method>();
foreach (var method in @class.Methods)
foreach (var method in methods)
{
if (ASTUtils.CheckIgnoreMethod(method, Options))
continue;
@ -2112,7 +2117,7 @@ namespace CppSharp.Generators.CSharp @@ -2112,7 +2117,7 @@ namespace CppSharp.Generators.CSharp
string.Join(", ",
method.Parameters.Where(
p => p.Kind == ParameterKind.Regular).Select(
p => p.GenerationKind == GenerationKind.None ? p.DefaultArgument.String : p.Name)));
p => p.Ignore ? p.DefaultArgument.String : p.Name)));
}
if (Driver.Options.GenerateAbstractImpls && method.IsPure)
@ -2142,7 +2147,7 @@ namespace CppSharp.Generators.CSharp @@ -2142,7 +2147,7 @@ namespace CppSharp.Generators.CSharp
string.Join(", ",
method.Parameters.Where(
p => p.Kind == ParameterKind.Regular).Select(
p => p.GenerationKind == GenerationKind.None ? p.DefaultArgument.String : p.Name)));
p => p.Ignore ? p.DefaultArgument.String : p.Name)));
}
goto SkipImpl;
}
@ -2185,9 +2190,40 @@ namespace CppSharp.Generators.CSharp @@ -2185,9 +2190,40 @@ namespace CppSharp.Generators.CSharp
SkipImpl:
WriteCloseBraceIndent();
if (method.OperatorKind == CXXOperatorKind.EqualEqual)
{
GenerateEquals(method, @class);
}
PopBlock(NewLineKind.BeforeNextBlock);
}
private void GenerateEquals(Function method, Class @class)
{
Class leftHandSide;
Class rightHandSide;
if (method.Parameters[0].Type.SkipPointerRefs().TryGetClass(out leftHandSide) &&
leftHandSide.OriginalPtr == @class.OriginalPtr &&
method.Parameters[1].Type.SkipPointerRefs().TryGetClass(out rightHandSide) &&
rightHandSide.OriginalPtr == @class.OriginalPtr)
{
NewLine();
WriteLine("public override bool Equals(object obj)");
WriteStartBraceIndent();
if (@class.IsRefType)
{
WriteLine("return this == obj as {0};", @class.Name);
}
else
{
WriteLine("if (!(obj is {0})) return false;", @class.Name);
WriteLine("return this == ({0}) obj;", @class.Name);
}
WriteCloseBraceIndent();
}
}
private void CheckArgumentRange(Function method)
{
if (Driver.Options.MarshalCharAsManagedChar)
@ -2266,6 +2302,14 @@ namespace CppSharp.Generators.CSharp @@ -2266,6 +2302,14 @@ namespace CppSharp.Generators.CSharp
return;
}
if (method.OperatorKind == CXXOperatorKind.EqualEqual)
{
WriteLine("bool {0}Null = ReferenceEquals({0}, null);", method.Parameters[0].Name);
WriteLine("bool {0}Null = ReferenceEquals({0}, null);", method.Parameters[1].Name);
WriteLine("if ({0}Null || {1}Null)", method.Parameters[0].Name, method.Parameters[1].Name);
WriteLineIndent("return {0}Null && {1}Null;", method.Parameters[0].Name, method.Parameters[1].Name);
}
GenerateInternalFunctionCall(method);
}
@ -2407,7 +2451,7 @@ namespace CppSharp.Generators.CSharp @@ -2407,7 +2451,7 @@ namespace CppSharp.Generators.CSharp
{
if (operatorParam != null)
{
names.Insert(instanceIndex, operatorParam.Name + "." + Helpers.InstanceIdentifier);
names.Insert(instanceIndex, @params[0].Name);
}
else
{

6
tests/Basic/Basic.Tests.cs

@ -458,5 +458,11 @@ public class BasicTests : GeneratorTestFixture @@ -458,5 +458,11 @@ public class BasicTests : GeneratorTestFixture
{
new InvokesInternalCtorAmbiguity().InvokeInternalCtor();
}
[Test]
public void TestEqualityOperator()
{
Assert.AreEqual(new Foo { A = 5, B = 5.5f }, new Foo { A = 5, B = 5.5f });
}
}

10
tests/Basic/Basic.cpp

@ -23,6 +23,11 @@ void Foo::TakesTypedefedPtr(FooPtr date) @@ -23,6 +23,11 @@ void Foo::TakesTypedefedPtr(FooPtr date)
{
}
bool Foo::operator ==(const Foo& other) const
{
return A == other.A && B == other.B;
}
Foo2::Foo2() {}
Foo2 Foo2::operator<<(signed int i)
@ -62,6 +67,11 @@ Bar* Bar::returnPointerToValueType() @@ -62,6 +67,11 @@ Bar* Bar::returnPointerToValueType()
return this;
}
bool Bar::operator ==(const Bar& other) const
{
return A == other.A && B == other.B;
}
Bar2::Nested::operator int() const
{
return 300;

4
tests/Basic/Basic.h

@ -45,6 +45,8 @@ public: @@ -45,6 +45,8 @@ public:
typedef Foo* FooPtr;
void TakesTypedefedPtr(FooPtr date);
bool operator ==(const Foo& other) const;
};
struct DLL_API Bar
@ -61,6 +63,8 @@ struct DLL_API Bar @@ -61,6 +63,8 @@ struct DLL_API Bar
float B;
Bar* returnPointerToValueType();
bool operator ==(const Bar& other) const;
};
class DLL_API Foo2 : public Foo

Loading…
Cancel
Save