Browse Source

Fixed type inference in foreach when the collection type does not implement IEnumerable.

newNRvisualizers
Daniel Grunwald 14 years ago
parent
commit
42ce4ca6e1
  1. 33
      ICSharpCode.NRefactory.CSharp/Resolver/ResolveVisitor.cs
  2. 149
      ICSharpCode.NRefactory.Tests/CSharp/Resolver/LocalTypeInferenceTests.cs

33
ICSharpCode.NRefactory.CSharp/Resolver/ResolveVisitor.cs

@ -2424,7 +2424,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -2424,7 +2424,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
if (IsVar(foreachStatement.VariableType)) {
if (navigator.Scan(foreachStatement.VariableType) == ResolveVisitorNavigationMode.Resolve) {
IType collectionType = Resolve(foreachStatement.InExpression).Type;
IType elementType = GetElementType(collectionType, resolver.Compilation, false);
IType elementType = GetElementTypeFromCollection(collectionType);
StoreCurrentState(foreachStatement.VariableType);
StoreResult(foreachStatement.VariableType, new TypeResolveResult(elementType));
v = MakeVariable(elementType, foreachStatement.VariableNameToken);
@ -2581,7 +2581,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -2581,7 +2581,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
{
if (resolverEnabled && resolver.CurrentMember != null) {
IType returnType = resolver.CurrentMember.ReturnType;
IType elementType = GetElementType(returnType, resolver.Compilation, true);
IType elementType = GetElementTypeFromIEnumerable(returnType, resolver.Compilation, true);
ResolveAndProcessConversion(yieldStatement.Expression, elementType);
} else {
Scan(yieldStatement.Expression);
@ -2864,7 +2864,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -2864,7 +2864,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
var result = visitor.Resolve(initializer).Type;
if (isForEach) {
result = GetElementType(result, storedContext.Compilation, false);
result = visitor.GetElementTypeFromCollection(result);
}
this.resolvedType = result;
});
@ -2936,10 +2936,31 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -2936,10 +2936,31 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
}
}
*/
static IType GetElementType(IType result, ICompilation compilation, bool allowIEnumerator)
IType GetElementTypeFromCollection(IType collectionType)
{
switch (collectionType.Kind) {
case TypeKind.Array:
return ((ArrayType)collectionType).ElementType;
case TypeKind.Dynamic:
return SpecialType.Dynamic;
}
var memberLookup = resolver.CreateMemberLookup();
var getEnumeratorMethodGroup = memberLookup.Lookup(new ResolveResult(collectionType), "GetEnumerator", EmptyList<IType>.Instance, true) as MethodGroupResolveResult;
if (getEnumeratorMethodGroup != null) {
var or = getEnumeratorMethodGroup.PerformOverloadResolution(resolver.Compilation, new ResolveResult[0]);
if (or.FoundApplicableCandidate && !or.IsAmbiguous && !or.BestCandidate.IsStatic && or.BestCandidate.IsPublic) {
IType enumeratorType = or.BestCandidate.ReturnType;
return memberLookup.Lookup(new ResolveResult(enumeratorType), "Current", EmptyList<IType>.Instance, false).Type;
}
}
return GetElementTypeFromIEnumerable(collectionType, resolver.Compilation, false);
}
static IType GetElementTypeFromIEnumerable(IType collectionType, ICompilation compilation, bool allowIEnumerator)
{
bool foundNonGenericIEnumerable = false;
foreach (IType baseType in result.GetAllBaseTypes()) {
foreach (IType baseType in collectionType.GetAllBaseTypes()) {
ITypeDefinition baseTypeDef = baseType.GetDefinition();
if (baseTypeDef != null) {
KnownTypeCode typeCode = baseTypeDef.KnownTypeCode;
@ -3124,7 +3145,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -3124,7 +3145,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
{
// This assumes queries are only used on IEnumerable.
// We might want to look at the signature of a LINQ method (e.g. Select) instead.
return GetElementType(type, resolver.Compilation, false);
return GetElementTypeFromIEnumerable(type, resolver.Compilation, false);
}
sealed class QueryExpressionLambda : LambdaResolveResult

149
ICSharpCode.NRefactory.Tests/CSharp/Resolver/LocalTypeInferenceTests.cs

@ -17,6 +17,8 @@ @@ -17,6 +17,8 @@
// DEALINGS IN THE SOFTWARE.
using System;
using System.Collections;
using System.Collections.Generic;
using ICSharpCode.NRefactory.Semantics;
using ICSharpCode.NRefactory.TypeSystem;
using NUnit.Framework;
@ -61,11 +63,152 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -61,11 +63,152 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
string program = @"using System;
class TestClass {
static void Main() {
var contact = {id = 54321};
$contact$.ToString();
} }";
var contact = {id = 54321};
$contact$.ToString();
} }";
var lrr = Resolve<LocalResolveResult>(program);
Assert.AreEqual(SpecialType.UnknownType, lrr.Type);
}
[Test]
public void Foreach_InferFromArrayType()
{
string program = @"using System;
class TestClass {
static void Method(int[] arr) {
foreach ($var$ x in arr) {}
} }";
var rr = Resolve<TypeResolveResult>(program);
Assert.AreEqual("System.Int32", rr.Type.ReflectionName);
}
[Test]
public void Foreach_InferFromDynamic()
{
string program = @"using System;
class TestClass {
static void Method(dynamic c) {
foreach ($var$ x in c) {}
} }";
var rr = Resolve<TypeResolveResult>(program);
Assert.AreEqual(TypeKind.Dynamic, rr.Type.Kind);
}
[Test]
public void Foreach_InferFromListOfInt()
{
string program = @"using System;
using System.Collections.Generic;
class TestClass {
static void Method(List<int> c) {
foreach ($var$ x in c) {}
} }";
var rr = Resolve<TypeResolveResult>(program);
Assert.AreEqual("System.Int32", rr.Type.ReflectionName);
}
[Test]
public void Foreach_InferFromICollectionOfInt()
{
string program = @"using System;
using System.Collections.Generic;
class TestClass {
static void Method(ICollection<int> c) {
foreach ($var$ x in c) {}
} }";
var rr = Resolve<TypeResolveResult>(program);
Assert.AreEqual("System.Int32", rr.Type.ReflectionName);
}
[Test]
public void Foreach_InferFromCustomCollection_WithoutIEnumerable()
{
string program = @"using System;
using System.Collections.Generic;
class TestClass {
static void Method(CustomCollection c) {
foreach ($var$ x in c) {}
}
}
class CustomCollection {
public MyEnumerator GetEnumerator() {}
public struct MyEnumerator {
public string Current { get { return null; } }
public bool MoveNext() { return false; }
}
}
";
var rr = Resolve<TypeResolveResult>(program);
Assert.AreEqual("System.String", rr.Type.ReflectionName);
}
[Test]
public void Foreach_InferFromCustomCollection_WithIEnumerableAndPublicGetEnumerator()
{
string program = @"using System;
using System.Collections.Generic;
class TestClass {
static void Method(CustomCollection c) {
foreach ($var$ x in c) {}
}
}
class CustomCollection : IEnumerable<int> {
public MyEnumerator GetEnumerator() {}
public struct MyEnumerator {
public string Current { get { return null; } }
public bool MoveNext() { return false; }
}
}
";
var rr = Resolve<TypeResolveResult>(program);
Assert.AreEqual("System.String", rr.Type.ReflectionName);
}
[Test]
public void Foreach_InferFromCustomCollection_WithIEnumerableAndInternalGetEnumerator()
{
string program = @"using System;
using System.Collections.Generic;
class TestClass {
static void Method(CustomCollection c) {
foreach ($var$ x in c) {}
}
}
class CustomCollection : IEnumerable<int> {
internal MyEnumerator GetEnumerator() {}
public struct MyEnumerator {
public string Current { get { return null; } }
public bool MoveNext() { return false; }
}
}
";
var rr = Resolve<TypeResolveResult>(program);
Assert.AreEqual("System.Int32", rr.Type.ReflectionName);
}
[Test]
public void Foreach_InferFromCustomCollection_WithIEnumerableAndGetEnumeratorExtensionMethod()
{
string program = @"using System;
using System.Collections.Generic;
class TestClass {
static void Method(CustomCollection c) {
foreach ($var$ x in c) {}
}
}
class CustomCollection : IEnumerable<int> {
public struct MyEnumerator {
public string Current { get { return null; } }
public bool MoveNext() { return false; }
}
}
static class ExtMethods {
public static CustomCollection.MyEnumerator GetEnumerator(this CustomCollection c) {
throw new NotImplementedException();
}
}";
var rr = Resolve<TypeResolveResult>(program);
Assert.AreEqual("System.Int32", rr.Type.ReflectionName);
}
}
}

Loading…
Cancel
Save