diff --git a/ICSharpCode.Decompiler.Tests/Helpers/Tester.cs b/ICSharpCode.Decompiler.Tests/Helpers/Tester.cs index de6e0a8f1..ad00927aa 100644 --- a/ICSharpCode.Decompiler.Tests/Helpers/Tester.cs +++ b/ICSharpCode.Decompiler.Tests/Helpers/Tester.cs @@ -296,14 +296,17 @@ namespace ICSharpCode.Decompiler.Tests.Helpers if (flags.HasFlag(CompilerOptions.UseRoslyn)) { + var languageVersion = flags.HasFlag(CompilerOptions.Preview) + ? Microsoft.CodeAnalysis.CSharp.LanguageVersion.Preview + : Microsoft.CodeAnalysis.CSharp.LanguageVersion.CSharp8; var parseOptions = new CSharpParseOptions( preprocessorSymbols: preprocessorSymbols.ToArray(), - languageVersion: flags.HasFlag(CompilerOptions.Preview) ? Microsoft.CodeAnalysis.CSharp.LanguageVersion.Preview : Microsoft.CodeAnalysis.CSharp.LanguageVersion.CSharp8 + languageVersion: languageVersion ); var syntaxTrees = sourceFileNames.Select(f => SyntaxFactory.ParseSyntaxTree(File.ReadAllText(f), parseOptions, path: f, encoding: Encoding.UTF8)); if (flags.HasFlag(CompilerOptions.ReferenceCore)) { - syntaxTrees = syntaxTrees.Concat(new[] { SyntaxFactory.ParseSyntaxTree(targetFrameworkAttributeSnippet) }); + syntaxTrees = syntaxTrees.Concat(new[] { SyntaxFactory.ParseSyntaxTree(targetFrameworkAttributeSnippet, parseOptions) }); } IEnumerable references; if (flags.HasFlag(CompilerOptions.ReferenceCore)) diff --git a/ICSharpCode.Decompiler.Tests/ICSharpCode.Decompiler.Tests.csproj b/ICSharpCode.Decompiler.Tests/ICSharpCode.Decompiler.Tests.csproj index 379fa9104..f61b24404 100644 --- a/ICSharpCode.Decompiler.Tests/ICSharpCode.Decompiler.Tests.csproj +++ b/ICSharpCode.Decompiler.Tests/ICSharpCode.Decompiler.Tests.csproj @@ -96,6 +96,7 @@ + diff --git a/ICSharpCode.Decompiler.Tests/PrettyTestRunner.cs b/ICSharpCode.Decompiler.Tests/PrettyTestRunner.cs index 617a7f2a0..7874c49f1 100644 --- a/ICSharpCode.Decompiler.Tests/PrettyTestRunner.cs +++ b/ICSharpCode.Decompiler.Tests/PrettyTestRunner.cs @@ -534,6 +534,12 @@ namespace ICSharpCode.Decompiler.Tests RunForLibrary(cscOptions: cscOptions); } + [Test] + public void CS9_ExtensionGetEnumerator([ValueSource(nameof(dotnetCoreOnlyOptions))] CompilerOptions cscOptions) + { + RunForLibrary(cscOptions: cscOptions | CompilerOptions.Preview); + } + void RunForLibrary([CallerMemberName] string testName = null, AssemblerOptions asmOptions = AssemblerOptions.None, CompilerOptions cscOptions = CompilerOptions.None, DecompilerSettings decompilerSettings = null) { Run(testName, asmOptions | AssemblerOptions.Library, cscOptions | CompilerOptions.Library, decompilerSettings); diff --git a/ICSharpCode.Decompiler.Tests/TestCases/Pretty/CS9_ExtensionGetEnumerator.cs b/ICSharpCode.Decompiler.Tests/TestCases/Pretty/CS9_ExtensionGetEnumerator.cs new file mode 100644 index 000000000..f5097e0fc --- /dev/null +++ b/ICSharpCode.Decompiler.Tests/TestCases/Pretty/CS9_ExtensionGetEnumerator.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections; +using System.Collections.Generic; + +namespace ICSharpCode.Decompiler.Tests.TestCases.Pretty +{ + public class CS9_ExtensionGetEnumerator + { + public class NonGeneric + { + } + + public class Generic + { + } + + public void Test(NonGeneric c) + { + foreach (object? item in c) + { + Console.WriteLine(item); + } + } + + public void Test(Generic c) + { + foreach (int item in c) + { + Console.WriteLine(item); + } + } + + public async void TestAsync(Generic c) + { + await foreach (int item in c) + { + Console.WriteLine(item); + } + } + } + + public static class CS9_ExtensionGetEnumerator_Ext + { + public static IEnumerator GetEnumerator(this CS9_ExtensionGetEnumerator.NonGeneric c) + { + throw null; + } + public static IEnumerator GetEnumerator(this CS9_ExtensionGetEnumerator.Generic c) + { + throw null; + } + public static IAsyncEnumerator GetAsyncEnumerator(this CS9_ExtensionGetEnumerator.Generic c) + { + throw null; + } + } +} \ No newline at end of file diff --git a/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs b/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs index 3a2af8620..8b798edb1 100644 --- a/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs +++ b/ICSharpCode.Decompiler/CSharp/CSharpDecompiler.cs @@ -1456,7 +1456,14 @@ namespace ICSharpCode.Decompiler.CSharp if (localSettings.DecompileMemberBodies) { AddDefinesForConditionalAttributes(function, decompileRun); - var statementBuilder = new StatementBuilder(typeSystem, decompilationContext, function, localSettings, CancellationToken); + var statementBuilder = new StatementBuilder( + typeSystem, + decompilationContext, + function, + localSettings, + decompileRun, + CancellationToken + ); body = statementBuilder.ConvertAsBlock(function.Body); Comment prev = null; diff --git a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs index b29800632..139c05490 100644 --- a/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs +++ b/ICSharpCode.Decompiler/CSharp/ExpressionBuilder.cs @@ -2205,7 +2205,14 @@ namespace ICSharpCode.Decompiler.CSharp ame.IsAsync = function.IsAsync; ame.Parameters.AddRange(MakeParameters(function.Parameters, function)); ame.HasParameterList = ame.Parameters.Count > 0; - StatementBuilder builder = new StatementBuilder(typeSystem, this.decompilationContext, function, settings, cancellationToken); + var builder = new StatementBuilder( + typeSystem, + this.decompilationContext, + function, + settings, + statementBuilder.decompileRun, + cancellationToken + ); var body = builder.ConvertAsBlock(function.Body); Comment prev = null; diff --git a/ICSharpCode.Decompiler/CSharp/StatementBuilder.cs b/ICSharpCode.Decompiler/CSharp/StatementBuilder.cs index 294c58930..ef41da4ef 100644 --- a/ICSharpCode.Decompiler/CSharp/StatementBuilder.cs +++ b/ICSharpCode.Decompiler/CSharp/StatementBuilder.cs @@ -24,6 +24,8 @@ using System.Threading; using ICSharpCode.Decompiler.CSharp.Syntax; using ICSharpCode.Decompiler.CSharp.Syntax.PatternMatching; +using ICSharpCode.Decompiler.CSharp.Transforms; +using ICSharpCode.Decompiler.CSharp.TypeSystem; using ICSharpCode.Decompiler.IL; using ICSharpCode.Decompiler.IL.Transforms; using ICSharpCode.Decompiler.Semantics; @@ -37,6 +39,7 @@ namespace ICSharpCode.Decompiler.CSharp internal readonly ExpressionBuilder exprBuilder; readonly ILFunction currentFunction; readonly IDecompilerTypeSystem typeSystem; + internal readonly DecompileRun decompileRun; readonly DecompilerSettings settings; readonly CancellationToken cancellationToken; @@ -44,16 +47,28 @@ namespace ICSharpCode.Decompiler.CSharp internal IType currentResultType; internal bool currentIsIterator; - public StatementBuilder(IDecompilerTypeSystem typeSystem, ITypeResolveContext decompilationContext, ILFunction currentFunction, DecompilerSettings settings, CancellationToken cancellationToken) + public StatementBuilder(IDecompilerTypeSystem typeSystem, ITypeResolveContext decompilationContext, + ILFunction currentFunction, DecompilerSettings settings, DecompileRun decompileRun, + CancellationToken cancellationToken) { Debug.Assert(typeSystem != null && decompilationContext != null); - this.exprBuilder = new ExpressionBuilder(this, typeSystem, decompilationContext, currentFunction, settings, cancellationToken); + this.exprBuilder = new ExpressionBuilder( + this, + typeSystem, + decompilationContext, + currentFunction, + settings, + cancellationToken + ); this.currentFunction = currentFunction; this.currentReturnContainer = (BlockContainer)currentFunction.Body; this.currentIsIterator = currentFunction.IsIterator; - this.currentResultType = currentFunction.IsAsync ? currentFunction.AsyncReturnType : currentFunction.ReturnType; + this.currentResultType = currentFunction.IsAsync + ? currentFunction.AsyncReturnType + : currentFunction.ReturnType; this.typeSystem = typeSystem; this.settings = settings; + this.decompileRun = decompileRun; this.cancellationToken = cancellationToken; } @@ -462,6 +477,13 @@ namespace ICSharpCode.Decompiler.CSharp new MemberReferenceExpression(new AnyNode("collection").ToExpression(), "GetAsyncEnumerator") } ); + static readonly InvocationExpression extensionGetEnumeratorPattern = new InvocationExpression( + new Choice { + new MemberReferenceExpression(new AnyNode("type").ToExpression(), "GetEnumerator"), + new MemberReferenceExpression(new AnyNode("type").ToExpression(), "GetAsyncEnumerator") + }, + new AnyNode("collection") + ); static readonly Expression moveNextConditionPattern = new Choice { new InvocationExpression(new MemberReferenceExpression(new NamedNode("enumerator", new IdentifierExpression(Pattern.AnyString)), "MoveNext")), new UnaryOperatorExpression(UnaryOperatorType.Await, new InvocationExpression(new MemberReferenceExpression(new NamedNode("enumerator", new IdentifierExpression(Pattern.AnyString)), "MoveNextAsync"))) @@ -549,10 +571,38 @@ namespace ICSharpCode.Decompiler.CSharp { return null; } - // Check if the using resource matches the GetEnumerator pattern. - var m = getEnumeratorPattern.Match(resource); + Match m; + if (settings.ExtensionMethods && settings.ForEachWithGetEnumeratorExtension) + { + // Check if the using resource matches the GetEnumerator pattern ... + m = getEnumeratorPattern.Match(resource); + if (!m.Success) + { + // ... or the extension GetEnumeratorPattern. + m = extensionGetEnumeratorPattern.Match(resource); + if (!m.Success) + return null; + // Validate that the invocation is an extension method invocation. + var context = new CSharpTypeResolveContext( + typeSystem.MainModule, + decompileRun.UsingScope.Resolve(typeSystem) + ); + if (!IntroduceExtensionMethods.CanTransformToExtensionMethodCall(context, + (InvocationExpression)resource)) + { + return null; + } + } + } + else + { + // Check if the using resource matches the GetEnumerator pattern. + m = getEnumeratorPattern.Match(resource); + if (!m.Success) + return null; + } // The using body must be a BlockContainer. - if (!(inst.Body is BlockContainer container) || !m.Success) + if (!(inst.Body is BlockContainer container)) return null; bool isAsync = ((MemberReferenceExpression)((InvocationExpression)resource).Target).MemberName == "GetAsyncEnumerator"; if (isAsync != inst.IsAsync) @@ -1220,7 +1270,14 @@ namespace ICSharpCode.Decompiler.CSharp LocalFunctionDeclarationStatement TranslateFunction(ILFunction function) { - var nestedBuilder = new StatementBuilder(typeSystem, exprBuilder.decompilationContext, function, settings, cancellationToken); + var nestedBuilder = new StatementBuilder( + typeSystem, + exprBuilder.decompilationContext, + function, + settings, + decompileRun, + cancellationToken + ); var astBuilder = exprBuilder.astBuilder; var method = (MethodDeclaration)astBuilder.ConvertEntity(function.ReducedMethod); method.Body = nestedBuilder.ConvertAsBlock(function.Body); diff --git a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs index 832b131a1..b1459639f 100644 --- a/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs +++ b/ICSharpCode.Decompiler/CSharp/Transforms/IntroduceExtensionMethods.cs @@ -106,11 +106,59 @@ namespace ICSharpCode.Decompiler.CSharp.Transforms public override void VisitInvocationExpression(InvocationExpression invocationExpression) { base.VisitInvocationExpression(invocationExpression); + if (!CanTransformToExtensionMethodCall(resolver, invocationExpression, out var memberRefExpr, + out var target, out var firstArgument)) + { + return; + } + var method = (IMethod)invocationExpression.GetSymbol(); + if (firstArgument is DirectionExpression dirExpr) + { + if (!context.Settings.RefExtensionMethods || dirExpr.FieldDirection == FieldDirection.Out) + return; + firstArgument = dirExpr.Expression; + target = firstArgument.GetResolveResult(); + dirExpr.Detach(); + } + else if (firstArgument is NullReferenceExpression) + { + Debug.Assert(context.RequiredNamespacesSuperset.Contains(method.Parameters[0].Type.Namespace)); + firstArgument = firstArgument.ReplaceWith(expr => new CastExpression(context.TypeSystemAstBuilder.ConvertType(method.Parameters[0].Type), expr.Detach())); + } + if (invocationExpression.Target is IdentifierExpression identifierExpression) + { + identifierExpression.Detach(); + memberRefExpr = new MemberReferenceExpression(firstArgument.Detach(), method.Name, identifierExpression.TypeArguments.Detach()); + invocationExpression.Target = memberRefExpr; + } + else + { + memberRefExpr.Target = firstArgument.Detach(); + } + if (invocationExpression.GetResolveResult() is CSharpInvocationResolveResult irr) + { + // do not forget to update the CSharpInvocationResolveResult => set IsExtensionMethodInvocation == true + invocationExpression.RemoveAnnotations(); + var newResolveResult = new CSharpInvocationResolveResult( + irr.TargetResult, irr.Member, irr.Arguments, irr.OverloadResolutionErrors, + isExtensionMethodInvocation: true, irr.IsExpandedForm, irr.IsDelegateInvocation, + irr.GetArgumentToParameterMap(), irr.InitializerStatements); + invocationExpression.AddAnnotation(newResolveResult); + } + } + + static bool CanTransformToExtensionMethodCall(CSharpResolver resolver, + InvocationExpression invocationExpression, out MemberReferenceExpression memberRefExpr, + out ResolveResult target, + out Expression firstArgument) + { var method = invocationExpression.GetSymbol() as IMethod; + memberRefExpr = null; + target = null; + firstArgument = null; if (method == null || !method.IsExtensionMethod || !invocationExpression.Arguments.Any()) - return; + return false; IReadOnlyList typeArguments; - MemberReferenceExpression memberRefExpr; switch (invocationExpression.Target) { case MemberReferenceExpression mre: @@ -122,13 +170,13 @@ namespace ICSharpCode.Decompiler.CSharp.Transforms memberRefExpr = null; break; default: - return; + return false; } - var firstArgument = invocationExpression.Arguments.First(); + firstArgument = invocationExpression.Arguments.First(); if (firstArgument is NamedArgumentExpression) - return; - var target = firstArgument.GetResolveResult(); + return false; + target = firstArgument.GetResolveResult(); if (target is ConstantResolveResult crr && crr.ConstantValue == null) { target = new ConversionResolveResult(method.Parameters[0].Type, crr, Conversion.NullLiteralConversion); @@ -153,41 +201,14 @@ namespace ICSharpCode.Decompiler.CSharp.Transforms } pos++; } - if (!CanTransformToExtensionMethodCall(resolver, method, typeArguments, target, args, argNames)) - return; - if (firstArgument is DirectionExpression dirExpr) - { - if (!context.Settings.RefExtensionMethods || dirExpr.FieldDirection == FieldDirection.Out) - return; - firstArgument = dirExpr.Expression; - target = firstArgument.GetResolveResult(); - dirExpr.Detach(); - } - else if (firstArgument is NullReferenceExpression) - { - Debug.Assert(context.RequiredNamespacesSuperset.Contains(method.Parameters[0].Type.Namespace)); - firstArgument = firstArgument.ReplaceWith(expr => new CastExpression(context.TypeSystemAstBuilder.ConvertType(method.Parameters[0].Type), expr.Detach())); - } - if (invocationExpression.Target is IdentifierExpression identifierExpression) - { - identifierExpression.Detach(); - memberRefExpr = new MemberReferenceExpression(firstArgument.Detach(), method.Name, identifierExpression.TypeArguments.Detach()); - invocationExpression.Target = memberRefExpr; - } - else - { - memberRefExpr.Target = firstArgument.Detach(); - } - if (invocationExpression.GetResolveResult() is CSharpInvocationResolveResult irr) - { - // do not forget to update the CSharpInvocationResolveResult => set IsExtensionMethodInvocation == true - invocationExpression.RemoveAnnotations(); - var newResolveResult = new CSharpInvocationResolveResult( - irr.TargetResult, irr.Member, irr.Arguments, irr.OverloadResolutionErrors, - isExtensionMethodInvocation: true, irr.IsExpandedForm, irr.IsDelegateInvocation, - irr.GetArgumentToParameterMap(), irr.InitializerStatements); - invocationExpression.AddAnnotation(newResolveResult); - } + return CanTransformToExtensionMethodCall(resolver, method, typeArguments, target, args, argNames); + } + + public static bool CanTransformToExtensionMethodCall(CSharpTypeResolveContext resolveContext, + InvocationExpression invocationExpression) + { + return CanTransformToExtensionMethodCall(new CSharpResolver(resolveContext), + invocationExpression, out _, out _, out _); } public static bool CanTransformToExtensionMethodCall(CSharpResolver resolver, IMethod method, diff --git a/ICSharpCode.Decompiler/DecompilerSettings.cs b/ICSharpCode.Decompiler/DecompilerSettings.cs index b36e4196e..bfac69635 100644 --- a/ICSharpCode.Decompiler/DecompilerSettings.cs +++ b/ICSharpCode.Decompiler/DecompilerSettings.cs @@ -134,12 +134,13 @@ namespace ICSharpCode.Decompiler nativeIntegers = false; initAccessors = false; functionPointers = false; + forEachWithGetEnumeratorExtension = false; } } public CSharp.LanguageVersion GetMinimumRequiredVersion() { - if (nativeIntegers || initAccessors || functionPointers) + if (nativeIntegers || initAccessors || functionPointers || forEachWithGetEnumeratorExtension) return CSharp.LanguageVersion.Preview; if (nullableReferenceTypes || readOnlyMethods || asyncEnumerator || asyncUsingAndForEachStatement || staticLocalFunctions || ranges || switchExpressions) @@ -586,6 +587,24 @@ namespace ICSharpCode.Decompiler } } + bool forEachWithGetEnumeratorExtension = true; + + /// + /// Support GetEnumerator extension methods in foreach. + /// + [Category("C# 9.0 (experimental)")] + [Description("DecompilerSettings.DecompileForEachWithGetEnumeratorExtension")] + public bool ForEachWithGetEnumeratorExtension { + get { return forEachWithGetEnumeratorExtension; } + set { + if (forEachWithGetEnumeratorExtension != value) + { + forEachWithGetEnumeratorExtension = value; + OnPropertyChanged(); + } + } + } + bool lockStatement = true; /// diff --git a/ILSpy/Properties/Resources.Designer.cs b/ILSpy/Properties/Resources.Designer.cs index 618957f6b..43bd69c5f 100644 --- a/ILSpy/Properties/Resources.Designer.cs +++ b/ILSpy/Properties/Resources.Designer.cs @@ -801,6 +801,15 @@ namespace ICSharpCode.ILSpy.Properties { } } + /// + /// Looks up a localized string similar to Decompile foreach statements with GetEnumerator extension methods. + /// + public static string DecompilerSettings_DecompileForEachWithGetEnumeratorExtension { + get { + return ResourceManager.GetString("DecompilerSettings.DecompileForEachWithGetEnumeratorExtension", resourceCulture); + } + } + /// /// Looks up a localized string similar to Decompile use of the 'dynamic' type. /// diff --git a/ILSpy/Properties/Resources.resx b/ILSpy/Properties/Resources.resx index 6bf69acf8..d1d5b6c30 100644 --- a/ILSpy/Properties/Resources.resx +++ b/ILSpy/Properties/Resources.resx @@ -297,6 +297,9 @@ Are you sure you want to continue? Decompile expression trees + + Decompile foreach statements with GetEnumerator extension methods + Decompile use of the 'dynamic' type