Browse Source

Records: hide generated Equals() method

pull/2251/head
Daniel Grunwald 4 years ago
parent
commit
648e7f9f87
  1. 265
      ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs

265
ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs

@ -36,6 +36,8 @@ namespace ICSharpCode.Decompiler.CSharp @@ -36,6 +36,8 @@ namespace ICSharpCode.Decompiler.CSharp
readonly CancellationToken cancellationToken;
readonly List<IMember> orderedMembers;
readonly bool isInheritedRecord;
readonly Dictionary<IField, IProperty> backingFieldToAutoProperty = new Dictionary<IField, IProperty>();
readonly Dictionary<IProperty, IField> autoPropertyToBackingField = new Dictionary<IProperty, IField>();
public RecordDecompiler(IDecompilerTypeSystem dts, ITypeDefinition recordTypeDef, CancellationToken cancellationToken)
{
@ -43,10 +45,104 @@ namespace ICSharpCode.Decompiler.CSharp @@ -43,10 +45,104 @@ namespace ICSharpCode.Decompiler.CSharp
this.recordTypeDef = recordTypeDef;
this.cancellationToken = cancellationToken;
this.isInheritedRecord = recordTypeDef.DirectBaseTypes.Any(b => b.Kind == TypeKind.Class && !b.IsKnownType(KnownTypeCode.Object));
this.orderedMembers = DetectMemberOrder(recordTypeDef);
DetectAutomaticProperties();
this.orderedMembers = DetectMemberOrder(recordTypeDef, backingFieldToAutoProperty);
}
static List<IMember> DetectMemberOrder(ITypeDefinition recordTypeDef)
void DetectAutomaticProperties()
{
foreach (var p in recordTypeDef.Properties)
{
cancellationToken.ThrowIfCancellationRequested();
if (IsAutoProperty(p, out var field))
{
backingFieldToAutoProperty.Add(field, p);
autoPropertyToBackingField.Add(p, field);
}
}
bool IsAutoProperty(IProperty p, out IField field)
{
field = null;
if (p.Parameters.Count != 0)
return false;
if (p.Getter != null)
{
if (!IsAutoGetter(p.Getter, out field))
return false;
}
if (p.Setter != null)
{
if (!IsAutoSetter(p.Setter, out var field2))
return false;
if (field != null)
{
if (!field.Equals(field2))
return false;
}
else
{
field = field2;
}
}
if (field == null)
return false;
if (!IsRecordType(field.DeclaringType))
return false;
return field.Name == $"<{p.Name}>k__BackingField";
}
bool IsAutoGetter(IMethod method, out IField field)
{
field = null;
var body = DecompileBody(method);
if (body == null)
return false;
// return this.field;
if (!body.Instructions[0].MatchReturn(out var retVal))
return false;
if (method.IsStatic)
{
return retVal.MatchLdsFld(out field);
}
else
{
if (!retVal.MatchLdFld(out var target, out field))
return false;
return target.MatchLdThis();
}
}
bool IsAutoSetter(IMethod method, out IField field)
{
field = null;
Debug.Assert(!method.IsStatic);
var body = DecompileBody(method);
if (body == null)
return false;
// this.field = value;
ILInstruction valueInst;
if (method.IsStatic)
{
if (!body.Instructions[0].MatchStsFld(out field, out valueInst))
return false;
}
else
{
if (!body.Instructions[0].MatchStFld(out var target, out field, out valueInst))
return false;
if (!target.MatchLdThis())
return false;
}
if (!valueInst.MatchLdLoc(out var value))
return false;
if (!(value.Kind == VariableKind.Parameter && value.Index == 0))
return false;
return body.Instructions[1].MatchReturn(out var retVal) && retVal.MatchNop();
}
}
static List<IMember> DetectMemberOrder(ITypeDefinition recordTypeDef, Dictionary<IField, IProperty> backingFieldToAutoProperty)
{
// For records, the order of members is important:
// Equals/GetHashCode/PrintMembers must agree on an order of fields+properties.
@ -54,7 +150,7 @@ namespace ICSharpCode.Decompiler.CSharp @@ -54,7 +150,7 @@ namespace ICSharpCode.Decompiler.CSharp
// need to detect the correct interleaving.
// We could try to detect this from the PrintMembers body, but let's initially
// restrict ourselves to the common case where the record only uses properties.
if (recordTypeDef.Fields.All(f => f.Name.StartsWith("<", StringComparison.Ordinal) && f.Name.EndsWith("BackingField", StringComparison.Ordinal)))
if (recordTypeDef.Fields.All(backingFieldToAutoProperty.ContainsKey))
{
return recordTypeDef.Properties.ToList<IMember>();
}
@ -96,7 +192,7 @@ namespace ICSharpCode.Decompiler.CSharp @@ -96,7 +192,7 @@ namespace ICSharpCode.Decompiler.CSharp
else if (IsRecordType(paramType))
{
// virtual bool Equals(R? other): generated unless user-declared
return false;
return IsGeneratedEquals(method);
}
else
{
@ -161,6 +257,7 @@ namespace ICSharpCode.Decompiler.CSharp @@ -161,6 +257,7 @@ namespace ICSharpCode.Decompiler.CSharp
return false;
return IsRecordType(ty);
}
private bool IsGeneratedPrintMembers(IMethod method)
{
Debug.Assert(method.Name == "PrintMembers");
@ -206,10 +303,19 @@ namespace ICSharpCode.Decompiler.CSharp @@ -206,10 +303,19 @@ namespace ICSharpCode.Decompiler.CSharp
bool needsComma = false;
foreach (var member in orderedMembers)
{
if (member.IsStatic)
{
continue; // static fields/properties are not printed
}
if (member.Name == "EqualityContract")
{
continue; // EqualityContract is never printed
}
if (member.IsExplicitInterfaceImplementation)
{
continue; // explicit interface impls are not printed
}
cancellationToken.ThrowIfCancellationRequested();
/*
callvirt Append(ldloc builder, ldstr "A")
callvirt Append(ldloc builder, ldstr " = ")
@ -349,6 +455,157 @@ namespace ICSharpCode.Decompiler.CSharp @@ -349,6 +455,157 @@ namespace ICSharpCode.Decompiler.CSharp
}
}
private bool IsGeneratedEquals(IMethod method)
{
// virtual bool Equals(R? other) {
// return other != null && EqualityContract == other.EqualityContract && EqualityComparer<int>.Default.Equals(A, other.A) && ...;
// }
Debug.Assert(method.Name == "Equals" && method.Parameters.Count == 1);
if (method.Parameters.Count != 1)
return false;
if (!method.IsOverridable)
return false;
if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any())
return false;
if (orderedMembers == null)
return false;
var body = DecompileBody(method);
if (body == null)
return false;
if (!body.Instructions[0].MatchReturn(out var returnValue))
return false;
var variables = body.Ancestors.OfType<ILFunction>().Single().Variables;
var other = variables.Single(v => v.Kind == VariableKind.Parameter && v.Index == 0);
Debug.Assert(IsRecordType(other.Type));
var conditions = UnpackLogicAndChain(returnValue);
Debug.Assert(conditions.Count >= 1);
int pos = 0;
if (isInheritedRecord)
{
return false; // TODO: implement this case
}
else
{
// comp.o(ldloc other != ldnull)
if (pos >= conditions.Count)
return false;
if (!conditions[pos].MatchCompNotEqualsNull(out var arg))
return false;
if (!arg.MatchLdLoc(other))
return false;
pos++;
// call op_Equality(callvirt get_EqualityContract(ldloc this), callvirt get_EqualityContract(ldloc other))
if (pos >= conditions.Count)
return false;
if (!(conditions[pos] is Call { Method: { IsOperator: true, Name: "op_Equality" } } opEqualityCall))
return false;
if (!opEqualityCall.Method.DeclaringType.IsKnownType(KnownTypeCode.Type))
return false;
if (opEqualityCall.Arguments.Count != 2)
return false;
if (!MatchGetEqualityContract(opEqualityCall.Arguments[0], out var target1))
return false;
if (!MatchGetEqualityContract(opEqualityCall.Arguments[1], out var target2))
return false;
if (!target1.MatchLdThis())
return false;
if (!target2.MatchLdLoc(other))
return false;
pos++;
}
foreach (var member in orderedMembers)
{
if (member.IsStatic)
continue;
if (member.Name == "EqualityContract")
continue; // already special-cased
IField field = member as IField;
if (field == null)
{
if (!(member is IProperty property))
return false;
if (!autoPropertyToBackingField.TryGetValue(property, out field))
continue; // Equals ignores non-automatic properties
}
// EqualityComparer<int>.Default.Equals(A, other.A)
// callvirt Equals(call get_Default(), ldfld <A>k__BackingField(ldloc this), ldfld <A>k__BackingField(ldloc other))
if (pos >= conditions.Count)
return false;
if (!(conditions[pos] is CallVirt { Method: { Name: "Equals" } } equalsCall))
return false;
if (equalsCall.Arguments.Count != 3)
return false;
if (!IsEqualityComparerGetDefaultCall(equalsCall.Arguments[0], field.Type))
return false;
if (!MatchLdFld(equalsCall.Arguments[1], field, out var target1))
return false;
if (!MatchLdFld(equalsCall.Arguments[2], field, out var target2))
return false;
if (!target1.MatchLdThis())
return false;
if (!target2.MatchLdLoc(other))
return false;
pos++;
}
return pos == conditions.Count;
}
static bool MatchLdFld(ILInstruction inst, IField field, out ILInstruction target)
{
if (inst.MatchLdFld(out target, out IField f))
{
return f.Equals(field);
}
else
{
return false;
}
}
static List<ILInstruction> UnpackLogicAndChain(ILInstruction rootOfChain)
{
var result = new List<ILInstruction>();
Visit(rootOfChain);
return result;
void Visit(ILInstruction inst)
{
if (inst.MatchLogicAnd(out var lhs, out var rhs))
{
Visit(lhs);
Visit(rhs);
}
else
{
result.Add(inst);
}
}
}
private static bool MatchGetEqualityContract(ILInstruction inst, out ILInstruction target)
{
target = null;
if (!(inst is CallVirt { Method: { Name: "get_EqualityContract" } } call))
return false;
if (call.Arguments.Count != 1)
return false;
target = call.Arguments[0];
return true;
}
private static bool IsEqualityComparerGetDefaultCall(ILInstruction inst, IType type)
{
if (!(inst is Call { Method: { Name: "get_Default", IsStatic: true } } call))
return false;
if (!(call.Method.DeclaringType is { Name: "EqualityComparer", Namespace: "System.Collections.Generic" }))
return false;
if (call.Method.DeclaringType.TypeArguments.Count != 1)
return false;
if (!NormalizeTypeVisitor.TypeErasure.EquivalentTypes(call.Method.DeclaringType.TypeArguments[0], type))
return false;
return call.Arguments.Count == 0;
}
Block DecompileBody(IMethod method)
{
if (method == null || method.MetadataToken.IsNil)

Loading…
Cancel
Save