diff --git a/src/Libraries/NRefactory/Project/Src/Visitors/LookupTableVisitor.cs b/src/Libraries/NRefactory/Project/Src/Visitors/LookupTableVisitor.cs index edfebb0a1e..447c480bef 100644 --- a/src/Libraries/NRefactory/Project/Src/Visitors/LookupTableVisitor.cs +++ b/src/Libraries/NRefactory/Project/Src/Visitors/LookupTableVisitor.cs @@ -140,49 +140,72 @@ namespace ICSharpCode.NRefactory.Visitors return base.VisitLambdaExpression(lambdaExpression, data); } + public override object VisitQueryExpression(QueryExpression queryExpression, object data) + { + endLocationStack.Push(GetQueryVariableEndScope(queryExpression)); + base.VisitQueryExpression(queryExpression, data); + endLocationStack.Pop(); + return null; + } + + Location GetQueryVariableEndScope(QueryExpression queryExpression) + { + return queryExpression.IntoClause.IsNull + ? queryExpression.EndLocation + : queryExpression.IntoClause.StartLocation; + } + public override object VisitQueryExpressionFromClause(QueryExpressionFromClause fromClause, object data) { - QueryExpression parentExpression = fromClause.Parent as QueryExpression; + AddVariable(fromClause.Type, fromClause.Identifier, + fromClause.StartLocation, CurrentEndLocation, + false, true, fromClause.InExpression, null); + return base.VisitQueryExpressionFromClause(fromClause, data); + } + + public override object VisitQueryExpressionIntoClause(QueryExpressionIntoClause intoClause, object data) + { + QueryExpression parentExpression = intoClause.Parent as QueryExpression; if (parentExpression != null) { - AddVariable(fromClause.Type, fromClause.Identifier, - parentExpression.StartLocation, parentExpression.EndLocation, - false, true, fromClause.InExpression, null); + Expression initializer = null; + var selectClause = parentExpression.SelectOrGroupClause as QueryExpressionSelectClause; + if (selectClause != null) { + initializer = selectClause.Projection; + } else { + var groupByClause = parentExpression.SelectOrGroupClause as QueryExpressionGroupClause; + if (groupByClause != null) + initializer = groupByClause.Projection; + } + AddVariable(null, intoClause.IntoIdentifier, + intoClause.StartLocation, GetQueryVariableEndScope(intoClause.ContinuedQuery), + false, false, initializer, null); } - return base.VisitQueryExpressionFromClause(fromClause, data); + return base.VisitQueryExpressionIntoClause(intoClause, data); } public override object VisitQueryExpressionJoinClause(QueryExpressionJoinClause joinClause, object data) { if (string.IsNullOrEmpty(joinClause.IntoIdentifier)) { - QueryExpression parentExpression = joinClause.Parent as QueryExpression; - if (parentExpression != null) { - AddVariable(joinClause.Type, joinClause.Identifier, - parentExpression.StartLocation, parentExpression.EndLocation, - false, true, joinClause.InExpression, null); - } + AddVariable(joinClause.Type, joinClause.Identifier, + joinClause.StartLocation, CurrentEndLocation, + false, true, joinClause.InExpression, null); } else { AddVariable(joinClause.Type, joinClause.Identifier, joinClause.StartLocation, joinClause.EndLocation, false, true, joinClause.InExpression, null); - QueryExpression parentExpression = joinClause.Parent as QueryExpression; - if (parentExpression != null) { - AddVariable(joinClause.Type, joinClause.IntoIdentifier, - parentExpression.StartLocation, parentExpression.EndLocation, - false, false, joinClause.InExpression, null); - } + AddVariable(joinClause.Type, joinClause.IntoIdentifier, + joinClause.StartLocation, CurrentEndLocation, + false, false, joinClause.InExpression, null); } return base.VisitQueryExpressionJoinClause(joinClause, data); } public override object VisitQueryExpressionLetClause(QueryExpressionLetClause letClause, object data) { - QueryExpression parentExpression = letClause.Parent as QueryExpression; - if (parentExpression != null) { - AddVariable(null, letClause.Identifier, - parentExpression.StartLocation, parentExpression.EndLocation, - false, false, letClause.Expression, null); - } + AddVariable(null, letClause.Identifier, + letClause.StartLocation, CurrentEndLocation, + false, false, letClause.Expression, null); return base.VisitQueryExpressionLetClause(letClause, data); } diff --git a/src/Main/Base/Test/NRefactoryResolverTests.cs b/src/Main/Base/Test/NRefactoryResolverTests.cs index 441d4b9e06..c39f2b8404 100644 --- a/src/Main/Base/Test/NRefactoryResolverTests.cs +++ b/src/Main/Base/Test/NRefactoryResolverTests.cs @@ -1886,6 +1886,30 @@ public class MyCollectionType : System.Collections.IEnumerable Assert.IsNotNull(mrr); Assert.AreEqual("Point.X", mrr.ResolvedMember.FullyQualifiedName); } + + [Test] + public void LinqQueryContinuationTest() + { + string program = @"using System; +class TestClass { + void Test(string[] input) { + var r = from x in input + select x.GetHashCode() into x + where x == 42 + select x * x; + + } +} +"; + LocalResolveResult lrr = Resolve(program, "x", 5, 11, ExpressionContext.Default); + Assert.AreEqual("System.String", lrr.ResolvedType.FullyQualifiedName); + lrr = Resolve(program, "x", 6, 10, ExpressionContext.Default); + Assert.AreEqual("System.Int32", lrr.ResolvedType.FullyQualifiedName); + + lrr = Resolve(program, "r", 8); + Assert.AreEqual("System.Collections.Generic.IEnumerable", lrr.ResolvedType.FullyQualifiedName); + Assert.AreEqual("System.Int32", lrr.ResolvedType.CastToConstructedReturnType().TypeArguments[0].FullyQualifiedName); + } #endregion [Test]