diff --git a/src/AST/Type.cs b/src/AST/Type.cs index 3e2de59c..b3cf9d8c 100644 --- a/src/AST/Type.cs +++ b/src/AST/Type.cs @@ -551,7 +551,7 @@ namespace CppSharp.AST /// /// Represents a template argument within a class template specialization. /// - public struct TemplateArgument + public class TemplateArgument { /// The kind of template argument we're storing. public enum ArgumentKind diff --git a/src/Generator/Generators/CSharp/CSharpSources.cs b/src/Generator/Generators/CSharp/CSharpSources.cs index 798b1f0b..d0827499 100644 --- a/src/Generator/Generators/CSharp/CSharpSources.cs +++ b/src/Generator/Generators/CSharp/CSharpSources.cs @@ -312,7 +312,7 @@ namespace CppSharp.Generators.CSharp GenerateClassTemplateSpecializationsInternals( nestedTemplate, nestedTemplate.Specializations); - foreach (var specialization in generated) + foreach (var specialization in generated.KeepSingleAllPointersSpecialization()) GenerateClassInternals(specialization); foreach (var group in generated.SelectMany(s => s.Classes).Where( @@ -558,14 +558,32 @@ namespace CppSharp.Generators.CSharp functions.AddRange(GatherClassInternalFunctions(@base.Class, false)); var currentSpecialization = @class as ClassTemplateSpecialization; - Class template; - if (currentSpecialization != null && - (template = currentSpecialization.TemplatedDecl.TemplatedClass) - .GetSpecializedClassesToGenerate().Count() == 1) - foreach (var specialization in template.Specializations.Where(s => s.IsGenerated)) - GatherClassInternalFunctions(specialization, includeCtors, functions); - else - GatherClassInternalFunctions(@class, includeCtors, functions); + if (currentSpecialization != null) + { + Class template = currentSpecialization.TemplatedDecl.TemplatedClass; + IEnumerable specializations = null; + if (template.GetSpecializedClassesToGenerate().Count() == 1) + specializations = template.Specializations.Where(s => s.IsGenerated); + else + { + Func allPointers = (TemplateArgument a) => + a.Type.Type?.Desugar().IsAddress() == true; + if (currentSpecialization.Arguments.All(allPointers)) + { + specializations = template.Specializations.Where( + s => s.IsGenerated && s.Arguments.All(allPointers)); + } + } + + if (specializations != null) + { + foreach (var specialization in specializations) + GatherClassInternalFunctions(specialization, includeCtors, functions); + return functions; + } + } + + GatherClassInternalFunctions(@class, includeCtors, functions); return functions; } diff --git a/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs b/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs index 736c6779..fb337a41 100644 --- a/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs +++ b/src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs @@ -23,11 +23,37 @@ namespace CppSharp.Generators.CSharp { var printedClass = @class.Visit(gen.TypePrinter); if (@class.IsDependent) - foreach (var specialization in @class.GetSpecializedClassesToGenerate( - ).Where(s => s.IsGenerated)) + { + IEnumerable specializations = + @class.GetSpecializedClassesToGenerate().Where(s => s.IsGenerated); + if (@class.IsTemplate) + specializations = specializations.KeepSingleAllPointersSpecialization(); + foreach (var specialization in specializations) gen.GenerateNativeConstructorByValue(specialization, printedClass.Type); + } else + { gen.GenerateNativeConstructorByValue(@class, printedClass.Type); + } + } + + public static IEnumerable KeepSingleAllPointersSpecialization( + this IEnumerable specializations) + { + Func allPointers = (TemplateArgument a) => + a.Type.Type?.Desugar().IsAddress() == true; + var groups = (from ClassTemplateSpecialization spec in specializations + group spec by spec.Arguments.All(allPointers) + into @group + select @group).ToList(); + foreach (var group in groups) + { + if (group.Key) + yield return group.First(); + else + foreach (var specialization in group) + yield return specialization; + } } public static void GenerateField(this CSharpSources gen, Class @class, @@ -112,7 +138,7 @@ namespace CppSharp.Generators.CSharp Enumerable.Range(0, @class.TemplateParameters.Count).Select( i => { - CppSharp.AST.Type type = specialization.Arguments[i].Type.Type.Desugar(); + CppSharp.AST.Type type = specialization.Arguments[i].Type.Type; return type.IsPointerToPrimitiveType() ? $"__{@class.TemplateParameters[i].Name}.FullName == \"System.IntPtr\"" : $"__{@class.TemplateParameters[i].Name}.IsAssignableFrom(typeof({type}))"; diff --git a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs index 8dc19a0d..24321b7e 100644 --- a/src/Generator/Generators/CSharp/CSharpTypePrinter.cs +++ b/src/Generator/Generators/CSharp/CSharpTypePrinter.cs @@ -560,14 +560,14 @@ namespace CppSharp.Generators.CSharp { if (a.Type.Type == null) return a.Integral.ToString(CultureInfo.InvariantCulture); - var type = a.Type.Type.Desugar(); + var type = a.Type.Type; PrimitiveType pointee; if (type.IsPointerToPrimitiveType(out pointee) && !type.IsConstCharString()) { return $@"CppSharp.Runtime.Pointer<{(pointee == PrimitiveType.Void ? IntPtrType : VisitPrimitiveType(pointee, new TypeQualifiers()).Type)}>"; } - return (type.IsPrimitiveType(PrimitiveType.Void)) ? "object" : type.Visit(this).Type; + return type.IsPrimitiveType(PrimitiveType.Void) ? "object" : type.Visit(this).Type; } public override TypePrinterResult VisitParameterDecl(Parameter parameter) diff --git a/src/Generator/Passes/CheckDuplicatedNamesPass.cs b/src/Generator/Passes/CheckDuplicatedNamesPass.cs index 9e2bfe52..95f8eb93 100644 --- a/src/Generator/Passes/CheckDuplicatedNamesPass.cs +++ b/src/Generator/Passes/CheckDuplicatedNamesPass.cs @@ -159,10 +159,16 @@ namespace CppSharp.Passes if (x.Kind != TemplateArgument.ArgumentKind.Type || y.Kind != TemplateArgument.ArgumentKind.Type) return x.Equals(y); - return x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps, - ParameterTypeComparer.GeneratorKind).Equals( - y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps, - ParameterTypeComparer.GeneratorKind)); + Type left = x.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps, + ParameterTypeComparer.GeneratorKind); + Type right = y.Type.Type.GetMappedType(ParameterTypeComparer.TypeMaps, + ParameterTypeComparer.GeneratorKind); + // consider Type and const Type the same + if (left.IsReference() && !left.IsPointerToPrimitiveType()) + left = left.GetPointee(); + if (right.IsReference() && !right.IsPointerToPrimitiveType()) + right = right.GetPointee(); + return left.Equals(right); } public int GetHashCode(TemplateArgument obj) diff --git a/src/Generator/Passes/DelegatesPass.cs b/src/Generator/Passes/DelegatesPass.cs index 73860e19..de1914d9 100644 --- a/src/Generator/Passes/DelegatesPass.cs +++ b/src/Generator/Passes/DelegatesPass.cs @@ -46,12 +46,38 @@ namespace CppSharp.Passes return base.VisitClassDecl(@class); } + public override bool VisitClassTemplateSpecializationDecl(ClassTemplateSpecialization specialization) + { + if (!base.VisitClassTemplateSpecializationDecl(specialization) || + !specialization.IsGenerated || !specialization.TemplatedDecl.TemplatedDecl.IsGenerated) + return false; + + foreach (TemplateArgument arg in specialization.Arguments.Where( + a => a.Kind == TemplateArgument.ArgumentKind.Type)) + { + arg.Type = CheckForDelegate(arg.Type, specialization); + } + + return true; + } + public override bool VisitMethodDecl(Method method) { if (!base.VisitMethodDecl(method) || !method.IsVirtual || method.Ignore) return false; - method.FunctionType = CheckForDelegate(method.FunctionType, method); + var functionType = new FunctionType + { + CallingConvention = method.CallingConvention, + IsDependent = method.IsDependent, + ReturnType = method.ReturnType + }; + + functionType.Parameters.AddRange( + method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true)); + + method.FunctionType = CheckForDelegate(new QualifiedType(functionType), + method.Namespace, @private: true); return true; } @@ -61,7 +87,8 @@ namespace CppSharp.Passes if (!base.VisitFunctionDecl(function) || function.Ignore) return false; - function.ReturnType = CheckForDelegate(function.ReturnType, function); + function.ReturnType = CheckForDelegate(function.ReturnType, + function.Namespace); return true; } @@ -71,7 +98,8 @@ namespace CppSharp.Passes parameter.Namespace.Ignore) return false; - parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType, parameter); + parameter.QualifiedType = CheckForDelegate(parameter.QualifiedType, + parameter.Namespace); return true; } @@ -81,7 +109,8 @@ namespace CppSharp.Passes if (!base.VisitProperty(property)) return false; - property.QualifiedType = CheckForDelegate(property.QualifiedType, property); + property.QualifiedType = CheckForDelegate(property.QualifiedType, + property.Namespace); return true; } @@ -91,12 +120,14 @@ namespace CppSharp.Passes if (!base.VisitFieldDecl(field)) return false; - field.QualifiedType = CheckForDelegate(field.QualifiedType, field); + field.QualifiedType = CheckForDelegate(field.QualifiedType, + field.Namespace); return true; } - private QualifiedType CheckForDelegate(QualifiedType type, ITypedDecl decl) + private QualifiedType CheckForDelegate(QualifiedType type, + DeclarationContext declarationContext, bool @private = false) { if (type.Type is TypedefType) return type; @@ -109,22 +140,21 @@ namespace CppSharp.Passes if (pointee is TypedefType) return type; - var functionType = pointee.Desugar() as FunctionType; - if (functionType == null) + desugared = pointee.Desugar(); + FunctionType functionType = desugared as FunctionType; + if (functionType == null && !desugared.IsPointerTo(out functionType)) return type; - TypedefDecl @delegate = GetDelegate(type, decl); + TypedefDecl @delegate = GetDelegate(functionType, declarationContext, @private); return new QualifiedType(new TypedefType { Declaration = @delegate }); } - private TypedefDecl GetDelegate(QualifiedType type, ITypedDecl typedDecl) + private TypedefDecl GetDelegate(FunctionType functionType, + DeclarationContext declarationContext, bool @private = false) { - FunctionType newFunctionType = GetNewFunctionType(typedDecl, type); - - var delegateName = GetDelegateName(newFunctionType); - var access = typedDecl is Method ? AccessSpecifier.Private : AccessSpecifier.Public; - var decl = (Declaration) typedDecl; - Module module = decl.TranslationUnit.Module; + var delegateName = GetDelegateName(functionType); + var access = @private ? AccessSpecifier.Private : AccessSpecifier.Public; + Module module = declarationContext.TranslationUnit.Module; var existingDelegate = delegates.Find(t => Match(t, delegateName, module)); if (existingDelegate != null) { @@ -135,18 +165,18 @@ namespace CppSharp.Passes // Check if there is an existing delegate with a different calling convention if (((FunctionType) existingDelegate.Type.GetPointee()).CallingConvention == - newFunctionType.CallingConvention) + functionType.CallingConvention) return existingDelegate; // Add a new delegate with the calling convention appended to its name - delegateName += '_' + newFunctionType.CallingConvention.ToString(); + delegateName += '_' + functionType.CallingConvention.ToString(); existingDelegate = delegates.Find(t => Match(t, delegateName, module)); if (existingDelegate != null) return existingDelegate; } - var namespaceDelegates = GetDeclContextForDelegates(decl.Namespace); - var delegateType = new QualifiedType(new PointerType(new QualifiedType(newFunctionType))); + var namespaceDelegates = GetDeclContextForDelegates(declarationContext); + var delegateType = new QualifiedType(new PointerType(new QualifiedType(functionType))); existingDelegate = new TypedefDecl { Access = access, @@ -160,30 +190,6 @@ namespace CppSharp.Passes return existingDelegate; } - private FunctionType GetNewFunctionType(ITypedDecl decl, QualifiedType type) - { - var functionType = new FunctionType(); - var method = decl as Method; - if (method != null && method.FunctionType == type) - { - functionType.Parameters.AddRange( - method.GatherInternalParams(Context.ParserOptions.IsItaniumLikeAbi, true)); - functionType.CallingConvention = method.CallingConvention; - functionType.IsDependent = method.IsDependent; - functionType.ReturnType = method.ReturnType; - } - else - { - var funcTypeParam = (FunctionType) decl.Type.Desugar().GetFinalPointee().Desugar(); - functionType = new FunctionType(funcTypeParam); - } - - for (int i = 0; i < functionType.Parameters.Count; i++) - functionType.Parameters[i].Name = $"_{i}"; - - return functionType; - } - private static bool Match(TypedefDecl t, string delegateName, Module module) { return t.Name == delegateName && diff --git a/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs b/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs index b55c7cee..3bee97f7 100644 --- a/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs +++ b/src/Generator/Passes/SpecializationMethodsWithDependentPointersPass.cs @@ -72,8 +72,11 @@ namespace CppSharp.Passes foreach (var method in methodsWithDependentPointers.Where( m => m.SynthKind == FunctionSynthKind.None)) { - var specializedMethod = specialization.Methods.First( + var specializedMethod = specialization.Methods.FirstOrDefault( m => m.InstantiatedFrom == method); + if (specializedMethod == null) + continue; + Method extensionMethod = GetExtensionMethodForDependentPointer(specializedMethod); classExtensions.Methods.Add(extensionMethod); extensionMethod.Namespace = classExtensions; diff --git a/src/Generator/Passes/TrimSpecializationsPass.cs b/src/Generator/Passes/TrimSpecializationsPass.cs index a70e5f98..b9024654 100644 --- a/src/Generator/Passes/TrimSpecializationsPass.cs +++ b/src/Generator/Passes/TrimSpecializationsPass.cs @@ -127,22 +127,10 @@ namespace CppSharp.Passes s => !s.IsExplicitlyGenerated && internalSpecializations.Contains(s))) specialization.GenerationKind = GenerationKind.Internal; - Func allPointers = - a => a.Type.Type != null && a.Type.Type.IsAddress(); - var groups = (from specialization in template.Specializations - group specialization by specialization.Arguments.All(allPointers) - into @group - select @group).ToList(); - - foreach (var group in groups.Where(g => g.Key)) - foreach (var specialization in group.Skip(1)) - template.Specializations.Remove(specialization); - for (int i = template.Specializations.Count - 1; i >= 0; i--) { var specialization = template.Specializations[i]; - if (specialization is ClassTemplatePartialSpecialization && - !specialization.Arguments.All(allPointers)) + if (specialization is ClassTemplatePartialSpecialization) template.Specializations.RemoveAt(i); }