// Copyright (c) 2017 Daniel Grunwald // // Permission is hereby granted, free of charge, to any person obtaining a copy of this // software and associated documentation files (the "Software"), to deal in the Software // without restriction, including without limitation the rights to use, copy, modify, merge, // publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons // to whom the Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all copies or // substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, // INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR // PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE // FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR // OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using ICSharpCode.Decompiler.TypeSystem; using ICSharpCode.Decompiler.Util; namespace ICSharpCode.Decompiler.IL.Transforms { /// /// Nullable lifting gets run in two places: /// * the usual form looks at an if-else, and runs within the ExpressionTransforms. /// * the NullableLiftingBlockTransform handles the cases where Roslyn generates /// two 'ret' statements for the null/non-null cases of a lifted operator. /// /// The transform handles the following languages constructs: /// * lifted conversions /// * lifted unary and binary operators /// * lifted comparisons /// * the ?? operator with type Nullable{T} on the left-hand-side /// * the ?. operator (via NullPropagationTransform) /// struct NullableLiftingTransform { readonly ILTransformContext context; List nullableVars; public NullableLiftingTransform(ILTransformContext context) { this.context = context; this.nullableVars = null; } #region Run /// /// Main entry point into the normal code path of this transform. /// Called by expression transform. /// public bool Run(IfInstruction ifInst) { var lifted = Lift(ifInst, ifInst.TrueInst, ifInst.FalseInst); if (lifted != null) { ifInst.ReplaceWith(lifted); return true; } return false; } public bool RunStatements(Block block, int pos) { /// e.g.: // if (!condition) Block { // leave IL_0000 (default.value System.Nullable`1[[System.Int64]]) // } // leave IL_0000 (newobj .ctor(exprToLift)) if (pos != block.Instructions.Count - 2) return false; if (!(block.Instructions[pos] is IfInstruction ifInst)) return false; if (!(Block.Unwrap(ifInst.TrueInst) is Leave thenLeave)) return false; if (!ifInst.FalseInst.MatchNop()) return false; if (!(block.Instructions[pos + 1] is Leave elseLeave)) return false; if (elseLeave.TargetContainer != thenLeave.TargetContainer) return false; var lifted = Lift(ifInst, thenLeave.Value, elseLeave.Value); if (lifted != null) { thenLeave.Value = lifted; ifInst.ReplaceWith(thenLeave); block.Instructions.Remove(elseLeave); return true; } return false; } #endregion #region AnalyzeCondition bool AnalyzeCondition(ILInstruction condition) { if (MatchHasValueCall(condition, out ILVariable v)) { if (nullableVars == null) nullableVars = new List(); nullableVars.Add(v); return true; } else if (condition is BinaryNumericInstruction bitand) { if (!(bitand.Operator == BinaryNumericOperator.BitAnd && bitand.ResultType == StackType.I4)) return false; return AnalyzeCondition(bitand.Left) && AnalyzeCondition(bitand.Right); } return false; } #endregion #region Main lifting logic /// /// Main entry point for lifting; called by both the expression-transform /// and the block transform. /// ILInstruction Lift(IfInstruction ifInst, ILInstruction trueInst, ILInstruction falseInst) { ILInstruction condition = ifInst.Condition; while (condition.MatchLogicNot(out var arg)) { condition = arg; ExtensionMethods.Swap(ref trueInst, ref falseInst); } if (context.Settings.NullPropagation && !NullPropagationTransform.IsProtectedIfInst(ifInst)) { var nullPropagated = new NullPropagationTransform(context) .Run(condition, trueInst, falseInst, ifInst.ILRange); if (nullPropagated != null) return nullPropagated; } if (!context.Settings.LiftNullables) return null; if (AnalyzeCondition(condition)) { // (v1 != null && ... && vn != null) ? trueInst : falseInst // => normal lifting return LiftNormal(trueInst, falseInst, ilrange: ifInst.ILRange); } if (MatchCompOrDecimal(condition, out var comp)) { // This might be a C#-style lifted comparison // (C# checks the underlying value before checking the HasValue bits) if (comp.Kind.IsEqualityOrInequality()) { // for equality/inequality, the HasValue bits must also compare equal/inequal if (comp.Kind == ComparisonKind.Inequality) { // handle inequality by swapping one last time ExtensionMethods.Swap(ref trueInst, ref falseInst); } if (falseInst.MatchLdcI4(0)) { // (a.GetValueOrDefault() == b.GetValueOrDefault()) ? (a.HasValue == b.HasValue) : false // => a == b return LiftCSharpEqualityComparison(comp, ComparisonKind.Equality, trueInst) ?? LiftCSharpUserEqualityComparison(comp, ComparisonKind.Equality, trueInst); } else if (falseInst.MatchLdcI4(1)) { // (a.GetValueOrDefault() == b.GetValueOrDefault()) ? (a.HasValue != b.HasValue) : true // => a != b return LiftCSharpEqualityComparison(comp, ComparisonKind.Inequality, trueInst) ?? LiftCSharpUserEqualityComparison(comp, ComparisonKind.Inequality, trueInst); } else if (IsGenericNewPattern(comp.Left, comp.Right, trueInst, falseInst)) { // (default(T) == null) ? Activator.CreateInstance() : default(T) // => Activator.CreateInstance() return trueInst; } } else { // Not (in)equality, but one of < <= > >=. // Returns false unless all HasValue bits are true. if (falseInst.MatchLdcI4(0) && AnalyzeCondition(trueInst)) { // comp(lhs, rhs) ? (v1 != null && ... && vn != null) : false // => comp.lifted[C#](lhs, rhs) return LiftCSharpComparison(comp, comp.Kind); } else if (trueInst.MatchLdcI4(0) && AnalyzeCondition(falseInst)) { // comp(lhs, rhs) ? false : (v1 != null && ... && vn != null) return LiftCSharpComparison(comp, comp.Kind.Negate()); } } } ILVariable v; // Handle equality comparisons with bool?: if (MatchGetValueOrDefault(condition, out v) && NullableType.GetUnderlyingType(v.Type).IsKnownType(KnownTypeCode.Boolean)) { if (MatchHasValueCall(trueInst, v) && falseInst.MatchLdcI4(0)) { // v.GetValueOrDefault() ? v.HasValue : false // ==> v == true context.Step("NullableLiftingTransform: v == true", ifInst); return new Comp(ComparisonKind.Equality, ComparisonLiftingKind.CSharp, StackType.I4, Sign.None, new LdLoc(v) { ILRange = trueInst.ILRange }, new LdcI4(1) { ILRange = falseInst.ILRange } ) { ILRange = ifInst.ILRange }; } else if (trueInst.MatchLdcI4(0) && MatchHasValueCall(falseInst, v)) { // v.GetValueOrDefault() ? false : v.HasValue // ==> v == false context.Step("NullableLiftingTransform: v == false", ifInst); return new Comp(ComparisonKind.Equality, ComparisonLiftingKind.CSharp, StackType.I4, Sign.None, new LdLoc(v) { ILRange = falseInst.ILRange }, trueInst // LdcI4(0) ) { ILRange = ifInst.ILRange }; } else if (MatchNegatedHasValueCall(trueInst, v) && falseInst.MatchLdcI4(1)) { // v.GetValueOrDefault() ? !v.HasValue : true // ==> v != true context.Step("NullableLiftingTransform: v != true", ifInst); return new Comp(ComparisonKind.Inequality, ComparisonLiftingKind.CSharp, StackType.I4, Sign.None, new LdLoc(v) { ILRange = trueInst.ILRange }, falseInst // LdcI4(1) ) { ILRange = ifInst.ILRange }; } else if (trueInst.MatchLdcI4(1) && MatchNegatedHasValueCall(falseInst, v)) { // v.GetValueOrDefault() ? true : !v.HasValue // ==> v != false context.Step("NullableLiftingTransform: v != false", ifInst); return new Comp(ComparisonKind.Inequality, ComparisonLiftingKind.CSharp, StackType.I4, Sign.None, new LdLoc(v) { ILRange = falseInst.ILRange }, new LdcI4(0) { ILRange = trueInst.ILRange } ) { ILRange = ifInst.ILRange }; } } // Handle & and | on bool?: if (trueInst.MatchLdLoc(out v)) { if (MatchNullableCtor(falseInst, out var utype, out var arg) && utype.IsKnownType(KnownTypeCode.Boolean) && arg.MatchLdcI4(0)) { // condition ? v : (bool?)false // => condition & v context.Step("NullableLiftingTransform: 3vl.logic.and(bool, bool?)", ifInst); return new ThreeValuedLogicAnd(condition, trueInst) { ILRange = ifInst.ILRange }; } if (falseInst.MatchLdLoc(out var v2)) { // condition ? v : v2 if (MatchThreeValuedLogicConditionPattern(condition, out var nullable1, out var nullable2)) { // (nullable1.GetValueOrDefault() || (!nullable2.GetValueOrDefault() && !nullable1.HasValue)) ? v : v2 if (v == nullable1 && v2 == nullable2) { context.Step("NullableLiftingTransform: 3vl.logic.or(bool?, bool?)", ifInst); return new ThreeValuedLogicOr(trueInst, falseInst) { ILRange = ifInst.ILRange }; } else if (v == nullable2 && v2 == nullable1) { context.Step("NullableLiftingTransform: 3vl.logic.and(bool?, bool?)", ifInst); return new ThreeValuedLogicAnd(falseInst, trueInst) { ILRange = ifInst.ILRange }; } } } } else if (falseInst.MatchLdLoc(out v)) { if (MatchNullableCtor(trueInst, out var utype, out var arg) && utype.IsKnownType(KnownTypeCode.Boolean) && arg.MatchLdcI4(1)) { // condition ? (bool?)true : v // => condition | v context.Step("NullableLiftingTransform: 3vl.logic.or(bool, bool?)", ifInst); return new ThreeValuedLogicOr(condition, falseInst) { ILRange = ifInst.ILRange }; } } return null; } private bool IsGenericNewPattern(ILInstruction compLeft, ILInstruction compRight, ILInstruction trueInst, ILInstruction falseInst) { // (default(T) == null) ? Activator.CreateInstance() : default(T) return falseInst.MatchDefaultValue(out var type) && (trueInst is Call c && c.Method.FullName == "System.Activator.CreateInstance" && c.Method.TypeArguments.Count == 1) && type.Kind == TypeKind.TypeParameter && compLeft.MatchDefaultValue(out var type2) && type.Equals(type2) && compRight.MatchLdNull(); } private bool MatchThreeValuedLogicConditionPattern(ILInstruction condition, out ILVariable nullable1, out ILVariable nullable2) { // Try to match: nullable1.GetValueOrDefault() || (!nullable2.GetValueOrDefault() && !nullable1.HasValue) nullable1 = null; nullable2 = null; if (!condition.MatchLogicOr(out var lhs, out var rhs)) return false; if (!MatchGetValueOrDefault(lhs, out nullable1)) return false; if (!NullableType.GetUnderlyingType(nullable1.Type).IsKnownType(KnownTypeCode.Boolean)) return false; if (!rhs.MatchLogicAnd(out lhs, out rhs)) return false; if (!lhs.MatchLogicNot(out var arg)) return false; if (!MatchGetValueOrDefault(arg, out nullable2)) return false; if (!NullableType.GetUnderlyingType(nullable2.Type).IsKnownType(KnownTypeCode.Boolean)) return false; if (!rhs.MatchLogicNot(out arg)) return false; return MatchHasValueCall(arg, nullable1); } #endregion #region CSharpComp static bool MatchCompOrDecimal(ILInstruction inst, out CompOrDecimal result) { result = default(CompOrDecimal); result.Instruction = inst; if (inst is Comp comp && !comp.IsLifted) { result.Kind = comp.Kind; result.Left = comp.Left; result.Right = comp.Right; return true; } else if (inst is Call call && call.Method.IsOperator && call.Arguments.Count == 2 && !call.IsLifted) { switch (call.Method.Name) { case "op_Equality": result.Kind = ComparisonKind.Equality; break; case "op_Inequality": result.Kind = ComparisonKind.Inequality; break; case "op_LessThan": result.Kind = ComparisonKind.LessThan; break; case "op_LessThanOrEqual": result.Kind = ComparisonKind.LessThanOrEqual; break; case "op_GreaterThan": result.Kind = ComparisonKind.GreaterThan; break; case "op_GreaterThanOrEqual": result.Kind = ComparisonKind.GreaterThanOrEqual; break; default: return false; } result.Left = call.Arguments[0]; result.Right = call.Arguments[1]; return call.Method.DeclaringType.IsKnownType(KnownTypeCode.Decimal); } return false; } /// /// Represents either non-lifted IL `Comp` or a call to one of the (non-lifted) 6 comparison operators on `System.Decimal`. /// struct CompOrDecimal { public ILInstruction Instruction; public ComparisonKind Kind; public ILInstruction Left; public ILInstruction Right; public IType LeftExpectedType { get { if (Instruction is Call call) { return call.Method.Parameters[0].Type; } else { return SpecialType.UnknownType; } } } public IType RightExpectedType { get { if (Instruction is Call call) { return call.Method.Parameters[1].Type; } else { return SpecialType.UnknownType; } } } internal ILInstruction MakeLifted(ComparisonKind newComparisonKind, ILInstruction left, ILInstruction right) { if (Instruction is Comp comp) { return new Comp(newComparisonKind, ComparisonLiftingKind.CSharp, comp.InputType, comp.Sign, left, right) { ILRange = Instruction.ILRange }; } else if (Instruction is Call call) { IMethod method; if (newComparisonKind == Kind) { method = call.Method; } else if (newComparisonKind == ComparisonKind.Inequality && call.Method.Name == "op_Equality") { method = call.Method.DeclaringType.GetMethods(m => m.Name == "op_Inequality") .FirstOrDefault(m => ParameterListComparer.Instance.Equals(m.Parameters, call.Method.Parameters)); if (method == null) return null; } else { return null; } return new Call(CSharp.Resolver.CSharpOperators.LiftUserDefinedOperator(method)) { Arguments = { left, right }, ConstrainedTo = call.ConstrainedTo, ILRange = call.ILRange, ILStackWasEmpty = call.ILStackWasEmpty, IsTail = call.IsTail }; } else { return null; } } } #endregion #region Lift...Comparison ILInstruction LiftCSharpEqualityComparison(CompOrDecimal valueComp, ComparisonKind newComparisonKind, ILInstruction hasValueTest) { Debug.Assert(newComparisonKind.IsEqualityOrInequality()); bool hasValueTestNegated = false; while (hasValueTest.MatchLogicNot(out var arg)) { hasValueTest = arg; hasValueTestNegated = !hasValueTestNegated; } // The HasValue comparison must be the same operator as the Value comparison. if (hasValueTest is Comp hasValueComp) { // Comparing two nullables: HasValue comparison must be the same operator as the Value comparison if ((hasValueTestNegated ? hasValueComp.Kind.Negate() : hasValueComp.Kind) != newComparisonKind) return null; if (!MatchHasValueCall(hasValueComp.Left, out ILVariable leftVar)) return null; if (!MatchHasValueCall(hasValueComp.Right, out ILVariable rightVar)) return null; nullableVars = new List { leftVar }; var (left, leftBits) = DoLift(valueComp.Left); nullableVars[0] = rightVar; var (right, rightBits) = DoLift(valueComp.Right); if (left != null && right != null && leftBits[0] && rightBits[0] && SemanticHelper.IsPure(left.Flags) && SemanticHelper.IsPure(right.Flags) ) { context.Step("NullableLiftingTransform: C# (in)equality comparison", valueComp.Instruction); return valueComp.MakeLifted(newComparisonKind, left, right); } } else if (newComparisonKind == ComparisonKind.Equality && !hasValueTestNegated && MatchHasValueCall(hasValueTest, out ILVariable v)) { // Comparing nullable with non-nullable -> we can fall back to the normal comparison code. nullableVars = new List { v }; return LiftCSharpComparison(valueComp, newComparisonKind); } else if (newComparisonKind == ComparisonKind.Inequality && hasValueTestNegated && MatchHasValueCall(hasValueTest, out v)) { // Comparing nullable with non-nullable -> we can fall back to the normal comparison code. nullableVars = new List { v }; return LiftCSharpComparison(valueComp, newComparisonKind); } return null; } /// /// Lift a C# comparison. /// This method cannot be used for (in)equality comparisons where both sides are nullable /// (these special cases are handled in LiftCSharpEqualityComparison instead). /// /// The output instructions should evaluate to false when any of the nullableVars is null /// (except for newComparisonKind==Inequality, where this case should evaluate to true instead). /// Otherwise, the output instruction should evaluate to the same value as the input instruction. /// The output instruction should have the same side-effects (incl. exceptions being thrown) as the input instruction. /// This means unlike LiftNormal(), we cannot rely on the input instruction not being evaluated if /// a variable is null. /// ILInstruction LiftCSharpComparison(CompOrDecimal comp, ComparisonKind newComparisonKind) { var (left, right, bits) = DoLiftBinary(comp.Left, comp.Right, comp.LeftExpectedType, comp.RightExpectedType); // due to the restrictions on side effects, we only allow instructions that are pure after lifting. // (we can't check this before lifting due to the calls to GetValueOrDefault()) if (left != null && right != null && SemanticHelper.IsPure(left.Flags) && SemanticHelper.IsPure(right.Flags)) { if (!bits.All(0, nullableVars.Count)) { // don't lift if a nullableVar doesn't contribute to the result return null; } context.Step("NullableLiftingTransform: C# comparison", comp.Instruction); return comp.MakeLifted(newComparisonKind, left, right); } return null; } Call LiftCSharpUserEqualityComparison(CompOrDecimal hasValueComp, ComparisonKind newComparisonKind, ILInstruction nestedIfInst) { // User-defined equality operator: // if (comp(call get_HasValue(ldloca nullable1) == call get_HasValue(ldloca nullable2))) // if (logic.not(call get_HasValue(ldloca nullable))) // ldc.i4 1 // else // call op_Equality(call GetValueOrDefault(ldloca nullable1), call GetValueOrDefault(ldloca nullable2) // else // ldc.i4 0 // User-defined inequality operator: // if (comp(call get_HasValue(ldloca nullable1) != call get_HasValue(ldloca nullable2))) // ldc.i4 1 // else // if (call get_HasValue(ldloca nullable)) // call op_Inequality(call GetValueOrDefault(ldloca nullable1), call GetValueOrDefault(ldloca nullable2)) // else // ldc.i4 0 if (!MatchHasValueCall(hasValueComp.Left, out ILVariable nullable1)) return null; if (!MatchHasValueCall(hasValueComp.Right, out ILVariable nullable2)) return null; if (!nestedIfInst.MatchIfInstructionPositiveCondition(out var condition, out var trueInst, out var falseInst)) return null; if (!MatchHasValueCall(condition, out ILVariable nullable)) return null; if (nullable != nullable1 && nullable != nullable2) return null; if (!falseInst.MatchLdcI4(newComparisonKind == ComparisonKind.Equality ? 1 : 0)) return null; if (!(trueInst is Call call)) return null; if (!(call.Method.IsOperator && call.Arguments.Count == 2)) return null; if (call.Method.Name != (newComparisonKind == ComparisonKind.Equality ? "op_Equality" : "op_Inequality")) return null; var liftedOperator = CSharp.Resolver.CSharpOperators.LiftUserDefinedOperator(call.Method); if (liftedOperator == null) return null; nullableVars = new List { nullable1 }; var (left, leftBits) = DoLift(call.Arguments[0]); nullableVars[0] = nullable2; var (right, rightBits) = DoLift(call.Arguments[1]); if (left != null && right != null && leftBits[0] && rightBits[0] && SemanticHelper.IsPure(left.Flags) && SemanticHelper.IsPure(right.Flags) ) { context.Step("NullableLiftingTransform: C# user-defined (in)equality comparison", nestedIfInst); return new Call(liftedOperator) { Arguments = { left, right }, ConstrainedTo = call.ConstrainedTo, ILRange = call.ILRange, ILStackWasEmpty = call.ILStackWasEmpty, IsTail = call.IsTail, }; } return null; } #endregion #region LiftNormal / DoLift /// /// Performs nullable lifting. /// /// Produces a lifted instruction with semantics equivalent to: /// (v1 != null && ... && vn != null) ? trueInst : falseInst, /// where the v1,...,vn are the this.nullableVars. /// If lifting fails, returns null. /// ILInstruction LiftNormal(ILInstruction trueInst, ILInstruction falseInst, Interval ilrange) { if (trueInst.MatchIfInstructionPositiveCondition(out var nestedCondition, out var nestedTrue, out var nestedFalse)) { // Sometimes Roslyn generates pointless conditions like: // if (nullable.HasValue && (!nullable.HasValue || nullable.GetValueOrDefault() == b)) if (MatchHasValueCall(nestedCondition, out ILVariable v) && nullableVars.Contains(v)) { trueInst = nestedTrue; } } bool isNullCoalescingWithNonNullableFallback = false; if (!MatchNullableCtor(trueInst, out var utype, out var exprToLift)) { isNullCoalescingWithNonNullableFallback = true; utype = context.TypeSystem.FindType(trueInst.ResultType.ToKnownTypeCode()); exprToLift = trueInst; if (nullableVars.Count == 1 && exprToLift.MatchLdLoc(nullableVars[0])) { // v.HasValue ? ldloc v : fallback // => v ?? fallback context.Step("v.HasValue ? v : fallback => v ?? fallback", trueInst); return new NullCoalescingInstruction(NullCoalescingKind.Nullable, trueInst, falseInst) { UnderlyingResultType = NullableType.GetUnderlyingType(nullableVars[0].Type).GetStackType(), ILRange = ilrange }; } else if (trueInst is Call call && !call.IsLifted && CSharp.Resolver.CSharpOperators.IsComparisonOperator(call.Method) && falseInst.MatchLdcI4(call.Method.Name == "op_Inequality" ? 1 : 0)) { // (v1 != null && ... && vn != null) ? call op_LessThan(lhs, rhs) : ldc.i4(0) var liftedOperator = CSharp.Resolver.CSharpOperators.LiftUserDefinedOperator(call.Method); if ((call.Method.Name == "op_Equality" || call.Method.Name == "op_Inequality") && nullableVars.Count != 1) { // Equality is special (returns true if both sides are null), only handle it // in the normal code path if we're dealing with only a single nullable var // (comparing nullable with non-nullable). liftedOperator = null; } if (liftedOperator != null) { context.Step("Lift user-defined comparison operator", trueInst); var (left, right, bits) = DoLiftBinary(call.Arguments[0], call.Arguments[1], call.Method.Parameters[0].Type, call.Method.Parameters[1].Type); if (left != null && right != null && bits.All(0, nullableVars.Count)) { return new Call(liftedOperator) { Arguments = { left, right }, ConstrainedTo = call.ConstrainedTo, ILRange = call.ILRange, ILStackWasEmpty = call.ILStackWasEmpty, IsTail = call.IsTail }; } } } } ILInstruction lifted; if (nullableVars.Count == 1 && MatchGetValueOrDefault(exprToLift, nullableVars[0])) { // v.HasValue ? call GetValueOrDefault(ldloca v) : fallback // => conv.nop.lifted(ldloc v) ?? fallback // This case is handled separately from DoLift() because // that doesn't introduce nop-conversions. context.Step("v.HasValue ? v.GetValueOrDefault() : fallback => v ?? fallback", trueInst); var inputUType = NullableType.GetUnderlyingType(nullableVars[0].Type); lifted = new LdLoc(nullableVars[0]); if (!inputUType.Equals(utype) && utype.ToPrimitiveType() != PrimitiveType.None) { // While the ILAst allows implicit conversions between short and int // (because both map to I4); it does not allow implicit conversions // between short? and int? (structs of different types). // So use 'conv.nop.lifted' to allow the conversion. lifted = new Conv( lifted, inputUType.GetStackType(), inputUType.GetSign(), utype.ToPrimitiveType(), checkForOverflow: false, isLifted: true ) { ILRange = ilrange }; } } else { context.Step("NullableLiftingTransform.DoLift", trueInst); BitSet bits; (lifted, bits) = DoLift(exprToLift); if (lifted == null) { return null; } if (!bits.All(0, nullableVars.Count)) { // don't lift if a nullableVar doesn't contribute to the result return null; } Debug.Assert(lifted is ILiftableInstruction liftable && liftable.IsLifted && liftable.UnderlyingResultType == exprToLift.ResultType); } if (isNullCoalescingWithNonNullableFallback) { lifted = new NullCoalescingInstruction(NullCoalescingKind.NullableWithValueFallback, lifted, falseInst) { UnderlyingResultType = exprToLift.ResultType, ILRange = ilrange }; } else if (!MatchNull(falseInst, utype)) { // Normal lifting, but the falseInst isn't `default(utype?)` // => use the `??` operator to provide the fallback value. lifted = new NullCoalescingInstruction(NullCoalescingKind.Nullable, lifted, falseInst) { UnderlyingResultType = exprToLift.ResultType, ILRange = ilrange }; } return lifted; } /// /// Recursive function that lifts the specified instruction. /// The input instruction is expected to a subexpression of the trueInst /// (so that all nullableVars are guaranteed non-null within this expression). /// /// Creates a new lifted instruction without modifying the input instruction. /// On success, returns (new lifted instruction, bitset). /// If lifting fails, returns (null, null). /// /// The returned bitset specifies which nullableVars were considered "relevant" for this instruction. /// bitSet[i] == true means nullableVars[i] was relevant. /// /// The new lifted instruction will have equivalent semantics to the input instruction /// if all relevant variables are non-null [except that the result will be wrapped in a Nullable{T} struct]. /// If any relevant variable is null, the new instruction is guaranteed to evaluate to null /// without having any other effect. /// (ILInstruction, BitSet) DoLift(ILInstruction inst) { if (MatchGetValueOrDefault(inst, out ILVariable inputVar)) { // n.GetValueOrDefault() lifted => n. BitSet foundIndices = new BitSet(nullableVars.Count); for (int i = 0; i < nullableVars.Count; i++) { if (nullableVars[i] == inputVar) { foundIndices[i] = true; } } if (foundIndices.Any()) return (new LdLoc(inputVar) { ILRange = inst.ILRange }, foundIndices); else return (null, null); } else if (inst is Conv conv) { var (arg, bits) = DoLift(conv.Argument); if (arg != null) { if (conv.HasDirectFlag(InstructionFlags.MayThrow) && !bits.All(0, nullableVars.Count)) { // Cannot execute potentially-throwing instruction unless all // the nullableVars are arguments to the instruction // (thus causing it not to throw when any of them is null). return (null, null); } var newInst = new Conv(arg, conv.InputType, conv.InputSign, conv.TargetType, conv.CheckForOverflow, isLifted: true) { ILRange = conv.ILRange }; return (newInst, bits); } } else if (inst is BitNot bitnot) { var (arg, bits) = DoLift(bitnot.Argument); if (arg != null) { var newInst = new BitNot(arg, isLifted: true, stackType: bitnot.ResultType) { ILRange = bitnot.ILRange }; return (newInst, bits); } } else if (inst is BinaryNumericInstruction binary) { var (left, right, bits) = DoLiftBinary(binary.Left, binary.Right, SpecialType.UnknownType, SpecialType.UnknownType); if (left != null && right != null) { if (binary.HasDirectFlag(InstructionFlags.MayThrow) && !bits.All(0, nullableVars.Count)) { // Cannot execute potentially-throwing instruction unless all // the nullableVars are arguments to the instruction // (thus causing it not to throw when any of them is null). return (null, null); } var newInst = new BinaryNumericInstruction( binary.Operator, left, right, binary.LeftInputType, binary.RightInputType, binary.CheckForOverflow, binary.Sign, isLifted: true ) { ILRange = binary.ILRange }; return (newInst, bits); } } else if (inst is Comp comp && !comp.IsLifted && comp.Kind == ComparisonKind.Equality && MatchGetValueOrDefault(comp.Left, out ILVariable v) && NullableType.GetUnderlyingType(v.Type).IsKnownType(KnownTypeCode.Boolean) && comp.Right.MatchLdcI4(0) ) { // C# doesn't support ComparisonLiftingKind.ThreeValuedLogic, // except for operator! on bool?. var (arg, bits) = DoLift(comp.Left); Debug.Assert(arg != null); var newInst = new Comp(comp.Kind, ComparisonLiftingKind.ThreeValuedLogic, comp.InputType, comp.Sign, arg, comp.Right.Clone()) { ILRange = comp.ILRange }; return (newInst, bits); } else if (inst is Call call && call.Method.IsOperator) { // Lifted user-defined operators, except for comparison operators (as those return bool, not bool?) var liftedOperator = CSharp.Resolver.CSharpOperators.LiftUserDefinedOperator(call.Method); if (liftedOperator == null || !NullableType.IsNullable(liftedOperator.ReturnType)) return (null, null); ILInstruction[] newArgs; BitSet newBits; if (call.Arguments.Count == 1) { var (arg, bits) = DoLift(call.Arguments[0]); newArgs = new[] { arg }; newBits = bits; } else if (call.Arguments.Count == 2) { var (left, right, bits) = DoLiftBinary(call.Arguments[0], call.Arguments[1], call.Method.Parameters[0].Type, call.Method.Parameters[1].Type); newArgs = new[] { left, right }; newBits = bits; } else { return (null, null); } if (newBits == null || !newBits.All(0, nullableVars.Count)) { // all nullable vars must be involved when calling a method (side effect) return (null, null); } var newInst = new Call(liftedOperator) { ConstrainedTo = call.ConstrainedTo, IsTail = call.IsTail, ILStackWasEmpty = call.ILStackWasEmpty, ILRange = call.ILRange }; newInst.Arguments.AddRange(newArgs); return (newInst, newBits); } return (null, null); } (ILInstruction, ILInstruction, BitSet) DoLiftBinary(ILInstruction lhs, ILInstruction rhs, IType leftExpectedType, IType rightExpectedType) { var (left, leftBits) = DoLift(lhs); var (right, rightBits) = DoLift(rhs); if (left != null && right == null && SemanticHelper.IsPure(rhs.Flags)) { // Embed non-nullable pure expression in lifted expression. right = NewNullable(rhs.Clone(), rightExpectedType); } if (left == null && right != null && SemanticHelper.IsPure(lhs.Flags)) { // Embed non-nullable pure expression in lifted expression. left = NewNullable(lhs.Clone(), leftExpectedType); } if (left != null && right != null) { var bits = leftBits ?? rightBits; if (rightBits != null) bits.UnionWith(rightBits); return (left, right, bits); } else { return (null, null, null); } } private ILInstruction NewNullable(ILInstruction inst, IType underlyingType) { if (underlyingType == SpecialType.UnknownType) return inst; var nullable = context.TypeSystem.FindType(KnownTypeCode.NullableOfT).GetDefinition(); var ctor = nullable?.Methods.FirstOrDefault(m => m.IsConstructor && m.Parameters.Count == 1); if (ctor != null) { ctor = ctor.Specialize(new TypeParameterSubstitution(new[] { underlyingType }, null)); return new NewObj(ctor) { Arguments = { inst } }; } else { return inst; } } #endregion #region Match...Call /// /// Matches 'call get_HasValue(arg)' /// internal static bool MatchHasValueCall(ILInstruction inst, out ILInstruction arg) { arg = null; if (!(inst is Call call)) return false; if (call.Arguments.Count != 1) return false; if (call.Method.Name != "get_HasValue") return false; if (call.Method.DeclaringTypeDefinition?.KnownTypeCode != KnownTypeCode.NullableOfT) return false; arg = call.Arguments[0]; return true; } /// /// Matches 'call get_HasValue(ldloca v)' /// internal static bool MatchHasValueCall(ILInstruction inst, out ILVariable v) { if (MatchHasValueCall(inst, out ILInstruction arg)) { return arg.MatchLdLoca(out v); } v = null; return false; } /// /// Matches 'call get_HasValue(ldloca v)' /// internal static bool MatchHasValueCall(ILInstruction inst, ILVariable v) { return MatchHasValueCall(inst, out ILVariable v2) && v == v2; } /// /// Matches 'logic.not(call get_HasValue(ldloca v))' /// static bool MatchNegatedHasValueCall(ILInstruction inst, ILVariable v) { return inst.MatchLogicNot(out var arg) && MatchHasValueCall(arg, v); } /// /// Matches 'newobj Nullable{underlyingType}.ctor(arg)' /// internal static bool MatchNullableCtor(ILInstruction inst, out IType underlyingType, out ILInstruction arg) { underlyingType = null; arg = null; if (!(inst is NewObj newobj)) return false; if (!newobj.Method.IsConstructor || newobj.Arguments.Count != 1) return false; if (newobj.Method.DeclaringTypeDefinition?.KnownTypeCode != KnownTypeCode.NullableOfT) return false; arg = newobj.Arguments[0]; underlyingType = NullableType.GetUnderlyingType(newobj.Method.DeclaringType); return true; } /// /// Matches 'call Nullable{T}.GetValueOrDefault(arg)' /// internal static bool MatchGetValueOrDefault(ILInstruction inst, out ILInstruction arg) { arg = null; if (!(inst is Call call)) return false; if (call.Method.Name != "GetValueOrDefault" || call.Arguments.Count != 1) return false; if (call.Method.DeclaringTypeDefinition?.KnownTypeCode != KnownTypeCode.NullableOfT) return false; arg = call.Arguments[0]; return true; } /// /// Matches 'call Nullable{T}.GetValueOrDefault(ldloca v)' /// internal static bool MatchGetValueOrDefault(ILInstruction inst, out ILVariable v) { v = null; return MatchGetValueOrDefault(inst, out ILInstruction arg) && arg.MatchLdLoca(out v); } /// /// Matches 'call Nullable{T}.GetValueOrDefault(ldloca v)' /// internal static bool MatchGetValueOrDefault(ILInstruction inst, ILVariable v) { return MatchGetValueOrDefault(inst, out ILVariable v2) && v == v2; } static bool MatchNull(ILInstruction inst, out IType underlyingType) { underlyingType = null; if (inst.MatchDefaultValue(out IType type)) { underlyingType = NullableType.GetUnderlyingType(type); return NullableType.IsNullable(type); } underlyingType = null; return false; } static bool MatchNull(ILInstruction inst, IType underlyingType) { return MatchNull(inst, out var utype) && utype.Equals(underlyingType); } #endregion } class NullableLiftingStatementTransform : IStatementTransform { public void Run(Block block, int pos, StatementTransformContext context) { new NullableLiftingTransform(context).RunStatements(block, pos); } } }