diff --git a/ICSharpCode.Decompiler.Tests/TestCases/Correctness/UnsafeCode.cs b/ICSharpCode.Decompiler.Tests/TestCases/Correctness/UnsafeCode.cs
index c4bc7796b..cd8f4af05 100644
--- a/ICSharpCode.Decompiler.Tests/TestCases/Correctness/UnsafeCode.cs
+++ b/ICSharpCode.Decompiler.Tests/TestCases/Correctness/UnsafeCode.cs
@@ -70,6 +70,7 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Correctness
char* ptr2 = stackalloc char[100];
for (int i = 0; i < count; i++) {
ptr[i] = (char)i;
+ ptr2[i] = '\0';
}
return this.UsePointer((double*)ptr);
}
@@ -77,7 +78,7 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Correctness
public unsafe string StackAllocStruct(int count)
{
SimpleStruct* s = stackalloc SimpleStruct[checked(count * 2)];
- SimpleStruct* p = stackalloc SimpleStruct[10];
+ SimpleStruct* _ = stackalloc SimpleStruct[10];
return this.UsePointer(&s->Y);
}
diff --git a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs
index 1bea33cf8..4673733d9 100644
--- a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs
+++ b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs
@@ -229,55 +229,38 @@ namespace ICSharpCode.Decompiler.CSharp
protected internal override TranslatedExpression VisitLocAlloc(LocAlloc inst, TranslationContext context)
{
- IType elementType;
- TranslatedExpression countExpression = TranslatePointerArgument(inst.Argument, context, out elementType);
+ TranslatedExpression countExpression;
+ PointerType pointerType;
+ if (inst.Argument.MatchBinaryNumericInstruction(BinaryNumericOperator.Mul, out var left, out var right)
+ && right.UnwrapConv(ConversionKind.SignExtend).UnwrapConv(ConversionKind.ZeroExtend).MatchSizeOf(out var elementType))
+ {
+ // Determine the element type from the sizeof
+ countExpression = Translate(left.UnwrapConv(ConversionKind.ZeroExtend));
+ pointerType = new PointerType(elementType);
+ } else {
+ // Determine the element type from the expected pointer type in this context
+ pointerType = context.TypeHint as PointerType;
+ if (pointerType != null && GetPointerArithmeticOffset(
+ inst.Argument, Translate(inst.Argument),
+ pointerType, checkForOverflow: true,
+ unwrapZeroExtension: true
+ ) is TranslatedExpression offset)
+ {
+ countExpression = offset;
+ elementType = pointerType.ElementType;
+ } else {
+ elementType = compilation.FindType(KnownTypeCode.Byte);
+ pointerType = new PointerType(elementType);
+ countExpression = Translate(inst.Argument);
+ }
+ }
countExpression = countExpression.ConvertTo(compilation.FindType(KnownTypeCode.Int32), this);
- if (elementType == null)
- elementType = compilation.FindType(KnownTypeCode.Byte);
return new StackAllocExpression {
Type = ConvertType(elementType),
CountExpression = countExpression
}.WithILInstruction(inst).WithRR(new ResolveResult(new PointerType(elementType)));
}
- ///
- /// Translate the argument of an operation that deals with pointers:
- /// * undoes the implicit multiplication with `sizeof(elementType)` and returns `elementType`
- /// * on failure, translates the whole expression and returns `elementType = null`.
- ///
- TranslatedExpression TranslatePointerArgument(ILInstruction countExpr, TranslationContext context, out IType elementType)
- {
- ILInstruction left;
- ILInstruction right;
- if (countExpr.MatchBinaryNumericInstruction(BinaryNumericOperator.Mul, out left, out right)
- && right.UnwrapConv(ConversionKind.SignExtend).UnwrapConv(ConversionKind.ZeroExtend).MatchSizeOf(out elementType))
- {
- return Translate(left);
- }
-
- var pointerTypeHint = context.TypeHint as PointerType;
- if (pointerTypeHint == null) {
- elementType = null;
- return Translate(countExpr);
- }
- ResolveResult sizeofRR = resolver.ResolveSizeOf(pointerTypeHint.ElementType);
- if (!(sizeofRR.IsCompileTimeConstant && sizeofRR.ConstantValue is int)) {
- elementType = null;
- return Translate(countExpr);
- }
- int typeSize = (int)sizeofRR.ConstantValue;
-
- if (countExpr.MatchBinaryNumericInstruction(BinaryNumericOperator.Mul, out left, out right)
- && right.UnwrapConv(ConversionKind.SignExtend).UnwrapConv(ConversionKind.ZeroExtend).MatchLdcI4(typeSize))
- {
- elementType = pointerTypeHint.ElementType;
- return Translate(left);
- }
-
- elementType = null;
- return Translate(countExpr);
- }
-
protected internal override TranslatedExpression VisitLdcI4(LdcI4 inst, TranslationContext context)
{
return new PrimitiveExpression(inst.Value)
@@ -739,7 +722,8 @@ namespace ICSharpCode.Decompiler.CSharp
} else {
return null;
}
- TranslatedExpression offsetExpr = GetPointerArithmeticOffset() ?? FallBackToBytePointer();
+ TranslatedExpression offsetExpr = GetPointerArithmeticOffset(byteOffsetInst, byteOffsetExpr, pointerType, inst.CheckForOverflow)
+ ?? FallBackToBytePointer();
if (!offsetExpr.Type.IsCSharpPrimitiveIntegerType()) {
// pointer arithmetic accepts all primitive integer types, but no enums etc.
StackType targetType = offsetExpr.Type.GetStackType() == StackType.I4 ? StackType.I4 : StackType.I8;
@@ -765,39 +749,6 @@ namespace ICSharpCode.Decompiler.CSharp
pointerType, BinaryOperatorExpression.GetLinqNodeType(operatorType, inst.CheckForOverflow),
left.ResolveResult, right.ResolveResult));
- TranslatedExpression? GetPointerArithmeticOffset()
- {
- if (byteOffsetInst is Conv conv && conv.InputType == StackType.I8 && conv.ResultType == StackType.I) {
- byteOffsetInst = conv.Argument;
- }
- int? elementSize = ComputeSizeOf(pointerType.ElementType);
- if (elementSize == 1) {
- return byteOffsetExpr;
- } else if (byteOffsetInst is BinaryNumericInstruction mul && mul.Operator == BinaryNumericOperator.Mul) {
- if (mul.CheckForOverflow != inst.CheckForOverflow)
- return null;
- if (mul.IsLifted)
- return null;
- if (elementSize > 0 && mul.Right.MatchLdcI(elementSize.Value)) {
- return Translate(mul.Left);
- } else if (mul.Right.UnwrapConv(ConversionKind.SignExtend) is SizeOf sizeOf && sizeOf.Type.Equals(pointerType.ElementType)) {
- return Translate(mul.Left);
- }
- } else if (byteOffsetInst.MatchLdcI(out long val)) {
- // If the offset is a constant, it's possible that the compiler
- // constant-folded the multiplication.
- if (elementSize > 0 && (val % elementSize == 0) && val > 0) {
- val /= elementSize.Value;
- if (val <= int.MaxValue) {
- return new PrimitiveExpression((int)val)
- .WithILInstruction(byteOffsetInst)
- .WithRR(new ConstantResolveResult(compilation.FindType(KnownTypeCode.Int32), val));
- }
- }
- }
- return null;
- }
-
TranslatedExpression FallBackToBytePointer()
{
pointerType = new PointerType(compilation.FindType(KnownTypeCode.Byte));
@@ -805,6 +756,44 @@ namespace ICSharpCode.Decompiler.CSharp
}
}
+ TranslatedExpression? GetPointerArithmeticOffset(ILInstruction byteOffsetInst, TranslatedExpression byteOffsetExpr,
+ PointerType pointerType, bool checkForOverflow, bool unwrapZeroExtension = false)
+ {
+ if (byteOffsetInst is Conv conv && conv.InputType == StackType.I8 && conv.ResultType == StackType.I) {
+ byteOffsetInst = conv.Argument;
+ }
+ int? elementSize = ComputeSizeOf(pointerType.ElementType);
+ if (elementSize == 1) {
+ return byteOffsetExpr;
+ } else if (byteOffsetInst is BinaryNumericInstruction mul && mul.Operator == BinaryNumericOperator.Mul) {
+ if (mul.CheckForOverflow != checkForOverflow)
+ return null;
+ if (mul.IsLifted)
+ return null;
+ if (elementSize > 0 && mul.Right.MatchLdcI(elementSize.Value)
+ || mul.Right.UnwrapConv(ConversionKind.SignExtend) is SizeOf sizeOf && sizeOf.Type.Equals(pointerType.ElementType))
+ {
+ var countOffsetInst = mul.Left;
+ if (unwrapZeroExtension) {
+ countOffsetInst = countOffsetInst.UnwrapConv(ConversionKind.ZeroExtend);
+ }
+ return Translate(countOffsetInst);
+ }
+ } else if (byteOffsetInst.MatchLdcI(out long val)) {
+ // If the offset is a constant, it's possible that the compiler
+ // constant-folded the multiplication.
+ if (elementSize > 0 && (val % elementSize == 0) && val > 0) {
+ val /= elementSize.Value;
+ if (val <= int.MaxValue) {
+ return new PrimitiveExpression((int)val)
+ .WithILInstruction(byteOffsetInst)
+ .WithRR(new ConstantResolveResult(compilation.FindType(KnownTypeCode.Int32), val));
+ }
+ }
+ }
+ return null;
+ }
+
///
/// Called for divisions, detect and handles the code pattern:
/// div(sub(a, b), sizeof(T))