AbstractAddParameterCodeFixProvider.cs 15.4 KB
Newer Older
C
CyrusNajmabadi 已提交
1 2 3 4
// Copyright (c) Microsoft.  All Rights Reserved.  Licensed under the Apache License, Version 2.0.  See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
C
CyrusNajmabadi 已提交
5
using System.Collections.Immutable;
C
CyrusNajmabadi 已提交
6
using System.Linq;
7
using System.Threading;
C
CyrusNajmabadi 已提交
8
using System.Threading.Tasks;
9
using Microsoft.CodeAnalysis.CodeActions;
C
CyrusNajmabadi 已提交
10
using Microsoft.CodeAnalysis.CodeFixes;
11 12 13 14 15 16
using Microsoft.CodeAnalysis.CodeGeneration;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;
C
CyrusNajmabadi 已提交
17 18 19

namespace Microsoft.CodeAnalysis.AddParameter
{
20 21 22 23 24 25 26 27 28 29 30 31
    internal abstract class AbstractAddParameterCodeFixProvider<
        TArgumentSyntax,
        TAttributeArgumentSyntax,
        TArgumentListSyntax,
        TAttributeArgumentListSyntax,
        TInvocationExpressionSyntax,
        TObjectCreationExpressionSyntax> : CodeFixProvider
        where TArgumentSyntax : SyntaxNode
        where TArgumentListSyntax : SyntaxNode
        where TAttributeArgumentListSyntax : SyntaxNode
        where TInvocationExpressionSyntax : SyntaxNode
        where TObjectCreationExpressionSyntax : SyntaxNode
C
CyrusNajmabadi 已提交
32
    {
C
CyrusNajmabadi 已提交
33 34
        protected abstract ImmutableArray<string> TooManyArgumentsDiagnosticIds { get; }

35 36 37
        public override async Task RegisterCodeFixesAsync(CodeFixContext context)
        {
            var cancellationToken = context.CancellationToken;
C
CyrusNajmabadi 已提交
38
            var diagnostic = context.Diagnostics.First();
39 40 41 42

            var document = context.Document;
            var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

C
CyrusNajmabadi 已提交
43
            var initialNode = root.FindNode(diagnostic.Location.SourceSpan);
44 45

            for (var node = initialNode; node != null; node = node.Parent)
46 47 48
            {
                if (node is TObjectCreationExpressionSyntax objectCreation)
                {
C
CyrusNajmabadi 已提交
49
                    var argumentOpt = TryGetRelevantArgument(initialNode, node, diagnostic);
50
                    await HandleObjectCreationExpressionAsync(context, objectCreation, argumentOpt).ConfigureAwait(false);
51 52 53 54
                    return;
                }
                else if (node is TInvocationExpressionSyntax invocationExpression)
                {
C
CyrusNajmabadi 已提交
55
                    var argumentOpt = TryGetRelevantArgument(initialNode, node, diagnostic);
56
                    await HandleInvocationExpressionAsync(context, invocationExpression, argumentOpt).ConfigureAwait(false);
57 58 59 60 61
                    return;
                }
            }
        }

C
CyrusNajmabadi 已提交
62 63
        private TArgumentSyntax TryGetRelevantArgument(
            SyntaxNode initialNode, SyntaxNode node, Diagnostic diagnostic)
64
        {
C
CyrusNajmabadi 已提交
65 66 67 68 69
            if (this.TooManyArgumentsDiagnosticIds.Contains(diagnostic.Id))
            {
                return null;
            }

70
            return initialNode.GetAncestorsOrThis<TArgumentSyntax>()
71
                              .LastOrDefault(a => a.AncestorsAndSelf().Contains(node));
72 73 74 75
        }

        private Task HandleInvocationExpressionAsync(
            CodeFixContext context, TInvocationExpressionSyntax invocationExpression, TArgumentSyntax argumentOpt)
76
        {
C
CyrusNajmabadi 已提交
77
            // Currently we only support this for 'new obj' calls.
78 79 80 81 82
            return SpecializedTasks.EmptyTask;
        }

        private async Task HandleObjectCreationExpressionAsync(
            CodeFixContext context,
83 84
            TObjectCreationExpressionSyntax objectCreation,
            TArgumentSyntax argumentOpt)
85 86 87 88 89 90
        {
            var document = context.Document;
            var cancellationToken = context.CancellationToken;
            var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            var syntaxFacts = document.GetLanguageService<ISyntaxFactsService>();

C
CyrusNajmabadi 已提交
91
            // Not supported if this is "new { ... }" (as there are no parameters at all.
92 93 94 95 96 97
            var typeNode = syntaxFacts.GetObjectCreationType(objectCreation);
            if (typeNode == null)
            {
                return;
            }

C
CyrusNajmabadi 已提交
98 99
            // If we can't figure out the type being created, or the type isn't in source,
            // then there's nothing we can do.
100
            var type = semanticModel.GetSymbolInfo(typeNode, cancellationToken).GetAnySymbol() as INamedTypeSymbol;
C
CyrusNajmabadi 已提交
101 102 103 104 105 106
            if (type == null)
            {
                return;
            }

            if (!type.IsNonImplicitAndFromSource())
107 108 109 110 111 112 113 114 115 116
            {
                return;
            }

            var arguments = (SeparatedSyntaxList<TArgumentSyntax>)syntaxFacts.GetArgumentsOfObjectCreationExpression(objectCreation);

            var comparer = syntaxFacts.IsCaseSensitive
                ? StringComparer.Ordinal
                : CaseInsensitiveComparison.Comparer;

C
CyrusNajmabadi 已提交
117 118
            var constructorsAndArgumentToAdd = ArrayBuilder<(IMethodSymbol constructor, TArgumentSyntax argument, int index)>.GetInstance();

119 120
            foreach (var constructor in type.InstanceConstructors.OrderBy(m => m.Parameters.Length))
            {
C
CyrusNajmabadi 已提交
121
                if (constructor.IsNonImplicitAndFromSource() &&
122
                    NonParamsParameterCount(constructor) < arguments.Count)
123 124
                {
                    var argumentToAdd = DetermineFirstArgumentToAdd(
125 126
                        semanticModel, syntaxFacts, comparer, constructor, 
                        arguments, argumentOpt);
127 128 129

                    if (argumentToAdd != null)
                    {
130 131 132 133 134 135 136 137 138
                        if (argumentOpt != null && argumentToAdd != argumentOpt)
                        {
                            // We were trying to fix a specific argument, but the argument we want
                            // to fix is something different.  That means there was an error earlier
                            // than this argument.  Which means we're looking at a non-viable 
                            // constructor.  Skip this one.
                            continue;
                        }

C
CyrusNajmabadi 已提交
139 140
                        constructorsAndArgumentToAdd.Add(
                            (constructor, argumentToAdd, arguments.IndexOf(argumentToAdd)));
141 142 143
                    }
                }
            }
C
CyrusNajmabadi 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

            // Order by the furthest argument index to the nearest argument index.  The ones with
            // larger argument indexes mean that we matched more earlier arguments (and thus are
            // likely to be the correct match).
            foreach (var tuple in constructorsAndArgumentToAdd.OrderByDescending(t => t.index))
            {
                var constructor = tuple.constructor;
                var argumentToAdd = tuple.argument;

                var parameters = constructor.Parameters.Select(p => p.ToDisplayString(SimpleFormat));
                var signature = $"{type.Name}({string.Join(", ", parameters)})";

                var title = string.Format(FeaturesResources.Add_parameter_to_0, signature);

                context.RegisterCodeFix(
C
CyrusNajmabadi 已提交
159 160
                    new MyCodeAction(title, c => FixAsync(document, constructor, argumentToAdd, arguments, c)),
                    context.Diagnostics);
C
CyrusNajmabadi 已提交
161
            }
162 163
        }

164 165 166
        private int NonParamsParameterCount(IMethodSymbol method)
            => method.IsParams() ? method.Parameters.Length - 1 : method.Parameters.Length;

167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        private async Task<Document> FixAsync(
            Document invocationDocument, 
            IMethodSymbol method,
            TArgumentSyntax argument,
            SeparatedSyntaxList<TArgumentSyntax> argumentList,
            CancellationToken cancellationToken)
        {
            var generator = SyntaxGenerator.GetGenerator(invocationDocument.Project.Solution.Workspace, method.Language);

            var methodDeclaration = await method.DeclaringSyntaxReferences[0].GetSyntaxAsync(cancellationToken).ConfigureAwait(false);

            var syntaxFacts = invocationDocument.GetLanguageService<ISyntaxFactsService>();
            var semanticFacts = invocationDocument.GetLanguageService<ISemanticFactsService>();
            var argumentName = syntaxFacts.GetNameForArgument(argument);
            var expression = syntaxFacts.GetExpressionOfArgument(argument);

            var semanticModel = await invocationDocument.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
            var parameterType = semanticModel.GetTypeInfo(expression).Type ?? semanticModel.Compilation.ObjectType;

            var newMethodDeclaration = GetNewMethodDeclaration(
                method, argument, argumentList, generator, methodDeclaration, 
188 189
                semanticFacts, argumentName, expression, semanticModel, 
                parameterType, cancellationToken);
190 191 192 193 194 195 196 197 198 199

            var root = methodDeclaration.SyntaxTree.GetRoot(cancellationToken);
            var newRoot = root.ReplaceNode(methodDeclaration, newMethodDeclaration);

            var methodDocument = invocationDocument.Project.Solution.GetDocument(methodDeclaration.SyntaxTree);
            var newDocument = methodDocument.WithSyntaxRoot(newRoot);

            return newDocument;
        }

C
CyrusNajmabadi 已提交
200 201 202 203 204 205 206 207 208 209
        private static SyntaxNode GetNewMethodDeclaration(
            IMethodSymbol method,
            TArgumentSyntax argument,
            SeparatedSyntaxList<TArgumentSyntax> argumentList,
            SyntaxGenerator generator,
            SyntaxNode declaration,
            ISemanticFactsService semanticFacts,
            string argumentName,
            SyntaxNode expression,
            SemanticModel semanticModel,
210 211
            ITypeSymbol parameterType,
            CancellationToken cancellationToken)
212 213 214 215
        {
            if (!string.IsNullOrWhiteSpace(argumentName))
            {
                var newParameterSymbol = CodeGenerationSymbolFactory.CreateParameterSymbol(
216 217
                    attributes: default(ImmutableArray<AttributeData>),
                    refKind: RefKind.None,
218 219 220 221 222 223 224 225 226
                    isParams: false,
                    type: parameterType,
                    name: argumentName);

                var newParameterDeclaration = generator.ParameterDeclaration(newParameterSymbol);
                return generator.AddParameters(declaration, new[] { newParameterDeclaration });
            }
            else
            {
227 228
                var name = semanticFacts.GenerateNameForExpression(
                    semanticModel, expression, capitalize: false, cancellationToken: cancellationToken);
229 230 231
                var uniqueName = NameGenerator.EnsureUniqueness(name, method.Parameters.Select(p => p.Name));

                var newParameterSymbol = CodeGenerationSymbolFactory.CreateParameterSymbol(
232 233
                    attributes: default(ImmutableArray<AttributeData>),
                    refKind: RefKind.None,
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
                    isParams: false,
                    type: parameterType,
                    name: uniqueName);

                var argumentIndex = argumentList.IndexOf(argument);
                var newParameterDeclaration = generator.ParameterDeclaration(newParameterSymbol);
                return generator.InsertParameters(
                    declaration, argumentIndex, new[] { newParameterDeclaration });
            }
        }

        private static readonly SymbolDisplayFormat SimpleFormat =
                    new SymbolDisplayFormat(
                        typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameOnly,
                        genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters,
                        parameterOptions: SymbolDisplayParameterOptions.IncludeParamsRefOut | SymbolDisplayParameterOptions.IncludeType,
                        miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes);

        private TArgumentSyntax DetermineFirstArgumentToAdd(
253 254 255 256
            SemanticModel semanticModel,
            ISyntaxFactsService syntaxFacts,
            StringComparer comparer,
            IMethodSymbol method,
257 258
            SeparatedSyntaxList<TArgumentSyntax> arguments,
            TArgumentSyntax argumentOpt)
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
        {
            var methodParameterNames = new HashSet<string>(comparer);
            methodParameterNames.AddRange(method.Parameters.Select(p => p.Name));

            for (int i = 0, n = arguments.Count; i < n; i++)
            {
                var argument = arguments[i];
                var argumentName = syntaxFacts.GetNameForArgument(argument);

                if (!string.IsNullOrWhiteSpace(argumentName))
                {
                    // If the user provided an argument-name and we don't have any parameters that
                    // match, then this is the argument we want to add a parameter for.
                    if (!methodParameterNames.Contains(argumentName))
                    {
                        return argument;
                    }
                }
                else
                {
                    // Positional argument.  If the position is beyond what the method supports,
                    // then this definitely is an argument we could add.
                    if (i >= method.Parameters.Length)
                    {
283 284 285 286 287 288
                        if (method.Parameters.LastOrDefault()?.IsParams == true)
                        {
                            // Last parameter is a params.  We can't place any parameters past it.
                            return null;
                        }

289 290 291
                        return argument;
                    }

C
CyrusNajmabadi 已提交
292 293
                    // Now check the type of the argument versus the type of the parameter.  If they
                    // don't match, then this is the argument we should make the parameter for.
294
                    var argumentTypeInfo = semanticModel.GetTypeInfo(syntaxFacts.GetExpressionOfArgument(argument));
C
CyrusNajmabadi 已提交
295 296 297 298 299 300 301
                    if (argumentTypeInfo.Type == null && argumentTypeInfo.ConvertedType == null)
                    {
                        // Didn't know the type of the argument.  We shouldn't assume it doesn't
                        // match a parameter. 
                        continue;
                    }

302 303
                    var parameter = method.Parameters[i];

304
                    if (!TypeInfoMatchesType(argumentTypeInfo, parameter.Type))
305
                    {
C
CyrusNajmabadi 已提交
306
                        if (TypeInfoMatchesWithParamsExpansion(argumentTypeInfo, parameter))
307
                        {
C
CyrusNajmabadi 已提交
308 309 310 311
                            // The argument matched if we expanded out the params-parameter.
                            // As the params-parameter has to be last, there's nothing else to 
                            // do here.
                            return null;
312 313
                        }

314 315 316 317 318 319 320
                        return argument;
                    }
                }
            }

            return null;
        }
C
CyrusNajmabadi 已提交
321

C
CyrusNajmabadi 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334
        private bool TypeInfoMatchesWithParamsExpansion(TypeInfo argumentTypeInfo, IParameterSymbol parameter)
        {
            if (parameter.IsParams && parameter.Type is IArrayTypeSymbol arrayType)
            {
                if (TypeInfoMatchesType(argumentTypeInfo, arrayType.ElementType))
                {
                    return true;
                }
            }

            return false;
        }

335 336 337
        private bool TypeInfoMatchesType(TypeInfo argumentTypeInfo, ITypeSymbol type)
            => type.Equals(argumentTypeInfo.Type) || type.Equals(argumentTypeInfo.ConvertedType);

338
        private class MyCodeAction : CodeAction.DocumentChangeAction
C
CyrusNajmabadi 已提交
339
        {
340 341 342 343 344 345
            public MyCodeAction(
                string title,
                Func<CancellationToken, Task<Document>> createChangedDocument)
                : base(title, createChangedDocument)
            {
            }
C
CyrusNajmabadi 已提交
346 347 348
        }
    }
}