未验证 提交 db896485 编写于 作者: J Jason Malinowski 提交者: GitHub

Merge pull request #37312 from jasonmalinowski/fix-makemethodasync

Fix make method async/sync and some related code fixes
......@@ -102,6 +102,39 @@ static IEnumerable<int> M()
await TestInRegularAndScriptAsync(initial, expected);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsChangeToIEnumerable)]
public async Task TestChangeToIEnumerableWithListReturningMethodWithNullableArgument()
{
var initial =
@"#nullable enable
using System;
using System.Collections.Generic;
class Program
{
static IList<string?> [|M|]()
{
yield return """";
}
}";
var expected =
@"#nullable enable
using System;
using System.Collections.Generic;
class Program
{
static IEnumerable<string?> M()
{
yield return """";
}
}";
await TestInRegularAndScriptAsync(initial, expected);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsChangeToIEnumerable)]
public async Task TestChangeToIEnumerableGenericIEnumerableMethod()
{
......
......@@ -1274,5 +1274,87 @@ void M()
}
}");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodAsynchronous)]
public async Task MethodWithNullableReturn()
{
await TestInRegularAndScriptAsync(
@"using System.Threading.Tasks;
class C
{
string? M()
{
[|await Task.Delay(1);|]
return null;
}
}",
@"using System.Threading.Tasks;
class C
{
async Task<string?> MAsync()
{
await Task.Delay(1);
return null;
}
}");
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodAsynchronous)]
public async Task EnumerableMethodWithNullableType()
{
var initial =
@"using System.Threading.Tasks;
using System.Collections.Generic;
class Program
{
IEnumerable<string?> Test()
{
yield return string.Empty;
[|await Task.Delay(1);|]
}
}" + IAsyncEnumerable;
var expected =
@"using System.Threading.Tasks;
using System.Collections.Generic;
class Program
{
async IAsyncEnumerable<string?> TestAsync()
{
yield return string.Empty;
await Task.Delay(1);
}
}" + IAsyncEnumerable;
await TestInRegularAndScriptAsync(initial, expected);
}
[Fact, Trait(Traits.Feature, Traits.Features.CodeActionsMakeMethodAsynchronous)]
public async Task EnumeratorMethodWithNullableType()
{
var initial =
@"using System.Threading.Tasks;
using System.Collections.Generic;
class Program
{
IEnumerator<string?> Test()
{
yield return string.Empty;
[|await Task.Delay(1);|]
}
}" + IAsyncEnumerable;
var expected =
@"using System.Threading.Tasks;
using System.Collections.Generic;
class Program
{
async IAsyncEnumerator<string?> TestAsync()
{
yield return string.Empty;
await Task.Delay(1);
}
}" + IAsyncEnumerable;
await TestInRegularAndScriptAsync(initial, expected);
}
}
}
......@@ -115,8 +115,8 @@ private bool IsCorrectTypeForYieldReturn(ITypeSymbol typeArgument, ITypeSymbol r
return false;
}
ienumerableGenericSymbol = ienumerableGenericSymbol.Construct(typeArgument);
ienumeratorGenericSymbol = ienumeratorGenericSymbol.Construct(typeArgument);
ienumerableGenericSymbol = ienumerableGenericSymbol.ConstructWithNullability(typeArgument);
ienumeratorGenericSymbol = ienumeratorGenericSymbol.ConstructWithNullability(typeArgument);
if (!CanConvertTypes(typeArgument, returnExpressionType, model))
{
......@@ -137,7 +137,7 @@ private bool IsCorrectTypeForYieldReturn(ITypeSymbol typeArgument, ITypeSymbol r
private bool CanConvertTypes(ITypeSymbol typeArgument, ITypeSymbol returnExpressionType, SemanticModel model)
{
// return false if there is no conversion for the top level type
if (!model.Compilation.ClassifyConversion(typeArgument, returnExpressionType).Exists)
if (!model.Compilation.ClassifyConversion(typeArgument.WithoutNullability(), returnExpressionType.WithoutNullability()).Exists)
{
return false;
}
......
......@@ -57,11 +57,11 @@ protected override async Task<CodeAction> GetCodeFixAsync(SyntaxNode root, Synta
if (arity == 1)
{
var typeArg = type.GetTypeArguments().First();
ienumerableGenericSymbol = ienumerableGenericSymbol.Construct(typeArg);
ienumerableGenericSymbol = ienumerableGenericSymbol.ConstructWithNullability(typeArg);
}
else if (arity == 0 && type is IArrayTypeSymbol)
{
ienumerableGenericSymbol = ienumerableGenericSymbol.Construct((type as IArrayTypeSymbol).ElementType);
ienumerableGenericSymbol = ienumerableGenericSymbol.ConstructWithNullability((type as IArrayTypeSymbol).ElementType);
}
else
{
......
......@@ -102,13 +102,13 @@ protected override bool IsAsyncReturnType(ITypeSymbol type, KnownTypes knownType
{
newReturnType = knownTypes._iAsyncEnumerableOfTTypeOpt is null
? MakeGenericType("IAsyncEnumerable", methodSymbol.ReturnType)
: knownTypes._iAsyncEnumerableOfTTypeOpt.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
: knownTypes._iAsyncEnumerableOfTTypeOpt.ConstructWithNullability(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
}
else if (IsIEnumerator(returnType, knownTypes) && IsIterator(methodSymbol))
{
newReturnType = knownTypes._iAsyncEnumeratorOfTTypeOpt is null
? MakeGenericType("IAsyncEnumerator", methodSymbol.ReturnType)
: knownTypes._iAsyncEnumeratorOfTTypeOpt.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
: knownTypes._iAsyncEnumeratorOfTTypeOpt.ConstructWithNullability(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
}
else if (IsIAsyncEnumerableOrEnumerator(returnType, knownTypes))
{
......@@ -118,7 +118,7 @@ protected override bool IsAsyncReturnType(ITypeSymbol type, KnownTypes knownType
{
// If it's not already Task-like, then wrap the existing return type
// in Task<>.
newReturnType = knownTypes._taskOfTType.Construct(methodSymbol.ReturnType).GenerateTypeSyntax();
newReturnType = knownTypes._taskOfTType.ConstructWithNullability(methodSymbol.GetReturnTypeWithAnnotatedNullability()).GenerateTypeSyntax();
}
}
......
......@@ -72,12 +72,12 @@ private static TypeSyntax FixMethodReturnType(IMethodSymbol methodSymbol, TypeSy
else if (returnType.OriginalDefinition.Equals(knownTypes._iAsyncEnumerableOfTTypeOpt))
{
// If the return type is IAsyncEnumerable<T>, then make the new return type IEnumerable<T>.
newReturnType = knownTypes._iEnumerableOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
newReturnType = knownTypes._iEnumerableOfTType.ConstructWithNullability(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
}
else if (returnType.OriginalDefinition.Equals(knownTypes._iAsyncEnumeratorOfTTypeOpt))
{
// If the return type is IAsyncEnumerator<T>, then make the new return type IEnumerator<T>.
newReturnType = knownTypes._iEnumeratorOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
newReturnType = knownTypes._iEnumeratorOfTType.ConstructWithNullability(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
}
return newReturnType;
......
......@@ -616,7 +616,8 @@ private bool DoesTypeReferenceTypeParameter(ITypeSymbol type, ITypeParameterSymb
return false;
}
if (type == typeParameter ||
// We want to ignore nullability when comparing as T and T? both are references to the type parameter
if (type.Equals(typeParameter, SymbolEqualityComparer.Default) ||
type.GetTypeArguments().Any(t => DoesTypeReferenceTypeParameter(t, typeParameter, checkedTypes)))
{
return true;
......
......@@ -67,7 +67,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.CodeFixes.Iterator
Return Nothing
End If
ienumerableSymbol = ienumerableSymbol.Construct(method.ReturnType.GetTypeArguments().First())
ienumerableSymbol = ienumerableSymbol.ConstructWithNullability(method.ReturnType.GetTypeArguments().First())
If Not method.ReturnType.Equals(ienumerableSymbol) Then
Return Nothing
......
......@@ -222,7 +222,7 @@ protected static bool IsNonIntersectingNamespace(ISymbol recommendationSymbol, S
? _context.SemanticModel.LookupStaticMembers(position, container)
: SuppressDefaultTupleElements(
container,
_context.SemanticModel.LookupSymbols(position, container, includeReducedExtensionMethods: true));
_context.SemanticModel.LookupSymbols(position, container.WithoutNullability(), includeReducedExtensionMethods: true));
}
/// <summary>
......
......@@ -464,8 +464,8 @@ public static ImmutableArray<ITypeSymbol> GetTypeArguments(this ISymbol symbol)
{
switch (symbol)
{
case IMethodSymbol m: return m.TypeArguments;
case INamedTypeSymbol nt: return nt.TypeArguments;
case IMethodSymbol m: return m.TypeArguments.ZipAsArray(m.TypeArgumentNullableAnnotations, (t, n) => t.WithNullability(n));
case INamedTypeSymbol nt: return nt.TypeArguments.ZipAsArray(nt.TypeArgumentNullableAnnotations, (t, n) => t.WithNullability(n));
default: return ImmutableArray.Create<ITypeSymbol>();
}
}
......
......@@ -66,7 +66,7 @@ public static NullableAnnotation GetNullability(this ITypeSymbol typeSymbol)
}
}
public static T WithoutNullability<T>(this T typeSymbol) where T : ITypeSymbol
public static T WithoutNullability<T>(this T typeSymbol) where T : INamespaceOrTypeSymbol
{
if (typeSymbol is TypeSymbolWithNullableAnnotation typeSymbolWithNullability)
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册