diff --git a/ICSharpCode.Decompiler.Tests/TestCases/Pretty/NullPropagation.cs b/ICSharpCode.Decompiler.Tests/TestCases/Pretty/NullPropagation.cs index 6d2191833..6e927db63 100644 --- a/ICSharpCode.Decompiler.Tests/TestCases/Pretty/NullPropagation.cs +++ b/ICSharpCode.Decompiler.Tests/TestCases/Pretty/NullPropagation.cs @@ -68,6 +68,48 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Pretty { } } + + private class Container + { + public GenericStruct Other; + } + + private struct GenericStruct + { + public T1 Field1; + public T2 Field2; + public Container Other; + + public override string ToString() + { + return "(" + Field1?.ToString() + ", " + Field2?.ToString() + ")"; + } + + public int? GetTextLength() + { + return Field1?.ToString().Length + Field2?.ToString().Length + 4; + } + + public string Chain1() + { + return Other?.Other.Other?.Other.Field1?.ToString(); + } + + public string Chain2() + { + return Other?.Other.Other?.Other.Field1?.ToString()?.GetType().Name; + } + + public int? Test2() + { + return Field1?.ToString().Length ?? 42; + } + + public int? GetTextLengthNRE() + { + return (Field1?.ToString()).Length; + } + } public interface ITest { @@ -243,12 +285,10 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Pretty return t?.Int(); } - // See also: https://github.com/icsharpcode/ILSpy/issues/1050 - // The C# compiler generates pretty weird code in this case. - //private static int? GenericRefUnconstrainedInt(ref T t) where T : ITest - //{ - // return t?.Int(); - //} + private static int? GenericRefUnconstrainedInt(ref T t) where T : ITest + { + return t?.Int(); + } private static int? GenericRefClassConstraintInt(ref T t) where T : class, ITest { diff --git a/ICSharpCode.Decompiler/IL/Transforms/NullPropagationTransform.cs b/ICSharpCode.Decompiler/IL/Transforms/NullPropagationTransform.cs index 37f8ffcd3..93d7f90f5 100644 --- a/ICSharpCode.Decompiler/IL/Transforms/NullPropagationTransform.cs +++ b/ICSharpCode.Decompiler/IL/Transforms/NullPropagationTransform.cs @@ -39,7 +39,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms } readonly ILTransformContext context; - + public NullPropagationTransform(ILTransformContext context) { this.context = context; @@ -59,6 +59,10 @@ namespace ICSharpCode.Decompiler.IL.Transforms /// nullable type, used by reference (comparison is 'call get_HasValue(ldloc(testedVar))') /// NullableByReference, + /// + /// unconstrained generic type (see the pattern described in TransformNullPropagationOnUnconstrainedGenericExpression) + /// + UnconstrainedType, } /// @@ -142,18 +146,46 @@ namespace ICSharpCode.Decompiler.IL.Transforms /// internal void RunStatements(Block block, int pos) { - var ifInst = block.Instructions[pos] as IfInstruction; - if (ifInst == null || !ifInst.FalseInst.MatchNop()) - return; - if (ifInst.Condition is Comp comp && comp.Kind == ComparisonKind.Inequality - && comp.Left.MatchLdLoc(out var testedVar) && comp.Right.MatchLdNull()) { - TryNullPropForVoidCall(testedVar, Mode.ReferenceType, ifInst.TrueInst as Block, ifInst); - } else if (NullableLiftingTransform.MatchHasValueCall(ifInst.Condition, out ILInstruction arg)) { - if (arg.MatchLdLoca(out testedVar)) { - TryNullPropForVoidCall(testedVar, Mode.NullableByValue, ifInst.TrueInst as Block, ifInst); - } else if (arg.MatchLdLoc(out testedVar)) { - TryNullPropForVoidCall(testedVar, Mode.NullableByReference, ifInst.TrueInst as Block, ifInst); + if (block.Instructions[pos] is IfInstruction ifInst && ifInst.FalseInst.MatchNop()) { + if (ifInst.Condition is Comp comp && comp.Kind == ComparisonKind.Inequality + && comp.Left.MatchLdLoc(out var testedVar) && comp.Right.MatchLdNull()) { + TryNullPropForVoidCall(testedVar, Mode.ReferenceType, ifInst.TrueInst as Block, ifInst); + } else if (NullableLiftingTransform.MatchHasValueCall(ifInst.Condition, out ILInstruction arg)) { + if (arg.MatchLdLoca(out testedVar)) { + TryNullPropForVoidCall(testedVar, Mode.NullableByValue, ifInst.TrueInst as Block, ifInst); + } else if (arg.MatchLdLoc(out testedVar)) { + TryNullPropForVoidCall(testedVar, Mode.NullableByReference, ifInst.TrueInst as Block, ifInst); + } + } + } + if (TransformNullPropagationOnUnconstrainedGenericExpression(block, pos, out var testedVariable, out var nonNullInst, out var nullInst, out var endBlock)) { + var parentInstruction = nonNullInst.Parent; + var replacement = TryNullPropagation(testedVariable, nonNullInst, nullInst, Mode.UnconstrainedType); + if (replacement == null) + return; + context.Step("TransformNullPropagationOnUnconstrainedGenericExpression", block); + switch (parentInstruction) { + case StLoc stloc: + stloc.Value = replacement; + break; + case Leave leave: + leave.Value = replacement; + break; + default: + // if this ever happens, the pattern checked by TransformNullPropagationOnUnconstrainedGenericExpression + // has changed, but this part of the code was not properly adjusted. + throw new NotSupportedException(); + } + // Remove the fallback conditions and blocks + block.Instructions.RemoveRange(pos + 1, 2); + // if the endBlock is only reachable through the current block, + // combine both blocks. + if (endBlock?.IncomingEdgeCount == 1) { + block.Instructions.AddRange(endBlock.Instructions); + block.Instructions.RemoveAt(pos + 2); + endBlock.Remove(); } + ILInlining.InlineIfPossible(block, pos, context); } } @@ -261,6 +293,9 @@ namespace ICSharpCode.Decompiler.IL.Transforms case Mode.NullableByReference: return NullableLiftingTransform.MatchGetValueOrDefault(inst, out ILInstruction arg) && arg.MatchLdLoc(testedVar); + case Mode.UnconstrainedType: + // unconstrained generic type (expect: ldloc(testedVar)) + return inst.MatchLdLoc(testedVar); default: throw new ArgumentOutOfRangeException(nameof(mode)); } @@ -288,6 +323,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms ILInstruction replacement; switch (mode) { case Mode.ReferenceType: + case Mode.UnconstrainedType: // Wrap varLoad in nullable.unwrap: replacement = new NullableUnwrap(varLoad.ResultType, varLoad, refInput: varLoad.ResultType == StackType.Ref); break; @@ -310,6 +346,147 @@ namespace ICSharpCode.Decompiler.IL.Transforms } oldParentChildren[oldChildIndex] = replacement; } + + // stloc target(targetInst) + // stloc defaultTemporary(default.value type) + // if (logic.not(comp.o(box `0(ldloc defaultTemporary) != ldnull))) Block fallbackBlock { + // stloc defaultTemporary(ldobj type(ldloc target)) + // stloc target(ldloca defaultTemporary) + // if (comp.o(ldloc defaultTemporary == ldnull)) Block fallbackBlock2 { + // stloc resultTemporary(nullInst) + // br endBlock + // } + // } + // stloc resultTemporary(constrained[type].call_instruction(ldloc target, ...)) + // br endBlock + // => + // stloc resultTemporary(nullable.rewrap(constrained[type].call_instruction(nullable.unwrap(targetInst), ...))) + // + // -or- + // + // stloc target(targetInst) + // stloc defaultTemporary(default.value type) + // if (logic.not(comp.o(box `0(ldloc defaultTemporary) != ldnull))) Block fallbackBlock { + // stloc defaultTemporary(ldobj type(ldloc target)) + // stloc target(ldloca defaultTemporary) + // if (comp.o(ldloc defaultTemporary == ldnull)) Block fallbackBlock2 { + // leave(nullInst) + // } + // } + // leave (constrained[type].call_instruction(ldloc target, ...)) + // => + // leave (nullable.rewrap(constrained[type].call_instruction(nullable.unwrap(targetInst), ...))) + private bool TransformNullPropagationOnUnconstrainedGenericExpression(Block block, int pos, + out ILVariable target, out ILInstruction nonNullInst, out ILInstruction nullInst, out Block endBlock) + { + target = null; + nonNullInst = null; + nullInst = null; + endBlock = null; + if (pos + 3 >= block.Instructions.Count) + return false; + // stloc target(...) + if (!block.Instructions[pos].MatchStLoc(out target, out _)) + return false; + if (!(target.Kind == VariableKind.StackSlot && target.LoadCount == 2 && target.StoreCount == 2)) + return false; + // stloc defaultTemporary(default.value type) + if (!(block.Instructions[pos + 1].MatchStLoc(out var defaultTemporary, out var defaultExpression) && defaultExpression.MatchDefaultValue(out var type))) + return false; + // In the above pattern the defaultTemporary variable is used two times in stloc and ldloc instructions and once in a ldloca instruction + if (!(defaultTemporary.Kind == VariableKind.Local && defaultTemporary.LoadCount == 2 && defaultTemporary.StoreCount == 2 && defaultTemporary.AddressCount == 1)) + return false; + // if (logic.not(comp.o(box `0(ldloc defaultTemporary) != ldnull))) Block fallbackBlock + if (!(block.Instructions[pos + 2].MatchIfInstruction(out var condition, out var fallbackBlock1) && condition.MatchCompEqualsNull(out var arg) && arg.MatchLdLoc(defaultTemporary))) + return false; + if (!MatchStLocResultTemporary(block, pos, type, target, defaultTemporary, fallbackBlock1, out nonNullInst, out nullInst, out endBlock) + && !MatchLeaveResult(block, pos, type, target, defaultTemporary, fallbackBlock1, out nonNullInst, out nullInst)) + return false; + return true; + } + + // stloc resultTemporary(constrained[type].call_instruction(ldloc target, ...)) + // br endBlock + private bool MatchStLocResultTemporary(Block block, int pos, IType type, ILVariable target, ILVariable defaultTemporary, ILInstruction fallbackBlock, out ILInstruction nonNullInst, out ILInstruction nullInst, out Block endBlock) + { + endBlock = null; + nonNullInst = null; + nullInst = null; + + if (pos + 4 >= block.Instructions.Count) + return false; + // stloc resultTemporary(constrained[type].call_instruction(ldloc target, ...)) + if (!(block.Instructions[pos + 3].MatchStLoc(out var resultTemporary, out nonNullInst))) + return false; + // br endBlock + if (!(block.Instructions[pos + 4].MatchBranch(out endBlock))) + return false; + // Analyze Block fallbackBlock + if (!(fallbackBlock is Block b && IsFallbackBlock(b, type, target, defaultTemporary, resultTemporary, endBlock, out nullInst))) + return false; + + return true; + } + + private bool MatchLeaveResult(Block block, int pos, IType type, ILVariable target, ILVariable defaultTemporary, ILInstruction fallbackBlock, out ILInstruction nonNullInst, out ILInstruction nullInst) + { + nonNullInst = null; + nullInst = null; + + // leave (constrained[type].call_instruction(ldloc target, ...)) + if (!(block.Instructions[pos + 3] is Leave leave && leave.IsLeavingFunction)) + return false; + nonNullInst = leave.Value; + // Analyze Block fallbackBlock + if (!(fallbackBlock is Block b && IsFallbackBlock(b, type, target, defaultTemporary, null, leave.TargetContainer, out nullInst))) + return false; + return true; + } + + // Block fallbackBlock { + // stloc defaultTemporary(ldobj type(ldloc target)) + // stloc target(ldloca defaultTemporary) + // if (comp.o(ldloc defaultTemporary == ldnull)) Block fallbackBlock { + // stloc resultTemporary(ldnull) + // br endBlock + // } + // } + private bool IsFallbackBlock(Block block, IType type, ILVariable target, ILVariable defaultTemporary, ILVariable resultTemporary, ILInstruction endBlockOrLeaveContainer, out ILInstruction nullInst) + { + nullInst = null; + if (!(block.Instructions.Count == 3)) + return false; + // stloc defaultTemporary(ldobj type(ldloc target)) + if (!(block.Instructions[0].MatchStLoc(defaultTemporary, out var value))) + return false; + if (!(value.MatchLdObj(out var inst, out var t) && type.Equals(t) && inst.MatchLdLoc(target))) + return false; + // stloc target(ldloca defaultTemporary) + if (!(block.Instructions[1].MatchStLoc(target, out var defaultAddress) && defaultAddress.MatchLdLoca(defaultTemporary))) + return false; + // if (comp.o(ldloc defaultTemporary == ldnull)) Block fallbackBlock + if (!(block.Instructions[2].MatchIfInstruction(out var condition, out var tmp) && condition.MatchCompEqualsNull(out var arg) && arg.MatchLdLoc(defaultTemporary))) + return false; + // Block fallbackBlock { + // stloc resultTemporary(nullInst) + // br endBlock + // } + var fallbackInst = Block.Unwrap(tmp); + if (fallbackInst is Block fallbackBlock && endBlockOrLeaveContainer is Block endBlock) { + if (!(fallbackBlock.Instructions.Count == 2)) + return false; + if (!fallbackBlock.Instructions[0].MatchStLoc(resultTemporary, out nullInst)) + return false; + if (!fallbackBlock.Instructions[1].MatchBranch(endBlock)) + return false; + } else { + if (!(fallbackInst is Leave fallbackLeave && endBlockOrLeaveContainer is BlockContainer leaveContainer + && fallbackLeave.TargetContainer == leaveContainer && !fallbackLeave.Value.MatchNop())) + return false; + nullInst = fallbackLeave.Value; + } + return true; + } } class NullPropagationStatementTransform : IStatementTransform