Browse Source

Started analysis of yield return statements.

pull/70/head
Daniel Grunwald 15 years ago
parent
commit
6feadf3840
  1. 2
      ICSharpCode.Decompiler/Ast/DecompilerContext.cs
  2. 1
      ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj
  3. 117
      ICSharpCode.Decompiler/ILAst/DefaultDictionary.cs
  4. 5
      ICSharpCode.Decompiler/ILAst/ILAstBuilder.cs
  5. 4
      ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs
  6. 10
      ICSharpCode.Decompiler/ILAst/ILAstTypes.cs
  7. 387
      ICSharpCode.Decompiler/ILAst/YieldReturnDecompiler.cs

2
ICSharpCode.Decompiler/Ast/DecompilerContext.cs

@ -12,7 +12,7 @@ namespace ICSharpCode.Decompiler @@ -12,7 +12,7 @@ namespace ICSharpCode.Decompiler
public CancellationToken CancellationToken;
public TypeDefinition CurrentType;
public MethodDefinition CurrentMethod;
public DecompilerSettings Settings;
public DecompilerSettings Settings = new DecompilerSettings();
public DecompilerContext Clone()
{

1
ICSharpCode.Decompiler/ICSharpCode.Decompiler.csproj

@ -87,6 +87,7 @@ @@ -87,6 +87,7 @@
<Compile Include="FlowAnalysis\TransformToSsa.cs" />
<Compile Include="GraphVizGraph.cs" />
<Compile Include="ILAst\ArrayInitializers.cs" />
<Compile Include="ILAst\DefaultDictionary.cs" />
<Compile Include="ILAst\GotoRemoval.cs" />
<Compile Include="ILAst\ILAstBuilder.cs" />
<Compile Include="ILAst\ILAstOptimizer.cs" />

117
ICSharpCode.Decompiler/ILAst/DefaultDictionary.cs

@ -0,0 +1,117 @@ @@ -0,0 +1,117 @@
// <file>
// <copyright see="prj:///doc/copyright.txt"/>
// <license see="prj:///doc/license.txt"/>
// <owner name="Daniel Grunwald" email="daniel@danielgrunwald.de"/>
// <version>$Revision$</version>
// </file>
using System;
using System.Collections;
using System.Collections.Generic;
namespace ICSharpCode.Decompiler.ILAst
{
/// <summary>
/// Dictionary with default values.
/// </summary>
sealed class DefaultDictionary<TKey, TValue> : IDictionary<TKey, TValue>
{
readonly IDictionary<TKey, TValue> dict;
readonly Func<TKey, TValue> defaultProvider;
public DefaultDictionary(TValue defaultValue, IDictionary<TKey, TValue> dictionary = null)
: this(key => defaultValue, dictionary)
{
}
public DefaultDictionary(Func<TKey, TValue> defaultProvider = null, IDictionary<TKey, TValue> dictionary = null)
{
this.dict = dictionary ?? new Dictionary<TKey, TValue>();
this.defaultProvider = defaultProvider ?? (key => default(TValue));
}
public TValue this[TKey key] {
get {
TValue val;
if (dict.TryGetValue(key, out val))
return val;
else
return dict[key] = defaultProvider(key);
}
set {
dict[key] = value;
}
}
public ICollection<TKey> Keys {
get { return dict.Keys; }
}
public ICollection<TValue> Values {
get { return dict.Values; }
}
public int Count {
get { return dict.Count; }
}
bool ICollection<KeyValuePair<TKey, TValue>>.IsReadOnly {
get { return false; }
}
public bool ContainsKey(TKey key)
{
return dict.ContainsKey(key);
}
public void Add(TKey key, TValue value)
{
dict.Add(key, value);
}
public bool Remove(TKey key)
{
return dict.Remove(key);
}
public bool TryGetValue(TKey key, out TValue value)
{
return dict.TryGetValue(key, out value);
}
void ICollection<KeyValuePair<TKey, TValue>>.Add(KeyValuePair<TKey, TValue> item)
{
dict.Add(item);
}
public void Clear()
{
dict.Clear();
}
bool ICollection<KeyValuePair<TKey, TValue>>.Contains(KeyValuePair<TKey, TValue> item)
{
return dict.Contains(item);
}
void ICollection<KeyValuePair<TKey, TValue>>.CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
{
dict.CopyTo(array, arrayIndex);
}
bool ICollection<KeyValuePair<TKey, TValue>>.Remove(KeyValuePair<TKey, TValue> item)
{
return dict.Remove(item);
}
IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator()
{
return dict.GetEnumerator();
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return dict.GetEnumerator();
}
}
}

5
ICSharpCode.Decompiler/ILAst/ILAstBuilder.cs

@ -609,9 +609,10 @@ namespace ICSharpCode.Decompiler.ILAst @@ -609,9 +609,10 @@ namespace ICSharpCode.Decompiler.ILAst
tryCatchBlock.CatchBlocks.Add(catchBlock);
} else if (eh.HandlerType == ExceptionHandlerType.Finally) {
tryCatchBlock.FinallyBlock = new ILBlock(handlerAst);
// TODO: ldexception
} else if (eh.HandlerType == ExceptionHandlerType.Fault) {
tryCatchBlock.FaultBlock = new ILBlock(handlerAst);
} else {
// TODO
// TODO: ExceptionHandlerType.Filter
}
}

4
ICSharpCode.Decompiler/ILAst/ILAstOptimizer.cs

@ -12,6 +12,7 @@ namespace ICSharpCode.Decompiler.ILAst @@ -12,6 +12,7 @@ namespace ICSharpCode.Decompiler.ILAst
public enum ILAstOptimizationStep
{
ReduceBranchInstructionSet,
YieldReturn,
SplitToMovableBlocks,
PeepholeOptimizations,
FindLoops,
@ -38,6 +39,9 @@ namespace ICSharpCode.Decompiler.ILAst @@ -38,6 +39,9 @@ namespace ICSharpCode.Decompiler.ILAst
ReduceBranchInstructionSet(block);
}
if (abortBeforeStep == ILAstOptimizationStep.YieldReturn) return;
YieldReturnDecompiler.Run(context, method);
if (abortBeforeStep == ILAstOptimizationStep.SplitToMovableBlocks) return;
foreach(ILBlock block in method.GetSelfAndChildrenRecursive<ILBlock>().ToList()) {
SplitToBasicBlocks(block);

10
ICSharpCode.Decompiler/ILAst/ILAstTypes.cs

@ -150,6 +150,7 @@ namespace ICSharpCode.Decompiler.ILAst @@ -150,6 +150,7 @@ namespace ICSharpCode.Decompiler.ILAst
public ILBlock TryBlock;
public List<CatchBlock> CatchBlocks;
public ILBlock FinallyBlock;
public ILBlock FaultBlock;
public override IEnumerable<ILNode> GetChildren()
{
@ -158,6 +159,8 @@ namespace ICSharpCode.Decompiler.ILAst @@ -158,6 +159,8 @@ namespace ICSharpCode.Decompiler.ILAst
foreach (var catchBlock in this.CatchBlocks) {
yield return catchBlock;
}
if (this.FaultBlock != null)
yield return this.FaultBlock;
if (this.FinallyBlock != null)
yield return this.FinallyBlock;
}
@ -172,6 +175,13 @@ namespace ICSharpCode.Decompiler.ILAst @@ -172,6 +175,13 @@ namespace ICSharpCode.Decompiler.ILAst
foreach (CatchBlock block in CatchBlocks) {
block.WriteTo(output);
}
if (FaultBlock != null) {
output.WriteLine("fault {");
output.Indent();
FaultBlock.WriteTo(output);
output.Unindent();
output.WriteLine("}");
}
if (FinallyBlock != null) {
output.WriteLine("finally {");
output.Indent();

387
ICSharpCode.Decompiler/ILAst/YieldReturnDecompiler.cs

@ -2,6 +2,10 @@ @@ -2,6 +2,10 @@
// This code is distributed under MIT X11 license (for details please see \doc\license.txt)
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Mono.Cecil;
namespace ICSharpCode.Decompiler.ILAst
{
@ -10,6 +14,387 @@ namespace ICSharpCode.Decompiler.ILAst @@ -10,6 +14,387 @@ namespace ICSharpCode.Decompiler.ILAst
// For a description on the code generated by the C# compiler for yield return:
// http://csharpindepth.com/Articles/Chapter6/IteratorBlockImplementation.aspx
// not implemented yet...
// The idea here is:
// - Figure out whether the current method is instanciating an enumerator
// - Figure out which of the fields is the state field
// - Construct an exception table based on states. This allows us to determine, for each state, what the parent try block is.
/// <summary>
/// This exception is thrown when we find something else than we expect from the C# compiler.
/// This aborts the analysis and makes the whole transform fail.
/// </summary>
class YieldAnalysisFailedException : Exception {}
DecompilerContext context;
TypeDefinition enumeratorType;
MethodDefinition enumeratorCtor;
FieldDefinition stateField;
Dictionary<ParameterDefinition, FieldDefinition> parameterToFieldMap;
#region Run() method
public static void Run(DecompilerContext context, ILBlock method)
{
if (!context.Settings.YieldReturn)
return; // abort if enumerator decompilation is disabled
var yrd = new YieldReturnDecompiler();
yrd.context = context;
if (!yrd.MatchEnumeratorCreationPattern(method))
return;
yrd.enumeratorType = yrd.enumeratorCtor.DeclaringType;
#if !DEBUG
try {
#endif
yrd.AnalyzeCtor();
yrd.ConstructExceptionTable();
#if !DEBUG
} catch (YieldAnalysisFailedException) {
return;
}
#endif
}
#endregion
#region Match the enumerator creation pattern
bool MatchEnumeratorCreationPattern(ILBlock method)
{
if (method.Body.Count == 0)
return false;
ILExpression ret;
if (method.Body.Count == 1) {
// ret(newobj(...))
if (method.Body[0].Match(ILCode.Ret, out ret) && ret.Arguments.Count == 1)
return MatchEnumeratorCreationNewObj(ret.Arguments[0], out enumeratorCtor);
else
return false;
}
// stloc(var_1, newobj(..)
ILExpression stloc;
if (!method.Body[0].Match(ILCode.Stloc, out stloc))
return false;
if (!MatchEnumeratorCreationNewObj(stloc.Arguments[0], out enumeratorCtor))
return false;
parameterToFieldMap = new Dictionary<ParameterDefinition, FieldDefinition>();
int i = 1;
ILExpression stfld;
while (i < method.Body.Count && method.Body[i].Match(ILCode.Stfld, out stfld)) {
// stfld(..., ldloc(var_1), ldarg(...))
ILExpression ldloc, ldarg;
if (!(stfld.Arguments[0].Match(ILCode.Ldloc, out ldloc) && stfld.Arguments[1].Match(ILCode.Ldarg, out ldarg)))
return false;
if (ldloc.Operand != stloc.Operand || !(stfld.Operand is FieldDefinition))
return false;
parameterToFieldMap[(ParameterDefinition)ldarg.Operand] = (FieldDefinition)stfld.Operand;
i++;
}
ILExpression stloc2;
if (i < method.Body.Count && method.Body[i].Match(ILCode.Stloc, out stloc2)) {
// stloc(var_2, ldloc(var_1))
if (stloc2.Arguments[0].Code != ILCode.Ldloc || stloc2.Arguments[0].Operand != stloc.Operand)
return false;
i++;
} else {
// the compiler might skip the above instruction in release builds; in that case, it directly returns stloc.Operand
stloc2 = stloc;
}
ILExpression br;
if (i + 1 < method.Body.Count && method.Body[i].Match(ILCode.Br, out br)) {
if (br.Operand != method.Body[i + 1])
return false;
i += 2;
}
if (i < method.Body.Count && method.Body[i].Match(ILCode.Ret, out ret)) {
if (ret.Arguments[0].Code == ILCode.Ldloc && ret.Arguments[0].Operand == stloc2.Operand) {
return true;
}
}
return false;
}
bool MatchEnumeratorCreationNewObj(ILExpression expr, out MethodDefinition ctor)
{
// newobj(CurrentType/...::.ctor, ldc.i4(-2))
ctor = expr.Operand as MethodDefinition;
if (expr.Code != ILCode.Newobj || expr.Arguments.Count != 1)
return false;
if (expr.Arguments[0].Code != ILCode.Ldc_I4 || (int)expr.Arguments[0].Operand != -2)
return false;
if (ctor == null || !ctor.DeclaringType.IsCompilerGenerated())
return false;
return ctor.DeclaringType.DeclaringType == context.CurrentType;
}
#endregion
#region Figure out what the state field is
void AnalyzeCtor()
{
ILBlock method = CreateILAst(enumeratorCtor);
foreach (ILNode node in method.Body) {
ILExpression stfld;
if (node.Match(ILCode.Stfld, out stfld)
&& stfld.Arguments[0].Code == ILCode.Ldarg && ((ParameterDefinition)stfld.Arguments[0].Operand).Index < 0
&& stfld.Arguments[1].Code == ILCode.Ldarg && ((ParameterDefinition)stfld.Arguments[1].Operand).Index == 0)
{
stateField = stfld.Operand as FieldDefinition;
}
}
if (stateField == null)
throw new YieldAnalysisFailedException();
}
ILBlock CreateILAst(MethodDefinition method)
{
if (method == null || !method.HasBody)
throw new YieldAnalysisFailedException();
ILBlock ilMethod = new ILBlock();
ILAstBuilder astBuilder = new ILAstBuilder();
ilMethod.Body = astBuilder.Build(method, true);
ILAstOptimizer optimizer = new ILAstOptimizer();
optimizer.Optimize(context, ilMethod, ILAstOptimizationStep.YieldReturn);
return ilMethod;
}
#endregion
#region Construction of the exception table
// We construct the exception table by analyzing the enumerator's Dispose() method.
// Assumption: there are no loops/backward jumps
// We 'run' the code, with "state" being a symbolic variable
// so it can form expressions like "state + x" (when there's a sub instruction)
// For each instruction, we maintain a list of value ranges for state for which the instruction is reachable.
// This is (int.MinValue, int.MaxValue) for the first instruction.
// These ranges are propagated depending on the conditional jumps performed by the code.
#region struct Interval / class StateRange
struct Interval
{
public readonly int Start, End;
public Interval(int start, int end)
{
Debug.Assert(start <= end || (start == 0 && end == -1));
this.Start = start;
this.End = end;
}
public override string ToString()
{
return string.Format("({0} to {1})", Start, End);
}
}
class StateRange
{
readonly List<Interval> data = new List<Interval>();
public StateRange()
{
}
public StateRange(int start, int end)
{
this.data.Add(new Interval(start, end));
}
public void UnionWith(StateRange other)
{
data.AddRange(other.data);
}
/// <summary>
/// Unions this state range with (other intersect (minVal to maxVal))
/// </summary>
public void UnionWith(StateRange other, int minVal, int maxVal)
{
foreach (Interval v in other.data) {
int start = Math.Max(v.Start, minVal);
int end = Math.Min(v.End, maxVal);
if (start <= end)
data.Add(new Interval(start, end));
}
}
/// <summary>
/// Merges overlapping interval ranges.
/// </summary>
public void Simplify()
{
if (data.Count < 2)
return;
data.Sort((a, b) => a.Start.CompareTo(b.Start));
Interval prev = data[0];
int prevIndex = 0;
for (int i = 1; i < data.Count; i++) {
Interval next = data[i];
Debug.Assert(prev.Start <= next.Start);
if (next.Start <= prev.End + 1) { // intervals overlapping or touching
prev = new Interval(prev.Start, Math.Max(prev.End, next.End));
data[prevIndex] = prev;
data[i] = new Interval(0, -1); // mark as deleted
} else {
prev = next;
prevIndex = i;
}
}
data.RemoveAll(i => i.Start > i.End); // remove all entries that were marked as deleted
}
public override string ToString()
{
return string.Join(",", data);
}
public Interval ToInterval()
{
if (data.Count == 1)
return data[0];
else
throw new YieldAnalysisFailedException();
}
}
#endregion
DefaultDictionary<ILNode, StateRange> ranges;
void ConstructExceptionTable()
{
MethodDefinition disposeMethod = enumeratorType.Methods.FirstOrDefault(m => m.Name == "System.IDisposable.Dispose");
ILBlock ilMethod = CreateILAst(disposeMethod);
ranges = new DefaultDictionary<ILNode, StateRange>(node => new StateRange());
ranges[ilMethod] = new StateRange(int.MinValue, int.MaxValue);
AssignStateRanges(ilMethod);
// Now look at the finally blocks:
foreach (var tryFinally in ilMethod.GetSelfAndChildrenRecursive<ILTryCatchBlock>()) {
Interval interval = ranges[tryFinally.TryBlock.Body[0]].ToInterval();
var finallyBody = tryFinally.FinallyBlock.Body;
if (!(finallyBody.Count == 2 || finallyBody.Count == 3))
throw new YieldAnalysisFailedException();
ILExpression call = finallyBody[0] as ILExpression;
if (call == null || call.Code != ILCode.Call || call.Arguments.Count != 1)
throw new YieldAnalysisFailedException();
if (call.Arguments[0].Code != ILCode.Ldarg || ((ParameterDefinition)call.Arguments[0].Operand).Index >= 0)
throw new YieldAnalysisFailedException();
if (finallyBody.Count == 3 && !finallyBody[1].Match(ILCode.Nop))
throw new YieldAnalysisFailedException();
if (!finallyBody[finallyBody.Count - 1].Match(ILCode.Endfinally))
throw new YieldAnalysisFailedException();
Debug.WriteLine("State " + interval + " -> " + call.Operand.ToString());
}
}
#region Assign StateRanges / Symbolic Execution
void AssignStateRanges(ILBlock block)
{
if (block.Body.Count == 0)
return;
ranges[block.Body[0]].UnionWith(ranges[block]);
for (int i = 0; i < block.Body.Count; i++) {
StateRange nodeRange = ranges[block.Body[i]];
nodeRange.Simplify();
ILLabel label = block.Body[i] as ILLabel;
if (label != null) {
ranges[block.Body[i + 1]].UnionWith(nodeRange);
continue;
}
ILTryCatchBlock tryFinally = block.Body[i] as ILTryCatchBlock;
if (tryFinally != null) {
if (tryFinally.CatchBlocks.Count != 0 || tryFinally.FaultBlock != null || tryFinally.FinallyBlock == null)
throw new YieldAnalysisFailedException();
ranges[tryFinally.TryBlock].UnionWith(nodeRange);
AssignStateRanges(tryFinally.TryBlock);
continue;
}
ILExpression expr = block.Body[i] as ILExpression;
if (expr == null)
throw new YieldAnalysisFailedException();
switch (expr.Code) {
case ILCode.Switch:
SymbolicValue val = Eval(expr.Arguments[0]);
if (val.Type != SymbolicValueType.State)
throw new YieldAnalysisFailedException();
ILLabel[] targetLabels = (ILLabel[])expr.Operand;
for (int j = 0; j < targetLabels.Length; j++) {
int state = j - val.Constant;
ranges[targetLabels[j]].UnionWith(nodeRange, state, state);
}
ranges[block.Body[i + 1]].UnionWith(nodeRange, int.MinValue, -1 - val.Constant);
ranges[block.Body[i + 1]].UnionWith(nodeRange, targetLabels.Length - val.Constant, int.MaxValue);
break;
case ILCode.Br:
case ILCode.Leave:
ranges[(ILLabel)expr.Operand].UnionWith(nodeRange);
break;
case ILCode.Nop:
ranges[block.Body[i + 1]].UnionWith(nodeRange);
break;
case ILCode.Ret:
break;
default:
throw new YieldAnalysisFailedException();
}
}
}
enum SymbolicValueType
{
IntegerConstant,
State,
This
}
struct SymbolicValue
{
public readonly int Constant;
public readonly SymbolicValueType Type;
public SymbolicValue(SymbolicValueType type, int constant = 0)
{
this.Type = type;
this.Constant = constant;
}
public override string ToString()
{
return string.Format("[SymbolicValue {0}: {1}]", this.Type, this.Constant);
}
}
SymbolicValue Eval(ILExpression expr)
{
switch (expr.Code) {
case ILCode.Sub:
SymbolicValue left = Eval(expr.Arguments[0]);
SymbolicValue right = Eval(expr.Arguments[1]);
if (left.Type != SymbolicValueType.State && right.Type != SymbolicValueType.IntegerConstant)
throw new YieldAnalysisFailedException();
if (right.Type != SymbolicValueType.IntegerConstant)
throw new YieldAnalysisFailedException();
return new SymbolicValue(left.Type, unchecked ( left.Constant - right.Constant ));
case ILCode.Ldfld:
if (Eval(expr.Arguments[0]).Type != SymbolicValueType.This)
throw new YieldAnalysisFailedException();
if (expr.Operand != stateField)
throw new YieldAnalysisFailedException();
return new SymbolicValue(SymbolicValueType.State);
case ILCode.Ldarg:
if (((ParameterDefinition)expr.Operand).Index < 0)
return new SymbolicValue(SymbolicValueType.This);
else
throw new YieldAnalysisFailedException();
case ILCode.Ldc_I4:
return new SymbolicValue(SymbolicValueType.IntegerConstant, (int)expr.Operand);
default:
throw new YieldAnalysisFailedException();
}
}
#endregion
#endregion
}
}

Loading…
Cancel
Save