Browse Source

Merge pull request #53 from ddobrev/abstract_implementations

Abstract implementations
pull/57/merge
João Matos 13 years ago
parent
commit
a6e98649d5
  1. 7
      src/AST/ClassLayout.cs
  2. 28
      src/AST/Function.cs
  3. 17
      src/AST/Method.cs
  4. 10
      src/AST/Type.cs
  5. 18
      src/Generator/AST/VTables.cs
  6. 5
      src/Generator/Driver.cs
  7. 8
      src/Generator/Generators/CSharp/CSharpMarshal.cs
  8. 98
      src/Generator/Generators/CSharp/CSharpTextTemplate.cs
  9. 4
      src/Generator/Passes/FindSymbolsPass.cs
  10. 188
      src/Generator/Passes/GenerateAbstractImplementationsPass.cs
  11. 18
      src/Generator/Utils/ParameterTypeComparer.cs
  12. 10
      tests/Basic/Basic.Tests.cs
  13. 20
      tests/Basic/Basic.cpp
  14. 25
      tests/Basic/Basic.h

7
src/AST/ClassLayout.cs

@ -99,8 +99,11 @@ namespace CppSharp.AST @@ -99,8 +99,11 @@ namespace CppSharp.AST
Size = classLayout.Size;
DataSize = classLayout.DataSize;
VFTables.AddRange(classLayout.VFTables);
Layout = new VTableLayout();
Layout.Components.AddRange(classLayout.Layout.Components);
if (classLayout.Layout != null)
{
Layout = new VTableLayout();
Layout.Components.AddRange(classLayout.Layout.Components);
}
}
/// <summary>

28
src/AST/Function.cs

@ -162,5 +162,33 @@ namespace CppSharp.AST @@ -162,5 +162,33 @@ namespace CppSharp.AST
public Type Type { get { return ReturnType.Type; } }
public QualifiedType QualifiedType { get { return ReturnType; } }
public virtual QualifiedType GetFunctionType()
{
var functionType = new FunctionType
{
CallingConvention = this.CallingConvention,
ReturnType = this.ReturnType
};
functionType.Parameters.AddRange(Parameters);
ReplaceIndirectReturnParamWithRegular(functionType);
var pointerType = new PointerType { QualifiedPointee = new QualifiedType(functionType) };
return new QualifiedType(pointerType);
}
private static void ReplaceIndirectReturnParamWithRegular(FunctionType functionType)
{
for (int i = functionType.Parameters.Count - 1; i >= 0; i--)
{
var parameter = functionType.Parameters[i];
if (parameter.Kind == ParameterKind.IndirectReturnType)
{
var ptrType = new PointerType { QualifiedPointee = new QualifiedType(parameter.Type) };
var retParam = new Parameter { Name = parameter.Name, QualifiedType = new QualifiedType(ptrType) };
functionType.Parameters.RemoveAt(i);
functionType.Parameters.Insert(i, retParam);
}
}
}
}
}

17
src/AST/Method.cs

@ -119,5 +119,22 @@ namespace CppSharp.AST @@ -119,5 +119,22 @@ namespace CppSharp.AST
public bool IsMoveConstructor;
public MethodConversionKind Conversion { get; set; }
public override QualifiedType GetFunctionType()
{
var qualifiedType = base.GetFunctionType();
if (!IsStatic)
{
FunctionType functionType;
qualifiedType.Type.IsPointerTo(out functionType);
var instance = new Parameter
{
Name = "instance",
QualifiedType = new QualifiedType(new BuiltinType(PrimitiveType.IntPtr))
};
functionType.Parameters.Insert(0, instance);
}
return qualifiedType;
}
}
}

10
src/AST/Type.cs

@ -181,6 +181,16 @@ namespace CppSharp.AST @@ -181,6 +181,16 @@ namespace CppSharp.AST
return Type.Equals(type.Type) && Qualifiers.Equals(type.Qualifiers);
}
public static bool operator ==(QualifiedType left, QualifiedType right)
{
return left.Equals(right);
}
public static bool operator !=(QualifiedType left, QualifiedType right)
{
return !(left == right);
}
public override int GetHashCode()
{
return base.GetHashCode();

18
src/Generator/AST/VTables.cs

@ -1,5 +1,9 @@ @@ -1,5 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using CppSharp.Generators;
using CppSharp.Generators.CSharp;
namespace CppSharp.AST
{
@ -78,5 +82,19 @@ namespace CppSharp.AST @@ -78,5 +82,19 @@ namespace CppSharp.AST
throw new NotSupportedException();
}
public static int GetVTableIndex(INamedDecl method, Class @class)
{
switch (@class.Layout.ABI)
{
case CppAbi.Microsoft:
return (from table in @class.Layout.VFTables
let j = table.Layout.Components.FindIndex(m => m.Method == method)
where j >= 0
select j).First();
default:
return @class.Layout.Layout.Components.FindIndex(m => m.Method == method);
}
}
}
}

5
src/Generator/Driver.cs

@ -148,6 +148,8 @@ namespace CppSharp @@ -148,6 +148,8 @@ namespace CppSharp
TranslationUnitPasses.AddPass(new CheckIgnoredDeclsPass());
TranslationUnitPasses.AddPass(new CheckFlagEnumsPass());
TranslationUnitPasses.AddPass(new CheckDuplicatedNamesPass());
if (Options.GenerateAbstractImpls)
TranslationUnitPasses.AddPass(new GenerateAbstractImplementationsPass());
}
public void ProcessCode()
@ -261,6 +263,7 @@ namespace CppSharp @@ -261,6 +263,7 @@ namespace CppSharp
public bool GenerateFunctionTemplates;
public bool GeneratePartialClasses;
public bool GenerateVirtualTables;
public bool GenerateAbstractImpls;
public bool GenerateInternalImports;
public string IncludePrefix;
public bool WriteOnlyWhenChanged;
@ -277,6 +280,8 @@ namespace CppSharp @@ -277,6 +280,8 @@ namespace CppSharp
{
get { return GeneratorKind == LanguageGeneratorKind.CLI; }
}
public bool Is32Bit { get { return true; } }
}
public class InvalidOptionException : Exception

8
src/Generator/Generators/CSharp/CSharpMarshal.cs

@ -216,7 +216,8 @@ namespace CppSharp.Generators.CSharp @@ -216,7 +216,8 @@ namespace CppSharp.Generators.CSharp
instance = copy;
}
if (@class.IsRefType)
if (@class.IsRefType &&
(!Context.Driver.Options.GenerateAbstractImpls || !@class.IsAbstract))
{
var instanceName = Generator.GeneratedIdentifier("instance");
if (VarSuffix > 0)
@ -252,7 +253,10 @@ namespace CppSharp.Generators.CSharp @@ -252,7 +253,10 @@ namespace CppSharp.Generators.CSharp
instance = instanceName;
}
Context.Return.Write("new {0}({1})", QualifiedIdentifier(@class),
Context.Return.Write("new {0}({1})",
QualifiedIdentifier(@class) +
(Context.Driver.Options.GenerateAbstractImpls && @class.IsAbstract ?
"Internal" : ""),
instance);
return true;

98
src/Generator/Generators/CSharp/CSharpTextTemplate.cs

@ -3,7 +3,9 @@ using System.Collections.Generic; @@ -3,7 +3,9 @@ using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using CppSharp.AST;
using CppSharp.Utils;
using Type = CppSharp.AST.Type;
namespace CppSharp.Generators.CSharp
@ -56,6 +58,19 @@ namespace CppSharp.Generators.CSharp @@ -56,6 +58,19 @@ namespace CppSharp.Generators.CSharp
{
get { return Generator.GeneratedIdentifier("Instance"); }
}
public static string GetAccess(Class @class)
{
switch (@class.Access)
{
case AccessSpecifier.Private:
return "internal ";
case AccessSpecifier.Protected:
return "protected ";
default:
return "public ";
}
}
}
public class CSharpBlockKind
@ -402,9 +417,6 @@ namespace CppSharp.Generators.CSharp @@ -402,9 +417,6 @@ namespace CppSharp.Generators.CSharp
if (method.IsSynthetized)
return;
if (method.IsPure)
return;
if (method.IsProxy)
return;
@ -611,7 +623,11 @@ namespace CppSharp.Generators.CSharp @@ -611,7 +623,11 @@ namespace CppSharp.Generators.CSharp
if (@class.IsUnion)
WriteLine("[StructLayout(LayoutKind.Explicit)]");
Write("public unsafe ");
Write(Helpers.GetAccess(@class));
Write("unsafe ");
if (Driver.Options.GenerateAbstractImpls && @class.IsAbstract)
Write("abstract ");
if (Options.GeneratePartialClasses)
Write("partial ");
@ -1346,23 +1362,29 @@ namespace CppSharp.Generators.CSharp @@ -1346,23 +1362,29 @@ namespace CppSharp.Generators.CSharp
private void GenerateNativeConstructor(Class @class)
{
PushBlock(CSharpBlockKind.Method);
WriteLine("internal {0}({1}.Internal* native)", SafeIdentifier(@class.Name),
@class.Name);
string className = @class.Name;
string safeIdentifier = SafeIdentifier(className);
if (@class.Access == AccessSpecifier.Private && className.EndsWith("Internal"))
{
className = className.Substring(0,
safeIdentifier.LastIndexOf("Internal", StringComparison.Ordinal));
}
WriteLine("internal {0}({1}.Internal* native)", safeIdentifier,
className);
WriteLineIndent(": this(new global::System.IntPtr(native))");
WriteStartBraceIndent();
WriteCloseBraceIndent();
PopBlock(NewLineKind.BeforeNextBlock);
PushBlock(CSharpBlockKind.Method);
WriteLine("internal {0}({1}.Internal native)", SafeIdentifier(@class.Name),
@class.Name);
WriteLine("internal {0}({1}.Internal native)", safeIdentifier, className);
WriteLineIndent(": this(&native)");
WriteStartBraceIndent();
WriteCloseBraceIndent();
PopBlock(NewLineKind.BeforeNextBlock);
PushBlock(CSharpBlockKind.Method);
WriteLine("internal {0}(global::System.IntPtr native){1}", SafeIdentifier(@class.Name),
WriteLine("internal {0}(global::System.IntPtr native){1}", safeIdentifier,
@class.IsValueType ? " : this()" : string.Empty);
var hasBaseClass = @class.HasBaseClass && @class.BaseClass.IsRefType;
@ -1459,9 +1481,11 @@ namespace CppSharp.Generators.CSharp @@ -1459,9 +1481,11 @@ namespace CppSharp.Generators.CSharp
PushBlock(CSharpBlockKind.Method);
GenerateDeclarationCommon(method);
Write("public ");
Write(Driver.Options.GenerateAbstractImpls &&
@class.IsAbstract && method.IsConstructor ? "protected " : "public ");
if (method.IsVirtual && !method.IsOverride)
if (method.IsVirtual && !method.IsOverride &&
(!Driver.Options.GenerateAbstractImpls || !method.IsPure))
Write("virtual ");
var isBuiltinOperator = method.IsOperator &&
@ -1473,6 +1497,9 @@ namespace CppSharp.Generators.CSharp @@ -1473,6 +1497,9 @@ namespace CppSharp.Generators.CSharp
if (method.IsOverride)
Write("override ");
if (Driver.Options.GenerateAbstractImpls && method.IsPure)
Write("abstract ");
var functionName = GetFunctionIdentifier(method);
if (method.IsConstructor || method.IsDestructor)
@ -1482,7 +1509,15 @@ namespace CppSharp.Generators.CSharp @@ -1482,7 +1509,15 @@ namespace CppSharp.Generators.CSharp
GenerateMethodParameters(method);
WriteLine(")");
Write(")");
if (Driver.Options.GenerateAbstractImpls && method.IsPure)
{
Write(";");
PopBlock(NewLineKind.BeforeNextBlock);
return;
}
NewLine();
if (method.Kind == CXXMethodKind.Constructor)
GenerateClassConstructorBase(@class, method);
@ -1502,6 +1537,10 @@ namespace CppSharp.Generators.CSharp @@ -1502,6 +1537,10 @@ namespace CppSharp.Generators.CSharp
{
GenerateOperator(method, @class);
}
else if (method.IsOverride && method.IsSynthetized)
{
GenerateVirtualTableFunctionCall(method, @class);
}
else
{
GenerateInternalFunctionCall(method);
@ -1529,6 +1568,37 @@ namespace CppSharp.Generators.CSharp @@ -1529,6 +1568,37 @@ namespace CppSharp.Generators.CSharp
PopBlock(NewLineKind.BeforeNextBlock);
}
private void GenerateVirtualTableFunctionCall(Function method, Class @class)
{
string delegateId;
Write(GetVirtualCallDelegate(method, @class, Driver.Options.Is32Bit, out delegateId));
GenerateFunctionCall(delegateId, method.Parameters, method);
}
public static string GetVirtualCallDelegate(INamedDecl method, Class @class,
bool is32Bit, out string delegateId)
{
var virtualCallBuilder = new StringBuilder();
virtualCallBuilder.AppendFormat("void* vtable = *((void**) {0}.ToPointer());",
Helpers.InstanceIdentifier);
virtualCallBuilder.AppendLine();
var i = VTables.GetVTableIndex(method, @class);
virtualCallBuilder.AppendFormat(
"void* slot = *((void**) vtable + {0} * {1});", i, is32Bit ? 4 : 8);
virtualCallBuilder.AppendLine();
string @delegate = method.Name + "Delegate";
delegateId = Generator.GeneratedIdentifier(@delegate);
virtualCallBuilder.AppendFormat(
"var {1} = ({0}) Marshal.GetDelegateForFunctionPointer(new IntPtr(slot), typeof({0}));",
@delegate, delegateId);
virtualCallBuilder.AppendLine();
return virtualCallBuilder.ToString();
}
private void GenerateOperator(Method method, Class @class)
{
if (method.IsSynthetized)
@ -1882,9 +1952,11 @@ namespace CppSharp.Generators.CSharp @@ -1882,9 +1952,11 @@ namespace CppSharp.Generators.CSharp
PushBlock(CSharpBlockKind.Typedef);
WriteLine("[UnmanagedFunctionPointerAttribute(CallingConvention.{0})]",
Helpers.ToCSharpCallConv(functionType.CallingConvention));
TypePrinter.PushContext(CSharpTypePrinterContextKind.Native);
WriteLine("public {0};",
string.Format(TypePrinter.VisitDelegate(functionType).Type,
SafeIdentifier(typedef.Name)));
TypePrinter.PopContext();
PopBlock(NewLineKind.BeforeNextBlock);
}
else if (typedef.Type.IsEnumType())
@ -1974,7 +2046,7 @@ namespace CppSharp.Generators.CSharp @@ -1974,7 +2046,7 @@ namespace CppSharp.Generators.CSharp
public void GenerateInternalFunction(Function function)
{
if (!function.IsProcessed || function.ExplicityIgnored)
if (!function.IsProcessed || function.ExplicityIgnored || function.IsPure)
return;
if (function.OriginalFunction != null)

4
src/Generator/Passes/FindSymbolsPass.cs

@ -11,7 +11,9 @@ namespace CppSharp.Passes @@ -11,7 +11,9 @@ namespace CppSharp.Passes
return false;
var mangledDecl = decl as IMangledDecl;
if (mangledDecl != null && !VisitMangledDeclaration(mangledDecl))
var method = decl as Method;
if (mangledDecl != null && !(method != null && method.IsPure) &&
!VisitMangledDeclaration(mangledDecl))
{
decl.ExplicityIgnored = true;
return false;

188
src/Generator/Passes/GenerateAbstractImplementationsPass.cs

@ -0,0 +1,188 @@ @@ -0,0 +1,188 @@
using System.Collections.Generic;
using System.Linq;
using CppSharp.AST;
using CppSharp.Utils;
namespace CppSharp.Passes
{
/// <summary>
/// This pass generates internal classes that implement abstract classes.
/// When the return type of a function is abstract, these internal classes provide -
/// since the real type cannot be resolved while binding - an allocatable class that supports proper polymorphism.
/// </summary>
public class GenerateAbstractImplementationsPass : TranslationUnitPass
{
/// <summary>
/// Collects all internal implementations in a unit to be added at the end because the unit cannot be changed while it's being iterated though.
/// </summary>
private readonly List<Class> internalImpls = new List<Class>();
public override bool VisitTranslationUnit(TranslationUnit unit)
{
bool result = base.VisitTranslationUnit(unit);
unit.Classes.AddRange(internalImpls);
internalImpls.Clear();
return result;
}
public override bool VisitClassDecl(Class @class)
{
if (@class.CompleteDeclaration != null)
return VisitClassDecl(@class.CompleteDeclaration as Class);
if (!VisitDeclaration(@class) || AlreadyVisited(@class))
return false;
if (@class.IsAbstract)
internalImpls.Add(AddInternalImplementation(@class));
return base.VisitClassDecl(@class);
}
private Class AddInternalImplementation(Class @class)
{
var internalImpl = GetInternalImpl(@class);
var abstractMethods = GetRelevantAbstractMethods(@class);
foreach (var abstractMethod in abstractMethods)
{
internalImpl.Methods.Add(new Method(abstractMethod));
var @delegate = new TypedefDecl
{
Name = abstractMethod.Name + "Delegate",
QualifiedType = abstractMethod.GetFunctionType(),
IgnoreFlags = abstractMethod.IgnoreFlags
};
internalImpl.Typedefs.Add(@delegate);
}
internalImpl.Layout = new ClassLayout(@class.Layout);
FillVTable(@class, abstractMethods, internalImpl);
foreach (var method in internalImpl.Methods)
{
method.IsPure = false;
method.IsOverride = true;
method.IsSynthetized = true;
}
return internalImpl;
}
private static Class GetInternalImpl(Declaration @class)
{
var internalImpl = new Class
{
Name = @class.Name + "Internal",
Access = AccessSpecifier.Private,
Namespace = @class.Namespace
};
var @base = new BaseClassSpecifier { Type = new TagType(@class) };
internalImpl.Bases.Add(@base);
return internalImpl;
}
private static List<Method> GetRelevantAbstractMethods(Class @class)
{
var abstractMethods = GetAbstractMethods(@class);
var overriddenMethods = GetOverriddenMethods(@class);
var paramTypeCmp = new ParameterTypeComparer();
for (int i = abstractMethods.Count - 1; i >= 0; i--)
{
var @abstract = abstractMethods[i];
if (overriddenMethods.Find(m => m.Name == @abstract.Name &&
m.ReturnType == @abstract.ReturnType &&
m.Parameters.Count == @abstract.Parameters.Count &&
m.Parameters.SequenceEqual(@abstract.Parameters, paramTypeCmp)) != null)
{
abstractMethods.RemoveAt(i);
}
}
return abstractMethods;
}
private static List<Method> GetAbstractMethods(Class @class)
{
var abstractMethods = @class.Methods.Where(m => m.IsPure).ToList();
foreach (var @base in @class.Bases)
abstractMethods.AddRange(GetAbstractMethods(@base.Class));
return abstractMethods;
}
private static List<Method> GetOverriddenMethods(Class @class)
{
var abstractMethods = @class.Methods.Where(m => m.IsOverride).ToList();
foreach (var @base in @class.Bases)
abstractMethods.AddRange(GetOverriddenMethods(@base.Class));
return abstractMethods;
}
private void FillVTable(Class @class, IList<Method> abstractMethods, Class internalImplementation)
{
switch (Driver.Options.Abi)
{
case CppAbi.Microsoft:
CreateVTableMS(@class, abstractMethods, internalImplementation);
break;
default:
CreateVTableItanium(@class, abstractMethods, internalImplementation);
break;
}
}
private static void CreateVTableMS(Class @class,
IList<Method> abstractMethods, Class internalImplementation)
{
var vTables = GetVTables(@class);
for (int i = 0; i < abstractMethods.Count; i++)
{
for (int j = 0; j < vTables.Count; j++)
{
var vTable = vTables[j];
var k = vTable.Layout.Components.FindIndex(v => v.Method == abstractMethods[i]);
if (k >= 0)
{
var vTableComponent = vTable.Layout.Components[k];
vTableComponent.Declaration = internalImplementation.Methods[i];
vTable.Layout.Components[k] = vTableComponent;
vTables[j] = vTable;
}
}
}
internalImplementation.Layout.VFTables.Clear();
internalImplementation.Layout.VFTables.AddRange(vTables);
}
private static void CreateVTableItanium(Class @class,
IList<Method> abstractMethods, Class internalImplementation)
{
var vTableComponents = GetVTableComponents(@class);
for (int i = 0; i < abstractMethods.Count; i++)
{
var j = vTableComponents.FindIndex(v => v.Method == abstractMethods[i]);
var vTableComponent = vTableComponents[j];
vTableComponent.Declaration = internalImplementation.Methods[i];
vTableComponents[j] = vTableComponent;
}
internalImplementation.Layout.Layout.Components.Clear();
internalImplementation.Layout.Layout.Components.AddRange(vTableComponents);
}
private static List<VTableComponent> GetVTableComponents(Class @class)
{
var vTableComponents = new List<VTableComponent>(
@class.Layout.Layout.Components);
foreach (var @base in @class.Bases)
vTableComponents.AddRange(GetVTableComponents(@base.Class));
return vTableComponents;
}
private static List<VFTableInfo> GetVTables(Class @class)
{
var vTables = new List<VFTableInfo>(
@class.Layout.VFTables);
foreach (var @base in @class.Bases)
vTables.AddRange(GetVTables(@base.Class));
return vTables;
}
}
}

18
src/Generator/Utils/ParameterTypeComparer.cs

@ -0,0 +1,18 @@ @@ -0,0 +1,18 @@
using System.Collections.Generic;
using CppSharp.AST;
namespace CppSharp.Utils
{
public class ParameterTypeComparer : IEqualityComparer<Parameter>
{
public bool Equals(Parameter x, Parameter y)
{
return x.QualifiedType == y.QualifiedType;
}
public int GetHashCode(Parameter obj)
{
return obj.Type.GetHashCode();
}
}
}

10
tests/Basic/Basic.Tests.cs

@ -85,5 +85,15 @@ public class BasicTests @@ -85,5 +85,15 @@ public class BasicTests
Foo2 result = foo2 << 3;
Assert.That(result.C, Is.EqualTo(16));
}
[Test, Ignore]
public void TestAbstractReturnType()
{
var returnsAbstractFoo = new ReturnsAbstractFoo();
var abstractFoo = returnsAbstractFoo.getFoo();
Assert.AreEqual(abstractFoo.pureFunction(), 5);
Assert.AreEqual(abstractFoo.pureFunction1(), 10);
Assert.AreEqual(abstractFoo.pureFunction2(), 15);
}
}

20
tests/Basic/Basic.cpp

@ -121,6 +121,26 @@ Bar indirectReturn() @@ -121,6 +121,26 @@ Bar indirectReturn()
return Bar();
}
int ImplementsAbstractFoo::pureFunction()
{
return 5;
}
int ImplementsAbstractFoo::pureFunction1()
{
return 10;
}
int ImplementsAbstractFoo::pureFunction2()
{
return 15;
}
const AbstractFoo& ReturnsAbstractFoo::getFoo()
{
return i;
}
void DefaultParameters::Foo(int a, int b)
{
}

25
tests/Basic/Basic.h

@ -80,6 +80,31 @@ public: @@ -80,6 +80,31 @@ public:
Hello* RetNull();
};
class DLL_API AbstractFoo
{
public:
virtual int pureFunction() = 0;
virtual int pureFunction1() = 0;
virtual int pureFunction2() = 0;
};
class DLL_API ImplementsAbstractFoo : public AbstractFoo
{
public:
virtual int pureFunction();
virtual int pureFunction1();
virtual int pureFunction2();
};
class DLL_API ReturnsAbstractFoo
{
public:
const AbstractFoo& getFoo();
private:
ImplementsAbstractFoo i;
};
DLL_API Bar operator-(const Bar &);
DLL_API Bar operator+(const Bar &, const Bar &);

Loading…
Cancel
Save