AbstractAddParameterCodeFixProvider.cs 15.2 KB
Newer Older
C
CyrusNajmabadi 已提交
1 2 3 4
// Copyright (c) Microsoft.  All Rights Reserved.  Licensed under the Apache License, Version 2.0.  See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
C
CyrusNajmabadi 已提交
5
using System.Collections.Immutable;
C
CyrusNajmabadi 已提交
6
using System.Linq;
7
using System.Threading;
C
CyrusNajmabadi 已提交
8
using System.Threading.Tasks;
9
using Microsoft.CodeAnalysis.CodeActions;
C
CyrusNajmabadi 已提交
10
using Microsoft.CodeAnalysis.CodeFixes;
11 12 13 14 15 16
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;
C
CyrusNajmabadi 已提交
17 18 19

namespace Microsoft.CodeAnalysis.AddParameter
{
20 21 22 23 24 25 26 27 28 29 30 31
    internal abstract class AbstractAddParameterCodeFixProvider<
        TArgumentSyntax,
        TAttributeArgumentSyntax,
        TArgumentListSyntax,
        TAttributeArgumentListSyntax,
        TInvocationExpressionSyntax,
        TObjectCreationExpressionSyntax> : CodeFixProvider
        where TArgumentSyntax : SyntaxNode
        where TArgumentListSyntax : SyntaxNode
        where TAttributeArgumentListSyntax : SyntaxNode
        where TInvocationExpressionSyntax : SyntaxNode
        where TObjectCreationExpressionSyntax : SyntaxNode
C
CyrusNajmabadi 已提交
32
    {
C
CyrusNajmabadi 已提交
33 34
        protected abstract ImmutableArray<string> TooManyArgumentsDiagnosticIds { get; }

35 36 37
        public override async Task RegisterCodeFixesAsync(CodeFixContext context)
        {
            var cancellationToken = context.CancellationToken;
C
CyrusNajmabadi 已提交
38
            var diagnostic = context.Diagnostics.First();
39 40 41 42

            var document = context.Document;
            var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

C
CyrusNajmabadi 已提交
43
            var initialNode = root.FindNode(diagnostic.Location.SourceSpan);
44 45

            for (var node = initialNode; node != null; node = node.Parent)
46 47 48
            {
                if (node is TObjectCreationExpressionSyntax objectCreation)
                {
C
CyrusNajmabadi 已提交
49
                    var argumentOpt = TryGetRelevantArgument(initialNode, node, diagnostic);
50
                    await HandleObjectCreationExpressionAsync(context, objectCreation, argumentOpt).ConfigureAwait(false);
51 52 53 54
                    return;
                }
                else if (node is TInvocationExpressionSyntax invocationExpression)
                {
C
CyrusNajmabadi 已提交
55
                    var argumentOpt = TryGetRelevantArgument(initialNode, node, diagnostic);
56
                    await HandleInvocationExpressionAsync(context, invocationExpression, argumentOpt).ConfigureAwait(false);
57 58 59 60 61
                    return;
                }
            }
        }

C
CyrusNajmabadi 已提交
62 63
        private TArgumentSyntax TryGetRelevantArgument(
            SyntaxNode initialNode, SyntaxNode node, Diagnostic diagnostic)
64
        {
C
CyrusNajmabadi 已提交
65 66 67 68 69
            if (this.TooManyArgumentsDiagnosticIds.Contains(diagnostic.Id))
            {
                return null;
            }

70
            return initialNode.GetAncestorsOrThis<TArgumentSyntax>()
71
                              .LastOrDefault(a => a.AncestorsAndSelf().Contains(node));
72 73 74 75
        }

        private Task HandleInvocationExpressionAsync(
            CodeFixContext context, TInvocationExpressionSyntax invocationExpression, TArgumentSyntax argumentOpt)
76
        {
C
CyrusNajmabadi 已提交
77
            // Currently we only support this for 'new obj' calls.
78 79 80 81 82
            return SpecializedTasks.EmptyTask;
        }

        private async Task HandleObjectCreationExpressionAsync(
            CodeFixContext context,
83 84
            TObjectCreationExpressionSyntax objectCreation,
            TArgumentSyntax argumentOpt)
85 86 87 88 89 90
        {
            var document = context.Document;
            var cancellationToken = context.CancellationToken;
            var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            var syntaxFacts = document.GetLanguageService<ISyntaxFactsService>();

C
CyrusNajmabadi 已提交
91
            // Not supported if this is "new { ... }" (as there are no parameters at all.
92 93 94 95 96 97
            var typeNode = syntaxFacts.GetObjectCreationType(objectCreation);
            if (typeNode == null)
            {
                return;
            }

C
CyrusNajmabadi 已提交
98 99
            // If we can't figure out the type being created, or the type isn't in source,
            // then there's nothing we can do.
100
            var type = semanticModel.GetSymbolInfo(typeNode, cancellationToken).GetAnySymbol() as INamedTypeSymbol;
C
CyrusNajmabadi 已提交
101 102 103 104 105 106
            if (type == null)
            {
                return;
            }

            if (!type.IsNonImplicitAndFromSource())
107 108 109 110 111 112 113 114 115 116
            {
                return;
            }

            var arguments = (SeparatedSyntaxList<TArgumentSyntax>)syntaxFacts.GetArgumentsOfObjectCreationExpression(objectCreation);

            var comparer = syntaxFacts.IsCaseSensitive
                ? StringComparer.Ordinal
                : CaseInsensitiveComparison.Comparer;

C
CyrusNajmabadi 已提交
117 118
            var constructorsAndArgumentToAdd = ArrayBuilder<(IMethodSymbol constructor, TArgumentSyntax argument, int index)>.GetInstance();

119 120
            foreach (var constructor in type.InstanceConstructors.OrderBy(m => m.Parameters.Length))
            {
C
CyrusNajmabadi 已提交
121
                if (constructor.IsNonImplicitAndFromSource() &&
122
                    NonParamsParameterCount(constructor) < arguments.Count)
123 124
                {
                    var argumentToAdd = DetermineFirstArgumentToAdd(
125 126
                        semanticModel, syntaxFacts, comparer, constructor, 
                        arguments, argumentOpt);
127 128 129

                    if (argumentToAdd != null)
                    {
130 131 132 133 134 135 136 137 138
                        if (argumentOpt != null && argumentToAdd != argumentOpt)
                        {
                            // We were trying to fix a specific argument, but the argument we want
                            // to fix is something different.  That means there was an error earlier
                            // than this argument.  Which means we're looking at a non-viable 
                            // constructor.  Skip this one.
                            continue;
                        }

C
CyrusNajmabadi 已提交
139 140
                        constructorsAndArgumentToAdd.Add(
                            (constructor, argumentToAdd, arguments.IndexOf(argumentToAdd)));
141 142 143
                    }
                }
            }
C
CyrusNajmabadi 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

            // Order by the furthest argument index to the nearest argument index.  The ones with
            // larger argument indexes mean that we matched more earlier arguments (and thus are
            // likely to be the correct match).
            foreach (var tuple in constructorsAndArgumentToAdd.OrderByDescending(t => t.index))
            {
                var constructor = tuple.constructor;
                var argumentToAdd = tuple.argument;

                var parameters = constructor.Parameters.Select(p => p.ToDisplayString(SimpleFormat));
                var signature = $"{type.Name}({string.Join(", ", parameters)})";

                var title = string.Format(FeaturesResources.Add_parameter_to_0, signature);

                context.RegisterCodeFix(
C
CyrusNajmabadi 已提交
159 160
                    new MyCodeAction(title, c => FixAsync(document, constructor, argumentToAdd, arguments, c)),
                    context.Diagnostics);
C
CyrusNajmabadi 已提交
161
            }
162 163
        }

164 165 166
        private int NonParamsParameterCount(IMethodSymbol method)
            => method.IsParams() ? method.Parameters.Length - 1 : method.Parameters.Length;

167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
        private async Task<Document> FixAsync(
            Document invocationDocument, 
            IMethodSymbol method,
            TArgumentSyntax argument,
            SeparatedSyntaxList<TArgumentSyntax> argumentList,
            CancellationToken cancellationToken)
        {
            var generator = SyntaxGenerator.GetGenerator(invocationDocument.Project.Solution.Workspace, method.Language);

            var methodDeclaration = await method.DeclaringSyntaxReferences[0].GetSyntaxAsync(cancellationToken).ConfigureAwait(false);

            var syntaxFacts = invocationDocument.GetLanguageService<ISyntaxFactsService>();
            var semanticFacts = invocationDocument.GetLanguageService<ISemanticFactsService>();
            var argumentName = syntaxFacts.GetNameForArgument(argument);
            var expression = syntaxFacts.GetExpressionOfArgument(argument);

            var semanticModel = await invocationDocument.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            var parameterType = semanticModel.GetTypeInfo(expression).Type ?? semanticModel.Compilation.ObjectType;

            var newMethodDeclaration = GetNewMethodDeclaration(
                method, argument, argumentList, generator, methodDeclaration, 
                semanticFacts, argumentName, expression, semanticModel, parameterType);

            var root = methodDeclaration.SyntaxTree.GetRoot(cancellationToken);
            var newRoot = root.ReplaceNode(methodDeclaration, newMethodDeclaration);

            var methodDocument = invocationDocument.Project.Solution.GetDocument(methodDeclaration.SyntaxTree);
            var newDocument = methodDocument.WithSyntaxRoot(newRoot);

            return newDocument;
        }

        private static SyntaxNode GetNewMethodDeclaration(IMethodSymbol method, TArgumentSyntax argument, SeparatedSyntaxList<TArgumentSyntax> argumentList, SyntaxGenerator generator, SyntaxNode declaration, ISemanticFactsService semanticFacts, string argumentName, SyntaxNode expression, SemanticModel semanticModel, ITypeSymbol parameterType)
        {
            if (!string.IsNullOrWhiteSpace(argumentName))
            {
                var newParameterSymbol = CodeGenerationSymbolFactory.CreateParameterSymbol(
204 205
                    attributes: default(ImmutableArray<AttributeData>),
                    refKind: RefKind.None,
206 207 208 209 210 211 212 213 214 215 216 217 218
                    isParams: false,
                    type: parameterType,
                    name: argumentName);

                var newParameterDeclaration = generator.ParameterDeclaration(newParameterSymbol);
                return generator.AddParameters(declaration, new[] { newParameterDeclaration });
            }
            else
            {
                var name = semanticFacts.GenerateNameForExpression(semanticModel, expression);
                var uniqueName = NameGenerator.EnsureUniqueness(name, method.Parameters.Select(p => p.Name));

                var newParameterSymbol = CodeGenerationSymbolFactory.CreateParameterSymbol(
219 220
                    attributes: default(ImmutableArray<AttributeData>),
                    refKind: RefKind.None,
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
                    isParams: false,
                    type: parameterType,
                    name: uniqueName);

                var argumentIndex = argumentList.IndexOf(argument);
                var newParameterDeclaration = generator.ParameterDeclaration(newParameterSymbol);
                return generator.InsertParameters(
                    declaration, argumentIndex, new[] { newParameterDeclaration });
            }
        }

        private static readonly SymbolDisplayFormat SimpleFormat =
                    new SymbolDisplayFormat(
                        typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameOnly,
                        genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters,
                        parameterOptions: SymbolDisplayParameterOptions.IncludeParamsRefOut | SymbolDisplayParameterOptions.IncludeType,
                        miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes);

        private TArgumentSyntax DetermineFirstArgumentToAdd(
240 241 242 243
            SemanticModel semanticModel,
            ISyntaxFactsService syntaxFacts,
            StringComparer comparer,
            IMethodSymbol method,
244 245
            SeparatedSyntaxList<TArgumentSyntax> arguments,
            TArgumentSyntax argumentOpt)
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
        {
            var methodParameterNames = new HashSet<string>(comparer);
            methodParameterNames.AddRange(method.Parameters.Select(p => p.Name));

            for (int i = 0, n = arguments.Count; i < n; i++)
            {
                var argument = arguments[i];
                var argumentName = syntaxFacts.GetNameForArgument(argument);

                if (!string.IsNullOrWhiteSpace(argumentName))
                {
                    // If the user provided an argument-name and we don't have any parameters that
                    // match, then this is the argument we want to add a parameter for.
                    if (!methodParameterNames.Contains(argumentName))
                    {
                        return argument;
                    }
                }
                else
                {
                    // Positional argument.  If the position is beyond what the method supports,
                    // then this definitely is an argument we could add.
                    if (i >= method.Parameters.Length)
                    {
270 271 272 273 274 275
                        if (method.Parameters.LastOrDefault()?.IsParams == true)
                        {
                            // Last parameter is a params.  We can't place any parameters past it.
                            return null;
                        }

276 277 278
                        return argument;
                    }

C
CyrusNajmabadi 已提交
279 280
                    // Now check the type of the argument versus the type of the parameter.  If they
                    // don't match, then this is the argument we should make the parameter for.
281
                    var argumentTypeInfo = semanticModel.GetTypeInfo(syntaxFacts.GetExpressionOfArgument(argument));
C
CyrusNajmabadi 已提交
282 283 284 285 286 287 288
                    if (argumentTypeInfo.Type == null && argumentTypeInfo.ConvertedType == null)
                    {
                        // Didn't know the type of the argument.  We shouldn't assume it doesn't
                        // match a parameter. 
                        continue;
                    }

289 290
                    var parameter = method.Parameters[i];

291
                    if (!TypeInfoMatchesType(argumentTypeInfo, parameter.Type))
292
                    {
C
CyrusNajmabadi 已提交
293
                        if (TypeInfoMatchesWithParamsExpansion(argumentTypeInfo, parameter))
294
                        {
C
CyrusNajmabadi 已提交
295 296 297 298
                            // The argument matched if we expanded out the params-parameter.
                            // As the params-parameter has to be last, there's nothing else to 
                            // do here.
                            return null;
299 300
                        }

301 302 303 304 305 306 307
                        return argument;
                    }
                }
            }

            return null;
        }
C
CyrusNajmabadi 已提交
308

C
CyrusNajmabadi 已提交
309 310 311 312 313 314 315 316 317 318 319 320 321
        private bool TypeInfoMatchesWithParamsExpansion(TypeInfo argumentTypeInfo, IParameterSymbol parameter)
        {
            if (parameter.IsParams && parameter.Type is IArrayTypeSymbol arrayType)
            {
                if (TypeInfoMatchesType(argumentTypeInfo, arrayType.ElementType))
                {
                    return true;
                }
            }

            return false;
        }

322 323 324
        private bool TypeInfoMatchesType(TypeInfo argumentTypeInfo, ITypeSymbol type)
            => type.Equals(argumentTypeInfo.Type) || type.Equals(argumentTypeInfo.ConvertedType);

325
        private class MyCodeAction : CodeAction.DocumentChangeAction
C
CyrusNajmabadi 已提交
326
        {
327 328 329 330 331 332
            public MyCodeAction(
                string title,
                Func<CancellationToken, Task<Document>> createChangedDocument)
                : base(title, createChangedDocument)
            {
            }
C
CyrusNajmabadi 已提交
333 334 335
        }
    }
}