提交 307e781a 编写于 作者: M mattwar

SyntaxGenerator changes (changeset 1384116)

上级 3de74bcb
......@@ -282,6 +282,11 @@ private void AddTrivia(SyntaxNode node)
{
this.AddEndOfLine();
}
if ((this.options & SyntaxRemoveOptions.AddElasticMarker) != 0)
{
this.AddResidualTrivia(SyntaxFactory.TriviaList(SyntaxFactory.ElasticMarker));
}
}
private void AddTrivia(SyntaxToken token, SyntaxNode node)
......@@ -316,6 +321,11 @@ private void AddTrivia(SyntaxToken token, SyntaxNode node)
{
this.AddEndOfLine();
}
if ((this.options & SyntaxRemoveOptions.AddElasticMarker) != 0)
{
this.AddResidualTrivia(SyntaxFactory.TriviaList(SyntaxFactory.ElasticMarker));
}
}
private void AddTrivia(SyntaxNode node, SyntaxToken token)
......@@ -350,6 +360,11 @@ private void AddTrivia(SyntaxNode node, SyntaxToken token)
{
this.AddEndOfLine();
}
if ((this.options & SyntaxRemoveOptions.AddElasticMarker) != 0)
{
this.AddResidualTrivia(SyntaxFactory.TriviaList(SyntaxFactory.ElasticMarker));
}
}
private TextSpan GetRemovedSpan(TextSpan span, TextSpan fullSpan)
......
......@@ -42,5 +42,10 @@ public enum SyntaxRemoveOptions
/// Ensure that at least one EndOfLine trivia is kept if one was present
/// </summary>
KeepEndOfLine = 0x10,
/// <summary>
/// Adds elastic marker trivia
/// </summary>
AddElasticMarker = 0x20
}
}
\ No newline at end of file
......@@ -214,6 +214,10 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.Syntax
ElseIf (Me._options And SyntaxRemoveOptions.KeepEndOfLine) <> 0 AndAlso HasEndOfLine(node.GetTrailingTrivia()) Then
Me.AddEndOfLine()
End If
If (Me._options And SyntaxRemoveOptions.AddElasticMarker) <> 0 Then
Me.AddResidualTrivia(SyntaxFactory.TriviaList(SyntaxFactory.ElasticMarker))
End If
End Sub
Private Sub AddTrivia(token As SyntaxToken, node As SyntaxNode)
......@@ -237,6 +241,10 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.Syntax
ElseIf (Me._options And SyntaxRemoveOptions.KeepEndOfLine) <> 0 AndAlso HasEndOfLine(node.GetTrailingTrivia()) Then
Me.AddEndOfLine()
End If
If (Me._options And SyntaxRemoveOptions.AddElasticMarker) <> 0 Then
Me.AddResidualTrivia(SyntaxFactory.TriviaList(SyntaxFactory.ElasticMarker))
End If
End Sub
Private Sub AddTrivia(node As SyntaxNode, token As SyntaxToken)
......@@ -260,6 +268,10 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.Syntax
(HasEndOfLine(node.GetTrailingTrivia()) OrElse HasEndOfLine(token.LeadingTrivia) OrElse HasEndOfLine(token.TrailingTrivia)) Then
Me.AddEndOfLine()
End If
If (Me._options And SyntaxRemoveOptions.AddElasticMarker) <> 0 Then
Me.AddResidualTrivia(SyntaxFactory.TriviaList(SyntaxFactory.ElasticMarker))
End If
End Sub
Private Function GetRemovedSpan(span As TextSpan, fullSpan As TextSpan) As TextSpan
......
......@@ -47,17 +47,15 @@ private static SyntaxNode GetDeclaration(ISymbol symbol)
return (symbol.DeclaringSyntaxReferences.Length > 0) ? symbol.DeclaringSyntaxReferences[0].GetSyntax() : null;
}
private SyntaxNode GetExplicitlyAssignedField(IFieldSymbol originalField, SyntaxGenerator generator)
private SyntaxNode GetExplicitlyAssignedField(IFieldSymbol originalField, SyntaxNode declaration, SyntaxGenerator generator)
{
var originalDeclaration = GetDeclaration(originalField);
var originalInitializer = generator.GetInitializer(originalDeclaration);
var originalInitializer = generator.GetExpression(declaration);
if (originalInitializer != null || !originalField.HasConstantValue)
{
return originalDeclaration;
return declaration;
}
return generator.WithInitializer(originalDeclaration, generator.LiteralExpression(originalField.ConstantValue));
return generator.WithExpression(declaration, generator.LiteralExpression(originalField.ConstantValue));
}
private async Task<Document> GetUpdatedDocumentForRuleNameRenameAsync(Document document, IFieldSymbol field, CancellationToken cancellationToken)
......@@ -66,16 +64,15 @@ private async Task<Document> GetUpdatedDocumentForRuleNameRenameAsync(Document d
return newSolution.GetDocument(document.Id);
}
private IList<SyntaxNode> GetNewFieldsForRuleNameMultipleZero(INamedTypeSymbol enumType, IEnumerable<IFieldSymbol> zeroValuedFields, SyntaxGenerator generator)
private async Task ApplyRuleNameMultipleZeroAsync(SymbolEditor editor, INamedTypeSymbol enumType, CancellationToken cancellationToken)
{
// Diagnostic: Remove all members that have the value zero from '{0}' except for one member that is named 'None'.
// Fix: Remove all members that have the value zero except for one member that is named 'None'.
bool needsNewZeroValuedNoneField = true;
var set = zeroValuedFields.ToSet();
var set = CA1008DiagnosticAnalyzer.GetZeroValuedFields(enumType).ToSet();
bool makeNextFieldExplicit = false;
var newFields = new List<SyntaxNode>();
foreach (IFieldSymbol field in enumType.GetMembers().Where(m => m.Kind == SymbolKind.Field))
{
var isZeroValued = set.Contains(field);
......@@ -83,15 +80,12 @@ private IList<SyntaxNode> GetNewFieldsForRuleNameMultipleZero(INamedTypeSymbol e
if (!isZeroValued || isZeroValuedNamedNone)
{
var newField = GetDeclaration(field);
if (makeNextFieldExplicit)
{
newField = GetExplicitlyAssignedField(field, generator);
await editor.EditOneDeclarationAsync(field, (d, g) => GetExplicitlyAssignedField(field, d, g), cancellationToken);
makeNextFieldExplicit = false;
}
newFields.Add(newField);
if (isZeroValuedNamedNone)
{
needsNewZeroValuedNoneField = false;
......@@ -99,58 +93,30 @@ private IList<SyntaxNode> GetNewFieldsForRuleNameMultipleZero(INamedTypeSymbol e
}
else
{
await editor.EditOneDeclarationAsync(field, (d, g) => null); // removes the field declaration
makeNextFieldExplicit = true;
}
}
if (needsNewZeroValuedNoneField)
{
var firstZeroValuedField = zeroValuedFields.First();
var newField = generator.EnumMember("None");
newFields.Insert(0, newField);
await editor.EditOneDeclarationAsync(enumType, (d, g) => g.InsertMembers(d, 0, g.EnumMember("None")), cancellationToken);
}
return newFields;
}
private Document GetUpdatedDocumentForRuleNameMultipleZero(Document document, SyntaxNode root, SyntaxNode nodeToFix, INamedTypeSymbol enumType, IEnumerable<IFieldSymbol> zeroValuedFields, CancellationToken cancellationToken)
{
Contract.ThrowIfFalse(zeroValuedFields.Count() > 1);
var generator = SyntaxGenerator.GetGenerator(document);
var newFields = GetNewFieldsForRuleNameMultipleZero(enumType, zeroValuedFields, generator);
return GetUpdatedDocumentWithFix(document, root, nodeToFix, newFields, cancellationToken);
}
private IList<SyntaxNode> GetNewFieldsForRuleNameNoZeroValue(INamedTypeSymbol enumType, SyntaxGenerator generator)
private async Task ApplyRuleNameNoZeroValueAsync(SymbolEditor editor, INamedTypeSymbol enumType, CancellationToken cancellationToken)
{
// Diagnostic: Add a member to '{0}' that has a value of zero with a suggested name of 'None'.
// Fix: Add a zero-valued member 'None' to enum.
var newFields = new List<SyntaxNode>();
var newField = generator.EnumMember("None");
newFields.Add(newField);
foreach (var member in enumType.GetMembers().Where(m => m.Kind == SymbolKind.Field))
// remove any non-zero member named 'None'
foreach (IFieldSymbol field in enumType.GetMembers().Where(m => m.Kind == SymbolKind.Field))
{
if (!CA1008DiagnosticAnalyzer.IsMemberNamedNone(member))
if (CA1008DiagnosticAnalyzer.IsMemberNamedNone(field))
{
var decl = GetDeclaration(member);
if (decl != null)
{
newFields.Add(decl);
}
await editor.EditOneDeclarationAsync(field, (d, g) => null);
}
}
return newFields;
}
private Document GetUpdatedDocumentForRuleNameNoZeroValue(Document document, SyntaxNode root, SyntaxNode nodeToFix, INamedTypeSymbol enumType, CancellationToken cancellationToken)
{
var generator = SyntaxGenerator.GetGenerator(document);
var newFields = GetNewFieldsForRuleNameNoZeroValue(enumType, generator);
return GetUpdatedDocumentWithFix(document, root, nodeToFix, newFields, cancellationToken);
// insert zero-valued member 'None' to top
await editor.EditOneDeclarationAsync(enumType, (d, g) => g.InsertMembers(d, 0, g.EnumMember("None")), cancellationToken);
}
protected virtual SyntaxNode GetParentNodeOrSelfToFix(SyntaxNode nodeToFix)
......@@ -162,8 +128,7 @@ private Document GetUpdatedDocumentWithFix(Document document, SyntaxNode root, S
{
nodeToFix = GetParentNodeOrSelfToFix(nodeToFix);
var g = SyntaxGenerator.GetGenerator(document);
var newEnumSyntax = g.WithMembers(nodeToFix, newFields)
.WithAdditionalAnnotations(Formatting.Formatter.Annotation);
var newEnumSyntax = g.AddMembers(nodeToFix, newFields);
var newRoot = root.ReplaceNode(nodeToFix, newEnumSyntax);
return document.WithSyntaxRoot(newRoot);
}
......@@ -173,6 +138,8 @@ internal sealed override async Task<Document> GetUpdatedDocumentAsync(Document d
ISymbol declaredSymbol = model.GetDeclaredSymbol(nodeToFix, cancellationToken);
Contract.ThrowIfNull(declaredSymbol);
var editor = new SymbolEditor(document);
foreach (var customTag in diagnostic.Descriptor.CustomTags)
{
switch (customTag)
......@@ -181,12 +148,12 @@ internal sealed override async Task<Document> GetUpdatedDocumentAsync(Document d
return await GetUpdatedDocumentForRuleNameRenameAsync(document, (IFieldSymbol)declaredSymbol, cancellationToken).ConfigureAwait(false);
case CA1008DiagnosticAnalyzer.RuleMultipleZeroCustomTag:
var enumType = (INamedTypeSymbol)declaredSymbol;
var zeroValuedFields = CA1008DiagnosticAnalyzer.GetZeroValuedFields(enumType);
return GetUpdatedDocumentForRuleNameMultipleZero(document, root, nodeToFix, enumType, zeroValuedFields, cancellationToken);
await ApplyRuleNameMultipleZeroAsync(editor, (INamedTypeSymbol)declaredSymbol, cancellationToken);
return editor.GetChangedDocuments().First();
case CA1008DiagnosticAnalyzer.RuleNoZeroCustomTag:
return GetUpdatedDocumentForRuleNameNoZeroValue(document, root, nodeToFix, (INamedTypeSymbol)declaredSymbol, cancellationToken);
await ApplyRuleNameNoZeroValueAsync(editor, (INamedTypeSymbol)declaredSymbol, cancellationToken);
return editor.GetChangedDocuments().First();
}
}
......
......@@ -387,6 +387,99 @@ partial class C
Assert.Equal(expected, actual);
}
[Fact]
public void TestEditDeclarationWithLocation_SequentialEdits_SameLocation()
{
var code =
@"partial class C
{
}
partial class C
{
}";
var expected =
@"partial class C
{
}
partial class C
{
void m()
{
}
void m2()
{
}
}";
var solution = GetSolution(code);
var symbol = GetSymbols(solution, "C").First();
var location = symbol.Locations.Last();
var editor = new SymbolEditor(solution);
var newSymbol = (INamedTypeSymbol)editor.EditOneDeclarationAsync(symbol, location, (d, g) => g.AddMembers(d, g.MethodDeclaration("m"))).Result;
Assert.Equal(1, newSymbol.GetMembers("m").Length);
// reuse location from original symbol/solution
var newSymbol2 = (INamedTypeSymbol)editor.EditOneDeclarationAsync(newSymbol, location, (d, g) => g.AddMembers(d, g.MethodDeclaration("m2"))).Result;
Assert.Equal(1, newSymbol2.GetMembers("m").Length);
Assert.Equal(1, newSymbol2.GetMembers("m2").Length);
var actual = GetActual(editor.GetChangedDocuments().First());
Assert.Equal(expected, actual);
}
[Fact]
public void TestEditDeclarationWithLocation_SequentialEdits_NewLocation()
{
var code =
@"partial class C
{
}
partial class C
{
}";
var expected =
@"partial class C
{
}
partial class C
{
void m()
{
}
void m2()
{
}
}";
var solution = GetSolution(code);
var symbol = GetSymbols(solution, "C").First();
var location = symbol.Locations.Last();
var editor = new SymbolEditor(solution);
var newSymbol = (INamedTypeSymbol)editor.EditOneDeclarationAsync(symbol, location, (d, g) => g.AddMembers(d, g.MethodDeclaration("m"))).Result;
Assert.Equal(1, newSymbol.GetMembers("m").Length);
// use location from new symbol
var newLocation = newSymbol.Locations.Last();
var newSymbol2 = (INamedTypeSymbol)editor.EditOneDeclarationAsync(newSymbol, newLocation, (d, g) => g.AddMembers(d, g.MethodDeclaration("m2"))).Result;
Assert.Equal(1, newSymbol2.GetMembers("m").Length);
Assert.Equal(1, newSymbol2.GetMembers("m2").Length);
var actual = GetActual(editor.GetChangedDocuments().First());
Assert.Equal(expected, actual);
}
[Fact]
public void TestEditDeclarationWithMember()
{
......@@ -423,12 +516,160 @@ void m2()
var member = symbol.GetMembers("m").First();
var editor = new SymbolEditor(solution);
var newSymbol = (INamedTypeSymbol)editor.EditDeclarationsWithMemberDeclaredAsync(symbol, member, (d, g) => g.AddMembers(d, g.MethodDeclaration("m2"))).Result;
var newSymbol = (INamedTypeSymbol)editor.EditOneDeclarationAsync(symbol, member, (d, g) => g.AddMembers(d, g.MethodDeclaration("m2"))).Result;
Assert.Equal(1, newSymbol.GetMembers("m").Length);
var actual = GetActual(editor.GetChangedDocuments().First());
Assert.Equal(expected, actual);
}
[Fact]
public void TestChangeLogicalIdentityReturnsCorrectSymbol_OneDeclaration()
{
// proves that APIs return the correct new symbol even after a change that changes the symbol's logical identity.
var code =
@"class C
{
}";
var expected =
@"class X
{
}";
var solution = GetSolution(code);
var symbol = GetSymbols(solution, "C").First();
var editor = new SymbolEditor(solution);
var newSymbol = (INamedTypeSymbol)editor.EditOneDeclarationAsync(symbol, (d, g) => g.WithName(d, "X")).Result;
Assert.Equal("X", newSymbol.Name);
// original symbols cannot be rebound after identity change.
var reboundSymbol = editor.GetCurrentSymbolAsync(symbol).Result;
Assert.Null(reboundSymbol);
var actual = GetActual(editor.GetChangedDocuments().First());
Assert.Equal(expected, actual);
}
[Fact]
public void TestChangeLogicalIdentityReturnsCorrectSymbol_AllDeclarations()
{
// proves that APIs return the correct new symbol even after a change that changes the symbol's logical identity.
var code =
@"partial class C
{
}
partial class C
{
}";
var expected =
@"partial class X
{
}
partial class X
{
}";
var solution = GetSolution(code);
var symbol = GetSymbols(solution, "C").First();
var editor = new SymbolEditor(solution);
var newSymbol = (INamedTypeSymbol)editor.EditAllDeclarationsAsync(symbol, (d, g) => g.WithName(d, "X")).Result;
Assert.Equal("X", newSymbol.Name);
// original symbols cannot be rebound after identity change.
var reboundSymbol = editor.GetCurrentSymbolAsync(symbol).Result;
Assert.Null(reboundSymbol);
var actual = GetActual(editor.GetChangedDocuments().First());
Assert.Equal(expected, actual);
}
[Fact]
public void TestRemovedDeclarationReturnsNull()
{
var code =
@"class C
{
}";
var expected =
@"";
var solution = GetSolution(code);
var symbol = GetSymbols(solution, "C").First();
var editor = new SymbolEditor(solution);
var newSymbol = (INamedTypeSymbol)editor.EditOneDeclarationAsync(symbol, (d, g) => null).Result;
Assert.Null(newSymbol);
var actual = GetActual(editor.GetChangedDocuments().First());
Assert.Equal(expected, actual);
}
[Fact]
public void TestRemovedOneOfManyDeclarationsReturnsChangedSymbol()
{
var code =
@"partial class C
{
}
partial class C
{
}";
var expected =
@"
partial class C
{
}";
var solution = GetSolution(code);
var symbol = GetSymbols(solution, "C").First();
var editor = new SymbolEditor(solution);
var newSymbol = (INamedTypeSymbol)editor.EditOneDeclarationAsync(symbol, (d, g) => null).Result;
Assert.NotNull(newSymbol);
Assert.Equal("C", newSymbol.Name);
var actual = GetActual(editor.GetChangedDocuments().First());
Assert.Equal(expected, actual);
}
[Fact]
public void TestRemoveAllOfManyDeclarationsReturnsNull()
{
var code =
@"partial class C
{
}
partial class C
{
}";
var expected =
@"";
var solution = GetSolution(code);
var symbol = GetSymbols(solution, "C").First();
var editor = new SymbolEditor(solution);
var newSymbol = (INamedTypeSymbol)editor.EditAllDeclarationsAsync(symbol, (d, g) => null).Result;
Assert.Null(newSymbol);
var actual = GetActual(editor.GetChangedDocuments().First());
Assert.Equal(expected, actual);
}
}
}
......@@ -1167,7 +1167,7 @@ public class C { } // end").Members[0];
var added = g.AddAttributes(cls, g.Attribute("a"));
VerifySyntax<ClassDeclarationSyntax>(added, "// comment\r\n[a]\r\npublic class C\r\n{\r\n} // end\r\n");
var removed = g.RemoveAttributes(added);
var removed = g.RemoveAllAttributes(added);
VerifySyntax<ClassDeclarationSyntax>(removed, "// comment\r\npublic class C\r\n{\r\n} // end\r\n");
var attrWithComment = g.GetAttributes(added).First();
......@@ -1352,6 +1352,37 @@ private void AssertMemberNamesEqual(string expectedName, SyntaxNode declaration)
AssertNamesEqual(new[] { expectedName }, g.GetMembers(declaration));
}
[Fact]
public void TestAddNamespaceImports()
{
AssertMemberNamesEqual("x.y", g.AddNamespaceImports(g.CompilationUnit(), g.NamespaceImportDeclaration("x.y")));
AssertMemberNamesEqual(new[] { "x.y", "z" }, g.AddNamespaceImports(g.CompilationUnit(), g.NamespaceImportDeclaration("x.y"), g.IdentifierName("z")));
AssertMemberNamesEqual("", g.AddNamespaceImports(g.CompilationUnit(), g.MethodDeclaration("m")));
AssertMemberNamesEqual(new[] { "x", "y.z" }, g.AddNamespaceImports(g.CompilationUnit(g.IdentifierName("x")), g.DottedName("y.z")));
}
[Fact]
public void TestRemoveNamespaceImports()
{
TestRemoveAllNamespaceImports(g.CompilationUnit(g.NamespaceImportDeclaration("x")));
TestRemoveAllNamespaceImports(g.CompilationUnit(g.NamespaceImportDeclaration("x"), g.IdentifierName("y")));
TestRemoveNamespaceImport(g.CompilationUnit(g.NamespaceImportDeclaration("x")), "x", new string[] { });
TestRemoveNamespaceImport(g.CompilationUnit(g.NamespaceImportDeclaration("x"), g.IdentifierName("y")), "x", new[] { "y" });
TestRemoveNamespaceImport(g.CompilationUnit(g.NamespaceImportDeclaration("x"), g.IdentifierName("y")), "y", new[] { "x" });
}
private void TestRemoveAllNamespaceImports(SyntaxNode declaration)
{
Assert.Equal(0, g.GetNamespaceImports(g.RemoveNamespaceImports(declaration, g.GetNamespaceImports(declaration))).Count);
}
private void TestRemoveNamespaceImport(SyntaxNode declaration, string name, string[] remainingNames)
{
var newDecl = g.RemoveNamespaceImports(declaration, g.GetNamespaceImports(declaration).First(m => g.GetName(m) == name));
AssertMemberNamesEqual(remainingNames, newDecl);
}
[Fact]
public void TestAddMembers()
{
......@@ -1371,21 +1402,29 @@ public void TestAddMembers()
}
[Fact]
public void TestWithMembers()
public void TestRemoveMembers()
{
// remove all members
TestRemoveAllMembers(g.ClassDeclaration("c", members: new[] { g.MethodDeclaration("m") }));
TestRemoveAllMembers(g.StructDeclaration("s", members: new[] { g.MethodDeclaration("m") }));
TestRemoveAllMembers(g.InterfaceDeclaration("i", members: new[] { g.MethodDeclaration("m") }));
TestRemoveAllMembers(g.EnumDeclaration("i", members: new[] { g.EnumMember("v") }));
TestRemoveAllMembers(g.NamespaceDeclaration("n", new[] { g.NamespaceDeclaration("n") }));
TestRemoveAllMembers(g.CompilationUnit(declarations: new[] { g.NamespaceDeclaration("n") }));
TestRemoveMember(g.ClassDeclaration("c", members: new[] { g.MethodDeclaration("m1"), g.MethodDeclaration("m2") }), "m1", new[] { "m2" });
TestRemoveMember(g.StructDeclaration("s", members: new[] { g.MethodDeclaration("m1"), g.MethodDeclaration("m2") }), "m1", new[] { "m2" });
}
private void TestRemoveAllMembers(SyntaxNode declaration)
{
AssertMemberNamesEqual("m", g.WithMembers(g.ClassDeclaration("d"), new[] { g.MethodDeclaration("m") }));
AssertMemberNamesEqual("m", g.WithMembers(g.StructDeclaration("s"), new[] { g.MethodDeclaration("m") }));
AssertMemberNamesEqual("m", g.WithMembers(g.InterfaceDeclaration("i"), new[] { g.MethodDeclaration("m") }));
AssertMemberNamesEqual("v", g.WithMembers(g.EnumDeclaration("e"), new[] { g.EnumMember("v") }));
AssertMemberNamesEqual("n2", g.WithMembers(g.NamespaceDeclaration("n"), new[] { g.NamespaceDeclaration("n2") }));
AssertMemberNamesEqual("n", g.WithMembers(g.CompilationUnit(), new[] { g.NamespaceDeclaration("n") }));
Assert.Equal(0, g.GetMembers(g.RemoveMembers(declaration, g.GetMembers(declaration))).Count);
}
Assert.Equal(0, g.GetMembers(g.WithMembers(g.ClassDeclaration("d", members: new[] { g.MethodDeclaration("m") }), null)).Count);
Assert.Equal(0, g.GetMembers(g.WithMembers(g.StructDeclaration("s", members: new[] { g.MethodDeclaration("m") }), null)).Count);
Assert.Equal(0, g.GetMembers(g.WithMembers(g.InterfaceDeclaration("i", members: new[] { g.MethodDeclaration("m") }), null)).Count);
Assert.Equal(0, g.GetMembers(g.WithMembers(g.EnumDeclaration("i", members: new[] { g.EnumMember("v") }), null)).Count);
Assert.Equal(0, g.GetMembers(g.WithMembers(g.NamespaceDeclaration("n", new[] { g.NamespaceDeclaration("n") }), null)).Count);
Assert.Equal(0, g.GetMembers(g.WithMembers(g.CompilationUnit(declarations: new[] { g.NamespaceDeclaration("n") }), null)).Count);
private void TestRemoveMember(SyntaxNode declaration, string name, string[] remainingNames)
{
var newDecl = g.RemoveMembers(declaration, g.GetMembers(declaration).First(m => g.GetName(m) == name));
AssertMemberNamesEqual(remainingNames, newDecl);
}
[Fact]
......@@ -1633,37 +1672,59 @@ public void TestGetParameters()
}
[Fact]
public void TestWithParameters()
public void TestAddParameters()
{
Assert.Equal(1, g.GetParameters(g.WithParameters(g.MethodDeclaration("m"), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
Assert.Equal(1, g.GetParameters(g.WithParameters(g.ConstructorDeclaration(), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
Assert.Equal(2, g.GetParameters(g.WithParameters(g.IndexerDeclaration(new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) }, g.IdentifierName("t")), new[] { g.ParameterDeclaration("p2", g.IdentifierName("t2")), g.ParameterDeclaration("p3", g.IdentifierName("t3")) })).Count);
Assert.Equal(1, g.GetParameters(g.AddParameters(g.MethodDeclaration("m"), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
Assert.Equal(1, g.GetParameters(g.AddParameters(g.ConstructorDeclaration(), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
Assert.Equal(3, g.GetParameters(g.AddParameters(g.IndexerDeclaration(new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) }, g.IdentifierName("t")), new[] { g.ParameterDeclaration("p2", g.IdentifierName("t2")), g.ParameterDeclaration("p3", g.IdentifierName("t3")) })).Count);
Assert.Equal(1, g.GetParameters(g.WithParameters(g.ValueReturningLambdaExpression(g.IdentifierName("expr")), new[] { g.LambdaParameter("p") })).Count);
Assert.Equal(1, g.GetParameters(g.WithParameters(g.VoidReturningLambdaExpression(g.IdentifierName("expr")), new[] { g.LambdaParameter("p") })).Count);
Assert.Equal(1, g.GetParameters(g.AddParameters(g.ValueReturningLambdaExpression(g.IdentifierName("expr")), new[] { g.LambdaParameter("p") })).Count);
Assert.Equal(1, g.GetParameters(g.AddParameters(g.VoidReturningLambdaExpression(g.IdentifierName("expr")), new[] { g.LambdaParameter("p") })).Count);
Assert.Equal(1, g.GetParameters(g.WithParameters(g.DelegateDeclaration("d"), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
Assert.Equal(1, g.GetParameters(g.AddParameters(g.DelegateDeclaration("d"), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
Assert.Equal(0, g.GetParameters(g.WithParameters(g.ClassDeclaration("c"), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
Assert.Equal(0, g.GetParameters(g.WithParameters(g.IdentifierName("x"), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
Assert.Equal(0, g.GetParameters(g.AddParameters(g.ClassDeclaration("c"), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
Assert.Equal(0, g.GetParameters(g.AddParameters(g.IdentifierName("x"), new[] { g.ParameterDeclaration("p", g.IdentifierName("t")) })).Count);
}
[Fact]
public void TestGetInitializer()
public void TestGetExpression()
{
Assert.Equal("x", g.GetInitializer(g.FieldDeclaration("f", g.IdentifierName("t"), initializer: g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetInitializer(g.ParameterDeclaration("p", g.IdentifierName("t"), initializer: g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetInitializer(g.LocalDeclarationStatement("loc", initializer: g.IdentifierName("x"))).ToString());
Assert.Null(g.GetInitializer(g.IdentifierName("e")));
// initializers
Assert.Equal("x", g.GetExpression(g.FieldDeclaration("f", g.IdentifierName("t"), initializer: g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetExpression(g.ParameterDeclaration("p", g.IdentifierName("t"), initializer: g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetExpression(g.LocalDeclarationStatement("loc", initializer: g.IdentifierName("x"))).ToString());
// lambda bodies
Assert.Null(g.GetExpression(g.ValueReturningLambdaExpression("p", new[] { g.IdentifierName("x") })));
Assert.Equal(1, g.GetStatements(g.ValueReturningLambdaExpression("p", new[] { g.IdentifierName("x") })).Count);
Assert.Equal("x", g.GetExpression(g.ValueReturningLambdaExpression(g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetExpression(g.VoidReturningLambdaExpression(g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetExpression(g.ValueReturningLambdaExpression("p", g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetExpression(g.VoidReturningLambdaExpression("p", g.IdentifierName("x"))).ToString());
Assert.Null(g.GetExpression(g.IdentifierName("e")));
}
[Fact]
public void TestWithInitializer()
public void TestWithExpression()
{
Assert.Equal("x", g.GetInitializer(g.WithInitializer(g.FieldDeclaration("f", g.IdentifierName("t")), g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetInitializer(g.WithInitializer(g.ParameterDeclaration("p", g.IdentifierName("t")), g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetInitializer(g.WithInitializer(g.LocalDeclarationStatement(g.IdentifierName("t"), "loc"), g.IdentifierName("x"))).ToString());
Assert.Null(g.GetInitializer(g.WithInitializer(g.IdentifierName("e"), g.IdentifierName("x"))));
// initializers
Assert.Equal("x", g.GetExpression(g.WithExpression(g.FieldDeclaration("f", g.IdentifierName("t")), g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetExpression(g.WithExpression(g.ParameterDeclaration("p", g.IdentifierName("t")), g.IdentifierName("x"))).ToString());
Assert.Equal("x", g.GetExpression(g.WithExpression(g.LocalDeclarationStatement(g.IdentifierName("t"), "loc"), g.IdentifierName("x"))).ToString());
// lambda bodies
Assert.Equal("y", g.GetExpression(g.WithExpression(g.ValueReturningLambdaExpression("p", new[] { g.IdentifierName("x") }), g.IdentifierName("y"))).ToString());
Assert.Equal("y", g.GetExpression(g.WithExpression(g.VoidReturningLambdaExpression("p", new[] { g.IdentifierName("x") }), g.IdentifierName("y"))).ToString());
Assert.Equal("y", g.GetExpression(g.WithExpression(g.ValueReturningLambdaExpression(new[] { g.IdentifierName("x") }), g.IdentifierName("y"))).ToString());
Assert.Equal("y", g.GetExpression(g.WithExpression(g.VoidReturningLambdaExpression(new[] { g.IdentifierName("x") }), g.IdentifierName("y"))).ToString());
Assert.Equal("y", g.GetExpression(g.WithExpression(g.ValueReturningLambdaExpression("p", g.IdentifierName("x")), g.IdentifierName("y"))).ToString());
Assert.Equal("y", g.GetExpression(g.WithExpression(g.VoidReturningLambdaExpression("p", g.IdentifierName("x")), g.IdentifierName("y"))).ToString());
Assert.Equal("y", g.GetExpression(g.WithExpression(g.ValueReturningLambdaExpression(g.IdentifierName("x")), g.IdentifierName("y"))).ToString());
Assert.Equal("y", g.GetExpression(g.WithExpression(g.VoidReturningLambdaExpression(g.IdentifierName("x")), g.IdentifierName("y"))).ToString());
Assert.Null(g.GetExpression(g.WithExpression(g.IdentifierName("e"), g.IdentifierName("x"))));
}
[Fact]
......
......@@ -23,6 +23,11 @@ public SymbolEditor(Solution solution)
this.currentSolution = solution;
}
public SymbolEditor(Document document)
: this(document.Project.Solution)
{
}
/// <summary>
/// The original solution.
/// </summary>
......@@ -79,20 +84,20 @@ public async Task<ISymbol> GetCurrentSymbolAsync(ISymbol symbol, CancellationTok
var project = this.currentSolution.GetProject(symbol.ContainingAssembly);
if (project != null)
{
return await GetCurrentSymbolAsync(project.Id, symbolId, cancellationToken).ConfigureAwait(false);
return await GetSymbolAsync(this.currentSolution, project.Id, symbolId, cancellationToken).ConfigureAwait(false);
}
// check to see if it is from original solution
project = this.originalSolution.GetProject(symbol.ContainingAssembly);
if (project != null)
{
return await GetCurrentSymbolAsync(project.Id, symbolId, cancellationToken).ConfigureAwait(false);
return await GetSymbolAsync(this.currentSolution, project.Id, symbolId, cancellationToken).ConfigureAwait(false);
}
// try to find symbol from any project (from current solution) with matching assembly name
foreach (var projectId in this.GetProjectsForAssembly(symbol.ContainingAssembly))
{
var currentSymbol = await GetCurrentSymbolAsync(projectId, symbolId, cancellationToken).ConfigureAwait(false);
var currentSymbol = await GetSymbolAsync(this.currentSolution, projectId, symbolId, cancellationToken).ConfigureAwait(false);
if (currentSymbol != null)
{
return currentSymbol;
......@@ -122,9 +127,9 @@ private ImmutableArray<ProjectId> GetProjectsForAssembly(IAssemblySymbol assembl
return projectIds;
}
private async Task<ISymbol> GetCurrentSymbolAsync(ProjectId projectId, string symbolId, CancellationToken cancellationToken)
private async Task<ISymbol> GetSymbolAsync(Solution solution, ProjectId projectId, string symbolId, CancellationToken cancellationToken)
{
var comp = await this.currentSolution.GetProject(projectId).GetCompilationAsync(cancellationToken).ConfigureAwait(false);
var comp = await solution.GetProject(projectId).GetCompilationAsync(cancellationToken).ConfigureAwait(false);
var symbols = DocumentationCommentId.GetSymbolsForDeclarationId(symbolId, comp).ToList();
if (symbols.Count == 1)
......@@ -150,10 +155,41 @@ private async Task<ISymbol> GetCurrentSymbolAsync(ProjectId projectId, string sy
return null;
}
/// <summary>
/// Get's the current symbol for a declaration at a specified position within a documment.
/// </summary>
private Task<ISymbol> GetCurrentSymbolAsync(DocumentId docId, int position, DeclarationKind kind, CancellationToken cancellationToken)
{
return this.GetSymbolAsync(this.currentSolution, docId, position, kind, cancellationToken);
}
private DeclarationKind GetKind(SyntaxNode declaration)
{
return SyntaxGenerator.GetGenerator(this.currentSolution.Workspace, declaration.Language).GetDeclarationKind(declaration);
}
private async Task<ISymbol> GetSymbolAsync(Solution solution, DocumentId docId, int position, DeclarationKind kind, CancellationToken cancellationToken)
{
var doc = solution.GetDocument(docId);
var model = await doc.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var generator = SyntaxGenerator.GetGenerator(doc);
var node = model.SyntaxTree.GetRoot().FindToken(position).Parent;
var decl = generator.GetDeclaration(node, kind);
if (decl != null)
{
return model.GetDeclaredSymbol(decl, cancellationToken);
}
else
{
return null;
}
}
/// <summary>
/// Gets the declaration syntax nodes for a given symbol.
/// </summary>
public IEnumerable<SyntaxNode> GetDeclarations(ISymbol symbol)
private IEnumerable<SyntaxNode> GetDeclarations(ISymbol symbol)
{
return symbol.DeclaringSyntaxReferences
.Select(sr => sr.GetSyntax())
......@@ -172,6 +208,7 @@ private bool TryGetBestDeclarationForSingleEdit(ISymbol symbol, out SyntaxNode d
/// <summary>
/// Enables editting the definition of one of the symbol's declarations.
/// Partial types and methods may have more than one declaration.
/// </summary>
/// <param name="symbol">The symbol to edit.</param>
/// <param name="declarationEditor">The function that produces the changed declaration.</param>
......@@ -184,86 +221,147 @@ private bool TryGetBestDeclarationForSingleEdit(ISymbol symbol, out SyntaxNode d
{
var currentSymbol = await this.GetCurrentSymbolAsync(symbol, cancellationToken).ConfigureAwait(false);
CheckSymbolArgument(currentSymbol, symbol);
SyntaxNode declaration;
if (TryGetBestDeclarationForSingleEdit(currentSymbol, out declaration))
{
var doc = this.currentSolution.GetDocument(declaration.SyntaxTree);
var root = declaration.SyntaxTree.GetRoot();
var generator = SyntaxGenerator.GetGenerator(this.currentSolution.Workspace, declaration.Language);
var newDecl = declarationEditor(declaration, generator);
var newRoot = root.ReplaceNode(declaration, newDecl);
var newDoc = doc.WithSyntaxRoot(newRoot);
this.currentSolution = newDoc.Project.Solution;
return await this.GetCurrentSymbolAsync(symbol, cancellationToken).ConfigureAwait(false);
return await this.EditDeclarationAsync(currentSymbol, declaration, declarationEditor, cancellationToken).ConfigureAwait(false);
}
return null;
}
private void CheckSymbolArgument(ISymbol currentSymbol, ISymbol argSymbol)
{
if (currentSymbol == null)
{
throw new ArgumentException(string.Format("The symbol '{0}' cannot be located within the current solution.".NeedsLocalization(), argSymbol.Name));
}
}
private async Task<ISymbol> EditDeclarationAsync(ISymbol currentSymbol, SyntaxNode declaration, Func<SyntaxNode, SyntaxGenerator, SyntaxNode> declarationEditor, CancellationToken cancellationToken)
{
var doc = this.currentSolution.GetDocument(declaration.SyntaxTree);
var root = declaration.SyntaxTree.GetRoot();
var generator = SyntaxGenerator.GetGenerator(this.currentSolution.Workspace, declaration.Language);
var newDecl = declarationEditor(declaration, generator);
SyntaxNode newRoot;
if (newDecl != null)
{
newRoot = root.ReplaceNode(declaration, newDecl);
}
else
{
newRoot = root.RemoveNode(declaration, SyntaxGenerator.DefaultRemoveOptions);
}
var newDoc = doc.WithSyntaxRoot(newRoot);
this.currentSolution = newDoc.Project.Solution;
if (newDecl != null)
{
return await this.GetCurrentSymbolAsync(doc.Id, declaration.Span.Start, GetKind(newDecl), cancellationToken).ConfigureAwait(false);
}
else
{
return await this.GetCurrentSymbolAsync(currentSymbol, cancellationToken).ConfigureAwait(false);
}
}
/// <summary>
/// Enables editting the definition of one of the symbol's declarations.
/// Partial types and methods may have more than one declaration.
/// </summary>
/// <param name="symbol">The symbol to edit.</param>
/// <param name="location">A location within one of the symbol's declarations.</param>
/// <param name="declarationEditor">The function that produces the changed declaration.</param>
/// <param name="cancellationToken">An optional <see cref="CancellationToken"/>.</param>
/// <returns>The new symbol including the changes.</returns>
public Task<ISymbol> EditOneDeclarationAsync(
public async Task<ISymbol> EditOneDeclarationAsync(
ISymbol symbol,
Location location,
Func<SyntaxNode, SyntaxGenerator, SyntaxNode> declarationEditor,
CancellationToken cancellationToken = default(CancellationToken))
{
var sourceTree = location.SourceTree;
return this.EditAllDeclarationsAsync(symbol,
(d, g) =>
{
if (d.SyntaxTree == sourceTree && d.FullSpan.IntersectsWith(location.SourceSpan.Start))
{
return declarationEditor(d, g);
}
else
{
return d;
}
},
cancellationToken);
var doc = this.currentSolution.GetDocument(sourceTree);
if (doc != null)
{
return await this.EditOneDeclarationAsync(symbol, doc.Id, location.SourceSpan.Start, declarationEditor, cancellationToken).ConfigureAwait(false);
}
doc = this.originalSolution.GetDocument(sourceTree);
if (doc != null)
{
return await this.EditOneDeclarationAsync(symbol, doc.Id, location.SourceSpan.Start, declarationEditor, cancellationToken).ConfigureAwait(false);
}
throw new ArgumentException("The location specified is not part of the solution.", nameof(location));
}
private async Task<ISymbol> EditOneDeclarationAsync(
ISymbol symbol,
DocumentId documentId,
int position,
Func<SyntaxNode, SyntaxGenerator, SyntaxNode> declarationEditor,
CancellationToken cancellationToken = default(CancellationToken))
{
var currentSymbol = await this.GetCurrentSymbolAsync(symbol, cancellationToken).ConfigureAwait(false);
CheckSymbolArgument(currentSymbol, symbol);
var decl = this.GetDeclarations(currentSymbol).FirstOrDefault(d =>
{
var doc = this.currentSolution.GetDocument(d.SyntaxTree);
return doc != null && doc.Id == documentId && d.FullSpan.IntersectsWith(position);
});
if (decl == null)
{
throw new ArgumentNullException("The position is not within the symbol's declaration".NeedsLocalization(), nameof(position));
}
return await this.EditDeclarationAsync(currentSymbol, decl, declarationEditor, cancellationToken).ConfigureAwait(false);
}
/// <summary>
/// Enables editting the symbol's declaration where the member is also declared.
/// Partial types and methods may have more than one declaration.
/// </summary>
/// <param name="symbol">The symbol to edit.</param>
/// <param name="member">A symbol whose declaration is contained within one of the primary symbol's declarations.</param>
/// <param name="declarationEditor">The function that produces the changed declaration.</param>
/// <param name="cancellationToken">An optional <see cref="CancellationToken"/>.</param>
/// <returns>The new symbol including the changes.</returns>
public async Task<ISymbol> EditDeclarationsWithMemberDeclaredAsync(
public async Task<ISymbol> EditOneDeclarationAsync(
ISymbol symbol,
ISymbol member,
Func<SyntaxNode, SyntaxGenerator, SyntaxNode> declarationEditor,
CancellationToken cancellationToken = default(CancellationToken))
{
var currentSymbol = await this.GetCurrentSymbolAsync(symbol, cancellationToken).ConfigureAwait(false);
CheckSymbolArgument(currentSymbol, symbol);
var currentMember = await this.GetCurrentSymbolAsync(member, cancellationToken).ConfigureAwait(false);
var memberDecls = this.GetDeclarations(currentMember);
CheckSymbolArgument(currentMember, member);
return await this.EditAllDeclarationsAsync(symbol,
(d, g) =>
{
if (memberDecls.Any(md => md.SyntaxTree == d.SyntaxTree && d.FullSpan.IntersectsWith(md.FullSpan)))
{
return declarationEditor(d, g);
}
else
{
return d;
}
},
cancellationToken).ConfigureAwait(false);
// get first symbol declaration that encompasses at least one of the member declarations
var memberDecls = this.GetDeclarations(currentMember).ToList();
var declaration = this.GetDeclarations(currentSymbol).FirstOrDefault(d => memberDecls.Any(md => md.SyntaxTree == d.SyntaxTree && d.FullSpan.IntersectsWith(md.FullSpan)));
if (declaration == null)
{
throw new ArgumentException(string.Format("The member '{0}' is not declared within the declaration of the symbol.".NeedsLocalization(), member.Name));
}
return await this.EditDeclarationAsync(currentSymbol, declaration, declarationEditor, cancellationToken).ConfigureAwait(false);
}
/// <summary>
/// Enables editting all the symbol's declarations.
/// Enables editting all the symbol's declarations.
/// Partial types and methods may have more than one declaration.
/// </summary>
/// <param name="symbol">The symbol to be editted.</param>
/// <param name="declarationEditor">The function that produces a changed declaration.</param>
......@@ -273,18 +371,41 @@ public async Task<ISymbol> EditAllDeclarationsAsync(ISymbol symbol, Func<SyntaxN
{
var currentSymbol = await this.GetCurrentSymbolAsync(symbol, cancellationToken).ConfigureAwait(false);
var docMap = new Dictionary<SyntaxTree, Document>();
var changeMap = new Dictionary<SyntaxNode, SyntaxNode>();
foreach (var decls in this.GetDeclarations(currentSymbol).GroupBy(d => d.SyntaxTree))
{
var doc = this.currentSolution.GetDocument(decls.Key);
docMap.Add(decls.Key, doc);
var root = decls.Key.GetRoot();
var generator = SyntaxGenerator.GetGenerator(doc);
var newRoot = root.ReplaceNodes(decls, (original, rewritten) => declarationEditor(original, generator));
var changes = decls.Select(d => new KeyValuePair<SyntaxNode, SyntaxNode>(d, declarationEditor(d, generator)));
changeMap.AddRange(changes);
var newRoot = root.ReplaceNodes(decls, (original, rewritten) => changeMap[original]);
var newDoc = doc.WithSyntaxRoot(newRoot);
this.currentSolution = newDoc.Project.Solution;
}
return await this.GetCurrentSymbolAsync(symbol, cancellationToken).ConfigureAwait(false);
// try to find new symbol using the first lexically changed decl in one of the trees, because the position will not have changed
var firstTreeChanges = changeMap.GroupBy(kvp => kvp.Key.SyntaxTree).FirstOrDefault();
if (firstTreeChanges != null)
{
var doc = docMap[firstTreeChanges.Key];
var firstChangedDecl = firstTreeChanges.OrderBy(kvp => kvp.Key.SpanStart).FirstOrDefault().Value;
if (firstChangedDecl != null)
{
return await GetCurrentSymbolAsync(doc.Id, firstChangedDecl.SpanStart, GetKind(firstChangedDecl), cancellationToken).ConfigureAwait(false);
}
}
// if prior method fails (possibly due to declaration being removed), attempt to rebind the original symbol
return await GetCurrentSymbolAsync(symbol, cancellationToken).ConfigureAwait(false);
}
}
}
......@@ -20,6 +20,8 @@ namespace Microsoft.CodeAnalysis.CodeGeneration
/// </summary>
public abstract class SyntaxGenerator : ILanguageService
{
public static SyntaxRemoveOptions DefaultRemoveOptions = SyntaxRemoveOptions.KeepUnbalancedDirectives | SyntaxRemoveOptions.AddElasticMarker;
/// <summary>
/// Gets the <see cref="SyntaxGenerator"/> for the specified language.
/// </summary>
......@@ -58,6 +60,26 @@ public SyntaxNode GetDeclaration(SyntaxNode node)
return null;
}
/// <summary>
/// Returns the enclosing declaration of the specified kind or null.
/// </summary>
public SyntaxNode GetDeclaration(SyntaxNode node, DeclarationKind kind)
{
while (node != null)
{
if (GetDeclarationKind(node) == kind)
{
return node;
}
else
{
node = node.Parent;
}
}
return null;
}
/// <summary>
/// Creates a field declaration.
/// </summary>
......@@ -519,11 +541,15 @@ public SyntaxNode WithTypeConstraint(SyntaxNode declaration, string typeParamete
/// <summary>
/// Creates a namespace declaration.
/// </summary>
/// <param name="name">The name of the namespace.</param>
/// <param name="declarations">Zero or more namespace or type declarations.</param>
public abstract SyntaxNode NamespaceDeclaration(SyntaxNode name, IEnumerable<SyntaxNode> declarations);
/// <summary>
/// Creates a namespace declaration.
/// </summary>
/// <param name="name">The name of the namespace.</param>
/// <param name="declarations">Zero or more namespace or type declarations.</param>
public SyntaxNode NamespaceDeclaration(SyntaxNode name, params SyntaxNode[] declarations)
{
return NamespaceDeclaration(name, (IEnumerable<SyntaxNode>)declarations);
......@@ -532,6 +558,8 @@ public SyntaxNode NamespaceDeclaration(SyntaxNode name, params SyntaxNode[] decl
/// <summary>
/// Creates a namespace declaration.
/// </summary>
/// <param name="name">The name of the namespace.</param>
/// <param name="declarations">Zero or more namespace or type declarations.</param>
public SyntaxNode NamespaceDeclaration(string name, IEnumerable<SyntaxNode> declarations)
{
return NamespaceDeclaration(DottedName(name), declarations);
......@@ -540,6 +568,8 @@ public SyntaxNode NamespaceDeclaration(string name, IEnumerable<SyntaxNode> decl
/// <summary>
/// Creates a namespace declaration.
/// </summary>
/// <param name="name">The name of the namespace.</param>
/// <param name="declarations">Zero or more namespace or type declarations.</param>
public SyntaxNode NamespaceDeclaration(string name, params SyntaxNode[] declarations)
{
return NamespaceDeclaration(DottedName(name), (IEnumerable<SyntaxNode>)declarations);
......@@ -548,11 +578,13 @@ public SyntaxNode NamespaceDeclaration(string name, params SyntaxNode[] declarat
/// <summary>
/// Creates a compilation unit declaration
/// </summary>
public abstract SyntaxNode CompilationUnit(IEnumerable<SyntaxNode> declarations = null);
/// <param name="declarations">Zero or more namespace import, namespace or type declarations.</param>
public abstract SyntaxNode CompilationUnit(IEnumerable<SyntaxNode> declarations);
/// <summary>
/// Creates a compilation unit declaration
/// </summary>
/// <param name="declarations">Zero or more namespace import, namespace or type declarations.</param>
public SyntaxNode CompilationUnit(params SyntaxNode[] declarations)
{
return CompilationUnit((IEnumerable<SyntaxNode>)declarations);
......@@ -561,11 +593,13 @@ public SyntaxNode CompilationUnit(params SyntaxNode[] declarations)
/// <summary>
/// Creates a namespace import declaration.
/// </summary>
/// <param name="name">The name of the namespace being imported.</param>
public abstract SyntaxNode NamespaceImportDeclaration(SyntaxNode name);
/// <summary>
/// Creates a namespace import declaration.
/// </summary>
/// <param name="name">The name of the namespace being imported.</param>
public SyntaxNode NamespaceImportDeclaration(string name)
{
return NamespaceImportDeclaration(DottedName(name));
......@@ -623,24 +657,40 @@ public SyntaxNode AttributeArgument(SyntaxNode expression)
}
/// <summary>
/// Gets the attributes of a declaration.
/// Removes all attributes from the declaration, including return attributes.
/// </summary>
public abstract IEnumerable<SyntaxNode> GetAttributes(SyntaxNode declaration);
public abstract SyntaxNode RemoveAllAttributes(SyntaxNode declaration);
/// <summary>
/// Removes all attributes from the declaration, including return attributes.
/// Gets the attributes of a declaration, not including the return attributes.
/// </summary>
public abstract SyntaxNode RemoveAttributes(SyntaxNode declaration);
public abstract IReadOnlyList<SyntaxNode> GetAttributes(SyntaxNode declaration);
/// <summary>
/// Removes specific attributes from the declaration.
/// </summary>
public abstract SyntaxNode RemoveAttributes(SyntaxNode declaration, IEnumerable<SyntaxNode> attributes);
/// <summary>
/// Creates a new instance of the declaration with the attributes inserted.
/// </summary>
public abstract SyntaxNode InsertAttributes(SyntaxNode declaration, int index, IEnumerable<SyntaxNode> attributes);
/// <summary>
/// Creates a new instance of the declaration with the attributes inserted.
/// </summary>
public SyntaxNode InsertAttributes(SyntaxNode declaration, int index, params SyntaxNode[] attributes)
{
return this.InsertAttributes(declaration, index, (IEnumerable<SyntaxNode>)attributes);
}
/// <summary>
/// Creates a new instance of a declaration with the specified attributes added.
/// </summary>
public abstract SyntaxNode AddAttributes(SyntaxNode declaration, IEnumerable<SyntaxNode> attributes);
public SyntaxNode AddAttributes(SyntaxNode declaration, IEnumerable<SyntaxNode> attributes)
{
return this.InsertAttributes(declaration, this.GetAttributes(declaration).Count, attributes);
}
/// <summary>
/// Creates a new instance of a declaration with the specified attributes added.
......@@ -650,21 +700,93 @@ public SyntaxNode AddAttributes(SyntaxNode declaration, params SyntaxNode[] attr
return AddAttributes(declaration, (IEnumerable<SyntaxNode>)attributes);
}
public abstract IEnumerable<SyntaxNode> GetReturnAttributes(SyntaxNode declaration);
/// <summary>
/// Gets the return attributes from the declaration.
/// </summary>
public abstract IReadOnlyList<SyntaxNode> GetReturnAttributes(SyntaxNode declaration);
public abstract SyntaxNode WithReturnAttributes(SyntaxNode declaration, IEnumerable<SyntaxNode> attributes);
/// <summary>
/// Removes the specified return attributes from the declaration.
/// </summary>
public abstract SyntaxNode RemoveReturnAttributes(SyntaxNode declaration, IEnumerable<SyntaxNode> attributes);
/// <summary>
/// Creates a new instance of a method declaration with return attributes inserted.
/// </summary>
public abstract SyntaxNode InsertReturnAttributes(SyntaxNode declaration, int index, IEnumerable<SyntaxNode> attributes);
/// <summary>
/// Creates a new instance of a method declaration with return attributes inserted.
/// </summary>
public SyntaxNode InsertReturnAttributes(SyntaxNode declaration, int index, params SyntaxNode[] attributes)
{
return this.InsertReturnAttributes(declaration, index, attributes);
}
/// <summary>
/// Creates a new instance of a method declaration with return attributes added.
/// </summary>
public abstract SyntaxNode AddReturnAttributes(SyntaxNode methodDeclaration, IEnumerable<SyntaxNode> attributes);
public SyntaxNode AddReturnAttributes(SyntaxNode declaration, IEnumerable<SyntaxNode> attributes)
{
return this.InsertReturnAttributes(declaration, this.GetReturnAttributes(declaration).Count, attributes);
}
/// <summary>
/// Creates a new instance of a method declaration node with return attributes added.
/// </summary>
public SyntaxNode AddReturnAttributes(SyntaxNode methodDeclaration, params SyntaxNode[] attributes)
public SyntaxNode AddReturnAttributes(SyntaxNode declaration, params SyntaxNode[] attributes)
{
return AddReturnAttributes(methodDeclaration, (IEnumerable<SyntaxNode>)attributes);
return AddReturnAttributes(declaration, (IEnumerable<SyntaxNode>)attributes);
}
/// <summary>
/// Gets the namespace imports that are part of the declaration.
/// </summary>
public abstract IReadOnlyList<SyntaxNode> GetNamespaceImports(SyntaxNode declaration);
/// <summary>
/// Creates a new instance of the declaration with the namespace imports inserted.
/// </summary>
public abstract SyntaxNode InsertNamespaceImports(SyntaxNode declaration, int index, IEnumerable<SyntaxNode> imports);
/// <summary>
/// Creates a new instance of the declaration with the namespace imports inserted.
/// </summary>
public SyntaxNode InsertNamespaceImports(SyntaxNode declaration, int index, params SyntaxNode[] imports)
{
return this.InsertNamespaceImports(declaration, index, imports);
}
/// <summary>
/// Creates a new instance of the declaration with the namespace imports added.
/// </summary>
public SyntaxNode AddNamespaceImports(SyntaxNode declaration, IEnumerable<SyntaxNode> imports)
{
return this.InsertNamespaceImports(declaration, this.GetNamespaceImports(declaration).Count, imports);
}
/// <summary>
/// Creates a new instance of the declaration with the namespace imports added.
/// </summary>
public SyntaxNode AddNamespaceImports(SyntaxNode declaration, params SyntaxNode[] imports)
{
return this.AddNamespaceImports(declaration, (IEnumerable<SyntaxNode>)imports);
}
/// <summary>
/// Creates a new instance of the declaration with the specified namespace imports removed.
/// </summary>
public SyntaxNode RemoveNamespaceImports(SyntaxNode declaration, IEnumerable<SyntaxNode> imports)
{
return declaration.RemoveNodes(imports, DefaultRemoveOptions);
}
/// <summary>
/// Creates a new instance of the declaration with the specified namespace imports removed.
/// </summary>
public SyntaxNode RemoveNamespaceImports(SyntaxNode declaration, params SyntaxNode[] imports)
{
return this.RemoveNamespaceImports(declaration, (IEnumerable<SyntaxNode>)imports);
}
/// <summary>
......@@ -673,14 +795,41 @@ public SyntaxNode AddReturnAttributes(SyntaxNode methodDeclaration, params Synta
public abstract IReadOnlyList<SyntaxNode> GetMembers(SyntaxNode declaration);
/// <summary>
/// Creates a new instance of the declaration with the members specified.
/// Creates a new instance of the declaration with the members removed.
/// </summary>
public SyntaxNode RemoveMembers(SyntaxNode declaration, IEnumerable<SyntaxNode> members)
{
return declaration.RemoveNodes(members, DefaultRemoveOptions);
}
/// <summary>
/// Creates a new instance of the declaration with the members removed.
/// </summary>
public abstract SyntaxNode WithMembers(SyntaxNode declaration, IEnumerable<SyntaxNode> members);
public SyntaxNode RemoveMembers(SyntaxNode declaration, params SyntaxNode[] members)
{
return this.RemoveMembers(declaration, (IEnumerable<SyntaxNode>)members);
}
/// <summary>
/// Creates a new instance of the declaration with the members inserted.
/// </summary>
public abstract SyntaxNode InsertMembers(SyntaxNode declaration, int index, IEnumerable<SyntaxNode> members);
/// <summary>
/// Creates a new instance of the declaration with the members inserted.
/// </summary>
public SyntaxNode InsertMembers(SyntaxNode declaration, int index, params SyntaxNode[] members)
{
return this.InsertMembers(declaration, index, (IEnumerable<SyntaxNode>)members);
}
/// <summary>
/// Creates a new instance of the declaration with the members added to the end.
/// </summary>
public abstract SyntaxNode AddMembers(SyntaxNode declaration, IEnumerable<SyntaxNode> members);
public SyntaxNode AddMembers(SyntaxNode declaration, IEnumerable<SyntaxNode> members)
{
return this.InsertMembers(declaration, this.GetMembers(declaration).Count, members);
}
/// <summary>
/// Creates a new instance of the declaration with the members added to the end.
......@@ -741,19 +890,35 @@ public SyntaxNode AddMembers(SyntaxNode declaration, params SyntaxNode[] members
public abstract IReadOnlyList<SyntaxNode> GetParameters(SyntaxNode declaration);
/// <summary>
/// Changes the list of parameters for the declaration.
/// Inserts the parameters at the specified index into the declaration.
/// </summary>
public abstract SyntaxNode WithParameters(SyntaxNode declaration, IEnumerable<SyntaxNode> parameters);
public abstract SyntaxNode InsertParameters(SyntaxNode declaration, int index, IEnumerable<SyntaxNode> parameters);
/// <summary>
/// Gets the initializer expression for the declaration.
/// Adds the parameters to the declaration.
/// </summary>
public abstract SyntaxNode GetInitializer(SyntaxNode declaration);
public SyntaxNode AddParameters(SyntaxNode declaration, IEnumerable<SyntaxNode> parameters)
{
return this.InsertParameters(declaration, this.GetParameters(declaration).Count, parameters);
}
/// <summary>
/// Removes the specified parameters from the declaration.
/// </summary>
public SyntaxNode RemoveParameters(SyntaxNode declaration, IEnumerable<SyntaxNode> parameters)
{
return declaration.RemoveNodes(parameters, DefaultRemoveOptions);
}
/// <summary>
/// Changes the intializer expression for the declaration.
/// Gets the expression associated with the declaration.
/// </summary>
public abstract SyntaxNode WithInitializer(SyntaxNode declaration, SyntaxNode initializer);
public abstract SyntaxNode GetExpression(SyntaxNode declaration);
/// <summary>
/// Changes the expression associated with the declaration.
/// </summary>
public abstract SyntaxNode WithExpression(SyntaxNode declaration, SyntaxNode expression);
/// <summary>
/// Gets the statements for the body of the declaration.
......@@ -789,14 +954,9 @@ public SyntaxNode AddMembers(SyntaxNode declaration, params SyntaxNode[] members
#region Utility
protected static TNode WithoutTrivia<TNode>(TNode node) where TNode : SyntaxNode
{
return node.WithoutLeadingTrivia().WithoutTrailingTrivia();
}
protected static SyntaxNode PreserveTrivia<TNode>(TNode node, Func<TNode, SyntaxNode> nodeChanger) where TNode : SyntaxNode
{
var nodeWithoutTrivia = WithoutTrivia(node);
var nodeWithoutTrivia = node.WithoutLeadingTrivia().WithoutTrailingTrivia();
var changedNode = nodeChanger(nodeWithoutTrivia);
......
......@@ -1647,7 +1647,7 @@ End Class ' end</x>.Value).Members(0)
Class C
End Class ' end</x>.Value)
Dim removed = g.RemoveAttributes(added)
Dim removed = g.RemoveAllAttributes(added)
VerifySyntax(Of ClassBlockSyntax)(
removed,
<x>' comment
......@@ -1939,35 +1939,102 @@ End Function</x>.Value)
End Sub
<Fact>
Public Sub TestWithParameters()
Assert.Equal(1, g.GetParameters(g.WithParameters(g.MethodDeclaration("m"), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(1, g.GetParameters(g.WithParameters(g.ConstructorDeclaration(), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(2, g.GetParameters(g.WithParameters(g.IndexerDeclaration({g.ParameterDeclaration("p", g.IdentifierName("t"))}, g.IdentifierName("t")), {g.ParameterDeclaration("p2", g.IdentifierName("t2")), g.ParameterDeclaration("p3", g.IdentifierName("t3"))})).Count)
Public Sub TestAddParameters()
Assert.Equal(1, g.GetParameters(g.AddParameters(g.MethodDeclaration("m"), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(1, g.GetParameters(g.AddParameters(g.ConstructorDeclaration(), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(3, g.GetParameters(g.AddParameters(g.IndexerDeclaration({g.ParameterDeclaration("p", g.IdentifierName("t"))}, g.IdentifierName("t")), {g.ParameterDeclaration("p2", g.IdentifierName("t2")), g.ParameterDeclaration("p3", g.IdentifierName("t3"))})).Count)
Assert.Equal(1, g.GetParameters(g.WithParameters(g.ValueReturningLambdaExpression(g.IdentifierName("expr")), {g.LambdaParameter("p")})).Count)
Assert.Equal(1, g.GetParameters(g.WithParameters(g.VoidReturningLambdaExpression(g.IdentifierName("expr")), {g.LambdaParameter("p")})).Count)
Assert.Equal(1, g.GetParameters(g.AddParameters(g.ValueReturningLambdaExpression(g.IdentifierName("expr")), {g.LambdaParameter("p")})).Count)
Assert.Equal(1, g.GetParameters(g.AddParameters(g.VoidReturningLambdaExpression(g.IdentifierName("expr")), {g.LambdaParameter("p")})).Count)
Assert.Equal(1, g.GetParameters(g.WithParameters(g.DelegateDeclaration("d"), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(1, g.GetParameters(g.AddParameters(g.DelegateDeclaration("d"), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(0, g.GetParameters(g.WithParameters(g.ClassDeclaration("c"), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(0, g.GetParameters(g.WithParameters(g.IdentifierName("x"), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(0, g.GetParameters(g.WithParameters(g.PropertyDeclaration("p", g.IdentifierName("t")), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(0, g.GetParameters(g.AddParameters(g.ClassDeclaration("c"), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(0, g.GetParameters(g.AddParameters(g.IdentifierName("x"), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
Assert.Equal(0, g.GetParameters(g.AddParameters(g.PropertyDeclaration("p", g.IdentifierName("t")), {g.ParameterDeclaration("p", g.IdentifierName("t"))})).Count)
End Sub
<Fact>
Public Sub TestGetInitializer()
Assert.Equal("x", g.GetInitializer(g.FieldDeclaration("f", g.IdentifierName("t"), initializer:=g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetInitializer(g.ParameterDeclaration("p", g.IdentifierName("t"), initializer:=g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetInitializer(g.LocalDeclarationStatement("loc", initializer:=g.IdentifierName("x"))).ToString())
Assert.Null(g.GetInitializer(g.IdentifierName("e")))
Public Sub TestGetExpression()
' initializers
Assert.Equal("x", g.GetExpression(g.FieldDeclaration("f", g.IdentifierName("t"), initializer:=g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetExpression(g.ParameterDeclaration("p", g.IdentifierName("t"), initializer:=g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetExpression(g.LocalDeclarationStatement("loc", initializer:=g.IdentifierName("x"))).ToString())
' lambda bodies
Assert.Null(g.GetExpression(g.ValueReturningLambdaExpression("p", {g.IdentifierName("x")})))
Assert.Equal(1, g.GetStatements(g.ValueReturningLambdaExpression("p", {g.IdentifierName("x")})).Count)
Assert.Equal("x", g.GetExpression(g.ValueReturningLambdaExpression(g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetExpression(g.VoidReturningLambdaExpression(g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetExpression(g.ValueReturningLambdaExpression("p", g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetExpression(g.VoidReturningLambdaExpression("p", g.IdentifierName("x"))).ToString())
Assert.Null(g.GetExpression(g.IdentifierName("e")))
End Sub
<Fact>
Public Sub TestWithInitializer()
Assert.Equal("x", g.GetInitializer(g.WithInitializer(g.FieldDeclaration("f", g.IdentifierName("t")), g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetInitializer(g.WithInitializer(g.ParameterDeclaration("p", g.IdentifierName("t")), g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetInitializer(g.WithInitializer(g.LocalDeclarationStatement(g.IdentifierName("t"), "loc"), g.IdentifierName("x"))).ToString())
Assert.Null(g.GetInitializer(g.WithInitializer(g.IdentifierName("e"), g.IdentifierName("x"))))
Public Sub TestWithExpression()
' initializers
Assert.Equal("x", g.GetExpression(g.WithExpression(g.FieldDeclaration("f", g.IdentifierName("t")), g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetExpression(g.WithExpression(g.ParameterDeclaration("p", g.IdentifierName("t")), g.IdentifierName("x"))).ToString())
Assert.Equal("x", g.GetExpression(g.WithExpression(g.LocalDeclarationStatement(g.IdentifierName("t"), "loc"), g.IdentifierName("x"))).ToString())
' lambda bodies
Assert.Equal("y", g.GetExpression(g.WithExpression(g.ValueReturningLambdaExpression("p", {g.IdentifierName("x")}), g.IdentifierName("y"))).ToString())
Assert.Equal("y", g.GetExpression(g.WithExpression(g.VoidReturningLambdaExpression("p", {g.IdentifierName("x")}), g.IdentifierName("y"))).ToString())
Assert.Equal("y", g.GetExpression(g.WithExpression(g.ValueReturningLambdaExpression({g.IdentifierName("x")}), g.IdentifierName("y"))).ToString())
Assert.Equal("y", g.GetExpression(g.WithExpression(g.VoidReturningLambdaExpression({g.IdentifierName("x")}), g.IdentifierName("y"))).ToString())
Assert.Equal("y", g.GetExpression(g.WithExpression(g.ValueReturningLambdaExpression("p", g.IdentifierName("x")), g.IdentifierName("y"))).ToString())
Assert.Equal("y", g.GetExpression(g.WithExpression(g.VoidReturningLambdaExpression("p", g.IdentifierName("x")), g.IdentifierName("y"))).ToString())
Assert.Equal("y", g.GetExpression(g.WithExpression(g.ValueReturningLambdaExpression(g.IdentifierName("x")), g.IdentifierName("y"))).ToString())
Assert.Equal("y", g.GetExpression(g.WithExpression(g.VoidReturningLambdaExpression(g.IdentifierName("x")), g.IdentifierName("y"))).ToString())
VerifySyntax(Of SingleLineLambdaExpressionSyntax)(
g.WithExpression(g.ValueReturningLambdaExpression({g.IdentifierName("s")}), g.IdentifierName("e")),
<x>Function() e</x>.Value)
Assert.Null(g.GetExpression(g.WithExpression(g.IdentifierName("e"), g.IdentifierName("x"))))
End Sub
<Fact>
Public Sub TestWithExpression_LambdaChanges()
' multi line function changes to single line function
VerifySyntax(Of SingleLineLambdaExpressionSyntax)(
g.WithExpression(g.ValueReturningLambdaExpression({g.IdentifierName("s")}), g.IdentifierName("e")),
<x>Function() e</x>.Value)
' multi line sub changes to single line sub
VerifySyntax(Of SingleLineLambdaExpressionSyntax)(
g.WithExpression(g.VoidReturningLambdaExpression({g.IdentifierName("s")}), g.IdentifierName("e")),
<x>Sub() e</x>.Value)
' single line function changes to multi-line function with null expression
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithExpression(g.ValueReturningLambdaExpression(g.IdentifierName("e")), Nothing),
<x>Function()
End Function</x>.Value)
' single line sub changes to multi line sub with null expression
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithExpression(g.VoidReturningLambdaExpression(g.IdentifierName("e")), Nothing),
<x>Sub()
End Sub</x>.Value)
' multi line function no-op when assigned null expression
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithExpression(g.ValueReturningLambdaExpression({g.IdentifierName("s")}), Nothing),
<x>Function()
s
End Function</x>.Value)
' multi line sub no-op when assigned null expression
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithExpression(g.VoidReturningLambdaExpression({g.IdentifierName("s")}), Nothing),
<x>Sub()
s
End Sub</x>.Value)
Assert.Null(g.GetExpression(g.WithExpression(g.IdentifierName("e"), g.IdentifierName("x"))))
End Sub
<Fact>
......@@ -1980,9 +2047,11 @@ End Function</x>.Value)
Assert.Equal(0, g.GetStatements(g.ConstructorDeclaration()).Count)
Assert.Equal(2, g.GetStatements(g.ConstructorDeclaration(statements:=stmts)).Count)
Assert.Equal(0, g.GetStatements(g.VoidReturningLambdaExpression(g.IdentifierName("e"))).Count)
Assert.Equal(0, g.GetStatements(g.VoidReturningLambdaExpression({})).Count)
Assert.Equal(2, g.GetStatements(g.VoidReturningLambdaExpression(stmts)).Count)
Assert.Equal(0, g.GetStatements(g.ValueReturningLambdaExpression(g.IdentifierName("e"))).Count)
Assert.Equal(0, g.GetStatements(g.ValueReturningLambdaExpression({})).Count)
Assert.Equal(2, g.GetStatements(g.ValueReturningLambdaExpression(stmts)).Count)
......@@ -1995,12 +2064,69 @@ End Function</x>.Value)
Assert.Equal(2, g.GetStatements(g.WithStatements(g.MethodDeclaration("m"), stmts)).Count)
Assert.Equal(2, g.GetStatements(g.WithStatements(g.ConstructorDeclaration(), stmts)).Count)
Assert.Equal(2, g.GetStatements(g.WithStatements(g.VoidReturningLambdaExpression({}), stmts)).Count)
Assert.Equal(2, g.GetStatements(g.WithStatements(g.ValueReturningLambdaExpression({}), stmts)).Count)
Assert.Equal(2, g.GetStatements(g.WithStatements(g.VoidReturningLambdaExpression(g.IdentifierName("e")), stmts)).Count)
Assert.Equal(2, g.GetStatements(g.WithStatements(g.ValueReturningLambdaExpression(g.IdentifierName("e")), stmts)).Count)
Assert.Equal(0, g.GetStatements(g.WithStatements(g.IdentifierName("x"), stmts)).Count)
End Sub
<Fact>
Public Sub TestWithStatements_LambdaChanges()
Dim stmts = {g.ExpressionStatement(g.IdentifierName("x")), g.ExpressionStatement(g.IdentifierName("y"))}
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithStatements(g.VoidReturningLambdaExpression({}), stmts),
<x>Sub()
x
y
End Sub</x>.Value)
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithStatements(g.ValueReturningLambdaExpression({}), stmts),
<x>Function()
x
y
End Function</x>.Value)
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithStatements(g.VoidReturningLambdaExpression(g.IdentifierName("e")), stmts),
<x>Sub()
x
y
End Sub</x>.Value)
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithStatements(g.ValueReturningLambdaExpression(g.IdentifierName("e")), stmts),
<x>Function()
x
y
End Function</x>.Value)
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithStatements(g.VoidReturningLambdaExpression(stmts), {}),
<x>Sub()
End Sub</x>.Value)
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithStatements(g.ValueReturningLambdaExpression(stmts), {}),
<x>Function()
End Function</x>.Value)
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithStatements(g.VoidReturningLambdaExpression(g.IdentifierName("e")), {}),
<x>Sub()
End Sub</x>.Value)
VerifySyntax(Of MultiLineLambdaExpressionSyntax)(
g.WithStatements(g.ValueReturningLambdaExpression(g.IdentifierName("e")), {}),
<x>Function()
End Function</x>.Value)
End Sub
<Fact>
Public Sub TestGetAccessorStatements()
Dim stmts = {g.ExpressionStatement(g.AssignmentStatement(g.IdentifierName("x"), g.IdentifierName("y"))), g.ExpressionStatement(g.InvocationExpression(g.IdentifierName("fn"), g.IdentifierName("arg")))}
......@@ -2043,8 +2169,9 @@ End Function</x>.Value)
Assert.Equal(0, g.GetSetAccessorStatements(g.WithSetAccessorStatements(g.IdentifierName("x"), stmts)).Count)
End Sub
Private Sub AssertNamesEqual(expectedNames As String(), actualNodes As IEnumerable(Of SyntaxNode))
Private Sub AssertNamesEqual(expectedNames As String(), actualNodes As IReadOnlyList(Of SyntaxNode))
Dim actualNames = actualNodes.Select(Function(n) g.GetName(n)).ToArray()
Assert.Equal(expectedNames.Length, actualNames.Length)
Dim expected = String.Join(", ", expectedNames)
Dim actual = String.Join(", ", actualNames)
Assert.Equal(expected, actual)
......@@ -2068,23 +2195,6 @@ End Function</x>.Value)
AssertMemberNamesEqual("c", g.CompilationUnit(declarations:={g.ClassDeclaration("c")}))
End Sub
<Fact>
Public Sub TestWithMembers()
AssertMemberNamesEqual("m", g.WithMembers(g.ClassDeclaration("d"), {g.MethodDeclaration("m")}))
AssertMemberNamesEqual("m", g.WithMembers(g.StructDeclaration("s"), {g.MethodDeclaration("m")}))
AssertMemberNamesEqual("m", g.WithMembers(g.InterfaceDeclaration("i"), {g.MethodDeclaration("m")}))
AssertMemberNamesEqual("v", g.WithMembers(g.EnumDeclaration("e"), {g.EnumMember("v")}))
AssertMemberNamesEqual("n2", g.WithMembers(g.NamespaceDeclaration("n"), {g.NamespaceDeclaration("n2")}))
AssertMemberNamesEqual("n", g.WithMembers(g.CompilationUnit(), {g.NamespaceDeclaration("n")}))
Assert.Equal(0, g.GetMembers(g.WithMembers(g.ClassDeclaration("d", members:={g.MethodDeclaration("m")}), Nothing)).Count)
Assert.Equal(0, g.GetMembers(g.WithMembers(g.StructDeclaration("s", members:={g.MethodDeclaration("m")}), Nothing)).Count)
Assert.Equal(0, g.GetMembers(g.WithMembers(g.InterfaceDeclaration("i", members:={g.MethodDeclaration("m")}), Nothing)).Count)
Assert.Equal(0, g.GetMembers(g.WithMembers(g.EnumDeclaration("i", members:={g.EnumMember("v")}), Nothing)).Count)
Assert.Equal(0, g.GetMembers(g.WithMembers(g.NamespaceDeclaration("n", {g.NamespaceDeclaration("n")}), Nothing)).Count)
Assert.Equal(0, g.GetMembers(g.WithMembers(g.CompilationUnit(declarations:={g.NamespaceDeclaration("n")}), Nothing)).Count)
End Sub
<Fact>
Public Sub TestAddMembers()
AssertMemberNamesEqual("m", g.AddMembers(g.ClassDeclaration("d"), {g.MethodDeclaration("m")}))
......@@ -2101,5 +2211,25 @@ End Function</x>.Value)
AssertMemberNamesEqual({"n1", "n2"}, g.AddMembers(g.NamespaceDeclaration("n", {g.NamespaceDeclaration("n1")}), {g.NamespaceDeclaration("n2")}))
AssertMemberNamesEqual({"n1", "n2"}, g.AddMembers(g.CompilationUnit(declarations:={g.NamespaceDeclaration("n1")}), {g.NamespaceDeclaration("n2")}))
End Sub
<Fact>
Public Sub TestRemoveMembers()
TestRemoveAllMembers(g.RemoveMembers(g.ClassDeclaration("d", members:={g.MethodDeclaration("m")})))
TestRemoveAllMembers(g.RemoveMembers(g.StructDeclaration("s", members:={g.MethodDeclaration("m")})))
TestRemoveAllMembers(g.RemoveMembers(g.InterfaceDeclaration("i", members:={g.MethodDeclaration("m")})))
TestRemoveAllMembers(g.RemoveMembers(g.EnumDeclaration("i", members:={g.EnumMember("v")})))
TestRemoveAllMembers(g.AddMembers(g.NamespaceDeclaration("n", {g.NamespaceDeclaration("n1")})))
TestRemoveAllMembers(g.AddMembers(g.CompilationUnit(declarations:={g.NamespaceDeclaration("n1")})))
End Sub
Private Sub TestRemoveAllMembers(declaration As SyntaxNode)
Assert.Equal(0, g.GetMembers(g.RemoveMembers(declaration, g.GetMembers(declaration))).Count)
End Sub
Private Sub TestRemoveMember(declaration As SyntaxNode, name As String, remainingNames As String())
Dim newDecl = g.RemoveMembers(declaration, g.GetMembers(declaration).First(Function(m) g.GetName(m) = name))
AssertMemberNamesEqual(remainingNames, newDecl)
End Sub
End Class
End Namespace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册