diff --git a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/TypeInferenceTests.cs b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/TypeInferenceTests.cs index 8a64a63df4..22ecb07833 100644 --- a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/TypeInferenceTests.cs +++ b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/TypeInferenceTests.cs @@ -14,7 +14,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver [TestFixture] public class TypeInferenceTests { - CommonTypeInference cti = new CommonTypeInference(CecilLoaderTests.Mscorlib, new Conversions(CecilLoaderTests.Mscorlib)); + TypeInference ti = new TypeInference(CecilLoaderTests.Mscorlib); IType[] Resolve(params Type[] types) { @@ -23,10 +23,25 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver r[i] = types[i].ToTypeReference().Resolve(CecilLoaderTests.Mscorlib); Assert.AreNotSame(r[i], SharedTypes.UnknownType); } - Array.Sort(r, (a,b)=>a.ReflectionName.CompareTo(b.ReflectionName)); return r; } + IType ResolveType(params Type[] type) + { + return type.Single().ToTypeReference().Resolve(CecilLoaderTests.Mscorlib); + } + + [Test] + public void ListOfShortAndInt() + { + Assert.AreEqual( + ResolveType(typeof(IList)), + ti.FindTypeInBounds(Resolve(typeof(List), typeof(List)), Resolve())); + } + + + + /* IType[] CommonBaseTypes(params Type[] types) { return cti.CommonBaseTypes(Resolve(types)).OrderBy(r => r.ReflectionName).ToArray(); @@ -41,8 +56,8 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver public void ListOfStringAndObject() { Assert.AreEqual( - Resolve(typeof(IList), typeof(IEnumerable)), - CommonBaseTypes(typeof(List), typeof(List))); + ResolveType(typeof(IList), typeof(IEnumerable)), + ti.FindTypeInBounds(Resolve(), Resolve(typeof(List), typeof(List)))); } [Test] @@ -61,14 +76,6 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver CommonBaseTypes(typeof(short), typeof(int))); } - [Test] - public void ListOfShortAndInt() - { - Assert.AreEqual( - Resolve(typeof(IList)), - CommonBaseTypes(typeof(List), typeof(List))); - } - [Test] public void StringAndVersion() { @@ -101,12 +108,12 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver CommonSubTypes(typeof(IEnumerable), typeof(IEnumerable))); } - [Test, Ignore("currently implementation is broken for anything nontrivial...")] + [Test] public void CommonSubTypeIEnumerableClonableIEnumerableComparableList() { Assert.AreEqual( Resolve(typeof(List), typeof(List)), CommonSubTypes(typeof(IEnumerable), typeof(IEnumerable), typeof(IList))); - } + }*/ } } diff --git a/ICSharpCode.NRefactory/CSharp/Resolver/TypeInference.cs b/ICSharpCode.NRefactory/CSharp/Resolver/TypeInference.cs index 925c0e14e1..aa2fa34e52 100644 --- a/ICSharpCode.NRefactory/CSharp/Resolver/TypeInference.cs +++ b/ICSharpCode.NRefactory/CSharp/Resolver/TypeInference.cs @@ -10,6 +10,18 @@ using ICSharpCode.NRefactory.TypeSystem.Implementation; namespace ICSharpCode.NRefactory.CSharp.Resolver { + public enum TypeInferenceAlgorithm + { + /// + /// C# 4.0 type inference. + /// + CSharp40, + /// + /// Improved algorithm (not part of any specification) using FindTypeInBounds. + /// + Improved + } + /// /// Implements C# 4.0 Type Inference (§7.5.2). /// @@ -26,6 +38,13 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver } #endregion + #region Properties + /// + /// Gets/Sets the type inference algorithm used. + /// + public TypeInferenceAlgorithm Algorithm { get; set; } + #endregion + TP[] typeParameters; IType[] parameterTypes; ResolveResult[] arguments; @@ -373,6 +392,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver } #endregion + #region MakeExplicitParameterTypeInference (§7.5.2.7) void MakeExplicitParameterTypeInference(ResolveResult e, IType t) { // C# 4.0 spec: §7.5.2.7 Explicit parameter type inferences @@ -387,11 +407,12 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver } }*/ } + #endregion - #region MakeExactInference + #region MakeExactInference (§7.5.2.8) /// /// Make exact inference from U to V. - /// C# 4.0 spec: 7.5.2.8 Exact inferences + /// C# 4.0 spec: §7.5.2.8 Exact inferences /// void MakeExactInference(IType U, IType V) { @@ -446,7 +467,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver } #endregion - #region MakeLowerBoundInference + #region MakeLowerBoundInference (§7.5.2.9) /// /// Make lower bound inference from U to V. /// C# 4.0 spec: §7.5.2.9 Lower-bound inferences @@ -529,7 +550,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver } #endregion - #region MakeUpperBoundInference + #region MakeUpperBoundInference (§7.5.2.10) /// /// Make upper bound inference from U to V. /// C# 4.0 spec: §7.5.2.10 Upper-bound inferences @@ -740,13 +761,81 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver /// /// Finds a type that satisfies the given lower and upper bounds. /// - public IType FindTypeInBounds(IEnumerable lowerBounds, IEnumerable upperBounds) + public IType FindTypeInBounds(IList lowerBounds, IList upperBounds) { if (lowerBounds == null) throw new ArgumentNullException("lowerBounds"); if (upperBounds == null) throw new ArgumentNullException("upperBounds"); - throw new NotImplementedException(); + + // Finds a type X so that "LB <: X <: UB" + + List candidateTypeDefinitions; + if (lowerBounds.Count > 0) { + // Find candidates by using the lower bounds: + var hashSet = new HashSet(lowerBounds[0].GetAllBaseTypeDefinitions(context)); + for (int i = 1; i < lowerBounds.Count; i++) { + hashSet.IntersectWith(lowerBounds[i].GetAllBaseTypeDefinitions(context)); + } + candidateTypeDefinitions = hashSet.ToList(); + } else { + // Find candidates by looking at all classes in the project: + candidateTypeDefinitions = context.GetAllClasses().ToList(); + } + + // Now filter out candidates that violate the upper bounds: + foreach (IType ub in upperBounds) { + ITypeDefinition ubDef = ub.GetDefinition(); + if (ubDef != null) { + candidateTypeDefinitions.RemoveAll(c => !c.IsDerivedFrom(ubDef, context)); + } + } + + List candidateTypes = new List(); + foreach (ITypeDefinition candidateDef in candidateTypeDefinitions) { + // determine the type parameters for the candidate: + IType candidate; + if (candidateDef.TypeParameterCount == 0) { + candidate = candidateDef; + } else { + bool success; + IType[] result = InferTypeArgumentsFromBounds( + candidateDef.TypeParameters, + new ParameterizedType(candidateDef, candidateDef.TypeParameters), + lowerBounds, upperBounds, + out success); + if (success) { + candidate = new ParameterizedType(candidateDef, result); + } else { + continue; + } + } + + if (lowerBounds.Count > 0) { + // if there were lower bounds, we aim for the most specific candidate: + + // if this candidate isn't made redundant by an existing, more specific candidate: + if (!candidateTypes.Any(c => c.GetDefinition().IsDerivedFrom(candidateDef, context))) { + // remove all existing candidates made redundant by this candidate: + candidateTypes.RemoveAll(c => candidateDef.IsDerivedFrom(c.GetDefinition(), context)); + // add new candidate + candidateTypes.Add(candidate); + } + } else { + // if there only were upper bounds, we aim for the least specific candidate: + + // if this candidate isn't made redundant by an existing, less specific candidate: + if (!candidateTypes.Any(c => candidateDef.IsDerivedFrom(c.GetDefinition(), context))) { + // remove all existing candidates made redundant by this candidate: + candidateTypes.RemoveAll(c => c.GetDefinition().IsDerivedFrom(candidateDef, context)); + // add new candidate + candidateTypes.Add(candidate); + } + } + } + // return any of the candidates (prefer non-interfaces) + return candidateTypes.FirstOrDefault(c => c.GetDefinition().ClassType != ClassType.Interface) + ?? candidateTypes.FirstOrDefault() ?? SharedTypes.UnknownType; } #endregion diff --git a/ICSharpCode.NRefactory/TypeSystem/ExtensionMethods.cs b/ICSharpCode.NRefactory/TypeSystem/ExtensionMethods.cs index 1f0d7fc9ef..c4529f2b69 100644 --- a/ICSharpCode.NRefactory/TypeSystem/ExtensionMethods.cs +++ b/ICSharpCode.NRefactory/TypeSystem/ExtensionMethods.cs @@ -62,11 +62,24 @@ namespace ICSharpCode.NRefactory.TypeSystem /// /// This is equivalent to type.GetAllBaseTypes().Select(t => t.GetDefinition()).Where(d => d != null).Distinct(). /// - public static IEnumerable GetAllBaseTypeDefinitions(this ITypeDefinition type, ITypeResolveContext context) + public static IEnumerable GetAllBaseTypeDefinitions(this IType type, ITypeResolveContext context) { + if (type == null) + throw new ArgumentNullException("type"); + if (context == null) + throw new ArgumentNullException("context"); + HashSet typeDefinitions = new HashSet(); - typeDefinitions.Add(type); - return TreeTraversal.PreOrder(type, t => t.GetBaseTypes(context).Select(b => b.GetDefinition()).Where(d => d != null && typeDefinitions.Add(d))); + Func> recursion = + t => t.GetBaseTypes(context).Select(b => b.GetDefinition()).Where(d => d != null && typeDefinitions.Add(d)); + + ITypeDefinition typeDef = type as ITypeDefinition; + if (typeDef != null) { + typeDefinitions.Add(typeDef); + return TreeTraversal.PreOrder(typeDef, recursion); + } else { + return TreeTraversal.PreOrder(recursion(type), recursion); + } } ///