From 66f51bff3a30b50c7f12868baf3b0bbd75e95232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20K=C3=A4ll=C3=A9n?= Date: Tue, 25 Sep 2012 16:41:47 +0200 Subject: [PATCH 1/2] Added a separate AwaitResolveResult --- .../ICSharpCode.NRefactory.CSharp.csproj | 1 + .../Resolver/AwaitResolveResult.cs | 80 ++++ .../Resolver/CSharpResolver.cs | 44 ++- .../CSharp/Resolver/UnaryOperatorTests.cs | 366 ++++++++++++++++++ 4 files changed, 482 insertions(+), 9 deletions(-) create mode 100644 ICSharpCode.NRefactory.CSharp/Resolver/AwaitResolveResult.cs diff --git a/ICSharpCode.NRefactory.CSharp/ICSharpCode.NRefactory.CSharp.csproj b/ICSharpCode.NRefactory.CSharp/ICSharpCode.NRefactory.CSharp.csproj index 37925b3369..428ae02e16 100644 --- a/ICSharpCode.NRefactory.CSharp/ICSharpCode.NRefactory.CSharp.csproj +++ b/ICSharpCode.NRefactory.CSharp/ICSharpCode.NRefactory.CSharp.csproj @@ -337,6 +337,7 @@ + diff --git a/ICSharpCode.NRefactory.CSharp/Resolver/AwaitResolveResult.cs b/ICSharpCode.NRefactory.CSharp/Resolver/AwaitResolveResult.cs new file mode 100644 index 0000000000..da758c4339 --- /dev/null +++ b/ICSharpCode.NRefactory.CSharp/Resolver/AwaitResolveResult.cs @@ -0,0 +1,80 @@ +// Copyright (c) AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using ICSharpCode.NRefactory.Semantics; +using ICSharpCode.NRefactory.TypeSystem; + +namespace ICSharpCode.NRefactory.CSharp.Resolver +{ + /// + /// Represents the result of an await expression. + /// + public class AwaitResolveResult : ResolveResult + { + /// + /// The method representing the GetAwaiter() call. Can be null if the GetAwaiter method was not found. + /// + public readonly ResolveResult GetAwaiterInvocation; + + /// + /// Awaiter type. Will not be null (but can be UnknownType). + /// + public readonly IType AwaiterType; + + /// + /// Property representing the IsCompleted property on the awaiter type. Can be null if the awaiter type or the property was not found, or when awaiting a dynamic expression. + /// + public readonly IProperty IsCompletedProperty; + + /// + /// Method representing the OnCompleted method on the awaiter type. Can be null if the awaiter type or the method was not found, or when awaiting a dynamic expression. + /// + public readonly IMethod OnCompletedMethod; + + /// + /// Method representing the GetResult method on the awaiter type. Can be null if the awaiter type or the method was not found, or when awaiting a dynamic expression. + /// + public readonly IMethod GetResultMethod; + + public AwaitResolveResult(IType resultType, ResolveResult getAwaiterInvocation, IType awaiterType, IProperty isCompletedProperty, IMethod onCompletedMethod, IMethod getResultMethod) + : base(resultType) + { + if (awaiterType == null) + throw new ArgumentNullException("awaiterType"); + if (getAwaiterInvocation == null) + throw new ArgumentNullException("getAwaiterInvocation"); + this.GetAwaiterInvocation = getAwaiterInvocation; + this.AwaiterType = awaiterType; + this.IsCompletedProperty = isCompletedProperty; + this.OnCompletedMethod = onCompletedMethod; + this.GetResultMethod = getResultMethod; + } + + public override bool IsError { + get { return this.GetAwaiterInvocation.IsError || (AwaiterType.Kind != TypeKind.Dynamic && (this.IsCompletedProperty == null || this.OnCompletedMethod == null || this.GetResultMethod == null)); } + } + + public override IEnumerable GetChildResults() { + return new[] { GetAwaiterInvocation }; + } + } +} diff --git a/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs b/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs index 71dc976d16..7470dd5f41 100644 --- a/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs +++ b/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs @@ -360,8 +360,14 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver #region ResolveUnaryOperator method public ResolveResult ResolveUnaryOperator(UnaryOperatorType op, ResolveResult expression) { - if (expression.Type.Kind == TypeKind.Dynamic) - return UnaryOperatorResolveResult(SpecialType.Dynamic, op, expression); + if (expression.Type.Kind == TypeKind.Dynamic) { + if (op == UnaryOperatorType.Await) { + return new AwaitResolveResult(SpecialType.Dynamic, new DynamicInvocationResolveResult(new DynamicMemberResolveResult(expression, "GetAwaiter"), DynamicInvocationType.Invocation, EmptyList.Instance), SpecialType.Dynamic, null, null, null); + } + else { + return UnaryOperatorResolveResult(SpecialType.Dynamic, op, expression); + } + } // C# 4.0 spec: §7.3.3 Unary operator overload resolution string overloadableOperatorName = GetOverloadableOperatorName(op); @@ -375,17 +381,37 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver return ErrorResult; case UnaryOperatorType.AddressOf: return UnaryOperatorResolveResult(new PointerType(expression.Type), op, expression); - case UnaryOperatorType.Await: + case UnaryOperatorType.Await: { ResolveResult getAwaiterMethodGroup = ResolveMemberAccess(expression, "GetAwaiter", EmptyList.Instance, NameLookupMode.InvocationTarget); ResolveResult getAwaiterInvocation = ResolveInvocation(getAwaiterMethodGroup, new ResolveResult[0]); - var getResultMethodGroup = CreateMemberLookup().Lookup(getAwaiterInvocation, "GetResult", EmptyList.Instance, true) as MethodGroupResolveResult; + + var lookup = CreateMemberLookup(); + IMethod getResultMethod; + IType awaitResultType; + var getResultMethodGroup = lookup.Lookup(getAwaiterInvocation, "GetResult", EmptyList.Instance, true) as MethodGroupResolveResult; if (getResultMethodGroup != null) { - var or = getResultMethodGroup.PerformOverloadResolution(compilation, new ResolveResult[0], allowExtensionMethods: false, conversions: conversions); - IType awaitResultType = or.GetBestCandidateWithSubstitutedTypeArguments().ReturnType; - return UnaryOperatorResolveResult(awaitResultType, UnaryOperatorType.Await, expression); - } else { - return UnaryOperatorResolveResult(SpecialType.UnknownType, UnaryOperatorType.Await, expression); + var getResultOR = getResultMethodGroup.PerformOverloadResolution(compilation, new ResolveResult[0], allowExtensionMethods: false, conversions: conversions); + getResultMethod = getResultOR.FoundApplicableCandidate ? getResultOR.GetBestCandidateWithSubstitutedTypeArguments() as IMethod : null; + awaitResultType = getResultMethod != null ? getResultMethod.ReturnType : SpecialType.UnknownType; + } + else { + getResultMethod = null; + awaitResultType = SpecialType.UnknownType; } + + var isCompletedRR = lookup.Lookup(getAwaiterInvocation, "IsCompleted", EmptyList.Instance, false); + var isCompletedProperty = (isCompletedRR is MemberResolveResult ? ((MemberResolveResult)isCompletedRR).Member as IProperty : null); + + var onCompletedMethodGroup = lookup.Lookup(getAwaiterInvocation, "OnCompleted", EmptyList.Instance, true) as MethodGroupResolveResult; + IMethod onCompletedMethod = null; + if (onCompletedMethodGroup != null) { + var onCompletedOR = onCompletedMethodGroup.PerformOverloadResolution(compilation, new ResolveResult[] { new TypeResolveResult(compilation.FindType(new FullTypeName("System.Action"))) }, allowExtensionMethods: false, conversions: conversions); + onCompletedMethod = (onCompletedOR.FoundApplicableCandidate ? onCompletedOR.GetBestCandidateWithSubstitutedTypeArguments() as IMethod : null); + } + + return new AwaitResolveResult(awaitResultType, getAwaiterInvocation, getAwaiterInvocation.Type, isCompletedProperty, onCompletedMethod, getResultMethod); + } + default: throw new ArgumentException("Invalid value for UnaryOperatorType", "op"); } diff --git a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs index a87aa6d0f6..2039011f9d 100644 --- a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs +++ b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs @@ -274,5 +274,371 @@ class Test { Assert.IsFalse(rr.IsError); Assert.AreEqual(unchecked( (ushort)~3 ), rr.ConstantValue); } + + [Test] + public void Await() { + string program = @" +using System; +class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } +} +class MyAwaitable { + public MyAwaiter GetAwaiter() { return null; } + public MyAwaiter GetAwaiter(int i) { return null; } +} +public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32)); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(0, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("MyAwaitable.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(0, getAwaiterInvocation.Member.Parameters.Count); + + Assert.AreEqual("MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNotNull(rr.GetResultMethod); + Assert.AreEqual("MyAwaiter.GetResult", rr.GetResultMethod.FullName); + } + + [Test] + public void AwaitWhenGetAwaiterIsAnExtensionMethod() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } + } + class MyAwaitable { + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32)); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count); + Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x"); + + Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("N.MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNotNull(rr.GetResultMethod); + Assert.AreEqual("N.MyAwaiter.GetResult", rr.GetResultMethod.FullName); + } + + [Test, Ignore("TODO: MS C# (at least the RC version) refuses to use default values in GetAwaiter(). I do not know, however, if this is by design, and I could not find a simple, nice way to do the implementation")] + public void GetAwaiterMethodWithDefaultArgumentCannotBeUsed() { + string program = @" +using System; +class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } +} +class MyAwaitable { + public MyAwaiter GetAwaiter(int i = 0) { return null; } +} +public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.AreEqual(SpecialType.UnknownType, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + Assert.IsTrue(rr.GetAwaiterInvocation.IsError); + + Assert.AreEqual(rr.AwaiterType, SpecialType.UnknownType); + + Assert.IsNull(rr.IsCompletedProperty); + Assert.IsNull(rr.OnCompletedMethod); + Assert.IsNull(rr.GetResultMethod); + } + + [Test, Ignore("TODO: MS C# (at least the RC version) refuses to use default values in GetAwaiter(). I do not know, however, if this is by design, and I could not find a simple, nice way to do the implementation")] + public void GetAwaiterMethodWithDefaultArgumentHidesExtensionMethodAndResultsInError() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } + } + class MyAwaitable { + public MyAwaiter GetAwaiter(int i = 0) { return null; } + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.AreEqual(SpecialType.UnknownType, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + Assert.IsTrue(rr.GetAwaiterInvocation.IsError); + + Assert.AreEqual(rr.AwaiterType, SpecialType.UnknownType); + + Assert.IsNull(rr.IsCompletedProperty); + Assert.IsNull(rr.OnCompletedMethod); + Assert.IsNull(rr.GetResultMethod); + } + + [Test] + public void GetAwaiterMethodWithArgumentDoesNotHideExtensionMethod() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } + } + class MyAwaitable { + public static MyAwaiter GetAwaiter(int i) { return null; } + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32)); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count); + Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x"); + + Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("N.MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNotNull(rr.GetResultMethod); + Assert.AreEqual("N.MyAwaiter.GetResult", rr.GetResultMethod.FullName); + } + + [Test] + public void AwaiterWithNoSuitableGetResult() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult(int i) { return 0; } + } + class MyAwaitable { + public static MyAwaiter GetAwaiter(int i) { return null; } + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsTrue(rr.IsError); + Assert.AreEqual(SpecialType.UnknownType, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count); + Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x"); + + Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("N.MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNull(rr.GetResultMethod); + } + + [Test] + public void AwaiterWithNoIsCompletedProperty() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted() { return false; } + public void OnCompleted(Action continuation) {} + public int GetResult(int i) { return 0; } + } + class MyAwaitable { + public static MyAwaiter GetAwaiter(int i) { return null; } + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsTrue(rr.IsError); + Assert.AreEqual(SpecialType.UnknownType, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count); + Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x"); + + Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNull(rr.IsCompletedProperty); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNull(rr.GetResultMethod); + } + + [Test] + public void AwaiterWithNoOnCompletedMethodWithSuitableSignature() { + string program = @" +using System; +class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Func continuation) {} + public int GetResult() { return 0; } +} +class MyAwaitable { + public MyAwaiter GetAwaiter() { return null; } + public MyAwaiter GetAwaiter(int i) { return null; } +} +public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } +}"; + + var rr = Resolve(program); + Assert.IsTrue(rr.IsError); + Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32)); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(0, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("MyAwaitable.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(0, getAwaiterInvocation.Member.Parameters.Count); + + Assert.AreEqual("MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNull(rr.OnCompletedMethod); + + Assert.IsNotNull(rr.GetResultMethod); + Assert.AreEqual("MyAwaiter.GetResult", rr.GetResultMethod.FullName); + } + + [Test] + public void AwaitDynamic() { + string program = @" +public class C { + public async void M() { + dynamic x = null; + int i = $await x$; + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.AreEqual(SpecialType.Dynamic, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (DynamicInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(DynamicInvocationType.Invocation, getAwaiterInvocation.InvocationType); + Assert.AreEqual(0, getAwaiterInvocation.Arguments.Count); + Assert.IsInstanceOf(getAwaiterInvocation.Target); + var target = (DynamicMemberResolveResult)getAwaiterInvocation.Target; + Assert.IsInstanceOf(target.Target); + Assert.AreEqual("GetAwaiter", target.Member); + + Assert.AreEqual(SpecialType.Dynamic, rr.AwaiterType); + + Assert.IsNull(rr.IsCompletedProperty); + Assert.IsNull(rr.OnCompletedMethod); + Assert.IsNull(rr.GetResultMethod); + } } } From 8d5536e2f6a22a741cfba4d72a7053877bc8a3c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20K=C3=A4ll=C3=A9n?= Date: Tue, 25 Sep 2012 17:06:04 +0200 Subject: [PATCH 2/2] Handle await expressions in find references. --- .../Resolver/FindReferences.cs | 43 ++++++++++++++ .../CSharp/Resolver/FindReferencesTest.cs | 58 +++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/ICSharpCode.NRefactory.CSharp/Resolver/FindReferences.cs b/ICSharpCode.NRefactory.CSharp/Resolver/FindReferences.cs index c269f581d1..15382b0e8a 100644 --- a/ICSharpCode.NRefactory.CSharp/Resolver/FindReferences.cs +++ b/ICSharpCode.NRefactory.CSharp/Resolver/FindReferences.cs @@ -244,6 +244,8 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver scope = FindMemberReferences(entity, m => new FindPropertyReferences((IProperty)m)); if (entity.Name == "Current") additionalScope = FindEnumeratorCurrentReferences((IProperty)entity); + else if (entity.Name == "IsCompleted") + additionalScope = FindAwaiterIsCompletedReferences((IProperty)entity); break; case EntityType.Event: scope = FindMemberReferences(entity, m => new FindEventReferences((IEvent)m)); @@ -661,6 +663,15 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver return imported != null ? new FindEnumeratorCurrentReferencesNavigator(imported) : null; }); } + + SearchScope FindAwaiterIsCompletedReferences(IProperty property) + { + return new SearchScope( + delegate(ICompilation compilation) { + IProperty imported = compilation.Import(property); + return imported != null ? new FindAwaiterIsCompletedReferencesNavigator(imported) : null; + }); + } sealed class FindEnumeratorCurrentReferencesNavigator : FindReferenceNavigator { @@ -682,6 +693,27 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver return ferr != null && ferr.CurrentProperty != null && findReferences.IsMemberMatch(property, ferr.CurrentProperty, true); } } + + sealed class FindAwaiterIsCompletedReferencesNavigator : FindReferenceNavigator + { + IProperty property; + + public FindAwaiterIsCompletedReferencesNavigator(IProperty property) + { + this.property = property; + } + + internal override bool CanMatch(AstNode node) + { + return node is UnaryOperatorExpression; + } + + internal override bool IsMatch(ResolveResult rr) + { + AwaitResolveResult arr = rr as AwaitResolveResult; + return arr != null && arr.IsCompletedProperty != null && findReferences.IsMemberMatch(property, arr.IsCompletedProperty, true); + } + } #endregion #region Find Method References @@ -724,6 +756,11 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver case "MoveNext": specialNodeType = typeof(ForeachStatement); break; + case "GetAwaiter": + case "GetResult": + case "OnCompleted": + specialNodeType = typeof(UnaryOperatorExpression); + break; default: specialNodeType = null; break; @@ -794,6 +831,12 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver return IsMatch(ferr.GetEnumeratorCall) || (ferr.MoveNextMethod != null && findReferences.IsMemberMatch(method, ferr.MoveNextMethod, true)); } + var arr = rr as AwaitResolveResult; + if (arr != null) { + return IsMatch(arr.GetAwaiterInvocation) + || (arr.GetResultMethod != null && findReferences.IsMemberMatch(method, arr.GetResultMethod, true)) + || (arr.OnCompletedMethod != null && findReferences.IsMemberMatch(method, arr.OnCompletedMethod, true)); + } } var mrr = rr as MemberResolveResult; return mrr != null && findReferences.IsMemberMatch(method, mrr.Member, mrr.IsVirtualCall); diff --git a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs index 87cf3cc3cf..6fc077d250 100644 --- a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs +++ b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs @@ -211,5 +211,63 @@ class Calls { Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 9 && r is InvocationExpression)); } #endregion + + #region Await + const string awaitTest = @"using System; +class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } +} +class MyAwaitable { + public MyAwaiter GetAwaiter() { return null; } +} +public class C { + public async void M() { + MyAwaitable x = null; + int i = await x; + } +}"; + + [Test] + public void GetAwaiterReferenceInAwaitExpressionIsFound() { + Init(awaitTest); + var test = compilation.MainAssembly.TopLevelTypeDefinitions.Single(t => t.Name == "MyAwaitable"); + var method = test.Methods.Single(m => m.Name == "GetAwaiter"); + var actual = FindReferences(method).ToList(); + Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 8 && r is MethodDeclaration)); + Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 13 && r is UnaryOperatorExpression)); + } + + [Test] + public void GetResultReferenceInAwaitExpressionIsFound() { + Init(awaitTest); + var test = compilation.MainAssembly.TopLevelTypeDefinitions.Single(t => t.Name == "MyAwaiter"); + var method = test.Methods.Single(m => m.Name == "GetResult"); + var actual = FindReferences(method).ToList(); + Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 5 && r is MethodDeclaration)); + Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 13 && r is UnaryOperatorExpression)); + } + + [Test] + public void OnCompletedReferenceInAwaitExpressionIsFound() { + Init(awaitTest); + var test = compilation.MainAssembly.TopLevelTypeDefinitions.Single(t => t.Name == "MyAwaiter"); + var method = test.Methods.Single(m => m.Name == "OnCompleted"); + var actual = FindReferences(method).ToList(); + Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 4 && r is MethodDeclaration)); + Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 13 && r is UnaryOperatorExpression)); + } + + [Test] + public void IsCompletedReferenceInAwaitExpressionIsFound() { + Init(awaitTest); + var test = compilation.MainAssembly.TopLevelTypeDefinitions.Single(t => t.Name == "MyAwaiter"); + var property = test.Properties.Single(m => m.Name == "IsCompleted"); + var actual = FindReferences(property).ToList(); + Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 3 && r is PropertyDeclaration)); + Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 13 && r is UnaryOperatorExpression)); + } + #endregion } }