diff --git a/src/AST/Type.cs b/src/AST/Type.cs
index 3e2de59c..b3cf9d8c 100644
--- a/src/AST/Type.cs
+++ b/src/AST/Type.cs
@@ -551,7 +551,7 @@ namespace CppSharp.AST
///
/// Represents a template argument within a class template specialization.
///
- public struct TemplateArgument
+ public class TemplateArgument
{
/// The kind of template argument we're storing.
public enum ArgumentKind
diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs
index 798b1f0b..d0827499 100644
--- a/src/Generator/Generators/CSharp/CSharpSources.cs
+++ b/src/Generator/Generators/CSharp/CSharpSources.cs
@@ -312,7 +312,7 @@ namespace CppSharp.Generators.CSharp
GenerateClassTemplateSpecializationsInternals(
nestedTemplate, nestedTemplate.Specializations);
- foreach (var specialization in generated)
+ foreach (var specialization in generated.KeepSingleAllPointersSpecialization())
GenerateClassInternals(specialization);
foreach (var group in generated.SelectMany(s => s.Classes).Where(
@@ -558,14 +558,32 @@ namespace CppSharp.Generators.CSharp
functions.AddRange(GatherClassInternalFunctions(@base.Class, false));
var currentSpecialization = @class as ClassTemplateSpecialization;
- Class template;
- if (currentSpecialization != null &&
- (template = currentSpecialization.TemplatedDecl.TemplatedClass)
- .GetSpecializedClassesToGenerate().Count() == 1)
- foreach (var specialization in template.Specializations.Where(s => s.IsGenerated))
- GatherClassInternalFunctions(specialization, includeCtors, functions);
- else
- GatherClassInternalFunctions(@class, includeCtors, functions);
+ if (currentSpecialization != null)
+ {
+ Class template = currentSpecialization.TemplatedDecl.TemplatedClass;
+ IEnumerable specializations = null;
+ if (template.GetSpecializedClassesToGenerate().Count() == 1)
+ specializations = template.Specializations.Where(s => s.IsGenerated);
+ else
+ {
+ Func allPointers = (TemplateArgument a) =>
+ a.Type.Type?.Desugar().IsAddress() == true;
+ if (currentSpecialization.Arguments.All(allPointers))
+ {
+ specializations = template.Specializations.Where(
+ s => s.IsGenerated && s.Arguments.All(allPointers));
+ }
+ }
+
+ if (specializations != null)
+ {
+ foreach (var specialization in specializations)
+ GatherClassInternalFunctions(specialization, includeCtors, functions);
+ return functions;
+ }
+ }
+
+ GatherClassInternalFunctions(@class, includeCtors, functions);
return functions;
}
diff --git a/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs b/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs
index 736c6779..fb337a41 100644
--- a/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs
+++ b/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs
@@ -23,11 +23,37 @@ namespace CppSharp.Generators.CSharp
{
var printedClass = @class.Visit(gen.TypePrinter);
if (@class.IsDependent)
- foreach (var specialization in @class.GetSpecializedClassesToGenerate(
- ).Where(s => s.IsGenerated))
+ {
+ IEnumerable specializations =
+ @class.GetSpecializedClassesToGenerate().Where(s => s.IsGenerated);
+ if (@class.IsTemplate)
+ specializations = specializations.KeepSingleAllPointersSpecialization();
+ foreach (var specialization in specializations)
gen.GenerateNativeConstructorByValue(specialization, printedClass.Type);
+ }
else
+ {
gen.GenerateNativeConstructorByValue(@class, printedClass.Type);
+ }
+ }
+
+ public static IEnumerable KeepSingleAllPointersSpecialization(
+ this IEnumerable specializations)
+ {
+ Func allPointers = (TemplateArgument a) =>
+ a.Type.Type?.Desugar().IsAddress() == true;
+ var groups = (from ClassTemplateSpecialization spec in specializations
+ group spec by spec.Arguments.All(allPointers)
+ into @group
+ select @group).ToList();
+ foreach (var group in groups)
+ {
+ if (group.Key)
+ yield return group.First();
+ else
+ foreach (var specialization in group)
+ yield return specialization;
+ }
}
public static void GenerateField(this CSharpSources gen, Class @class,
@@ -112,7 +138,7 @@ namespace CppSharp.Generators.CSharp
Enumerable.Range(0, @class.TemplateParameters.Count).Select(
i =>
{
- CppSharp.AST.Type type = specialization.Arguments[i].Type.Type.Desugar();
+ CppSharp.AST.Type type = specialization.Arguments[i].Type.Type;
return type.IsPointerToPrimitiveType() ?
$"__{@class.TemplateParameters[i].Name}.FullName == \"System.IntPtr\"" :
$"__{@class.TemplateParameters[i].Name}.IsAssignableFrom(typeof({type}))";
diff --git a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs
index 8dc19a0d..24321b7e 100644
--- a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs
+++ b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs
@@ -560,14 +560,14 @@ namespace CppSharp.Generators.CSharp
{
if (a.Type.Type == null)
return a.Integral.ToString(CultureInfo.InvariantCulture);
- var type = a.Type.Type.Desugar();
+ var type = a.Type.Type;
PrimitiveType pointee;
if (type.IsPointerToPrimitiveType(out pointee) && !type.IsConstCharString())
{
return $@"CppSharp.Runtime.Pointer<{(pointee == PrimitiveType.Void ? IntPtrType :
VisitPrimitiveType(pointee, new TypeQualifiers()).Type)}>";
}
- return (type.IsPrimitiveType(PrimitiveType.Void)) ? "object" : type.Visit(this).Type;
+ return type.IsPrimitiveType(PrimitiveType.Void) ? "object" : type.Visit(this).Type;
}
public override TypePrinterResult VisitParameterDecl(Parameter parameter)
diff --git a/src/Generator/Passes/CheckDuplicatedNamesPass.cs b/src/Generator/Passes/CheckDuplicatedNamesPass.cs
index 9e2bfe52..95f8eb93 100644
--- a/src/Generator/Passes/CheckDuplicatedNamesPass.cs
+++ b/src/Generator/Passes/CheckDuplicatedNamesPass.cs
@@ -159,10 +159,16 @@ namespace CppSharp.Passes
if (x.Kind != TemplateArgument.ArgumentKind.Type ||
y.Kind != TemplateArgument.ArgumentKind.Type)
return x.Equals(y);
- return x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
- ParameterTypeComparer.GeneratorKind).Equals(
- y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
- ParameterTypeComparer.GeneratorKind));
+ Type left = x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
+ ParameterTypeComparer.GeneratorKind);
+ Type right = y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps,
+ ParameterTypeComparer.GeneratorKind);
+ // consider Type and const Type the same
+ if (left.IsReference() && !left.IsPointerToPrimitiveType())
+ left = left.GetPointee();
+ if (right.IsReference() && !right.IsPointerToPrimitiveType())
+ right = right.GetPointee();
+ return left.Equals(right);
}
public int GetHashCode(TemplateArgument obj)
diff --git a/src/Generator/Passes/DelegatesPass.cs b/src/Generator/Passes/DelegatesPass.cs
index 73860e19..de1914d9 100644
--- a/src/Generator/Passes/DelegatesPass.cs
+++ b/src/Generator/Passes/DelegatesPass.cs
@@ -46,12 +46,38 @@ namespace CppSharp.Passes
return base.VisitClassDecl(@class);
}
+ public override bool VisitClassTemplateSpecializationDecl(ClassTemplateSpecialization specialization)
+ {
+ if (!base.VisitClassTemplateSpecializationDecl(specialization) ||
+ !specialization.IsGenerated || !specialization.TemplatedDecl.TemplatedDecl.IsGenerated)
+ return false;
+
+ foreach (TemplateArgument arg in specialization.Arguments.Where(
+ a => a.Kind == TemplateArgument.ArgumentKind.Type))
+ {
+ arg.Type = CheckForDelegate(arg.Type, specialization);
+ }
+
+ return true;
+ }
+
public override bool VisitMethodDecl(Method method)
{
if (!base.VisitMethodDecl(method) || !method.IsVirtual || method.Ignore)
return false;
- method.FunctionType = CheckForDelegate(method.FunctionType, method);
+ var functionType = new FunctionType
+ {
+ CallingConvention = method.CallingConvention,
+ IsDependent = method.IsDependent,
+ ReturnType = method.ReturnType
+ };
+
+ functionType.Parameters.AddRange(
+ method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true));
+
+ method.FunctionType = CheckForDelegate(new QualifiedType(functionType),
+ method.Namespace, @private: true);
return true;
}
@@ -61,7 +87,8 @@ namespace CppSharp.Passes
if (!base.VisitFunctionDecl(function) || function.Ignore)
return false;
- function.ReturnType = CheckForDelegate(function.ReturnType, function);
+ function.ReturnType = CheckForDelegate(function.ReturnType,
+ function.Namespace);
return true;
}
@@ -71,7 +98,8 @@ namespace CppSharp.Passes
parameter.Namespace.Ignore)
return false;
- parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType, parameter);
+ parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType,
+ parameter.Namespace);
return true;
}
@@ -81,7 +109,8 @@ namespace CppSharp.Passes
if (!base.VisitProperty(property))
return false;
- property.QualifiedType = CheckForDelegate(property.QualifiedType, property);
+ property.QualifiedType = CheckForDelegate(property.QualifiedType,
+ property.Namespace);
return true;
}
@@ -91,12 +120,14 @@ namespace CppSharp.Passes
if (!base.VisitFieldDecl(field))
return false;
- field.QualifiedType = CheckForDelegate(field.QualifiedType, field);
+ field.QualifiedType = CheckForDelegate(field.QualifiedType,
+ field.Namespace);
return true;
}
- private QualifiedType CheckForDelegate(QualifiedType type, ITypedDecl decl)
+ private QualifiedType CheckForDelegate(QualifiedType type,
+ DeclarationContext declarationContext, bool @private = false)
{
if (type.Type is TypedefType)
return type;
@@ -109,22 +140,21 @@ namespace CppSharp.Passes
if (pointee is TypedefType)
return type;
- var functionType = pointee.Desugar() as FunctionType;
- if (functionType == null)
+ desugared = pointee.Desugar();
+ FunctionType functionType = desugared as FunctionType;
+ if (functionType == null && !desugared.IsPointerTo(out functionType))
return type;
- TypedefDecl @delegate = GetDelegate(type, decl);
+ TypedefDecl @delegate = GetDelegate(functionType, declarationContext, @private);
return new QualifiedType(new TypedefType { Declaration = @delegate });
}
- private TypedefDecl GetDelegate(QualifiedType type, ITypedDecl typedDecl)
+ private TypedefDecl GetDelegate(FunctionType functionType,
+ DeclarationContext declarationContext, bool @private = false)
{
- FunctionType newFunctionType = GetNewFunctionType(typedDecl, type);
-
- var delegateName = GetDelegateName(newFunctionType);
- var access = typedDecl is Method ? AccessSpecifier.Private : AccessSpecifier.Public;
- var decl = (Declaration) typedDecl;
- Module module = decl.TranslationUnit.Module;
+ var delegateName = GetDelegateName(functionType);
+ var access = @private ? AccessSpecifier.Private : AccessSpecifier.Public;
+ Module module = declarationContext.TranslationUnit.Module;
var existingDelegate = delegates.Find(t => Match(t, delegateName, module));
if (existingDelegate != null)
{
@@ -135,18 +165,18 @@ namespace CppSharp.Passes
// Check if there is an existing delegate with a different calling convention
if (((FunctionType) existingDelegate.Type.GetPointee()).CallingConvention ==
- newFunctionType.CallingConvention)
+ functionType.CallingConvention)
return existingDelegate;
// Add a new delegate with the calling convention appended to its name
- delegateName += '_' + newFunctionType.CallingConvention.ToString();
+ delegateName += '_' + functionType.CallingConvention.ToString();
existingDelegate = delegates.Find(t => Match(t, delegateName, module));
if (existingDelegate != null)
return existingDelegate;
}
- var namespaceDelegates = GetDeclContextForDelegates(decl.Namespace);
- var delegateType = new QualifiedType(new PointerType(new QualifiedType(newFunctionType)));
+ var namespaceDelegates = GetDeclContextForDelegates(declarationContext);
+ var delegateType = new QualifiedType(new PointerType(new QualifiedType(functionType)));
existingDelegate = new TypedefDecl
{
Access = access,
@@ -160,30 +190,6 @@ namespace CppSharp.Passes
return existingDelegate;
}
- private FunctionType GetNewFunctionType(ITypedDecl decl, QualifiedType type)
- {
- var functionType = new FunctionType();
- var method = decl as Method;
- if (method != null && method.FunctionType == type)
- {
- functionType.Parameters.AddRange(
- method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true));
- functionType.CallingConvention = method.CallingConvention;
- functionType.IsDependent = method.IsDependent;
- functionType.ReturnType = method.ReturnType;
- }
- else
- {
- var funcTypeParam = (FunctionType) decl.Type.Desugar().GetFinalPointee().Desugar();
- functionType = new FunctionType(funcTypeParam);
- }
-
- for (int i = 0; i < functionType.Parameters.Count; i++)
- functionType.Parameters[i].Name = $"_{i}";
-
- return functionType;
- }
-
private static bool Match(TypedefDecl t, string delegateName, Module module)
{
return t.Name == delegateName &&
diff --git a/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs b/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs
index b55c7cee..3bee97f7 100644
--- a/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs
+++ b/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs
@@ -72,8 +72,11 @@ namespace CppSharp.Passes
foreach (var method in methodsWithDependentPointers.Where(
m => m.SynthKind == FunctionSynthKind.None))
{
- var specializedMethod = specialization.Methods.First(
+ var specializedMethod = specialization.Methods.FirstOrDefault(
m => m.InstantiatedFrom == method);
+ if (specializedMethod == null)
+ continue;
+
Method extensionMethod = GetExtensionMethodForDependentPointer(specializedMethod);
classExtensions.Methods.Add(extensionMethod);
extensionMethod.Namespace = classExtensions;
diff --git a/src/Generator/Passes/TrimSpecializationsPass.cs b/src/Generator/Passes/TrimSpecializationsPass.cs
index a70e5f98..b9024654 100644
--- a/src/Generator/Passes/TrimSpecializationsPass.cs
+++ b/src/Generator/Passes/TrimSpecializationsPass.cs
@@ -127,22 +127,10 @@ namespace CppSharp.Passes
s => !s.IsExplicitlyGenerated && internalSpecializations.Contains(s)))
specialization.GenerationKind = GenerationKind.Internal;
- Func allPointers =
- a => a.Type.Type != null && a.Type.Type.IsAddress();
- var groups = (from specialization in template.Specializations
- group specialization by specialization.Arguments.All(allPointers)
- into @group
- select @group).ToList();
-
- foreach (var group in groups.Where(g => g.Key))
- foreach (var specialization in group.Skip(1))
- template.Specializations.Remove(specialization);
-
for (int i = template.Specializations.Count - 1; i >= 0; i--)
{
var specialization = template.Specializations[i];
- if (specialization is ClassTemplatePartialSpecialization &&
- !specialization.Arguments.All(allPointers))
+ if (specialization is ClassTemplatePartialSpecialization)
template.Specializations.RemoveAt(i);
}