diff --git a/ICSharpCode.Decompiler/IL/Transforms/TransformCollectionAndObjectInitializers.cs b/ICSharpCode.Decompiler/IL/Transforms/TransformCollectionAndObjectInitializers.cs index 722c9c888..0c2bc44dd 100644 --- a/ICSharpCode.Decompiler/IL/Transforms/TransformCollectionAndObjectInitializers.cs +++ b/ICSharpCode.Decompiler/IL/Transforms/TransformCollectionAndObjectInitializers.cs @@ -43,22 +43,27 @@ namespace ICSharpCode.Decompiler.IL.Transforms { ILInstruction inst = body.Instructions[pos]; // Match stloc(v, newobj) - if (inst.MatchStLoc(out var v, out var initInst)) { + if (inst.MatchStLoc(out var v, out var initInst) && (v.Kind == VariableKind.Local || v.Kind == VariableKind.StackSlot)) { + Block initializerBlock = null; switch (initInst) { case NewObj newObjInst: if (DelegateConstruction.IsDelegateConstruction(newObjInst) || DelegateConstruction.IsPotentialClosure(context, newObjInst)) return false; - if (newObjInst.Method.DeclaringType.Kind != TypeKind.Struct && v.Kind != VariableKind.StackSlot) - return false; break; case DefaultValue defaultVal: break; + case Block existingInitializer: + if (existingInitializer.Type == BlockType.CollectionInitializer || existingInitializer.Type == BlockType.ObjectInitializer) { + initializerBlock = existingInitializer; + break; + } + return false; default: return false; } context.Step("CollectionOrObjectInitializer", inst); int initializerItemsCount = 0; - var blockType = BlockType.CollectionInitializer; + var blockType = initializerBlock?.Type ?? BlockType.CollectionInitializer; // Detect initializer type by scanning the following statements // each must be a callvirt with ldloc v as first argument // if the method is a setter we're dealing with an object initializer @@ -68,10 +73,15 @@ namespace ICSharpCode.Decompiler.IL.Transforms initializerItemsCount++; if (initializerItemsCount == 0) return false; - Block initBlock = new Block(blockType); - var finalSlot = context.Function.RegisterVariable(VariableKind.StackSlot, v.Type); - initBlock.FinalInstruction = new LdLoc(finalSlot); - initBlock.Instructions.Add(new StLoc(finalSlot, initInst.Clone())); + ILVariable finalSlot; + if (initializerBlock == null) { + initializerBlock = new Block(blockType); + finalSlot = context.Function.RegisterVariable(VariableKind.StackSlot, v.Type); + initializerBlock.FinalInstruction = new LdLoc(finalSlot); + initializerBlock.Instructions.Add(new StLoc(finalSlot, initInst.Clone())); + } else { + finalSlot = ((LdLoc)initializerBlock.FinalInstruction).Variable; + } for (int i = 1; i <= initializerItemsCount; i++) { switch (body.Instructions[i + pos]) { case CallInstruction call: @@ -81,19 +91,18 @@ namespace ICSharpCode.Decompiler.IL.Transforms foreach (var load in newTarget.Descendants.OfType()) if (load is LdLoc || load is LdLoca) load.Variable = finalSlot; - initBlock.Instructions.Add(newCall); + initializerBlock.Instructions.Add(newCall); break; case StObj stObj: var newStObj = (StObj)stObj.Clone(); foreach (var load in newStObj.Target.Descendants.OfType()) if (load is LdLoc || load is LdLoca) load.Variable = finalSlot; - initBlock.Instructions.Add(newStObj); + initializerBlock.Instructions.Add(newStObj); break; } - } - initInst.ReplaceWith(initBlock); + initInst.ReplaceWith(initializerBlock); for (int i = 0; i < initializerItemsCount; i++) body.Instructions.RemoveAt(pos + 1); ILInlining.InlineIfPossible(body, ref pos, context); @@ -151,7 +160,7 @@ namespace ICSharpCode.Decompiler.IL.Transforms case CallInstruction call: if (!(call is CallVirt || call is Call)) goto default; method = call.Method; - if (!IsMethodApplicable(method)) goto default; + if (!IsMethodApplicable(method, call.Arguments)) goto default; instruction = call.Arguments[0]; if (values == null) { values = new List(call.Arguments.Skip(1)); @@ -206,13 +215,42 @@ namespace ICSharpCode.Decompiler.IL.Transforms return (kind, path, values, target); } - static bool IsMethodApplicable(IMethod method) + static bool IsMethodApplicable(IMethod method, IList arguments) { - if (method.IsStatic) + if (!method.IsExtensionMethod && method.IsStatic) return false; if (method.IsAccessor) return true; - return "Add".Equals(method.Name, StringComparison.Ordinal); + if (!"Add".Equals(method.Name, StringComparison.Ordinal) || arguments.Count == 0) + return false; + var targetType = GetReturnTypeFromInstruction(arguments[0]); + if (targetType == null) + return false; + return targetType.GetAllBaseTypes().Any(i => i.IsKnownType(KnownTypeCode.IEnumerable) || i.IsKnownType(KnownTypeCode.IEnumerableOfT)); + } + + static IType GetReturnTypeFromInstruction(ILInstruction instruction) + { + // this switch must match the one in GetAccessPath + switch (instruction) { + case CallInstruction call: + if (!(call is CallVirt || call is Call)) goto default; + return call.Method.ReturnType; + case LdObj ldobj: + if (ldobj.Target is LdFlda ldflda) + return ldflda.Field.ReturnType; + goto default; + case StObj stobj: + if (stobj.Target is LdFlda ldflda2) + return ldflda2.Field.ReturnType; + goto default; + case LdLoc ldloc: + return ldloc.Variable.Type; + case LdLoca ldloca: + return ldloca.Variable.Type; + default: + return null; + } } public override bool Equals(object obj)