未验证 提交 e616ed7f 编写于 作者: S Sam Harwell 提交者: GitHub

Merge pull request #20240 from alrz/smart-as-p2

Look for type check in local declaration and return statements
......@@ -14,10 +14,9 @@ public async Task FixAllInDocument1()
await TestInRegularAndScriptAsync(
@"class C
{
void M()
int M()
{
string a;
int[] b;
{|FixAllInDocument:var|} x = o as string;
if (x != null)
{
......@@ -32,14 +31,16 @@ void M()
{
}
if ((b = o as int[]) != null)
{
}
var c = o as string;
var d = c != null ? 1 : 0;
var e = o as string;
return e != null ? 1 : 0;
}
}",
@"class C
{
void M()
int M()
{
if (o is string x)
{
......@@ -53,9 +54,9 @@ void M()
{
}
if (o is int[] b)
{
}
var d = o is string c ? 1 : 0;
return o is string e ? 1 : 0;
}
}");
}
......
......@@ -53,8 +53,12 @@ private static ExpressionSyntax GetCondition(SyntaxNode node)
return ((WhileStatementSyntax)node).Condition;
case SyntaxKind.IfStatement:
return ((IfStatementSyntax)node).Condition;
case SyntaxKind.ReturnStatement:
return ((ReturnStatementSyntax)node).Expression;
case SyntaxKind.LocalDeclarationStatement:
return ((LocalDeclarationStatementSyntax)node).Declaration.Variables[0].Initializer.Value;
default:
throw ExceptionUtilities.Unreachable;
throw ExceptionUtilities.UnexpectedValue(node.Kind());
}
}
......@@ -64,12 +68,12 @@ private static ExpressionSyntax GetCondition(SyntaxNode node)
CancellationToken cancellationToken)
{
var localDeclarationLocation = diagnostic.AdditionalLocations[0];
var ifOrWhileStatementLocation = diagnostic.AdditionalLocations[1];
var targetStatementLocation = diagnostic.AdditionalLocations[1];
var conditionLocation = diagnostic.AdditionalLocations[2];
var asExpressionLocation = diagnostic.AdditionalLocations[3];
var localDeclaration = (LocalDeclarationStatementSyntax)localDeclarationLocation.FindNode(cancellationToken);
var ifOrWhileStatement = (StatementSyntax)ifOrWhileStatementLocation.FindNode(cancellationToken);
var targetStatement = (StatementSyntax)targetStatementLocation.FindNode(cancellationToken);
var conditionPart = (BinaryExpressionSyntax)conditionLocation.FindNode(cancellationToken);
var asExpression = (BinaryExpressionSyntax)asExpressionLocation.FindNode(cancellationToken);
......@@ -79,7 +83,7 @@ private static ExpressionSyntax GetCondition(SyntaxNode node)
SyntaxFactory.SingleVariableDesignation(
localDeclaration.Declaration.Variables[0].Identifier.WithoutTrivia())));
var currentCondition = GetCondition(ifOrWhileStatement);
var currentCondition = GetCondition(targetStatement);
var updatedCondition = currentCondition.ReplaceNode(conditionPart, updatedConditionPart);
var block = (BlockSyntax)localDeclaration.Parent;
......@@ -93,7 +97,7 @@ private static ExpressionSyntax GetCondition(SyntaxNode node)
(s, g) => s.WithPrependedNonIndentationTriviaFrom(localDeclaration));
editor.RemoveNode(localDeclaration, SyntaxRemoveOptions.KeepUnbalancedDirectives);
editor.ReplaceNode(ifOrWhileStatement, (currentStatement, g) =>
editor.ReplaceNode(targetStatement, (currentStatement, g) =>
{
var updatedStatement = currentStatement.ReplaceNode(GetCondition(currentStatement), updatedCondition);
return updatedStatement.WithAdditionalAnnotations(Formatter.Annotation);
......
......@@ -39,7 +39,10 @@ public CSharpAsAndNullCheckDiagnosticAnalyzer()
protected override void InitializeWorker(AnalysisContext context)
=> context.RegisterSyntaxNodeAction(SyntaxNodeAction,
SyntaxKind.IfStatement, SyntaxKind.WhileStatement);
SyntaxKind.IfStatement,
SyntaxKind.WhileStatement,
SyntaxKind.ReturnStatement,
SyntaxKind.LocalDeclarationStatement);
private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
{
......@@ -68,34 +71,33 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
return;
}
var ifOrWhileStatement = (StatementSyntax)node;
var leftmostCondition = GetLeftmostCondition(ifOrWhileStatement);
if (!leftmostCondition.IsKind(SyntaxKind.NotEqualsExpression, out BinaryExpressionSyntax comparison))
var targetStatement = (StatementSyntax)node;
var leftmostCondition = GetLeftmostCondition(targetStatement);
if (!leftmostCondition.IsKind(SyntaxKind.NotEqualsExpression, out BinaryExpressionSyntax notEquals))
{
return;
}
var operand = GetNullCheckOperand(comparison.Left, comparison.Right)?.WalkDownParentheses();
var operand = GetNullCheckOperand(notEquals.Left, notEquals.Right)?.WalkDownParentheses();
if (operand == null)
{
return;
}
// if/while has to be in a block so we can at least look for a preceding local variable declaration.
if (!ifOrWhileStatement.Parent.IsKind(SyntaxKind.Block, out BlockSyntax parentBlock))
if (!targetStatement.Parent.IsKind(SyntaxKind.Block, out BlockSyntax parentBlock))
{
return;
}
var blockStatements = parentBlock.Statements;
if (!TryGetTypeCheckParts(operand, ifOrWhileStatement, blockStatements,
if (!TryGetTypeCheckParts(operand, targetStatement, parentBlock,
out var declarator, out var asExpression))
{
return;
}
var semanticModel = syntaxContext.SemanticModel;
if (semanticModel.GetSymbolInfo(comparison).GetAnySymbol().IsUserDefinedOperator())
if (semanticModel.GetSymbolInfo(notEquals).GetAnySymbol().IsUserDefinedOperator())
{
return;
}
......@@ -139,7 +141,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
// in the Else branch of the IfStatement, or after the IfStatement. Make sure
// that doesn't cause definite assignment issues.
if (IsAccessedBeforeAssignment(semanticModel, localSymbol,
declarationStatement, ifOrWhileStatement, blockStatements, cancellationToken))
declarationStatement, targetStatement, parentBlock, cancellationToken))
{
return;
}
......@@ -147,7 +149,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
// Looks good!
var additionalLocations = ImmutableArray.Create(
declarationStatement.GetLocation(),
ifOrWhileStatement.GetLocation(),
targetStatement.GetLocation(),
leftmostCondition.GetLocation(),
asExpression.GetLocation());
......@@ -162,22 +164,23 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
SemanticModel semanticModel,
ISymbol localVariable,
StatementSyntax declarationStatement,
StatementSyntax usageStatement,
SyntaxList<StatementSyntax> blockStatements,
StatementSyntax targetStatement,
BlockSyntax parentBlock,
CancellationToken cancellationToken)
{
var isAssigned = false;
var isAccessedBeforeAssignment = false;
var usageIndex = blockStatements.IndexOf(usageStatement);
var declarationIndex = blockStatements.IndexOf(declarationStatement);
var statements = parentBlock.Statements;
var targetIndex = statements.IndexOf(targetStatement);
var declarationIndex = statements.IndexOf(declarationStatement);
// Since we're going to remove this declaration-statement,
// we need to first ensure that it's not used up to the target statement.
for (var index = declarationIndex + 1; index < usageIndex; index++)
for (var index = declarationIndex + 1; index < targetIndex; index++)
{
CheckDefiniteAssignment(
semanticModel, localVariable, blockStatements[index],
semanticModel, localVariable, statements[index],
out isAssigned, out isAccessedBeforeAssignment,
cancellationToken);
......@@ -189,7 +192,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
// In case of an if-statement, we need to check if the variable
// is being accessed before assignment in the else clause.
if (usageStatement is IfStatementSyntax ifStatement)
if (targetStatement.IsKind(SyntaxKind.IfStatement, out IfStatementSyntax ifStatement))
{
CheckDefiniteAssignment(
semanticModel, localVariable, ifStatement.Else,
......@@ -208,10 +211,10 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
}
// Make sure that no access is made to the variable before assignment in the subsequent statements
for (int index = usageIndex + 1, n = blockStatements.Count; index < n; index++)
for (int index = targetIndex + 1, n = statements.Count; index < n; index++)
{
CheckDefiniteAssignment(
semanticModel, localVariable, blockStatements[index],
semanticModel, localVariable, statements[index],
out isAssigned, out isAccessedBeforeAssignment,
cancellationToken);
......@@ -224,7 +227,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
{
// The scope of pattern variables in a while-statement does not leak out to
// the enclosing block so we bail also if there is any assignments afterwards.
return usageStatement.Kind() == SyntaxKind.WhileStatement;
return targetStatement.Kind() == SyntaxKind.WhileStatement;
}
}
......@@ -261,8 +264,8 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
private static bool TryGetTypeCheckParts(
SyntaxNode operand,
StatementSyntax usageStatement,
SyntaxList<StatementSyntax> blockStatements,
StatementSyntax targetStatement,
BlockSyntax parentBlock,
out SyntaxNode variableDeclarator,
out SyntaxNode asExpression)
{
......@@ -279,12 +282,12 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
//
// That's because in this case, unlike the original code, we're type-checking in every iteration
// so we do not replace simple null check with the "is" operator if it's in a while loop
case SyntaxKind.IdentifierName when usageStatement.Kind() != SyntaxKind.WhileStatement:
case SyntaxKind.IdentifierName when targetStatement.Kind() != SyntaxKind.WhileStatement:
{
// var x = e as T;
// if (x != null) F(x);
var identifier = (IdentifierNameSyntax)operand;
var declarator = TryFindVariableDeclarator(identifier, usageStatement, blockStatements);
var declarator = TryFindVariableDeclarator(identifier, targetStatement, parentBlock);
var initializerValue = declarator?.Initializer?.Value;
if (!initializerValue.IsKind(SyntaxKind.AsExpression))
{
......@@ -308,7 +311,7 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
}
var identifier = (IdentifierNameSyntax)assignment.Left;
var declarator = TryFindVariableDeclarator(identifier, usageStatement, blockStatements);
var declarator = TryFindVariableDeclarator(identifier, targetStatement, parentBlock);
if (declarator == null)
{
break;
......@@ -326,12 +329,13 @@ private void SyntaxNodeAction(SyntaxNodeAnalysisContext syntaxContext)
}
private static VariableDeclaratorSyntax TryFindVariableDeclarator(
IdentifierNameSyntax identifier, StatementSyntax usageStatement, SyntaxList<StatementSyntax> blockStatements)
IdentifierNameSyntax identifier, StatementSyntax targetStatement, BlockSyntax parentBlock)
{
var usageIndex = blockStatements.IndexOf(usageStatement);
for (var index = usageIndex - 1; index >= 0; index--)
var statement = parentBlock.Statements;
var targetIndex = statement.IndexOf(targetStatement);
for (var index = targetIndex - 1; index >= 0; index--)
{
if (!blockStatements[index].IsKind(SyntaxKind.LocalDeclarationStatement,
if (!statement[index].IsKind(SyntaxKind.LocalDeclarationStatement,
out LocalDeclarationStatementSyntax declarationStatement))
{
continue;
......@@ -369,7 +373,7 @@ private static SyntaxNode GetLeftmostCondition(SyntaxNode node)
{
while (true)
{
switch (node.Kind())
switch (node?.Kind())
{
case SyntaxKind.WhileStatement:
node = ((WhileStatementSyntax)node).Condition;
......@@ -377,6 +381,15 @@ private static SyntaxNode GetLeftmostCondition(SyntaxNode node)
case SyntaxKind.IfStatement:
node = ((IfStatementSyntax)node).Condition;
continue;
case SyntaxKind.ReturnStatement:
node = ((ReturnStatementSyntax)node).Expression;
continue;
case SyntaxKind.LocalDeclarationStatement:
var declarators = ((LocalDeclarationStatementSyntax)node).Declaration.Variables;
// We require this to be the only declarator in the declaration statement
// to simplify definitive assignment check and the code fix for now
node = declarators.Count == 1 ? declarators[0].Initializer?.Value : null;
continue;
case SyntaxKind.ParenthesizedExpression:
node = ((ParenthesizedExpressionSyntax)node).Expression;
continue;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册