diff --git a/src/Generator/Passes/EqualiseAccessOfOverrideAndBasePass.cs b/src/Generator/Passes/EqualiseAccessOfOverrideAndBasePass.cs index 32669568..26b464b3 100644 --- a/src/Generator/Passes/EqualiseAccessOfOverrideAndBasePass.cs +++ b/src/Generator/Passes/EqualiseAccessOfOverrideAndBasePass.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using CppSharp.AST; @@ -20,39 +21,18 @@ namespace CppSharp.Passes VisitOptions.VisitTemplateArguments = false; } - public override bool VisitASTContext(ASTContext context) - { - var result = base.VisitASTContext(context); - - foreach (var baseOverride in basesOverrides) - { - var access = baseOverride.Value.Max(o => o.Access); - foreach (var @override in baseOverride.Value) - @override.Access = access; - } - - return result; - } - public override bool VisitMethodDecl(Method method) { - if (!base.VisitMethodDecl(method) || !method.IsOverride) - return false; - - var baseMethod = method.GetRootBaseMethod(); - if (!baseMethod.IsGenerated) + if (!base.VisitMethodDecl(method) || !method.OverriddenMethods.Any()) return false; - HashSet overrides; - if (basesOverrides.ContainsKey(baseMethod)) - overrides = basesOverrides[baseMethod]; - else - overrides = basesOverrides[baseMethod] = new HashSet { baseMethod }; - overrides.Add(method); + var virtuals = new List(method.OverriddenMethods); + virtuals.Add(method); + AccessSpecifier access = virtuals.Max(o => o.Access); + foreach (var @virtual in virtuals) + @virtual.Access = access; return true; } - - private Dictionary> basesOverrides = new Dictionary>(); } }