Browse Source

Fix #3577: Properly infer the switch governing type and preserve conversions

pull/3595/head
Siegfried Pammer 3 months ago
parent
commit
126e870a5a
  1. 17
      ICSharpCode.Decompiler.Tests/TestCases/Pretty/Switch.cs
  2. 116
      ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs
  3. 11
      ICSharpCode.Decompiler/IL/Transforms/SwitchOnNullableTransform.cs
  4. 7
      ICSharpCode.Decompiler/Util/LongSet.cs

17
ICSharpCode.Decompiler.Tests/TestCases/Pretty/Switch.cs

@ -1746,5 +1746,22 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Pretty @@ -1746,5 +1746,22 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Pretty
}
}
#endif
#if ROSLYN
public static int Issue3577(int what)
{
int result = 0;
switch ((long)what)
{
case 1L:
result = 1;
break;
case 2L:
result = 2;
break;
}
return result;
}
#endif
}
}

116
ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs

@ -3983,63 +3983,108 @@ namespace ICSharpCode.Decompiler.CSharp @@ -3983,63 +3983,108 @@ namespace ICSharpCode.Decompiler.CSharp
internal (TranslatedExpression, IType, StringToInt) TranslateSwitchValue(SwitchInstruction inst, bool isExpressionContext)
{
TranslatedExpression value;
IType type;
IType governingType;
// prepare expression and expected type
// first try to guess a governing type
if (inst.Value is StringToInt strToInt)
{
value = Translate(strToInt.Argument);
type = strToInt.ExpectedType ?? compilation.FindType(KnownTypeCode.String);
governingType = strToInt.ExpectedType ?? compilation.FindType(KnownTypeCode.String);
}
else
{
strToInt = null;
value = Translate(inst.Value);
type = inst.Type ?? value.Type;
}
governingType = inst.Type ?? value.Type;
// find and unwrap the input type
IType inputType = value.Type;
if (value.Expression is CastExpression && value.ResolveResult is ConversionResolveResult crr)
{
inputType = crr.Input.Type;
// validate the governing type
if (inst.Value.ResultType == StackType.I8)
{
if (governingType.GetStackType() != StackType.I8)
governingType = FindType(StackType.I8, governingType.GetSign());
}
else if (inst.Value.ResultType == StackType.I4)
{
if (governingType.GetStackType() != StackType.I4)
governingType = FindType(StackType.I4, governingType.GetSign());
if (governingType.IsSmallIntegerType())
{
var defaultSection = inst.GetDefaultSection();
int bits = 8 * governingType.GetSize();
int minValue = governingType.GetSign() == Sign.Unsigned ? 0 : -(1 << (bits - 1));
int maxValue = governingType.GetSign() == Sign.Unsigned ? (1 << bits) - 1 : (1 << (bits - 1)) - 1;
foreach (var section in inst.Sections)
{
if (section == defaultSection)
continue;
LongInterval interval = section.Labels.ContainingInterval();
if (interval.Start < minValue || interval.InclusiveEnd > maxValue)
{
// governing type is too small to hold all case values
governingType = FindType(StackType.I4, Sign.Signed);
break;
}
}
}
}
else
{
Debug.Assert(inst.Value.ResultType == StackType.O);
Debug.Assert(inst.IsLifted);
Debug.Assert(inst.Type == governingType);
}
}
inputType = NullableType.GetUnderlyingType(inputType).GetEnumUnderlyingType();
// check input/underlying type for compatibility
bool allowImplicitConversion;
if (IsCompatibleWithSwitch(inputType) || (strToInt != null && inputType.Equals(type)))
if (isExpressionContext)
{
allowImplicitConversion = !isExpressionContext;
value = value.ConvertTo(governingType, this, allowImplicitConversion: false);
}
else
{
var applicableImplicitConversionOperators = inputType.GetMethods(IsCompatibleImplicitConversionOperator).ToArray();
switch (applicableImplicitConversionOperators.Length)
value = value.ConvertTo(governingType, this, allowImplicitConversion: true);
var csharpGoverningType = GetCSharpSwitchGoverningType(value.Type);
if (!csharpGoverningType.Equals(governingType))
{
case 0:
allowImplicitConversion = !isExpressionContext;
break;
case 1:
allowImplicitConversion = !isExpressionContext;
// TODO validate
break;
default:
allowImplicitConversion = false;
break;
value = value.ConvertTo(governingType, this, allowImplicitConversion: false);
}
}
value = value.ConvertTo(type, this, allowImplicitConversion: allowImplicitConversion);
var caseType = strToInt != null
? compilation.FindType(KnownTypeCode.String)
: type;
: governingType;
return (value, caseType, strToInt);
}
static IType GetCSharpSwitchGoverningType(IType type)
{
if (IsCompatibleWithSwitch(type))
return type;
var applicableImplicitConversionOperators = type.GetMethods(IsImplicitConversionOperator)
.Where(m => IsCompatibleWithSwitch(m.ReturnType))
.ToArray();
if (applicableImplicitConversionOperators.Length != 1)
return type;
return applicableImplicitConversionOperators[0].ReturnType;
static bool IsImplicitConversionOperator(IMethod operatorMethod)
{
if (!operatorMethod.IsOperator)
return false;
if (operatorMethod.Name != "op_Implicit")
return false;
if (operatorMethod.Parameters.Count != 1)
return false;
return true;
}
static bool IsCompatibleWithSwitch(IType type)
{
return type.IsKnownType(KnownTypeCode.SByte)
type = NullableType.GetUnderlyingType(type);
return type.IsKnownType(KnownTypeCode.Boolean)
|| type.IsKnownType(KnownTypeCode.SByte)
|| type.IsKnownType(KnownTypeCode.Byte)
|| type.IsKnownType(KnownTypeCode.Int16)
|| type.IsKnownType(KnownTypeCode.UInt16)
@ -4050,17 +4095,6 @@ namespace ICSharpCode.Decompiler.CSharp @@ -4050,17 +4095,6 @@ namespace ICSharpCode.Decompiler.CSharp
|| type.IsKnownType(KnownTypeCode.Char)
|| type.IsKnownType(KnownTypeCode.String);
}
bool IsCompatibleImplicitConversionOperator(IMethod operatorMethod)
{
if (!operatorMethod.IsOperator)
return false;
if (operatorMethod.Name != "op_Implicit")
return false;
if (operatorMethod.Parameters.Count != 1)
return false;
return IsCompatibleWithSwitch(operatorMethod.ReturnType);
}
}
protected internal override TranslatedExpression VisitSwitchInstruction(SwitchInstruction inst, TranslationContext context)

11
ICSharpCode.Decompiler/IL/Transforms/SwitchOnNullableTransform.cs

@ -16,13 +16,11 @@ @@ -16,13 +16,11 @@
// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
using System;
using System.Collections.Generic;
using System.Linq;
using ICSharpCode.Decompiler.IL.ControlFlow;
using ICSharpCode.Decompiler.TypeSystem;
using ICSharpCode.Decompiler.Util;
namespace ICSharpCode.Decompiler.IL.Transforms
{
@ -117,14 +115,16 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -117,14 +115,16 @@ namespace ICSharpCode.Decompiler.IL.Transforms
return false;
if (!(switchBlock.Instructions[0] is SwitchInstruction switchInst))
return false;
newSwitch = BuildLiftedSwitch(nullCaseBlock, switchInst, new LdLoc(switchValueVar));
var nullableType = ((Call)getHasValue).Method.DeclaringType;
newSwitch = BuildLiftedSwitch(nullCaseBlock, switchInst, new LdLoc(switchValueVar), nullableType);
return true;
}
static SwitchInstruction BuildLiftedSwitch(Block nullCaseBlock, SwitchInstruction switchInst, ILInstruction switchValue)
static SwitchInstruction BuildLiftedSwitch(Block nullCaseBlock, SwitchInstruction switchInst, ILInstruction switchValue, IType nullableType)
{
SwitchInstruction newSwitch = new SwitchInstruction(switchValue);
newSwitch.IsLifted = true;
newSwitch.Type = nullableType;
newSwitch.Sections.AddRange(switchInst.Sections);
newSwitch.Sections.Add(new SwitchSection { Body = new Branch(nullCaseBlock), HasNullLabel = true });
return newSwitch;
@ -192,7 +192,8 @@ namespace ICSharpCode.Decompiler.IL.Transforms @@ -192,7 +192,8 @@ namespace ICSharpCode.Decompiler.IL.Transforms
switchValue = new LdLoc(v).WithILRange(target);
else
switchValue = new LdObj(target, ((CallInstruction)getHasValue).Method.DeclaringType);
newSwitch = BuildLiftedSwitch(nullCaseBlock, switchInst, switchValue);
var nullableType = ((Call)getHasValue).Method.DeclaringType;
newSwitch = BuildLiftedSwitch(nullCaseBlock, switchInst, switchValue, nullableType);
return true;
}
}

7
ICSharpCode.Decompiler/Util/LongSet.cs

@ -345,6 +345,13 @@ namespace ICSharpCode.Decompiler.Util @@ -345,6 +345,13 @@ namespace ICSharpCode.Decompiler.Util
get { return Intervals.SelectMany(i => i.Range()); }
}
public LongInterval ContainingInterval()
{
if (IsEmpty)
return default;
return new LongInterval(Intervals[0].Start, Intervals[Intervals.Length - 1].End);
}
public override string ToString()
{
return string.Join(",", Intervals);

Loading…
Cancel
Save