// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Composition; using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Editor.Shared.Utilities; using Microsoft.CodeAnalysis.ErrorReporting; using Microsoft.CodeAnalysis.Host.Mef; using Microsoft.CodeAnalysis.Notification; using Microsoft.CodeAnalysis.Packaging; using Microsoft.CodeAnalysis.Shared.Options; using Microsoft.CodeAnalysis.Shared.Utilities; using Microsoft.CodeAnalysis.Text; using Microsoft.VisualStudio.ComponentModelHost; using Microsoft.VisualStudio.Editor; using Microsoft.VisualStudio.LanguageServices.Implementation.ProjectSystem; using Microsoft.VisualStudio.Shell.Interop; using NuGet.VisualStudio; using Roslyn.Utilities; namespace Microsoft.VisualStudio.LanguageServices.Packaging { /// /// Free threaded wrapper around the NuGet.VisualStudio STA package installer interfaces. /// We want to be able to make queries about packages from any thread. For example, the /// add-nuget-reference feature wants to know what packages a project already has /// references to. NuGet.VisualStudio provides this information, but only in a COM STA /// manner. As we don't want our background work to bounce and block on the UI thread /// we have this helper class which queries the information on the UI thread and caches /// the data so it can be read from the background. /// [ExportWorkspaceService(typeof(IPackageInstallerService)), Shared] internal partial class PackageInstallerService : ForegroundThreadAffinitizedObject, IPackageInstallerService, IVsSearchProviderCallback { private readonly object _gate = new object(); private readonly IVsEditorAdaptersFactoryService _editorAdaptersFactoryService; /// /// The workspace we're connected to. When we're disconnected this will become 'null'. /// That's our signal to stop working. /// private VisualStudioWorkspaceImpl _workspace; private IVsPackageInstallerServices _packageInstallerServices; private IVsPackageInstaller _packageInstaller; private IVsPackageUninstaller _packageUninstaller; private IVsPackageSourceProvider _packageSourceProvider; private CancellationTokenSource _tokenSource = new CancellationTokenSource(); // We keep track of what types of changes we've seen so we can then determine what to // refresh on the UI thread. If we hear about project changes, we only refresh that // project. If we hear about a solution level change, we'll refresh all projects. private bool _solutionChanged; private HashSet _changedProjects = new HashSet(); private readonly ConcurrentDictionary> _projectToInstalledPackageAndVersion = new ConcurrentDictionary>(); [ImportingConstructor] public PackageInstallerService( IVsEditorAdaptersFactoryService editorAdaptersFactoryService) { _editorAdaptersFactoryService = editorAdaptersFactoryService; } public ImmutableArray PackageSources { get; private set; } = ImmutableArray.Empty; public event EventHandler PackageSourcesChanged; internal void Connect(VisualStudioWorkspaceImpl workspace) { this.AssertIsForeground(); var options = workspace.Options; if (!options.GetOption(ServiceComponentOnOffOptions.PackageSearch)) { return; } ConnectWorker(workspace); } // Don't inline this method. The references to nuget types will cause the nuget packages // to load. [MethodImpl(MethodImplOptions.NoInlining)] private void ConnectWorker(VisualStudioWorkspaceImpl workspace) { var componentModel = workspace.GetVsService(); _packageInstallerServices = componentModel.GetExtensions().FirstOrDefault(); _packageInstaller = componentModel.GetExtensions().FirstOrDefault(); _packageUninstaller = componentModel.GetExtensions().FirstOrDefault(); _packageSourceProvider = componentModel.GetExtensions().FirstOrDefault(); if (!this.IsEnabled) { return; } // Start listening to workspace changes. _workspace = workspace; _workspace.WorkspaceChanged += OnWorkspaceChanged; _packageSourceProvider.SourcesChanged += OnSourceProviderSourcesChanged; OnSourceProviderSourcesChanged(null, EventArgs.Empty); } public bool IsEnabled => _packageInstallerServices != null && _packageInstallerServices != null && _packageUninstaller != null && _packageSourceProvider != null; internal void Disconnect(VisualStudioWorkspaceImpl workspace) { this.AssertIsForeground(); if (!this.IsEnabled) { return; } Debug.Assert(workspace == _workspace); _packageSourceProvider.SourcesChanged -= OnSourceProviderSourcesChanged; _workspace.WorkspaceChanged -= OnWorkspaceChanged; _workspace = null; } private void OnSourceProviderSourcesChanged(object sender, EventArgs e) { if (!this.IsForeground()) { this.InvokeBelowInputPriority(() => OnSourceProviderSourcesChanged(sender, e)); return; } this.AssertIsForeground(); PackageSources = _packageSourceProvider.GetSources(includeUnOfficial: true, includeDisabled: false) .Select(r => new PackageSource(r.Key, r.Value)) .ToImmutableArrayOrEmpty(); PackageSourcesChanged?.Invoke(this, EventArgs.Empty); } public bool TryInstallPackage( Workspace workspace, DocumentId documentId, string source, string packageName, string versionOpt, CancellationToken cancellationToken) { this.AssertIsForeground(); // The 'workspace == _workspace' line is probably not necessary. However, we include // it just to make sure that someone isn't trying to install a package into a workspace // other than the VisualStudioWorkspace. if (workspace == _workspace && _workspace != null && _packageInstallerServices != null) { var projectId = documentId.ProjectId; var dte = _workspace.GetVsService(); var dteProject = _workspace.TryGetDTEProject(projectId); if (dteProject != null) { var description = string.Format(ServicesVSResources.Install_0, packageName); var document = workspace.CurrentSolution.GetDocument(documentId); var text = document.GetTextAsync(cancellationToken).WaitAndGetResult(cancellationToken); var textSnapshot = text.FindCorrespondingEditorTextSnapshot(); var textBuffer = textSnapshot?.TextBuffer; var undoManager = GetUndoManager(textBuffer); return TryInstallAndAddUndoAction(source, packageName, versionOpt, dte, dteProject, undoManager); } } return false; } private bool TryInstallPackage( string source, string packageName, string versionOpt, EnvDTE.DTE dte, EnvDTE.Project dteProject) { try { if (!_packageInstallerServices.IsPackageInstalled(dteProject, packageName)) { dte.StatusBar.Text = string.Format(ServicesVSResources.Installing_0, packageName); _packageInstaller.InstallPackage(source, dteProject, packageName, versionOpt, ignoreDependencies: false); var installedVersion = GetInstalledVersion(packageName, dteProject); dte.StatusBar.Text = string.Format(ServicesVSResources.Installing_0_completed, GetStatusBarText(packageName, installedVersion)); return true; } // fall through. } catch (Exception e) when (FatalError.ReportWithoutCrash(e)) { dte.StatusBar.Text = string.Format(ServicesVSResources.Package_install_failed_0, e.Message); var notificationService = _workspace.Services.GetService(); notificationService?.SendNotification( string.Format(ServicesVSResources.Installing_0_failed_Additional_information_1, packageName, e.Message), severity: NotificationSeverity.Error); // fall through. } return false; } private static string GetStatusBarText(string packageName, string installedVersion) { return installedVersion == null ? packageName : $"{packageName} - {installedVersion}"; } private bool TryUninstallPackage( string packageName, EnvDTE.DTE dte, EnvDTE.Project dteProject) { this.AssertIsForeground(); try { if (_packageInstallerServices.IsPackageInstalled(dteProject, packageName)) { dte.StatusBar.Text = string.Format(ServicesVSResources.Uninstalling_0, packageName); var installedVersion = GetInstalledVersion(packageName, dteProject); _packageUninstaller.UninstallPackage(dteProject, packageName, removeDependencies: true); dte.StatusBar.Text = string.Format(ServicesVSResources.Uninstalling_0_completed, GetStatusBarText(packageName, installedVersion)); return true; } // fall through. } catch (Exception e) when (FatalError.ReportWithoutCrash(e)) { dte.StatusBar.Text = string.Format(ServicesVSResources.Package_uninstall_failed_0, e.Message); var notificationService = _workspace.Services.GetService(); notificationService?.SendNotification( string.Format(ServicesVSResources.Uninstalling_0_failed_Additional_information_1, packageName, e.Message), severity: NotificationSeverity.Error); // fall through. } return false; } private string GetInstalledVersion(string packageName, EnvDTE.Project dteProject) { this.AssertIsForeground(); try { var installedPackages = _packageInstallerServices.GetInstalledPackages(dteProject); var metadata = installedPackages.FirstOrDefault(m => m.Id == packageName); return metadata?.VersionString; } catch (Exception e) when (FatalError.ReportWithoutCrash(e)) { return null; } } private void OnWorkspaceChanged(object sender, WorkspaceChangeEventArgs e) { ThisCanBeCalledOnAnyThread(); bool localSolutionChanged = false; ProjectId localChangedProject = null; switch (e.Kind) { default: // Nothing to do for any other events. return; case WorkspaceChangeKind.ProjectAdded: case WorkspaceChangeKind.ProjectChanged: case WorkspaceChangeKind.ProjectReloaded: case WorkspaceChangeKind.ProjectRemoved: localChangedProject = e.ProjectId; break; case WorkspaceChangeKind.SolutionAdded: case WorkspaceChangeKind.SolutionChanged: case WorkspaceChangeKind.SolutionCleared: case WorkspaceChangeKind.SolutionReloaded: case WorkspaceChangeKind.SolutionRemoved: localSolutionChanged = true; break; } lock (_gate) { // Augment the data that the foreground thread will process. _solutionChanged |= localSolutionChanged; if (localChangedProject != null) { _changedProjects.Add(localChangedProject); } // Now cancel any inflight work that is processing the data. _tokenSource.Cancel(); _tokenSource = new CancellationTokenSource(); // And enqueue a new job to process things. Wait one second before starting. // That way if we get a flurry of events we'll end up processing them after // they've all come in. var cancellationToken = _tokenSource.Token; Task.Delay(TimeSpan.FromSeconds(1), cancellationToken) .ContinueWith(_ => ProcessBatchedChangesOnForeground(cancellationToken), cancellationToken, TaskContinuationOptions.OnlyOnRanToCompletion, this.ForegroundTaskScheduler); } } private void ProcessBatchedChangesOnForeground(CancellationToken cancellationToken) { this.AssertIsForeground(); // If we've been asked to stop, then there's no point proceeding. if (cancellationToken.IsCancellationRequested) { return; } // If we've been disconnected, then there's no point proceeding. if (_workspace == null || _packageInstallerServices == null) { return; } // Get a project to process. var solution = _workspace.CurrentSolution; var projectId = DequeueNextProject(solution); if (projectId == null) { // No project to process, nothing to do. return; } // Process this single project. ProcessProjectChange(solution, projectId); // After processing this single project, yield so the foreground thread // can do more work. Then go and loop again so we can process the // rest of the projects. Task.Factory.SafeStartNew( () => ProcessBatchedChangesOnForeground(cancellationToken), cancellationToken, ForegroundTaskScheduler); } private ProjectId DequeueNextProject(Solution solution) { this.AssertIsForeground(); lock (_gate) { // If we detected a solution change, then we need to process all projects. // This includes all the projects that we already know about, as well as // all the projects in the current workspace solution. if (_solutionChanged) { _changedProjects.AddRange(solution.ProjectIds); _changedProjects.AddRange(_projectToInstalledPackageAndVersion.Keys); } _solutionChanged = false; // Remove and return the first project in the list. var projectId = _changedProjects.FirstOrDefault(); _changedProjects.Remove(projectId); return projectId; } } private void ProcessProjectChange(Solution solution, ProjectId projectId) { this.AssertIsForeground(); // Remove anything we have associated with this project. Dictionary installedPackages; _projectToInstalledPackageAndVersion.TryRemove(projectId, out installedPackages); if (!solution.ContainsProject(projectId)) { // Project was removed. Nothing needs to be done. return; } // Project was changed in some way. Let's go find the set of installed packages for it. var dteProject = _workspace.TryGetDTEProject(projectId); if (dteProject == null) { // Don't have a DTE project for this project ID. not something we can query nuget for. return; } installedPackages = new Dictionary(); // Calling into nuget. Assume they may fail for any reason. try { var installedPackageMetadata = _packageInstallerServices.GetInstalledPackages(dteProject); installedPackages.AddRange(installedPackageMetadata.Select(m => new KeyValuePair(m.Id, m.VersionString))); } catch (Exception e) when (FatalError.ReportWithoutCrash(e)) { } _projectToInstalledPackageAndVersion.AddOrUpdate(projectId, installedPackages, (_1, _2) => installedPackages); } public bool IsInstalled(Workspace workspace, ProjectId projectId, string packageName) { ThisCanBeCalledOnAnyThread(); Dictionary installedPackages; return _projectToInstalledPackageAndVersion.TryGetValue(projectId, out installedPackages) && installedPackages.ContainsKey(packageName); } public IEnumerable GetInstalledVersions(string packageName) { ThisCanBeCalledOnAnyThread(); var installedVersions = new HashSet(); foreach (var installedPackages in _projectToInstalledPackageAndVersion.Values) { string version = null; if (installedPackages?.TryGetValue(packageName, out version) == true && version != null) { installedVersions.Add(version); } } // Order the versions with a weak heuristic so that 'newer' versions come first. // Essentially, we try to break the version on dots, and then we use a LogicalComparer // to try to more naturally order the things we see between the dots. var versionsAndSplits = installedVersions.Select(v => new { Version = v, Split = v.Split('.') }).ToList(); versionsAndSplits.Sort((v1, v2) => { var diff = CompareSplit(v1.Split, v2.Split); return diff != 0 ? diff : -v1.Version.CompareTo(v2.Version); }); return versionsAndSplits.Select(v => v.Version).ToList(); } private int CompareSplit(string[] split1, string[] split2) { ThisCanBeCalledOnAnyThread(); for (int i = 0, n = Math.Min(split1.Length, split2.Length); i < n; i++) { // Prefer things that look larger. i.e. 7 should come before 6. // Use a logical string comparer so that 10 is understood to be // greater than 3. var diff = -LogicalStringComparer.Instance.Compare(split1[i], split2[i]); if (diff != 0) { return diff; } } // Choose the one with more parts. return split2.Length - split1.Length; } public IEnumerable GetProjectsWithInstalledPackage(Solution solution, string packageName, string version) { ThisCanBeCalledOnAnyThread(); var result = new List(); foreach (var kvp in this._projectToInstalledPackageAndVersion) { var installedPackageAndVersion = kvp.Value; if (installedPackageAndVersion != null) { string installedVersion; if (installedPackageAndVersion.TryGetValue(packageName, out installedVersion) && installedVersion == version) { var project = solution.GetProject(kvp.Key); if (project != null) { result.Add(project); } } } } return result; } public void ShowManagePackagesDialog(string packageName) { this.AssertIsForeground(); var shell = _workspace.GetVsService(); if (shell == null) { return; } IVsPackage nugetPackage; var nugetGuid = new Guid("5fcc8577-4feb-4d04-ad72-d6c629b083cc"); shell.LoadPackage(ref nugetGuid, out nugetPackage); if (nugetPackage == null) { return; } // We're able to launch the package manager (with an item in its search box) by // using the IVsSearchProvider API that the nuget package exposes. // // We get that interface for it and then pass it a SearchQuery that effectively // wraps the package name we're looking for. The NuGet package will then read // out that string and populate their search box with it. var extensionProvider = (IVsPackageExtensionProvider)nugetPackage; var extensionGuid = new Guid("042C2B4B-C7F7-49DB-B7A2-402EB8DC7892"); var emptyGuid = Guid.Empty; var searchProvider = (IVsSearchProvider)extensionProvider.CreateExtensionInstance(ref emptyGuid, ref extensionGuid); var task = searchProvider.CreateSearch(dwCookie: 1, pSearchQuery: new SearchQuery(packageName), pSearchCallback: this); task.Start(); } public void ReportProgress(IVsSearchTask pTask, uint dwProgress, uint dwMaxProgress) { } public void ReportComplete(IVsSearchTask pTask, uint dwResultsFound) { } public void ReportResult(IVsSearchTask pTask, IVsSearchItemResult pSearchItemResult) { pSearchItemResult.InvokeAction(); } public void ReportResults(IVsSearchTask pTask, uint dwResults, IVsSearchItemResult[] pSearchItemResults) { } private class SearchQuery : IVsSearchQuery { public SearchQuery(string packageName) { this.SearchString = packageName; } public string SearchString { get; } public uint ParseError => 0; public uint GetTokens(uint dwMaxTokens, IVsSearchToken[] rgpSearchTokens) { return 0; } } } }