ICodeDefinitionFactoryExtensions.cs 20.4 KB
Newer Older
1
// Copyright (c) Microsoft.  All Rights Reserved.  Licensed under the Apache License, Version 2.0.  See License.txt in the project root for license information.
P
Pilchie 已提交
2

3
using System;
P
Pilchie 已提交
4 5 6 7
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
C
CyrusNajmabadi 已提交
8
using System.Threading.Tasks;
P
Pilchie 已提交
9
using Microsoft.CodeAnalysis.CodeGeneration;
10
using Microsoft.CodeAnalysis.Editing;
C
CyrusNajmabadi 已提交
11
using Microsoft.CodeAnalysis.ErrorReporting;
P
Pilchie 已提交
12
using Microsoft.CodeAnalysis.FindSymbols;
13
using Microsoft.CodeAnalysis.Simplification;
P
Pilchie 已提交
14 15 16 17 18 19
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.Shared.Extensions
{
    internal static partial class ICodeDefinitionFactoryExtensions
    {
20
        public static SyntaxNode CreateThrowNotImplementedStatement(
21
            this SyntaxGenerator codeDefinitionFactory,
P
Pilchie 已提交
22 23
            Compilation compilation)
        {
24
            return codeDefinitionFactory.ThrowStatement(
25 26 27
               codeDefinitionFactory.ObjectCreationExpression(
                   codeDefinitionFactory.TypeExpression(compilation.NotImplementedExceptionType(), addImport: false),
                   SpecializedCollections.EmptyList<SyntaxNode>()));
P
Pilchie 已提交
28 29
        }

C
CyrusNajmabadi 已提交
30 31 32
        public static ImmutableArray<SyntaxNode> CreateThrowNotImplementedStatementBlock(
            this SyntaxGenerator codeDefinitionFactory, Compilation compilation)
            => ImmutableArray.Create(CreateThrowNotImplementedStatement(codeDefinitionFactory, compilation));
P
Pilchie 已提交
33

C
CyrusNajmabadi 已提交
34
        public static ImmutableArray<SyntaxNode> CreateArguments(
35
            this SyntaxGenerator factory,
P
Pilchie 已提交
36 37
            ImmutableArray<IParameterSymbol> parameters)
        {
C
CyrusNajmabadi 已提交
38
            return parameters.SelectAsArray(p => CreateArgument(factory, p));
P
Pilchie 已提交
39 40 41
        }

        private static SyntaxNode CreateArgument(
42
            this SyntaxGenerator factory,
P
Pilchie 已提交
43 44
            IParameterSymbol parameter)
        {
45
            return factory.Argument(parameter.RefKind, factory.IdentifierName(parameter.Name));
P
Pilchie 已提交
46 47 48
        }

        public static IMethodSymbol CreateBaseDelegatingConstructor(
49
            this SyntaxGenerator factory,
P
Pilchie 已提交
50 51 52 53 54 55
            IMethodSymbol constructor,
            string typeName)
        {
            // Create a constructor that calls the base constructor.  Note: if there are no
            // parameters then don't bother writing out "base()" it's automatically implied.
            return CodeGenerationSymbolFactory.CreateConstructorSymbol(
C
CyrusNajmabadi 已提交
56
                attributes: default(ImmutableArray<AttributeData>),
P
Pilchie 已提交
57
                accessibility: Accessibility.Public,
58
                modifiers: new DeclarationModifiers(),
P
Pilchie 已提交
59 60
                typeName: typeName,
                parameters: constructor.Parameters,
C
CyrusNajmabadi 已提交
61 62 63 64
                statements: default(ImmutableArray<SyntaxNode>),
                baseConstructorArguments: constructor.Parameters.Length == 0 
                    ? default(ImmutableArray<SyntaxNode>)
                    : factory.CreateArguments(constructor.Parameters));
P
Pilchie 已提交
65 66 67
        }

        public static IEnumerable<ISymbol> CreateFieldDelegatingConstructor(
68
            this SyntaxGenerator factory,
P
Pilchie 已提交
69 70
            string typeName,
            INamedTypeSymbol containingTypeOpt,
C
CyrusNajmabadi 已提交
71
            ImmutableArray<IParameterSymbol> parameters,
P
Pilchie 已提交
72 73 74 75 76
            IDictionary<string, ISymbol> parameterToExistingFieldMap,
            IDictionary<string, string> parameterToNewFieldMap,
            CancellationToken cancellationToken)
        {
            var fields = factory.CreateFieldsForParameters(parameters, parameterToNewFieldMap);
77 78
            var statements = factory.CreateAssignmentStatements(parameters, parameterToExistingFieldMap, parameterToNewFieldMap)
                                    .Select(s => s.WithAdditionalAnnotations(Simplifier.Annotation));
P
Pilchie 已提交
79 80 81 82 83 84 85

            foreach (var field in fields)
            {
                yield return field;
            }

            yield return CodeGenerationSymbolFactory.CreateConstructorSymbol(
C
CyrusNajmabadi 已提交
86
                attributes: default(ImmutableArray<AttributeData>),
P
Pilchie 已提交
87
                accessibility: Accessibility.Public,
88
                modifiers: new DeclarationModifiers(),
P
Pilchie 已提交
89 90
                typeName: typeName,
                parameters: parameters,
C
CyrusNajmabadi 已提交
91
                statements: statements.ToImmutableArray(),
P
Pilchie 已提交
92 93 94
                thisConstructorArguments: GetThisConstructorArguments(containingTypeOpt, parameterToExistingFieldMap));
        }

C
CyrusNajmabadi 已提交
95
        private static ImmutableArray<SyntaxNode> GetThisConstructorArguments(
P
Pilchie 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
            INamedTypeSymbol containingTypeOpt,
            IDictionary<string, ISymbol> parameterToExistingFieldMap)
        {
            if (containingTypeOpt != null && containingTypeOpt.TypeKind == TypeKind.Struct)
            {
                // Special case.  If we're generating a struct constructor, then we'll need
                // to initialize all fields in the struct, not just the ones we're creating.  To
                // do that, we call the default constructor.
                var realFields = containingTypeOpt.GetMembers()
                                     .OfType<IFieldSymbol>()
                                     .Where(f => !f.IsStatic);
                var initializedFields = parameterToExistingFieldMap.Values
                                            .OfType<IFieldSymbol>()
                                            .Where(f => !f.IsImplicitlyDeclared && !f.IsStatic);
                if (initializedFields.Count() < realFields.Count())
                {
                    // We have less field assignments than actual fields.  Generate a call to the
                    // default constructor as well.
C
CyrusNajmabadi 已提交
114
                    return ImmutableArray<SyntaxNode>.Empty;
P
Pilchie 已提交
115 116 117
                }
            }

C
CyrusNajmabadi 已提交
118
            return default(ImmutableArray<SyntaxNode>);
P
Pilchie 已提交
119 120 121
        }

        public static IEnumerable<IFieldSymbol> CreateFieldsForParameters(
122
            this SyntaxGenerator factory,
P
Pilchie 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135
            IList<IParameterSymbol> parameters,
            IDictionary<string, string> parameterToNewFieldMap)
        {
            foreach (var parameter in parameters)
            {
                var refKind = parameter.RefKind;
                var parameterType = parameter.Type;
                var parameterName = parameter.Name;

                if (refKind != RefKind.Out)
                {
                    // For non-out parameters, create a field and assign the parameter to it. 
                    // TODO: I'm not sure that's what we really want for ref parameters. 
C
CyrusNajmabadi 已提交
136
                    if (TryGetValue(parameterToNewFieldMap, parameterName, out var fieldName))
P
Pilchie 已提交
137 138
                    {
                        yield return CodeGenerationSymbolFactory.CreateFieldSymbol(
C
CyrusNajmabadi 已提交
139
                            attributes: default(ImmutableArray<AttributeData>),
P
Pilchie 已提交
140
                            accessibility: Accessibility.Private,
141
                            modifiers: default(DeclarationModifiers),
P
Pilchie 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
                            type: parameterType,
                            name: parameterToNewFieldMap[parameterName]);
                    }
                }
            }
        }

        private static bool TryGetValue(IDictionary<string, string> dictionary, string key, out string value)
        {
            value = null;
            return
                dictionary != null &&
                dictionary.TryGetValue(key, out value);
        }

        private static bool TryGetValue(IDictionary<string, ISymbol> dictionary, string key, out string value)
        {
            value = null;
C
CyrusNajmabadi 已提交
160
            if (dictionary != null && dictionary.TryGetValue(key, out var symbol))
P
Pilchie 已提交
161 162 163 164 165 166 167 168 169
            {
                value = symbol.Name;
                return true;
            }

            return false;
        }

        public static IEnumerable<SyntaxNode> CreateAssignmentStatements(
170
            this SyntaxGenerator factory,
P
Pilchie 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184
            IList<IParameterSymbol> parameters,
            IDictionary<string, ISymbol> parameterToExistingFieldMap,
            IDictionary<string, string> parameterToNewFieldMap)
        {
            foreach (var parameter in parameters)
            {
                var refKind = parameter.RefKind;
                var parameterType = parameter.Type;
                var parameterName = parameter.Name;

                if (refKind == RefKind.Out)
                {
                    // If it's an out param, then don't create a field for it.  Instead, assign
                    // assign the default value for that type (i.e. "default(...)") to it.
185 186 187 188
                    var assignExpression = factory.AssignmentStatement(
                        factory.IdentifierName(parameterName),
                        factory.DefaultExpression(parameterType));
                    var statement = factory.ExpressionStatement(assignExpression);
P
Pilchie 已提交
189 190 191 192 193 194
                    yield return statement;
                }
                else
                {
                    // For non-out parameters, create a field and assign the parameter to it. 
                    // TODO: I'm not sure that's what we really want for ref parameters. 
C
CyrusNajmabadi 已提交
195
                    if (TryGetValue(parameterToExistingFieldMap, parameterName, out var fieldName) ||
P
Pilchie 已提交
196 197
                        TryGetValue(parameterToNewFieldMap, parameterName, out fieldName))
                    {
198 199
                        var fieldAccess = factory.MemberAccessExpression(factory.ThisExpression(), factory.IdentifierName(fieldName))
                                                 .WithAdditionalAnnotations(Simplifier.Annotation);
200
                        var assignExpression = factory.AssignmentStatement(
201
                            fieldAccess, factory.IdentifierName(parameterName));
202
                        var statement = factory.ExpressionStatement(assignExpression);
P
Pilchie 已提交
203 204 205 206 207 208
                        yield return statement;
                    }
                }
            }
        }

C
CyrusNajmabadi 已提交
209
        public static async Task<IPropertySymbol> OverridePropertyAsync(
210
            this SyntaxGenerator codeFactory,
P
Pilchie 已提交
211
            IPropertySymbol overriddenProperty,
C
CyrusNajmabadi 已提交
212
            DeclarationModifiers modifiers,
P
Pilchie 已提交
213 214
            INamedTypeSymbol containingType,
            Document document,
C
CyrusNajmabadi 已提交
215
            CancellationToken cancellationToken)
P
Pilchie 已提交
216 217 218 219 220 221 222 223 224 225
        {
            var getAccessibility = overriddenProperty.GetMethod.ComputeResultantAccessibility(containingType);
            var setAccessibility = overriddenProperty.SetMethod.ComputeResultantAccessibility(containingType);

            SyntaxNode getBody = null;
            SyntaxNode setBody = null;

            // Implement an abstract property by throwing not implemented in accessors.
            if (overriddenProperty.IsAbstract)
            {
C
CyrusNajmabadi 已提交
226
                var compilation = await document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
227 228 229 230
                var statement = codeFactory.CreateThrowNotImplementedStatement(compilation);

                getBody = statement;
                setBody = statement;
P
Pilchie 已提交
231 232 233 234
            }
            else if (overriddenProperty.IsIndexer() && document.Project.Language == LanguageNames.CSharp)
            {
                // Indexer: return or set base[]. Only in C#, since VB must refer to these by name.
235 236 237
                getBody = codeFactory.ReturnStatement(
                    codeFactory.ElementAccessExpression(
                        codeFactory.BaseExpression(),
P
Pilchie 已提交
238 239
                        codeFactory.CreateArguments(overriddenProperty.Parameters)));

240 241 242 243
                setBody = codeFactory.ExpressionStatement(
                    codeFactory.AssignmentStatement(
                    codeFactory.ElementAccessExpression(
                        codeFactory.BaseExpression(),
P
Pilchie 已提交
244
                        codeFactory.CreateArguments(overriddenProperty.Parameters)),
245
                    codeFactory.IdentifierName("value")));
P
Pilchie 已提交
246 247 248 249 250
            }
            else if (overriddenProperty.GetParameters().Any())
            {
                // Call accessors directly if C# overriding VB
                if (document.Project.Language == LanguageNames.CSharp
C
CyrusNajmabadi 已提交
251 252
                    && (await SymbolFinder.FindSourceDefinitionAsync(overriddenProperty, document.Project.Solution, cancellationToken).ConfigureAwait(false))
                        .Language == LanguageNames.VisualBasic)
P
Pilchie 已提交
253
                {
254 255
                    var getName = overriddenProperty.GetMethod?.Name;
                    var setName = overriddenProperty.SetMethod?.Name;
P
Pilchie 已提交
256 257 258

                    getBody = getName == null
                        ? null
259 260 261 262 263
                        : codeFactory.ReturnStatement(
                    codeFactory.InvocationExpression(
                        codeFactory.MemberAccessExpression(
                            codeFactory.BaseExpression(),
                            codeFactory.IdentifierName(getName)),
P
Pilchie 已提交
264 265 266 267
                        codeFactory.CreateArguments(overriddenProperty.Parameters)));

                    setBody = setName == null
                        ? null
268 269 270 271 272
                        : codeFactory.ExpressionStatement(
                        codeFactory.InvocationExpression(
                            codeFactory.MemberAccessExpression(
                                codeFactory.BaseExpression(),
                                codeFactory.IdentifierName(setName)),
P
Pilchie 已提交
273 274 275 276
                            codeFactory.CreateArguments(overriddenProperty.SetMethod.GetParameters())));
                }
                else
                {
277 278 279 280 281 282 283 284 285 286 287 288
                    getBody = codeFactory.ReturnStatement(
                        codeFactory.InvocationExpression(
                        codeFactory.MemberAccessExpression(
                            codeFactory.BaseExpression(),
                            codeFactory.IdentifierName(overriddenProperty.Name)), codeFactory.CreateArguments(overriddenProperty.Parameters)));
                    setBody = codeFactory.ExpressionStatement(
                        codeFactory.AssignmentStatement(
                            codeFactory.InvocationExpression(
                            codeFactory.MemberAccessExpression(
                            codeFactory.BaseExpression(),
                        codeFactory.IdentifierName(overriddenProperty.Name)), codeFactory.CreateArguments(overriddenProperty.Parameters)),
                        codeFactory.IdentifierName("value")));
P
Pilchie 已提交
289 290 291 292 293
                }
            }
            else
            {
                // Regular property: return or set the base property
294 295 296 297 298 299 300 301 302 303
                getBody = codeFactory.ReturnStatement(
                    codeFactory.MemberAccessExpression(
                        codeFactory.BaseExpression(),
                        codeFactory.IdentifierName(overriddenProperty.Name)));
                setBody = codeFactory.ExpressionStatement(
                    codeFactory.AssignmentStatement(
                        codeFactory.MemberAccessExpression(
                        codeFactory.BaseExpression(),
                    codeFactory.IdentifierName(overriddenProperty.Name)),
                    codeFactory.IdentifierName("value")));
P
Pilchie 已提交
304 305 306 307 308 309 310 311 312
            }

            // Only generate a getter if the base getter is accessible.
            IMethodSymbol accessorGet = null;
            if (overriddenProperty.GetMethod != null && overriddenProperty.GetMethod.IsAccessibleWithin(containingType))
            {
                accessorGet = CodeGenerationSymbolFactory.CreateMethodSymbol(
                    overriddenProperty.GetMethod,
                    accessibility: getAccessibility,
C
CyrusNajmabadi 已提交
313
                    statements: ImmutableArray.Create(getBody),
P
Pilchie 已提交
314 315 316 317 318 319 320 321 322 323 324 325
                    modifiers: modifiers);
            }

            // Only generate a setter if the base setter is accessible.
            IMethodSymbol accessorSet = null;
            if (overriddenProperty.SetMethod != null &&
                overriddenProperty.SetMethod.IsAccessibleWithin(containingType) &&
                overriddenProperty.SetMethod.DeclaredAccessibility != Accessibility.Private)
            {
                accessorSet = CodeGenerationSymbolFactory.CreateMethodSymbol(
                    overriddenProperty.SetMethod,
                    accessibility: setAccessibility,
C
CyrusNajmabadi 已提交
326
                    statements: ImmutableArray.Create(setBody),
P
Pilchie 已提交
327 328 329 330 331 332 333 334 335 336 337 338 339 340
                    modifiers: modifiers);
            }

            return CodeGenerationSymbolFactory.CreatePropertySymbol(
                overriddenProperty,
                accessibility: overriddenProperty.ComputeResultantAccessibility(containingType),
                modifiers: modifiers,
                name: overriddenProperty.Name,
                isIndexer: overriddenProperty.IsIndexer(),
                getMethod: accessorGet,
                setMethod: accessorSet);
        }

        public static IEventSymbol OverrideEvent(
341
            this SyntaxGenerator codeFactory,
P
Pilchie 已提交
342
            IEventSymbol overriddenEvent,
C
CyrusNajmabadi 已提交
343 344
            DeclarationModifiers modifiers,
            INamedTypeSymbol newContainingType)
P
Pilchie 已提交
345 346 347
        {
            return CodeGenerationSymbolFactory.CreateEventSymbol(
                overriddenEvent,
C
CyrusNajmabadi 已提交
348
                attributes: default(ImmutableArray<AttributeData>),
P
Pilchie 已提交
349 350 351 352 353 354
                accessibility: overriddenEvent.ComputeResultantAccessibility(newContainingType),
                modifiers: modifiers,
                explicitInterfaceSymbol: null,
                name: overriddenEvent.Name);
        }

355 356 357 358 359 360 361 362
        public static async Task<ISymbol> OverrideAsync(
            this SyntaxGenerator generator,
            ISymbol symbol,
            INamedTypeSymbol containingType,
            Document document,
            DeclarationModifiers? modifiersOpt = null,
            CancellationToken cancellationToken = default(CancellationToken))
        {
363
            var modifiers = modifiersOpt ?? GetOverrideModifiers(symbol);
364 365 366 367

            if (symbol is IMethodSymbol method)
            {
                return await generator.OverrideMethodAsync(method,
C
CyrusNajmabadi 已提交
368
                    modifiers, containingType, document, cancellationToken).ConfigureAwait(false);
369 370 371 372
            }
            else if (symbol is IPropertySymbol property)
            {
                return await generator.OverridePropertyAsync(property,
C
CyrusNajmabadi 已提交
373
                    modifiers, containingType, document, cancellationToken).ConfigureAwait(false);
374 375 376
            }
            else if (symbol is IEventSymbol ev)
            {
C
CyrusNajmabadi 已提交
377
                return generator.OverrideEvent(ev, modifiers, containingType);
378 379 380
            }
            else
            {
C
CyrusNajmabadi 已提交
381
                throw ExceptionUtilities.Unreachable;
382 383 384
            }
        }

385 386 387 388 389 390
        private static DeclarationModifiers GetOverrideModifiers(ISymbol symbol)
            => symbol.GetSymbolModifiers()
                     .WithIsOverride(true)
                     .WithIsAbstract(false)
                     .WithIsVirtual(false);

C
CyrusNajmabadi 已提交
391
        private static async Task<IMethodSymbol> OverrideMethodAsync(
392
            this SyntaxGenerator codeFactory,
P
Pilchie 已提交
393
            IMethodSymbol overriddenMethod,
C
CyrusNajmabadi 已提交
394
            DeclarationModifiers modifiers,
P
Pilchie 已提交
395 396
            INamedTypeSymbol newContainingType,
            Document newDocument,
C
CyrusNajmabadi 已提交
397
            CancellationToken cancellationToken)
P
Pilchie 已提交
398 399 400 401
        {
            // Abstract: Throw not implemented
            if (overriddenMethod.IsAbstract)
            {
C
CyrusNajmabadi 已提交
402
                var compilation = await newDocument.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false);
403 404
                var statement = codeFactory.CreateThrowNotImplementedStatement(compilation);

P
Pilchie 已提交
405 406 407 408
                return CodeGenerationSymbolFactory.CreateMethodSymbol(
                    overriddenMethod,
                    accessibility: overriddenMethod.ComputeResultantAccessibility(newContainingType),
                    modifiers: modifiers,
C
CyrusNajmabadi 已提交
409
                    statements: ImmutableArray.Create(statement));
P
Pilchie 已提交
410 411 412 413 414
            }
            else
            {
                // Otherwise, call the base method with the same parameters
                var typeParams = overriddenMethod.GetTypeArguments();
415 416
                var body = codeFactory.InvocationExpression(
                    codeFactory.MemberAccessExpression(codeFactory.BaseExpression(),
P
Pilchie 已提交
417
                    typeParams.IsDefaultOrEmpty
418 419
                        ? codeFactory.IdentifierName(overriddenMethod.Name)
                        : codeFactory.GenericName(overriddenMethod.Name, typeParams)),
P
Pilchie 已提交
420 421 422 423 424 425
                    codeFactory.CreateArguments(overriddenMethod.GetParameters()));

                return CodeGenerationSymbolFactory.CreateMethodSymbol(
                    method: overriddenMethod,
                    accessibility: overriddenMethod.ComputeResultantAccessibility(newContainingType),
                    modifiers: modifiers,
C
Cyrus Najmabadi 已提交
426
                    statements: overriddenMethod.ReturnsVoid
C
CyrusNajmabadi 已提交
427 428
                        ? ImmutableArray.Create(codeFactory.ExpressionStatement(body))
                        : ImmutableArray.Create(codeFactory.ReturnStatement(body)));
P
Pilchie 已提交
429 430 431
            }
        }
    }
432
}