Browse Source

Refactor and extract common code in NAPI source generator.

instantiate-types-nested-templates
Joao Matos 5 years ago committed by João Matos
parent
commit
06d271806d
  1. 181
      src/Generator/Generators/NAPI/NAPIHelpers.cs
  2. 256
      src/Generator/Generators/NAPI/NAPISources.cs

181
src/Generator/Generators/NAPI/NAPIHelpers.cs

@ -1,9 +1,12 @@
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Globalization;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using CppSharp.AST; using CppSharp.AST;
using CppSharp.AST.Extensions; using CppSharp.AST.Extensions;
using CppSharp.Generators.C; using CppSharp.Generators.C;
using CppSharp.Generators.NAPI;
using CppSharp.Passes; using CppSharp.Passes;
namespace CppSharp.Generators.Cpp namespace CppSharp.Generators.Cpp
@ -84,6 +87,184 @@ namespace CppSharp.Generators.Cpp
return; return;
} }
} }
public virtual MarshalPrinter<MarshalContext> GetMarshalManagedToNativePrinter(MarshalContext ctx)
{
return new NAPIMarshalManagedToNativePrinter(ctx);
}
public virtual MarshalPrinter<MarshalContext> GetMarshalNativeToManagedPrinter(MarshalContext ctx)
{
return new NAPIMarshalNativeToManagedPrinter(ctx);
}
public struct ParamMarshal
{
public string Name;
public string Prefix;
public Parameter Param;
public MarshalContext Context;
}
public virtual List<ParamMarshal> GenerateFunctionParamsMarshal(IEnumerable<Parameter> @params,
Function function = null)
{
var marshals = new List<ParamMarshal>();
var paramIndex = 0;
foreach (var param in @params)
{
marshals.Add(GenerateFunctionParamMarshal(param, paramIndex, function));
paramIndex++;
}
return marshals;
}
public virtual ParamMarshal GenerateFunctionParamMarshal(Parameter param, int paramIndex,
Function function = null)
{
var paramMarshal = new ParamMarshal { Name = param.Name, Param = param };
var argName = Generator.GeneratedIdentifier("arg") + paramIndex.ToString(CultureInfo.InvariantCulture);
Parameter effectiveParam = param;
var isRef = param.IsOut || param.IsInOut;
var paramType = param.Type;
// TODO: Use same name between generators
var typeArgName = $"args[{paramIndex}]";
if (this is QuickJSInvokes)
typeArgName = $"argv[{paramIndex}]";
var ctx = new MarshalContext(Context, CurrentIndentation)
{
Parameter = effectiveParam,
ParameterIndex = paramIndex,
ArgName = typeArgName,
Function = function
};
paramMarshal.Context = ctx;
var marshal = GetMarshalManagedToNativePrinter(ctx);
effectiveParam.Visit(marshal);
if (string.IsNullOrEmpty(marshal.Context.Return))
throw new Exception($"Cannot marshal argument of function '{function.QualifiedOriginalName}'");
if (isRef)
{
var type = paramType.Visit(CTypePrinter);
if (param.IsInOut)
{
if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
{
Write(marshal.Context.Before);
NeedNewLine();
}
WriteLine($"{type} {argName} = {marshal.Context.Return};");
}
else
WriteLine($"{type} {argName};");
}
else
{
if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
{
Write(marshal.Context.Before);
NeedNewLine();
}
NewLineIfNeeded();
WriteLine($"auto {marshal.Context.VarPrefix}{argName} = {marshal.Context.Return};");
paramMarshal.Prefix = marshal.Context.ArgumentPrefix;
}
paramMarshal.Name = argName;
return paramMarshal;
}
public virtual void GenerateFunctionCallReturnMarshal(Function function)
{
var ctx = new MarshalContext(Context, CurrentIndentation)
{
ArgName = Helpers.ReturnIdentifier,
ReturnVarName = Helpers.ReturnIdentifier,
ReturnType = function.ReturnType
};
// TODO: Move this into the marshaler
if (function.ReturnType.Type.Desugar().IsClass())
ctx.ArgName = $"&{ctx.ArgName}";
var marshal = GetMarshalNativeToManagedPrinter(ctx);
function.ReturnType.Visit(marshal);
if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
{
Write(marshal.Context.Before);
NewLine();
}
WriteLine($"return {marshal.Context.Return};");
}
public virtual void GenerateFunctionParamsMarshalCleanups(List<ParamMarshal> @params)
{
var marshalers = new List<MarshalPrinter<MarshalContext>>();
PushBlock();
{
foreach (var paramInfo in @params)
{
var param = paramInfo.Param;
if (param.Usage != ParameterUsage.Out && param.Usage != ParameterUsage.InOut)
continue;
if (param.Type.IsPointer() && !param.Type.GetFinalPointee().IsPrimitiveType())
param.QualifiedType = new QualifiedType(param.Type.GetFinalPointee());
var ctx = new MarshalContext(Context, CurrentIndentation)
{
ArgName = paramInfo.Name,
ReturnVarName = paramInfo.Name,
ReturnType = param.QualifiedType
};
var marshal = GetMarshalNativeToManagedPrinter(ctx);
marshalers.Add(marshal);
param.Visit(marshal);
if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
Write(marshal.Context.Before);
WriteLine($"{param.Name} = {marshal.Context.Return};");
}
}
PopBlock(NewLineKind.IfNotEmpty);
PushBlock();
{
foreach (var marshal in marshalers)
{
if (!string.IsNullOrWhiteSpace(marshal.Context.Cleanup))
Write(marshal.Context.Cleanup);
}
foreach (var marshal in @params)
{
if (!string.IsNullOrWhiteSpace(marshal.Context.Cleanup))
Write(marshal.Context.Cleanup);
}
}
PopBlock(NewLineKind.IfNotEmpty);
}
} }
/// <summary> /// <summary>

256
src/Generator/Generators/NAPI/NAPISources.cs

@ -11,6 +11,7 @@ using CppSharp.Generators.NAPI;
using CppSharp.Passes; using CppSharp.Passes;
using CppSharp.Utils.FSM; using CppSharp.Utils.FSM;
using static CppSharp.Generators.Cpp.NAPISources; using static CppSharp.Generators.Cpp.NAPISources;
using Type = CppSharp.AST.Type;
namespace CppSharp.Generators.Cpp namespace CppSharp.Generators.Cpp
{ {
@ -415,7 +416,6 @@ namespace CppSharp.Generators.Cpp
PushBlock(BlockKind.Function, function); PushBlock(BlockKind.Function, function);
GenerateFunctionCallback(@group); GenerateFunctionCallback(@group);
GenerateNativeCall(group);
NewLine(); NewLine();
// TODO: // TODO:
@ -440,7 +440,6 @@ namespace CppSharp.Generators.Cpp
PushBlock(BlockKind.Method); PushBlock(BlockKind.Method);
GenerateFunctionCallback(@group.OfType<Function>().ToList()); GenerateFunctionCallback(@group.OfType<Function>().ToList());
GenerateNativeCall(@group);
NewLine(); NewLine();
WriteLine("return _this;"); WriteLine("return _this;");
@ -449,11 +448,6 @@ namespace CppSharp.Generators.Cpp
PopBlock(NewLineKind.BeforeNextBlock); PopBlock(NewLineKind.BeforeNextBlock);
} }
public virtual void GenerateNativeCall(IEnumerable<Function> group)
{
}
public virtual void GenerateFunctionCallback(List<Function> @group) public virtual void GenerateFunctionCallback(List<Function> @group)
{ {
var function = @group.First(); var function = @group.First();
@ -479,15 +473,8 @@ namespace CppSharp.Generators.Cpp
NewLine(); NewLine();
// Handle the zero arguments case right away if one exists. // Handle the zero arguments case right away if one exists.
var zeroParamsOverload = @group.SingleOrDefault(f => f.Parameters.Count == 0); CheckZeroArguments(@group);
if (zeroParamsOverload != null && @group.Count > 1) NewLineIfNeeded();
{
var index = @group.FindIndex(f => f == zeroParamsOverload);
WriteLine($"if (argc == 0)");
WriteLineIndent($"goto overload{index};");
NewLine();
}
// Check if the arguments are in the expected range. // Check if the arguments are in the expected range.
CheckArgumentsRange(@group); CheckArgumentsRange(@group);
@ -496,16 +483,17 @@ namespace CppSharp.Generators.Cpp
var needsArguments = @group.Any(f => f.Parameters.Any(p => p.IsGenerated)); var needsArguments = @group.Any(f => f.Parameters.Any(p => p.IsGenerated));
if (needsArguments) if (needsArguments)
{ {
//WriteLine("void* data;");
WriteLine("status = napi_get_cb_info(env, info, &argc, args, nullptr, nullptr);"); WriteLine("status = napi_get_cb_info(env, info, &argc, args, nullptr, nullptr);");
WriteLine("assert(status == napi_ok);"); WriteLine("assert(status == napi_ok);");
NewLine(); NewLine();
// Next we need to disambiguate which overload to call based on: WriteLine("for (size_t i = 0; i < argc; i++)");
// 1. Number of arguments passed to the method WriteOpenBraceAndIndent();
// 2. Type of arguments {
WriteLine("status = napi_typeof(env, args[i], &types[i]);");
CheckArgumentsTypes(@group); WriteLine("assert(status == napi_ok);");
}
UnindentAndWriteCloseBrace();
NewLine(); NewLine();
} }
@ -530,6 +518,19 @@ namespace CppSharp.Generators.Cpp
{ {
var stateMachine = CalculateOverloadStates(@group); var stateMachine = CalculateOverloadStates(@group);
CheckArgumentsOverload(@group, stateMachine); CheckArgumentsOverload(@group, stateMachine);
// Error state.
Unindent();
WriteLine($"error:");
Indent();
WriteLine("status = napi_throw_type_error(env, nullptr, \"Unsupported argument type\");");
WriteLine("assert(status == napi_ok);");
NewLine();
WriteLine("return nullptr;");
NewLine();
GenerateOverloadCalls(@group, stateMachine); GenerateOverloadCalls(@group, stateMachine);
} }
else else
@ -547,7 +548,20 @@ namespace CppSharp.Generators.Cpp
} }
} }
private void CheckArgumentsRange(IEnumerable<Function> @group) public virtual void CheckZeroArguments(List<Function> @group)
{
var zeroParamsOverload = @group.SingleOrDefault(f => f.Parameters.Count == 0);
if (zeroParamsOverload == null || @group.Count <= 1)
return;
var index = @group.FindIndex(f => f == zeroParamsOverload);
WriteLine($"if (argc == 0)");
WriteLineIndent($"goto overload{index};");
NeedNewLine();
}
public virtual void CheckArgumentsRange(IEnumerable<Function> @group)
{ {
var enumerable = @group as List<Function> ?? @group.ToList(); var enumerable = @group as List<Function> ?? @group.ToList();
var (minArgs, maxArgs) = (enumerable.Min(m => m.Parameters.Count), var (minArgs, maxArgs) = (enumerable.Min(m => m.Parameters.Count),
@ -557,26 +571,17 @@ namespace CppSharp.Generators.Cpp
WriteLine($"if ({rangeCheck})"); WriteLine($"if ({rangeCheck})");
WriteOpenBraceAndIndent(); WriteOpenBraceAndIndent();
{
WriteLine("status = napi_throw_type_error(env, nullptr, \"Unsupported number of arguments\");"); WriteLine("status = napi_throw_type_error(env, nullptr, \"Unsupported number of arguments\");");
WriteLine("assert(status == napi_ok);"); WriteLine("assert(status == napi_ok);");
NewLine(); NewLine();
WriteLine("return nullptr;"); WriteLine("return nullptr;");
UnindentAndWriteCloseBrace();
} }
private void CheckArgumentsTypes(IEnumerable<Function> @group)
{
WriteLine("for (size_t i = 0; i < argc; i++)");
WriteOpenBraceAndIndent();
WriteLine("status = napi_typeof(env, args[i], &types[i]);");
WriteLine("assert(status == napi_ok);");
UnindentAndWriteCloseBrace(); UnindentAndWriteCloseBrace();
} }
private void CheckArgumentsOverload(IList<Function> @group, DFSM stateMachine) public virtual void CheckArgumentsOverload(IList<Function> @group, DFSM stateMachine)
{ {
var typeCheckStates = stateMachine.Q.Except(stateMachine.F).ToList(); var typeCheckStates = stateMachine.Q.Except(stateMachine.F).ToList();
var finalStates = stateMachine.F; var finalStates = stateMachine.F;
@ -610,12 +615,7 @@ namespace CppSharp.Generators.Cpp
var type = uniqueTypes[(int) transition.Symbol]; var type = uniqueTypes[(int) transition.Symbol];
var typeChecker = new NAPITypeCheckGen(paramIndex); var condition = GenerateTypeCheckForParameter(paramIndex, type);
type.Visit(typeChecker);
var condition = typeChecker.Generate();
if (string.IsNullOrWhiteSpace(condition))
throw new NotSupportedException();
WriteLine($"if ({condition})"); WriteLine($"if ({condition})");
@ -632,21 +632,21 @@ namespace CppSharp.Generators.Cpp
NeedNewLine(); NeedNewLine();
} }
NewLineIfNeeded(); NewLineIfNeeded();
}
// Error state. public virtual string GenerateTypeCheckForParameter(int paramIndex, Type type)
Unindent(); {
WriteLine($"error:"); var typeChecker = new NAPITypeCheckGen(paramIndex);
Indent(); type.Visit(typeChecker);
WriteLine("status = napi_throw_type_error(env, nullptr, \"Unsupported argument type\");"); var condition = typeChecker.Generate();
WriteLine("assert(status == napi_ok);"); if (string.IsNullOrWhiteSpace(condition))
NewLine(); throw new NotSupportedException();
WriteLine("return nullptr;"); return condition;
NewLine();
} }
private void GenerateOverloadCalls(IList<Function> @group, DFSM stateMachine) public virtual void GenerateOverloadCalls(IList<Function> @group, DFSM stateMachine)
{ {
// Final states. // Final states.
for (var i = 0; i < stateMachine.F.Count; i++) for (var i = 0; i < stateMachine.F.Count; i++)
@ -661,9 +661,9 @@ namespace CppSharp.Generators.Cpp
Indent(); Indent();
WriteOpenBraceAndIndent(); WriteOpenBraceAndIndent();
{
GenerateFunctionCall(function); GenerateFunctionCall(function);
}
UnindentAndWriteCloseBrace(); UnindentAndWriteCloseBrace();
NeedNewLine(); NeedNewLine();
} }
@ -672,9 +672,12 @@ namespace CppSharp.Generators.Cpp
public virtual void GenerateFunctionCall(Function function) public virtual void GenerateFunctionCall(Function function)
{ {
var @params = GenerateFunctionParamsMarshal(function.Parameters, function); var @params = GenerateFunctionParamsMarshal(function.Parameters, function);
var method = function as Method;
var isVoidReturn = function.ReturnType.Type.IsPrimitiveType(PrimitiveType.Void);
var needsReturn = !function.ReturnType.Type.IsPrimitiveType(PrimitiveType.Void); PushBlock();
if (needsReturn) {
if (!isVoidReturn)
{ {
CTypePrinter.PushContext(TypePrinterContextKind.Native); CTypePrinter.PushContext(TypePrinterContextKind.Native);
var returnType = function.ReturnType.Visit(CTypePrinter); var returnType = function.ReturnType.Visit(CTypePrinter);
@ -683,9 +686,7 @@ namespace CppSharp.Generators.Cpp
Write($"{returnType} {Helpers.ReturnIdentifier} = "); Write($"{returnType} {Helpers.ReturnIdentifier} = ");
} }
var method = function as Method;
var @class = function.Namespace as Class; var @class = function.Namespace as Class;
var property = method?.AssociatedDeclaration as Property; var property = method?.AssociatedDeclaration as Property;
var field = property?.Field; var field = property?.Field;
if (field != null) if (field != null)
@ -719,58 +720,23 @@ namespace CppSharp.Generators.Cpp
GenerateFunctionParams(@params); GenerateFunctionParams(@params);
WriteLine(");"); WriteLine(");");
} }
foreach(var paramInfo in @params)
{
var param = paramInfo.Param;
if(param.Usage != ParameterUsage.Out && param.Usage != ParameterUsage.InOut)
continue;
if (param.Type.IsPointer() && !param.Type.GetFinalPointee().IsPrimitiveType())
param.QualifiedType = new QualifiedType(param.Type.GetFinalPointee());
var ctx = new MarshalContext(Context, CurrentIndentation)
{
ArgName = paramInfo.Name,
ReturnVarName = paramInfo.Name,
ReturnType = param.QualifiedType
};
var marshal = GetMarshalNativeToManagedPrinter(ctx);
param.Visit(marshal);
if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
Write(marshal.Context.Before);
WriteLine($"{param.Name} = {marshal.Context.Return};");
} }
PopBlock(NewLineKind.IfNotEmpty);
GenerateFunctionParamsMarshalCleanups(@params);
var isCtor = method != null && method.IsConstructor;
var needsReturn = !isVoidReturn || (this is QuickJSInvokes && !isCtor);
if (needsReturn) if (needsReturn)
{ {
NewLine(); NewLine();
GenerateFunctionCallReturnMarshal(function); GenerateFunctionCallReturnMarshal(function);
} }
}
public virtual void GenerateFunctionCallReturnMarshal(Function function)
{
var ctx = new MarshalContext(Context, CurrentIndentation)
{
ArgName = Helpers.ReturnIdentifier,
ReturnVarName = Helpers.ReturnIdentifier,
ReturnType = function.ReturnType
};
var marshal = GetMarshalNativeToManagedPrinter(ctx);
function.ReturnType.Visit(marshal);
if (!string.IsNullOrWhiteSpace(marshal.Context.Before)) if (isCtor)
{ {
Write(marshal.Context.Before); WriteLine("goto wrap;");
NewLine();
} }
WriteLine($"return {marshal.Context.Return};");
} }
public bool IsNativeFunctionOrStaticMethod(Function function) public bool IsNativeFunctionOrStaticMethod(Function function)
@ -788,98 +754,6 @@ namespace CppSharp.Generators.Cpp
return method.IsStatic || method.Conversion != MethodConversionKind.None; return method.IsStatic || method.Conversion != MethodConversionKind.None;
} }
public struct ParamMarshal
{
public string Name;
public string Prefix;
public Parameter Param;
}
public List<ParamMarshal> GenerateFunctionParamsMarshal(IEnumerable<Parameter> @params,
Function function = null)
{
var marshals = new List<ParamMarshal>();
var paramIndex = 0;
foreach (var param in @params)
{
marshals.Add(GenerateFunctionParamMarshal(param, paramIndex, function));
paramIndex++;
}
return marshals;
}
public virtual MarshalPrinter<MarshalContext> GetMarshalManagedToNativePrinter(MarshalContext ctx)
{
return new NAPIMarshalManagedToNativePrinter(ctx);
}
public virtual MarshalPrinter<MarshalContext> GetMarshalNativeToManagedPrinter(MarshalContext ctx)
{
return new NAPIMarshalNativeToManagedPrinter(ctx);
}
public virtual ParamMarshal GenerateFunctionParamMarshal(Parameter param, int paramIndex,
Function function = null)
{
var paramMarshal = new ParamMarshal { Name = param.Name, Param = param };
var argName = Generator.GeneratedIdentifier("arg") + paramIndex.ToString(CultureInfo.InvariantCulture);
Parameter effectiveParam = param;
var isRef = param.IsOut || param.IsInOut;
var paramType = param.Type;
var ctx = new MarshalContext(Context, CurrentIndentation)
{
Parameter = effectiveParam,
ParameterIndex = paramIndex,
ArgName = argName,
Function = function
};
var marshal = GetMarshalManagedToNativePrinter(ctx);
effectiveParam.Visit(marshal);
if (string.IsNullOrEmpty(marshal.Context.Return))
throw new Exception($"Cannot marshal argument of function '{function.QualifiedOriginalName}'");
if (isRef)
{
var type = paramType.Visit(CTypePrinter);
if (param.IsInOut)
{
if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
{
Write(marshal.Context.Before);
NeedNewLine();
}
WriteLine($"{type} {argName} = {marshal.Context.Return};");
}
else
WriteLine($"{type} {argName};");
}
else
{
if (!string.IsNullOrWhiteSpace(marshal.Context.Before))
{
Write(marshal.Context.Before);
NeedNewLine();
}
NewLineIfNeeded();
WriteLine($"auto {marshal.Context.VarPrefix}{argName} = {marshal.Context.Return};");
paramMarshal.Prefix = marshal.Context.ArgumentPrefix;
}
paramMarshal.Name = argName;
return paramMarshal;
}
public void GenerateFunctionParams(List<ParamMarshal> @params) public void GenerateFunctionParams(List<ParamMarshal> @params)
{ {
var names = @params.Select(param => var names = @params.Select(param =>
@ -889,7 +763,7 @@ namespace CppSharp.Generators.Cpp
Write(string.Join(", ", names)); Write(string.Join(", ", names));
} }
private static DFSM CalculateOverloadStates(IEnumerable<Function> group) public static DFSM CalculateOverloadStates(IEnumerable<Function> group)
{ {
var functionGroup = group.ToList(); var functionGroup = group.ToList();

Loading…
Cancel
Save