Browse Source

Make 'await' resolve as in the C# 5.0 language specification.

This means that the awaiter type must implement INotifyCompletion and can optionally implement ICriticalNotifyCompletion.
newNRvisualizers
erikkallen 13 years ago committed by Erik Källén
parent
commit
c615c9f730
  1. 3
      ICSharpCode.NRefactory.CSharp/Resolver/AwaitResolveResult.cs
  2. 52
      ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs
  3. 1
      ICSharpCode.NRefactory.CSharp/Resolver/FindReferences.cs
  4. 33
      ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs
  5. 289
      ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs
  6. 22
      ICSharpCode.NRefactory/TypeSystem/KnownTypeReference.cs

3
ICSharpCode.NRefactory.CSharp/Resolver/AwaitResolveResult.cs

@ -31,7 +31,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -31,7 +31,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
public class AwaitResolveResult : ResolveResult
{
/// <summary>
/// The method representing the GetAwaiter() call. Can be null if the GetAwaiter method was not found.
/// The method representing the GetAwaiter() call. Can be an <see cref="InvocationResolveResult"/> or a <see cref="DynamicInvocationResolveResult"/>.
/// </summary>
public readonly ResolveResult GetAwaiterInvocation;
@ -47,6 +47,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -47,6 +47,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
/// <summary>
/// Method representing the OnCompleted method on the awaiter type. Can be null if the awaiter type or the method was not found, or when awaiting a dynamic expression.
/// This can also refer to an UnsafeOnCompleted method, if the awaiter type implements <c>System.Runtime.CompilerServices.ICriticalNotifyCompletion</c>.
/// </summary>
public readonly IMethod OnCompletedMethod;

52
ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs

@ -383,7 +383,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -383,7 +383,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
return UnaryOperatorResolveResult(new PointerType(expression.Type), op, expression);
case UnaryOperatorType.Await: {
ResolveResult getAwaiterMethodGroup = ResolveMemberAccess(expression, "GetAwaiter", EmptyList<IType>.Instance, NameLookupMode.InvocationTarget);
ResolveResult getAwaiterInvocation = ResolveInvocation(getAwaiterMethodGroup, new ResolveResult[0]);
ResolveResult getAwaiterInvocation = ResolveInvocation(getAwaiterMethodGroup, new ResolveResult[0], argumentNames: null, allowOptionalParameters: false);
var lookup = CreateMemberLookup();
IMethod getResultMethod;
@ -401,12 +401,21 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -401,12 +401,21 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
var isCompletedRR = lookup.Lookup(getAwaiterInvocation, "IsCompleted", EmptyList<IType>.Instance, false);
var isCompletedProperty = (isCompletedRR is MemberResolveResult ? ((MemberResolveResult)isCompletedRR).Member as IProperty : null);
if (isCompletedProperty != null && (!isCompletedProperty.ReturnType.IsKnownType(KnownTypeCode.Boolean) || !isCompletedProperty.CanGet))
isCompletedProperty = null;
var interfaceOnCompleted = compilation.FindType(KnownTypeCode.INotifyCompletion).GetMethods().FirstOrDefault(x => x.Name == "OnCompleted");
var interfaceUnsafeOnCompleted = compilation.FindType(KnownTypeCode.ICriticalNotifyCompletion).GetMethods().FirstOrDefault(x => x.Name == "UnsafeOnCompleted");
var onCompletedMethodGroup = lookup.Lookup(getAwaiterInvocation, "OnCompleted", EmptyList<IType>.Instance, true) as MethodGroupResolveResult;
IMethod onCompletedMethod = null;
if (onCompletedMethodGroup != null) {
var onCompletedOR = onCompletedMethodGroup.PerformOverloadResolution(compilation, new ResolveResult[] { new TypeResolveResult(compilation.FindType(new FullTypeName("System.Action"))) }, allowExtensionMethods: false, conversions: conversions);
onCompletedMethod = (onCompletedOR.FoundApplicableCandidate ? onCompletedOR.GetBestCandidateWithSubstitutedTypeArguments() as IMethod : null);
var candidates = getAwaiterInvocation.Type.GetMethods().Where(x => x.ImplementedInterfaceMembers.Contains(interfaceUnsafeOnCompleted)).ToList();
if (candidates.Count == 0) {
candidates = getAwaiterInvocation.Type.GetMethods().Where(x => x.ImplementedInterfaceMembers.Contains(interfaceOnCompleted)).ToList();
if (candidates.Count == 1)
onCompletedMethod = candidates[0];
}
else if (candidates.Count == 1) {
onCompletedMethod = candidates[0];
}
return new AwaitResolveResult(awaitResultType, getAwaiterInvocation, getAwaiterInvocation.Type, isCompletedProperty, onCompletedMethod, getResultMethod);
@ -1927,19 +1936,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -1927,19 +1936,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
}
}
/// <summary>
/// Resolves an invocation.
/// </summary>
/// <param name="target">The target of the invocation. Usually a MethodGroupResolveResult.</param>
/// <param name="arguments">
/// Arguments passed to the method.
/// The resolver may mutate this array to wrap elements in <see cref="ConversionResolveResult"/>s!
/// </param>
/// <param name="argumentNames">
/// The argument names. Pass the null string for positional arguments.
/// </param>
/// <returns>InvocationResolveResult or UnknownMethodResolveResult</returns>
public ResolveResult ResolveInvocation(ResolveResult target, ResolveResult[] arguments, string[] argumentNames = null)
private ResolveResult ResolveInvocation(ResolveResult target, ResolveResult[] arguments, string[] argumentNames, bool allowOptionalParameters)
{
// C# 4.0 spec: §7.6.5
@ -1971,7 +1968,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -1971,7 +1968,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
}
}
OverloadResolution or = mgrr.PerformOverloadResolution(compilation, arguments, argumentNames, checkForOverflow: checkForOverflow, conversions: conversions);
OverloadResolution or = mgrr.PerformOverloadResolution(compilation, arguments, argumentNames, checkForOverflow: checkForOverflow, conversions: conversions, allowOptionalParameters: allowOptionalParameters);
if (or.BestCandidate != null) {
if (or.BestCandidate.IsStatic && !or.IsExtensionMethodInvocation && !(mgrr.TargetResult is TypeResolveResult))
return or.CreateResolveResult(new TypeResolveResult(mgrr.TargetType));
@ -2005,6 +2002,23 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -2005,6 +2002,23 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
}
return ErrorResult;
}
/// <summary>
/// Resolves an invocation.
/// </summary>
/// <param name="target">The target of the invocation. Usually a MethodGroupResolveResult.</param>
/// <param name="arguments">
/// Arguments passed to the method.
/// The resolver may mutate this array to wrap elements in <see cref="ConversionResolveResult"/>s!
/// </param>
/// <param name="argumentNames">
/// The argument names. Pass the null string for positional arguments.
/// </param>
/// <returns>InvocationResolveResult or UnknownMethodResolveResult</returns>
public ResolveResult ResolveInvocation(ResolveResult target, ResolveResult[] arguments, string[] argumentNames = null)
{
return ResolveInvocation(target, arguments, argumentNames, allowOptionalParameters: true);
}
List<IParameter> CreateParameters(ResolveResult[] arguments, string[] argumentNames)
{

1
ICSharpCode.NRefactory.CSharp/Resolver/FindReferences.cs

@ -759,6 +759,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver @@ -759,6 +759,7 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver
case "GetAwaiter":
case "GetResult":
case "OnCompleted":
case "UnsafeOnCompleted":
specialNodeType = typeof(UnaryOperatorExpression);
break;
default:

33
ICSharpCode.NRefactory.Tests/CSharp/Resolver/FindReferencesTest.cs

@ -213,8 +213,11 @@ class Calls { @@ -213,8 +213,11 @@ class Calls {
#endregion
#region Await
#if NET_4_5
const string awaitTest = @"using System;
class MyAwaiter {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
@ -268,6 +271,34 @@ public class C { @@ -268,6 +271,34 @@ public class C {
Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 3 && r is PropertyDeclaration));
Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 13 && r is UnaryOperatorExpression));
}
[Test]
public void UnsafeOnCompletedReferenceInAwaitExpressionIsFound() {
Init(@"using System;
class MyAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public void UnsafeOnCompleted(Action continuation) {}
public int GetResult() { return 0; }
}
class MyAwaitable {
public MyAwaiter GetAwaiter() { return null; }
}
public class C {
public async void M() {
MyAwaitable x = null;
int i = await x;
}
}");
var test = compilation.MainAssembly.TopLevelTypeDefinitions.Single(t => t.Name == "MyAwaiter");
var method = test.Methods.Single(m => m.Name == "UnsafeOnCompleted");
var actual = FindReferences(method).ToList();
Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 5 && r is MethodDeclaration));
Assert.IsTrue(actual.Any(r => r.StartLocation.Line == 14 && r is UnaryOperatorExpression));
}
#endif // NET_4_5
#endregion
}
}

289
ICSharpCode.NRefactory.Tests/CSharp/Resolver/UnaryOperatorTests.cs

@ -275,11 +275,12 @@ class Test { @@ -275,11 +275,12 @@ class Test {
Assert.AreEqual(unchecked( (ushort)~3 ), rr.ConstantValue);
}
#if NET_4_5
[Test]
public void Await() {
string program = @"
using System;
class MyAwaiter {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
@ -322,7 +323,7 @@ public class C { @@ -322,7 +323,7 @@ public class C {
string program = @"
using System;
namespace N {
class MyAwaiter {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
@ -363,11 +364,11 @@ namespace N { @@ -363,11 +364,11 @@ namespace N {
Assert.AreEqual("N.MyAwaiter.GetResult", rr.GetResultMethod.FullName);
}
[Test, Ignore("TODO: MS C# (at least the RC version) refuses to use default values in GetAwaiter(). I do not know, however, if this is by design, and I could not find a simple, nice way to do the implementation")]
[Test]
public void GetAwaiterMethodWithDefaultArgumentCannotBeUsed() {
string program = @"
using System;
class MyAwaiter {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
@ -383,24 +384,16 @@ public class C { @@ -383,24 +384,16 @@ public class C {
}";
var rr = Resolve<AwaitResolveResult>(program);
Assert.IsFalse(rr.IsError);
Assert.AreEqual(SpecialType.UnknownType, rr.Type);
Assert.IsInstanceOf<CSharpInvocationResolveResult>(rr.GetAwaiterInvocation);
Assert.IsTrue(rr.IsError);
Assert.IsTrue(rr.GetAwaiterInvocation.IsError);
Assert.AreEqual(rr.AwaiterType, SpecialType.UnknownType);
Assert.IsNull(rr.IsCompletedProperty);
Assert.IsNull(rr.OnCompletedMethod);
Assert.IsNull(rr.GetResultMethod);
}
[Test, Ignore("TODO: MS C# (at least the RC version) refuses to use default values in GetAwaiter(). I do not know, however, if this is by design, and I could not find a simple, nice way to do the implementation")]
[Test, Ignore("TODO: MS C# refuses to use an extension method GetAwaiter() when there is an instance GetAwaiter() with only optional arguments. I do not know, however, if this is by design, and I could not find a simple, nice way to do the implementation")]
public void GetAwaiterMethodWithDefaultArgumentHidesExtensionMethodAndResultsInError() {
string program = @"
using System;
namespace N {
class MyAwaiter {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
@ -420,7 +413,7 @@ namespace N { @@ -420,7 +413,7 @@ namespace N {
}";
var rr = Resolve<AwaitResolveResult>(program);
Assert.IsFalse(rr.IsError);
Assert.IsTrue(rr.IsError);
Assert.AreEqual(SpecialType.UnknownType, rr.Type);
Assert.IsInstanceOf<CSharpInvocationResolveResult>(rr.GetAwaiterInvocation);
Assert.IsTrue(rr.GetAwaiterInvocation.IsError);
@ -437,7 +430,7 @@ namespace N { @@ -437,7 +430,7 @@ namespace N {
string program = @"
using System;
namespace N {
class MyAwaiter {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
@ -479,12 +472,35 @@ namespace N { @@ -479,12 +472,35 @@ namespace N {
Assert.AreEqual("N.MyAwaiter.GetResult", rr.GetResultMethod.FullName);
}
[Test]
public void GenericGetAwaiterResultsInError() {
string program = @"
using System;
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
}
class MyAwaitable {
public MyAwaiter GetAwaiter<T>() { return null; }
}
public class C {
public async void M() {
MyAwaitable x = null;
int i = $await x$;
}
}";
var rr = Resolve<AwaitResolveResult>(program);
Assert.IsTrue(rr.IsError);
}
[Test]
public void AwaiterWithNoSuitableGetResult() {
string program = @"
using System;
namespace N {
class MyAwaiter {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public int GetResult(int i) { return 0; }
@ -525,15 +541,61 @@ namespace N { @@ -525,15 +541,61 @@ namespace N {
Assert.IsNull(rr.GetResultMethod);
}
[Test]
public void AwaiterWithInaccessibleGetResult() {
string program = @"
using System;
namespace N {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
private int GetResult() { return 0; }
}
class MyAwaitable {
public static MyAwaiter GetAwaiter(int i) { return null; }
}
static class MyAwaitableExtensions {
public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; }
}
public class C {
public async void M() {
MyAwaitable x = null;
int i = $await x$;
}
}
}";
var rr = Resolve<AwaitResolveResult>(program);
Assert.IsTrue(rr.IsError);
Assert.AreEqual(SpecialType.UnknownType, rr.Type);
Assert.IsInstanceOf<CSharpInvocationResolveResult>(rr.GetAwaiterInvocation);
var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation;
Assert.IsFalse(rr.GetAwaiterInvocation.IsError);
Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count);
Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName);
Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count);
Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x");
Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName);
Assert.IsNotNull(rr.IsCompletedProperty);
Assert.AreEqual("N.MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName);
Assert.IsNotNull(rr.OnCompletedMethod);
Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName);
Assert.IsNull(rr.GetResultMethod);
}
[Test]
public void AwaiterWithNoIsCompletedProperty() {
string program = @"
using System;
namespace N {
class MyAwaiter {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted() { return false; }
public void OnCompleted(Action continuation) {}
public int GetResult(int i) { return 0; }
public int GetResult() { return 0; }
}
class MyAwaitable {
public static MyAwaiter GetAwaiter(int i) { return null; }
@ -551,7 +613,7 @@ namespace N { @@ -551,7 +613,7 @@ namespace N {
var rr = Resolve<AwaitResolveResult>(program);
Assert.IsTrue(rr.IsError);
Assert.AreEqual(SpecialType.UnknownType, rr.Type);
Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32));
Assert.IsInstanceOf<CSharpInvocationResolveResult>(rr.GetAwaiterInvocation);
var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation;
Assert.IsFalse(rr.GetAwaiterInvocation.IsError);
@ -567,16 +629,151 @@ namespace N { @@ -567,16 +629,151 @@ namespace N {
Assert.IsNotNull(rr.OnCompletedMethod);
Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName);
Assert.IsNull(rr.GetResultMethod);
Assert.IsNotNull(rr.GetResultMethod);
}
[Test]
public void AwaiterWithIsCompletedPropertyThatIsNotBoolean() {
string program = @"
using System;
namespace N {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public string IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
}
class MyAwaitable {
public static MyAwaiter GetAwaiter(int i) { return null; }
}
static class MyAwaitableExtensions {
public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; }
}
public class C {
public async void M() {
MyAwaitable x = null;
int i = $await x$;
}
}
}";
var rr = Resolve<AwaitResolveResult>(program);
Assert.IsTrue(rr.IsError);
Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32));
Assert.IsInstanceOf<CSharpInvocationResolveResult>(rr.GetAwaiterInvocation);
var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation;
Assert.IsFalse(rr.GetAwaiterInvocation.IsError);
Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count);
Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName);
Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count);
Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x");
Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName);
Assert.IsNull(rr.IsCompletedProperty);
Assert.IsNotNull(rr.OnCompletedMethod);
Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName);
Assert.IsNotNull(rr.GetResultMethod);
}
[Test]
public void AwaiterWithNoOnCompletedMethodWithSuitableSignature() {
public void AwaiterWithIsCompletedPropertyThatIsNotReadable() {
string program = @"
using System;
namespace N {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
public bool IsCompleted { set {} }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
}
class MyAwaitable {
public static MyAwaiter GetAwaiter(int i) { return null; }
}
static class MyAwaitableExtensions {
public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; }
}
public class C {
public async void M() {
MyAwaitable x = null;
int i = $await x$;
}
}
}";
var rr = Resolve<AwaitResolveResult>(program);
Assert.IsTrue(rr.IsError);
Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32));
Assert.IsInstanceOf<CSharpInvocationResolveResult>(rr.GetAwaiterInvocation);
var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation;
Assert.IsFalse(rr.GetAwaiterInvocation.IsError);
Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count);
Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName);
Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count);
Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x");
Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName);
Assert.IsNull(rr.IsCompletedProperty);
Assert.IsNotNull(rr.OnCompletedMethod);
Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName);
Assert.IsNotNull(rr.GetResultMethod);
}
[Test]
public void AwaiterWithIsCompletedPropertyThatIsNotAccessible() {
string program = @"
using System;
namespace N {
class MyAwaiter : System.Runtime.CompilerServices.INotifyCompletion {
private bool IsCompleted { get; set; }
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
}
class MyAwaitable {
public static MyAwaiter GetAwaiter(int i) { return null; }
}
static class MyAwaitableExtensions {
public static MyAwaiter GetAwaiter(this MyAwaitable x) { return null; }
}
public class C {
public async void M() {
MyAwaitable x = null;
int i = $await x$;
}
}
}";
var rr = Resolve<AwaitResolveResult>(program);
Assert.IsTrue(rr.IsError);
Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32));
Assert.IsInstanceOf<CSharpInvocationResolveResult>(rr.GetAwaiterInvocation);
var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation;
Assert.IsFalse(rr.GetAwaiterInvocation.IsError);
Assert.AreEqual(1, getAwaiterInvocation.Arguments.Count);
Assert.AreEqual("N.MyAwaitableExtensions.GetAwaiter", getAwaiterInvocation.Member.FullName);
Assert.AreEqual(1, getAwaiterInvocation.Member.Parameters.Count);
Assert.IsTrue(getAwaiterInvocation.Arguments[0] is LocalResolveResult && ((LocalResolveResult)getAwaiterInvocation.Arguments[0]).Variable.Name == "x");
Assert.AreEqual("N.MyAwaiter", rr.AwaiterType.FullName);
Assert.IsNull(rr.IsCompletedProperty);
Assert.IsNotNull(rr.OnCompletedMethod);
Assert.AreEqual("N.MyAwaiter.OnCompleted", rr.OnCompletedMethod.FullName);
Assert.IsNotNull(rr.GetResultMethod);
}
[Test]
public void AwaiterThatDoesNotImplementINotifyCompletion() {
string program = @"
using System;
class MyAwaiter {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Func<int> continuation) {}
public void OnCompleted(Action continuation) {}
public int GetResult() { return 0; }
}
class MyAwaitable {
@ -611,6 +808,49 @@ public class C { @@ -611,6 +808,49 @@ public class C {
Assert.AreEqual("MyAwaiter.GetResult", rr.GetResultMethod.FullName);
}
[Test]
public void AwaiterThatImplementsICriticalNotifyCompletion() {
string program = @"
using System;
class MyAwaiter : System.Runtime.CompilerServices.ICriticalNotifyCompletion {
public bool IsCompleted { get { return false; } }
public void OnCompleted(Action continuation) {}
public void UnsafeOnCompleted(Action continuation) {}
public int GetResult() { return 0; }
}
class MyAwaitable {
public MyAwaiter GetAwaiter() { return null; }
public MyAwaiter GetAwaiter(int i) { return null; }
}
public class C {
public async void M() {
MyAwaitable x = null;
int i = $await x$;
}
}";
var rr = Resolve<AwaitResolveResult>(program);
Assert.IsFalse(rr.IsError);
Assert.IsTrue(rr.Type.IsKnownType(KnownTypeCode.Int32));
Assert.IsInstanceOf<CSharpInvocationResolveResult>(rr.GetAwaiterInvocation);
var getAwaiterInvocation = (CSharpInvocationResolveResult)rr.GetAwaiterInvocation;
Assert.IsFalse(rr.GetAwaiterInvocation.IsError);
Assert.AreEqual(0, getAwaiterInvocation.Arguments.Count);
Assert.AreEqual("MyAwaitable.GetAwaiter", getAwaiterInvocation.Member.FullName);
Assert.AreEqual(0, getAwaiterInvocation.Member.Parameters.Count);
Assert.AreEqual("MyAwaiter", rr.AwaiterType.FullName);
Assert.IsNotNull(rr.IsCompletedProperty);
Assert.AreEqual("MyAwaiter.IsCompleted", rr.IsCompletedProperty.FullName);
Assert.IsNotNull(rr.OnCompletedMethod);
Assert.AreEqual("MyAwaiter.UnsafeOnCompleted", rr.OnCompletedMethod.FullName);
Assert.IsNotNull(rr.GetResultMethod);
Assert.AreEqual("MyAwaiter.GetResult", rr.GetResultMethod.FullName);
}
[Test]
public void AwaitDynamic() {
string program = @"
@ -640,5 +880,6 @@ public class C { @@ -640,5 +880,6 @@ public class C {
Assert.IsNull(rr.OnCompletedMethod);
Assert.IsNull(rr.GetResultMethod);
}
#endif // NET_4_5
}
}

22
ICSharpCode.NRefactory/TypeSystem/KnownTypeReference.cs

@ -118,7 +118,11 @@ namespace ICSharpCode.NRefactory.TypeSystem @@ -118,7 +118,11 @@ namespace ICSharpCode.NRefactory.TypeSystem
/// <summary><c>System.Nullable{T}</c></summary>
NullableOfT,
/// <summary><c>System.IDisposable</c></summary>
IDisposable
IDisposable,
/// <summary><c>System.Runtime.CompilerServices.INotifyCompletion</c></summary>
INotifyCompletion,
/// <summary><c>System.Runtime.CompilerServices.ICriticalNotifyCompletion</c></summary>
ICriticalNotifyCompletion,
}
/// <summary>
@ -127,7 +131,7 @@ namespace ICSharpCode.NRefactory.TypeSystem @@ -127,7 +131,7 @@ namespace ICSharpCode.NRefactory.TypeSystem
[Serializable]
public sealed class KnownTypeReference : ITypeReference
{
internal const int KnownTypeCodeCount = (int)KnownTypeCode.IDisposable + 1;
internal const int KnownTypeCodeCount = (int)KnownTypeCode.ICriticalNotifyCompletion + 1;
static readonly KnownTypeReference[] knownTypeReferences = new KnownTypeReference[KnownTypeCodeCount] {
null, // None
@ -174,6 +178,8 @@ namespace ICSharpCode.NRefactory.TypeSystem @@ -174,6 +178,8 @@ namespace ICSharpCode.NRefactory.TypeSystem
new KnownTypeReference(KnownTypeCode.TaskOfT, "System.Threading.Tasks", "Task", 1, baseType: KnownTypeCode.Task),
new KnownTypeReference(KnownTypeCode.NullableOfT, "System", "Nullable", 1, baseType: KnownTypeCode.ValueType),
new KnownTypeReference(KnownTypeCode.IDisposable, "System", "IDisposable"),
new KnownTypeReference(KnownTypeCode.INotifyCompletion, "System.Runtime.CompilerServices", "INotifyCompletion"),
new KnownTypeReference(KnownTypeCode.ICriticalNotifyCompletion, "System.Runtime.CompilerServices", "ICriticalNotifyCompletion"),
};
/// <summary>
@ -389,7 +395,17 @@ namespace ICSharpCode.NRefactory.TypeSystem @@ -389,7 +395,17 @@ namespace ICSharpCode.NRefactory.TypeSystem
/// Gets a type reference pointing to the <c>System.IDisposable</c> type.
/// </summary>
public static readonly KnownTypeReference IDisposable = Get(KnownTypeCode.IDisposable);
/// <summary>
/// Gets a type reference pointing to the <c>System.Runtime.CompilerServices.INotifyCompletion</c> type.
/// </summary>
public static readonly KnownTypeReference INotifyCompletion = Get(KnownTypeCode.INotifyCompletion);
/// <summary>
/// Gets a type reference pointing to the <c>System.Runtime.CompilerServices.ICriticalNotifyCompletion</c> type.
/// </summary>
public static readonly KnownTypeReference ICriticalNotifyCompletion = Get(KnownTypeCode.ICriticalNotifyCompletion);
readonly KnownTypeCode knownTypeCode;
readonly string namespaceName;
readonly string name;

Loading…
Cancel
Save