Browse Source

Fix #2475: Add support for sealed records and records with interface

pull/2476/head
SilverFox 4 years ago
parent
commit
d0d70a6496
  1. 16
      ICSharpCode.Decompiler.Tests/TestCases/Pretty/Records.cs
  2. 152
      ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs
  3. 3
      ICSharpCode.Decompiler/CSharp/Transforms/TransformFieldAndConstructorInitializers.cs

16
ICSharpCode.Decompiler.Tests/TestCases/Pretty/Records.cs

@ -15,6 +15,12 @@ @@ -15,6 +15,12 @@
public string S = "abc";
}
public record Interface(int B) : IRecord;
public interface IRecord
{
}
public record Pair<A, B>
{
public A First { get; init; }
@ -26,11 +32,13 @@ @@ -26,11 +32,13 @@
public record PrimaryCtor(int A, string B);
public record PrimaryCtorWithField(int A, string B)
{
public double C;
public double C = 1.0;
public string D = A + B;
}
public record PrimaryCtorWithProperty(int A, string B)
{
public double C { get; init; }
public double C { get; init; } = 1.0;
public string D { get; } = A + B;
}
public record Properties
@ -48,6 +56,10 @@ @@ -48,6 +56,10 @@
}
}
public sealed record Sealed(string A);
public sealed record SealedDerived(int B) : Base(B.ToString());
public class WithExpressionTests
{
public Fields Test(Fields input)

152
ICSharpCode.Decompiler/CSharp/RecordDecompiler.cs

@ -37,6 +37,8 @@ namespace ICSharpCode.Decompiler.CSharp @@ -37,6 +37,8 @@ namespace ICSharpCode.Decompiler.CSharp
readonly CancellationToken cancellationToken;
readonly List<IMember> orderedMembers;
readonly bool isInheritedRecord;
readonly bool isStruct;
readonly bool isSealed;
readonly IMethod primaryCtor;
readonly IType baseClass;
readonly Dictionary<IField, IProperty> backingFieldToAutoProperty = new Dictionary<IField, IProperty>();
@ -51,7 +53,9 @@ namespace ICSharpCode.Decompiler.CSharp @@ -51,7 +53,9 @@ namespace ICSharpCode.Decompiler.CSharp
this.settings = settings;
this.cancellationToken = cancellationToken;
this.baseClass = recordTypeDef.DirectBaseTypes.FirstOrDefault(b => b.Kind == TypeKind.Class);
this.isInheritedRecord = !baseClass.IsKnownType(KnownTypeCode.Object);
this.isStruct = baseClass.IsKnownType(KnownTypeCode.ValueType);
this.isInheritedRecord = !isStruct && !baseClass.IsKnownType(KnownTypeCode.Object);
this.isSealed = recordTypeDef.IsSealed;
DetectAutomaticProperties();
this.orderedMembers = DetectMemberOrder(recordTypeDef, backingFieldToAutoProperty);
this.primaryCtor = DetectPrimaryConstructor();
@ -180,10 +184,11 @@ namespace ICSharpCode.Decompiler.CSharp @@ -180,10 +184,11 @@ namespace ICSharpCode.Decompiler.CSharp
if (method.Parameters.Count == 0)
return false;
if (body.Instructions.Count != method.Parameters.Count + 2)
var addonInst = isStruct ? 1 : 2;
if (body.Instructions.Count < method.Parameters.Count + addonInst)
return false;
for (int i = 0; i < body.Instructions.Count - 2; i++)
for (int i = 0; i < method.Parameters.Count; i++)
{
if (!body.Instructions[i].MatchStFld(out var target, out var field, out var valueInst))
return false;
@ -199,9 +204,12 @@ namespace ICSharpCode.Decompiler.CSharp @@ -199,9 +204,12 @@ namespace ICSharpCode.Decompiler.CSharp
autoPropertyToPrimaryCtorParameter.Add(property, method.Parameters[i]);
}
var baseCtorCall = body.Instructions.SecondToLastOrDefault() as CallInstruction;
if (baseCtorCall == null)
return false;
if (!isStruct)
{
var baseCtorCall = body.Instructions.SecondToLastOrDefault() as CallInstruction;
if (baseCtorCall == null)
return false;
}
var returnInst = body.Instructions.LastOrDefault();
return returnInst != null && returnInst.MatchReturn(out var retVal) && retVal.MatchNop();
@ -233,6 +241,8 @@ namespace ICSharpCode.Decompiler.CSharp @@ -233,6 +241,8 @@ namespace ICSharpCode.Decompiler.CSharp
/// </summary>
public IMethod PrimaryConstructor => primaryCtor;
public bool IsInheritedRecord => isInheritedRecord;
bool IsRecordType(IType type)
{
return type.GetDefinition() == recordTypeDef
@ -309,7 +319,7 @@ namespace ICSharpCode.Decompiler.CSharp @@ -309,7 +319,7 @@ namespace ICSharpCode.Decompiler.CSharp
{
switch (property.Name)
{
case "EqualityContract":
case "EqualityContract" when !isStruct:
return IsGeneratedEqualityContract(property);
default:
return IsPropertyDeclaredByPrimaryConstructor(property);
@ -333,7 +343,7 @@ namespace ICSharpCode.Decompiler.CSharp @@ -333,7 +343,7 @@ namespace ICSharpCode.Decompiler.CSharp
Debug.Assert(method.IsConstructor && method.Parameters.Count == 1);
if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any())
return false;
if (method.Accessibility != Accessibility.Protected)
if (method.Accessibility != Accessibility.Protected && (!isSealed || method.Accessibility != Accessibility.Private))
return false;
if (orderedMembers == null)
return false;
@ -393,10 +403,10 @@ namespace ICSharpCode.Decompiler.CSharp @@ -393,10 +403,10 @@ namespace ICSharpCode.Decompiler.CSharp
// protected virtual Type EqualityContract {
// [CompilerGenerated] get => typeof(R);
// }
Debug.Assert(property.Name == "EqualityContract");
if (property.Accessibility != Accessibility.Protected)
Debug.Assert(!isStruct && property.Name == "EqualityContract");
if (property.Accessibility != Accessibility.Protected && (!isSealed || property.Accessibility != Accessibility.Private))
return false;
if (!(property.IsVirtual || property.IsOverride))
if (!(isSealed || property.IsVirtual || property.IsOverride))
return false;
if (property.IsSealed)
return false;
@ -428,11 +438,11 @@ namespace ICSharpCode.Decompiler.CSharp @@ -428,11 +438,11 @@ namespace ICSharpCode.Decompiler.CSharp
Debug.Assert(method.Name == "PrintMembers");
if (method.Parameters.Count != 1)
return false;
if (!method.IsOverridable)
if (!isSealed && !method.IsOverridable)
return false;
if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any())
return false;
if (method.Accessibility != Accessibility.Protected)
if (method.Accessibility != Accessibility.Protected && (!isSealed || method.Accessibility != Accessibility.Private))
return false;
if (orderedMembers == null)
return false;
@ -551,7 +561,7 @@ namespace ICSharpCode.Decompiler.CSharp @@ -551,7 +561,7 @@ namespace ICSharpCode.Decompiler.CSharp
{
return false; // static fields/properties are not printed
}
if (member.Name == "EqualityContract")
if (!isStruct && member.Name == "EqualityContract")
{
return false; // EqualityContract is never printed
}
@ -617,7 +627,8 @@ namespace ICSharpCode.Decompiler.CSharp @@ -617,7 +627,8 @@ namespace ICSharpCode.Decompiler.CSharp
// if (callvirt PrintMembers(ldloc this, ldloc stringBuilder)) { trueInst }
if (!body.Instructions[3].MatchIfInstruction(out var condition, out var trueInst))
return true;
if (!(condition is CallVirt { Method: { Name: "PrintMembers" } } printMembersCall))
if (!((condition is CallInstruction { Method: { Name: "PrintMembers" } } printMembersCall) &&
(condition is CallVirt || (isSealed && condition is Call))))
return false;
if (printMembersCall.Arguments.Count != 2)
return false;
@ -670,7 +681,7 @@ namespace ICSharpCode.Decompiler.CSharp @@ -670,7 +681,7 @@ namespace ICSharpCode.Decompiler.CSharp
Debug.Assert(method.Name == "Equals" && method.Parameters.Count == 1);
if (method.Parameters.Count != 1)
return false;
if (!method.IsOverridable)
if (!isSealed && !method.IsOverridable)
return false;
if (method.GetAttributes().Any() || method.GetReturnTypeAttributes().Any())
return false;
@ -698,58 +709,61 @@ namespace ICSharpCode.Decompiler.CSharp @@ -698,58 +709,61 @@ namespace ICSharpCode.Decompiler.CSharp
var conditions = UnpackLogicAndChain(returnValue);
Debug.Assert(conditions.Count >= 1);
int pos = 0;
if (isInheritedRecord)
{
// call BaseClass::Equals(ldloc this, ldloc other)
if (pos >= conditions.Count)
return false;
if (!(conditions[pos] is Call { Method: { Name: "Equals" } } call))
return false;
if (!NormalizeTypeVisitor.TypeErasure.EquivalentTypes(call.Method.DeclaringType, baseClass))
return false;
if (call.Arguments.Count != 2)
return false;
if (!call.Arguments[0].MatchLdThis())
return false;
if (!call.Arguments[1].MatchLdLoc(other))
return false;
pos++;
}
else
if (!isStruct)
{
// comp.o(ldloc other != ldnull)
if (pos >= conditions.Count)
return false;
if (!conditions[pos].MatchCompNotEqualsNull(out var arg))
return false;
if (!arg.MatchLdLoc(other))
return false;
pos++;
// call op_Equality(callvirt get_EqualityContract(ldloc this), callvirt get_EqualityContract(ldloc other))
// Special-cased because Roslyn isn't using EqualityComparer<T> here.
if (pos >= conditions.Count)
return false;
if (!(conditions[pos] is Call { Method: { IsOperator: true, Name: "op_Equality" } } opEqualityCall))
return false;
if (!opEqualityCall.Method.DeclaringType.IsKnownType(KnownTypeCode.Type))
return false;
if (opEqualityCall.Arguments.Count != 2)
return false;
if (!MatchGetEqualityContract(opEqualityCall.Arguments[0], out var target1))
return false;
if (!MatchGetEqualityContract(opEqualityCall.Arguments[1], out var target2))
return false;
if (!target1.MatchLdThis())
return false;
if (!target2.MatchLdLoc(other))
return false;
pos++;
if (isInheritedRecord)
{
// call BaseClass::Equals(ldloc this, ldloc other)
if (pos >= conditions.Count)
return false;
if (!(conditions[pos] is Call { Method: { Name: "Equals" } } call))
return false;
if (!NormalizeTypeVisitor.TypeErasure.EquivalentTypes(call.Method.DeclaringType, baseClass))
return false;
if (call.Arguments.Count != 2)
return false;
if (!call.Arguments[0].MatchLdThis())
return false;
if (!call.Arguments[1].MatchLdLoc(other))
return false;
pos++;
}
else
{
// comp.o(ldloc other != ldnull)
if (pos >= conditions.Count)
return false;
if (!conditions[pos].MatchCompNotEqualsNull(out var arg))
return false;
if (!arg.MatchLdLoc(other))
return false;
pos++;
// call op_Equality(callvirt get_EqualityContract(ldloc this), callvirt get_EqualityContract(ldloc other))
// Special-cased because Roslyn isn't using EqualityComparer<T> here.
if (pos >= conditions.Count)
return false;
if (!(conditions[pos] is Call { Method: { IsOperator: true, Name: "op_Equality" } } opEqualityCall))
return false;
if (!opEqualityCall.Method.DeclaringType.IsKnownType(KnownTypeCode.Type))
return false;
if (opEqualityCall.Arguments.Count != 2)
return false;
if (!MatchGetEqualityContract(opEqualityCall.Arguments[0], out var target1))
return false;
if (!MatchGetEqualityContract(opEqualityCall.Arguments[1], out var target2))
return false;
if (!target1.MatchLdThis())
return false;
if (!target2.MatchLdLoc(other))
return false;
pos++;
}
}
foreach (var member in orderedMembers)
{
if (!MemberConsideredForEquality(member))
continue;
if (member.Name == "EqualityContract")
if (!isStruct && member.Name == "EqualityContract")
{
continue; // already special-cased
}
@ -771,7 +785,7 @@ namespace ICSharpCode.Decompiler.CSharp @@ -771,7 +785,7 @@ namespace ICSharpCode.Decompiler.CSharp
return false;
if (!member1.Equals(member))
return false;
if (!target2.MatchLdLoc(other))
if (!(isStruct ? target2.MatchLdLoca(other) : target2.MatchLdLoc(other)))
return false;
if (!member2.Equals(member))
return false;
@ -800,10 +814,12 @@ namespace ICSharpCode.Decompiler.CSharp @@ -800,10 +814,12 @@ namespace ICSharpCode.Decompiler.CSharp
}
}
private static bool MatchGetEqualityContract(ILInstruction inst, out ILInstruction target)
private bool MatchGetEqualityContract(ILInstruction inst, out ILInstruction target)
{
target = null;
if (!(inst is CallVirt { Method: { Name: "get_EqualityContract" } } call))
if (!(inst is CallInstruction { Method: { Name: "get_EqualityContract" } } call))
return false;
if (!(inst is CallVirt || (isSealed && inst is Call)))
return false;
if (call.Arguments.Count != 1)
return false;
@ -830,7 +846,7 @@ namespace ICSharpCode.Decompiler.CSharp @@ -830,7 +846,7 @@ namespace ICSharpCode.Decompiler.CSharp
return false;
if (member is IProperty property)
{
if (property.Name == "EqualityContract")
if (!isStruct && property.Name == "EqualityContract")
return !isInheritedRecord;
return autoPropertyToBackingField.ContainsKey(property);
}
@ -985,14 +1001,14 @@ namespace ICSharpCode.Decompiler.CSharp @@ -985,14 +1001,14 @@ namespace ICSharpCode.Decompiler.CSharp
{
target = null;
member = null;
if (inst is CallVirt
if (inst is CallInstruction
{
Method:
{
AccessorKind: System.Reflection.MethodSemanticsAttributes.Getter,
AccessorOwner: IProperty property
}
} call)
} call && (call is CallVirt || (isSealed && call is Call)))
{
if (call.Arguments.Count != 1)
return false;

3
ICSharpCode.Decompiler/CSharp/Transforms/TransformFieldAndConstructorInitializers.cs

@ -118,7 +118,8 @@ namespace ICSharpCode.Decompiler.CSharp.Transforms @@ -118,7 +118,8 @@ namespace ICSharpCode.Decompiler.CSharp.Transforms
&& currentCtor.Equals(record.PrimaryConstructor)
&& ci.ConstructorInitializerType == ConstructorInitializerType.Base)
{
if (constructorDeclaration.Parent is TypeDeclaration { BaseTypes: { Count: >= 1 } } typeDecl)
if (record.IsInheritedRecord &&
constructorDeclaration.Parent is TypeDeclaration { BaseTypes: { Count: >= 1 } } typeDecl)
{
var baseType = typeDecl.BaseTypes.First();
var newBaseType = new InvocationAstType();

Loading…
Cancel
Save