diff --git a/ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs b/ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs index 45bd140d6..dcd2d5530 100644 --- a/ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs +++ b/ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs @@ -36,6 +36,7 @@ namespace ICSharpCode.Decompiler.CSharp readonly CancellationToken cancellationToken; readonly List orderedMembers; readonly bool isInheritedRecord; + readonly IType baseClass; readonly Dictionary backingFieldToAutoProperty = new Dictionary(); readonly Dictionary autoPropertyToBackingField = new Dictionary(); @@ -44,7 +45,8 @@ namespace ICSharpCode.Decompiler.CSharp this.typeSystem = dts; this.recordTypeDef = recordTypeDef; this.cancellationToken = cancellationToken; - this.isInheritedRecord = recordTypeDef.DirectBaseTypes.Any(b => b.Kind == TypeKind.Class && !b.IsKnownType(KnownTypeCode.Object)); + this.baseClass = recordTypeDef.DirectBaseTypes.FirstOrDefault(b => b.Kind == TypeKind.Class); + this.isInheritedRecord = !baseClass.IsKnownType(KnownTypeCode.Object); DetectAutomaticProperties(); this.orderedMembers = DetectMemberOrder(recordTypeDef, backingFieldToAutoProperty); } @@ -199,6 +201,8 @@ namespace ICSharpCode.Decompiler.CSharp return false; } } + case "GetHashCode": + return IsGeneratedGetHashCode(method); case "$" when method.Parameters.Count == 0: // Always generated; Method name cannot be expressed in C# return true; @@ -306,7 +310,7 @@ namespace ICSharpCode.Decompiler.CSharp if (member.IsStatic) { continue; // static fields/properties are not printed - } + } if (member.Name == "EqualityContract") { continue; // EqualityContract is never printed @@ -495,6 +499,7 @@ namespace ICSharpCode.Decompiler.CSharp return false; pos++; // call op_Equality(callvirt get_EqualityContract(ldloc this), callvirt get_EqualityContract(ldloc other)) + // Special-cased because Roslyn isn't using EqualityComparer here. if (pos >= conditions.Count) return false; if (!(conditions[pos] is Call { Method: { IsOperator: true, Name: "op_Equality" } } opEqualityCall)) @@ -515,17 +520,11 @@ namespace ICSharpCode.Decompiler.CSharp } foreach (var member in orderedMembers) { - if (member.IsStatic) + if (!MemberConsideredForEquality(member)) continue; if (member.Name == "EqualityContract") - continue; // already special-cased - IField field = member as IField; - if (field == null) { - if (!(member is IProperty property)) - return false; - if (!autoPropertyToBackingField.TryGetValue(property, out field)) - continue; // Equals ignores non-automatic properties + continue; // already special-cased } // EqualityComparer.Default.Equals(A, other.A) // callvirt Equals(call get_Default(), ldfld k__BackingField(ldloc this), ldfld k__BackingField(ldloc other)) @@ -535,33 +534,25 @@ namespace ICSharpCode.Decompiler.CSharp return false; if (equalsCall.Arguments.Count != 3) return false; - if (!IsEqualityComparerGetDefaultCall(equalsCall.Arguments[0], field.Type)) + if (!IsEqualityComparerGetDefaultCall(equalsCall.Arguments[0], member.ReturnType)) return false; - if (!MatchLdFld(equalsCall.Arguments[1], field, out var target1)) + if (!MatchMemberAccess(equalsCall.Arguments[1], out var target1, out var member1)) return false; - if (!MatchLdFld(equalsCall.Arguments[2], field, out var target2)) + if (!MatchMemberAccess(equalsCall.Arguments[2], out var target2, out var member2)) return false; if (!target1.MatchLdThis()) return false; + if (!member1.Equals(member)) + return false; if (!target2.MatchLdLoc(other)) return false; + if (!member2.Equals(member)) + return false; pos++; } return pos == conditions.Count; } - static bool MatchLdFld(ILInstruction inst, IField field, out ILInstruction target) - { - if (inst.MatchLdFld(out target, out IField f)) - { - return f.Equals(field); - } - else - { - return false; - } - } - static List UnpackLogicAndChain(ILInstruction rootOfChain) { var result = new List(); @@ -581,7 +572,7 @@ namespace ICSharpCode.Decompiler.CSharp } } } - + private static bool MatchGetEqualityContract(ILInstruction inst, out ILInstruction target) { target = null; @@ -606,6 +597,144 @@ namespace ICSharpCode.Decompiler.CSharp return call.Arguments.Count == 0; } + bool MemberConsideredForEquality(IMember member) + { + if (member.IsStatic) + return false; + if (member is IProperty property) + { + if (property.Name == "EqualityContract") + return !isInheritedRecord; + return autoPropertyToBackingField.ContainsKey(property); + } + else + { + return member is IField; + } + } + + bool IsGeneratedGetHashCode(IMethod method) + { + /* + return ( + ( + EqualityComparer.Default.GetHashCode(EqualityContract) * -1521134295 + EqualityComparer.Default.GetHashCode(A) + ) * -1521134295 + EqualityComparer.Default.GetHashCode(B) + ) * -1521134295 + EqualityComparer.Default.GetHashCode(C); + */ + Debug.Assert(method.Name == "GetHashCode"); + if (method.Parameters.Count != 0) + return false; + if (!method.IsOverride || method.IsSealed) + return false; + if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any()) + return false; + if (orderedMembers == null) + return false; + var body = DecompileBody(method); + if (body == null) + return false; + if (!body.Instructions[0].MatchReturn(out var returnValue)) + return false; + var hashedMembers = new List(); + bool foundBaseClassHash = false; + if (!Visit(returnValue)) + return false; + if (foundBaseClassHash != isInheritedRecord) + return false; + return orderedMembers.Where(MemberConsideredForEquality).SequenceEqual(hashedMembers); + + bool Visit(ILInstruction inst) + { + if (inst is BinaryNumericInstruction + { + Operator: BinaryNumericOperator.Add, + CheckForOverflow: false, + Left: BinaryNumericInstruction + { + Operator: BinaryNumericOperator.Mul, + CheckForOverflow: false, + Left: var left, + Right: LdcI4 { Value: -1521134295 } + }, + Right: var right + }) + { + if (!Visit(left)) + return false; + return ProcessIndividualHashCode(right); + } + else + { + return ProcessIndividualHashCode(inst); + } + } + + bool ProcessIndividualHashCode(ILInstruction inst) + { + // base.GetHashCode(): call GetHashCode(ldloc this) + if (inst is Call { Method: { Name: "GetHashCode" } } baseHashCodeCall) + { + if (baseHashCodeCall.Arguments.Count != 1) + return false; + if (!baseHashCodeCall.Arguments[0].MatchLdThis()) + return false; + if (foundBaseClassHash || hashedMembers.Count > 0) + return false; // must be first + foundBaseClassHash = true; + return baseHashCodeCall.Method.DeclaringType.Equals(baseClass); + } + // callvirt GetHashCode(call get_Default(), callvirt get_EqualityContract(ldloc this)) + // callvirt GetHashCode(call get_Default(), ldfld k__BackingField(ldloc this))) + if (!(inst is CallVirt { Method: { Name: "GetHashCode" } } getHashCodeCall)) + return false; + if (getHashCodeCall.Arguments.Count != 2) + return false; + // getHashCodeCall.Arguments[0] checked later + if (!MatchMemberAccess(getHashCodeCall.Arguments[1], out var target, out var member)) + return false; + if (!target.MatchLdThis()) + return false; + if (!IsEqualityComparerGetDefaultCall(getHashCodeCall.Arguments[0], member.ReturnType)) + return false; + hashedMembers.Add(member); + return true; + } + } + + bool MatchMemberAccess(ILInstruction inst, out ILInstruction target, out IMember member) + { + target = null; + member = null; + if (inst is CallVirt + { + Method: + { + AccessorKind: System.Reflection.MethodSemanticsAttributes.Getter, + AccessorOwner: IProperty property + } + } call) + { + if (call.Arguments.Count != 1) + return false; + target = call.Arguments[0]; + member = property; + return true; + } + else if (inst.MatchLdFld(out target, out IField field)) + { + if (backingFieldToAutoProperty.TryGetValue(field, out property)) + member = property; + else + member = field; + return true; + } + else + { + return false; + } + } + Block DecompileBody(IMethod method) { if (method == null || method.MetadataToken.IsNil)