Browse Source

Force compilation of all functions of specializations

Functions of template specializations can have their symbols compiled by having their addresses taken just like regular functions. This way we take just the necessary symbols compared to exporting entire templates which both compile useless symbols and skip actually needed ones.

Signed-off-by: Dimitar Dobrev <dpldobrev@protonmail.com>
pull/1265/head
Dimitar Dobrev 6 years ago
parent
commit
07b9e4ca10
  1. 1
      src/Generator/Driver.cs
  2. 31
      src/Generator/Passes/GenerateSymbolsPass.cs
  3. 109
      src/Generator/Passes/SymbolsCodeGenerator.cs

1
src/Generator/Driver.cs

@ -226,6 +226,7 @@ namespace CppSharp @@ -226,6 +226,7 @@ namespace CppSharp
if (Options.IsCSharpGenerator)
{
TranslationUnitPasses.AddPass(new TrimSpecializationsPass());
TranslationUnitPasses.AddPass(new CheckAmbiguousFunctions());
TranslationUnitPasses.AddPass(new GenerateSymbolsPass());
TranslationUnitPasses.AddPass(new CheckIgnoredDeclsPass());
}

31
src/Generator/Passes/GenerateSymbolsPass.cs

@ -15,7 +15,6 @@ namespace CppSharp.Passes @@ -15,7 +15,6 @@ namespace CppSharp.Passes
{
VisitOptions.VisitClassBases = false;
VisitOptions.VisitClassFields = false;
VisitOptions.VisitClassTemplateSpecializations = false;
VisitOptions.VisitEventParameters = false;
VisitOptions.VisitFunctionParameters = false;
VisitOptions.VisitFunctionReturnType = false;
@ -46,22 +45,6 @@ namespace CppSharp.Passes @@ -46,22 +45,6 @@ namespace CppSharp.Passes
foreach (var module in modules)
{
var symbolsCodeGenerator = symbolsCodeGenerators[module];
if (specializations.ContainsKey(module))
{
symbolsCodeGenerator.NewLine();
foreach (var specialization in specializations[module])
{
Func<Method, bool> exportable = m => !m.IsDependent &&
!m.IsImplicit && !m.IsDeleted && !m.IsDefaulted;
if (specialization.Methods.Any(m => m.IsInvalid && exportable(m)))
foreach (var method in specialization.Methods.Where(
m => m.IsGenerated && (m.InstantiatedFrom == null || m.InstantiatedFrom.IsGenerated) &&
exportable(m)))
symbolsCodeGenerator.VisitMethodDecl(method);
else
symbolsCodeGenerator.VisitClassTemplateSpecializationDecl(specialization);
}
}
var cpp = $"{module.SymbolsLibraryName}.{symbolsCodeGenerator.FileExtension}";
Directory.CreateDirectory(Options.OutputDir);
@ -133,10 +116,18 @@ namespace CppSharp.Passes @@ -133,10 +116,18 @@ namespace CppSharp.Passes
{
var mangled = function.Mangled;
var method = function as Method;
bool isInspecialization;
var declarationContext = function.Namespace;
do
{
isInspecialization = declarationContext is ClassTemplateSpecialization;
declarationContext = declarationContext.Namespace;
} while (!isInspecialization && declarationContext != null);
return function.IsGenerated && !function.IsDeleted &&
!function.IsDependent && !function.IsPure &&
(!string.IsNullOrEmpty(function.Body) || function.IsImplicit) &&
!(function.Namespace is ClassTemplateSpecialization) &&
!function.IsDependent && !function.IsPure && function.Namespace.IsGenerated &&
(!string.IsNullOrEmpty(function.Body) ||
isInspecialization || function.IsImplicit) &&
// we don't need symbols for virtual functions anyway
(method == null || (!method.IsVirtual && !method.IsSynthetized &&
(!method.IsConstructor || !((Class) method.Namespace).IsAbstract))) &&

109
src/Generator/Passes/SymbolsCodeGenerator.cs

@ -32,15 +32,13 @@ namespace CppSharp.Passes @@ -32,15 +32,13 @@ namespace CppSharp.Passes
NewLine();
}
public override bool VisitClassTemplateSpecializationDecl(ClassTemplateSpecialization specialization)
{
WriteLine($"template class {GetExporting()}{specialization.Visit(cppTypePrinter)};");
return true;
}
public override bool VisitMethodDecl(Method method)
{
if (method.Namespace is ClassTemplateSpecialization)
if (method.Namespace is ClassTemplateSpecialization specialization &&
(method.TranslationUnit.IsSystemHeader ||
((method.IsConstructor || method.IsDestructor) &&
!method.IsImplicit && !method.IsDefaulted && !method.IsPure &&
string.IsNullOrEmpty(method.Body))))
{
WriteLine($"template {GetExporting()}{method.Visit(cppTypePrinter)};");
return true;
@ -120,7 +118,7 @@ namespace CppSharp.Passes @@ -120,7 +118,7 @@ namespace CppSharp.Passes
if (method.Access == AccessSpecifier.Protected)
{
Write(GetDerivedType(@namespace, wrapper));
Write($"{wrapper}{@namespace}");
Write(wrapper + @namespace);
Write($@"({string.Join(", ", method.Parameters.Select(
p => cppTypePrinter.VisitParameter(p)))})");
WriteLine($": {@namespace}({@params}) {{}} }};");
@ -129,7 +127,7 @@ namespace CppSharp.Passes @@ -129,7 +127,7 @@ namespace CppSharp.Passes
}
else
{
Write($"extern \"C\" ");
Write("extern \"C\" ");
if (method.Namespace.Access == AccessSpecifier.Protected)
Write($@"{{ class {wrapper}{method.Namespace.Namespace.Name} : public {
method.Namespace.Namespace.Visit(cppTypePrinter)} ");
@ -168,7 +166,7 @@ namespace CppSharp.Passes @@ -168,7 +166,7 @@ namespace CppSharp.Passes
bool isProtected = method.Access == AccessSpecifier.Protected;
string @namespace = method.Namespace.Visit(cppTypePrinter);
if (isProtected)
Write(GetDerivedType(@namespace, wrapper));
Write($"class {wrapper} : public {@namespace} {{ public: ");
else
Write("extern \"C\" { ");
if (method.Namespace.Access == AccessSpecifier.Protected)
@ -176,14 +174,16 @@ namespace CppSharp.Passes @@ -176,14 +174,16 @@ namespace CppSharp.Passes
method.Namespace.Namespace.Visit(cppTypePrinter)} {{ ");
Write($"void {wrapper}");
if (isProtected)
Write("protected");
Write($@"({@namespace}* {Helpers.InstanceField}) {{ {
Helpers.InstanceField}->~{method.Namespace.OriginalName}(); }} }}");
Write("Protected");
string instance = Helpers.InstanceField;
Write($@"({(isProtected ? wrapper : @namespace)}* {
instance}) {{ delete {instance}; }} }};");
if (isProtected)
{
NewLine();
Write($@"void {wrapper}({@namespace} {Helpers.InstanceField}) {{ {
wrapper}{@namespace}::{wrapper}protected({Helpers.InstanceField}); }}");
Write($@"extern ""C"" {{ void {wrapper}({wrapper}* {instance}) {{ {
instance}->{wrapper}Protected({instance}); }} }}");
}
if (method.Namespace.Access == AccessSpecifier.Protected)
Write("; }");
@ -192,14 +192,14 @@ namespace CppSharp.Passes @@ -192,14 +192,14 @@ namespace CppSharp.Passes
private void TakeFunctionAddress(Function function)
{
//function = function.OriginalFunction ?? function;
string wrapper = GetWrapper(function);
string @namespace = function.Namespace.Visit(cppTypePrinter);
string @namespace = function.OriginalNamespace.Visit(cppTypePrinter);
if (function.Access == AccessSpecifier.Protected)
{
Write(GetDerivedType(@namespace, wrapper));
Write($"class {wrapper}{function.Namespace.Name} : public {@namespace} {{ public: ");
Write("static constexpr ");
}
string returnType = function.OriginalReturnType.Visit(cppTypePrinter);
string signature = GetSignature(function);
@ -209,15 +209,15 @@ namespace CppSharp.Passes @@ -209,15 +209,15 @@ namespace CppSharp.Passes
var method = function as Method;
if (function.Namespace.Access == AccessSpecifier.Protected)
Write($@"class {wrapper}{function.Namespace.Namespace.Name} : public {
Write($@"class {wrapper}{function.Namespace.Name} : public {
function.Namespace.Namespace.Visit(cppTypePrinter)} {{ ");
Write($@"{returnType} ({(method != null && !method.IsStatic ?
(@namespace + "::") : string.Empty)}*{wrapper}){signature}");
if (function.Access == AccessSpecifier.Protected)
{
Write($" = &{wrapper}{@namespace}::{functionName};");
Write($" = &{wrapper}{function.Namespace.Name}::{functionName};");
WriteLine(" };");
Write($"auto {wrapper}protected = {wrapper}{@namespace}::{wrapper};");
Write($"auto {wrapper}Protected = {wrapper}{function.Namespace.Name}::{wrapper};");
}
else
{
@ -251,25 +251,47 @@ namespace CppSharp.Passes @@ -251,25 +251,47 @@ namespace CppSharp.Passes
private string GetFunctionName(Function function, string @namespace)
{
return $@"{(function.Access == AccessSpecifier.Protected ||
string.IsNullOrEmpty(@namespace) ?
string.Empty : (@namespace + "::"))}{function.OriginalName}{
(function.SpecializationInfo == null ? string.Empty : $@"<{
string.Join(", ", function.SpecializationInfo.Arguments.Select(
a =>
{
switch (a.Kind)
{
case TemplateArgument.ArgumentKind.Type:
return a.Type.Visit(cppTypePrinter).Type;
case TemplateArgument.ArgumentKind.Declaration:
return a.Declaration.Visit(cppTypePrinter).Type;
case TemplateArgument.ArgumentKind.Integral:
return a.Integral.ToString(CultureInfo.InvariantCulture);
}
throw new System.ArgumentOutOfRangeException(
nameof(a.Kind), a.Kind, "Unsupported kind of template argument.");
}))}>")}";
var nameBuilder = new StringBuilder();
if (function.Access != AccessSpecifier.Protected &&
!string.IsNullOrEmpty(@namespace))
nameBuilder.Append(@namespace).Append("::");
bool isConversionToSpecialization =
(function.OperatorKind == CXXOperatorKind.Conversion ||
function.OperatorKind == CXXOperatorKind.ExplicitConversion) &&
function.OriginalReturnType.Type.Desugar(
).TryGetDeclaration(out ClassTemplateSpecialization specialization);
nameBuilder.Append(isConversionToSpecialization ?
"operator " : function.OriginalName);
if (function.SpecializationInfo != null)
nameBuilder.Append('<').Append(string.Join(", ",
GetTemplateArguments(function.SpecializationInfo.Arguments))).Append('>');
else if (isConversionToSpecialization)
nameBuilder.Append(function.OriginalReturnType.Visit(cppTypePrinter));
return nameBuilder.ToString();
}
private IEnumerable<string> GetTemplateArguments(
IEnumerable<TemplateArgument> templateArguments)
{
return templateArguments.Select(
a =>
{
switch (a.Kind)
{
case TemplateArgument.ArgumentKind.Type:
return a.Type.Visit(cppTypePrinter).Type;
case TemplateArgument.ArgumentKind.Declaration:
return a.Declaration.Visit(cppTypePrinter).Type;
case TemplateArgument.ArgumentKind.Integral:
return a.Integral.ToString(CultureInfo.InvariantCulture);
}
throw new System.ArgumentOutOfRangeException(
nameof(a.Kind), a.Kind, "Unsupported kind of template argument.");
});
}
private void WriteRedeclaration(Function function, string returnType,
@ -286,7 +308,7 @@ namespace CppSharp.Passes @@ -286,7 +308,7 @@ namespace CppSharp.Passes
Write(paramTypes);
if (functionType.ExceptionSpecType == ExceptionSpecType.BasicNoexcept)
Write(" noexcept");
WriteLine($";{string.Concat(parentsOpen.Select(p => " }"))}");
WriteLine($";{string.Concat(parentsOpen.Select(_ => " }"))}");
}
private static Stack<string> GenerateNamespace(Function function)
@ -301,7 +323,7 @@ namespace CppSharp.Passes @@ -301,7 +323,7 @@ namespace CppSharp.Passes
if (finalType.TryGetDeclaration(out declaration))
declarationContextsInSignature.Add(declaration.Namespace);
}
var nestedNamespace = declarationContextsInSignature.FirstOrDefault(d =>
var nestedNamespace = declarationContextsInSignature.Find(d =>
d.Namespace is Namespace && !(d.Namespace is TranslationUnit));
var parentsOpen = new Stack<string>();
if (nestedNamespace != null)
@ -320,7 +342,8 @@ namespace CppSharp.Passes @@ -320,7 +342,8 @@ namespace CppSharp.Passes
private CppTypePrinter cppTypePrinter = new CppTypePrinter
{
ScopeKind = TypePrintScopeKind.Qualified
ScopeKind = TypePrintScopeKind.Qualified,
ResolveTypedefs = true
};
private int functionCount;
}

Loading…
Cancel
Save