diff --git a/ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs b/ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs index 1a6aa0f29..680689650 100644 --- a/ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs +++ b/ICSharpCode.Decompiler/IL/Transforms/TransformDisplayClassUsage.cs @@ -170,18 +170,42 @@ namespace ICSharpCode.Decompiler.IL.Transforms } } } + } + + bool ValidateDisplayClassUses(ILVariable v, DisplayClass displayClass, bool readOnly = false) + { + foreach (var ldloc in v.LoadInstructions) { + if (!ValidateUse(displayClass, ldloc)) + return false; + } + foreach (var ldloca in v.AddressInstructions) { + if (!ValidateUse(displayClass, ldloca)) + return false; + } + return true; - bool ValidateDisplayClassUses(ILVariable v, DisplayClass displayClass) + bool ValidateUse(DisplayClass container, ILInstruction use) { - foreach (var ldloc in v.LoadInstructions) { - if (!ValidateUse(displayClass, ldloc)) - return false; - } - foreach (var ldloca in v.AddressInstructions) { - if (!ValidateUse(displayClass, ldloca)) + IField field; + switch (use.Parent) { + case LdFlda ldflda when ldflda.MatchLdFlda(out _, out field): + var keyField = (IField)field.MemberDefinition; + if (!container.VariablesToDeclare.TryGetValue(keyField, out VariableToDeclare variable) || variable == null) { + if (readOnly) + return false; + variable = AddVariable(container, null, field); + } + container.VariablesToDeclare[keyField] = variable; + return variable != null; + case StObj stobj when stobj.MatchStObj(out var target, out ILInstruction value, out _) && value == use: + if (target.MatchLdFlda(out var load, out field) && load.MatchLdLocRef(out var otherVariable) && displayClasses.TryGetValue(otherVariable, out var otherDisplayClass)) { + if (otherDisplayClass.VariablesToDeclare.TryGetValue((IField)field.MemberDefinition, out var declaredVar)) + return declaredVar.CanPropagate; + } + return true; + default: return false; } - return true; } } @@ -285,8 +309,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms private Block FindDisplayStructInitBlock(ILVariable v) { var root = v.Function.Body; - var block = Visit(root)?.Ancestors.OfType().FirstOrDefault(); - return block; + return Visit(root)?.Ancestors.OfType().FirstOrDefault(); ILInstruction Visit(ILInstruction inst) { @@ -417,28 +440,6 @@ namespace ICSharpCode.Decompiler.IL.Transforms return variable; } - bool ValidateUse(DisplayClass container, ILInstruction use) - { - IField field; - switch (use.Parent) { - case LdFlda ldflda when ldflda.MatchLdFlda(out _, out field): - var keyField = (IField)field.MemberDefinition; - if (!container.VariablesToDeclare.TryGetValue(keyField, out VariableToDeclare variable) || variable == null) { - variable = AddVariable(container, null, field); - } - container.VariablesToDeclare[keyField] = variable; - return variable != null; - case StObj stobj when stobj.MatchStObj(out var target, out ILInstruction value, out _) && value == use: - if (target.MatchLdFlda(out var load, out field) && load.MatchLdLocRef(out var otherVariable) && displayClasses.TryGetValue(otherVariable, out var displayClass)) { - if (displayClass.VariablesToDeclare.TryGetValue((IField)field.MemberDefinition, out var declaredVar)) - return declaredVar.CanPropagate; - } - return true; - default: - return false; - } - } - private void Transform(ILFunction function) { VisitILFunction(function); @@ -630,8 +631,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms } return; } - // TODO : this is dangerous! - if (inst.Value.MatchLdLocRef(out var otherVariable) && displayClasses.TryGetValue(otherVariable, out displayClass)) { + if (inst.Value.MatchLdLocRef(out var otherVariable) && displayClasses.TryGetValue(otherVariable, out displayClass) && ValidateDisplayClassUses(inst.Variable, displayClass)) { instructionsToRemove.Add(inst); displayClasses.Add(inst.Variable, displayClass); }