diff --git a/src/AST/Function.cs b/src/AST/Function.cs index f924e2a4..a9043167 100644 --- a/src/AST/Function.cs +++ b/src/AST/Function.cs @@ -162,5 +162,30 @@ namespace CppSharp.AST public Type Type { get { return ReturnType.Type; } } public QualifiedType QualifiedType { get { return ReturnType; } } + + public virtual QualifiedType GetFunctionType() + { + var functionType = new FunctionType(); + functionType.CallingConvention = CallingConvention; + functionType.ReturnType = ReturnType; + functionType.Parameters.AddRange(Parameters); + for (int i = functionType.Parameters.Count - 1; i >= 0; i--) + { + var parameter = functionType.Parameters[i]; + if (parameter.Kind == ParameterKind.IndirectReturnType) + { + var retParam = new Parameter(); + retParam.Name = parameter.Name; + var ptrType = new PointerType(); + ptrType.QualifiedPointee = new QualifiedType(parameter.Type); + retParam.QualifiedType = new QualifiedType(ptrType); + functionType.Parameters.RemoveAt(i); + functionType.Parameters.Insert(i, retParam); + } + } + var pointerType = new PointerType(); + pointerType.QualifiedPointee = new QualifiedType(functionType); + return new QualifiedType(pointerType); + } } } \ No newline at end of file diff --git a/src/AST/Method.cs b/src/AST/Method.cs index adc7c4c3..64d2d897 100644 --- a/src/AST/Method.cs +++ b/src/AST/Method.cs @@ -119,5 +119,21 @@ namespace CppSharp.AST public bool IsMoveConstructor; public MethodConversionKind Conversion { get; set; } + + public override QualifiedType GetFunctionType() + { + var qualifiedType = base.GetFunctionType(); + FunctionType functionType; + qualifiedType.Type.IsPointerTo(out functionType); + if (!IsStatic) + { + var instance = new Parameter(); + instance.Name = "instance"; + instance.QualifiedType = new QualifiedType( + new BuiltinType(PrimitiveType.IntPtr)); + functionType.Parameters.Insert(0, instance); + } + return qualifiedType; + } } } \ No newline at end of file diff --git a/src/Generator/Passes/AbstractImplementationsPass.cs b/src/Generator/Passes/AbstractImplementationsPass.cs index 4321a892..bcc25ef3 100644 --- a/src/Generator/Passes/AbstractImplementationsPass.cs +++ b/src/Generator/Passes/AbstractImplementationsPass.cs @@ -32,73 +32,65 @@ namespace CppSharp.Passes private Class AddInternalImplementation(Class @class) { - var internalImplementation = new Class(); - internalImplementation.Name = @class.Name + "Internal"; - internalImplementation.Access = AccessSpecifier.Private; - internalImplementation.Namespace = @class.Namespace; + var internalImpl = GetInternalImpl(@class); + + var abstractMethods = GetRelevantAbstractMethods(@class); + + foreach (var abstractMethod in abstractMethods) + { + internalImpl.Methods.Add(new Method(abstractMethod)); + var @delegate = new TypedefDecl { Name = abstractMethod.Name + "Delegate" }; + @delegate.QualifiedType = abstractMethod.GetFunctionType(); + @delegate.IgnoreFlags = abstractMethod.IgnoreFlags; + internalImpl.Typedefs.Add(@delegate); + } + + internalImpl.Layout = new ClassLayout(@class.Layout); + FillVTable(@class, abstractMethods, internalImpl); + + foreach (var method in internalImpl.Methods) + { + method.IsPure = false; + method.IsOverride = true; + method.IsSynthetized = true; + } + return internalImpl; + } + + private static Class GetInternalImpl(Declaration @class) + { + var internalImpl = new Class(); + internalImpl.Name = @class.Name + "Internal"; + internalImpl.Access = AccessSpecifier.Private; + internalImpl.Namespace = @class.Namespace; var @base = new BaseClassSpecifier { Type = new TagType(@class) }; - internalImplementation.Bases.Add(@base); + internalImpl.Bases.Add(@base); + return internalImpl; + } + + private static List GetRelevantAbstractMethods(Class @class) + { var abstractMethods = GetAbstractMethods(@class); var overriddenMethods = GetOverriddenMethods(@class); - var parameterTypeComparer = new ParameterTypeComparer(); + var paramTypeCmp = new ParameterTypeComparer(); for (int i = abstractMethods.Count - 1; i >= 0; i--) { - Method @abstract = abstractMethods[i]; + var @abstract = abstractMethods[i]; if (overriddenMethods.Find(m => m.Name == @abstract.Name && m.ReturnType.Type == @abstract.ReturnType.Type && m.Parameters.Count == @abstract.Parameters.Count && - m.Parameters.SequenceEqual(@abstract.Parameters, parameterTypeComparer)) != null) + m.Parameters.SequenceEqual(@abstract.Parameters, paramTypeCmp)) != null) { abstractMethods.RemoveAt(i); } } - foreach (Method abstractMethod in abstractMethods) - { - internalImplementation.Methods.Add(new Method(abstractMethod)); - var @delegate = new TypedefDecl { Name = abstractMethod.Name + "Delegate" }; - var functionType = new FunctionType(); - functionType.CallingConvention = abstractMethod.CallingConvention; - functionType.ReturnType = abstractMethod.ReturnType; - var instance = new Parameter(); - instance.Name = "instance"; - instance.QualifiedType = new QualifiedType(new BuiltinType(PrimitiveType.IntPtr)); - functionType.Parameters.Add(instance); - functionType.Parameters.AddRange(abstractMethod.Parameters); - for (int i = functionType.Parameters.Count - 1; i >= 0; i--) - { - var parameter = functionType.Parameters[i]; - if (parameter.Kind == ParameterKind.IndirectReturnType) - { - var retParam = new Parameter(); - retParam.Name = parameter.Name; - var ptrType = new PointerType(); - ptrType.QualifiedPointee = new QualifiedType(parameter.Type); - retParam.QualifiedType = new QualifiedType(ptrType); - functionType.Parameters.RemoveAt(i); - functionType.Parameters.Insert(i, retParam); - } - } - var pointerType = new PointerType(); - pointerType.QualifiedPointee = new QualifiedType(functionType); - @delegate.QualifiedType = new QualifiedType(pointerType); - @delegate.IgnoreFlags = abstractMethod.IgnoreFlags; - internalImplementation.Typedefs.Add(@delegate); - } - internalImplementation.Layout = new ClassLayout(@class.Layout); - FillVTable(@class, abstractMethods, internalImplementation); - foreach (Method method in internalImplementation.Methods) - { - method.IsPure = false; - method.IsOverride = true; - method.IsSynthetized = true; - } - return internalImplementation; + return abstractMethods; } private static List GetAbstractMethods(Class @class) { var abstractMethods = @class.Methods.Where(m => m.IsPure).ToList(); - foreach (BaseClassSpecifier @base in @class.Bases) + foreach (var @base in @class.Bases) abstractMethods.AddRange(GetAbstractMethods(@base.Class)); return abstractMethods; } @@ -106,7 +98,7 @@ namespace CppSharp.Passes private static List GetOverriddenMethods(Class @class) { var abstractMethods = @class.Methods.Where(m => m.IsOverride).ToList(); - foreach (BaseClassSpecifier @base in @class.Bases) + foreach (var @base in @class.Bases) abstractMethods.AddRange(GetOverriddenMethods(@base.Class)); return abstractMethods; } @@ -132,11 +124,11 @@ namespace CppSharp.Passes { for (int j = 0; j < vTables.Count; j++) { - VFTableInfo vTable = vTables[j]; + var vTable = vTables[j]; var k = vTable.Layout.Components.FindIndex(v => v.Method == abstractMethods[i]); if (k >= 0) { - VTableComponent vTableComponent = vTable.Layout.Components[k]; + var vTableComponent = vTable.Layout.Components[k]; vTableComponent.Declaration = internalImplementation.Methods[i]; vTable.Layout.Components[k] = vTableComponent; vTables[j] = vTable; @@ -154,7 +146,7 @@ namespace CppSharp.Passes for (int i = 0; i < abstractMethods.Count; i++) { var j = vTableComponents.FindIndex(v => v.Method == abstractMethods[i]); - VTableComponent vTableComponent = vTableComponents[j]; + var vTableComponent = vTableComponents[j]; vTableComponent.Declaration = internalImplementation.Methods[i]; vTableComponents[j] = vTableComponent; } @@ -164,18 +156,18 @@ namespace CppSharp.Passes private static List GetVTableComponents(Class @class) { - List vTableComponents = new List( + var vTableComponents = new List( @class.Layout.Layout.Components); - foreach (BaseClassSpecifier @base in @class.Bases) + foreach (var @base in @class.Bases) vTableComponents.AddRange(GetVTableComponents(@base.Class)); return vTableComponents; } private static List GetVTables(Class @class) { - List vTables = new List( + var vTables = new List( @class.Layout.VFTables); - foreach (BaseClassSpecifier @base in @class.Bases) + foreach (var @base in @class.Bases) vTables.AddRange(GetVTables(@base.Class)); return vTables; }