Browse Source

Handle pointer arithmetic.

pull/100/head
Daniel Grunwald 14 years ago
parent
commit
bc229df848
  1. 92
      ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs
  2. 24
      ICSharpCode.Decompiler/Ast/Transforms/IntroduceUnsafeModifier.cs
  3. 49
      ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs
  4. 81
      ICSharpCode.Decompiler/ILAst/TypeAnalysis.cs
  5. 9
      ICSharpCode.Decompiler/Tests/UnsafeCode.cs

92
ICSharpCode.Decompiler/Ast/AstMethodBodyBuilder.cs

@ -4,7 +4,7 @@ using System.Collections.Generic; @@ -4,7 +4,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using ICSharpCode.Decompiler.Ast.Transforms;
using ICSharpCode.Decompiler.ILAst;
using ICSharpCode.NRefactory.CSharp;
using ICSharpCode.NRefactory.Utils;
@ -225,9 +225,36 @@ namespace ICSharpCode.Decompiler.Ast @@ -225,9 +225,36 @@ namespace ICSharpCode.Decompiler.Ast
switch(byteCode.Code) {
#region Arithmetic
case ILCode.Add: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Add, arg2);
case ILCode.Add_Ovf: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Add, arg2);
case ILCode.Add_Ovf_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Add, arg2);
case ILCode.Add:
case ILCode.Add_Ovf:
case ILCode.Add_Ovf_Un:
{
if (byteCode.InferredType is PointerType) {
if (byteCode.Arguments[0].ExpectedType is PointerType) {
arg2 = DivideBySize(arg2, ((PointerType)byteCode.InferredType).ElementType);
return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Add, arg2)
.WithAnnotation(IntroduceUnsafeModifier.PointerArithmeticAnnotation);
} else if (byteCode.Arguments[1].ExpectedType is PointerType) {
arg1 = DivideBySize(arg1, ((PointerType)byteCode.InferredType).ElementType);
return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Add, arg2)
.WithAnnotation(IntroduceUnsafeModifier.PointerArithmeticAnnotation);
}
}
return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Add, arg2);
}
case ILCode.Sub:
case ILCode.Sub_Ovf:
case ILCode.Sub_Ovf_Un:
{
if (byteCode.InferredType is PointerType) {
if (byteCode.Arguments[0].ExpectedType is PointerType) {
arg2 = DivideBySize(arg2, ((PointerType)byteCode.InferredType).ElementType);
return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Subtract, arg2)
.WithAnnotation(IntroduceUnsafeModifier.PointerArithmeticAnnotation);
}
}
return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Subtract, arg2);
}
case ILCode.Div: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Divide, arg2);
case ILCode.Div_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Divide, arg2);
case ILCode.Mul: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Multiply, arg2);
@ -235,9 +262,6 @@ namespace ICSharpCode.Decompiler.Ast @@ -235,9 +262,6 @@ namespace ICSharpCode.Decompiler.Ast
case ILCode.Mul_Ovf_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Multiply, arg2);
case ILCode.Rem: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Modulus, arg2);
case ILCode.Rem_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Modulus, arg2);
case ILCode.Sub: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Subtract, arg2);
case ILCode.Sub_Ovf: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Subtract, arg2);
case ILCode.Sub_Ovf_Un: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.Subtract, arg2);
case ILCode.And: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.BitwiseAnd, arg2);
case ILCode.Or: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.BitwiseOr, arg2);
case ILCode.Xor: return new Ast.BinaryOperatorExpression(arg1, BinaryOperatorType.ExclusiveOr, arg2);
@ -472,7 +496,20 @@ namespace ICSharpCode.Decompiler.Ast @@ -472,7 +496,20 @@ namespace ICSharpCode.Decompiler.Ast
return InlineAssembly(byteCode, args);
}
case ILCode.Leave: return new GotoStatement() { Label = ((ILLabel)operand).Name };
case ILCode.Localloc: return InlineAssembly(byteCode, args);
case ILCode.Localloc:
{
PointerType ptrType = byteCode.InferredType as PointerType;
TypeReference type;
if (ptrType != null) {
type = ptrType.ElementType;
} else {
type = typeSystem.Byte;
}
return new StackAllocExpression {
Type = AstBuilder.ConvertType(type),
CountExpression = DivideBySize(arg1, type)
};
}
case ILCode.Mkrefany: return InlineAssembly(byteCode, args);
case ILCode.Newobj: {
Cecil.TypeReference declaringType = ((MethodReference)operand).DeclaringType;
@ -542,6 +579,45 @@ namespace ICSharpCode.Decompiler.Ast @@ -542,6 +579,45 @@ namespace ICSharpCode.Decompiler.Ast
}
}
/// <summary>
/// Divides expr by the size of 'type'.
/// </summary>
Expression DivideBySize(Expression expr, TypeReference type)
{
CastExpression cast = expr as CastExpression;
if (cast != null && cast.Type is PrimitiveType && ((PrimitiveType)cast.Type).Keyword == "int")
expr = cast.Expression.Detach();
Expression sizeOfExpression;
switch (TypeAnalysis.GetInformationAmount(type)) {
case 1:
case 8:
sizeOfExpression = new PrimitiveExpression(1);
break;
case 16:
sizeOfExpression = new PrimitiveExpression(2);
break;
case 32:
sizeOfExpression = new PrimitiveExpression(4);
break;
case 64:
sizeOfExpression = new PrimitiveExpression(8);
break;
default:
sizeOfExpression = new SizeOfExpression { Type = AstBuilder.ConvertType(type) };
break;
}
BinaryOperatorExpression boe = expr as BinaryOperatorExpression;
if (boe != null && boe.Operator == BinaryOperatorType.Multiply && sizeOfExpression.Match(boe.Right) != null)
return boe.Left.Detach();
if (sizeOfExpression.Match(expr) != null)
return new PrimitiveExpression(1);
return new BinaryOperatorExpression(expr, BinaryOperatorType.Divide, sizeOfExpression);
}
Expression MakeDefaultValue(TypeReference type)
{
TypeDefinition typeDef = type.Resolve();

24
ICSharpCode.Decompiler/Ast/Transforms/IntroduceUnsafeModifier.cs

@ -8,6 +8,10 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -8,6 +8,10 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
{
public class IntroduceUnsafeModifier : DepthFirstAstVisitor<object, bool>, IAstTransform
{
public static readonly object PointerArithmeticAnnotation = new PointerArithmetic();
sealed class PointerArithmetic {}
public void Run(AstNode compilationUnit)
{
compilationUnit.AcceptVisitor(this, null);
@ -42,17 +46,23 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -42,17 +46,23 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
public override bool VisitUnaryOperatorExpression(UnaryOperatorExpression unaryOperatorExpression, object data)
{
base.VisitUnaryOperatorExpression(unaryOperatorExpression, data);
bool result = base.VisitUnaryOperatorExpression(unaryOperatorExpression, data);
if (unaryOperatorExpression.Operator == UnaryOperatorType.Dereference) {
BinaryOperatorExpression bop = unaryOperatorExpression.Expression as BinaryOperatorExpression;
if (bop != null && bop.Operator == BinaryOperatorType.Add) {
// TODO: transform "*(ptr + int)" to "ptr[int]"
if (bop != null && bop.Operator == BinaryOperatorType.Add && bop.Annotation<PointerArithmetic>() != null) {
// transform "*(ptr + int)" to "ptr[int]"
IndexerExpression indexer = new IndexerExpression();
indexer.Target = bop.Left.Detach();
indexer.Arguments.Add(bop.Right.Detach());
indexer.CopyAnnotationsFrom(unaryOperatorExpression);
indexer.CopyAnnotationsFrom(bop);
unaryOperatorExpression.ReplaceWith(indexer);
}
return true;
} else if (unaryOperatorExpression.Operator == UnaryOperatorType.AddressOf) {
return true;
} else {
return false;
return result;
}
}
@ -71,5 +81,11 @@ namespace ICSharpCode.Decompiler.Ast.Transforms @@ -71,5 +81,11 @@ namespace ICSharpCode.Decompiler.Ast.Transforms
}
return result;
}
public override bool VisitStackAllocExpression(StackAllocExpression stackAllocExpression, object data)
{
base.VisitStackAllocExpression(stackAllocExpression, data);
return true;
}
}
}

49
ICSharpCode.Decompiler/ILAst/PeepholeTransform.cs

@ -320,25 +320,25 @@ namespace ICSharpCode.Decompiler.ILAst @@ -320,25 +320,25 @@ namespace ICSharpCode.Decompiler.ILAst
}
// find where pinnedVar is reset to 0:
for (int j = initEndPos; j < block.Body.Count; j++) {
int j;
for (j = initEndPos; j < block.Body.Count; j++) {
ILVariable v2;
ILExpression storedVal;
// stloc(pinned_Var, conv.u(ldc.i4(0)))
if (block.Body[j].Match(ILCode.Stloc, out v2, out storedVal) && v2 == pinnedVar) {
if (IsNullOrZero(storedVal)) {
// Create fixed statement from i to j
fixedStmt = new ILFixedStatement();
fixedStmt.Initializers.Add(initValue);
fixedStmt.BodyBlock = new ILBlock(block.Body.GetRange(initEndPos, j - initEndPos)); // from initEndPos to j-1 (inclusive)
block.Body.RemoveRange(i + 1, j - i); // from i+1 to j (inclusive)
block.Body[i] = fixedStmt;
if (pinnedVar.Type.IsByReference)
pinnedVar.Type = new PointerType(((ByReferenceType)pinnedVar.Type).ElementType);
break;
}
}
}
// Create fixed statement from i to j
fixedStmt = new ILFixedStatement();
fixedStmt.Initializers.Add(initValue);
fixedStmt.BodyBlock = new ILBlock(block.Body.GetRange(initEndPos, j - initEndPos)); // from initEndPos to j-1 (inclusive)
block.Body.RemoveRange(i + 1, Math.Min(j, block.Body.Count - 1) - i); // from i+1 to j (inclusive)
block.Body[i] = fixedStmt;
if (pinnedVar.Type.IsByReference)
pinnedVar.Type = new PointerType(((ByReferenceType)pinnedVar.Type).ElementType);
}
bool IsNullOrZero(ILExpression expr)
@ -350,7 +350,7 @@ namespace ICSharpCode.Decompiler.ILAst @@ -350,7 +350,7 @@ namespace ICSharpCode.Decompiler.ILAst
bool MatchFixedInitializer(ILBlock block, int i, out ILVariable pinnedVar, out ILExpression initValue, out int nextPos)
{
if (block.Body[i].Match(ILCode.Stloc, out pinnedVar, out initValue) && pinnedVar.IsPinned) {
if (block.Body[i].Match(ILCode.Stloc, out pinnedVar, out initValue) && pinnedVar.IsPinned && !IsNullOrZero(initValue)) {
initValue = (ILExpression)block.Body[i];
nextPos = i + 1;
HandleStringFixing(pinnedVar, block.Body, ref nextPos, ref initValue);
@ -365,19 +365,22 @@ namespace ICSharpCode.Decompiler.ILAst @@ -365,19 +365,22 @@ namespace ICSharpCode.Decompiler.ILAst
&& ifStmt.TrueBlock.Body[0].Match(ILCode.Stloc, out pinnedVar, out trueValue)
&& pinnedVar.IsPinned && IsNullOrZero(trueValue))
{
ILVariable stlocVar;
ILExpression falseValue;
if (ifStmt.FalseBlock != null && ifStmt.FalseBlock.Body.Count == 1
&& ifStmt.FalseBlock.Body[0].Match(ILCode.Stloc, out stlocVar, out falseValue) && stlocVar == pinnedVar)
{
ILVariable loadedVariable;
if (falseValue.Code == ILCode.Ldelema
&& falseValue.Arguments[0].Match(ILCode.Ldloc, out loadedVariable) && loadedVariable == arrayVariable
&& IsNullOrZero(falseValue.Arguments[1]))
if (ifStmt.FalseBlock != null && ifStmt.FalseBlock.Body.Count == 1 && ifStmt.FalseBlock.Body[0] is ILFixedStatement) {
ILFixedStatement fixedStmt = (ILFixedStatement)ifStmt.FalseBlock.Body[0];
ILVariable stlocVar;
ILExpression falseValue;
if (fixedStmt.Initializers.Count == 1 && fixedStmt.BodyBlock.Body.Count == 0
&& fixedStmt.Initializers[0].Match(ILCode.Stloc, out stlocVar, out falseValue) && stlocVar == pinnedVar)
{
initValue = new ILExpression(ILCode.Stloc, pinnedVar, arrayLoadingExpr);
nextPos = i + 1;
return true;
ILVariable loadedVariable;
if (falseValue.Code == ILCode.Ldelema
&& falseValue.Arguments[0].Match(ILCode.Ldloc, out loadedVariable) && loadedVariable == arrayVariable
&& IsNullOrZero(falseValue.Arguments[1]))
{
initValue = new ILExpression(ILCode.Stloc, pinnedVar, arrayLoadingExpr);
nextPos = i + 1;
return true;
}
}
}
}

81
ICSharpCode.Decompiler/ILAst/TypeAnalysis.cs

@ -437,20 +437,26 @@ namespace ICSharpCode.Decompiler.ILAst @@ -437,20 +437,26 @@ namespace ICSharpCode.Decompiler.ILAst
case ILCode.Neg:
return InferTypeForExpression(expr.Arguments.Single(), expectedType);
case ILCode.Add:
return InferArgumentsInAddition(expr, null, expectedType);
case ILCode.Sub:
return InferArgumentsInSubtraction(expr, null, expectedType);
case ILCode.Mul:
case ILCode.Or:
case ILCode.And:
case ILCode.Xor:
return InferArgumentsInBinaryOperator(expr, null, expectedType);
case ILCode.Add_Ovf:
return InferArgumentsInAddition(expr, true, expectedType);
case ILCode.Sub_Ovf:
return InferArgumentsInSubtraction(expr, true, expectedType);
case ILCode.Mul_Ovf:
case ILCode.Div:
case ILCode.Rem:
return InferArgumentsInBinaryOperator(expr, true, expectedType);
case ILCode.Add_Ovf_Un:
return InferArgumentsInAddition(expr, false, expectedType);
case ILCode.Sub_Ovf_Un:
return InferArgumentsInSubtraction(expr, false, expectedType);
case ILCode.Mul_Ovf_Un:
case ILCode.Div_Un:
case ILCode.Rem_Un:
@ -593,11 +599,11 @@ namespace ICSharpCode.Decompiler.ILAst @@ -593,11 +599,11 @@ namespace ICSharpCode.Decompiler.ILAst
case ILCode.Conv_I:
case ILCode.Conv_Ovf_I:
case ILCode.Conv_Ovf_I_Un:
return HandleConversion(nativeInt, true, expr.Arguments[0], expectedType, typeSystem.IntPtr);
return HandleConversion(NativeInt, true, expr.Arguments[0], expectedType, typeSystem.IntPtr);
case ILCode.Conv_U:
case ILCode.Conv_Ovf_U:
case ILCode.Conv_Ovf_U_Un:
return HandleConversion(nativeInt, false, expr.Arguments[0], expectedType, typeSystem.UIntPtr);
return HandleConversion(NativeInt, false, expr.Arguments[0], expectedType, typeSystem.UIntPtr);
case ILCode.Conv_R4:
return typeSystem.Single;
case ILCode.Conv_R8:
@ -671,17 +677,17 @@ namespace ICSharpCode.Decompiler.ILAst @@ -671,17 +677,17 @@ namespace ICSharpCode.Decompiler.ILAst
TypeReference HandleConversion(int targetBitSize, bool targetSigned, ILExpression arg, TypeReference expectedType, TypeReference targetType)
{
if (targetBitSize >= nativeInt && expectedType is PointerType) {
if (targetBitSize >= NativeInt && expectedType is PointerType) {
InferTypeForExpression(arg, expectedType);
return expectedType;
}
TypeReference argType = InferTypeForExpression(arg, null);
if (targetBitSize >= nativeInt && argType is ByReferenceType) {
if (targetBitSize >= NativeInt && argType is ByReferenceType) {
// conv instructions on managed references mean that the GC should stop tracking them, so they become pointers:
PointerType ptrType = new PointerType(((ByReferenceType)argType).ElementType);
InferTypeForExpression(arg, ptrType);
return ptrType;
} else if (targetBitSize >= nativeInt && argType is PointerType) {
} else if (targetBitSize >= NativeInt && argType is PointerType) {
return argType;
}
return (GetInformationAmount(expectedType) == targetBitSize && IsSigned(expectedType) == targetSigned) ? expectedType : targetType;
@ -791,6 +797,62 @@ namespace ICSharpCode.Decompiler.ILAst @@ -791,6 +797,62 @@ namespace ICSharpCode.Decompiler.ILAst
}
}
TypeReference InferArgumentsInAddition(ILExpression expr, bool? isSigned, TypeReference expectedType)
{
ILExpression left = expr.Arguments[0];
ILExpression right = expr.Arguments[1];
TypeReference leftPreferred = DoInferTypeForExpression(left, expectedType);
if (leftPreferred is PointerType) {
left.InferredType = left.ExpectedType = leftPreferred;
InferTypeForExpression(right, typeSystem.IntPtr);
return leftPreferred;
} else {
TypeReference rightPreferred = DoInferTypeForExpression(right, expectedType);
if (rightPreferred is PointerType) {
InferTypeForExpression(left, typeSystem.IntPtr);
right.InferredType = right.ExpectedType = rightPreferred;
return rightPreferred;
} else if (leftPreferred == rightPreferred) {
return left.InferredType = right.InferredType = left.ExpectedType = right.ExpectedType = leftPreferred;
} else if (rightPreferred == DoInferTypeForExpression(left, rightPreferred)) {
return left.InferredType = right.InferredType = left.ExpectedType = right.ExpectedType = rightPreferred;
} else if (leftPreferred == DoInferTypeForExpression(right, leftPreferred)) {
return left.InferredType = right.InferredType = left.ExpectedType = right.ExpectedType = leftPreferred;
} else {
left.ExpectedType = right.ExpectedType = TypeWithMoreInformation(leftPreferred, rightPreferred);
left.InferredType = DoInferTypeForExpression(left, left.ExpectedType);
right.InferredType = DoInferTypeForExpression(right, right.ExpectedType);
return left.ExpectedType;
}
}
}
TypeReference InferArgumentsInSubtraction(ILExpression expr, bool? isSigned, TypeReference expectedType)
{
ILExpression left = expr.Arguments[0];
ILExpression right = expr.Arguments[1];
TypeReference leftPreferred = DoInferTypeForExpression(left, expectedType);
if (leftPreferred is PointerType) {
left.InferredType = left.ExpectedType = leftPreferred;
InferTypeForExpression(right, typeSystem.IntPtr);
return leftPreferred;
} else {
TypeReference rightPreferred = DoInferTypeForExpression(right, expectedType);
if (leftPreferred == rightPreferred) {
return left.InferredType = right.InferredType = left.ExpectedType = right.ExpectedType = leftPreferred;
} else if (rightPreferred == DoInferTypeForExpression(left, rightPreferred)) {
return left.InferredType = right.InferredType = left.ExpectedType = right.ExpectedType = rightPreferred;
} else if (leftPreferred == DoInferTypeForExpression(right, leftPreferred)) {
return left.InferredType = right.InferredType = left.ExpectedType = right.ExpectedType = leftPreferred;
} else {
left.ExpectedType = right.ExpectedType = TypeWithMoreInformation(leftPreferred, rightPreferred);
left.InferredType = DoInferTypeForExpression(left, left.ExpectedType);
right.InferredType = DoInferTypeForExpression(right, right.ExpectedType);
return left.ExpectedType;
}
}
}
TypeReference TypeWithMoreInformation(TypeReference leftPreferred, TypeReference rightPreferred)
{
int left = GetInformationAmount(leftPreferred);
@ -805,9 +867,12 @@ namespace ICSharpCode.Decompiler.ILAst @@ -805,9 +867,12 @@ namespace ICSharpCode.Decompiler.ILAst
}
}
const int nativeInt = 33; // treat native int as between int32 and int64
/// <summary>
/// Information amount used for IntPtr.
/// </summary>
public const int NativeInt = 33; // treat native int as between int32 and int64
static int GetInformationAmount(TypeReference type)
public static int GetInformationAmount(TypeReference type)
{
if (type == null)
return 0;
@ -841,7 +906,7 @@ namespace ICSharpCode.Decompiler.ILAst @@ -841,7 +906,7 @@ namespace ICSharpCode.Decompiler.ILAst
return 64;
case MetadataType.IntPtr:
case MetadataType.UIntPtr:
return nativeInt;
return NativeInt;
default:
return 100; // we consider structs/objects to have more information than any primitives
}

9
ICSharpCode.Decompiler/Tests/UnsafeCode.cs

@ -65,4 +65,13 @@ public class UnsafeCode @@ -65,4 +65,13 @@ public class UnsafeCode
*e = 'e';
}
}
public unsafe string StackAlloc(int count)
{
char* a = stackalloc char[count];
for (int i = 0; i < count; i++) {
a[i] = (char)i;
}
return PointerReferenceExpression((double*)a);
}
}

Loading…
Cancel
Save