diff --git a/ICSharpCode.Decompiler.Tests/TestCases/Correctness/UnsafeCode.cs b/ICSharpCode.Decompiler.Tests/TestCases/Correctness/UnsafeCode.cs index c4bc7796b..cd8f4af05 100644 --- a/ICSharpCode.Decompiler.Tests/TestCases/Correctness/UnsafeCode.cs +++ b/ICSharpCode.Decompiler.Tests/TestCases/Correctness/UnsafeCode.cs @@ -70,6 +70,7 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Correctness char* ptr2 = stackalloc char[100]; for (int i = 0; i < count; i++) { ptr[i] = (char)i; + ptr2[i] = '\0'; } return this.UsePointer((double*)ptr); } @@ -77,7 +78,7 @@ namespace ICSharpCode.Decompiler.Tests.TestCases.Correctness public unsafe string StackAllocStruct(int count) { SimpleStruct* s = stackalloc SimpleStruct[checked(count * 2)]; - SimpleStruct* p = stackalloc SimpleStruct[10]; + SimpleStruct* _ = stackalloc SimpleStruct[10]; return this.UsePointer(&s->Y); } diff --git a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs index 1bea33cf8..4673733d9 100644 --- a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs +++ b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs @@ -229,55 +229,38 @@ namespace ICSharpCode.Decompiler.CSharp protected internal override TranslatedExpression VisitLocAlloc(LocAlloc inst, TranslationContext context) { - IType elementType; - TranslatedExpression countExpression = TranslatePointerArgument(inst.Argument, context, out elementType); + TranslatedExpression countExpression; + PointerType pointerType; + if (inst.Argument.MatchBinaryNumericInstruction(BinaryNumericOperator.Mul, out var left, out var right) + && right.UnwrapConv(ConversionKind.SignExtend).UnwrapConv(ConversionKind.ZeroExtend).MatchSizeOf(out var elementType)) + { + // Determine the element type from the sizeof + countExpression = Translate(left.UnwrapConv(ConversionKind.ZeroExtend)); + pointerType = new PointerType(elementType); + } else { + // Determine the element type from the expected pointer type in this context + pointerType = context.TypeHint as PointerType; + if (pointerType != null && GetPointerArithmeticOffset( + inst.Argument, Translate(inst.Argument), + pointerType, checkForOverflow: true, + unwrapZeroExtension: true + ) is TranslatedExpression offset) + { + countExpression = offset; + elementType = pointerType.ElementType; + } else { + elementType = compilation.FindType(KnownTypeCode.Byte); + pointerType = new PointerType(elementType); + countExpression = Translate(inst.Argument); + } + } countExpression = countExpression.ConvertTo(compilation.FindType(KnownTypeCode.Int32), this); - if (elementType == null) - elementType = compilation.FindType(KnownTypeCode.Byte); return new StackAllocExpression { Type = ConvertType(elementType), CountExpression = countExpression }.WithILInstruction(inst).WithRR(new ResolveResult(new PointerType(elementType))); } - /// - /// Translate the argument of an operation that deals with pointers: - /// * undoes the implicit multiplication with `sizeof(elementType)` and returns `elementType` - /// * on failure, translates the whole expression and returns `elementType = null`. - /// - TranslatedExpression TranslatePointerArgument(ILInstruction countExpr, TranslationContext context, out IType elementType) - { - ILInstruction left; - ILInstruction right; - if (countExpr.MatchBinaryNumericInstruction(BinaryNumericOperator.Mul, out left, out right) - && right.UnwrapConv(ConversionKind.SignExtend).UnwrapConv(ConversionKind.ZeroExtend).MatchSizeOf(out elementType)) - { - return Translate(left); - } - - var pointerTypeHint = context.TypeHint as PointerType; - if (pointerTypeHint == null) { - elementType = null; - return Translate(countExpr); - } - ResolveResult sizeofRR = resolver.ResolveSizeOf(pointerTypeHint.ElementType); - if (!(sizeofRR.IsCompileTimeConstant && sizeofRR.ConstantValue is int)) { - elementType = null; - return Translate(countExpr); - } - int typeSize = (int)sizeofRR.ConstantValue; - - if (countExpr.MatchBinaryNumericInstruction(BinaryNumericOperator.Mul, out left, out right) - && right.UnwrapConv(ConversionKind.SignExtend).UnwrapConv(ConversionKind.ZeroExtend).MatchLdcI4(typeSize)) - { - elementType = pointerTypeHint.ElementType; - return Translate(left); - } - - elementType = null; - return Translate(countExpr); - } - protected internal override TranslatedExpression VisitLdcI4(LdcI4 inst, TranslationContext context) { return new PrimitiveExpression(inst.Value) @@ -739,7 +722,8 @@ namespace ICSharpCode.Decompiler.CSharp } else { return null; } - TranslatedExpression offsetExpr = GetPointerArithmeticOffset() ?? FallBackToBytePointer(); + TranslatedExpression offsetExpr = GetPointerArithmeticOffset(byteOffsetInst, byteOffsetExpr, pointerType, inst.CheckForOverflow) + ?? FallBackToBytePointer(); if (!offsetExpr.Type.IsCSharpPrimitiveIntegerType()) { // pointer arithmetic accepts all primitive integer types, but no enums etc. StackType targetType = offsetExpr.Type.GetStackType() == StackType.I4 ? StackType.I4 : StackType.I8; @@ -765,39 +749,6 @@ namespace ICSharpCode.Decompiler.CSharp pointerType, BinaryOperatorExpression.GetLinqNodeType(operatorType, inst.CheckForOverflow), left.ResolveResult, right.ResolveResult)); - TranslatedExpression? GetPointerArithmeticOffset() - { - if (byteOffsetInst is Conv conv && conv.InputType == StackType.I8 && conv.ResultType == StackType.I) { - byteOffsetInst = conv.Argument; - } - int? elementSize = ComputeSizeOf(pointerType.ElementType); - if (elementSize == 1) { - return byteOffsetExpr; - } else if (byteOffsetInst is BinaryNumericInstruction mul && mul.Operator == BinaryNumericOperator.Mul) { - if (mul.CheckForOverflow != inst.CheckForOverflow) - return null; - if (mul.IsLifted) - return null; - if (elementSize > 0 && mul.Right.MatchLdcI(elementSize.Value)) { - return Translate(mul.Left); - } else if (mul.Right.UnwrapConv(ConversionKind.SignExtend) is SizeOf sizeOf && sizeOf.Type.Equals(pointerType.ElementType)) { - return Translate(mul.Left); - } - } else if (byteOffsetInst.MatchLdcI(out long val)) { - // If the offset is a constant, it's possible that the compiler - // constant-folded the multiplication. - if (elementSize > 0 && (val % elementSize == 0) && val > 0) { - val /= elementSize.Value; - if (val <= int.MaxValue) { - return new PrimitiveExpression((int)val) - .WithILInstruction(byteOffsetInst) - .WithRR(new ConstantResolveResult(compilation.FindType(KnownTypeCode.Int32), val)); - } - } - } - return null; - } - TranslatedExpression FallBackToBytePointer() { pointerType = new PointerType(compilation.FindType(KnownTypeCode.Byte)); @@ -805,6 +756,44 @@ namespace ICSharpCode.Decompiler.CSharp } } + TranslatedExpression? GetPointerArithmeticOffset(ILInstruction byteOffsetInst, TranslatedExpression byteOffsetExpr, + PointerType pointerType, bool checkForOverflow, bool unwrapZeroExtension = false) + { + if (byteOffsetInst is Conv conv && conv.InputType == StackType.I8 && conv.ResultType == StackType.I) { + byteOffsetInst = conv.Argument; + } + int? elementSize = ComputeSizeOf(pointerType.ElementType); + if (elementSize == 1) { + return byteOffsetExpr; + } else if (byteOffsetInst is BinaryNumericInstruction mul && mul.Operator == BinaryNumericOperator.Mul) { + if (mul.CheckForOverflow != checkForOverflow) + return null; + if (mul.IsLifted) + return null; + if (elementSize > 0 && mul.Right.MatchLdcI(elementSize.Value) + || mul.Right.UnwrapConv(ConversionKind.SignExtend) is SizeOf sizeOf && sizeOf.Type.Equals(pointerType.ElementType)) + { + var countOffsetInst = mul.Left; + if (unwrapZeroExtension) { + countOffsetInst = countOffsetInst.UnwrapConv(ConversionKind.ZeroExtend); + } + return Translate(countOffsetInst); + } + } else if (byteOffsetInst.MatchLdcI(out long val)) { + // If the offset is a constant, it's possible that the compiler + // constant-folded the multiplication. + if (elementSize > 0 && (val % elementSize == 0) && val > 0) { + val /= elementSize.Value; + if (val <= int.MaxValue) { + return new PrimitiveExpression((int)val) + .WithILInstruction(byteOffsetInst) + .WithRR(new ConstantResolveResult(compilation.FindType(KnownTypeCode.Int32), val)); + } + } + } + return null; + } + /// /// Called for divisions, detect and handles the code pattern: /// div(sub(a, b), sizeof(T))