From 42ce4ca6e10c6277540fda6136cc239353990c42 Mon Sep 17 00:00:00 2001 From: Daniel Grunwald Date: Thu, 1 Dec 2011 14:20:12 +0100 Subject: [PATCH] Fixed type inference in foreach when the collection type does not implement IEnumerable. --- .../Resolver/ResolveVisitor.cs | 33 +++- .../Resolver/LocalTypeInferenceTests.cs | 149 +++++++++++++++++- 2 files changed, 173 insertions(+), 9 deletions(-) diff --git a/ICSharpCode.NRefactory.CSharp/Resolver/ResolveVisitor.cs b/ICSharpCode.NRefactory.CSharp/Resolver/ResolveVisitor.cs index 456c215b1c..f31bd685c8 100644 --- a/ICSharpCode.NRefactory.CSharp/Resolver/ResolveVisitor.cs +++ b/ICSharpCode.NRefactory.CSharp/Resolver/ResolveVisitor.cs @@ -2424,7 +2424,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver if (IsVar(foreachStatement.VariableType)) { if (navigator.Scan(foreachStatement.VariableType) == ResolveVisitorNavigationMode.Resolve) { IType collectionType = Resolve(foreachStatement.InExpression).Type; - IType elementType = GetElementType(collectionType, resolver.Compilation, false); + IType elementType = GetElementTypeFromCollection(collectionType); StoreCurrentState(foreachStatement.VariableType); StoreResult(foreachStatement.VariableType, new TypeResolveResult(elementType)); v = MakeVariable(elementType, foreachStatement.VariableNameToken); @@ -2581,7 +2581,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver { if (resolverEnabled && resolver.CurrentMember != null) { IType returnType = resolver.CurrentMember.ReturnType; - IType elementType = GetElementType(returnType, resolver.Compilation, true); + IType elementType = GetElementTypeFromIEnumerable(returnType, resolver.Compilation, true); ResolveAndProcessConversion(yieldStatement.Expression, elementType); } else { Scan(yieldStatement.Expression); @@ -2864,7 +2864,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver var result = visitor.Resolve(initializer).Type; if (isForEach) { - result = GetElementType(result, storedContext.Compilation, false); + result = visitor.GetElementTypeFromCollection(result); } this.resolvedType = result; }); @@ -2936,10 +2936,31 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver } } */ - static IType GetElementType(IType result, ICompilation compilation, bool allowIEnumerator) + + IType GetElementTypeFromCollection(IType collectionType) + { + switch (collectionType.Kind) { + case TypeKind.Array: + return ((ArrayType)collectionType).ElementType; + case TypeKind.Dynamic: + return SpecialType.Dynamic; + } + var memberLookup = resolver.CreateMemberLookup(); + var getEnumeratorMethodGroup = memberLookup.Lookup(new ResolveResult(collectionType), "GetEnumerator", EmptyList.Instance, true) as MethodGroupResolveResult; + if (getEnumeratorMethodGroup != null) { + var or = getEnumeratorMethodGroup.PerformOverloadResolution(resolver.Compilation, new ResolveResult[0]); + if (or.FoundApplicableCandidate && !or.IsAmbiguous && !or.BestCandidate.IsStatic && or.BestCandidate.IsPublic) { + IType enumeratorType = or.BestCandidate.ReturnType; + return memberLookup.Lookup(new ResolveResult(enumeratorType), "Current", EmptyList.Instance, false).Type; + } + } + return GetElementTypeFromIEnumerable(collectionType, resolver.Compilation, false); + } + + static IType GetElementTypeFromIEnumerable(IType collectionType, ICompilation compilation, bool allowIEnumerator) { bool foundNonGenericIEnumerable = false; - foreach (IType baseType in result.GetAllBaseTypes()) { + foreach (IType baseType in collectionType.GetAllBaseTypes()) { ITypeDefinition baseTypeDef = baseType.GetDefinition(); if (baseTypeDef != null) { KnownTypeCode typeCode = baseTypeDef.KnownTypeCode; @@ -3124,7 +3145,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver { // This assumes queries are only used on IEnumerable. // We might want to look at the signature of a LINQ method (e.g. Select) instead. - return GetElementType(type, resolver.Compilation, false); + return GetElementTypeFromIEnumerable(type, resolver.Compilation, false); } sealed class QueryExpressionLambda : LambdaResolveResult diff --git a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/LocalTypeInferenceTests.cs b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/LocalTypeInferenceTests.cs index d31f92e0c4..66ec024aea 100644 --- a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/LocalTypeInferenceTests.cs +++ b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/LocalTypeInferenceTests.cs @@ -17,6 +17,8 @@ // DEALINGS IN THE SOFTWARE. using System; +using System.Collections; +using System.Collections.Generic; using ICSharpCode.NRefactory.Semantics; using ICSharpCode.NRefactory.TypeSystem; using NUnit.Framework; @@ -61,11 +63,152 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver string program = @"using System; class TestClass { static void Main() { - var contact = {id = 54321}; - $contact$.ToString(); - } }"; + var contact = {id = 54321}; + $contact$.ToString(); + } }"; var lrr = Resolve(program); Assert.AreEqual(SpecialType.UnknownType, lrr.Type); } + + [Test] + public void Foreach_InferFromArrayType() + { + string program = @"using System; +class TestClass { + static void Method(int[] arr) { + foreach ($var$ x in arr) {} + } }"; + var rr = Resolve(program); + Assert.AreEqual("System.Int32", rr.Type.ReflectionName); + } + + [Test] + public void Foreach_InferFromDynamic() + { + string program = @"using System; +class TestClass { + static void Method(dynamic c) { + foreach ($var$ x in c) {} + } }"; + var rr = Resolve(program); + Assert.AreEqual(TypeKind.Dynamic, rr.Type.Kind); + } + + [Test] + public void Foreach_InferFromListOfInt() + { + string program = @"using System; +using System.Collections.Generic; +class TestClass { + static void Method(List c) { + foreach ($var$ x in c) {} + } }"; + var rr = Resolve(program); + Assert.AreEqual("System.Int32", rr.Type.ReflectionName); + } + + [Test] + public void Foreach_InferFromICollectionOfInt() + { + string program = @"using System; +using System.Collections.Generic; +class TestClass { + static void Method(ICollection c) { + foreach ($var$ x in c) {} + } }"; + var rr = Resolve(program); + Assert.AreEqual("System.Int32", rr.Type.ReflectionName); + } + + [Test] + public void Foreach_InferFromCustomCollection_WithoutIEnumerable() + { + string program = @"using System; +using System.Collections.Generic; +class TestClass { + static void Method(CustomCollection c) { + foreach ($var$ x in c) {} + } +} +class CustomCollection { + public MyEnumerator GetEnumerator() {} + public struct MyEnumerator { + public string Current { get { return null; } } + public bool MoveNext() { return false; } + } +} +"; + var rr = Resolve(program); + Assert.AreEqual("System.String", rr.Type.ReflectionName); + } + + [Test] + public void Foreach_InferFromCustomCollection_WithIEnumerableAndPublicGetEnumerator() + { + string program = @"using System; +using System.Collections.Generic; +class TestClass { + static void Method(CustomCollection c) { + foreach ($var$ x in c) {} + } +} +class CustomCollection : IEnumerable { + public MyEnumerator GetEnumerator() {} + public struct MyEnumerator { + public string Current { get { return null; } } + public bool MoveNext() { return false; } + } +} +"; + var rr = Resolve(program); + Assert.AreEqual("System.String", rr.Type.ReflectionName); + } + + [Test] + public void Foreach_InferFromCustomCollection_WithIEnumerableAndInternalGetEnumerator() + { + string program = @"using System; +using System.Collections.Generic; +class TestClass { + static void Method(CustomCollection c) { + foreach ($var$ x in c) {} + } +} +class CustomCollection : IEnumerable { + internal MyEnumerator GetEnumerator() {} + public struct MyEnumerator { + public string Current { get { return null; } } + public bool MoveNext() { return false; } + } +} +"; + var rr = Resolve(program); + Assert.AreEqual("System.Int32", rr.Type.ReflectionName); + } + + [Test] + public void Foreach_InferFromCustomCollection_WithIEnumerableAndGetEnumeratorExtensionMethod() + { + string program = @"using System; +using System.Collections.Generic; +class TestClass { + static void Method(CustomCollection c) { + foreach ($var$ x in c) {} + } +} +class CustomCollection : IEnumerable { + public struct MyEnumerator { + public string Current { get { return null; } } + public bool MoveNext() { return false; } + } +} +static class ExtMethods { + public static CustomCollection.MyEnumerator GetEnumerator(this CustomCollection c) { + throw new NotImplementedException(); + } +}"; + var rr = Resolve(program); + Assert.AreEqual("System.Int32", rr.Type.ReflectionName); + } } }