Browse Source

Merge pull request #128 from ddobrev/master

Added a pass to move functions to an appropriate existing class if possible
pull/129/head
João Matos 12 years ago
parent
commit
786ceb60e2
  1. 49
      src/AST/ASTContext.cs
  2. 8
      src/AST/Namespace.cs
  3. 1
      src/Generator/Driver.cs
  4. 60
      src/Generator/Passes/MoveFunctionToClassPass.cs
  5. 6
      tests/Basic/Basic.Tests.cs
  6. 5
      tests/Basic/Basic.cpp
  7. 7
      tests/Basic/Basic.h

49
src/AST/ASTContext.cs

@ -1,6 +1,7 @@ @@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
namespace CppSharp.AST
{
@ -50,61 +51,45 @@ namespace CppSharp.AST @@ -50,61 +51,45 @@ namespace CppSharp.AST
/// Finds an existing enum in the library modules.
public IEnumerable<Enumeration> FindEnum(string name)
{
foreach (var module in TranslationUnits)
{
var type = module.FindEnum(name);
if (type != null) yield return type;
}
return TranslationUnits.Select(
module => module.FindEnum(name)).Where(type => type != null);
}
/// Finds the complete declaration of an enum.
public Enumeration FindCompleteEnum(string name)
{
foreach (var @enum in FindEnum(name))
if (!@enum.IsIncomplete)
return @enum;
return null;
return FindEnum(name).FirstOrDefault(@enum => !@enum.IsIncomplete);
}
/// Finds an existing struct/class in the library modules.
public IEnumerable<Class> FindClass(string name, bool create = false)
public IEnumerable<Class> FindClass(string name, bool create = false,
bool ignoreCase = false)
{
foreach (var module in TranslationUnits)
{
var type = module.FindClass(name);
if (type != null) yield return type;
}
return TranslationUnits.Select(
module => module.FindClass(name,
ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal))
.Where(type => type != null);
}
/// Finds the complete declaration of a class.
public Class FindCompleteClass(string name)
public Class FindCompleteClass(string name, bool ignoreCase = false)
{
foreach (var @class in FindClass(name))
if (!@class.IsIncomplete)
return @class;
return null;
return FindClass(name, ignoreCase: ignoreCase).FirstOrDefault(
@class => !@class.IsIncomplete);
}
/// Finds an existing function in the library modules.
public IEnumerable<Function> FindFunction(string name)
{
foreach (var module in TranslationUnits)
{
var type = module.FindFunction(name);
if (type != null) yield return type;
}
return TranslationUnits.Select(module => module.FindFunction(name))
.Where(type => type != null);
}
/// Finds an existing typedef in the library modules.
public IEnumerable<TypedefDecl> FindTypedef(string name)
{
foreach (var module in TranslationUnits)
{
var type = module.FindTypedef(name);
if (type != null) yield return type;
}
return TranslationUnits.Select(module => module.FindTypedef(name))
.Where(type => type != null);
}
/// Finds an existing declaration by name.

8
src/AST/Namespace.cs

@ -168,17 +168,17 @@ namespace CppSharp.AST @@ -168,17 +168,17 @@ namespace CppSharp.AST
return @class;
}
public Class FindClass(string name)
public Class FindClass(string name,
StringComparison stringComparison = StringComparison.Ordinal)
{
if (string.IsNullOrEmpty(name)) return null;
var entries = name.Split(new string[] { "::" },
var entries = name.Split(new[] { "::" },
StringSplitOptions.RemoveEmptyEntries).ToList();
if (entries.Count <= 1)
{
var @class = Classes.Find(e => e.Name.Equals(name));
return @class;
return Classes.Find(e => e.Name.Equals(name, stringComparison));
}
var className = entries[entries.Count - 1];

1
src/Generator/Driver.cs

@ -201,6 +201,7 @@ namespace CppSharp @@ -201,6 +201,7 @@ namespace CppSharp
TranslationUnitPasses.AddPass(new FindSymbolsPass());
TranslationUnitPasses.AddPass(new MoveOperatorToClassPass());
TranslationUnitPasses.AddPass(new MoveFunctionToClassPass());
TranslationUnitPasses.AddPass(new CheckAmbiguousFunctions());
TranslationUnitPasses.AddPass(new CheckOperatorsOverloadsPass());
TranslationUnitPasses.AddPass(new CheckVirtualOverrideReturnCovariance());

60
src/Generator/Passes/MoveFunctionToClassPass.cs

@ -0,0 +1,60 @@ @@ -0,0 +1,60 @@
using System.Linq;
using CppSharp.AST;
namespace CppSharp.Passes
{
/// <summary>
/// Moves a function to a class, if any, named after the function's header.
/// </summary>
public class MoveFunctionToClassPass : TranslationUnitPass
{
public override bool VisitFunctionDecl(Function function)
{
if (AlreadyVisited(function) || function.Ignore || function.Namespace is Class)
return base.VisitFunctionDecl(function);
Class @class = FindClassToMoveFunctionTo(function.Namespace);
if (@class != null)
{
MoveFunction(function, @class);
}
return base.VisitFunctionDecl(function);
}
private Class FindClassToMoveFunctionTo(INamedDecl @namespace)
{
TranslationUnit unit = @namespace as TranslationUnit;
if (unit == null)
{
return Driver.ASTContext.FindClass(
@namespace.Name, ignoreCase: true).FirstOrDefault();
}
return Driver.ASTContext.FindCompleteClass(
unit.FileNameWithoutExtension.ToLowerInvariant(), true);
}
private static void MoveFunction(Function function, Class @class)
{
var method = new Method(function)
{
Namespace = @class,
IsStatic = true
};
if (method.OperatorKind != CXXOperatorKind.None)
{
var param = function.Parameters[0];
Class type;
if (!FunctionToInstanceMethodPass.GetClassParameter(param, out type))
return;
method.Kind = CXXMethodKind.Operator;
method.SynthKind = FunctionSynthKind.NonMemberOperator;
method.OriginalFunction = null;
}
function.ExplicityIgnored = true;
@class.Methods.Add(method);
}
}
}

6
tests/Basic/Basic.Tests.cs

@ -123,6 +123,12 @@ public class BasicTests @@ -123,6 +123,12 @@ public class BasicTests
Assert.That(foo.GetANSI(), Is.EqualTo("ANSI"));
}
[Test]
public void TestMoveFunctionToClass()
{
Assert.That(basic.test(new basic()), Is.EqualTo(5));
}
[Test, Ignore]
public void TestConversionOperator()
{

5
tests/Basic/Basic.cpp

@ -203,3 +203,8 @@ void DefaultParameters::Bar() const @@ -203,3 +203,8 @@ void DefaultParameters::Bar() const
void DefaultParameters::Bar()
{
}
int test(basic& s)
{
return 5;
}

7
tests/Basic/Basic.h

@ -178,3 +178,10 @@ class Base @@ -178,3 +178,10 @@ class Base
class Derived : public Base<Derived>
{
};
class DLL_API basic
{
};
DLL_API int test(basic& s);

Loading…
Cancel
Save