From 126e870a5a0bea7ab39f2db56f612523db0d1dd0 Mon Sep 17 00:00:00 2001 From: Siegfried Pammer Date: Tue, 14 Oct 2025 22:06:32 +0200 Subject: [PATCH] Fix #3577: Properly infer the switch governing type and preserve conversions --- .../TestCases/Pretty/Switch.cs | 17 +++ .../CSharp/ExpressionBuilder.cs | 116 +++++++++++------- .../Transforms/SwitchOnNullableTransform.cs | 11 +- ICSharpCode.Decompiler/Util/LongSet.cs | 7 ++ 4 files changed, 105 insertions(+), 46 deletions(-) diff --git a/ICSharpCode.Decompiler.Tests/TestCases/Pretty/Switch.cs b/ICSharpCode.Decompiler.Tests/TestCases/Pretty/Switch.cs index 605eafac1..ee72b37d6 100644 --- a/ICSharpCode.Decompiler.Tests/TestCases/Pretty/Switch.cs +++ b/ICSharpCode.Decompiler.Tests/TestCases/Pretty/Switch.cs @@ -1746,5 +1746,22 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Pretty } } #endif + +#if ROSLYN + public static int Issue3577(int what) + { + int result = 0; + switch ((long)what) + { + case 1L: + result = 1; + break; + case 2L: + result = 2; + break; + } + return result; + } +#endif } } diff --git a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs index 7701ae7ce..350a8287c 100644 --- a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs +++ b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs @@ -3983,63 +3983,108 @@ namespace ICSharpCode.Decompiler.CSharp internal (TranslatedExpression, IType, StringToInt) TranslateSwitchValue(SwitchInstruction inst, bool isExpressionContext) { TranslatedExpression value; - IType type; + IType governingType; // prepare expression and expected type + // first try to guess a governing type if (inst.Value is StringToInt strToInt) { value = Translate(strToInt.Argument); - type = strToInt.ExpectedType ?? compilation.FindType(KnownTypeCode.String); + governingType = strToInt.ExpectedType ?? compilation.FindType(KnownTypeCode.String); } else { strToInt = null; value = Translate(inst.Value); - type = inst.Type ?? value.Type; - } + governingType = inst.Type ?? value.Type; - // find and unwrap the input type - IType inputType = value.Type; - if (value.Expression is CastExpression && value.ResolveResult is ConversionResolveResult crr) - { - inputType = crr.Input.Type; + // validate the governing type + if (inst.Value.ResultType == StackType.I8) + { + if (governingType.GetStackType() != StackType.I8) + governingType = FindType(StackType.I8, governingType.GetSign()); + } + else if (inst.Value.ResultType == StackType.I4) + { + if (governingType.GetStackType() != StackType.I4) + governingType = FindType(StackType.I4, governingType.GetSign()); + if (governingType.IsSmallIntegerType()) + { + var defaultSection = inst.GetDefaultSection(); + int bits = 8 * governingType.GetSize(); + int minValue = governingType.GetSign() == Sign.Unsigned ? 0 : -(1 << (bits - 1)); + int maxValue = governingType.GetSign() == Sign.Unsigned ? (1 << bits) - 1 : (1 << (bits - 1)) - 1; + foreach (var section in inst.Sections) + { + if (section == defaultSection) + continue; + LongInterval interval = section.Labels.ContainingInterval(); + if (interval.Start < minValue || interval.InclusiveEnd > maxValue) + { + // governing type is too small to hold all case values + governingType = FindType(StackType.I4, Sign.Signed); + break; + } + } + } + } + else + { + Debug.Assert(inst.Value.ResultType == StackType.O); + Debug.Assert(inst.IsLifted); + Debug.Assert(inst.Type == governingType); + } } - inputType = NullableType.GetUnderlyingType(inputType).GetEnumUnderlyingType(); - // check input/underlying type for compatibility - bool allowImplicitConversion; - if (IsCompatibleWithSwitch(inputType) || (strToInt != null && inputType.Equals(type))) + if (isExpressionContext) { - allowImplicitConversion = !isExpressionContext; + value = value.ConvertTo(governingType, this, allowImplicitConversion: false); } else { - var applicableImplicitConversionOperators = inputType.GetMethods(IsCompatibleImplicitConversionOperator).ToArray(); - switch (applicableImplicitConversionOperators.Length) + value = value.ConvertTo(governingType, this, allowImplicitConversion: true); + + var csharpGoverningType = GetCSharpSwitchGoverningType(value.Type); + if (!csharpGoverningType.Equals(governingType)) { - case 0: - allowImplicitConversion = !isExpressionContext; - break; - case 1: - allowImplicitConversion = !isExpressionContext; - // TODO validate - break; - default: - allowImplicitConversion = false; - break; + value = value.ConvertTo(governingType, this, allowImplicitConversion: false); } } - value = value.ConvertTo(type, this, allowImplicitConversion: allowImplicitConversion); - var caseType = strToInt != null ? compilation.FindType(KnownTypeCode.String) - : type; + : governingType; return (value, caseType, strToInt); + } + + static IType GetCSharpSwitchGoverningType(IType type) + { + if (IsCompatibleWithSwitch(type)) + return type; + + var applicableImplicitConversionOperators = type.GetMethods(IsImplicitConversionOperator) + .Where(m => IsCompatibleWithSwitch(m.ReturnType)) + .ToArray(); + if (applicableImplicitConversionOperators.Length != 1) + return type; + return applicableImplicitConversionOperators[0].ReturnType; + + static bool IsImplicitConversionOperator(IMethod operatorMethod) + { + if (!operatorMethod.IsOperator) + return false; + if (operatorMethod.Name != "op_Implicit") + return false; + if (operatorMethod.Parameters.Count != 1) + return false; + return true; + } static bool IsCompatibleWithSwitch(IType type) { - return type.IsKnownType(KnownTypeCode.SByte) + type = NullableType.GetUnderlyingType(type); + return type.IsKnownType(KnownTypeCode.Boolean) + || type.IsKnownType(KnownTypeCode.SByte) || type.IsKnownType(KnownTypeCode.Byte) || type.IsKnownType(KnownTypeCode.Int16) || type.IsKnownType(KnownTypeCode.UInt16) @@ -4050,17 +4095,6 @@ namespace ICSharpCode.Decompiler.CSharp || type.IsKnownType(KnownTypeCode.Char) || type.IsKnownType(KnownTypeCode.String); } - - bool IsCompatibleImplicitConversionOperator(IMethod operatorMethod) - { - if (!operatorMethod.IsOperator) - return false; - if (operatorMethod.Name != "op_Implicit") - return false; - if (operatorMethod.Parameters.Count != 1) - return false; - return IsCompatibleWithSwitch(operatorMethod.ReturnType); - } } protected internal override TranslatedExpression VisitSwitchInstruction(SwitchInstruction inst, TranslationContext context) diff --git a/ICSharpCode.Decompiler/IL/Transforms/SwitchOnNullableTransform.cs b/ICSharpCode.Decompiler/IL/Transforms/SwitchOnNullableTransform.cs index e800d92a7..e5181f578 100644 --- a/ICSharpCode.Decompiler/IL/Transforms/SwitchOnNullableTransform.cs +++ b/ICSharpCode.Decompiler/IL/Transforms/SwitchOnNullableTransform.cs @@ -16,13 +16,11 @@ // OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -using System; using System.Collections.Generic; using System.Linq; using ICSharpCode.Decompiler.IL.ControlFlow; using ICSharpCode.Decompiler.TypeSystem; -using ICSharpCode.Decompiler.Util; namespace ICSharpCode.Decompiler.IL.Transforms { @@ -117,14 +115,16 @@ namespace ICSharpCode.Decompiler.IL.Transforms return false; if (!(switchBlock.Instructions[0] is SwitchInstruction switchInst)) return false; - newSwitch = BuildLiftedSwitch(nullCaseBlock, switchInst, new LdLoc(switchValueVar)); + var nullableType = ((Call)getHasValue).Method.DeclaringType; + newSwitch = BuildLiftedSwitch(nullCaseBlock, switchInst, new LdLoc(switchValueVar), nullableType); return true; } - static SwitchInstruction BuildLiftedSwitch(Block nullCaseBlock, SwitchInstruction switchInst, ILInstruction switchValue) + static SwitchInstruction BuildLiftedSwitch(Block nullCaseBlock, SwitchInstruction switchInst, ILInstruction switchValue, IType nullableType) { SwitchInstruction newSwitch = new SwitchInstruction(switchValue); newSwitch.IsLifted = true; + newSwitch.Type = nullableType; newSwitch.Sections.AddRange(switchInst.Sections); newSwitch.Sections.Add(new SwitchSection { Body = new Branch(nullCaseBlock), HasNullLabel = true }); return newSwitch; @@ -192,7 +192,8 @@ namespace ICSharpCode.Decompiler.IL.Transforms switchValue = new LdLoc(v).WithILRange(target); else switchValue = new LdObj(target, ((CallInstruction)getHasValue).Method.DeclaringType); - newSwitch = BuildLiftedSwitch(nullCaseBlock, switchInst, switchValue); + var nullableType = ((Call)getHasValue).Method.DeclaringType; + newSwitch = BuildLiftedSwitch(nullCaseBlock, switchInst, switchValue, nullableType); return true; } } diff --git a/ICSharpCode.Decompiler/Util/LongSet.cs b/ICSharpCode.Decompiler/Util/LongSet.cs index 3097f47e3..dc6caa9fa 100644 --- a/ICSharpCode.Decompiler/Util/LongSet.cs +++ b/ICSharpCode.Decompiler/Util/LongSet.cs @@ -345,6 +345,13 @@ namespace ICSharpCode.Decompiler.Util get { return Intervals.SelectMany(i => i.Range()); } } + public LongInterval ContainingInterval() + { + if (IsEmpty) + return default; + return new LongInterval(Intervals[0].Start, Intervals[Intervals.Length - 1].End); + } + public override string ToString() { return string.Join(",", Intervals);