From 12d7961e39dd0ff9794184572d209ac5d6d58707 Mon Sep 17 00:00:00 2001 From: Daniel Grunwald Date: Wed, 17 Aug 2011 16:53:05 +0200 Subject: [PATCH] Add SerializationBinder-support to FastSerializer. --- .../CSharp/OutputVisitorTests.cs | 53 ++++ .../ICSharpCode.NRefactory.Tests.csproj | 1 + .../Utils/FastSerializer.cs | 293 ++++++++++++++---- 3 files changed, 283 insertions(+), 64 deletions(-) create mode 100644 ICSharpCode.NRefactory.Tests/CSharp/OutputVisitorTests.cs diff --git a/ICSharpCode.NRefactory.Tests/CSharp/OutputVisitorTests.cs b/ICSharpCode.NRefactory.Tests/CSharp/OutputVisitorTests.cs new file mode 100644 index 0000000000..ab1eaf7d9b --- /dev/null +++ b/ICSharpCode.NRefactory.Tests/CSharp/OutputVisitorTests.cs @@ -0,0 +1,53 @@ +// Copyright (c) AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.IO; +using NUnit.Framework; + +namespace ICSharpCode.NRefactory.CSharp +{ + [TestFixture] + public class OutputVisitorTests + { + void AssertOutput(string expected, Expression expr, CSharpFormattingOptions policy = null) + { + if (policy == null) + policy = new CSharpFormattingOptions();; + StringWriter w = new StringWriter(); + w.NewLine = "\n"; + expr.AcceptVisitor(new OutputVisitor(new TextWriterOutputFormatter(w) { IndentationString = "\t" }, policy), null); + Assert.AreEqual(expected.Replace("\r", ""), w.ToString()); + } + + [Test, Ignore("Incorrect whitespace")] + public void AssignmentInCollectionInitialize() + { + Expression expr = new ObjectCreateExpression { + Type = new SimpleType("List"), + Initializer = new ArrayInitializerExpression( + new ArrayInitializerExpression( + new AssignmentExpression(new IdentifierExpression("a"), new PrimitiveExpression(1)) + ) + ) + }; + + AssertOutput("new List {\n {\n a = 1\n }\n}", expr); + } + } +} diff --git a/ICSharpCode.NRefactory.Tests/ICSharpCode.NRefactory.Tests.csproj b/ICSharpCode.NRefactory.Tests/ICSharpCode.NRefactory.Tests.csproj index 614f24eb7e..d859360338 100644 --- a/ICSharpCode.NRefactory.Tests/ICSharpCode.NRefactory.Tests.csproj +++ b/ICSharpCode.NRefactory.Tests/ICSharpCode.NRefactory.Tests.csproj @@ -57,6 +57,7 @@ + diff --git a/ICSharpCode.NRefactory/Utils/FastSerializer.cs b/ICSharpCode.NRefactory/Utils/FastSerializer.cs index 0bd3887c6d..9622b887c4 100644 --- a/ICSharpCode.NRefactory/Utils/FastSerializer.cs +++ b/ICSharpCode.NRefactory/Utils/FastSerializer.cs @@ -30,6 +30,24 @@ namespace ICSharpCode.NRefactory.Utils { public class FastSerializer { + #region Properties + /// + /// Gets/Sets the serialization binder that is being used. + /// The default value is null, which will cause the FastSerializer to use the + /// full assembly and type names. + /// + public SerializationBinder SerializationBinder { get; set; } + #endregion + + #region Constants + const int magic = 0x71D18A5D; + + const byte Type_ReferenceType = 1; + const byte Type_ValueType = 2; + const byte Type_SZArray = 3; + const byte Type_ParameterizedType = 4; + #endregion + #region Serialization sealed class ReferenceComparer : IEqualityComparer { @@ -44,17 +62,35 @@ namespace ICSharpCode.NRefactory.Utils } } + sealed class SerializationType + { + public readonly int ID; + public readonly Type Type; + + public SerializationType(int iD, Type type) + { + this.ID = iD; + this.Type = type; + } + + public ObjectScanner Scanner; + public ObjectWriter Writer; + public string TypeName; + public int AssemblyNameID; + } + sealed class SerializationContext { readonly Dictionary objectToID = new Dictionary(new ReferenceComparer()); readonly List instances = new List(); // index: object ID - readonly List typeIDs = new List(); // index: object ID - int stringTypeID = -1; - int typeCountForObjects = 0; + readonly List objectTypes = new List(); // index: object ID + SerializationType stringType; - readonly Dictionary typeToID = new Dictionary(); - readonly List types = new List(); // index: type ID - readonly List writers = new List(); // index: type ID + readonly Dictionary typeMap = new Dictionary(); + readonly List types = new List(); + + readonly Dictionary assemblyNameToID = new Dictionary(); + readonly List assemblyNames = new List(); readonly FastSerializer fastSerializer; public readonly BinaryWriter writer; @@ -64,7 +100,7 @@ namespace ICSharpCode.NRefactory.Utils this.fastSerializer = fastSerializer; this.writer = writer; instances.Add(null); // use object ID 0 for null - typeIDs.Add(-1); + objectTypes.Add(null); } #region Scanning @@ -84,26 +120,14 @@ namespace ICSharpCode.NRefactory.Utils internal void Scan() { Log("Scanning..."); - List objectScanners = new List(); // index: type ID // starting from 1, because index 0 is null for (int i = 1; i < instances.Count; i++) { object instance = instances[i]; ISerializable serializable = instance as ISerializable; Type type = instance.GetType(); Log("Scan #{0}: {1}", i, type.Name); - int typeID; - if (!typeToID.TryGetValue(type, out typeID)) { - typeID = types.Count; - typeToID.Add(type, typeID); - types.Add(type); - Log("Registered type %{0}: {1}", typeID, type); - if (type == typeof(string)) { - stringTypeID = typeID; - } - objectScanners.Add(serializable != null ? null : fastSerializer.GetScanner(type)); - writers.Add(serializable != null ? serializationInfoWriter : fastSerializer.GetWriter(type)); - } - typeIDs.Add(typeID); + SerializationType sType = MarkType(type); + objectTypes.Add(sType); if (serializable != null) { SerializationInfo info = new SerializationInfo(type, fastSerializer.formatterConverter); serializable.GetObjectData(info, fastSerializer.streamingContext); @@ -111,23 +135,74 @@ namespace ICSharpCode.NRefactory.Utils foreach (SerializationEntry entry in info) { Mark(entry.Value); } + sType.Writer = serializationInfoWriter; } else { - objectScanners[typeID](this, instance); + ObjectScanner objectScanner = sType.Scanner; + if (objectScanner == null) { + objectScanner = fastSerializer.GetScanner(type); + sType.Scanner = objectScanner; + sType.Writer = fastSerializer.GetWriter(type); + } + objectScanner(this, instance); } } } #endregion #region Scan Types + SerializationType MarkType(Type type) + { + SerializationType sType; + if (!typeMap.TryGetValue(type, out sType)) { + string assemblyName = null; + string typeName = null; + if (type.HasElementType) { + MarkType(type.GetElementType()); + } else if (type.IsGenericType && !type.IsGenericTypeDefinition) { + MarkType(type.GetGenericTypeDefinition()); + foreach (Type typeArg in type.GetGenericArguments()) + MarkType(typeArg); + } else if (type.IsGenericParameter) { + throw new NotSupportedException(); + } else { + var serializationBinder = fastSerializer.SerializationBinder; + if (serializationBinder != null) { + serializationBinder.BindToName(type, out assemblyName, out typeName); + } else { + assemblyName = type.Assembly.FullName; + typeName = type.FullName; + Debug.Assert(typeName != null); + } + } + + sType = new SerializationType(typeMap.Count, type); + sType.TypeName = typeName; + if (assemblyName != null) { + if (!assemblyNameToID.TryGetValue(assemblyName, out sType.AssemblyNameID)) { + sType.AssemblyNameID = assemblyNames.Count; + assemblyNameToID.Add(assemblyName, sType.AssemblyNameID); + assemblyNames.Add(assemblyName); + Log("Registered assembly #{0}: {1}", sType.AssemblyNameID, assemblyName); + } + } + typeMap.Add(type, sType); + types.Add(sType); + Log("Registered type %{0}: {1}", sType.ID, type); + if (type == typeof(string)) { + stringType = sType; + } + } + return sType; + } + internal void ScanTypes() { - typeCountForObjects = types.Count; for (int i = 0; i < types.Count; i++) { - foreach (FieldInfo field in GetSerializableFields(types[i])) { - if (!typeToID.ContainsKey(field.FieldType)) { - typeToID.Add(field.FieldType, types.Count); - types.Add(field.FieldType); - } + Type type = types[i].Type; + if (type.IsGenericTypeDefinition || type.HasElementType) + continue; + foreach (FieldInfo field in GetSerializableFields(type)) { + MarkType(field.FieldType); } } } @@ -143,19 +218,65 @@ namespace ICSharpCode.NRefactory.Utils writer.Write(id); } + void WriteTypeID(Type type) + { + Debug.Assert(typeMap.ContainsKey(type)); + int typeID = typeMap[type].ID; + if (types.Count <= ushort.MaxValue) + writer.Write((ushort)typeID); + else + writer.Write(typeID); + } + internal void Write() { Log("Writing..."); + writer.Write(magic); // Write out type information - writer.Write(types.Count); writer.Write(instances.Count); - writer.Write(typeCountForObjects); - writer.Write(stringTypeID); - foreach (Type type in types) { - writer.Write(type.AssemblyQualifiedName); + writer.Write(types.Count); + writer.Write(assemblyNames.Count); + + foreach (string assemblyName in assemblyNames) { + writer.Write(assemblyName); } - foreach (Type type in types) { - if (type.IsArray || type.IsPrimitive || typeof(ISerializable).IsAssignableFrom(type)) { + + foreach (SerializationType sType in types) { + Type type = sType.Type; + if (type.HasElementType) { + if (type.IsArray) { + if (type.GetArrayRank() == 1) + writer.Write(Type_SZArray); + else + throw new NotSupportedException(); + } else { + throw new NotSupportedException(); + } + WriteTypeID(type.GetElementType()); + } else if (type.IsGenericType && !type.IsGenericTypeDefinition) { + writer.Write(Type_ParameterizedType); + WriteTypeID(type.GetGenericTypeDefinition()); + foreach (Type typeArg in type.GetGenericArguments()) { + WriteTypeID(typeArg); + } + } else { + if (type.IsValueType) { + writer.Write(Type_ValueType); + } else { + writer.Write(Type_ReferenceType); + } + if (assemblyNames.Count <= ushort.MaxValue) + writer.Write((ushort)sType.AssemblyNameID); + else + writer.Write(sType.AssemblyNameID); + writer.Write(sType.TypeName); + } + } + foreach (SerializationType sType in types) { + Type type = sType.Type; + if (type.IsGenericTypeDefinition || type.HasElementType) + continue; + if (type.IsPrimitive || typeof(ISerializable).IsAssignableFrom(type)) { writer.Write(byte.MaxValue); } else { var fields = GetSerializableFields(type); @@ -163,11 +284,7 @@ namespace ICSharpCode.NRefactory.Utils throw new SerializationException("Too many fields."); writer.Write((byte)fields.Count); foreach (var field in fields) { - int typeID = typeToID[field.FieldType]; - if (types.Count <= ushort.MaxValue) - writer.Write((ushort)typeID); - else - writer.Write(typeID); + WriteTypeID(field.FieldType); writer.Write(field.Name); } } @@ -176,24 +293,24 @@ namespace ICSharpCode.NRefactory.Utils // Write out information necessary to create the instances // starting from 1, because index 0 is null for (int i = 1; i < instances.Count; i++) { - int typeID = typeIDs[i]; + SerializationType sType = objectTypes[i]; if (types.Count <= ushort.MaxValue) - writer.Write((ushort)typeID); + writer.Write((ushort)sType.ID); else - writer.Write(typeID); - if (typeID == stringTypeID) { + writer.Write(sType.ID); + if (sType == stringType) { // Strings are written to the output immediately // - we can't create an empty string and fill it later writer.Write((string)instances[i]); - } else if (types[typeID].IsArray) { + } else if (sType.Type.IsArray) { // For arrays, write down the length, because we need that to create the array instance writer.Write(((Array)instances[i]).Length); } } // Write out information necessary to fill data into the instances for (int i = 1; i < instances.Count; i++) { - Log("0x{2:x6}, Write #{0}: {1}", i, types[typeIDs[i]].Name, writer.BaseStream.Position); - writers[typeIDs[i]](this, instances[i]); + Log("0x{2:x6}, Write #{0}: {1}", i, objectTypes[i].Type.Name, writer.BaseStream.Position); + objectTypes[i].Writer(this, instances[i]); } Log("Serialization done."); } @@ -652,7 +769,6 @@ namespace ICSharpCode.NRefactory.Utils sealed class DeserializationContext { public Type[] Types; // index: type ID - public ObjectReader[] ObjectReaders; // index: type ID public object[] Objects; // index: object ID @@ -667,7 +783,7 @@ namespace ICSharpCode.NRefactory.Utils } #region DeserializeTypeDescriptions - internal int ReadFieldTypeID() + internal int ReadTypeID() { if (this.Types.Length <= ushort.MaxValue) return Reader.ReadUInt16(); @@ -679,8 +795,10 @@ namespace ICSharpCode.NRefactory.Utils { for (int i = 0; i < this.Types.Length; i++) { Type type = this.Types[i]; + if (type.IsGenericTypeDefinition || type.HasElementType) + continue; bool isCustomSerialization = typeof(ISerializable).IsAssignableFrom(type); - bool typeIsSpecial = type.IsArray || type.IsPrimitive || isCustomSerialization; + bool typeIsSpecial = type.IsPrimitive || isCustomSerialization; byte serializedFieldCount = Reader.ReadByte(); if (serializedFieldCount == byte.MaxValue) { @@ -695,7 +813,7 @@ namespace ICSharpCode.NRefactory.Utils if (availableFields.Count != serializedFieldCount) throw new SerializationException("Number of fields on " + type.FullName + " has changed."); for (int j = 0; j < serializedFieldCount; j++) { - int fieldTypeID = ReadFieldTypeID(); + int fieldTypeID = ReadTypeID(); string fieldName = Reader.ReadString(); FieldInfo fieldInfo = availableFields[j]; @@ -705,9 +823,6 @@ namespace ICSharpCode.NRefactory.Utils throw new SerializationException(type.FullName + "." + fieldName + " was serialized as " + this.Types[fieldTypeID] + ", but now is " + fieldInfo.FieldType); } } - - if (i < this.ObjectReaders.Length && !isCustomSerialization) - this.ObjectReaders[i] = fastSerializer.GetReader(type); } } #endregion @@ -722,23 +837,67 @@ namespace ICSharpCode.NRefactory.Utils public object Deserialize(BinaryReader reader) { + if (reader.ReadInt32() != magic) + throw new SerializationException("The data cannot be read by FastSerializer (unknown magic value)"); + DeserializationContext context = new DeserializationContext(); context.Reader = reader; - context.Types = new Type[reader.ReadInt32()]; context.Objects = new object[reader.ReadInt32()]; - context.ObjectReaders = new ObjectReader[reader.ReadInt32()]; - int stringTypeID = reader.ReadInt32(); + context.Types = new Type[reader.ReadInt32()]; + string[] assemblyNames = new string[reader.ReadInt32()]; + + for (int i = 0; i < assemblyNames.Length; i++) { + assemblyNames[i] = reader.ReadString(); + } + int stringTypeID = -1; for (int i = 0; i < context.Types.Length; i++) { - string typeName = reader.ReadString(); - Type type = Type.GetType(typeName); - if (type == null) - throw new SerializationException("Could not find " + typeName); - context.Types[i] = type; + byte typeKind = reader.ReadByte(); + switch (typeKind) { + case Type_ReferenceType: + case Type_ValueType: + int assemblyID; + if (assemblyNames.Length <= ushort.MaxValue) + assemblyID = reader.ReadUInt16(); + else + assemblyID = reader.ReadInt32(); + string assemblyName = assemblyNames[assemblyID]; + string typeName = reader.ReadString(); + Type type; + if (SerializationBinder != null) { + type = SerializationBinder.BindToType(assemblyName, typeName); + } else { + type = Assembly.Load(assemblyName).GetType(typeName); + } + if (type == null) + throw new SerializationException("Could not find '" + typeName + "' in '" + assemblyName + "'"); + if (typeKind == Type_ValueType && !type.IsValueType) + throw new SerializationException("Expected '" + typeName + "' to be a value type, but it is reference type"); + if (typeKind == Type_ReferenceType && type.IsValueType) + throw new SerializationException("Expected '" + typeName + "' to be a reference type, but it is value type"); + context.Types[i] = type; + if (type == typeof(string)) + stringTypeID = i; + break; + case Type_SZArray: + context.Types[i] = context.Types[context.ReadTypeID()].MakeArrayType(); + break; + case Type_ParameterizedType: + Type genericType = context.Types[context.ReadTypeID()]; + int typeParameterCount = genericType.GetGenericArguments().Length; + Type[] typeArguments = new Type[typeParameterCount]; + for (int j = 0; j < typeArguments.Length; j++) { + typeArguments[j] = context.Types[context.ReadTypeID()]; + } + context.Types[i] = genericType.MakeGenericType(typeArguments); + break; + default: + throw new SerializationException("Unknown type kind"); + } } context.DeserializeTypeDescriptions(this); int[] typeIDByObjectID = new int[context.Objects.Length]; for (int i = 1; i < context.Objects.Length; i++) { - int typeID = context.ReadFieldTypeID(); + int typeID = context.ReadTypeID(); object instance; if (typeID == stringTypeID) { @@ -756,6 +915,7 @@ namespace ICSharpCode.NRefactory.Utils typeIDByObjectID[i] = typeID; } List customDeserializatons = new List(); + ObjectReader[] objectReaders = new ObjectReader[context.Types.Length]; // index: type ID for (int i = 1; i < context.Objects.Length; i++) { object instance = context.Objects[i]; int typeID = typeIDByObjectID[i]; @@ -773,7 +933,12 @@ namespace ICSharpCode.NRefactory.Utils CustomDeserializationAction action = GetCustomDeserializationAction(type); customDeserializatons.Add(new CustomDeserialization(instance, info, action)); } else { - context.ObjectReaders[typeID](context, instance); + ObjectReader objectReader = objectReaders[typeID]; + if (objectReader == null) { + objectReader = GetReader(context.Types[typeID]); + objectReaders[typeID] = objectReader; + } + objectReader(context, instance); } } Log("File was read successfully, now running {0} custom deserializations...", customDeserializatons.Count);