diff --git a/src/Generator/Generators/CodeTemplate.cs b/src/Generator/Generators/CodeTemplate.cs index 7cd1c1b1..55b4b026 100644 --- a/src/Generator/Generators/CodeTemplate.cs +++ b/src/Generator/Generators/CodeTemplate.cs @@ -6,13 +6,13 @@ namespace CppSharp.Generators { public abstract class CodeTemplate : BlockGenerator, IDeclVisitor { - public BindingContext Context { get; private set; } + public BindingContext Context { get; } - public DriverOptions Options { get { return Context.Options; } } + public DriverOptions Options => Context.Options; - public List TranslationUnits { get; private set; } + public List TranslationUnits { get; } - public TranslationUnit TranslationUnit { get { return TranslationUnits[0]; } } + public TranslationUnit TranslationUnit => TranslationUnits[0]; public abstract string FileExtension { get; } @@ -29,7 +29,7 @@ namespace CppSharp.Generators public abstract void Process(); - public new string Generate() + public override string Generate() { if (Options.IsCSharpGenerator && Options.CompileCode) return base.GenerateUnformatted(); diff --git a/src/Generator/Passes/GenerateInlinesCodePass.cs b/src/Generator/Passes/GenerateInlinesCodePass.cs index 5ebf3754..c52784eb 100644 --- a/src/Generator/Passes/GenerateInlinesCodePass.cs +++ b/src/Generator/Passes/GenerateInlinesCodePass.cs @@ -1,29 +1,83 @@ -using System.IO; -using System.Text; +using System.Collections.Generic; +using System.IO; +using System.Linq; using CppSharp.AST; namespace CppSharp.Passes { public class GenerateInlinesCodePass : TranslationUnitPass { + public GenerateInlinesCodePass() + { + VisitOptions.VisitClassBases = false; + VisitOptions.VisitClassFields = false; + VisitOptions.VisitEventParameters = false; + VisitOptions.VisitFunctionParameters = false; + VisitOptions.VisitFunctionReturnType = false; + VisitOptions.VisitNamespaceEnums = false; + VisitOptions.VisitNamespaceEvents = false; + VisitOptions.VisitNamespaceTemplates = false; + VisitOptions.VisitNamespaceTypedefs = false; + VisitOptions.VisitNamespaceVariables = false; + VisitOptions.VisitTemplateArguments = false; + } + public override bool VisitASTContext(ASTContext context) { - WriteInlinesIncludes(); - return true; + var result = base.VisitASTContext(context); + WriteInlines(); + return result; } - private void WriteInlinesIncludes() + private void WriteInlines() { - foreach (var module in Options.Modules) + foreach (var module in Options.Modules.Where(m => inlinesCodeGenerators.ContainsKey(m))) { - var cppBuilder = new StringBuilder(); - foreach (var header in module.Headers) - cppBuilder.AppendFormat("#include <{0}>\n", header); - var cpp = string.Format("{0}.cpp", module.InlinesLibraryName); + var inlinesCodeGenerator = inlinesCodeGenerators[module]; + var cpp = $"{module.InlinesLibraryName}.{inlinesCodeGenerator.FileExtension}"; Directory.CreateDirectory(Options.OutputDir); var path = Path.Combine(Options.OutputDir, cpp); - File.WriteAllText(path, cppBuilder.ToString()); + File.WriteAllText(path, inlinesCodeGenerator.Generate()); } } + + public override bool VisitFunctionDecl(Function function) + { + if (!base.VisitFunctionDecl(function) || !NeedsSymbol(function)) + return false; + + InlinesCodeGenerator inlinesCodeGenerator; + var module = function.TranslationUnit.Module; + if (inlinesCodeGenerators.ContainsKey(module)) + inlinesCodeGenerator = inlinesCodeGenerators[module]; + else + { + inlinesCodeGenerators[module] = inlinesCodeGenerator = + new InlinesCodeGenerator(Context, module.Units); + inlinesCodeGenerator.Process(); + } + + if (module == Options.SystemModule) + return false; + + return function.Visit(inlinesCodeGenerator); + } + + private bool NeedsSymbol(Function function) + { + var mangled = function.Mangled; + var method = function as Method; + return function.IsGenerated && !function.IsDeleted && !function.IsDependent && + !function.IsPure && (!string.IsNullOrEmpty(function.Body) || function.IsImplicit) && + // we don't need symbols for virtual functions anyway + (method == null || (!method.IsVirtual && !method.IsSynthetized && + (!method.IsConstructor || !((Class) method.Namespace).IsAbstract))) && + // we cannot handle nested anonymous types + (!(function.Namespace is Class) || !string.IsNullOrEmpty(function.Namespace.OriginalName)) && + !Context.Symbols.FindSymbol(ref mangled); + } + + private Dictionary inlinesCodeGenerators = + new Dictionary(); } } diff --git a/src/Generator/Passes/InlinesCodeGenerator.cs b/src/Generator/Passes/InlinesCodeGenerator.cs new file mode 100644 index 00000000..202b9d5b --- /dev/null +++ b/src/Generator/Passes/InlinesCodeGenerator.cs @@ -0,0 +1,233 @@ +using System.Collections.Generic; +using System.Linq; +using System.Text; +using CppSharp.AST; +using CppSharp.AST.Extensions; +using CppSharp.Generators; + +namespace CppSharp.Passes +{ + public class InlinesCodeGenerator : CodeTemplate + { + public override string FileExtension => "cpp"; + + public InlinesCodeGenerator(BindingContext context, IEnumerable units) + : base(context, units) + { + } + + public override void Process() + { + foreach (var header in TranslationUnit.Module.Headers) + WriteLine($"#include <{header}>"); + NewLine(); + } + + public override bool VisitMethodDecl(Method method) + { + if (method.IsConstructor) + { + WrapConstructor(method); + return true; + } + if (method.IsDestructor) + { + WrapDestructor(method); + return true; + } + return this.VisitFunctionDecl(method); + } + + public override bool VisitFunctionDecl(Function function) + { + TakeFunctionAddress(function); + return true; + } + + private string GetWrapper(Module module) + { + var inlinesLibraryName = new StringBuilder(module.InlinesLibraryName); + for (int i = 0; i < inlinesLibraryName.Length; i++) + if (!char.IsLetterOrDigit(inlinesLibraryName[i])) + inlinesLibraryName[i] = '_'; + return $"{inlinesLibraryName}{++functionCount}"; + } + + private static string GetDerivedType(string @namespace, string wrapper) + { + return $@"class {wrapper}{@namespace} : public {@namespace} {{ public: "; + } + + private void WrapConstructor(Method method) + { + string wrapper = GetWrapper(method.TranslationUnit.Module); + if (Options.CheckSymbols) + method.Mangled = wrapper; + + int i = 0; + foreach (var param in method.Parameters.Where( + p => string.IsNullOrEmpty(p.OriginalName))) + param.Name = "_" + i++; + var @params = string.Join(", ", method.Parameters.Select(CastIfRVReference)); + var signature = string.Join(", ", method.GatherInternalParams( + Context.ParserOptions.IsItaniumLikeAbi).Select( + p => cppTypePrinter.VisitParameter(p))); + + string @namespace = method.Namespace.Visit(cppTypePrinter); + if (method.Access == AccessSpecifier.Protected) + { + Write(GetDerivedType(@namespace, wrapper)); + Write($"{wrapper}{@namespace}"); + Write($@"({(string.Join(", ", method.Parameters.Select( + p => cppTypePrinter.VisitParameter(p))))})"); + WriteLine($": {@namespace}({@params}) {{}} }};"); + Write($"extern \"C\" {{ void {wrapper}({signature}) "); + WriteLine($"{{ new (instance) {wrapper}{@namespace}({@params}); }} }}"); + } + else + { + Write($"extern \"C\" {{ void {wrapper}({signature}) "); + WriteLine($"{{ new (instance) {@namespace}({@params}); }} }}"); + } + + foreach (var param in method.Parameters.Where(p => + string.IsNullOrEmpty(p.OriginalName))) + param.Name = param.OriginalName; + } + + private string CastIfRVReference(Parameter p) + { + var pointer = p.Type.Desugar() as PointerType; + if (pointer == null || + pointer.Modifier != PointerType.TypeModifier.RVReference) + return p.Name; + + return $@"({pointer.Visit( + cppTypePrinter, p.QualifiedType.Qualifiers)}) {p.Name}"; + } + + private void WrapDestructor(Method method) + { + string wrapper = GetWrapper(method.TranslationUnit.Module); + if (Options.CheckSymbols) + method.Mangled = wrapper; + + bool isProtected = method.Access == AccessSpecifier.Protected; + string @namespace = method.Namespace.Visit(cppTypePrinter); + if (isProtected) + Write($"{GetDerivedType(@namespace, wrapper)}"); + else + Write("extern \"C\" { "); + Write($"void {wrapper}"); + if (isProtected) + Write("protected"); + WriteLine($@"({@namespace}* instance) {{ instance->~{ + method.Namespace.OriginalName}(); }} }}"); + if (isProtected) + WriteLine($@"void {wrapper}({@namespace} instance) {{ { + wrapper}{@namespace}::{wrapper}protected(instance); }}"); + } + + private void TakeFunctionAddress(Function function) + { + string wrapper = GetWrapper(function.TranslationUnit.Module); + string @namespace = function.Namespace.Visit(cppTypePrinter); + if (function.Access == AccessSpecifier.Protected) + { + Write(GetDerivedType(@namespace, wrapper)); + Write("static constexpr "); + } + + string returnType = function.OriginalReturnType.Visit(cppTypePrinter); + bool ambiguity = function.Namespace is TranslationUnit || + function.Namespace.GetOverloads(function).Count() > 1 || + function.FriendKind != FriendKind.None; + string signature = ambiguity ? GetSignature(function) : string.Empty; + + string functionName = GetFunctionName(function, @namespace); + if (function.FriendKind != FriendKind.None) + WriteRedeclaration(function, returnType, signature, functionName); + + var method = function as Method; + if (ambiguity) + Write($@"{returnType} ({(method != null && !method.IsStatic ? + (@namespace + "::") : string.Empty)}*{wrapper}){signature}"); + else + Write($@"auto {wrapper}"); + Write($@" = &{functionName};"); + if (function.Access == AccessSpecifier.Protected) + { + WriteLine(" };"); + Write($"auto {wrapper}protected = {wrapper}{@namespace}::{wrapper};"); + } + NewLine(); + } + + private string GetSignature(Function function) + { + var method = function as Method; + + var paramTypes = string.Join(", ", function.Parameters.Where( + p => p.Kind == ParameterKind.Regular).Select( + p => cppTypePrinter.VisitParameterDecl(p))); + + var variadicType = function.IsVariadic ? + (function.Parameters.Where( + p => p.Kind == ParameterKind.Regular).Any() ? ", ..." : "...") : + string.Empty; + + var @const = method != null && method.IsConst ? " const" : string.Empty; + + var refQualifier = method == null || method.RefQualifier == RefQualifier.None ? + string.Empty : (method.RefQualifier == RefQualifier.LValue ? " &" : " &&"); + + return $@"({paramTypes}{variadicType}){@const}{refQualifier}"; + } + + private string GetFunctionName(Function function, string @namespace) + { + return $@"{(function.Access == AccessSpecifier.Protected || + string.IsNullOrEmpty(@namespace) ? + string.Empty : (@namespace + "::"))}{function.OriginalName}{ + (function.SpecializationInfo == null ? string.Empty : $@"<{ + string.Join(", ", function.SpecializationInfo.Arguments.Select( + a => a.Type.Visit(cppTypePrinter)))}>")}"; + } + + private void WriteRedeclaration(Function function, string returnType, + string paramTypes, string functionName) + { + var parentsOpen = new Stack(); + var parentsClose = new StringBuilder(); + if (function.Namespace is Namespace && + function.Namespace.Namespace is Namespace && + !(function.Namespace.Namespace is TranslationUnit)) + { + var @namespace = function.Namespace; + while (!(@namespace is TranslationUnit)) + { + parentsOpen.Push($"namespace {@namespace.OriginalName} {{ "); + parentsClose.Append(" }"); + @namespace = @namespace.Namespace; + } + } + var functionType = (FunctionType) function.FunctionType.Type; + Write($@"{string.Join(string.Empty, parentsOpen)}"); + if (function.IsConstExpr) + Write("constexpr "); + Write(returnType); + Write(" "); + Write(parentsOpen.Any() ? function.OriginalName : functionName); + Write(paramTypes); + if (functionType.ExceptionSpecType == ExceptionSpecType.BasicNoexcept) + Write(" noexcept"); + WriteLine($";{parentsClose}"); + } + + private CppTypePrinter cppTypePrinter = new CppTypePrinter + { + PrintScopeKind = CppTypePrintScopeKind.Qualified + }; + private int functionCount; + } +} diff --git a/src/Generator/Utils/BlockGenerator.cs b/src/Generator/Utils/BlockGenerator.cs index 1370a70a..44e69602 100644 --- a/src/Generator/Utils/BlockGenerator.cs +++ b/src/Generator/Utils/BlockGenerator.cs @@ -288,7 +288,7 @@ namespace CppSharp ActiveBlock = RootBlock; } - public string Generate() + public virtual string Generate() { return RootBlock.Generate(); }