diff --git a/ICSharpCode.NRefactory.CSharp/Refactoring/CodeIssues/RedundantWhereWithPredicateIssue.cs b/ICSharpCode.NRefactory.CSharp/Refactoring/CodeIssues/RedundantWhereWithPredicateIssue.cs index 45c930748f..3b255c029d 100644 --- a/ICSharpCode.NRefactory.CSharp/Refactoring/CodeIssues/RedundantWhereWithPredicateIssue.cs +++ b/ICSharpCode.NRefactory.CSharp/Refactoring/CodeIssues/RedundantWhereWithPredicateIssue.cs @@ -6,7 +6,7 @@ using ICSharpCode.NRefactory.PatternMatching; namespace ICSharpCode.NRefactory.CSharp.Refactoring { - [IssueDescription("Any() should be used with predicate and Where() removed", + [IssueDescription("Any()/First()/etc. should be used with predicate and Where() removed", Description= "Detects redundant Where() with predicate calls followed by Any().", Category = IssueCategories.CodeQualityIssues, Severity = Severity.Hint)] @@ -15,11 +15,11 @@ namespace ICSharpCode.NRefactory.CSharp.Refactoring static readonly AstNode pattern = new InvocationExpression ( new MemberReferenceExpression ( - new NamedNode ("whereInvoke", - new InvocationExpression ( - new MemberReferenceExpression (new AnyNode ("target"), "Where"), - new AnyNode ())), - "Any")); + new NamedNode ("whereInvoke", + new InvocationExpression ( + new MemberReferenceExpression (new AnyNode ("target"), "Where"), + new AnyNode ())), + Pattern.AnyString)); public IEnumerable GetIssues(BaseRefactoringContext context) { @@ -41,11 +41,11 @@ namespace ICSharpCode.NRefactory.CSharp.Refactoring return; var anyResolve = ctx.Resolve (anyInvoke) as InvocationResolveResult; - if (anyResolve == null || anyResolve.Member.FullName != "System.Linq.Enumerable.Any") + if (anyResolve == null || !HasPredicateVersion(anyResolve.Member)) return; var whereInvoke = match.Get ("whereInvoke").Single (); var whereResolve = ctx.Resolve (whereInvoke) as InvocationResolveResult; - if (whereResolve == null || whereResolve.Member.FullName != "System.Linq.Enumerable.Where") + if (whereResolve == null || whereResolve.Member.Name != "Where" || !IsQueryExtensionClass(whereResolve.Member.DeclaringTypeDefinition)) return; if (whereResolve.Member.Parameters.Count != 2) return; @@ -53,11 +53,47 @@ namespace ICSharpCode.NRefactory.CSharp.Refactoring if (predResolve.Type.TypeParameterCount != 2) return; - AddIssue (anyInvoke, "Redundant Where() call with predicate followed by Any()", script => { - var arg = whereInvoke.Arguments.Single ().Clone (); - var target = match.Get ("target").Single ().Clone (); - script.Replace (anyInvoke, new InvocationExpression (new MemberReferenceExpression (target, "Any"), arg)); - }); + AddIssue ( + anyInvoke, string.Format("Redundant Where() call with predicate followed by {0}()", anyResolve.Member.Name), + script => { + var arg = whereInvoke.Arguments.Single ().Clone (); + var target = match.Get ("target").Single ().Clone (); + script.Replace (anyInvoke, new InvocationExpression (new MemberReferenceExpression (target, anyResolve.Member.Name), arg)); + }); + } + + bool IsQueryExtensionClass(ITypeDefinition typeDef) + { + if (typeDef == null || typeDef.Namespace != "System.Linq") + return false; + switch (typeDef.Name) { + case "Enumerable": + case "ParallelEnumerable": + case "Queryable": + return true; + default: + return false; + } + } + + bool HasPredicateVersion(IParameterizedMember member) + { + if (!IsQueryExtensionClass(member.DeclaringTypeDefinition)) + return false; + switch (member.Name) { + case "Any": + case "Count": + case "First": + case "FirstOrDefault": + case "Last": + case "LastOrDefault": + case "LongCount": + case "Single": + case "SingleOrDefault": + return true; + default: + return false; + } } } } diff --git a/ICSharpCode.NRefactory.Tests/CSharp/CodeIssues/RedundantWhereWithPredicateIssueTests.cs b/ICSharpCode.NRefactory.Tests/CSharp/CodeIssues/RedundantWhereWithPredicateIssueTests.cs index 9c712e6726..fc40866ca6 100644 --- a/ICSharpCode.NRefactory.Tests/CSharp/CodeIssues/RedundantWhereWithPredicateIssueTests.cs +++ b/ICSharpCode.NRefactory.Tests/CSharp/CodeIssues/RedundantWhereWithPredicateIssueTests.cs @@ -66,5 +66,28 @@ public class X var issues = GetIssues (new RedundantWhereWithPredicateIssue (), input, out context); Assert.AreEqual (0, issues.Count); } + + [Test] + public void TestWhereCount() + { + var input = @"using System.Linq; +public class CSharpDemo { + public void Bla () { + int[] arr; + var bla = arr.Where (x => x < 4).Count (); + } +}"; + + TestRefactoringContext context; + var issues = GetIssues (new RedundantWhereWithPredicateIssue (), input, out context); + Assert.AreEqual (1, issues.Count); + CheckFix (context, issues, @"using System.Linq; +public class CSharpDemo { + public void Bla () { + int[] arr; + var bla = arr.Count (x => x < 4); + } +}"); + } } }