using System; using System.Diagnostics; using System.Linq; using System.Reflection.Metadata; using System.Threading; using ICSharpCode.Decompiler.IL; using ICSharpCode.Decompiler.IL.Transforms; using ICSharpCode.Decompiler.TypeSystem; namespace ICSharpCode.Decompiler.CSharp { class RecordDecompiler { readonly IDecompilerTypeSystem typeSystem; readonly ITypeDefinition recordTypeDef; readonly CancellationToken cancellationToken; public RecordDecompiler(IDecompilerTypeSystem dts, ITypeDefinition recordTypeDef, CancellationToken cancellationToken) { this.typeSystem = dts; this.recordTypeDef = recordTypeDef; this.cancellationToken = cancellationToken; } bool IsRecordType(IType type) { return type.GetDefinition() == recordTypeDef && type.TypeArguments.SequenceEqual(recordTypeDef.TypeParameters); } /// /// Gets whether the member of the record type will be automatically generated by the compiler. /// public bool MethodIsGenerated(IMethod method) { switch (method.Name) { // Some members in records are always compiler-generated and lead to a // "duplicate definition" error if we emit the generated code. case "op_Equality": case "op_Inequality": { // Don't emit comparison operators into C# record definition // Note: user can declare additional operator== as long as they have // different parameter types. return method.Parameters.Count == 2 && method.Parameters.All(p => IsRecordType(p.Type)); } case "Equals" when method.Parameters.Count == 1: { IType paramType = method.Parameters[0].Type; if (paramType.IsKnownType(KnownTypeCode.Object)) { // override bool Equals(object? obj): always generated return true; } else if (IsRecordType(paramType)) { // virtual bool Equals(R? other): generated unless user-declared return false; } else { return false; } } case "$" when method.Parameters.Count == 0: // Always generated; Method name cannot be expressed in C# return true; case "ToString" when method.Parameters.Count == 0: return IsGeneratedToString(method); default: return false; } } internal bool PropertyIsGenerated(IProperty property) { switch (property.Name) { case "EqualityContract": return IsGeneratedEqualityContract(property); default: return false; } } private bool IsGeneratedEqualityContract(IProperty property) { // Generated member: // protected virtual Type EqualityContract { // [CompilerGenerated] get => typeof(R); // } Debug.Assert(property.Name == "EqualityContract"); if (property.Accessibility != Accessibility.Protected) return false; if (!(property.IsVirtual || property.IsOverride)) return false; if (property.IsSealed) return false; var getter = property.Getter; if (!(getter != null && !property.CanSet)) return false; if (property.GetAttributes().Any()) return false; if (getter.GetReturnTypeAttributes().Any()) return false; var attrs = getter.GetAttributes().ToList(); if (attrs.Count != 1) return false; if (!attrs[0].AttributeType.IsKnownType(KnownAttribute.CompilerGenerated)) return false; var body = DecompileBody(getter); if (body == null || body.Instructions.Count != 1) return false; if (!(body.Instructions.Single() is Leave leave)) return false; // leave IL_0000 (call GetTypeFromHandle(ldtypetoken R)) if (!TransformExpressionTrees.MatchGetTypeFromHandle(leave.Value, out IType ty)) return false; return IsRecordType(ty); } private bool IsGeneratedToString(IMethod method) { Debug.Assert(method.Name == "ToString"); if (!method.IsOverride) return false; if (method.IsSealed) return false; if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any()) return false; var body = DecompileBody(method); if (body == null) return false; // stloc stringBuilder(newobj StringBuilder..ctor()) if (!body.Instructions[0].MatchStLoc(out var stringBuilder, out var stringBuilderInit)) return false; if (!(stringBuilderInit is NewObj { Arguments: { Count: 0 }, Method: { DeclaringTypeDefinition: { Name: "StringBuilder", Namespace: "System.Text" } } })) return false; // callvirt Append(ldloc stringBuilder, ldstr "R") if (!MatchAppendCallWithValue(body.Instructions[1], recordTypeDef.Name)) return false; // callvirt Append(ldloc stringBuilder, ldstr " { ") if (!MatchAppendCallWithValue(body.Instructions[2], " { ")) return false; // if (callvirt PrintMembers(ldloc this, ldloc stringBuilder)) { trueInst } if (!body.Instructions[3].MatchIfInstruction(out var condition, out var trueInst)) return true; if (!(condition is CallVirt { Method: { Name: "PrintMembers" } } printMembersCall)) return false; if (printMembersCall.Arguments.Count != 2) return false; if (!printMembersCall.Arguments[0].MatchLdThis()) return false; if (!printMembersCall.Arguments[1].MatchLdLoc(stringBuilder)) return false; // trueInst: callvirt Append(ldloc stringBuilder, ldstr " ") if (!MatchAppendCallWithValue(Block.Unwrap(trueInst), " ")) return false; // callvirt Append(ldloc stringBuilder, ldstr "}") if (!MatchAppendCallWithValue(body.Instructions[4], "}")) return false; // leave IL_0000 (callvirt ToString(ldloc stringBuilder)) if (!(body.Instructions[5] is Leave leave)) return false; if (!(leave.Value is CallVirt { Method: { Name: "ToString" } } toStringCall)) return false; if (toStringCall.Arguments.Count != 1) return false; return toStringCall.Arguments[0].MatchLdLoc(stringBuilder); bool MatchAppendCall(ILInstruction inst, out string val) { val = null; if (!(inst is CallVirt { Method: { Name: "Append" } } call)) return false; if (call.Arguments.Count != 2) return false; if (!call.Arguments[0].MatchLdLoc(stringBuilder)) return false; return call.Arguments[1].MatchLdStr(out val); } bool MatchAppendCallWithValue(ILInstruction inst, string val) { return MatchAppendCall(inst, out string tmp) && tmp == val; } } Block DecompileBody(IMethod method) { if (method == null || method.MetadataToken.IsNil) return null; var metadata = typeSystem.MainModule.metadata; var methodDefHandle = (MethodDefinitionHandle)method.MetadataToken; var methodDef = metadata.GetMethodDefinition(methodDefHandle); if (!methodDef.HasBody()) return null; var genericContext = new GenericContext( classTypeParameters: recordTypeDef.TypeParameters, methodTypeParameters: null); var body = typeSystem.MainModule.PEFile.Reader.GetMethodBody(methodDef.RelativeVirtualAddress); var ilReader = new ILReader(typeSystem.MainModule); var il = ilReader.ReadIL(methodDefHandle, body, genericContext, ILFunctionKind.TopLevelFunction, cancellationToken); var settings = new DecompilerSettings(LanguageVersion.CSharp1); var transforms = CSharpDecompiler.GetILTransforms(); // Remove the last couple transforms -- we don't need variable names etc. here int lastBlockTransform = transforms.FindLastIndex(t => t is BlockILTransform); transforms.RemoveRange(lastBlockTransform + 1, transforms.Count - (lastBlockTransform + 1)); il.RunTransforms(transforms, new ILTransformContext(il, typeSystem, debugInfo: null, settings) { CancellationToken = cancellationToken }); if (il.Body is BlockContainer container) { return container.EntryPoint; } else if (il.Body is Block block) { return block; } else { return null; } } } }