Browse Source

[CodeActions] IterateViaForeach: Avoid infinite recursion in cyclic inheritance structures.

newNRvisualizers
Simon Lindgren 13 years ago
parent
commit
764280a957
  1. 36
      ICSharpCode.NRefactory.CSharp/Refactoring/CodeActions/IterateViaForeachAction.cs

36
ICSharpCode.NRefactory.CSharp/Refactoring/CodeActions/IterateViaForeachAction.cs

@ -123,7 +123,7 @@ namespace ICSharpCode.NRefactory.CSharp.Refactoring @@ -123,7 +123,7 @@ namespace ICSharpCode.NRefactory.CSharp.Refactoring
static ForeachStatement MakeForeach(Expression node, IType type, RefactoringContext context)
{
var namingHelper = new NamingHelper(context);
return new ForeachStatement() {
return new ForeachStatement {
VariableType = new SimpleType("var"),
VariableName = namingHelper.GenerateVariableName(type),
InExpression = node.Clone(),
@ -131,36 +131,32 @@ namespace ICSharpCode.NRefactory.CSharp.Refactoring @@ -131,36 +131,32 @@ namespace ICSharpCode.NRefactory.CSharp.Refactoring
};
}
static IType GetElementType(ResolveResult rr, RefactoringContext context)
static IType GetElementType(ResolveResult rr, BaseRefactoringContext context)
{
IType type = GetInterfaceType(rr.Type, "System.Collections.Generic.IEnumerable") ??
GetInterfaceType(rr.Type, "System.Collections.IEnumerable");
var type = GetCollectionType(rr.Type);
if (type == null)
return null;
var parameterizedType = type as ParameterizedType;
if (parameterizedType != null)
return parameterizedType.TypeArguments.First();
else
return context.Compilation.FindType(KnownTypeCode.Object);
return context.Compilation.FindType(KnownTypeCode.Object);
}
static IType GetInterfaceType(IType type, string typeName)
static IType GetCollectionType(IType type)
{
string fullName = null;
if (type.TypeParameterCount > 0)
fullName = type.FullName.Split('`').First();
else
fullName = type.FullName;
if (fullName == typeName)
return type;
foreach (var baseType in type.DirectBaseTypes) {
IType retType = GetInterfaceType(baseType, typeName);
if (retType != null)
return retType;
IType collectionType = null;
foreach (var baseType in type.GetAllBaseTypes()) {
var baseTypeDefinition = baseType.GetDefinition();
if (baseTypeDefinition.IsKnownType(KnownTypeCode.IEnumerableOfT)) {
collectionType = baseType;
break;
} else if (baseTypeDefinition.IsKnownType(KnownTypeCode.IEnumerable)) {
collectionType = baseType;
// Don't break, continue in case type implements IEnumerable<T>
}
}
return null;
return collectionType;
}
#endregion

Loading…
Cancel
Save