#develop (short for SharpDevelop) is a free IDE for .NET programming languages.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

363 lines
13 KiB

// <file>
// <copyright see="prj:///doc/copyright.txt"/>
// <license see="prj:///doc/license.txt"/>
// <owner name="Siegfried Pammer" email="siegfriedpammer@gmail.com"/>
// <version>$Revision$</version>
// </file>
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using ICSharpCode.Core;
using ICSharpCode.NRefactory.Ast;
using ICSharpCode.SharpDevelop.Dom;
using ICSharpCode.SharpDevelop.Dom.Refactoring;
using ICSharpCode.SharpDevelop.Editor;
using Dom = ICSharpCode.SharpDevelop.Dom;
namespace SharpRefactoring.Gui
{
/// <summary>
/// Interaction logic for OverrideEqualsGetHashCodeMethodsDialog.xaml
/// </summary>
public partial class OverrideEqualsGetHashCodeMethodsDialog : AbstractInlineRefactorDialog
{
IClass selectedClass;
public OverrideEqualsGetHashCodeMethodsDialog(ITextEditor editor, ITextAnchor anchor, IClass selectedClass)
: base(null, editor, anchor)
{
if (selectedClass == null)
throw new ArgumentNullException("selectedClass");
InitializeComponent();
this.selectedClass = selectedClass;
addIEquatable.Content = string.Format(StringParser.Parse("${res:AddIns.SharpRefactoring.OverrideEqualsGetHashCodeMethods.AddInterface}"),
"IEquatable<" + selectedClass.Name + ">");
addIEquatable.IsEnabled = !selectedClass.BaseTypes.Any(
type => {
if (!type.IsGenericReturnType)
return false;
var genericType = type.CastToGenericReturnType();
var boundTo = genericType.TypeParameter.BoundTo;
if (boundTo == null)
return false;
return boundTo.Name == selectedClass.Name;
}
);
}
static int[] largePrimes = {
1000000007, 1000000009, 1000000021, 1000000033,
1000000087, 1000000093, 1000000097, 1000000103,
1000000123, 1000000181, 1000000207, 1000000223,
1000000241, 1000000271, 1000000289, 1000000297,
1000000321, 1000000349, 1000000363, 1000000403,
1000000409, 1000000411, 1000000427, 1000000433,
1000000439, 1000000447, 1000000453, 1000000459,
1000000483, 1000000513, 1000000531, 1000000579
};
static bool IsValueType(IReturnType type)
{
IClass c = type.GetUnderlyingClass();
return c != null && (c.ClassType == Dom.ClassType.Struct || c.ClassType == Dom.ClassType.Enum);
}
static bool CanCompareEqualityWithOperator(IReturnType type)
{
// return true for value types except float and double
// return false for reference types except string.
IClass c = type.GetUnderlyingClass();
return c != null
&& c.FullyQualifiedName != "System.Single"
&& c.FullyQualifiedName != "System.Double"
&& (c.ClassType == Dom.ClassType.Struct
|| c.ClassType == Dom.ClassType.Enum
|| c.FullyQualifiedName == "System.String");
}
static Expression TestEquality(string other, IField field)
{
if (CanCompareEqualityWithOperator(field.ReturnType)) {
return new BinaryOperatorExpression(new MemberReferenceExpression(new ThisReferenceExpression(), field.Name),
BinaryOperatorType.Equality,
new MemberReferenceExpression(new IdentifierExpression(other), field.Name));
} else {
InvocationExpression ie = new InvocationExpression(
new MemberReferenceExpression(new TypeReferenceExpression(new TypeReference("System.Object", true)), "Equals")
);
ie.Arguments.Add(new MemberReferenceExpression(new ThisReferenceExpression(), field.Name));
ie.Arguments.Add(new MemberReferenceExpression(new IdentifierExpression(other), field.Name));
return ie;
}
}
protected override string GenerateCode(CodeGenerator generator, IClass currentClass)
{
StringBuilder code = new StringBuilder();
var line = editor.Document.GetLineForOffset(editor.Caret.Offset);
string indent = DocumentUtilitites.GetWhitespaceAfter(editor.Document, line.Offset);
if (Options.AddIEquatableInterface) {
// TODO : add IEquatable<T> to class
// IAmbience ambience = currentClass.CompilationUnit.Language.GetAmbience();
//
// IReturnType baseRType = currentClass.CompilationUnit.ProjectContent.GetClass("System.IEquatable", 1).DefaultReturnType;
//
// IClass newClass = new DefaultClass(currentClass.CompilationUnit, currentClass.FullyQualifiedName, currentClass.Modifiers, currentClass.Region, null);
//
// foreach (IReturnType type in currentClass.BaseTypes) {
// newClass.BaseTypes.Add(type);
// }
//
//
// newClass.BaseTypes.Add(new ConstructedReturnType(baseRType, new List<IReturnType>() { currentClass.DefaultReturnType }));
//
// ambience.ConversionFlags = ConversionFlags.IncludeBody;
//
// string a = ambience.Convert(currentClass);
//
// int startOffset = editor.Document.PositionToOffset(currentClass.Region.BeginLine, currentClass.Region.BeginColumn);
// int endOffset = editor.Document.PositionToOffset(currentClass.BodyRegion.EndLine, currentClass.BodyRegion.EndColumn);
//
// editor.Document.Replace(startOffset, endOffset - startOffset, a);
}
if (Options.SurroundWithRegion) {
code.AppendLine("#region Equals and GetHashCode implementation");
}
code.Append(generator.GenerateCode(CreateGetHashCodeOverride(currentClass), indent));
code.Append("\n" + string.Join("\n", CreateEqualsOverrides(currentClass).Select(item => generator.GenerateCode(item, indent))));
if (Options.AddOperatorOverrides) {
var checkStatements = new[] {
new IfElseStatement(
new InvocationExpression(
new IdentifierExpression("ReferenceEquals"),
new List<Expression>() { new IdentifierExpression("lhs"), new IdentifierExpression("rhs") }
),
new ReturnStatement(new PrimitiveExpression(true))
),
new IfElseStatement(
new BinaryOperatorExpression(
new InvocationExpression(
new IdentifierExpression("ReferenceEquals"),
new List<Expression>() { new IdentifierExpression("lhs"), new PrimitiveExpression(null) }
),
BinaryOperatorType.LogicalOr,
new InvocationExpression(
new IdentifierExpression("ReferenceEquals"),
new List<Expression>() { new IdentifierExpression("rhs"), new PrimitiveExpression(null) }
)
),
new ReturnStatement(new PrimitiveExpression(false))
)
};
BlockStatement equalsOpBody = new BlockStatement() {
Children = {
new ReturnStatement(
new InvocationExpression(
new MemberReferenceExpression(new IdentifierExpression("lhs"), "Equals"),
new List<Expression>() { new IdentifierExpression("rhs") }
)
)
}
};
if (currentClass.ClassType == Dom.ClassType.Class) {
equalsOpBody.Children.InsertRange(0, checkStatements);
}
BlockStatement notEqualsOpBody = new BlockStatement() {
Children = {
new ReturnStatement(
new UnaryOperatorExpression(
new ParenthesizedExpression(
new BinaryOperatorExpression(
new IdentifierExpression("lhs"),
BinaryOperatorType.Equality,
new IdentifierExpression("rhs")
)
),
UnaryOperatorType.Not
)
)
}
};
code.Append("\n" + generator.GenerateCode(CreateOperatorOverload(OverloadableOperatorType.Equality, currentClass, equalsOpBody), indent));
code.Append("\n" + generator.GenerateCode(CreateOperatorOverload(OverloadableOperatorType.InEquality, currentClass, notEqualsOpBody), indent));
}
if (Options.SurroundWithRegion) {
code.AppendLine(indent + "#endregion");
}
return code.ToString();
}
List<MethodDeclaration> CreateEqualsOverrides(IClass currentClass)
{
List<MethodDeclaration> methods = new List<MethodDeclaration>();
TypeReference boolReference = new TypeReference("System.Boolean", true);
TypeReference objectReference = new TypeReference("System.Object", true);
MethodDeclaration method = new MethodDeclaration {
Name = "Equals",
Modifier = Modifiers.Public | Modifiers.Override,
TypeReference = boolReference
};
method.Parameters.Add(new ParameterDeclarationExpression(objectReference, "obj"));
method.Body = new BlockStatement();
TypeReference currentType = ConvertType(currentClass.DefaultReturnType);
Expression expr = null;
if (currentClass.ClassType == Dom.ClassType.Struct) {
// return obj is CurrentType && Equals((CurrentType)obj);
expr = new TypeOfIsExpression(new IdentifierExpression("obj"), currentType);
expr = new ParenthesizedExpression(expr);
expr = new BinaryOperatorExpression(
expr, BinaryOperatorType.LogicalAnd,
new InvocationExpression(
new IdentifierExpression("Equals"),
new List<Expression> {
new CastExpression(currentType, new IdentifierExpression("obj"), CastType.Cast)
}));
method.Body.AddChild(new ReturnStatement(expr));
methods.Add(method);
// IEquatable implementation:
method = new MethodDeclaration {
Name = "Equals",
Modifier = Modifiers.Public | Modifiers.Override,
TypeReference = boolReference
};
method.Parameters.Add(new ParameterDeclarationExpression(currentType, "other"));
method.Body = new BlockStatement();
} else {
method.Body.AddChild(new LocalVariableDeclaration(new VariableDeclaration(
"other",
new CastExpression(currentType, new IdentifierExpression("obj"), CastType.TryCast),
currentType)));
method.Body.AddChild(new IfElseStatement(
new BinaryOperatorExpression(new IdentifierExpression("other"), BinaryOperatorType.ReferenceEquality, new PrimitiveExpression(null, "null")),
new ReturnStatement(new PrimitiveExpression(false, "false"))));
// expr = new BinaryOperatorExpression(new ThisReferenceExpression(),
// BinaryOperatorType.ReferenceEquality,
// new IdentifierExpression("obj"));
// method.Body.AddChild(new IfElseStatement(expr, new ReturnStatement(new PrimitiveExpression(true, "true"))));
}
expr = null;
foreach (IField field in currentClass.Fields) {
if (field.IsStatic) continue;
if (expr == null) {
expr = TestEquality("other", field);
} else {
expr = new BinaryOperatorExpression(expr, BinaryOperatorType.LogicalAnd,
TestEquality("other", field));
}
}
method.Body.AddChild(new ReturnStatement(expr ?? new PrimitiveExpression(true, "true")));
methods.Add(method);
return methods;
}
MethodDeclaration CreateGetHashCodeOverride(IClass currentClass)
{
TypeReference intReference = new TypeReference("System.Int32", true);
VariableDeclaration hashCodeVar = new VariableDeclaration("hashCode", new PrimitiveExpression(0, "0"), intReference);
MethodDeclaration getHashCodeMethod = new MethodDeclaration {
Name = "GetHashCode",
Modifier = Modifiers.Public | Modifiers.Override,
TypeReference = intReference,
Body = new BlockStatement()
};
getHashCodeMethod.Body.AddChild(new LocalVariableDeclaration(hashCodeVar));
if (currentClass.Fields.Any(f => !f.IsStatic)) {
bool usePrimeMultiplication = currentClass.ProjectContent.Language == LanguageProperties.CSharp;
BlockStatement hashCalculationBlock;
if (usePrimeMultiplication) {
hashCalculationBlock = new BlockStatement();
getHashCodeMethod.Body.AddChild(new UncheckedStatement(hashCalculationBlock));
} else {
hashCalculationBlock = getHashCodeMethod.Body;
}
int fieldIndex = 0;
Expression expr;
foreach (IField field in currentClass.Fields) {
if (field.IsStatic) continue;
expr = new InvocationExpression(new MemberReferenceExpression(new IdentifierExpression(field.Name), "GetHashCode"));
if (usePrimeMultiplication) {
int prime = largePrimes[fieldIndex++ % largePrimes.Length];
expr = new AssignmentExpression(
new IdentifierExpression(hashCodeVar.Name),
AssignmentOperatorType.Add,
new BinaryOperatorExpression(new PrimitiveExpression(prime, prime.ToString()),
BinaryOperatorType.Multiply,
expr));
} else {
expr = new AssignmentExpression(new IdentifierExpression(hashCodeVar.Name),
AssignmentOperatorType.ExclusiveOr,
expr);
}
if (IsValueType(field.ReturnType)) {
hashCalculationBlock.AddChild(new ExpressionStatement(expr));
} else {
hashCalculationBlock.AddChild(new IfElseStatement(
new BinaryOperatorExpression(new IdentifierExpression(field.Name),
BinaryOperatorType.ReferenceInequality,
new PrimitiveExpression(null, "null")),
new ExpressionStatement(expr)
));
}
}
}
getHashCodeMethod.Body.AddChild(new ReturnStatement(new IdentifierExpression(hashCodeVar.Name)));
return getHashCodeMethod;
}
OperatorDeclaration CreateOperatorOverload(OverloadableOperatorType op, IClass currentClass, BlockStatement body)
{
return new OperatorDeclaration() {
OverloadableOperator = op,
TypeReference = new TypeReference("System.Boolean", true),
Parameters = {
new ParameterDeclarationExpression(ConvertType(currentClass.DefaultReturnType), "lhs"),
new ParameterDeclarationExpression(ConvertType(currentClass.DefaultReturnType), "rhs")
},
Modifier = Modifiers.Public | Modifiers.Static,
Body = body
};
}
}
}