diff --git a/src/CppParser/AST.cpp b/src/CppParser/AST.cpp index 43372707..11a826c0 100644 --- a/src/CppParser/AST.cpp +++ b/src/CppParser/AST.cpp @@ -637,7 +637,7 @@ DEF_STRING(Variable, Mangled) BaseClassSpecifier::BaseClassSpecifier() : Type(0), Offset(0) {} Field::Field() : Declaration(DeclarationKind::Field), Class(0), - IsBitField(false), BitWidth(0) {} + IsBitField(false), BitWidth(0), Offset(0) {} Field::~Field() {} diff --git a/src/CppParser/Parser.cpp b/src/CppParser/Parser.cpp index 6b3fdd19..beed28f1 100644 --- a/src/CppParser/Parser.cpp +++ b/src/CppParser/Parser.cpp @@ -1621,6 +1621,16 @@ Type* Parser::WalkType(clang::QualType QualType, clang::TypeLoc* TL, if (QualType.isNull()) return nullptr; + // we cannot get a location in some cases of template arguments + const RecordType* RT; + if (!(RT = QualType->getAs()) || + !dyn_cast(RT->getDecl()) || + (TL && !TL->isNull())) + { + C->getSema().RequireCompleteType( + TL && !TL->isNull() ? TL->getLocStart() : clang::SourceLocation(), QualType, 1); + } + const clang::Type* Type = QualType.getTypePtr(); if (DesugarType) @@ -1940,7 +1950,7 @@ Type* Parser::WalkType(clang::QualType QualType, clang::TypeLoc* TL, TST->Template = static_cast(WalkDeclaration( Name.getAsTemplateDecl(), 0, /*IgnoreSystemDecls=*/false)); if (TS->isSugared()) - TST->Desugared = WalkType(TS->desugar()); + TST->Desugared = WalkType(TS->desugar(), TL); TypeLoc UTL, ETL, ITL; @@ -2099,7 +2109,7 @@ Type* Parser::WalkType(clang::QualType QualType, clang::TypeLoc* TL, case clang::Type::Decltype: { auto DT = Type->getAs(); - Ty = WalkType(DT->getUnderlyingType()); + Ty = WalkType(DT->getUnderlyingType(), TL); break; } default: @@ -2236,7 +2246,7 @@ static bool CanCheckCodeGenInfo(clang::Sema& S, { if (auto MPT = Ty->getAs()) if (!MPT->isDependentType()) - S.RequireCompleteType(clang::SourceLocation(), clang::QualType(Ty, 0), 0); + S.RequireCompleteType(clang::SourceLocation(), clang::QualType(Ty, 0), 1); } return CheckCodeGenInfo; @@ -2283,8 +2293,6 @@ void Parser::WalkFunction(clang::FunctionDecl* FD, Function* F, HandlePreprocessedEntities(F, headRange, MacroLocation::FunctionHead); HandlePreprocessedEntities(F, FTL.getParensRange(), MacroLocation::FunctionParameters); - //auto bodyRange = clang::SourceRange(FTL.getRParenLoc(), FD->getLocEnd()); - //HandlePreprocessedEntities(F, bodyRange, MacroLocation::FunctionBody); } } @@ -2526,9 +2534,6 @@ Friend* Parser::WalkFriend(clang::FriendDecl *FD) F->Declaration = GetDeclarationFromFriend(FriendDecl); } - //auto TL = FD->getFriendType()->getTypeLoc(); - //F->QualifiedType = GetQualifiedType(VD->getType(), WalkType(FD->getFriendType(), &TL)); - NS->Friends.push_back(F); return F; @@ -2683,7 +2688,7 @@ AST::Expression* Parser::WalkExpression(clang::Expr* Expr) { auto CallExpr = cast(Expr); auto CallExpression = new AST::CallExpr(GetStringFromStatement(Expr), - WalkDeclaration(CallExpr->getCalleeDecl())); + CallExpr->getCalleeDecl() ? WalkDeclaration(CallExpr->getCalleeDecl()) : 0); for (auto arg : CallExpr->arguments()) { CallExpression->Arguments.push_back(WalkExpression(arg)); diff --git a/src/Generator.Tests/AST/TestAST.cs b/src/Generator.Tests/AST/TestAST.cs index 59beeff2..057a1356 100644 --- a/src/Generator.Tests/AST/TestAST.cs +++ b/src/Generator.Tests/AST/TestAST.cs @@ -226,7 +226,8 @@ namespace CppSharp.Generator.Tests.AST Assert.AreEqual(3, template.Specializations.Count); Assert.AreEqual(TemplateSpecializationKind.ExplicitInstantiationDefinition, template.Specializations[0].SpecializationKind); Assert.AreEqual(TemplateSpecializationKind.ExplicitInstantiationDefinition, template.Specializations[1].SpecializationKind); - Assert.AreEqual(TemplateSpecializationKind.Undeclared, template.Specializations[2].SpecializationKind); + // the instantian is not declared in the header but we force the completion the types in our parser + Assert.AreEqual(TemplateSpecializationKind.ImplicitInstantiation, template.Specializations[2].SpecializationKind); var typeDef = AstContext.FindTypedef("TestTemplateClassInt").FirstOrDefault(); Assert.IsNotNull(typeDef, "Couldn't find TestTemplateClassInt typedef."); var integerInst = typeDef.Type as TemplateSpecializationType; diff --git a/src/Generator/Driver.cs b/src/Generator/Driver.cs index 5db6b8ed..2fb83d13 100644 --- a/src/Generator/Driver.cs +++ b/src/Generator/Driver.cs @@ -257,6 +257,7 @@ namespace CppSharp { if (Options.GenerateInlines) TranslationUnitPasses.AddPass(new GenerateInlinesCodePass()); + TranslationUnitPasses.AddPass(new TrimSpecializationsPass()); TranslationUnitPasses.AddPass(new GenerateTemplatesCodePass()); } diff --git a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs index 514bf6ae..296e5dd3 100644 --- a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs +++ b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs @@ -66,6 +66,43 @@ namespace CppSharp.Generators.CSharp return "public "; } } + + public static StringBuilder FormatTypesStringForIdentifier(StringBuilder types) + { + return types.Replace("global::System.", string.Empty).Replace("*", "Ptr").Replace('.', '_'); + } + + public static string GetSuffixForInternal(ClassTemplateSpecialization templateSpecialization, + CSharpTypePrinter typePrinter) + { + return templateSpecialization == null ? string.Empty : + GetSuffixForInternal(templateSpecialization.TemplatedDecl.TemplatedDecl, + templateSpecialization.Arguments, typePrinter); + } + + public static string GetSuffixForInternal(Declaration template, + IEnumerable args, CSharpTypePrinter typePrinter) + { + if (((Class) template).Fields.All(f => !f.IsDependent || f.Type.IsAddress())) + return string.Empty; + + if (args.All(a => a.Type.Type != null && a.Type.Type.IsAddress())) + return "_Ptr"; + + // we don't want internals in the names of internals :) + typePrinter.PushContext(CSharpTypePrinterContextKind.Managed); + var suffix = new StringBuilder(); + foreach (var argType in from argType in args + where argType.Type.Type != null + select argType.Type.ToString()) + { + suffix.Append('_'); + suffix.Append(argType); + } + typePrinter.PopContext(); + FormatTypesStringForIdentifier(suffix); + return suffix.ToString(); + } } public class CSharpBlockKind @@ -191,16 +228,36 @@ namespace CppSharp.Generators.CSharp GenerateTypedef(typedef); } + var templateGroups = (from template in context.Templates.OfType() + where template.Specializations.Count > 0 + group template by context.Classes.Contains(template.TemplatedClass) + into @group + select @group).ToList(); + + if (templateGroups.Count > 0 && !templateGroups[0].Key) + foreach (var classTemplate in templateGroups[0]) + GenerateClassTemplateSpecializationInternal(classTemplate); + + var classTemplates = templateGroups.Where(g => g.Key).SelectMany(g => g).ToList(); + // Generate all the struct/class declarations. - foreach (var @class in context.Classes) + foreach (var @class in context.Classes.Where(c => !c.IsIncomplete)) { - if (@class.IsIncomplete) - continue; - if (@class.IsInterface) GenerateInterface(@class); else - GenerateClass(@class); + { + var classTemplate = classTemplates.FirstOrDefault(t => t.TemplatedDecl == @class); + if (classTemplate != null) + { + GenerateClassTemplateSpecializationInternal(classTemplate); + classTemplates.Remove(classTemplate); + } + else if (!@class.IsDependent) + { + GenerateClass(@class); + } + } } if (context.HasFunctions) @@ -253,6 +310,20 @@ namespace CppSharp.Generators.CSharp } } + private void GenerateClassTemplateSpecializationInternal(ClassTemplate classTemplate) + { + PushBlock(CSharpBlockKind.Namespace); + WriteLine("namespace {0}", classTemplate.Name); + WriteStartBraceIndent(); + if (classTemplate.TemplatedClass.Fields.Any(f => f.IsDependent && !f.Type.IsAddress())) + foreach (var specialization in classTemplate.Specializations) + GenerateClassInternals(specialization); + else + GenerateClassInternals(classTemplate.Specializations[0]); + WriteCloseBraceIndent(); + PopBlock(NewLineKind.BeforeNextBlock); + } + public void GenerateDeclarationCommon(Declaration decl) { if (decl.Comment != null) @@ -483,7 +554,7 @@ namespace CppSharp.Generators.CSharp if (@class.IsDynamic) GenerateVTablePointers(@class); GenerateClassFields(@class, @class, GenerateClassInternalsField, true); - if (@class.IsGenerated) + if (@class.IsGenerated && !(@class is ClassTemplateSpecialization)) { var functions = GatherClassInternalFunctions(@class); @@ -584,10 +655,13 @@ namespace CppSharp.Generators.CSharp { Write("public "); - if (@class != null && @class.NeedsBase && !@class.BaseClass.IsInterface) + if (@class != null && @class.NeedsBase && !@class.BaseClass.IsInterface && !@class.IsDependent) Write("new "); - WriteLine("partial struct Internal"); + var templateSpecialization = @class as ClassTemplateSpecialization; + var suffix = Helpers.GetSuffixForInternal(templateSpecialization, TypePrinter); + WriteLine("{0}partial struct Internal{1}", + templateSpecialization != null ? "unsafe " : string.Empty, suffix); } public static bool ShouldGenerateClassNativeField(Class @class) @@ -687,7 +761,7 @@ namespace CppSharp.Generators.CSharp var fieldTypePrinted = field.QualifiedType.CSharpType(TypePrinter); TypePrinter.PopMarshalKind(); - var fieldType = field.Type.IsAddress() ? + var fieldType = field.Type.Desugar().IsAddress() ? CSharpTypePrinter.IntPtrType : fieldTypePrinted.Type; var fieldName = safeIdentifier; @@ -2491,8 +2565,13 @@ namespace CppSharp.Generators.CSharp if (parameters == null) parameters = function.Parameters; + var templateSpecialization = function.Namespace as ClassTemplateSpecialization; + + string @namespace = templateSpecialization != null ? + (function.Namespace.OriginalName + '.') : string.Empty; + CheckArgumentRange(function); - var functionName = string.Format("Internal.{0}", + var functionName = string.Format("{0}Internal.{1}", @namespace, GetFunctionNativeIdentifier(function.OriginalFunction ?? function)); GenerateFunctionCall(functionName, parameters, function, returnType); } diff --git a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs index 15da956b..fcb2ea30 100644 --- a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs +++ b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs @@ -5,6 +5,7 @@ using CppSharp.AST.Extensions; using CppSharp.Types; using Type = CppSharp.AST.Type; using ParserTargetInfo = CppSharp.Parser.ParserTargetInfo; +using System.Linq; namespace CppSharp.Generators.CSharp { @@ -383,9 +384,11 @@ namespace CppSharp.Generators.CSharp TypeMap typeMap; if (!driver.TypeDatabase.FindTypeMap(template, out typeMap)) - return GetNestedQualifiedName(decl) + - (ContextKind == CSharpTypePrinterContextKind.Native - ? ".Internal" : string.Empty); + { + if (ContextKind != CSharpTypePrinterContextKind.Native) + return GetNestedQualifiedName(decl); + return GetTemplateSpecializationInternal(template); + } typeMap.Declaration = decl; typeMap.Type = template; @@ -408,6 +411,65 @@ namespace CppSharp.Generators.CSharp ".Internal" : string.Empty); } + private string GetTemplateSpecializationInternal(TemplateSpecializationType template) + { + var classTemplate = template.Template as ClassTemplate; + if (classTemplate != null) + { + foreach (var specialization in classTemplate.Specializations) + { + if (FoundMatchingSpecialization(template.Arguments, + specialization.Arguments, classTemplate.Parameters)) + { + return GetNestedQualifiedName(specialization.TemplatedDecl) + + ".Internal" + Helpers.GetSuffixForInternal( + template.Template.TemplatedDecl, specialization.Arguments, this); + } + } + } + var functionTemplate = (FunctionTemplate) template.Template; + foreach (var specialization in functionTemplate.Specializations) + { + if (FoundMatchingSpecialization(template.Arguments, + specialization.Arguments, functionTemplate.Parameters)) + { + return GetNestedQualifiedName(specialization.SpecializedFunction) + + ".Internal" + Helpers.GetSuffixForInternal( + template.Template.TemplatedDecl, specialization.Arguments, this); + } + } + var qualifiedName = GetNestedQualifiedName(template.Template.TemplatedDecl); + return qualifiedName + ".Internal" + + Helpers.GetSuffixForInternal(template.Template.TemplatedDecl, template.Arguments, this); + } + + private static bool FoundMatchingSpecialization( + IList templateTypeArguments, + IEnumerable templateSpecializationArguments, + IList templateParameters) + { + var usedTemplateArguments = new List(templateSpecializationArguments); + for (int i = usedTemplateArguments.Count - 1; i >= templateTypeArguments.Count; i--) + { + var templateParameter = templateParameters[i]; + var typeTemplateParameter = templateParameter as TypeTemplateParameter; + if (typeTemplateParameter != null && + typeTemplateParameter.DefaultArgument.Type != null) + { + usedTemplateArguments.RemoveAt(i); + continue; + } + var nonTypeTemplateParameter = templateParameter as NonTypeTemplateParameter; + if (nonTypeTemplateParameter != null && + nonTypeTemplateParameter.DefaultArgument != null) + { + usedTemplateArguments.RemoveAt(i); + continue; + } + } + return usedTemplateArguments.SequenceEqual(templateTypeArguments); + } + private string GetCSharpSignature(TypeMap typeMap) { Context.CSharpKind = ContextKind; diff --git a/src/Generator/Passes/DelegatesPass.cs b/src/Generator/Passes/DelegatesPass.cs index 2d125d73..e829a48f 100644 --- a/src/Generator/Passes/DelegatesPass.cs +++ b/src/Generator/Passes/DelegatesPass.cs @@ -145,30 +145,28 @@ namespace CppSharp.Passes private string GenerateDelegateSignature(IEnumerable @params, QualifiedType returnType) { - var typePrinter = new CSharpTypePrinter(Driver); - typePrinter.PushContext(CSharpTypePrinterContextKind.Native); + TypePrinter.PushContext(CSharpTypePrinterContextKind.Native); var typesBuilder = new StringBuilder(); if (!returnType.Type.IsPrimitiveType(PrimitiveType.Void)) { - typesBuilder.Insert(0, returnType.Type.CSharpType(typePrinter)); + typesBuilder.Insert(0, returnType.Type.CSharpType(TypePrinter)); typesBuilder.Append('_'); } foreach (var parameter in @params) { - typesBuilder.Append(parameter.CSharpType(typePrinter)); + typesBuilder.Append(parameter.CSharpType(TypePrinter)); typesBuilder.Append('_'); } if (typesBuilder.Length > 0) typesBuilder.Remove(typesBuilder.Length - 1, 1); - var delegateName = typesBuilder.Replace("global::System.", string.Empty).Replace( - "*", "Ptr").Replace('.', '_').ToString(); + var delegateName = Helpers.FormatTypesStringForIdentifier(typesBuilder).ToString(); if (returnType.Type.IsPrimitiveType(PrimitiveType.Void)) delegateName = "Action_" + delegateName; else delegateName = "Func_" + delegateName; - typePrinter.PopContext(); + TypePrinter.PopContext(); return delegateName; } diff --git a/src/Generator/Passes/TrimSpecializationsPass.cs b/src/Generator/Passes/TrimSpecializationsPass.cs new file mode 100644 index 00000000..173583a5 --- /dev/null +++ b/src/Generator/Passes/TrimSpecializationsPass.cs @@ -0,0 +1,28 @@ +using System.Linq; +using CppSharp.AST; +using CppSharp.AST.Extensions; + +namespace CppSharp.Passes +{ + public class TrimSpecializationsPass : TranslationUnitPass + { + public override bool VisitClassTemplateDecl(ClassTemplate template) + { + if (!base.VisitClassTemplateDecl(template) || + template.Specializations.Count == 0) + return false; + + var lastGroup = (from specialization in template.Specializations + group specialization by specialization.Arguments.All( + a => a.Type.Type != null && a.Type.Type.IsAddress()) into @group + select @group).Last(); + if (lastGroup.Key) + { + foreach (var specialization in lastGroup.Skip(1)) + template.Specializations.Remove(specialization); + } + + return true; + } + } +} diff --git a/tests/CSharp/CSharp.Tests.cs b/tests/CSharp/CSharp.Tests.cs index 62455ded..8dcd38ea 100644 --- a/tests/CSharp/CSharp.Tests.cs +++ b/tests/CSharp/CSharp.Tests.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Reflection; +using System.Runtime.InteropServices; using CppSharp.Utils; using CSharp; using NUnit.Framework; @@ -285,12 +286,6 @@ public class CSharpTests : GeneratorTestFixture Assert.That(res, Is.EqualTo(50)); } - [Test] - public void TestInnerClasses() - { - QMap.Iterator test_iter; - } - [Test] public void TestNativeToManagedMapWithForeignObjects() { @@ -507,4 +502,38 @@ public class CSharpTests : GeneratorTestFixture } Assert.IsTrue(VirtualDtorAddedInDerived.DtorCalled); } + + [Test] + public void TestTemplateInternals() + { + foreach (var internalType in new[] + { + typeof(CSharp.IndependentFields.Internal), + typeof(CSharp.DependentValueFields.Internal_int), + typeof(CSharp.DependentValueFields.Internal_float), + typeof(CSharp.DependentPointerFields.Internal), + typeof(CSharp.DependentValueFields.Internal_Ptr), + typeof(CSharp.HasDefaultTemplateArgument.Internal_int_IndependentFields) + }) + { + var independentFields = internalType.GetFields(); + Assert.That(independentFields.Length, Is.EqualTo(1)); + var fieldOffset = (FieldOffsetAttribute) independentFields[0].GetCustomAttribute(typeof(FieldOffsetAttribute)); + Assert.That(fieldOffset.Value, Is.EqualTo(0)); + } + foreach (var internalType in new[] + { + typeof(CSharp.TwoTemplateArgs.Internal_Ptr), + typeof(CSharp.TwoTemplateArgs.Internal_intPtr_int), + typeof(CSharp.TwoTemplateArgs.Internal_intPtr_float) + }) + { + var independentFields = internalType.GetFields(); + Assert.That(independentFields.Length, Is.EqualTo(2)); + var fieldOffsetKey = (FieldOffsetAttribute) independentFields[0].GetCustomAttribute(typeof(FieldOffsetAttribute)); + Assert.That(fieldOffsetKey.Value, Is.EqualTo(0)); + var fieldOffsetValue = (FieldOffsetAttribute) independentFields[1].GetCustomAttribute(typeof(FieldOffsetAttribute)); + Assert.That(fieldOffsetValue.Value, Is.EqualTo(Marshal.SizeOf(IntPtr.Zero))); + } + } } diff --git a/tests/CSharp/CSharpTemplates.cpp b/tests/CSharp/CSharpTemplates.cpp new file mode 100644 index 00000000..0be43c42 --- /dev/null +++ b/tests/CSharp/CSharpTemplates.cpp @@ -0,0 +1,17 @@ +#include "CSharpTemplates.h" + +TemplateSpecializer::TemplateSpecializer() +{ +} + +void TemplateSpecializer::completeSpecializationInParameter(DependentValueFields p1, + DependentValueFields p2, + DependentValueFields p3) +{ +} + +void TemplateSpecializer::completeSpecializationInParameter(TwoTemplateArgs p1, + TwoTemplateArgs p2, + TwoTemplateArgs p3) +{ +} diff --git a/tests/CSharp/CSharpTemplates.h b/tests/CSharp/CSharpTemplates.h new file mode 100644 index 00000000..6fc9a0a6 --- /dev/null +++ b/tests/CSharp/CSharpTemplates.h @@ -0,0 +1,63 @@ +#include "../Tests.h" + +class DLL_API T1 +{ +}; + +class DLL_API T2 +{ +}; + +template +class DLL_API IndependentFields +{ +private: + int field; +}; + +template +class DLL_API DependentValueFields +{ +private: + T field; +}; + +template +class DLL_API DependentPointerFields +{ +private: + T* field; +}; + +template +class TwoTemplateArgs +{ +private: + K key; + V value; +}; + +template > +class HasDefaultTemplateArgument +{ + T field; +}; + +class DLL_API TemplateSpecializer +{ +public: + TemplateSpecializer(); +private: + IndependentFields independentFields; + DependentValueFields dependentValueFields; + DependentPointerFields dependentPointerFields; + HasDefaultTemplateArgument hasDefaultTemplateArgument; + DependentValueFields dependentPointerFieldsT1; + DependentValueFields dependentPointerFieldsT2; + void completeSpecializationInParameter(DependentValueFields p1, + DependentValueFields p2, + DependentValueFields p3); + void completeSpecializationInParameter(TwoTemplateArgs p1, + TwoTemplateArgs p2, + TwoTemplateArgs p3); +};