diff --git a/src/AST/ClassLayout.cs b/src/AST/ClassLayout.cs index 3721dda4..1377f37e 100644 --- a/src/AST/ClassLayout.cs +++ b/src/AST/ClassLayout.cs @@ -99,8 +99,11 @@ namespace CppSharp.AST Size = classLayout.Size; DataSize = classLayout.DataSize; VFTables.AddRange(classLayout.VFTables); - Layout = new VTableLayout(); - Layout.Components.AddRange(classLayout.Layout.Components); + if (classLayout.Layout != null) + { + Layout = new VTableLayout(); + Layout.Components.AddRange(classLayout.Layout.Components); + } } /// diff --git a/src/AST/Function.cs b/src/AST/Function.cs index f924e2a4..04a20a0a 100644 --- a/src/AST/Function.cs +++ b/src/AST/Function.cs @@ -162,5 +162,33 @@ namespace CppSharp.AST public Type Type { get { return ReturnType.Type; } } public QualifiedType QualifiedType { get { return ReturnType; } } + + public virtual QualifiedType GetFunctionType() + { + var functionType = new FunctionType + { + CallingConvention = this.CallingConvention, + ReturnType = this.ReturnType + }; + functionType.Parameters.AddRange(Parameters); + ReplaceIndirectReturnParamWithRegular(functionType); + var pointerType = new PointerType { QualifiedPointee = new QualifiedType(functionType) }; + return new QualifiedType(pointerType); + } + + private static void ReplaceIndirectReturnParamWithRegular(FunctionType functionType) + { + for (int i = functionType.Parameters.Count - 1; i >= 0; i--) + { + var parameter = functionType.Parameters[i]; + if (parameter.Kind == ParameterKind.IndirectReturnType) + { + var ptrType = new PointerType { QualifiedPointee = new QualifiedType(parameter.Type) }; + var retParam = new Parameter { Name = parameter.Name, QualifiedType = new QualifiedType(ptrType) }; + functionType.Parameters.RemoveAt(i); + functionType.Parameters.Insert(i, retParam); + } + } + } } } \ No newline at end of file diff --git a/src/AST/Method.cs b/src/AST/Method.cs index adc7c4c3..74285471 100644 --- a/src/AST/Method.cs +++ b/src/AST/Method.cs @@ -119,5 +119,22 @@ namespace CppSharp.AST public bool IsMoveConstructor; public MethodConversionKind Conversion { get; set; } + + public override QualifiedType GetFunctionType() + { + var qualifiedType = base.GetFunctionType(); + if (!IsStatic) + { + FunctionType functionType; + qualifiedType.Type.IsPointerTo(out functionType); + var instance = new Parameter + { + Name = "instance", + QualifiedType = new QualifiedType(new BuiltinType(PrimitiveType.IntPtr)) + }; + functionType.Parameters.Insert(0, instance); + } + return qualifiedType; + } } } \ No newline at end of file diff --git a/src/AST/Type.cs b/src/AST/Type.cs index 84b78663..421852e8 100644 --- a/src/AST/Type.cs +++ b/src/AST/Type.cs @@ -181,6 +181,16 @@ namespace CppSharp.AST return Type.Equals(type.Type) && Qualifiers.Equals(type.Qualifiers); } + public static bool operator ==(QualifiedType left, QualifiedType right) + { + return left.Equals(right); + } + + public static bool operator !=(QualifiedType left, QualifiedType right) + { + return !(left == right); + } + public override int GetHashCode() { return base.GetHashCode(); diff --git a/src/Generator/AST/VTables.cs b/src/Generator/AST/VTables.cs index fa6cb5cf..edd24d44 100644 --- a/src/Generator/AST/VTables.cs +++ b/src/Generator/AST/VTables.cs @@ -1,5 +1,9 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Text; +using CppSharp.Generators; +using CppSharp.Generators.CSharp; namespace CppSharp.AST { @@ -78,5 +82,19 @@ namespace CppSharp.AST throw new NotSupportedException(); } + + public static int GetVTableIndex(INamedDecl method, Class @class) + { + switch (@class.Layout.ABI) + { + case CppAbi.Microsoft: + return (from table in @class.Layout.VFTables + let j = table.Layout.Components.FindIndex(m => m.Method == method) + where j >= 0 + select j).First(); + default: + return @class.Layout.Layout.Components.FindIndex(m => m.Method == method); + } + } } } diff --git a/src/Generator/Driver.cs b/src/Generator/Driver.cs index 5566a67a..1e153420 100644 --- a/src/Generator/Driver.cs +++ b/src/Generator/Driver.cs @@ -148,6 +148,8 @@ namespace CppSharp TranslationUnitPasses.AddPass(new CheckIgnoredDeclsPass()); TranslationUnitPasses.AddPass(new CheckFlagEnumsPass()); TranslationUnitPasses.AddPass(new CheckDuplicatedNamesPass()); + if (Options.GenerateAbstractImpls) + TranslationUnitPasses.AddPass(new GenerateAbstractImplementationsPass()); } public void ProcessCode() @@ -261,6 +263,7 @@ namespace CppSharp public bool GenerateFunctionTemplates; public bool GeneratePartialClasses; public bool GenerateVirtualTables; + public bool GenerateAbstractImpls; public bool GenerateInternalImports; public string IncludePrefix; public bool WriteOnlyWhenChanged; @@ -277,6 +280,8 @@ namespace CppSharp { get { return GeneratorKind == LanguageGeneratorKind.CLI; } } + + public bool Is32Bit { get { return true; } } } public class InvalidOptionException : Exception diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index fd976038..6d414ef3 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -216,7 +216,8 @@ namespace CppSharp.Generators.CSharp instance = copy; } - if (@class.IsRefType) + if (@class.IsRefType && + (!Context.Driver.Options.GenerateAbstractImpls || !@class.IsAbstract)) { var instanceName = Generator.GeneratedIdentifier("instance"); if (VarSuffix > 0) @@ -252,7 +253,10 @@ namespace CppSharp.Generators.CSharp instance = instanceName; } - Context.Return.Write("new {0}({1})", QualifiedIdentifier(@class), + Context.Return.Write("new {0}({1})", + QualifiedIdentifier(@class) + + (Context.Driver.Options.GenerateAbstractImpls && @class.IsAbstract ? + "Internal" : ""), instance); return true; diff --git a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs index 82b8b709..83c11c47 100644 --- a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs +++ b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs @@ -3,7 +3,9 @@ using System.Collections.Generic; using System.Globalization; using System.IO; using System.Linq; +using System.Text; using CppSharp.AST; +using CppSharp.Utils; using Type = CppSharp.AST.Type; namespace CppSharp.Generators.CSharp @@ -56,6 +58,19 @@ namespace CppSharp.Generators.CSharp { get { return Generator.GeneratedIdentifier("Instance"); } } + + public static string GetAccess(Class @class) + { + switch (@class.Access) + { + case AccessSpecifier.Private: + return "internal "; + case AccessSpecifier.Protected: + return "protected "; + default: + return "public "; + } + } } public class CSharpBlockKind @@ -402,9 +417,6 @@ namespace CppSharp.Generators.CSharp if (method.IsSynthetized) return; - if (method.IsPure) - return; - if (method.IsProxy) return; @@ -611,7 +623,11 @@ namespace CppSharp.Generators.CSharp if (@class.IsUnion) WriteLine("[StructLayout(LayoutKind.Explicit)]"); - Write("public unsafe "); + Write(Helpers.GetAccess(@class)); + Write("unsafe "); + + if (Driver.Options.GenerateAbstractImpls && @class.IsAbstract) + Write("abstract "); if (Options.GeneratePartialClasses) Write("partial "); @@ -1346,23 +1362,29 @@ namespace CppSharp.Generators.CSharp private void GenerateNativeConstructor(Class @class) { PushBlock(CSharpBlockKind.Method); - WriteLine("internal {0}({1}.Internal* native)", SafeIdentifier(@class.Name), - @class.Name); + string className = @class.Name; + string safeIdentifier = SafeIdentifier(className); + if (@class.Access == AccessSpecifier.Private && className.EndsWith("Internal")) + { + className = className.Substring(0, + safeIdentifier.LastIndexOf("Internal", StringComparison.Ordinal)); + } + WriteLine("internal {0}({1}.Internal* native)", safeIdentifier, + className); WriteLineIndent(": this(new global::System.IntPtr(native))"); WriteStartBraceIndent(); WriteCloseBraceIndent(); PopBlock(NewLineKind.BeforeNextBlock); PushBlock(CSharpBlockKind.Method); - WriteLine("internal {0}({1}.Internal native)", SafeIdentifier(@class.Name), - @class.Name); + WriteLine("internal {0}({1}.Internal native)", safeIdentifier, className); WriteLineIndent(": this(&native)"); WriteStartBraceIndent(); WriteCloseBraceIndent(); PopBlock(NewLineKind.BeforeNextBlock); PushBlock(CSharpBlockKind.Method); - WriteLine("internal {0}(global::System.IntPtr native){1}", SafeIdentifier(@class.Name), + WriteLine("internal {0}(global::System.IntPtr native){1}", safeIdentifier, @class.IsValueType ? " : this()" : string.Empty); var hasBaseClass = @class.HasBaseClass && @class.BaseClass.IsRefType; @@ -1459,9 +1481,11 @@ namespace CppSharp.Generators.CSharp PushBlock(CSharpBlockKind.Method); GenerateDeclarationCommon(method); - Write("public "); + Write(Driver.Options.GenerateAbstractImpls && + @class.IsAbstract && method.IsConstructor ? "protected " : "public "); - if (method.IsVirtual && !method.IsOverride) + if (method.IsVirtual && !method.IsOverride && + (!Driver.Options.GenerateAbstractImpls || !method.IsPure)) Write("virtual "); var isBuiltinOperator = method.IsOperator && @@ -1473,6 +1497,9 @@ namespace CppSharp.Generators.CSharp if (method.IsOverride) Write("override "); + if (Driver.Options.GenerateAbstractImpls && method.IsPure) + Write("abstract "); + var functionName = GetFunctionIdentifier(method); if (method.IsConstructor || method.IsDestructor) @@ -1482,7 +1509,15 @@ namespace CppSharp.Generators.CSharp GenerateMethodParameters(method); - WriteLine(")"); + Write(")"); + + if (Driver.Options.GenerateAbstractImpls && method.IsPure) + { + Write(";"); + PopBlock(NewLineKind.BeforeNextBlock); + return; + } + NewLine(); if (method.Kind == CXXMethodKind.Constructor) GenerateClassConstructorBase(@class, method); @@ -1502,6 +1537,10 @@ namespace CppSharp.Generators.CSharp { GenerateOperator(method, @class); } + else if (method.IsOverride && method.IsSynthetized) + { + GenerateVirtualTableFunctionCall(method, @class); + } else { GenerateInternalFunctionCall(method); @@ -1529,6 +1568,37 @@ namespace CppSharp.Generators.CSharp PopBlock(NewLineKind.BeforeNextBlock); } + private void GenerateVirtualTableFunctionCall(Function method, Class @class) + { + string delegateId; + Write(GetVirtualCallDelegate(method, @class, Driver.Options.Is32Bit, out delegateId)); + GenerateFunctionCall(delegateId, method.Parameters, method); + } + + public static string GetVirtualCallDelegate(INamedDecl method, Class @class, + bool is32Bit, out string delegateId) + { + var virtualCallBuilder = new StringBuilder(); + virtualCallBuilder.AppendFormat("void* vtable = *((void**) {0}.ToPointer());", + Helpers.InstanceIdentifier); + virtualCallBuilder.AppendLine(); + + var i = VTables.GetVTableIndex(method, @class); + + virtualCallBuilder.AppendFormat( + "void* slot = *((void**) vtable + {0} * {1});", i, is32Bit ? 4 : 8); + virtualCallBuilder.AppendLine(); + + string @delegate = method.Name + "Delegate"; + delegateId = Generator.GeneratedIdentifier(@delegate); + + virtualCallBuilder.AppendFormat( + "var {1} = ({0}) Marshal.GetDelegateForFunctionPointer(new IntPtr(slot), typeof({0}));", + @delegate, delegateId); + virtualCallBuilder.AppendLine(); + return virtualCallBuilder.ToString(); + } + private void GenerateOperator(Method method, Class @class) { if (method.IsSynthetized) @@ -1882,9 +1952,11 @@ namespace CppSharp.Generators.CSharp PushBlock(CSharpBlockKind.Typedef); WriteLine("[UnmanagedFunctionPointerAttribute(CallingConvention.{0})]", Helpers.ToCSharpCallConv(functionType.CallingConvention)); + TypePrinter.PushContext(CSharpTypePrinterContextKind.Native); WriteLine("public {0};", string.Format(TypePrinter.VisitDelegate(functionType).Type, SafeIdentifier(typedef.Name))); + TypePrinter.PopContext(); PopBlock(NewLineKind.BeforeNextBlock); } else if (typedef.Type.IsEnumType()) @@ -1974,7 +2046,7 @@ namespace CppSharp.Generators.CSharp public void GenerateInternalFunction(Function function) { - if (!function.IsProcessed || function.ExplicityIgnored) + if (!function.IsProcessed || function.ExplicityIgnored || function.IsPure) return; if (function.OriginalFunction != null) diff --git a/src/Generator/Passes/FindSymbolsPass.cs b/src/Generator/Passes/FindSymbolsPass.cs index 7d9ee99e..e7146d08 100644 --- a/src/Generator/Passes/FindSymbolsPass.cs +++ b/src/Generator/Passes/FindSymbolsPass.cs @@ -11,7 +11,9 @@ namespace CppSharp.Passes return false; var mangledDecl = decl as IMangledDecl; - if (mangledDecl != null && !VisitMangledDeclaration(mangledDecl)) + var method = decl as Method; + if (mangledDecl != null && !(method != null && method.IsPure) && + !VisitMangledDeclaration(mangledDecl)) { decl.ExplicityIgnored = true; return false; diff --git a/src/Generator/Passes/GenerateAbstractImplementationsPass.cs b/src/Generator/Passes/GenerateAbstractImplementationsPass.cs new file mode 100644 index 00000000..3a614eff --- /dev/null +++ b/src/Generator/Passes/GenerateAbstractImplementationsPass.cs @@ -0,0 +1,188 @@ +using System.Collections.Generic; +using System.Linq; +using CppSharp.AST; +using CppSharp.Utils; + +namespace CppSharp.Passes +{ + /// + /// This pass generates internal classes that implement abstract classes. + /// When the return type of a function is abstract, these internal classes provide - + /// since the real type cannot be resolved while binding - an allocatable class that supports proper polymorphism. + /// + public class GenerateAbstractImplementationsPass : TranslationUnitPass + { + /// + /// Collects all internal implementations in a unit to be added at the end because the unit cannot be changed while it's being iterated though. + /// + private readonly List internalImpls = new List(); + + public override bool VisitTranslationUnit(TranslationUnit unit) + { + bool result = base.VisitTranslationUnit(unit); + unit.Classes.AddRange(internalImpls); + internalImpls.Clear(); + return result; + } + + public override bool VisitClassDecl(Class @class) + { + if (@class.CompleteDeclaration != null) + return VisitClassDecl(@class.CompleteDeclaration as Class); + + if (!VisitDeclaration(@class) || AlreadyVisited(@class)) + return false; + + if (@class.IsAbstract) + internalImpls.Add(AddInternalImplementation(@class)); + return base.VisitClassDecl(@class); + } + + private Class AddInternalImplementation(Class @class) + { + var internalImpl = GetInternalImpl(@class); + + var abstractMethods = GetRelevantAbstractMethods(@class); + + foreach (var abstractMethod in abstractMethods) + { + internalImpl.Methods.Add(new Method(abstractMethod)); + var @delegate = new TypedefDecl + { + Name = abstractMethod.Name + "Delegate", + QualifiedType = abstractMethod.GetFunctionType(), + IgnoreFlags = abstractMethod.IgnoreFlags + }; + internalImpl.Typedefs.Add(@delegate); + } + + internalImpl.Layout = new ClassLayout(@class.Layout); + FillVTable(@class, abstractMethods, internalImpl); + + foreach (var method in internalImpl.Methods) + { + method.IsPure = false; + method.IsOverride = true; + method.IsSynthetized = true; + } + return internalImpl; + } + + private static Class GetInternalImpl(Declaration @class) + { + var internalImpl = new Class + { + Name = @class.Name + "Internal", + Access = AccessSpecifier.Private, + Namespace = @class.Namespace + }; + var @base = new BaseClassSpecifier { Type = new TagType(@class) }; + internalImpl.Bases.Add(@base); + return internalImpl; + } + + private static List GetRelevantAbstractMethods(Class @class) + { + var abstractMethods = GetAbstractMethods(@class); + var overriddenMethods = GetOverriddenMethods(@class); + var paramTypeCmp = new ParameterTypeComparer(); + for (int i = abstractMethods.Count - 1; i >= 0; i--) + { + var @abstract = abstractMethods[i]; + if (overriddenMethods.Find(m => m.Name == @abstract.Name && + m.ReturnType == @abstract.ReturnType && + m.Parameters.Count == @abstract.Parameters.Count && + m.Parameters.SequenceEqual(@abstract.Parameters, paramTypeCmp)) != null) + { + abstractMethods.RemoveAt(i); + } + } + return abstractMethods; + } + + private static List GetAbstractMethods(Class @class) + { + var abstractMethods = @class.Methods.Where(m => m.IsPure).ToList(); + foreach (var @base in @class.Bases) + abstractMethods.AddRange(GetAbstractMethods(@base.Class)); + return abstractMethods; + } + + private static List GetOverriddenMethods(Class @class) + { + var abstractMethods = @class.Methods.Where(m => m.IsOverride).ToList(); + foreach (var @base in @class.Bases) + abstractMethods.AddRange(GetOverriddenMethods(@base.Class)); + return abstractMethods; + } + + private void FillVTable(Class @class, IList abstractMethods, Class internalImplementation) + { + switch (Driver.Options.Abi) + { + case CppAbi.Microsoft: + CreateVTableMS(@class, abstractMethods, internalImplementation); + break; + default: + CreateVTableItanium(@class, abstractMethods, internalImplementation); + break; + } + } + + private static void CreateVTableMS(Class @class, + IList abstractMethods, Class internalImplementation) + { + var vTables = GetVTables(@class); + for (int i = 0; i < abstractMethods.Count; i++) + { + for (int j = 0; j < vTables.Count; j++) + { + var vTable = vTables[j]; + var k = vTable.Layout.Components.FindIndex(v => v.Method == abstractMethods[i]); + if (k >= 0) + { + var vTableComponent = vTable.Layout.Components[k]; + vTableComponent.Declaration = internalImplementation.Methods[i]; + vTable.Layout.Components[k] = vTableComponent; + vTables[j] = vTable; + } + } + } + internalImplementation.Layout.VFTables.Clear(); + internalImplementation.Layout.VFTables.AddRange(vTables); + } + + private static void CreateVTableItanium(Class @class, + IList abstractMethods, Class internalImplementation) + { + var vTableComponents = GetVTableComponents(@class); + for (int i = 0; i < abstractMethods.Count; i++) + { + var j = vTableComponents.FindIndex(v => v.Method == abstractMethods[i]); + var vTableComponent = vTableComponents[j]; + vTableComponent.Declaration = internalImplementation.Methods[i]; + vTableComponents[j] = vTableComponent; + } + internalImplementation.Layout.Layout.Components.Clear(); + internalImplementation.Layout.Layout.Components.AddRange(vTableComponents); + } + + private static List GetVTableComponents(Class @class) + { + var vTableComponents = new List( + @class.Layout.Layout.Components); + foreach (var @base in @class.Bases) + vTableComponents.AddRange(GetVTableComponents(@base.Class)); + return vTableComponents; + } + + private static List GetVTables(Class @class) + { + var vTables = new List( + @class.Layout.VFTables); + foreach (var @base in @class.Bases) + vTables.AddRange(GetVTables(@base.Class)); + return vTables; + } + } +} diff --git a/src/Generator/Utils/ParameterTypeComparer.cs b/src/Generator/Utils/ParameterTypeComparer.cs new file mode 100644 index 00000000..a8740cdc --- /dev/null +++ b/src/Generator/Utils/ParameterTypeComparer.cs @@ -0,0 +1,18 @@ +using System.Collections.Generic; +using CppSharp.AST; + +namespace CppSharp.Utils +{ + public class ParameterTypeComparer : IEqualityComparer + { + public bool Equals(Parameter x, Parameter y) + { + return x.QualifiedType == y.QualifiedType; + } + + public int GetHashCode(Parameter obj) + { + return obj.Type.GetHashCode(); + } + } +} diff --git a/tests/Basic/Basic.Tests.cs b/tests/Basic/Basic.Tests.cs index f2092642..9bc6fc5d 100644 --- a/tests/Basic/Basic.Tests.cs +++ b/tests/Basic/Basic.Tests.cs @@ -85,5 +85,15 @@ public class BasicTests Foo2 result = foo2 << 3; Assert.That(result.C, Is.EqualTo(16)); } + + [Test, Ignore] + public void TestAbstractReturnType() + { + var returnsAbstractFoo = new ReturnsAbstractFoo(); + var abstractFoo = returnsAbstractFoo.getFoo(); + Assert.AreEqual(abstractFoo.pureFunction(), 5); + Assert.AreEqual(abstractFoo.pureFunction1(), 10); + Assert.AreEqual(abstractFoo.pureFunction2(), 15); + } } \ No newline at end of file diff --git a/tests/Basic/Basic.cpp b/tests/Basic/Basic.cpp index b19ccff5..bb9269e4 100644 --- a/tests/Basic/Basic.cpp +++ b/tests/Basic/Basic.cpp @@ -121,6 +121,26 @@ Bar indirectReturn() return Bar(); } +int ImplementsAbstractFoo::pureFunction() +{ + return 5; +} + +int ImplementsAbstractFoo::pureFunction1() +{ + return 10; +} + +int ImplementsAbstractFoo::pureFunction2() +{ + return 15; +} + +const AbstractFoo& ReturnsAbstractFoo::getFoo() +{ + return i; +} + void DefaultParameters::Foo(int a, int b) { } @@ -135,4 +155,4 @@ void DefaultParameters::Bar() const void DefaultParameters::Bar() { -} \ No newline at end of file +} diff --git a/tests/Basic/Basic.h b/tests/Basic/Basic.h index d85e5a51..1ea349f0 100644 --- a/tests/Basic/Basic.h +++ b/tests/Basic/Basic.h @@ -80,6 +80,31 @@ public: Hello* RetNull(); }; +class DLL_API AbstractFoo +{ +public: + virtual int pureFunction() = 0; + virtual int pureFunction1() = 0; + virtual int pureFunction2() = 0; +}; + +class DLL_API ImplementsAbstractFoo : public AbstractFoo +{ +public: + virtual int pureFunction(); + virtual int pureFunction1(); + virtual int pureFunction2(); +}; + +class DLL_API ReturnsAbstractFoo +{ +public: + const AbstractFoo& getFoo(); + +private: + ImplementsAbstractFoo i; +}; + DLL_API Bar operator-(const Bar &); DLL_API Bar operator+(const Bar &, const Bar &); @@ -112,4 +137,4 @@ struct DLL_API DefaultParameters void Bar() const; void Bar(); -}; \ No newline at end of file +};