Browse Source

Restore removed specializations

All specializations which only use pointers as their type arguments need at most one internal representation since pointers are mapped to IntPtr. This was achieved by removing the unneeded specializations from their containing list. This was, however, a bug because specializations were thus removed not only as internal structures but in their entirety.

Signed-off-by: Dimitar Dobrev <dpldobrev@protonmail.com>
pull/1204/head
Dimitar Dobrev 6 years ago
parent
commit
3caa8c5da2
  1. 2
      src/AST/Type.cs
  2. 32
      src/Generator/Generators/CSharp/CSharpSources.cs
  3. 32
      src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs
  4. 4
      src/Generator/Generators/CSharp/CSharpTypePrinter.cs
  5. 14
      src/Generator/Passes/CheckDuplicatedNamesPass.cs
  6. 94
      src/Generator/Passes/DelegatesPass.cs
  7. 5
      src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs
  8. 14
      src/Generator/Passes/TrimSpecializationsPass.cs

2
src/AST/Type.cs

@ -551,7 +551,7 @@ namespace CppSharp.AST
/// <summary> /// <summary>
/// Represents a template argument within a class template specialization. /// Represents a template argument within a class template specialization.
/// </summary> /// </summary>
public struct TemplateArgument public class TemplateArgument
{ {
/// The kind of template argument we're storing. /// The kind of template argument we're storing.
public enum ArgumentKind public enum ArgumentKind

32
src/Generator/Generators/CSharp/CSharpSources.cs

@ -312,7 +312,7 @@ namespace CppSharp.Generators.CSharp
GenerateClassTemplateSpecializationsInternals( GenerateClassTemplateSpecializationsInternals(
nestedTemplate, nestedTemplate.Specializations); nestedTemplate, nestedTemplate.Specializations);
foreach (var specialization in generated) foreach (var specialization in generated.KeepSingleAllPointersSpecialization())
GenerateClassInternals(specialization); GenerateClassInternals(specialization);
foreach (var group in generated.SelectMany(s => s.Classes).Where( foreach (var group in generated.SelectMany(s => s.Classes).Where(
@ -558,13 +558,31 @@ namespace CppSharp.Generators.CSharp
functions.AddRange(GatherClassInternalFunctions(@base.Class, false)); functions.AddRange(GatherClassInternalFunctions(@base.Class, false));
var currentSpecialization = @class as ClassTemplateSpecialization; var currentSpecialization = @class as ClassTemplateSpecialization;
Class template; if (currentSpecialization != null)
if (currentSpecialization != null && {
(template = currentSpecialization.TemplatedDecl.TemplatedClass) Class template = currentSpecialization.TemplatedDecl.TemplatedClass;
.GetSpecializedClassesToGenerate().Count() == 1) IEnumerable<ClassTemplateSpecialization> specializations = null;
foreach (var specialization in template.Specializations.Where(s => s.IsGenerated)) if (template.GetSpecializedClassesToGenerate().Count() == 1)
GatherClassInternalFunctions(specialization, includeCtors, functions); specializations = template.Specializations.Where(s => s.IsGenerated);
else else
{
Func<TemplateArgument, bool> allPointers = (TemplateArgument a) =>
a.Type.Type?.Desugar().IsAddress() == true;
if (currentSpecialization.Arguments.All(allPointers))
{
specializations = template.Specializations.Where(
s => s.IsGenerated && s.Arguments.All(allPointers));
}
}
if (specializations != null)
{
foreach (var specialization in specializations)
GatherClassInternalFunctions(specialization, includeCtors, functions);
return functions;
}
}
GatherClassInternalFunctions(@class, includeCtors, functions); GatherClassInternalFunctions(@class, includeCtors, functions);
return functions; return functions;

32
src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs

@ -23,12 +23,38 @@ namespace CppSharp.Generators.CSharp
{ {
var printedClass = @class.Visit(gen.TypePrinter); var printedClass = @class.Visit(gen.TypePrinter);
if (@class.IsDependent) if (@class.IsDependent)
foreach (var specialization in @class.GetSpecializedClassesToGenerate( {
).Where(s => s.IsGenerated)) IEnumerable<Class> specializations =
@class.GetSpecializedClassesToGenerate().Where(s => s.IsGenerated);
if (@class.IsTemplate)
specializations = specializations.KeepSingleAllPointersSpecialization();
foreach (var specialization in specializations)
gen.GenerateNativeConstructorByValue(specialization, printedClass.Type); gen.GenerateNativeConstructorByValue(specialization, printedClass.Type);
}
else else
{
gen.GenerateNativeConstructorByValue(@class, printedClass.Type); gen.GenerateNativeConstructorByValue(@class, printedClass.Type);
} }
}
public static IEnumerable<Class> KeepSingleAllPointersSpecialization(
this IEnumerable<Class> specializations)
{
Func<TemplateArgument, bool> allPointers = (TemplateArgument a) =>
a.Type.Type?.Desugar().IsAddress() == true;
var groups = (from ClassTemplateSpecialization spec in specializations
group spec by spec.Arguments.All(allPointers)
into @group
select @group).ToList();
foreach (var group in groups)
{
if (group.Key)
yield return group.First();
else
foreach (var specialization in group)
yield return specialization;
}
}
public static void GenerateField(this CSharpSources gen, Class @class, public static void GenerateField(this CSharpSources gen, Class @class,
Field field, Action<Field, Class, QualifiedType> generate, bool isVoid) Field field, Action<Field, Class, QualifiedType> generate, bool isVoid)
@ -112,7 +138,7 @@ namespace CppSharp.Generators.CSharp
Enumerable.Range(0, @class.TemplateParameters.Count).Select( Enumerable.Range(0, @class.TemplateParameters.Count).Select(
i => i =>
{ {
CppSharp.AST.Type type = specialization.Arguments[i].Type.Type.Desugar(); CppSharp.AST.Type type = specialization.Arguments[i].Type.Type;
return type.IsPointerToPrimitiveType() ? return type.IsPointerToPrimitiveType() ?
$"__{@class.TemplateParameters[i].Name}.FullName == \"System.IntPtr\"" : $"__{@class.TemplateParameters[i].Name}.FullName == \"System.IntPtr\"" :
$"__{@class.TemplateParameters[i].Name}.IsAssignableFrom(typeof({type}))"; $"__{@class.TemplateParameters[i].Name}.IsAssignableFrom(typeof({type}))";

4
src/Generator/Generators/CSharp/CSharpTypePrinter.cs

@ -560,14 +560,14 @@ namespace CppSharp.Generators.CSharp
{ {
if (a.Type.Type == null) if (a.Type.Type == null)
return a.Integral.ToString(CultureInfo.InvariantCulture); return a.Integral.ToString(CultureInfo.InvariantCulture);
var type = a.Type.Type.Desugar(); var type = a.Type.Type;
PrimitiveType pointee; PrimitiveType pointee;
if (type.IsPointerToPrimitiveType(out pointee) && !type.IsConstCharString()) if (type.IsPointerToPrimitiveType(out pointee) && !type.IsConstCharString())
{ {
return $@"CppSharp.Runtime.Pointer<{(pointee == PrimitiveType.Void ? IntPtrType : return $@"CppSharp.Runtime.Pointer<{(pointee == PrimitiveType.Void ? IntPtrType :
VisitPrimitiveType(pointee, new TypeQualifiers()).Type)}>"; VisitPrimitiveType(pointee, new TypeQualifiers()).Type)}>";
} }
return (type.IsPrimitiveType(PrimitiveType.Void)) ? "object" : type.Visit(this).Type; return type.IsPrimitiveType(PrimitiveType.Void) ? "object" : type.Visit(this).Type;
} }
public override TypePrinterResult VisitParameterDecl(Parameter parameter) public override TypePrinterResult VisitParameterDecl(Parameter parameter)

14
src/Generator/Passes/CheckDuplicatedNamesPass.cs

@ -159,10 +159,16 @@ namespace CppSharp.Passes
if (x.Kind != TemplateArgument.ArgumentKind.Type || if (x.Kind != TemplateArgument.ArgumentKind.Type ||
y.Kind != TemplateArgument.ArgumentKind.Type) y.Kind != TemplateArgument.ArgumentKind.Type)
return x.Equals(y); return x.Equals(y);
return x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps, Type left = x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
ParameterTypeComparer.GeneratorKind).Equals( ParameterTypeComparer.GeneratorKind);
y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps, Type right = y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
ParameterTypeComparer.GeneratorKind)); ParameterTypeComparer.GeneratorKind);
// consider Type and const Type the same
if (left.IsReference() && !left.IsPointerToPrimitiveType())
left = left.GetPointee();
if (right.IsReference() && !right.IsPointerToPrimitiveType())
right = right.GetPointee();
return left.Equals(right);
} }
public int GetHashCode(TemplateArgument obj) public int GetHashCode(TemplateArgument obj)

94
src/Generator/Passes/DelegatesPass.cs

@ -46,12 +46,38 @@ namespace CppSharp.Passes
return base.VisitClassDecl(@class); return base.VisitClassDecl(@class);
} }
public override bool VisitClassTemplateSpecializationDecl(ClassTemplateSpecialization specialization)
{
if (!base.VisitClassTemplateSpecializationDecl(specialization) ||
!specialization.IsGenerated || !specialization.TemplatedDecl.TemplatedDecl.IsGenerated)
return false;
foreach (TemplateArgument arg in specialization.Arguments.Where(
a => a.Kind == TemplateArgument.ArgumentKind.Type))
{
arg.Type = CheckForDelegate(arg.Type, specialization);
}
return true;
}
public override bool VisitMethodDecl(Method method) public override bool VisitMethodDecl(Method method)
{ {
if (!base.VisitMethodDecl(method) || !method.IsVirtual || method.Ignore) if (!base.VisitMethodDecl(method) || !method.IsVirtual || method.Ignore)
return false; return false;
method.FunctionType = CheckForDelegate(method.FunctionType, method); var functionType = new FunctionType
{
CallingConvention = method.CallingConvention,
IsDependent = method.IsDependent,
ReturnType = method.ReturnType
};
functionType.Parameters.AddRange(
method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true));
method.FunctionType = CheckForDelegate(new QualifiedType(functionType),
method.Namespace, @private: true);
return true; return true;
} }
@ -61,7 +87,8 @@ namespace CppSharp.Passes
if (!base.VisitFunctionDecl(function) || function.Ignore) if (!base.VisitFunctionDecl(function) || function.Ignore)
return false; return false;
function.ReturnType = CheckForDelegate(function.ReturnType, function); function.ReturnType = CheckForDelegate(function.ReturnType,
function.Namespace);
return true; return true;
} }
@ -71,7 +98,8 @@ namespace CppSharp.Passes
parameter.Namespace.Ignore) parameter.Namespace.Ignore)
return false; return false;
parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType, parameter); parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType,
parameter.Namespace);
return true; return true;
} }
@ -81,7 +109,8 @@ namespace CppSharp.Passes
if (!base.VisitProperty(property)) if (!base.VisitProperty(property))
return false; return false;
property.QualifiedType = CheckForDelegate(property.QualifiedType, property); property.QualifiedType = CheckForDelegate(property.QualifiedType,
property.Namespace);
return true; return true;
} }
@ -91,12 +120,14 @@ namespace CppSharp.Passes
if (!base.VisitFieldDecl(field)) if (!base.VisitFieldDecl(field))
return false; return false;
field.QualifiedType = CheckForDelegate(field.QualifiedType, field); field.QualifiedType = CheckForDelegate(field.QualifiedType,
field.Namespace);
return true; return true;
} }
private QualifiedType CheckForDelegate(QualifiedType type, ITypedDecl decl) private QualifiedType CheckForDelegate(QualifiedType type,
DeclarationContext declarationContext, bool @private = false)
{ {
if (type.Type is TypedefType) if (type.Type is TypedefType)
return type; return type;
@ -109,22 +140,21 @@ namespace CppSharp.Passes
if (pointee is TypedefType) if (pointee is TypedefType)
return type; return type;
var functionType = pointee.Desugar() as FunctionType; desugared = pointee.Desugar();
if (functionType == null) FunctionType functionType = desugared as FunctionType;
if (functionType == null && !desugared.IsPointerTo(out functionType))
return type; return type;
TypedefDecl @delegate = GetDelegate(type, decl); TypedefDecl @delegate = GetDelegate(functionType, declarationContext, @private);
return new QualifiedType(new TypedefType { Declaration = @delegate }); return new QualifiedType(new TypedefType { Declaration = @delegate });
} }
private TypedefDecl GetDelegate(QualifiedType type, ITypedDecl typedDecl) private TypedefDecl GetDelegate(FunctionType functionType,
DeclarationContext declarationContext, bool @private = false)
{ {
FunctionType newFunctionType = GetNewFunctionType(typedDecl, type); var delegateName = GetDelegateName(functionType);
var access = @private ? AccessSpecifier.Private : AccessSpecifier.Public;
var delegateName = GetDelegateName(newFunctionType); Module module = declarationContext.TranslationUnit.Module;
var access = typedDecl is Method ? AccessSpecifier.Private : AccessSpecifier.Public;
var decl = (Declaration) typedDecl;
Module module = decl.TranslationUnit.Module;
var existingDelegate = delegates.Find(t => Match(t, delegateName, module)); var existingDelegate = delegates.Find(t => Match(t, delegateName, module));
if (existingDelegate != null) if (existingDelegate != null)
{ {
@ -135,18 +165,18 @@ namespace CppSharp.Passes
// Check if there is an existing delegate with a different calling convention // Check if there is an existing delegate with a different calling convention
if (((FunctionType) existingDelegate.Type.GetPointee()).CallingConvention == if (((FunctionType) existingDelegate.Type.GetPointee()).CallingConvention ==
newFunctionType.CallingConvention) functionType.CallingConvention)
return existingDelegate; return existingDelegate;
// Add a new delegate with the calling convention appended to its name // Add a new delegate with the calling convention appended to its name
delegateName += '_' + newFunctionType.CallingConvention.ToString(); delegateName += '_' + functionType.CallingConvention.ToString();
existingDelegate = delegates.Find(t => Match(t, delegateName, module)); existingDelegate = delegates.Find(t => Match(t, delegateName, module));
if (existingDelegate != null) if (existingDelegate != null)
return existingDelegate; return existingDelegate;
} }
var namespaceDelegates = GetDeclContextForDelegates(decl.Namespace); var namespaceDelegates = GetDeclContextForDelegates(declarationContext);
var delegateType = new QualifiedType(new PointerType(new QualifiedType(newFunctionType))); var delegateType = new QualifiedType(new PointerType(new QualifiedType(functionType)));
existingDelegate = new TypedefDecl existingDelegate = new TypedefDecl
{ {
Access = access, Access = access,
@ -160,30 +190,6 @@ namespace CppSharp.Passes
return existingDelegate; return existingDelegate;
} }
private FunctionType GetNewFunctionType(ITypedDecl decl, QualifiedType type)
{
var functionType = new FunctionType();
var method = decl as Method;
if (method != null && method.FunctionType == type)
{
functionType.Parameters.AddRange(
method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true));
functionType.CallingConvention = method.CallingConvention;
functionType.IsDependent = method.IsDependent;
functionType.ReturnType = method.ReturnType;
}
else
{
var funcTypeParam = (FunctionType) decl.Type.Desugar().GetFinalPointee().Desugar();
functionType = new FunctionType(funcTypeParam);
}
for (int i = 0; i < functionType.Parameters.Count; i++)
functionType.Parameters[i].Name = $"_{i}";
return functionType;
}
private static bool Match(TypedefDecl t, string delegateName, Module module) private static bool Match(TypedefDecl t, string delegateName, Module module)
{ {
return t.Name == delegateName && return t.Name == delegateName &&

5
src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs

@ -72,8 +72,11 @@ namespace CppSharp.Passes
foreach (var method in methodsWithDependentPointers.Where( foreach (var method in methodsWithDependentPointers.Where(
m => m.SynthKind == FunctionSynthKind.None)) m => m.SynthKind == FunctionSynthKind.None))
{ {
var specializedMethod = specialization.Methods.First( var specializedMethod = specialization.Methods.FirstOrDefault(
m => m.InstantiatedFrom == method); m => m.InstantiatedFrom == method);
if (specializedMethod == null)
continue;
Method extensionMethod = GetExtensionMethodForDependentPointer(specializedMethod); Method extensionMethod = GetExtensionMethodForDependentPointer(specializedMethod);
classExtensions.Methods.Add(extensionMethod); classExtensions.Methods.Add(extensionMethod);
extensionMethod.Namespace = classExtensions; extensionMethod.Namespace = classExtensions;

14
src/Generator/Passes/TrimSpecializationsPass.cs

@ -127,22 +127,10 @@ namespace CppSharp.Passes
s => !s.IsExplicitlyGenerated && internalSpecializations.Contains(s))) s => !s.IsExplicitlyGenerated && internalSpecializations.Contains(s)))
specialization.GenerationKind = GenerationKind.Internal; specialization.GenerationKind = GenerationKind.Internal;
Func<TemplateArgument, bool> allPointers =
a => a.Type.Type != null && a.Type.Type.IsAddress();
var groups = (from specialization in template.Specializations
group specialization by specialization.Arguments.All(allPointers)
into @group
select @group).ToList();
foreach (var group in groups.Where(g => g.Key))
foreach (var specialization in group.Skip(1))
template.Specializations.Remove(specialization);
for (int i = template.Specializations.Count - 1; i >= 0; i--) for (int i = template.Specializations.Count - 1; i >= 0; i--)
{ {
var specialization = template.Specializations[i]; var specialization = template.Specializations[i];
if (specialization is ClassTemplatePartialSpecialization && if (specialization is ClassTemplatePartialSpecialization)
!specialization.Arguments.All(allPointers))
template.Specializations.RemoveAt(i); template.Specializations.RemoveAt(i);
} }

Loading…
Cancel
Save