From 52c754c4de61851288e9012a7667ce4a57272050 Mon Sep 17 00:00:00 2001 From: Dimitar Dobrev Date: Sat, 23 Dec 2017 23:34:20 +0200 Subject: [PATCH] Extended the multiple inheritance to work for templates. Signed-off-by: Dimitar Dobrev --- .../Generators/CSharp/CSharpMarshal.cs | 12 +++-- .../Generators/CSharp/CSharpSources.cs | 15 ++----- .../CSharp/CSharpSourcesExtensions.cs | 9 ++-- .../Passes/MultipleInheritancePass.cs | 44 ++++++++++++++----- .../Passes/TrimSpecializationsPass.cs | 20 ++++++--- tests/CSharp/CSharp.Tests.cs | 10 +++++ tests/CSharp/CSharpTemplates.cpp | 8 ++++ tests/CSharp/CSharpTemplates.h | 7 +++ 8 files changed, 89 insertions(+), 36 deletions(-) diff --git a/src/Generator/Generators/CSharp/CSharpMarshal.cs b/src/Generator/Generators/CSharp/CSharpMarshal.cs index 4b5a1981..7a150857 100644 --- a/src/Generator/Generators/CSharp/CSharpMarshal.cs +++ b/src/Generator/Generators/CSharp/CSharpMarshal.cs @@ -275,12 +275,13 @@ namespace CppSharp.Generators.CSharp Context.Before.WriteLine($"var {ptrName} = {Context.ReturnVarName};"); var specialization = decl.Namespace as ClassTemplateSpecialization; + Type returnType = Context.ReturnType.Type.Desugar(); + var finalType = (returnType.GetFinalPointee() ?? returnType).Desugar(); var res = string.Format( "{0} == IntPtr.Zero? null : {1}({2}) Marshal.GetDelegateForFunctionPointer({0}, typeof({2}))", ptrName, - specialization == null ? string.Empty : - $@"({specialization.TemplatedDecl.TemplatedClass.Typedefs.First( - t => t.Name == decl.Name).Visit(this.typePrinter)}) (object) ", + finalType.IsDependent ? $@"({specialization.TemplatedDecl.TemplatedClass.Typedefs.First( + t => t.Name == decl.Name).Visit(this.typePrinter)}) (object) " : string.Empty, typedef); Context.Return.Write(res); return true; @@ -359,7 +360,10 @@ namespace CppSharp.Generators.CSharp public override bool VisitTemplateParameterSubstitutionType(TemplateParameterSubstitutionType param, TypeQualifiers quals) { - Context.Return.Write($"({param.ReplacedParameter.Parameter.Name}) (object) "); + Type returnType = Context.ReturnType.Type.Desugar(); + Type finalType = (returnType.GetFinalPointee() ?? returnType).Desugar(); + if (finalType.IsDependent) + Context.Return.Write($"({param.ReplacedParameter.Parameter.Name}) (object) "); if (param.Replacement.Type.Desugar().IsPointerToPrimitiveType()) Context.Return.Write($"({CSharpTypePrinter.IntPtrType}) "); return base.VisitTemplateParameterSubstitutionType(param, quals); diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 1529c3e5..913a666d 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -680,7 +680,7 @@ namespace CppSharp.Generators.CSharp if (@class.NeedsBase) { - foreach (var @base in @class.Bases.Where(b => b.IsClass)) + foreach (var @base in @class.Bases.Where(b => b.IsClass && b.Class.IsGenerated)) { var typeMaps = new List(); var keys = new List(); @@ -868,7 +868,7 @@ namespace CppSharp.Generators.CSharp GenerateInternalFunctionCall(property.SetMethod, parameters, @void); } - private void GenerateFieldSetter(Field field, Class @class) + private void GenerateFieldSetter(Field field, Class @class, QualifiedType fieldType) { var param = new Parameter { @@ -1153,7 +1153,7 @@ namespace CppSharp.Generators.CSharp p.GetMethod.InstantiatedFrom == property.GetMethod); } - private void GenerateFieldGetter(Field field, Class @class) + private void GenerateFieldGetter(Field field, Class @class, QualifiedType returnType) { var name = @class.Layout.Fields.First(f => f.FieldPtr == field.OriginalPtr).Name; var ctx = new CSharpMarshalContext(Context) @@ -1163,7 +1163,7 @@ namespace CppSharp.Generators.CSharp ReturnVarName = $@"{(@class.IsValueType ? Helpers.InstanceField : $"(({TypePrinter.PrintNative(@class)}*) {Helpers.InstanceIdentifier})")}{ (@class.IsValueType ? "." : "->")}{SafeIdentifier(name)}", - ReturnType = field.QualifiedType + ReturnType = returnType }; ctx.PushMarshalKind(MarshalKind.NativeField); @@ -3167,13 +3167,6 @@ namespace CppSharp.Generators.CSharp if (function.IsPure) return; - if (function.OriginalFunction != null) - { - var @class = function.OriginalNamespace as Class; - if (@class != null && @class.IsInterface) - function = function.OriginalFunction; - } - PushBlock(BlockKind.InternalsClassMethod); WriteLine("[SuppressUnmanagedCodeSecurity]"); Write("[DllImport(\"{0}\", ", GetLibraryOf(function)); diff --git a/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs b/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs index 6e9d4ee8..64e24e0b 100644 --- a/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs +++ b/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs @@ -31,7 +31,7 @@ namespace CppSharp.Generators.CSharp } public static void GenerateField(this CSharpSources gen, Class @class, - Field field, Action generate, bool isVoid) + Field field, Action generate, bool isVoid) { if (@class.IsDependent) { @@ -46,7 +46,7 @@ namespace CppSharp.Generators.CSharp gen.WriteStartBraceIndent(); var specializedField = specialization.Fields.First( f => f.OriginalName == field.OriginalName); - generate(specializedField, specialization); + generate(specializedField, specialization, field.QualifiedType); if (isVoid) gen.WriteLine("return;"); gen.WriteCloseBraceIndent(); @@ -58,12 +58,13 @@ namespace CppSharp.Generators.CSharp var specialization = @class.Specializations[0]; var specializedField = specialization.Fields.First( f => f.OriginalName == field.OriginalName); - generate(specializedField, specialization); + generate(specializedField, specialization, field.QualifiedType); } } else { - generate(field, @class.IsDependent ? @class.Specializations[0] : @class); + generate(field, @class.IsDependent ? @class.Specializations[0] : @class, + field.QualifiedType); } } diff --git a/src/Generator/Passes/MultipleInheritancePass.cs b/src/Generator/Passes/MultipleInheritancePass.cs index 6084340c..15b0f956 100644 --- a/src/Generator/Passes/MultipleInheritancePass.cs +++ b/src/Generator/Passes/MultipleInheritancePass.cs @@ -46,7 +46,7 @@ namespace CppSharp.Passes { var @base = @class.Bases[i]; var baseClass = @base.Class; - if (baseClass == null || baseClass.IsInterface) continue; + if (baseClass == null || baseClass.IsInterface || !baseClass.IsGenerated) continue; var @interface = GetInterface(baseClass); @class.Bases[i] = new BaseClassSpecifier(@base) { Type = new TagType(@interface) }; @@ -70,14 +70,35 @@ namespace CppSharp.Passes private Class GetNewInterface(string name, Class @base) { - var @interface = new Class - { - Name = name, - Namespace = @base.Namespace, - Access = @base.Access, - Type = ClassType.Interface, - OriginalClass = @base - }; + var specialization = @base as ClassTemplateSpecialization; + Class @interface; + if (specialization == null) + { + @interface = new Class(); + } + else + { + Class template = specialization.TemplatedDecl.TemplatedClass; + Class templatedInterface; + if (templatedInterfaces.ContainsKey(template)) + templatedInterface = templatedInterfaces[template]; + else + templatedInterfaces[template] = templatedInterface = GetInterface(template); + var specializedInterface = new ClassTemplateSpecialization(); + specializedInterface.Arguments.AddRange(specialization.Arguments); + specializedInterface.TemplatedDecl = new ClassTemplate { TemplatedDecl = templatedInterface }; + @interface = specializedInterface; + } + @interface.Name = name; + @interface.Namespace = @base.Namespace; + @interface.Access = @base.Access; + @interface.Type = ClassType.Interface; + @interface.OriginalClass = @base; + if (@base.IsTemplate) + { + @interface.IsDependent = true; + @interface.TemplateParameters.AddRange(@base.TemplateParameters); + } @interface.Bases.AddRange( from b in @base.Bases @@ -136,7 +157,8 @@ namespace CppSharp.Passes @base.Bases.Add(new BaseClassSpecifier { Type = new TagType(@interface) }); - interfaces.Add(@base, @interface); + if (specialization == null) + interfaces.Add(@base, @interface); return @interface; } @@ -204,5 +226,7 @@ namespace CppSharp.Passes foreach (var @base in @interface.Bases) ImplementInterfaceProperties(@class, @base.Class); } + + private readonly Dictionary templatedInterfaces = new Dictionary(); } } diff --git a/src/Generator/Passes/TrimSpecializationsPass.cs b/src/Generator/Passes/TrimSpecializationsPass.cs index a765049b..a70e5f98 100644 --- a/src/Generator/Passes/TrimSpecializationsPass.cs +++ b/src/Generator/Passes/TrimSpecializationsPass.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using CppSharp.AST; using CppSharp.AST.Extensions; +using CppSharp.Types; using CppSharp.Utils; namespace CppSharp.Passes @@ -78,7 +79,14 @@ namespace CppSharp.Passes if (!base.VisitDeclaration(field)) return false; - if (field.Access == AccessSpecifier.Private || + if (field.Access == AccessSpecifier.Private) + { + CheckForInternalSpecialization(field, field.Type); + return true; + } + + TypeMap typeMap; + if (!Context.TypeMaps.FindTypeMap(field.Type, out typeMap) && !ASTUtils.CheckTypeForSpecialization(field.Type, field, AddSpecialization, Context.TypeMaps)) CheckForInternalSpecialization(field, field.Type); @@ -200,12 +208,10 @@ namespace CppSharp.Passes foreach (var @base in @class.Bases.Where(b => b.IsClass)) { var specialization = @base.Class as ClassTemplateSpecialization; - if (specialization != null) - { - if (!ASTUtils.CheckTypeForSpecialization(@base.Type, @class, - AddSpecialization, Context.TypeMaps)) - CheckForInternalSpecialization(@class, @base.Type); - } + if (specialization != null && + !ASTUtils.CheckTypeForSpecialization(@base.Type, @class, + AddSpecialization, Context.TypeMaps)) + CheckForInternalSpecialization(@class, @base.Type); CheckBasesForSpecialization(@base.Class); } } diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index a3240bf1..63715be2 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -885,6 +885,16 @@ public unsafe class CSharpTests : GeneratorTestFixture } } + [Test] + public void TestSpecializationForSecondaryBase() + { + using (var hasSpecializationForSecondaryBase = new HasSpecializationForSecondaryBase()) + { + hasSpecializationForSecondaryBase.DependentValue = 5; + Assert.That(hasSpecializationForSecondaryBase.DependentValue, Is.EqualTo(5)); + } + } + [Test] public void TestAbstractImplementatonsInPrimaryAndSecondaryBases() { diff --git a/tests/CSharp/CSharpTemplates.cpp b/tests/CSharp/CSharpTemplates.cpp index 807ed568..5cef3c22 100644 --- a/tests/CSharp/CSharpTemplates.cpp +++ b/tests/CSharp/CSharpTemplates.cpp @@ -87,6 +87,14 @@ int HasVirtualTemplate::function() return v->function(); } +HasSpecializationForSecondaryBase::HasSpecializationForSecondaryBase() +{ +} + +HasSpecializationForSecondaryBase::~HasSpecializationForSecondaryBase() +{ +} + TemplateSpecializer::TemplateSpecializer() { } diff --git a/tests/CSharp/CSharpTemplates.h b/tests/CSharp/CSharpTemplates.h index a36468ef..46ad0552 100644 --- a/tests/CSharp/CSharpTemplates.h +++ b/tests/CSharp/CSharpTemplates.h @@ -425,6 +425,13 @@ private: HasDefaultTemplateArgument explicitSpecialization; }; +class DLL_API HasSpecializationForSecondaryBase : T1, DependentValueFields +{ +public: + HasSpecializationForSecondaryBase(); + ~HasSpecializationForSecondaryBase(); +}; + template class TemplateInAnotherUnit;