diff --git a/src/Generator/Driver.cs b/src/Generator/Driver.cs index bf8083d4..8ca07a0f 100644 --- a/src/Generator/Driver.cs +++ b/src/Generator/Driver.cs @@ -258,6 +258,7 @@ namespace CppSharp public bool GenerateLibraryNamespace; public bool GenerateFunctionTemplates; public bool GeneratePartialClasses; + public bool GenerateVirtualTables; public string IncludePrefix; public bool WriteOnlyWhenChanged; public Func GenerateName; diff --git a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs index 3d55c8ce..5121fc87 100644 --- a/src/Generator/Generators/CSharp/CSharpTextTemplate.cs +++ b/src/Generator/Generators/CSharp/CSharpTextTemplate.cs @@ -80,6 +80,8 @@ namespace CppSharp.Generators.CSharp public const int Variable = FIRST + 12; public const int Property = FIRST + 13; public const int Field = FIRST + 14; + public const int VTableDelegate = FIRST + 16; + public const int Region = FIRST + 17; } public class CSharpTextTemplate : Template @@ -337,6 +339,9 @@ namespace CppSharp.Generators.CSharp GenerateClassMethods(@class); GenerateClassVariables(@class); GenerateClassProperties(@class); + + if (Options.GenerateVirtualTables && @class.IsDynamic) + GenerateVTable(@class); } WriteCloseBraceIndent(); @@ -356,6 +361,7 @@ namespace CppSharp.Generators.CSharp typePrinter.PushContext(CSharpTypePrinterContextKind.Native); GenerateClassFields(@class, isInternal: true); + GenerateVTablePointers(@class); var functions = GatherClassInternalFunctions(@class); @@ -877,6 +883,255 @@ namespace CppSharp.Generators.CSharp PopBlock(NewLineKind.BeforeNextBlock); } + #region Virtual Tables + + public void GenerateVTable(Class @class) + { + var entries = VTables.GatherVTableMethodEntries(@class); + entries = entries.Where(entry => !entry.Method.Ignore).ToList(); + + if (entries.Count == 0) + return; + + PushBlock(CSharpBlockKind.Region); + WriteLine("#region Virtual table interop"); + NewLine(); + + // Generate a delegate type for each method. + foreach (var entry in entries) + { + var method = entry.Method; + GenerateVTableMethodDelegates(@class, method); + } + + const string Dictionary = "System.Collections.Generic.Dictionary"; + + WriteLine("static IntPtr[] _OldVTables;"); + WriteLine("static IntPtr[] _NewVTables;"); + WriteLine("static IntPtr[] _Thunks;"); + WriteLine("static {0} _References;", Dictionary); + NewLine(); + + GenerateVTableClassSetup(@class, Dictionary, entries); + + WriteLine("#endregion"); + PopBlock(NewLineKind.BeforeNextBlock); + } + + private void GenerateVTableClassSetup(Class @class, string Dictionary, + List entries) + { + WriteLine("void SetupVTables(global::System.IntPtr instance)"); + WriteStartBraceIndent(); + + WriteLine("var native = (Internal*)instance.ToPointer();"); + NewLine(); + + WriteLine("if (_References == null)"); + WriteLineIndent("_References = new {0}();", Dictionary); + NewLine(); + + WriteLine("if (_References.ContainsKey(instance))"); + WriteLineIndent("return;"); + + NewLine(); + WriteLine("_References[instance] = new WeakReference(this);"); + NewLine(); + + // Save the original vftable pointers. + WriteLine("if (_OldVTables == null)"); + WriteStartBraceIndent(); + + var vftables = @class.Layout.VFTables; + WriteLine("_OldVTables = new IntPtr[{0}];", vftables.Count); + + var index = 0; + foreach (var vfptr in vftables) + { + WriteLine("_OldVTables[{0}] = native->vfptr{0};", index++); + } + + WriteCloseBraceIndent(); + NewLine(); + + // Get the _Thunks + WriteLine("if (_Thunks == null)"); + WriteStartBraceIndent(); + + WriteLine("_Thunks = new IntPtr[{0}];", entries.Count); + + index = 0; + foreach (var entry in entries) + { + var method = entry.Method; + var delegateName = GetVTableMethodDelegateName(method); + var delegateInstance = delegateName + "Instance"; + WriteLine("{0} += {1}Hook;", delegateInstance, delegateName); + WriteLine("_Thunks[{0}] = Marshal.GetFunctionPointerForDelegate({1});", + index++, delegateInstance); + } + + WriteCloseBraceIndent(); + NewLine(); + + // Allocate new vtables if there are none yet. + WriteLine("if (_NewVTables == null)"); + WriteStartBraceIndent(); + + WriteLine("_NewVTables = new IntPtr[{0}];", vftables.Count); + + index = 0; + foreach (var vfptr in vftables) + { + var size = vfptr.Layout.Components.Count; + WriteLine("var vfptr = Marshal.AllocHGlobal({0} * IntPtr.Size);", size); + WriteLine("_NewVTables[{0}] = vfptr;", index++); + + var entryIndex = 0; + foreach (var entry in vfptr.Layout.Components) + { + var offsetInBytes = VTables.GetVTableComponentIndex(@class, entry)*IntPtr.Size; + WriteLine("*(IntPtr*)(vfptr + {0}) = _Thunks[{1}];", offsetInBytes, entryIndex++); + } + } + + WriteCloseBraceIndent(); + NewLine(); + + // Set the previous delegate instances pointers in the object + // virtual table. + index = 0; + foreach (var vfptr in @class.Layout.VFTables) + WriteLine("native->vfptr{0} = _NewVTables[{0}];", index++); + + WriteCloseBraceIndent(); + NewLine(); + } + + private void GenerateVTableManagedCall(Method method) + { + if (method.IsDestructor) + { + WriteLine("target.Dispose();"); + return; + } + + var marshals = new List(); + foreach (var param in method.Parameters) + { + if (param.Ignore) + continue; + + if (param.Kind == ParameterKind.HiddenStructureReturn) + continue; + + var ctx = new CSharpMarshalContext(Driver) + { + ReturnType = param.QualifiedType, + ReturnVarName = SafeIdentifier(param.Name) + }; + + var marshal = new CSharpMarshalNativeToManagedPrinter(ctx); + param.Visit(marshal); + + if (!string.IsNullOrWhiteSpace(marshal.Context.SupportBefore)) + Write(marshal.Context.SupportBefore); + + marshals.Add(marshal.Context.Return); + } + + var hasReturn = !method.ReturnType.Type.IsPrimitiveType(PrimitiveType.Void) + && !method.HasHiddenStructParameter; + + if (hasReturn) + Write("var _ret = "); + + WriteLine("target.{0}({1});", method.Name, string.Join(", ", marshals)); + + // TODO: Handle hidden structure return types. + + if (hasReturn) + { + var param = new Parameter + { + Name = "_ret", + QualifiedType = method.ReturnType + }; + + // Marshal the managed result to native + var ctx = new CSharpMarshalContext(Driver) + { + ArgName = "_ret", + Parameter = param, + Function = method + }; + + var marshal = new CSharpMarshalManagedToNativePrinter(ctx); + method.ReturnType.Visit(marshal); + + if (!string.IsNullOrWhiteSpace(marshal.Context.SupportBefore)) + Write(marshal.Context.SupportBefore); + + WriteLine("return {0};", marshal.Context.Return); + } + } + + private void GenerateVTableMethodDelegates(Class @class, Method method) + { + PushBlock(CSharpBlockKind.VTableDelegate); + + WriteLine("[SuppressUnmanagedCodeSecurity]"); + WriteLine("[UnmanagedFunctionPointerAttribute(CallingConvention.{0})]", + Helpers.ToCSharpCallConv(method.CallingConvention)); + + CSharpTypePrinterResult retType; + var @params = GatherInternalParams(method, out retType); + + var delegateName = GetVTableMethodDelegateName(method); + WriteLine("delegate {0} {1}({2});", retType, delegateName, + string.Join(", ", @params)); + + WriteLine("static {0} {0}Instance;", delegateName); + NewLine(); + + WriteLine("static {0} {1}Hook({2})", retType, delegateName, + string.Join(", ", @params)); + WriteStartBraceIndent(); + + WriteLine("if (!_References.ContainsKey(instance))"); + WriteLineIndent("throw new Exception(\"No managed instance was found\");"); + NewLine(); + + WriteLine("var target = ({0}) _References[instance].Target;", @class.Name); + GenerateVTableManagedCall(method); + + WriteCloseBraceIndent(); + + PopBlock(NewLineKind.Always); + } + + public string GetVTableMethodDelegateName(Method method) + { + return string.Format("_{0}Delegate", GetFunctionIdentifier(method)); + } + + public void GenerateVTablePointers(Class @class) + { + var index = 0; + foreach (var info in @class.Layout.VFTables) + { + PushBlock(CSharpBlockKind.InternalsClassField); + + WriteLine("[FieldOffset({0})]", info.VFPtrFullOffset); + WriteLine("public global::System.IntPtr vfptr{0};", + info.VFPtrFullOffset, index++); + + PopBlock(NewLineKind.BeforeNextBlock); + } + } + + #endregion + #region Events private string delegateName; @@ -1072,7 +1327,11 @@ namespace CppSharp.Generators.CSharp if (@class.IsRefType) { if (ShouldGenerateClassNativeField(@class)) + { WriteLine("{0} = native;", Helpers.InstanceIdentifier); + if (Options.GenerateVirtualTables && @class.IsDynamic) + WriteLine("SetupVTables(_Instance);"); + } } else { @@ -1274,6 +1533,8 @@ namespace CppSharp.Generators.CSharp Write(", "); GenerateFunctionParams(@params); WriteLine(");"); + if (Options.GenerateVirtualTables && @class.IsDynamic) + WriteLine("SetupVTables(_Instance);"); } public void GenerateInternalFunctionCall(Function function,