Browse Source

Records: Detect compiler-generated Equals() in derived records.

pull/2251/head
Daniel Grunwald 5 years ago
parent
commit
be9871981a
  1. 27
      ICSharpCode.Decompiler.Tests/TestCases/Pretty/Records.cs
  2. 87
      ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs
  3. 15
      ICSharpCode.Decompiler/TypeSystem/Implementation/AttributeListBuilder.cs

27
ICSharpCode.Decompiler.Tests/TestCases/Pretty/Records.cs

@ -53,6 +53,33 @@
B = 42; B = 42;
} }
} }
public abstract record WithNestedRecords
{
public record A : WithNestedRecords
{
public override string AbstractProp => "A";
}
public record B : WithNestedRecords
{
public override string AbstractProp => "B";
public int? Value {
get;
set;
}
}
public record DerivedGeneric<T> : Pair<T, T?> where T : struct
{
public bool Flag;
}
public abstract string AbstractProp {
get;
}
}
} }
namespace System.Runtime.CompilerServices namespace System.Runtime.CompilerServices
{ {

87
ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs

@ -198,7 +198,7 @@ namespace ICSharpCode.Decompiler.CSharp
case "Equals" when method.Parameters.Count == 1: case "Equals" when method.Parameters.Count == 1:
{ {
IType paramType = method.Parameters[0].Type; IType paramType = method.Parameters[0].Type;
if (paramType.IsKnownType(KnownTypeCode.Object)) if (paramType.IsKnownType(KnownTypeCode.Object) && method.IsOverride)
{ {
// override bool Equals(object? obj): always generated // override bool Equals(object? obj): always generated
return true; return true;
@ -208,6 +208,11 @@ namespace ICSharpCode.Decompiler.CSharp
// virtual bool Equals(R? other): generated unless user-declared // virtual bool Equals(R? other): generated unless user-declared
return IsGeneratedEquals(method); return IsGeneratedEquals(method);
} }
else if (isInheritedRecord && NormalizeTypeVisitor.TypeErasure.EquivalentTypes(paramType, baseClass) && method.IsOverride)
{
// override bool Equals(BaseClass? obj): always generated
return true;
}
else else
{ {
return false; return false;
@ -248,6 +253,8 @@ namespace ICSharpCode.Decompiler.CSharp
Debug.Assert(method.IsConstructor && method.Parameters.Count == 1); Debug.Assert(method.IsConstructor && method.Parameters.Count == 1);
if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any()) if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any())
return false; return false;
if (method.Accessibility != Accessibility.Protected)
return false;
if (orderedMembers == null) if (orderedMembers == null)
return false; return false;
var body = DecompileBody(method); var body = DecompileBody(method);
@ -345,6 +352,8 @@ namespace ICSharpCode.Decompiler.CSharp
return false; return false;
if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any()) if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any())
return false; return false;
if (method.Accessibility != Accessibility.Protected)
return false;
if (orderedMembers == null) if (orderedMembers == null)
return false; return false;
var body = DecompileBody(method); var body = DecompileBody(method);
@ -357,18 +366,18 @@ namespace ICSharpCode.Decompiler.CSharp
int pos = 0; int pos = 0;
if (isInheritedRecord) if (isInheritedRecord)
{ {
// Special case: inherited record adding no new members
if (body.Instructions[pos].MatchReturn(out var returnValue)
&& IsBaseCall(returnValue) && !orderedMembers.Any(IsPrintedMember))
{
return true;
}
// if (call PrintMembers(ldloc this, ldloc builder)) Block IL_000f { // if (call PrintMembers(ldloc this, ldloc builder)) Block IL_000f {
// callvirt Append(ldloc builder, ldstr ", ") // callvirt Append(ldloc builder, ldstr ", ")
// } // }
if (!body.Instructions[pos].MatchIfInstruction(out var condition, out var trueInst)) if (!body.Instructions[pos].MatchIfInstruction(out var condition, out var trueInst))
return false; return false;
if (!(condition is CallInstruction { Method: { Name: "PrintMembers" } } call)) if (!IsBaseCall(condition))
return false;
if (call.Arguments.Count != 2)
return false;
if (!call.Arguments[0].MatchLdThis())
return false;
if (!call.Arguments[1].MatchLdLoc(builder))
return false; return false;
// trueInst = callvirt Append(ldloc builder, ldstr ", ") // trueInst = callvirt Append(ldloc builder, ldstr ", ")
trueInst = Block.Unwrap(trueInst); trueInst = Block.Unwrap(trueInst);
@ -377,22 +386,25 @@ namespace ICSharpCode.Decompiler.CSharp
if (!(val.MatchLdStr(out string text) && text == ", ")) if (!(val.MatchLdStr(out string text) && text == ", "))
return false; return false;
pos++; pos++;
}
bool needsComma = false; bool IsBaseCall(ILInstruction inst)
foreach (var member in orderedMembers)
{
if (member.IsStatic)
{ {
continue; // static fields/properties are not printed if (!(inst is CallInstruction { Method: { Name: "PrintMembers" } } call))
return false;
if (call.Arguments.Count != 2)
return false;
if (!call.Arguments[0].MatchLdThis())
return false;
if (!call.Arguments[1].MatchLdLoc(builder))
return false;
return true;
} }
if (member.Name == "EqualityContract")
{
continue; // EqualityContract is never printed
} }
if (member.IsExplicitInterfaceImplementation) bool needsComma = false;
foreach (var member in orderedMembers)
{ {
continue; // explicit interface impls are not printed if (!IsPrintedMember(member))
} continue;
cancellationToken.ThrowIfCancellationRequested(); cancellationToken.ThrowIfCancellationRequested();
/* /*
callvirt Append(ldloc builder, ldstr "A") callvirt Append(ldloc builder, ldstr "A")
@ -453,6 +465,26 @@ namespace ICSharpCode.Decompiler.CSharp
return body.Instructions[pos].MatchReturn(out var retVal) return body.Instructions[pos].MatchReturn(out var retVal)
&& retVal.MatchLdcI4(needsComma ? 1 : 0); && retVal.MatchLdcI4(needsComma ? 1 : 0);
bool IsPrintedMember(IMember member)
{
if (member.IsStatic)
{
return false; // static fields/properties are not printed
}
if (member.Name == "EqualityContract")
{
return false; // EqualityContract is never printed
}
if (member.IsExplicitInterfaceImplementation)
{
return false; // explicit interface impls are not printed
}
if (member.IsOverride)
{
return false; // override is not printed (again), the virtual base property was already printed
}
return true;
}
bool MatchStringBuilderAppendConstant(out string text) bool MatchStringBuilderAppendConstant(out string text)
{ {
@ -573,7 +605,20 @@ namespace ICSharpCode.Decompiler.CSharp
int pos = 0; int pos = 0;
if (isInheritedRecord) if (isInheritedRecord)
{ {
return false; // TODO: implement this case // call BaseClass::Equals(ldloc this, ldloc other)
if (pos >= conditions.Count)
return false;
if (!(conditions[pos] is Call { Method: { Name: "Equals" } } call))
return false;
if (!NormalizeTypeVisitor.TypeErasure.EquivalentTypes(call.Method.DeclaringType, baseClass))
return false;
if (call.Arguments.Count != 2)
return false;
if (!call.Arguments[0].MatchLdThis())
return false;
if (!call.Arguments[1].MatchLdLoc(other))
return false;
pos++;
} }
else else
{ {

15
ICSharpCode.Decompiler/TypeSystem/Implementation/AttributeListBuilder.cs

@ -245,7 +245,7 @@ namespace ICSharpCode.Decompiler.TypeSystem.Implementation
return (options & TypeSystemOptions.NullabilityAnnotations) != 0; return (options & TypeSystemOptions.NullabilityAnnotations) != 0;
case "NullableContextAttribute": case "NullableContextAttribute":
return (options & TypeSystemOptions.NullabilityAnnotations) != 0 return (options & TypeSystemOptions.NullabilityAnnotations) != 0
&& (target == SymbolKind.TypeDefinition || target == SymbolKind.Method || target == SymbolKind.Accessor); && (target == SymbolKind.TypeDefinition || IsMethodLike(target));
default: default:
return false; return false;
} }
@ -255,6 +255,19 @@ namespace ICSharpCode.Decompiler.TypeSystem.Implementation
return false; return false;
} }
} }
static bool IsMethodLike(SymbolKind kind)
{
return kind switch
{
SymbolKind.Method => true,
SymbolKind.Operator => true,
SymbolKind.Constructor => true,
SymbolKind.Destructor => true,
SymbolKind.Accessor => true,
_ => false
};
}
#endregion #endregion
#region Security Attributes #region Security Attributes

Loading…
Cancel
Save