diff --git a/ICSharpCode.NRefactory.CSharp/ICSharpCode.NRefactory.CSharp.csproj b/ICSharpCode.NRefactory.CSharp/ICSharpCode.NRefactory.CSharp.csproj index 37925b3369..428ae02e16 100644 --- a/ICSharpCode.NRefactory.CSharp/ICSharpCode.NRefactory.CSharp.csproj +++ b/ICSharpCode.NRefactory.CSharp/ICSharpCode.NRefactory.CSharp.csproj @@ -337,6 +337,7 @@ + diff --git a/ICSharpCode.NRefactory.CSharp/Resolver/AwaitResolveResult.cs b/ICSharpCode.NRefactory.CSharp/Resolver/AwaitResolveResult.cs new file mode 100644 index 0000000000..da758c4339 --- /dev/null +++ b/ICSharpCode.NRefactory.CSharp/Resolver/AwaitResolveResult.cs @@ -0,0 +1,80 @@ +// Copyright (c) AlphaSierraPapa for the SharpDevelop Team +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this +// software and associated documentation files (the "Software"), to deal in the Software +// without restriction, including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +// to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +// PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +// FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using ICSharpCode.NRefactory.Semantics; +using ICSharpCode.NRefactory.TypeSystem; + +namespace ICSharpCode.NRefactory.CSharp.Resolver +{ + /// + /// Represents the result of an await expression. + /// + public class AwaitResolveResult : ResolveResult + { + /// + /// The method representing the GetAwaiter() call. Can be null if the GetAwaiter method was not found. + /// + public readonly ResolveResult GetAwaiterInvocation; + + /// + /// Awaiter type. Will not be null (but can be UnknownType). + /// + public readonly IType AwaiterType; + + /// + /// Property representing the IsCompleted property on the awaiter type. Can be null if the awaiter type or the property was not found, or when awaiting a dynamic expression. + /// + public readonly IProperty IsCompletedProperty; + + /// + /// Method representing the OnCompleted method on the awaiter type. Can be null if the awaiter type or the method was not found, or when awaiting a dynamic expression. + /// + public readonly IMethod OnCompletedMethod; + + /// + /// Method representing the GetResult method on the awaiter type. Can be null if the awaiter type or the method was not found, or when awaiting a dynamic expression. + /// + public readonly IMethod GetResultMethod; + + public AwaitResolveResult(IType resultType, ResolveResult getAwaiterInvocation, IType awaiterType, IProperty isCompletedProperty, IMethod onCompletedMethod, IMethod getResultMethod) + : base(resultType) + { + if (awaiterType == null) + throw new ArgumentNullException("awaiterType"); + if (getAwaiterInvocation == null) + throw new ArgumentNullException("getAwaiterInvocation"); + this.GetAwaiterInvocation = getAwaiterInvocation; + this.AwaiterType = awaiterType; + this.IsCompletedProperty = isCompletedProperty; + this.OnCompletedMethod = onCompletedMethod; + this.GetResultMethod = getResultMethod; + } + + public override bool IsError { + get { return this.GetAwaiterInvocation.IsError || (AwaiterType.Kind != TypeKind.Dynamic && (this.IsCompletedProperty == null || this.OnCompletedMethod == null || this.GetResultMethod == null)); } + } + + public override IEnumerable GetChildResults() { + return new[] { GetAwaiterInvocation }; + } + } +} diff --git a/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs b/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs index 71dc976d16..7470dd5f41 100644 --- a/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs +++ b/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs @@ -360,8 +360,14 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver #region ResolveUnaryOperator method public ResolveResult ResolveUnaryOperator(UnaryOperatorType op, ResolveResult expression) { - if (expression.Type.Kind == TypeKind.Dynamic) - return UnaryOperatorResolveResult(SpecialType.Dynamic, op, expression); + if (expression.Type.Kind == TypeKind.Dynamic) { + if (op == UnaryOperatorType.Await) { + return new AwaitResolveResult(SpecialType.Dynamic, new DynamicInvocationResolveResult(new DynamicMemberResolveResult(expression, "GetAwaiter"), DynamicInvocationType.Invocation, EmptyList.Instance), SpecialType.Dynamic, null, null, null); + } + else { + return UnaryOperatorResolveResult(SpecialType.Dynamic, op, expression); + } + } // C# 4.0 spec: §7.3.3 Unary operator overload resolution string overloadableOperatorName = GetOverloadableOperatorName(op); @@ -375,17 +381,37 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver return ErrorResult; case UnaryOperatorType.AddressOf: return UnaryOperatorResolveResult(new PointerType(expression.Type), op, expression); - case UnaryOperatorType.Await: + case UnaryOperatorType.Await: { ResolveResult getAwaiterMethodGroup = ResolveMemberAccess(expression, "GetAwaiter", EmptyList.Instance, NameLookupMode.InvocationTarget); ResolveResult getAwaiterInvocation = ResolveInvocation(getAwaiterMethodGroup, new ResolveResult[0]); - var getResultMethodGroup = CreateMemberLookup().Lookup(getAwaiterInvocation, "GetResult", EmptyList.Instance, true) as MethodGroupResolveResult; + + var lookup = CreateMemberLookup(); + IMethod getResultMethod; + IType awaitResultType; + var getResultMethodGroup = lookup.Lookup(getAwaiterInvocation, "GetResult", EmptyList.Instance, true) as MethodGroupResolveResult; if (getResultMethodGroup != null) { - var or = getResultMethodGroup.PerformOverloadResolution(compilation, new ResolveResult[0], allowExtensionMethods: false, conversions: conversions); - IType awaitResultType = or.GetBestCandidateWithSubstitutedTypeArguments().ReturnType; - return UnaryOperatorResolveResult(awaitResultType, UnaryOperatorType.Await, expression); - } else { - return UnaryOperatorResolveResult(SpecialType.UnknownType, UnaryOperatorType.Await, expression); + var getResultOR = getResultMethodGroup.PerformOverloadResolution(compilation, new ResolveResult[0], allowExtensionMethods: false, conversions: conversions); + getResultMethod = getResultOR.FoundApplicableCandidate ? getResultOR.GetBestCandidateWithSubstitutedTypeArguments() as IMethod : null; + awaitResultType = getResultMethod != null ? getResultMethod.ReturnType : SpecialType.UnknownType; + } + else { + getResultMethod = null; + awaitResultType = SpecialType.UnknownType; } + + var isCompletedRR = lookup.Lookup(getAwaiterInvocation, "IsCompleted", EmptyList.Instance, false); + var isCompletedProperty = (isCompletedRR is MemberResolveResult ? ((MemberResolveResult)isCompletedRR).Member as IProperty : null); + + var onCompletedMethodGroup = lookup.Lookup(getAwaiterInvocation, "OnCompleted", EmptyList.Instance, true) as MethodGroupResolveResult; + IMethod onCompletedMethod = null; + if (onCompletedMethodGroup != null) { + var onCompletedOR = onCompletedMethodGroup.PerformOverloadResolution(compilation, new ResolveResult[] { new TypeResolveResult(compilation.FindType(new FullTypeName("System.Action"))) }, allowExtensionMethods: false, conversions: conversions); + onCompletedMethod = (onCompletedOR.FoundApplicableCandidate ? onCompletedOR.GetBestCandidateWithSubstitutedTypeArguments() as IMethod : null); + } + + return new AwaitResolveResult(awaitResultType, getAwaiterInvocation, getAwaiterInvocation.Type, isCompletedProperty, onCompletedMethod, getResultMethod); + } + default: throw new ArgumentException("Invalid value for UnaryOperatorType", "op"); } diff --git a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs index a87aa6d0f6..2039011f9d 100644 --- a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs +++ b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs @@ -274,5 +274,371 @@ class Test { Assert.IsFalse(rr.IsError); Assert.AreEqual(unchecked( (ushort)~3 ), rr.ConstantValue); } + + [Test] + public void Await() { + string program = @" +using System; +class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } +} +class MyAwaitable { + public MyAwaiter GetAwaiter() { return null; } + public MyAwaiter GetAwaiter(int i) { return null; } +} +public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32)); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(0, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("MyAwaitable.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(0, getAwaiterInvocation.Member.Parameters.Count); + + Assert.AreEqual("MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNotNull(rr.GetResultMethod); + Assert.AreEqual("MyAwaiter.GetResult", rr.GetResultMethod.FullName); + } + + [Test] + public void AwaitWhenGetAwaiterIsAnExtensionMethod() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } + } + class MyAwaitable { + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32)); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count); + Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x"); + + Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("N.MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNotNull(rr.GetResultMethod); + Assert.AreEqual("N.MyAwaiter.GetResult", rr.GetResultMethod.FullName); + } + + [Test, Ignore("TODO: MS C# (at least the RC version) refuses to use default values in GetAwaiter(). I do not know, however, if this is by design, and I could not find a simple, nice way to do the implementation")] + public void GetAwaiterMethodWithDefaultArgumentCannotBeUsed() { + string program = @" +using System; +class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } +} +class MyAwaitable { + public MyAwaiter GetAwaiter(int i = 0) { return null; } +} +public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.AreEqual(SpecialType.UnknownType, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + Assert.IsTrue(rr.GetAwaiterInvocation.IsError); + + Assert.AreEqual(rr.AwaiterType, SpecialType.UnknownType); + + Assert.IsNull(rr.IsCompletedProperty); + Assert.IsNull(rr.OnCompletedMethod); + Assert.IsNull(rr.GetResultMethod); + } + + [Test, Ignore("TODO: MS C# (at least the RC version) refuses to use default values in GetAwaiter(). I do not know, however, if this is by design, and I could not find a simple, nice way to do the implementation")] + public void GetAwaiterMethodWithDefaultArgumentHidesExtensionMethodAndResultsInError() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } + } + class MyAwaitable { + public MyAwaiter GetAwaiter(int i = 0) { return null; } + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.AreEqual(SpecialType.UnknownType, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + Assert.IsTrue(rr.GetAwaiterInvocation.IsError); + + Assert.AreEqual(rr.AwaiterType, SpecialType.UnknownType); + + Assert.IsNull(rr.IsCompletedProperty); + Assert.IsNull(rr.OnCompletedMethod); + Assert.IsNull(rr.GetResultMethod); + } + + [Test] + public void GetAwaiterMethodWithArgumentDoesNotHideExtensionMethod() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult() { return 0; } + } + class MyAwaitable { + public static MyAwaiter GetAwaiter(int i) { return null; } + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32)); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count); + Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x"); + + Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("N.MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNotNull(rr.GetResultMethod); + Assert.AreEqual("N.MyAwaiter.GetResult", rr.GetResultMethod.FullName); + } + + [Test] + public void AwaiterWithNoSuitableGetResult() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Action continuation) {} + public int GetResult(int i) { return 0; } + } + class MyAwaitable { + public static MyAwaiter GetAwaiter(int i) { return null; } + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsTrue(rr.IsError); + Assert.AreEqual(SpecialType.UnknownType, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count); + Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x"); + + Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("N.MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNull(rr.GetResultMethod); + } + + [Test] + public void AwaiterWithNoIsCompletedProperty() { + string program = @" +using System; +namespace N { + class MyAwaiter { + public bool IsCompleted() { return false; } + public void OnCompleted(Action continuation) {} + public int GetResult(int i) { return 0; } + } + class MyAwaitable { + public static MyAwaiter GetAwaiter(int i) { return null; } + } + static class MyAwaitableExtensions { + public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; } + } + public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } + } +}"; + + var rr = Resolve(program); + Assert.IsTrue(rr.IsError); + Assert.AreEqual(SpecialType.UnknownType, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count); + Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x"); + + Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNull(rr.IsCompletedProperty); + + Assert.IsNotNull(rr.OnCompletedMethod); + Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName); + + Assert.IsNull(rr.GetResultMethod); + } + + [Test] + public void AwaiterWithNoOnCompletedMethodWithSuitableSignature() { + string program = @" +using System; +class MyAwaiter { + public bool IsCompleted { get { return false; } } + public void OnCompleted(Func continuation) {} + public int GetResult() { return 0; } +} +class MyAwaitable { + public MyAwaiter GetAwaiter() { return null; } + public MyAwaiter GetAwaiter(int i) { return null; } +} +public class C { + public async void M() { + MyAwaitable x = null; + int i = $await x$; + } +}"; + + var rr = Resolve(program); + Assert.IsTrue(rr.IsError); + Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32)); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(0, getAwaiterInvocation.Arguments.Count); + Assert.AreEqual("MyAwaitable.GetAwaiter", getAwaiterInvocation.Member.FullName); + Assert.AreEqual(0, getAwaiterInvocation.Member.Parameters.Count); + + Assert.AreEqual("MyAwaiter", rr.AwaiterType.FullName); + + Assert.IsNotNull(rr.IsCompletedProperty); + Assert.AreEqual("MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName); + + Assert.IsNull(rr.OnCompletedMethod); + + Assert.IsNotNull(rr.GetResultMethod); + Assert.AreEqual("MyAwaiter.GetResult", rr.GetResultMethod.FullName); + } + + [Test] + public void AwaitDynamic() { + string program = @" +public class C { + public async void M() { + dynamic x = null; + int i = $await x$; + } +}"; + + var rr = Resolve(program); + Assert.IsFalse(rr.IsError); + Assert.AreEqual(SpecialType.Dynamic, rr.Type); + Assert.IsInstanceOf(rr.GetAwaiterInvocation); + var getAwaiterInvocation = (DynamicInvocationResolveResult)rr.GetAwaiterInvocation; + Assert.IsFalse(rr.GetAwaiterInvocation.IsError); + Assert.AreEqual(DynamicInvocationType.Invocation, getAwaiterInvocation.InvocationType); + Assert.AreEqual(0, getAwaiterInvocation.Arguments.Count); + Assert.IsInstanceOf(getAwaiterInvocation.Target); + var target = (DynamicMemberResolveResult)getAwaiterInvocation.Target; + Assert.IsInstanceOf(target.Target); + Assert.AreEqual("GetAwaiter", target.Member); + + Assert.AreEqual(SpecialType.Dynamic, rr.AwaiterType); + + Assert.IsNull(rr.IsCompletedProperty); + Assert.IsNull(rr.OnCompletedMethod); + Assert.IsNull(rr.GetResultMethod); + } } }