.NET Decompiler with support for PDB generation, ReadyToRun, Metadata (&more) - cross-platform!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

400 lines
13 KiB

// Copyright (c) AlphaSierraPapa for the SharpDevelop Team (for details please see \doc\copyright.txt)
// This code is distributed under MIT X11 license (for details please see \doc\license.txt)
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Mono.Cecil;
namespace ICSharpCode.Decompiler.ILAst
{
public class YieldReturnDecompiler
{
// For a description on the code generated by the C# compiler for yield return:
// http://csharpindepth.com/Articles/Chapter6/IteratorBlockImplementation.aspx
// The idea here is:
// - Figure out whether the current method is instanciating an enumerator
// - Figure out which of the fields is the state field
// - Construct an exception table based on states. This allows us to determine, for each state, what the parent try block is.
/// <summary>
/// This exception is thrown when we find something else than we expect from the C# compiler.
/// This aborts the analysis and makes the whole transform fail.
/// </summary>
class YieldAnalysisFailedException : Exception {}
DecompilerContext context;
TypeDefinition enumeratorType;
MethodDefinition enumeratorCtor;
FieldDefinition stateField;
Dictionary<ParameterDefinition, FieldDefinition> parameterToFieldMap;
#region Run() method
public static void Run(DecompilerContext context, ILBlock method)
{
if (!context.Settings.YieldReturn)
return; // abort if enumerator decompilation is disabled
var yrd = new YieldReturnDecompiler();
yrd.context = context;
if (!yrd.MatchEnumeratorCreationPattern(method))
return;
yrd.enumeratorType = yrd.enumeratorCtor.DeclaringType;
#if !DEBUG
try {
#endif
yrd.AnalyzeCtor();
yrd.ConstructExceptionTable();
#if !DEBUG
} catch (YieldAnalysisFailedException) {
return;
}
#endif
}
#endregion
#region Match the enumerator creation pattern
bool MatchEnumeratorCreationPattern(ILBlock method)
{
if (method.Body.Count == 0)
return false;
ILExpression ret;
if (method.Body.Count == 1) {
// ret(newobj(...))
if (method.Body[0].Match(ILCode.Ret, out ret) && ret.Arguments.Count == 1)
return MatchEnumeratorCreationNewObj(ret.Arguments[0], out enumeratorCtor);
else
return false;
}
// stloc(var_1, newobj(..)
ILExpression stloc;
if (!method.Body[0].Match(ILCode.Stloc, out stloc))
return false;
if (!MatchEnumeratorCreationNewObj(stloc.Arguments[0], out enumeratorCtor))
return false;
parameterToFieldMap = new Dictionary<ParameterDefinition, FieldDefinition>();
int i = 1;
ILExpression stfld;
while (i < method.Body.Count && method.Body[i].Match(ILCode.Stfld, out stfld)) {
// stfld(..., ldloc(var_1), ldarg(...))
ILExpression ldloc, ldarg;
if (!(stfld.Arguments[0].Match(ILCode.Ldloc, out ldloc) && stfld.Arguments[1].Match(ILCode.Ldarg, out ldarg)))
return false;
if (ldloc.Operand != stloc.Operand || !(stfld.Operand is FieldDefinition))
return false;
parameterToFieldMap[(ParameterDefinition)ldarg.Operand] = (FieldDefinition)stfld.Operand;
i++;
}
ILExpression stloc2;
if (i < method.Body.Count && method.Body[i].Match(ILCode.Stloc, out stloc2)) {
// stloc(var_2, ldloc(var_1))
if (stloc2.Arguments[0].Code != ILCode.Ldloc || stloc2.Arguments[0].Operand != stloc.Operand)
return false;
i++;
} else {
// the compiler might skip the above instruction in release builds; in that case, it directly returns stloc.Operand
stloc2 = stloc;
}
ILExpression br;
if (i + 1 < method.Body.Count && method.Body[i].Match(ILCode.Br, out br)) {
if (br.Operand != method.Body[i + 1])
return false;
i += 2;
}
if (i < method.Body.Count && method.Body[i].Match(ILCode.Ret, out ret)) {
if (ret.Arguments[0].Code == ILCode.Ldloc && ret.Arguments[0].Operand == stloc2.Operand) {
return true;
}
}
return false;
}
bool MatchEnumeratorCreationNewObj(ILExpression expr, out MethodDefinition ctor)
{
// newobj(CurrentType/...::.ctor, ldc.i4(-2))
ctor = expr.Operand as MethodDefinition;
if (expr.Code != ILCode.Newobj || expr.Arguments.Count != 1)
return false;
if (expr.Arguments[0].Code != ILCode.Ldc_I4 || (int)expr.Arguments[0].Operand != -2)
return false;
if (ctor == null || !ctor.DeclaringType.IsCompilerGenerated())
return false;
return ctor.DeclaringType.DeclaringType == context.CurrentType;
}
#endregion
#region Figure out what the state field is
void AnalyzeCtor()
{
ILBlock method = CreateILAst(enumeratorCtor);
foreach (ILNode node in method.Body) {
ILExpression stfld;
if (node.Match(ILCode.Stfld, out stfld)
&& stfld.Arguments[0].Code == ILCode.Ldarg && ((ParameterDefinition)stfld.Arguments[0].Operand).Index < 0
&& stfld.Arguments[1].Code == ILCode.Ldarg && ((ParameterDefinition)stfld.Arguments[1].Operand).Index == 0)
{
stateField = stfld.Operand as FieldDefinition;
}
}
if (stateField == null)
throw new YieldAnalysisFailedException();
}
ILBlock CreateILAst(MethodDefinition method)
{
if (method == null || !method.HasBody)
throw new YieldAnalysisFailedException();
ILBlock ilMethod = new ILBlock();
ILAstBuilder astBuilder = new ILAstBuilder();
ilMethod.Body = astBuilder.Build(method, true);
ILAstOptimizer optimizer = new ILAstOptimizer();
optimizer.Optimize(context, ilMethod, ILAstOptimizationStep.YieldReturn);
return ilMethod;
}
#endregion
#region Construction of the exception table
// We construct the exception table by analyzing the enumerator's Dispose() method.
// Assumption: there are no loops/backward jumps
// We 'run' the code, with "state" being a symbolic variable
// so it can form expressions like "state + x" (when there's a sub instruction)
// For each instruction, we maintain a list of value ranges for state for which the instruction is reachable.
// This is (int.MinValue, int.MaxValue) for the first instruction.
// These ranges are propagated depending on the conditional jumps performed by the code.
#region struct Interval / class StateRange
struct Interval
{
public readonly int Start, End;
public Interval(int start, int end)
{
Debug.Assert(start <= end || (start == 0 && end == -1));
this.Start = start;
this.End = end;
}
public override string ToString()
{
return string.Format("({0} to {1})", Start, End);
}
}
class StateRange
{
readonly List<Interval> data = new List<Interval>();
public StateRange()
{
}
public StateRange(int start, int end)
{
this.data.Add(new Interval(start, end));
}
public void UnionWith(StateRange other)
{
data.AddRange(other.data);
}
/// <summary>
/// Unions this state range with (other intersect (minVal to maxVal))
/// </summary>
public void UnionWith(StateRange other, int minVal, int maxVal)
{
foreach (Interval v in other.data) {
int start = Math.Max(v.Start, minVal);
int end = Math.Min(v.End, maxVal);
if (start <= end)
data.Add(new Interval(start, end));
}
}
/// <summary>
/// Merges overlapping interval ranges.
/// </summary>
public void Simplify()
{
if (data.Count < 2)
return;
data.Sort((a, b) => a.Start.CompareTo(b.Start));
Interval prev = data[0];
int prevIndex = 0;
for (int i = 1; i < data.Count; i++) {
Interval next = data[i];
Debug.Assert(prev.Start <= next.Start);
if (next.Start <= prev.End + 1) { // intervals overlapping or touching
prev = new Interval(prev.Start, Math.Max(prev.End, next.End));
data[prevIndex] = prev;
data[i] = new Interval(0, -1); // mark as deleted
} else {
prev = next;
prevIndex = i;
}
}
data.RemoveAll(i => i.Start > i.End); // remove all entries that were marked as deleted
}
public override string ToString()
{
return string.Join(",", data);
}
public Interval ToInterval()
{
if (data.Count == 1)
return data[0];
else
throw new YieldAnalysisFailedException();
}
}
#endregion
DefaultDictionary<ILNode, StateRange> ranges;
void ConstructExceptionTable()
{
MethodDefinition disposeMethod = enumeratorType.Methods.FirstOrDefault(m => m.Name == "System.IDisposable.Dispose");
ILBlock ilMethod = CreateILAst(disposeMethod);
ranges = new DefaultDictionary<ILNode, StateRange>(node => new StateRange());
ranges[ilMethod] = new StateRange(int.MinValue, int.MaxValue);
AssignStateRanges(ilMethod);
// Now look at the finally blocks:
foreach (var tryFinally in ilMethod.GetSelfAndChildrenRecursive<ILTryCatchBlock>()) {
Interval interval = ranges[tryFinally.TryBlock.Body[0]].ToInterval();
var finallyBody = tryFinally.FinallyBlock.Body;
if (!(finallyBody.Count == 2 || finallyBody.Count == 3))
throw new YieldAnalysisFailedException();
ILExpression call = finallyBody[0] as ILExpression;
if (call == null || call.Code != ILCode.Call || call.Arguments.Count != 1)
throw new YieldAnalysisFailedException();
if (call.Arguments[0].Code != ILCode.Ldarg || ((ParameterDefinition)call.Arguments[0].Operand).Index >= 0)
throw new YieldAnalysisFailedException();
if (finallyBody.Count == 3 && !finallyBody[1].Match(ILCode.Nop))
throw new YieldAnalysisFailedException();
if (!finallyBody[finallyBody.Count - 1].Match(ILCode.Endfinally))
throw new YieldAnalysisFailedException();
Debug.WriteLine("State " + interval + " -> " + call.Operand.ToString());
}
}
#region Assign StateRanges / Symbolic Execution
void AssignStateRanges(ILBlock block)
{
if (block.Body.Count == 0)
return;
ranges[block.Body[0]].UnionWith(ranges[block]);
for (int i = 0; i < block.Body.Count; i++) {
StateRange nodeRange = ranges[block.Body[i]];
nodeRange.Simplify();
ILLabel label = block.Body[i] as ILLabel;
if (label != null) {
ranges[block.Body[i + 1]].UnionWith(nodeRange);
continue;
}
ILTryCatchBlock tryFinally = block.Body[i] as ILTryCatchBlock;
if (tryFinally != null) {
if (tryFinally.CatchBlocks.Count != 0 || tryFinally.FaultBlock != null || tryFinally.FinallyBlock == null)
throw new YieldAnalysisFailedException();
ranges[tryFinally.TryBlock].UnionWith(nodeRange);
AssignStateRanges(tryFinally.TryBlock);
continue;
}
ILExpression expr = block.Body[i] as ILExpression;
if (expr == null)
throw new YieldAnalysisFailedException();
switch (expr.Code) {
case ILCode.Switch:
SymbolicValue val = Eval(expr.Arguments[0]);
if (val.Type != SymbolicValueType.State)
throw new YieldAnalysisFailedException();
ILLabel[] targetLabels = (ILLabel[])expr.Operand;
for (int j = 0; j < targetLabels.Length; j++) {
int state = j - val.Constant;
ranges[targetLabels[j]].UnionWith(nodeRange, state, state);
}
ranges[block.Body[i + 1]].UnionWith(nodeRange, int.MinValue, -1 - val.Constant);
ranges[block.Body[i + 1]].UnionWith(nodeRange, targetLabels.Length - val.Constant, int.MaxValue);
break;
case ILCode.Br:
case ILCode.Leave:
ranges[(ILLabel)expr.Operand].UnionWith(nodeRange);
break;
case ILCode.Nop:
ranges[block.Body[i + 1]].UnionWith(nodeRange);
break;
case ILCode.Ret:
break;
default:
throw new YieldAnalysisFailedException();
}
}
}
enum SymbolicValueType
{
IntegerConstant,
State,
This
}
struct SymbolicValue
{
public readonly int Constant;
public readonly SymbolicValueType Type;
public SymbolicValue(SymbolicValueType type, int constant = 0)
{
this.Type = type;
this.Constant = constant;
}
public override string ToString()
{
return string.Format("[SymbolicValue {0}: {1}]", this.Type, this.Constant);
}
}
SymbolicValue Eval(ILExpression expr)
{
switch (expr.Code) {
case ILCode.Sub:
SymbolicValue left = Eval(expr.Arguments[0]);
SymbolicValue right = Eval(expr.Arguments[1]);
if (left.Type != SymbolicValueType.State && right.Type != SymbolicValueType.IntegerConstant)
throw new YieldAnalysisFailedException();
if (right.Type != SymbolicValueType.IntegerConstant)
throw new YieldAnalysisFailedException();
return new SymbolicValue(left.Type, unchecked ( left.Constant - right.Constant ));
case ILCode.Ldfld:
if (Eval(expr.Arguments[0]).Type != SymbolicValueType.This)
throw new YieldAnalysisFailedException();
if (expr.Operand != stateField)
throw new YieldAnalysisFailedException();
return new SymbolicValue(SymbolicValueType.State);
case ILCode.Ldarg:
if (((ParameterDefinition)expr.Operand).Index < 0)
return new SymbolicValue(SymbolicValueType.This);
else
throw new YieldAnalysisFailedException();
case ILCode.Ldc_I4:
return new SymbolicValue(SymbolicValueType.IntegerConstant, (int)expr.Operand);
default:
throw new YieldAnalysisFailedException();
}
}
#endregion
#endregion
}
}