diff --git a/src/Workspaces/Core/Portable/FindSymbols/FindReferences/DependentProjectsFinder.cs b/src/Workspaces/Core/Portable/FindSymbols/FindReferences/DependentProjectsFinder.cs index 33b7671fd1398904a4905a6421008c85e9273612..aead66b605f5f3574cc102b4d2dc435e23a91e3a 100644 --- a/src/Workspaces/Core/Portable/FindSymbols/FindReferences/DependentProjectsFinder.cs +++ b/src/Workspaces/Core/Portable/FindSymbols/FindReferences/DependentProjectsFinder.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +#nullable enable + using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -24,21 +26,29 @@ internal static class DependentProjectsFinder /// /// A helper struct used for keying in . /// - private struct DefinitionProject + private readonly struct DefinitionProject : IEquatable { -#pragma warning disable IDE0052 // Remove unread private members - DefinitionProject is used as a key for dictionaries. - private readonly ProjectId _sourceProjectId; - private readonly string _assemblyName; -#pragma warning restore IDE0052 // Remove unread private members + private readonly ProjectId? _sourceProjectId; + private readonly string? _assemblyName; - public DefinitionProject(ProjectId sourceProjectId, string assemblyName) + public DefinitionProject(ProjectId? sourceProjectId, string assemblyName) { _sourceProjectId = sourceProjectId; _assemblyName = assemblyName; } + + public override bool Equals(object? obj) + => obj is DefinitionProject project && Equals(project); + + public bool Equals(DefinitionProject other) + => EqualityComparer.Default.Equals(_sourceProjectId, other._sourceProjectId) && + _assemblyName == other._assemblyName; + + public override int GetHashCode() + => Hash.Combine(_sourceProjectId, _assemblyName?.GetHashCode() ?? 0); } - private struct DependentProject : IEquatable + private readonly struct DependentProject : IEquatable { public readonly ProjectId ProjectId; public readonly bool HasInternalsAccess; @@ -49,8 +59,8 @@ public DependentProject(ProjectId dependentProjectId, bool hasInternalsAccess) this.HasInternalsAccess = hasInternalsAccess; } - public override bool Equals(object obj) - => obj is DependentProject && this.Equals((DependentProject)obj); + public override bool Equals(object? obj) + => obj is DependentProject project && this.Equals(project); public override int GetHashCode() => Hash.Combine(HasInternalsAccess, ProjectId.GetHashCode()); @@ -103,7 +113,7 @@ private static ImmutableArray GetAllProjects(Solution solution) => solution.Projects.ToImmutableArray(); private static ImmutableArray GetProjects(Solution solution, ImmutableArray projectIds) - => projectIds.SelectAsArray(id => solution.GetProject(id)); + => projectIds.SelectAsArray(id => solution.GetRequiredProject(id)); /// /// This method computes the dependent projects that need to be searched for references of the given . @@ -136,6 +146,7 @@ private static ImmutableArray GetProjects(Solution solution, ImmutableA // Find the projects that reference this assembly. + // If this is a source symbol from a project, try to find that project. var sourceProject = solution.GetProject(containingAssembly, cancellationToken); cancellationToken.ThrowIfCancellationRequested(); @@ -167,7 +178,7 @@ private static ImmutableArray GetProjects(Solution solution, ImmutableA private static async Task> GetDependentProjectsCoreAsync( ISymbol symbol, Solution solution, - Project sourceProject, + Project? sourceProject, SymbolVisibility visibility, CancellationToken cancellationToken) { @@ -213,7 +224,8 @@ private static ImmutableArray GetProjects(Solution solution, ImmutableA return GetProjects(solution, projectIds); } - private static async Task AddSubmissionDependentProjectsAsync(Solution solution, Project sourceProject, HashSet dependentProjects, CancellationToken cancellationToken) + private static async Task AddSubmissionDependentProjectsAsync( + Solution solution, Project? sourceProject, HashSet dependentProjects, CancellationToken cancellationToken) { var isSubmission = sourceProject != null && sourceProject.IsSubmission; if (!isSubmission) @@ -226,26 +238,29 @@ private static async Task AddSubmissionDependentProjectsAsync(Solution solution, // search only submission project foreach (var projectId in solution.ProjectIds) { - var project = solution.GetProject(projectId); + var project = solution.GetRequiredProject(projectId); if (project.IsSubmission && project.SupportsCompilation) { cancellationToken.ThrowIfCancellationRequested(); // If we are referencing another project, store the link in the other direction // so we walk across it later - var compilation = await project.GetCompilationAsync(cancellationToken).ConfigureAwait(false); - var previous = compilation.ScriptCompilationInfo.PreviousScriptCompilation; + var compilation = await project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false); + var previous = compilation.ScriptCompilationInfo?.PreviousScriptCompilation; if (previous != null) { var referencedProject = solution.GetProject(previous.Assembly, cancellationToken); - if (!projectIdsToReferencingSubmissionIds.TryGetValue(referencedProject.Id, out var referencingSubmissions)) + if (referencedProject != null) { - referencingSubmissions = new List(); - projectIdsToReferencingSubmissionIds.Add(referencedProject.Id, referencingSubmissions); - } + if (!projectIdsToReferencingSubmissionIds.TryGetValue(referencedProject.Id, out var referencingSubmissions)) + { + referencingSubmissions = new List(); + projectIdsToReferencingSubmissionIds.Add(referencedProject.Id, referencingSubmissions); + } - referencingSubmissions.Add(project.Id); + referencingSubmissions.Add(project.Id); + } } } } @@ -278,22 +293,23 @@ private static async Task AddSubmissionDependentProjectsAsync(Solution solution, private static bool IsInternalsVisibleToAttribute(AttributeData attr) { var attrType = attr.AttributeClass; - if (attrType == null) - { - return false; - } - - var attributeName = attr.AttributeClass.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat.WithGlobalNamespaceStyle(SymbolDisplayGlobalNamespaceStyle.Omitted)); - return attributeName == "System.Runtime.CompilerServices.InternalsVisibleToAttribute"; + return attrType?.Name == nameof(InternalsVisibleToAttribute) && + attrType.ContainingNamespace?.Name == nameof(System.Runtime.CompilerServices) && + attrType.ContainingNamespace.ContainingNamespace?.Name == nameof(System.Runtime) && + attrType.ContainingNamespace.ContainingNamespace.ContainingNamespace?.Name == nameof(System) && + attrType.ContainingNamespace.ContainingNamespace.ContainingNamespace.ContainingNamespace?.IsGlobalNamespace == true; } - private static async Task AddNonSubmissionDependentProjectsAsync(IAssemblySymbol sourceAssembly, Solution solution, Project sourceProject, HashSet dependentProjects, CancellationToken cancellationToken) + private static async Task AddNonSubmissionDependentProjectsAsync( + IAssemblySymbol sourceAssembly, + Solution solution, + Project? sourceProject, + HashSet dependentProjects, + CancellationToken cancellationToken) { var isSubmission = sourceProject != null && sourceProject.IsSubmission; if (isSubmission) - { return; - } var internalsVisibleToMap = CreateInternalsVisibleToMap(sourceAssembly); @@ -304,7 +320,7 @@ private static async Task AddNonSubmissionDependentProjectsAsync(IAssemblySymbol // things we want to find. foreach (var projectId in solution.ProjectIds) { - var project = solution.GetProject(projectId); + var project = solution.GetRequiredProject(projectId); cancellationToken.ThrowIfCancellationRequested(); @@ -326,7 +342,7 @@ private static async Task AddNonSubmissionDependentProjectsAsync(IAssemblySymbol if (internalsVisibleToMap.Value.Contains(project.AssemblyName) && project.SupportsCompilation) { - var compilation = await project.GetCompilationAsync(cancellationToken).ConfigureAwait(false); + var compilation = await project.GetRequiredCompilationAsync(cancellationToken).ConfigureAwait(false); var targetAssembly = compilation.Assembly; if (sourceAssembly.Language != targetAssembly.Language) @@ -361,13 +377,9 @@ private static Lazy> CreateInternalsVisibleToMap(IAssemblySymbol foreach (var attr in assembly.GetAttributes().Where(IsInternalsVisibleToAttribute)) { var typeNameConstant = attr.ConstructorArguments.FirstOrDefault(); - if (typeNameConstant.Type == null || typeNameConstant.Type.SpecialType != SpecialType.System_String) - { - continue; - } - - var value = (string)typeNameConstant.Value; - if (value == null) + if (typeNameConstant.Type == null || + typeNameConstant.Type.SpecialType != SpecialType.System_String || + !(typeNameConstant.Value is string value)) { continue; } @@ -383,7 +395,11 @@ private static Lazy> CreateInternalsVisibleToMap(IAssemblySymbol return internalsVisibleToMap; } - private static bool HasReferenceTo(IAssemblySymbol containingAssembly, Project sourceProject, Project project, CancellationToken cancellationToken) + private static bool HasReferenceTo( + IAssemblySymbol containingAssembly, + Project? sourceProject, + Project project, + CancellationToken cancellationToken) { if (containingAssembly == null) { @@ -431,16 +447,14 @@ public static bool HasReferenceToAssembly(this Project project, string assemblyN // way for it to have an IAssemblySymbol. And without that, there is no way for it // to have any sort of 'ReferenceTo' the provided 'containingAssembly' symbol. if (!project.SupportsCompilation) - { return null; - } if (!project.TryGetCompilation(out var compilation)) { // WORKAROUND: // perf check metadata reference using newly created empty compilation with only metadata references. - compilation = project.LanguageServices.CompilationFactory.CreateCompilation( - project.AssemblyName, project.CompilationOptions); + compilation = project.LanguageServices.CompilationFactory!.CreateCompilation( + project.AssemblyName, project.CompilationOptions!); compilation = compilation.AddReferences(project.MetadataReferences); }