diff --git a/ICSharpCode.NRefactory.Tests/CSharp/OutputVisitorTests.cs b/ICSharpCode.NRefactory.Tests/CSharp/OutputVisitorTests.cs
new file mode 100644
index 0000000000..ab1eaf7d9b
--- /dev/null
+++ b/ICSharpCode.NRefactory.Tests/CSharp/OutputVisitorTests.cs
@@ -0,0 +1,53 @@
+// Copyright (c) AlphaSierraPapa for the SharpDevelop Team
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy of this
+// software and associated documentation files (the "Software"), to deal in the Software
+// without restriction, including without limitation the rights to use, copy, modify, merge,
+// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
+// to whom the Software is furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in all copies or
+// substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+// DEALINGS IN THE SOFTWARE.
+
+using System;
+using System.IO;
+using NUnit.Framework;
+
+namespace ICSharpCode.NRefactory.CSharp
+{
+ [TestFixture]
+ public class OutputVisitorTests
+ {
+ void AssertOutput(string expected, Expression expr, CSharpFormattingOptions policy = null)
+ {
+ if (policy == null)
+ policy = new CSharpFormattingOptions();;
+ StringWriter w = new StringWriter();
+ w.NewLine = "\n";
+ expr.AcceptVisitor(new OutputVisitor(new TextWriterOutputFormatter(w) { IndentationString = "\t" }, policy), null);
+ Assert.AreEqual(expected.Replace("\r", ""), w.ToString());
+ }
+
+ [Test, Ignore("Incorrect whitespace")]
+ public void AssignmentInCollectionInitialize()
+ {
+ Expression expr = new ObjectCreateExpression {
+ Type = new SimpleType("List"),
+ Initializer = new ArrayInitializerExpression(
+ new ArrayInitializerExpression(
+ new AssignmentExpression(new IdentifierExpression("a"), new PrimitiveExpression(1))
+ )
+ )
+ };
+
+ AssertOutput("new List {\n {\n a = 1\n }\n}", expr);
+ }
+ }
+}
diff --git a/ICSharpCode.NRefactory.Tests/ICSharpCode.NRefactory.Tests.csproj b/ICSharpCode.NRefactory.Tests/ICSharpCode.NRefactory.Tests.csproj
index 614f24eb7e..d859360338 100644
--- a/ICSharpCode.NRefactory.Tests/ICSharpCode.NRefactory.Tests.csproj
+++ b/ICSharpCode.NRefactory.Tests/ICSharpCode.NRefactory.Tests.csproj
@@ -57,6 +57,7 @@
+
diff --git a/ICSharpCode.NRefactory/Utils/FastSerializer.cs b/ICSharpCode.NRefactory/Utils/FastSerializer.cs
index 0bd3887c6d..9622b887c4 100644
--- a/ICSharpCode.NRefactory/Utils/FastSerializer.cs
+++ b/ICSharpCode.NRefactory/Utils/FastSerializer.cs
@@ -30,6 +30,24 @@ namespace ICSharpCode.NRefactory.Utils
{
public class FastSerializer
{
+ #region Properties
+ ///
+ /// Gets/Sets the serialization binder that is being used.
+ /// The default value is null, which will cause the FastSerializer to use the
+ /// full assembly and type names.
+ ///
+ public SerializationBinder SerializationBinder { get; set; }
+ #endregion
+
+ #region Constants
+ const int magic = 0x71D18A5D;
+
+ const byte Type_ReferenceType = 1;
+ const byte Type_ValueType = 2;
+ const byte Type_SZArray = 3;
+ const byte Type_ParameterizedType = 4;
+ #endregion
+
#region Serialization
sealed class ReferenceComparer : IEqualityComparer
{
@@ -44,17 +62,35 @@ namespace ICSharpCode.NRefactory.Utils
}
}
+ sealed class SerializationType
+ {
+ public readonly int ID;
+ public readonly Type Type;
+
+ public SerializationType(int iD, Type type)
+ {
+ this.ID = iD;
+ this.Type = type;
+ }
+
+ public ObjectScanner Scanner;
+ public ObjectWriter Writer;
+ public string TypeName;
+ public int AssemblyNameID;
+ }
+
sealed class SerializationContext
{
readonly Dictionary objectToID = new Dictionary(new ReferenceComparer());
readonly List instances = new List(); // index: object ID
- readonly List typeIDs = new List(); // index: object ID
- int stringTypeID = -1;
- int typeCountForObjects = 0;
+ readonly List objectTypes = new List(); // index: object ID
+ SerializationType stringType;
- readonly Dictionary typeToID = new Dictionary();
- readonly List types = new List(); // index: type ID
- readonly List writers = new List(); // index: type ID
+ readonly Dictionary typeMap = new Dictionary();
+ readonly List types = new List();
+
+ readonly Dictionary assemblyNameToID = new Dictionary();
+ readonly List assemblyNames = new List();
readonly FastSerializer fastSerializer;
public readonly BinaryWriter writer;
@@ -64,7 +100,7 @@ namespace ICSharpCode.NRefactory.Utils
this.fastSerializer = fastSerializer;
this.writer = writer;
instances.Add(null); // use object ID 0 for null
- typeIDs.Add(-1);
+ objectTypes.Add(null);
}
#region Scanning
@@ -84,26 +120,14 @@ namespace ICSharpCode.NRefactory.Utils
internal void Scan()
{
Log("Scanning...");
- List objectScanners = new List(); // index: type ID
// starting from 1, because index 0 is null
for (int i = 1; i < instances.Count; i++) {
object instance = instances[i];
ISerializable serializable = instance as ISerializable;
Type type = instance.GetType();
Log("Scan #{0}: {1}", i, type.Name);
- int typeID;
- if (!typeToID.TryGetValue(type, out typeID)) {
- typeID = types.Count;
- typeToID.Add(type, typeID);
- types.Add(type);
- Log("Registered type %{0}: {1}", typeID, type);
- if (type == typeof(string)) {
- stringTypeID = typeID;
- }
- objectScanners.Add(serializable != null ? null : fastSerializer.GetScanner(type));
- writers.Add(serializable != null ? serializationInfoWriter : fastSerializer.GetWriter(type));
- }
- typeIDs.Add(typeID);
+ SerializationType sType = MarkType(type);
+ objectTypes.Add(sType);
if (serializable != null) {
SerializationInfo info = new SerializationInfo(type, fastSerializer.formatterConverter);
serializable.GetObjectData(info, fastSerializer.streamingContext);
@@ -111,23 +135,74 @@ namespace ICSharpCode.NRefactory.Utils
foreach (SerializationEntry entry in info) {
Mark(entry.Value);
}
+ sType.Writer = serializationInfoWriter;
} else {
- objectScanners[typeID](this, instance);
+ ObjectScanner objectScanner = sType.Scanner;
+ if (objectScanner == null) {
+ objectScanner = fastSerializer.GetScanner(type);
+ sType.Scanner = objectScanner;
+ sType.Writer = fastSerializer.GetWriter(type);
+ }
+ objectScanner(this, instance);
}
}
}
#endregion
#region Scan Types
+ SerializationType MarkType(Type type)
+ {
+ SerializationType sType;
+ if (!typeMap.TryGetValue(type, out sType)) {
+ string assemblyName = null;
+ string typeName = null;
+ if (type.HasElementType) {
+ MarkType(type.GetElementType());
+ } else if (type.IsGenericType && !type.IsGenericTypeDefinition) {
+ MarkType(type.GetGenericTypeDefinition());
+ foreach (Type typeArg in type.GetGenericArguments())
+ MarkType(typeArg);
+ } else if (type.IsGenericParameter) {
+ throw new NotSupportedException();
+ } else {
+ var serializationBinder = fastSerializer.SerializationBinder;
+ if (serializationBinder != null) {
+ serializationBinder.BindToName(type, out assemblyName, out typeName);
+ } else {
+ assemblyName = type.Assembly.FullName;
+ typeName = type.FullName;
+ Debug.Assert(typeName != null);
+ }
+ }
+
+ sType = new SerializationType(typeMap.Count, type);
+ sType.TypeName = typeName;
+ if (assemblyName != null) {
+ if (!assemblyNameToID.TryGetValue(assemblyName, out sType.AssemblyNameID)) {
+ sType.AssemblyNameID = assemblyNames.Count;
+ assemblyNameToID.Add(assemblyName, sType.AssemblyNameID);
+ assemblyNames.Add(assemblyName);
+ Log("Registered assembly #{0}: {1}", sType.AssemblyNameID, assemblyName);
+ }
+ }
+ typeMap.Add(type, sType);
+ types.Add(sType);
+ Log("Registered type %{0}: {1}", sType.ID, type);
+ if (type == typeof(string)) {
+ stringType = sType;
+ }
+ }
+ return sType;
+ }
+
internal void ScanTypes()
{
- typeCountForObjects = types.Count;
for (int i = 0; i < types.Count; i++) {
- foreach (FieldInfo field in GetSerializableFields(types[i])) {
- if (!typeToID.ContainsKey(field.FieldType)) {
- typeToID.Add(field.FieldType, types.Count);
- types.Add(field.FieldType);
- }
+ Type type = types[i].Type;
+ if (type.IsGenericTypeDefinition || type.HasElementType)
+ continue;
+ foreach (FieldInfo field in GetSerializableFields(type)) {
+ MarkType(field.FieldType);
}
}
}
@@ -143,19 +218,65 @@ namespace ICSharpCode.NRefactory.Utils
writer.Write(id);
}
+ void WriteTypeID(Type type)
+ {
+ Debug.Assert(typeMap.ContainsKey(type));
+ int typeID = typeMap[type].ID;
+ if (types.Count <= ushort.MaxValue)
+ writer.Write((ushort)typeID);
+ else
+ writer.Write(typeID);
+ }
+
internal void Write()
{
Log("Writing...");
+ writer.Write(magic);
// Write out type information
- writer.Write(types.Count);
writer.Write(instances.Count);
- writer.Write(typeCountForObjects);
- writer.Write(stringTypeID);
- foreach (Type type in types) {
- writer.Write(type.AssemblyQualifiedName);
+ writer.Write(types.Count);
+ writer.Write(assemblyNames.Count);
+
+ foreach (string assemblyName in assemblyNames) {
+ writer.Write(assemblyName);
}
- foreach (Type type in types) {
- if (type.IsArray || type.IsPrimitive || typeof(ISerializable).IsAssignableFrom(type)) {
+
+ foreach (SerializationType sType in types) {
+ Type type = sType.Type;
+ if (type.HasElementType) {
+ if (type.IsArray) {
+ if (type.GetArrayRank() == 1)
+ writer.Write(Type_SZArray);
+ else
+ throw new NotSupportedException();
+ } else {
+ throw new NotSupportedException();
+ }
+ WriteTypeID(type.GetElementType());
+ } else if (type.IsGenericType && !type.IsGenericTypeDefinition) {
+ writer.Write(Type_ParameterizedType);
+ WriteTypeID(type.GetGenericTypeDefinition());
+ foreach (Type typeArg in type.GetGenericArguments()) {
+ WriteTypeID(typeArg);
+ }
+ } else {
+ if (type.IsValueType) {
+ writer.Write(Type_ValueType);
+ } else {
+ writer.Write(Type_ReferenceType);
+ }
+ if (assemblyNames.Count <= ushort.MaxValue)
+ writer.Write((ushort)sType.AssemblyNameID);
+ else
+ writer.Write(sType.AssemblyNameID);
+ writer.Write(sType.TypeName);
+ }
+ }
+ foreach (SerializationType sType in types) {
+ Type type = sType.Type;
+ if (type.IsGenericTypeDefinition || type.HasElementType)
+ continue;
+ if (type.IsPrimitive || typeof(ISerializable).IsAssignableFrom(type)) {
writer.Write(byte.MaxValue);
} else {
var fields = GetSerializableFields(type);
@@ -163,11 +284,7 @@ namespace ICSharpCode.NRefactory.Utils
throw new SerializationException("Too many fields.");
writer.Write((byte)fields.Count);
foreach (var field in fields) {
- int typeID = typeToID[field.FieldType];
- if (types.Count <= ushort.MaxValue)
- writer.Write((ushort)typeID);
- else
- writer.Write(typeID);
+ WriteTypeID(field.FieldType);
writer.Write(field.Name);
}
}
@@ -176,24 +293,24 @@ namespace ICSharpCode.NRefactory.Utils
// Write out information necessary to create the instances
// starting from 1, because index 0 is null
for (int i = 1; i < instances.Count; i++) {
- int typeID = typeIDs[i];
+ SerializationType sType = objectTypes[i];
if (types.Count <= ushort.MaxValue)
- writer.Write((ushort)typeID);
+ writer.Write((ushort)sType.ID);
else
- writer.Write(typeID);
- if (typeID == stringTypeID) {
+ writer.Write(sType.ID);
+ if (sType == stringType) {
// Strings are written to the output immediately
// - we can't create an empty string and fill it later
writer.Write((string)instances[i]);
- } else if (types[typeID].IsArray) {
+ } else if (sType.Type.IsArray) {
// For arrays, write down the length, because we need that to create the array instance
writer.Write(((Array)instances[i]).Length);
}
}
// Write out information necessary to fill data into the instances
for (int i = 1; i < instances.Count; i++) {
- Log("0x{2:x6}, Write #{0}: {1}", i, types[typeIDs[i]].Name, writer.BaseStream.Position);
- writers[typeIDs[i]](this, instances[i]);
+ Log("0x{2:x6}, Write #{0}: {1}", i, objectTypes[i].Type.Name, writer.BaseStream.Position);
+ objectTypes[i].Writer(this, instances[i]);
}
Log("Serialization done.");
}
@@ -652,7 +769,6 @@ namespace ICSharpCode.NRefactory.Utils
sealed class DeserializationContext
{
public Type[] Types; // index: type ID
- public ObjectReader[] ObjectReaders; // index: type ID
public object[] Objects; // index: object ID
@@ -667,7 +783,7 @@ namespace ICSharpCode.NRefactory.Utils
}
#region DeserializeTypeDescriptions
- internal int ReadFieldTypeID()
+ internal int ReadTypeID()
{
if (this.Types.Length <= ushort.MaxValue)
return Reader.ReadUInt16();
@@ -679,8 +795,10 @@ namespace ICSharpCode.NRefactory.Utils
{
for (int i = 0; i < this.Types.Length; i++) {
Type type = this.Types[i];
+ if (type.IsGenericTypeDefinition || type.HasElementType)
+ continue;
bool isCustomSerialization = typeof(ISerializable).IsAssignableFrom(type);
- bool typeIsSpecial = type.IsArray || type.IsPrimitive || isCustomSerialization;
+ bool typeIsSpecial = type.IsPrimitive || isCustomSerialization;
byte serializedFieldCount = Reader.ReadByte();
if (serializedFieldCount == byte.MaxValue) {
@@ -695,7 +813,7 @@ namespace ICSharpCode.NRefactory.Utils
if (availableFields.Count != serializedFieldCount)
throw new SerializationException("Number of fields on " + type.FullName + " has changed.");
for (int j = 0; j < serializedFieldCount; j++) {
- int fieldTypeID = ReadFieldTypeID();
+ int fieldTypeID = ReadTypeID();
string fieldName = Reader.ReadString();
FieldInfo fieldInfo = availableFields[j];
@@ -705,9 +823,6 @@ namespace ICSharpCode.NRefactory.Utils
throw new SerializationException(type.FullName + "." + fieldName + " was serialized as " + this.Types[fieldTypeID] + ", but now is " + fieldInfo.FieldType);
}
}
-
- if (i < this.ObjectReaders.Length && !isCustomSerialization)
- this.ObjectReaders[i] = fastSerializer.GetReader(type);
}
}
#endregion
@@ -722,23 +837,67 @@ namespace ICSharpCode.NRefactory.Utils
public object Deserialize(BinaryReader reader)
{
+ if (reader.ReadInt32() != magic)
+ throw new SerializationException("The data cannot be read by FastSerializer (unknown magic value)");
+
DeserializationContext context = new DeserializationContext();
context.Reader = reader;
- context.Types = new Type[reader.ReadInt32()];
context.Objects = new object[reader.ReadInt32()];
- context.ObjectReaders = new ObjectReader[reader.ReadInt32()];
- int stringTypeID = reader.ReadInt32();
+ context.Types = new Type[reader.ReadInt32()];
+ string[] assemblyNames = new string[reader.ReadInt32()];
+
+ for (int i = 0; i < assemblyNames.Length; i++) {
+ assemblyNames[i] = reader.ReadString();
+ }
+ int stringTypeID = -1;
for (int i = 0; i < context.Types.Length; i++) {
- string typeName = reader.ReadString();
- Type type = Type.GetType(typeName);
- if (type == null)
- throw new SerializationException("Could not find " + typeName);
- context.Types[i] = type;
+ byte typeKind = reader.ReadByte();
+ switch (typeKind) {
+ case Type_ReferenceType:
+ case Type_ValueType:
+ int assemblyID;
+ if (assemblyNames.Length <= ushort.MaxValue)
+ assemblyID = reader.ReadUInt16();
+ else
+ assemblyID = reader.ReadInt32();
+ string assemblyName = assemblyNames[assemblyID];
+ string typeName = reader.ReadString();
+ Type type;
+ if (SerializationBinder != null) {
+ type = SerializationBinder.BindToType(assemblyName, typeName);
+ } else {
+ type = Assembly.Load(assemblyName).GetType(typeName);
+ }
+ if (type == null)
+ throw new SerializationException("Could not find '" + typeName + "' in '" + assemblyName + "'");
+ if (typeKind == Type_ValueType && !type.IsValueType)
+ throw new SerializationException("Expected '" + typeName + "' to be a value type, but it is reference type");
+ if (typeKind == Type_ReferenceType && type.IsValueType)
+ throw new SerializationException("Expected '" + typeName + "' to be a reference type, but it is value type");
+ context.Types[i] = type;
+ if (type == typeof(string))
+ stringTypeID = i;
+ break;
+ case Type_SZArray:
+ context.Types[i] = context.Types[context.ReadTypeID()].MakeArrayType();
+ break;
+ case Type_ParameterizedType:
+ Type genericType = context.Types[context.ReadTypeID()];
+ int typeParameterCount = genericType.GetGenericArguments().Length;
+ Type[] typeArguments = new Type[typeParameterCount];
+ for (int j = 0; j < typeArguments.Length; j++) {
+ typeArguments[j] = context.Types[context.ReadTypeID()];
+ }
+ context.Types[i] = genericType.MakeGenericType(typeArguments);
+ break;
+ default:
+ throw new SerializationException("Unknown type kind");
+ }
}
context.DeserializeTypeDescriptions(this);
int[] typeIDByObjectID = new int[context.Objects.Length];
for (int i = 1; i < context.Objects.Length; i++) {
- int typeID = context.ReadFieldTypeID();
+ int typeID = context.ReadTypeID();
object instance;
if (typeID == stringTypeID) {
@@ -756,6 +915,7 @@ namespace ICSharpCode.NRefactory.Utils
typeIDByObjectID[i] = typeID;
}
List customDeserializatons = new List();
+ ObjectReader[] objectReaders = new ObjectReader[context.Types.Length]; // index: type ID
for (int i = 1; i < context.Objects.Length; i++) {
object instance = context.Objects[i];
int typeID = typeIDByObjectID[i];
@@ -773,7 +933,12 @@ namespace ICSharpCode.NRefactory.Utils
CustomDeserializationAction action = GetCustomDeserializationAction(type);
customDeserializatons.Add(new CustomDeserialization(instance, info, action));
} else {
- context.ObjectReaders[typeID](context, instance);
+ ObjectReader objectReader = objectReaders[typeID];
+ if (objectReader == null) {
+ objectReader = GetReader(context.Types[typeID]);
+ objectReaders[typeID] = objectReader;
+ }
+ objectReader(context, instance);
}
}
Log("File was read successfully, now running {0} custom deserializations...", customDeserializatons.Count);