diff --git a/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs b/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs index 62341ac9ec..fb5760f4c5 100644 --- a/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs +++ b/ICSharpCode.NRefactory.CSharp/Resolver/CSharpResolver.cs @@ -692,6 +692,16 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver } else if (lhsType is PointerType && rhsType is PointerType) { return BinaryOperatorResolveResult(compilation.FindType(KnownTypeCode.Boolean), lhs, op, rhs); } + if (op == BinaryOperatorType.Equality || op == BinaryOperatorType.InEquality) { + if (lhsType.Kind == TypeKind.Null && NullableType.IsNullable(rhs.Type) + || rhsType.Kind == TypeKind.Null && NullableType.IsNullable(lhs.Type)) + { + // ยง7.10.9 Equality operators and null + // "x == null", "null == x", "x != null" and "null != x" are valid + // even if the struct does not define operator ==. + return BinaryOperatorResolveResult(compilation.FindType(KnownTypeCode.Boolean), lhs, op, rhs); + } + } switch (op) { case BinaryOperatorType.Equality: methodGroup = operators.EqualityOperators; @@ -900,10 +910,10 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver TypeCode lhsCode = ReflectionHelper.GetTypeCode(NullableType.GetUnderlyingType(lhs.Type)); TypeCode rhsCode = ReflectionHelper.GetTypeCode(NullableType.GetUnderlyingType(rhs.Type)); // if one of the inputs is the null literal, promote that to the type of the other operand - if (isNullable && SpecialType.NullType.Equals(lhs.Type)) { + if (isNullable && SpecialType.NullType.Equals(lhs.Type) && rhsCode >= TypeCode.Boolean && rhsCode <= TypeCode.Decimal) { lhs = CastTo(rhsCode, isNullable, lhs, allowNullableConstants); lhsCode = rhsCode; - } else if (isNullable && SpecialType.NullType.Equals(rhs.Type)) { + } else if (isNullable && SpecialType.NullType.Equals(rhs.Type) && lhsCode >= TypeCode.Boolean && lhsCode <= TypeCode.Decimal) { rhs = CastTo(lhsCode, isNullable, rhs, allowNullableConstants); rhsCode = lhsCode; } @@ -1075,9 +1085,13 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver LiftedUserDefinedOperator LiftUserDefinedOperator(IMethod m) { - IType returnType = m.ReturnType; - if (!NullableType.IsNonNullableValueType(returnType)) - return null; // cannot lift this operator + if (IsComparisonOperator(m)) { + if (!m.ReturnType.Equals(compilation.FindType(KnownTypeCode.Boolean))) + return null; // cannot lift this operator + } else { + if (!NullableType.IsNonNullableValueType(m.ReturnType)) + return null; // cannot lift this operator + } for (int i = 0; i < m.Parameters.Count; i++) { if (!NullableType.IsNonNullableValueType(m.Parameters[i].Type)) return null; // cannot lift this operator @@ -1085,6 +1099,23 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver return new LiftedUserDefinedOperator(m); } + static bool IsComparisonOperator(IMethod m) + { + var type = OperatorDeclaration.GetOperatorType(m.Name); + if (type.HasValue) { + switch (type.Value) { + case OperatorType.Equality: + case OperatorType.Inequality: + case OperatorType.GreaterThan: + case OperatorType.LessThan: + case OperatorType.GreaterThanOrEqual: + case OperatorType.LessThanOrEqual: + return true; + } + } + return false; + } + sealed class LiftedUserDefinedOperator : SpecializedMethod, OverloadResolution.ILiftedOperator { internal readonly IParameterizedMember nonLiftedOperator; @@ -1094,6 +1125,9 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver EmptyList.Instance, new MakeNullableVisitor(nonLiftedMethod.Compilation)) { this.nonLiftedOperator = nonLiftedMethod; + // Comparison operators keep the 'bool' return type even when lifted. + if (IsComparisonOperator(nonLiftedMethod)) + this.returnType = nonLiftedMethod.ReturnType; } public IList NonLiftedParameters { diff --git a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/BinaryOperatorTests.cs b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/BinaryOperatorTests.cs index ec952f48a8..6a016ec3f3 100644 --- a/ICSharpCode.NRefactory.Tests/CSharp/Resolver/BinaryOperatorTests.cs +++ b/ICSharpCode.NRefactory.Tests/CSharp/Resolver/BinaryOperatorTests.cs @@ -342,6 +342,9 @@ namespace ICSharpCode.NRefactory.CSharp.Resolver AssertType(typeof(bool), resolver.ResolveBinaryOperator( BinaryOperatorType.InEquality, MakeResult(typeof(int*)), MakeResult(typeof(uint*)))); + + AssertType(typeof(bool), resolver.ResolveBinaryOperator( + BinaryOperatorType.InEquality, MakeResult(typeof(bool?)), MakeConstant(null))); } [Test] @@ -579,5 +582,37 @@ class Test { Assert.IsNull(irr.UserDefinedOperatorMethod); Assert.AreEqual("System.Byte", irr.Type.ReflectionName); } + + [Test] + public void CompareNullableStructWithNullLiteral() + { + string program = @" +struct X { } +class Test { + static void Inc(X? x) { + var c = $x == null$; + } +}"; + var irr = Resolve(program); + Assert.IsFalse(irr.IsError); + Assert.AreEqual(compilation.FindType(KnownTypeCode.Boolean), irr.Type); + } + + [Test] + public void LiftedEqualityOperator() + { + string program = @" +struct X { + public static bool operator ==(X a, X b) {} +} +class Test { + static void Inc(X? x) { + var c = $x == x$; + } +}"; + var irr = Resolve(program); + Assert.IsFalse(irr.IsError); + Assert.AreEqual(compilation.FindType(KnownTypeCode.Boolean), irr.Type); + } } } diff --git a/ICSharpCode.NRefactory/TypeSystem/Implementation/SpecializedMember.cs b/ICSharpCode.NRefactory/TypeSystem/Implementation/SpecializedMember.cs index fb81ef7524..bf79177471 100644 --- a/ICSharpCode.NRefactory/TypeSystem/Implementation/SpecializedMember.cs +++ b/ICSharpCode.NRefactory/TypeSystem/Implementation/SpecializedMember.cs @@ -30,7 +30,7 @@ namespace ICSharpCode.NRefactory.TypeSystem.Implementation { readonly IType declaringType; readonly IMember memberDefinition; - IType returnType; + protected IType returnType; protected SpecializedMember(IType declaringType, IMember memberDefinition) {