Browse Source

Restore removed specializations

All specializations which only use pointers as their type arguments need at most one internal representation since pointers are mapped to IntPtr. This was achieved by removing the unneeded specializations from their containing list. This was, however, a bug because specializations were thus removed not only as internal structures but in their entirety.

Signed-off-by: Dimitar Dobrev <dpldobrev@protonmail.com>
pull/1204/head
Dimitar Dobrev 6 years ago
parent
commit
3caa8c5da2
  1. 2
      src/AST/Type.cs
  2. 32
      src/Generator/Generators/CSharp/CSharpSources.cs
  3. 32
      src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs
  4. 4
      src/Generator/Generators/CSharp/CSharpTypePrinter.cs
  5. 14
      src/Generator/Passes/CheckDuplicatedNamesPass.cs
  6. 94
      src/Generator/Passes/DelegatesPass.cs
  7. 5
      src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs
  8. 14
      src/Generator/Passes/TrimSpecializationsPass.cs

2
src/AST/Type.cs

@ -551,7 +551,7 @@ namespace CppSharp.AST @@ -551,7 +551,7 @@ namespace CppSharp.AST
/// <summary>
/// Represents a template argument within a class template specialization.
/// </summary>
public struct TemplateArgument
public class TemplateArgument
{
/// The kind of template argument we're storing.
public enum ArgumentKind

32
src/Generator/Generators/CSharp/CSharpSources.cs

@ -312,7 +312,7 @@ namespace CppSharp.Generators.CSharp @@ -312,7 +312,7 @@ namespace CppSharp.Generators.CSharp
GenerateClassTemplateSpecializationsInternals(
nestedTemplate, nestedTemplate.Specializations);
foreach (var specialization in generated)
foreach (var specialization in generated.KeepSingleAllPointersSpecialization())
GenerateClassInternals(specialization);
foreach (var group in generated.SelectMany(s => s.Classes).Where(
@ -558,13 +558,31 @@ namespace CppSharp.Generators.CSharp @@ -558,13 +558,31 @@ namespace CppSharp.Generators.CSharp
functions.AddRange(GatherClassInternalFunctions(@base.Class, false));
var currentSpecialization = @class as ClassTemplateSpecialization;
Class template;
if (currentSpecialization != null &&
(template = currentSpecialization.TemplatedDecl.TemplatedClass)
.GetSpecializedClassesToGenerate().Count() == 1)
foreach (var specialization in template.Specializations.Where(s => s.IsGenerated))
GatherClassInternalFunctions(specialization, includeCtors, functions);
if (currentSpecialization != null)
{
Class template = currentSpecialization.TemplatedDecl.TemplatedClass;
IEnumerable<ClassTemplateSpecialization> specializations = null;
if (template.GetSpecializedClassesToGenerate().Count() == 1)
specializations = template.Specializations.Where(s => s.IsGenerated);
else
{
Func<TemplateArgument, bool> allPointers = (TemplateArgument a) =>
a.Type.Type?.Desugar().IsAddress() == true;
if (currentSpecialization.Arguments.All(allPointers))
{
specializations = template.Specializations.Where(
s => s.IsGenerated && s.Arguments.All(allPointers));
}
}
if (specializations != null)
{
foreach (var specialization in specializations)
GatherClassInternalFunctions(specialization, includeCtors, functions);
return functions;
}
}
GatherClassInternalFunctions(@class, includeCtors, functions);
return functions;

32
src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs

@ -23,12 +23,38 @@ namespace CppSharp.Generators.CSharp @@ -23,12 +23,38 @@ namespace CppSharp.Generators.CSharp
{
var printedClass = @class.Visit(gen.TypePrinter);
if (@class.IsDependent)
foreach (var specialization in @class.GetSpecializedClassesToGenerate(
).Where(s => s.IsGenerated))
{
IEnumerable<Class> specializations =
@class.GetSpecializedClassesToGenerate().Where(s => s.IsGenerated);
if (@class.IsTemplate)
specializations = specializations.KeepSingleAllPointersSpecialization();
foreach (var specialization in specializations)
gen.GenerateNativeConstructorByValue(specialization, printedClass.Type);
}
else
{
gen.GenerateNativeConstructorByValue(@class, printedClass.Type);
}
}
public static IEnumerable<Class> KeepSingleAllPointersSpecialization(
this IEnumerable<Class> specializations)
{
Func<TemplateArgument, bool> allPointers = (TemplateArgument a) =>
a.Type.Type?.Desugar().IsAddress() == true;
var groups = (from ClassTemplateSpecialization spec in specializations
group spec by spec.Arguments.All(allPointers)
into @group
select @group).ToList();
foreach (var group in groups)
{
if (group.Key)
yield return group.First();
else
foreach (var specialization in group)
yield return specialization;
}
}
public static void GenerateField(this CSharpSources gen, Class @class,
Field field, Action<Field, Class, QualifiedType> generate, bool isVoid)
@ -112,7 +138,7 @@ namespace CppSharp.Generators.CSharp @@ -112,7 +138,7 @@ namespace CppSharp.Generators.CSharp
Enumerable.Range(0, @class.TemplateParameters.Count).Select(
i =>
{
CppSharp.AST.Type type = specialization.Arguments[i].Type.Type.Desugar();
CppSharp.AST.Type type = specialization.Arguments[i].Type.Type;
return type.IsPointerToPrimitiveType() ?
$"__{@class.TemplateParameters[i].Name}.FullName == \"System.IntPtr\"" :
$"__{@class.TemplateParameters[i].Name}.IsAssignableFrom(typeof({type}))";

4
src/Generator/Generators/CSharp/CSharpTypePrinter.cs

@ -560,14 +560,14 @@ namespace CppSharp.Generators.CSharp @@ -560,14 +560,14 @@ namespace CppSharp.Generators.CSharp
{
if (a.Type.Type == null)
return a.Integral.ToString(CultureInfo.InvariantCulture);
var type = a.Type.Type.Desugar();
var type = a.Type.Type;
PrimitiveType pointee;
if (type.IsPointerToPrimitiveType(out pointee) && !type.IsConstCharString())
{
return $@"CppSharp.Runtime.Pointer<{(pointee == PrimitiveType.Void ? IntPtrType :
VisitPrimitiveType(pointee, new TypeQualifiers()).Type)}>";
}
return (type.IsPrimitiveType(PrimitiveType.Void)) ? "object" : type.Visit(this).Type;
return type.IsPrimitiveType(PrimitiveType.Void) ? "object" : type.Visit(this).Type;
}
public override TypePrinterResult VisitParameterDecl(Parameter parameter)

14
src/Generator/Passes/CheckDuplicatedNamesPass.cs

@ -159,10 +159,16 @@ namespace CppSharp.Passes @@ -159,10 +159,16 @@ namespace CppSharp.Passes
if (x.Kind != TemplateArgument.ArgumentKind.Type ||
y.Kind != TemplateArgument.ArgumentKind.Type)
return x.Equals(y);
return x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
ParameterTypeComparer.GeneratorKind).Equals(
y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
ParameterTypeComparer.GeneratorKind));
Type left = x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
ParameterTypeComparer.GeneratorKind);
Type right = y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
ParameterTypeComparer.GeneratorKind);
// consider Type and const Type the same
if (left.IsReference() && !left.IsPointerToPrimitiveType())
left = left.GetPointee();
if (right.IsReference() && !right.IsPointerToPrimitiveType())
right = right.GetPointee();
return left.Equals(right);
}
public int GetHashCode(TemplateArgument obj)

94
src/Generator/Passes/DelegatesPass.cs

@ -46,12 +46,38 @@ namespace CppSharp.Passes @@ -46,12 +46,38 @@ namespace CppSharp.Passes
return base.VisitClassDecl(@class);
}
public override bool VisitClassTemplateSpecializationDecl(ClassTemplateSpecialization specialization)
{
if (!base.VisitClassTemplateSpecializationDecl(specialization) ||
!specialization.IsGenerated || !specialization.TemplatedDecl.TemplatedDecl.IsGenerated)
return false;
foreach (TemplateArgument arg in specialization.Arguments.Where(
a => a.Kind == TemplateArgument.ArgumentKind.Type))
{
arg.Type = CheckForDelegate(arg.Type, specialization);
}
return true;
}
public override bool VisitMethodDecl(Method method)
{
if (!base.VisitMethodDecl(method) || !method.IsVirtual || method.Ignore)
return false;
method.FunctionType = CheckForDelegate(method.FunctionType, method);
var functionType = new FunctionType
{
CallingConvention = method.CallingConvention,
IsDependent = method.IsDependent,
ReturnType = method.ReturnType
};
functionType.Parameters.AddRange(
method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true));
method.FunctionType = CheckForDelegate(new QualifiedType(functionType),
method.Namespace, @private: true);
return true;
}
@ -61,7 +87,8 @@ namespace CppSharp.Passes @@ -61,7 +87,8 @@ namespace CppSharp.Passes
if (!base.VisitFunctionDecl(function) || function.Ignore)
return false;
function.ReturnType = CheckForDelegate(function.ReturnType, function);
function.ReturnType = CheckForDelegate(function.ReturnType,
function.Namespace);
return true;
}
@ -71,7 +98,8 @@ namespace CppSharp.Passes @@ -71,7 +98,8 @@ namespace CppSharp.Passes
parameter.Namespace.Ignore)
return false;
parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType, parameter);
parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType,
parameter.Namespace);
return true;
}
@ -81,7 +109,8 @@ namespace CppSharp.Passes @@ -81,7 +109,8 @@ namespace CppSharp.Passes
if (!base.VisitProperty(property))
return false;
property.QualifiedType = CheckForDelegate(property.QualifiedType, property);
property.QualifiedType = CheckForDelegate(property.QualifiedType,
property.Namespace);
return true;
}
@ -91,12 +120,14 @@ namespace CppSharp.Passes @@ -91,12 +120,14 @@ namespace CppSharp.Passes
if (!base.VisitFieldDecl(field))
return false;
field.QualifiedType = CheckForDelegate(field.QualifiedType, field);
field.QualifiedType = CheckForDelegate(field.QualifiedType,
field.Namespace);
return true;
}
private QualifiedType CheckForDelegate(QualifiedType type, ITypedDecl decl)
private QualifiedType CheckForDelegate(QualifiedType type,
DeclarationContext declarationContext, bool @private = false)
{
if (type.Type is TypedefType)
return type;
@ -109,22 +140,21 @@ namespace CppSharp.Passes @@ -109,22 +140,21 @@ namespace CppSharp.Passes
if (pointee is TypedefType)
return type;
var functionType = pointee.Desugar() as FunctionType;
if (functionType == null)
desugared = pointee.Desugar();
FunctionType functionType = desugared as FunctionType;
if (functionType == null && !desugared.IsPointerTo(out functionType))
return type;
TypedefDecl @delegate = GetDelegate(type, decl);
TypedefDecl @delegate = GetDelegate(functionType, declarationContext, @private);
return new QualifiedType(new TypedefType { Declaration = @delegate });
}
private TypedefDecl GetDelegate(QualifiedType type, ITypedDecl typedDecl)
private TypedefDecl GetDelegate(FunctionType functionType,
DeclarationContext declarationContext, bool @private = false)
{
FunctionType newFunctionType = GetNewFunctionType(typedDecl, type);
var delegateName = GetDelegateName(newFunctionType);
var access = typedDecl is Method ? AccessSpecifier.Private : AccessSpecifier.Public;
var decl = (Declaration) typedDecl;
Module module = decl.TranslationUnit.Module;
var delegateName = GetDelegateName(functionType);
var access = @private ? AccessSpecifier.Private : AccessSpecifier.Public;
Module module = declarationContext.TranslationUnit.Module;
var existingDelegate = delegates.Find(t => Match(t, delegateName, module));
if (existingDelegate != null)
{
@ -135,18 +165,18 @@ namespace CppSharp.Passes @@ -135,18 +165,18 @@ namespace CppSharp.Passes
// Check if there is an existing delegate with a different calling convention
if (((FunctionType) existingDelegate.Type.GetPointee()).CallingConvention ==
newFunctionType.CallingConvention)
functionType.CallingConvention)
return existingDelegate;
// Add a new delegate with the calling convention appended to its name
delegateName += '_' + newFunctionType.CallingConvention.ToString();
delegateName += '_' + functionType.CallingConvention.ToString();
existingDelegate = delegates.Find(t => Match(t, delegateName, module));
if (existingDelegate != null)
return existingDelegate;
}
var namespaceDelegates = GetDeclContextForDelegates(decl.Namespace);
var delegateType = new QualifiedType(new PointerType(new QualifiedType(newFunctionType)));
var namespaceDelegates = GetDeclContextForDelegates(declarationContext);
var delegateType = new QualifiedType(new PointerType(new QualifiedType(functionType)));
existingDelegate = new TypedefDecl
{
Access = access,
@ -160,30 +190,6 @@ namespace CppSharp.Passes @@ -160,30 +190,6 @@ namespace CppSharp.Passes
return existingDelegate;
}
private FunctionType GetNewFunctionType(ITypedDecl decl, QualifiedType type)
{
var functionType = new FunctionType();
var method = decl as Method;
if (method != null && method.FunctionType == type)
{
functionType.Parameters.AddRange(
method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true));
functionType.CallingConvention = method.CallingConvention;
functionType.IsDependent = method.IsDependent;
functionType.ReturnType = method.ReturnType;
}
else
{
var funcTypeParam = (FunctionType) decl.Type.Desugar().GetFinalPointee().Desugar();
functionType = new FunctionType(funcTypeParam);
}
for (int i = 0; i < functionType.Parameters.Count; i++)
functionType.Parameters[i].Name = $"_{i}";
return functionType;
}
private static bool Match(TypedefDecl t, string delegateName, Module module)
{
return t.Name == delegateName &&

5
src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs

@ -72,8 +72,11 @@ namespace CppSharp.Passes @@ -72,8 +72,11 @@ namespace CppSharp.Passes
foreach (var method in methodsWithDependentPointers.Where(
m => m.SynthKind == FunctionSynthKind.None))
{
var specializedMethod = specialization.Methods.First(
var specializedMethod = specialization.Methods.FirstOrDefault(
m => m.InstantiatedFrom == method);
if (specializedMethod == null)
continue;
Method extensionMethod = GetExtensionMethodForDependentPointer(specializedMethod);
classExtensions.Methods.Add(extensionMethod);
extensionMethod.Namespace = classExtensions;

14
src/Generator/Passes/TrimSpecializationsPass.cs

@ -127,22 +127,10 @@ namespace CppSharp.Passes @@ -127,22 +127,10 @@ namespace CppSharp.Passes
s => !s.IsExplicitlyGenerated && internalSpecializations.Contains(s)))
specialization.GenerationKind = GenerationKind.Internal;
Func<TemplateArgument, bool> allPointers =
a => a.Type.Type != null && a.Type.Type.IsAddress();
var groups = (from specialization in template.Specializations
group specialization by specialization.Arguments.All(allPointers)
into @group
select @group).ToList();
foreach (var group in groups.Where(g => g.Key))
foreach (var specialization in group.Skip(1))
template.Specializations.Remove(specialization);
for (int i = template.Specializations.Count - 1; i >= 0; i--)
{
var specialization = template.Specializations[i];
if (specialization is ClassTemplatePartialSpecialization &&
!specialization.Arguments.All(allPointers))
if (specialization is ClassTemplatePartialSpecialization)
template.Specializations.RemoveAt(i);
}

Loading…
Cancel
Save