From 20c4ad2eba4d23c35f4bc323418a5fcb41b06796 Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 14 Nov 2016 13:49:25 -0800 Subject: [PATCH] Ensure CancellationToken parameters go last in an extracted method. --- .../ExtractMethod/ExtractMethodTests.cs | 45 ++++++++++++++++++- .../ExtractMethod/MethodExtractor.Analyzer.cs | 27 ++++------- .../MethodExtractor.CodeGenerator.cs | 2 +- .../MethodExtractor.VariableInfo.cs | 8 +++- .../MethodExtractor.VariableSymbol.cs | 33 ++++++++------ 5 files changed, 79 insertions(+), 36 deletions(-) diff --git a/src/EditorFeatures/CSharpTest/CodeActions/ExtractMethod/ExtractMethodTests.cs b/src/EditorFeatures/CSharpTest/CodeActions/ExtractMethod/ExtractMethodTests.cs index 0c8fecd5257..a3679231f33 100644 --- a/src/EditorFeatures/CSharpTest/CodeActions/ExtractMethod/ExtractMethodTests.cs +++ b/src/EditorFeatures/CSharpTest/CodeActions/ExtractMethod/ExtractMethodTests.cs @@ -1170,5 +1170,48 @@ private static void NewMethod(out int r, out int y) } }"); } - } + + [WorkItem(15218, "https://github.com/dotnet/roslyn/issues/15218")] + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsExtractMethod)] + public async Task TestCancellationTokenGoesLast() + { + await TestAsync( +@"using System; +using System.Threading; + +class C +{ + void M(CancellationToken ct) + { + var v = 0; + + [|if (true) + { + ct.ThrowIfCancellationRequested(); + Console.WriteLine(v); + }|] + } +}", +@"using System; +using System.Threading; + +class C +{ + void M(CancellationToken ct) + { + var v = 0; + {|Rename:NewMethod|}(v, ct); + } + + private static void NewMethod(int v, CancellationToken ct) + { + if (true) + { + ct.ThrowIfCancellationRequested(); + Console.WriteLine(v); + } + } +}"); + } + } } \ No newline at end of file diff --git a/src/Features/Core/Portable/ExtractMethod/MethodExtractor.Analyzer.cs b/src/Features/Core/Portable/ExtractMethod/MethodExtractor.Analyzer.cs index c188c78c2ee..72506907aad 100644 --- a/src/Features/Core/Portable/ExtractMethod/MethodExtractor.Analyzer.cs +++ b/src/Features/Core/Portable/ExtractMethod/MethodExtractor.Analyzer.cs @@ -82,7 +82,7 @@ public async Task AnalyzeAsync() // collects various variable informations // extracted code contains return value var isInExpressionOrHasReturnStatement = IsInExpressionOrHasReturnStatement(model); - var signatureTuple = GetSignatureInformation(model, dataFlowAnalysisData, variableInfoMap, isInExpressionOrHasReturnStatement); + var signatureTuple = GetSignatureInformation(dataFlowAnalysisData, variableInfoMap, isInExpressionOrHasReturnStatement); var parameters = signatureTuple.Item1; var returnType = signatureTuple.Item2; @@ -198,20 +198,17 @@ private void WrapReturnTypeInTask(SemanticModel model, ref ITypeSymbol returnTyp } private Tuple, ITypeSymbol, VariableInfo, bool> GetSignatureInformation( - SemanticModel model, DataFlowAnalysis dataFlowAnalysisData, IDictionary variableInfoMap, bool isInExpressionOrHasReturnStatement) { + var model = _semanticDocument.SemanticModel; + var compilation = model.Compilation; if (isInExpressionOrHasReturnStatement) { // check whether current selection contains return statement var parameters = GetMethodParameters(variableInfoMap.Values); - var returnType = this.SelectionResult.GetContainingScopeType(); - if (returnType == null) - { - returnType = model.Compilation.GetSpecialType(SpecialType.System_Object); - } + var returnType = SelectionResult.GetContainingScopeType() ?? compilation.GetSpecialType(SpecialType.System_Object); var unsafeAddressTakenUsed = ContainsVariableUnsafeAddressTaken(dataFlowAnalysisData, variableInfoMap.Keys); return Tuple.Create(parameters, returnType, default(VariableInfo), unsafeAddressTakenUsed); @@ -221,15 +218,9 @@ private void WrapReturnTypeInTask(SemanticModel model, ref ITypeSymbol returnTyp // no return statement var parameters = MarkVariableInfoToUseAsReturnValueIfPossible(GetMethodParameters(variableInfoMap.Values)); var variableToUseAsReturnValue = parameters.FirstOrDefault(v => v.UseAsReturnValue); - var returnType = default(ITypeSymbol); - if (variableToUseAsReturnValue != null) - { - returnType = variableToUseAsReturnValue.GetVariableType(_semanticDocument); - } - else - { - returnType = model.Compilation.GetSpecialType(SpecialType.System_Void); - } + var returnType = variableToUseAsReturnValue != null + ? variableToUseAsReturnValue.GetVariableType(_semanticDocument) + : compilation.GetSpecialType(SpecialType.System_Void); var unsafeAddressTakenUsed = ContainsVariableUnsafeAddressTaken(dataFlowAnalysisData, variableInfoMap.Keys); return Tuple.Create(parameters, returnType, variableToUseAsReturnValue, unsafeAddressTakenUsed); @@ -353,9 +344,7 @@ private IList MarkVariableInfoToUseAsReturnValueIfPossible(IList GetMethodParameters(ICollection variableInfo) { var list = new List(variableInfo); - - list.Sort(VariableInfo.Compare); - + VariableInfo.SortVariables(_semanticDocument.SemanticModel.Compilation, list); return list; } diff --git a/src/Features/Core/Portable/ExtractMethod/MethodExtractor.CodeGenerator.cs b/src/Features/Core/Portable/ExtractMethod/MethodExtractor.CodeGenerator.cs index f3af131c855..a8554079323 100644 --- a/src/Features/Core/Portable/ExtractMethod/MethodExtractor.CodeGenerator.cs +++ b/src/Features/Core/Portable/ExtractMethod/MethodExtractor.CodeGenerator.cs @@ -131,7 +131,7 @@ protected VariableInfo GetOutermostVariableToMoveIntoMethodDefinition(Cancellati return null; } - variables.Sort(VariableInfo.Compare); + VariableInfo.SortVariables(SemanticDocument.SemanticModel.Compilation, variables); return variables[0]; } diff --git a/src/Features/Core/Portable/ExtractMethod/MethodExtractor.VariableInfo.cs b/src/Features/Core/Portable/ExtractMethod/MethodExtractor.VariableInfo.cs index d863da74c3f..725e548b6c5 100644 --- a/src/Features/Core/Portable/ExtractMethod/MethodExtractor.VariableInfo.cs +++ b/src/Features/Core/Portable/ExtractMethod/MethodExtractor.VariableInfo.cs @@ -128,10 +128,14 @@ public SyntaxToken GetIdentifierTokenAtDeclaration(SyntaxNode node) return node.GetAnnotatedTokens(_variableSymbol.IdentifierTokenAnnotation).SingleOrDefault(); } - public static int Compare(VariableInfo left, VariableInfo right) + public static void SortVariables(Compilation compilation, List list) { - return VariableSymbol.Compare(left._variableSymbol, right._variableSymbol); + var cancellationTokenType = compilation.GetTypeByMetadataName(typeof(CancellationToken).FullName); + list.Sort((v1, v2) => Compare(v1, v2, cancellationTokenType)); } + + private static int Compare(VariableInfo left, VariableInfo right, INamedTypeSymbol cancellationTokenType) + => VariableSymbol.Compare(left._variableSymbol, right._variableSymbol, cancellationTokenType); } } } diff --git a/src/Features/Core/Portable/ExtractMethod/MethodExtractor.VariableSymbol.cs b/src/Features/Core/Portable/ExtractMethod/MethodExtractor.VariableSymbol.cs index 97fe8321921..0ec80693217 100644 --- a/src/Features/Core/Portable/ExtractMethod/MethodExtractor.VariableSymbol.cs +++ b/src/Features/Core/Portable/ExtractMethod/MethodExtractor.VariableSymbol.cs @@ -45,8 +45,24 @@ protected VariableSymbol(Compilation compilation, ITypeSymbol type) /// public ITypeSymbol OriginalType { get; } - public static int Compare(VariableSymbol left, VariableSymbol right) + public static int Compare( + VariableSymbol left, + VariableSymbol right, + INamedTypeSymbol cancellationTokenType) { + // CancellationTokens always go at the end of method signature. + var leftIsCancellationToken = left.OriginalType.Equals(cancellationTokenType); + var rightIsCancellationToken = right.OriginalType.Equals(cancellationTokenType); + + if (leftIsCancellationToken && !rightIsCancellationToken) + { + return 1; + } + else if (!leftIsCancellationToken && rightIsCancellationToken) + { + return -1; + } + if (left.DisplayOrder == right.DisplayOrder) { return left.CompareTo(right); @@ -97,10 +113,7 @@ protected class ParameterVariableSymbol : NotMovableVariableSymbol, IComparable< _parameterSymbol = parameterSymbol; } - public override int DisplayOrder - { - get { return 0; } - } + public override int DisplayOrder => 0; protected override int CompareTo(VariableSymbol right) { @@ -191,10 +204,7 @@ protected class LocalVariableSymbol : VariableSymbol, IComparable 1; protected override int CompareTo(VariableSymbol right) { @@ -310,10 +320,7 @@ protected class QueryVariableSymbol : NotMovableVariableSymbol, IComparable 2; protected override int CompareTo(VariableSymbol right) { -- GitLab