diff --git a/src/Generator/Passes/MoveOperatorToClassPass.cs b/src/Generator/Passes/MoveOperatorToClassPass.cs index b35753b5..eddeadf1 100644 --- a/src/Generator/Passes/MoveOperatorToClassPass.cs +++ b/src/Generator/Passes/MoveOperatorToClassPass.cs @@ -1,6 +1,4 @@ -using System.Collections.Generic; -using System.Linq; -using CppSharp.AST; +using CppSharp.AST; namespace CppSharp.Passes { @@ -17,14 +15,19 @@ namespace CppSharp.Passes if (function.Ignore || !function.IsOperator) return false; - var param = function.Parameters[0]; + Class @class = null; + foreach (var param in function.Parameters) + { + FunctionToInstanceMethodPass.GetClassParameter( + param, out @class); + + if (@class != null) break; + } - Class @class; - if (!FunctionToInstanceMethodPass.GetClassParameter(param, out @class)) + if (@class == null) return false; // Create a new fake method so it acts as a static method. - var method = new Method(function) { Namespace = @class, diff --git a/tests/Basic/Basic.Tests.cs b/tests/Basic/Basic.Tests.cs index 5b203a20..fcc3f5b7 100644 --- a/tests/Basic/Basic.Tests.cs +++ b/tests/Basic/Basic.Tests.cs @@ -71,25 +71,6 @@ public class BasicTests Assert.That(hello.RetNull(), Is.Null); } - [Test] - public void TestUnaryOperator() - { - var bar = new Bar { A = 4, B = 7 }; - var barMinus = -bar; - Assert.That(barMinus.A, Is.EqualTo(-bar.A)); - Assert.That(barMinus.B, Is.EqualTo(-bar.B)); - } - - [Test] - public void TestBinaryOperator() - { - var bar = new Bar { A = 4, B = 7 }; - var bar1 = new Bar { A = 5, B = 10 }; - var barSum = bar + bar1; - Assert.That(barSum.A, Is.EqualTo(bar.A + bar1.A)); - Assert.That(barSum.B, Is.EqualTo(bar.B + bar1.B)); - } - [Test] public void TestAmbiguous() { @@ -123,6 +104,30 @@ public class BasicTests Assert.That(foo.GetANSI(), Is.EqualTo("ANSI")); } + [Test] + public void TestMoveOperatorToClass() + { + // Unary operator + var unary = new TestMoveOperatorToClass() { A = 4, B = 7 }; + var unaryMinus = -unary; + + Assert.That(unaryMinus.A, Is.EqualTo(-unary.A)); + Assert.That(unaryMinus.B, Is.EqualTo(-unary.B)); + + // Binary operator + var bin = new TestMoveOperatorToClass { A = 4, B = 7 }; + var bin1 = new TestMoveOperatorToClass { A = 5, B = 10 }; + var binSum = bin + bin1; + + Assert.That(binSum.A, Is.EqualTo(bin.A + bin1.A)); + Assert.That(binSum.B, Is.EqualTo(bin.B + bin1.B)); + + // Multiple argument operator + var multiArg = new TestMoveOperatorToClass { A = 4, B = 7 }; + var multiArgStar = multiArg * 2; + Assert.That(multiArgStar, Is.EqualTo(8)); + } + [Test] public void TestMoveFunctionToClass() { diff --git a/tests/Basic/Basic.cpp b/tests/Basic/Basic.cpp index 390b6c4b..c1677463 100644 --- a/tests/Basic/Basic.cpp +++ b/tests/Basic/Basic.cpp @@ -162,22 +162,6 @@ const wchar_t* wcharFunction(const wchar_t* constWideChar) return constWideChar; } -Bar operator-(const Bar& b) -{ - Bar nb; - nb.A = -b.A; - nb.B = -b.B; - return nb; -} - -Bar operator+(const Bar& b1, const Bar& b2) -{ - Bar b; - b.A = b1.A + b2.A; - b.B = b1.B + b2.B; - return b; -} - Bar indirectReturn() { return Bar(); diff --git a/tests/Basic/Basic.h b/tests/Basic/Basic.h index 21603797..7eecc2ef 100644 --- a/tests/Basic/Basic.h +++ b/tests/Basic/Basic.h @@ -139,9 +139,6 @@ private: ImplementsAbstractFoo i; }; -DLL_API Bar operator-(const Bar &); -DLL_API Bar operator+(const Bar &, const Bar &); - int DLL_API unsafeFunction(const Bar& ret, char* testForString, void (*foo)(int)); DLL_API Bar indirectReturn(); @@ -193,6 +190,36 @@ class DLL_API basic DLL_API int test(basic& s); +// Tests the MoveOperatorToClassPass +struct DLL_API TestMoveOperatorToClass +{ + TestMoveOperatorToClass() {} + int A; + int B; +}; + +DLL_API int operator *(TestMoveOperatorToClass klass, int b) +{ + return klass.A * b; +} + +DLL_API TestMoveOperatorToClass operator-(const TestMoveOperatorToClass& b) +{ + TestMoveOperatorToClass nb; + nb.A = -b.A; + nb.B = -b.B; + return nb; +} + +DLL_API TestMoveOperatorToClass operator+(const TestMoveOperatorToClass& b1, + const TestMoveOperatorToClass& b2) +{ + TestMoveOperatorToClass b; + b.A = b1.A + b2.A; + b.B = b1.B + b2.B; + return b; +} + // Tests delegates typedef int (*DelegateInGlobalNamespace)(int);