未验证 提交 a150ad4f 编写于 作者: C Charles Stoner 提交者: GitHub

Allow conversion of GetAwaiter extension method this arg (#38960)

上级 fd502267
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.CodeAnalysis.CSharp.Symbols;
namespace Microsoft.CodeAnalysis.CSharp
{
/// <summary>
/// Internal structure containing all semantic information about an await expression.
/// </summary>
internal sealed class AwaitableInfo
{
public static readonly AwaitableInfo Empty = new AwaitableInfo(getAwaiterMethod: null, isCompletedProperty: null, getResultMethod: null);
public readonly MethodSymbol GetAwaiter;
public readonly PropertySymbol IsCompleted;
public readonly MethodSymbol GetResult;
public bool IsDynamic => GetResult is null;
internal AwaitableInfo(MethodSymbol getAwaiterMethod, PropertySymbol isCompletedProperty, MethodSymbol getResultMethod)
{
this.GetAwaiter = getAwaiterMethod;
this.IsCompleted = isCompletedProperty;
this.GetResult = getResultMethod;
}
internal AwaitableInfo Update(MethodSymbol newGetAwaiter, PropertySymbol newIsCompleted, MethodSymbol newGetResult)
{
if (ReferenceEquals(GetAwaiter, newGetAwaiter) && ReferenceEquals(IsCompleted, newIsCompleted) && ReferenceEquals(GetResult, newGetResult))
{
return this;
}
return new AwaitableInfo(newGetAwaiter, newIsCompleted, newGetResult);
}
}
}
......@@ -2478,6 +2478,9 @@ internal static uint GetValEscape(BoundExpression expr, uint scopeOfTheContainin
// then immediately discarded. The actual expression will be generated during lowering
return scopeOfTheContainingExpression;
case BoundKind.AwaitableValuePlaceholder:
return ((BoundAwaitableValuePlaceholder)expr).ValEscape;
case BoundKind.PointerElementAccess:
case BoundKind.PointerIndirectionOperator:
// Unsafe code will always be allowed to escape.
......@@ -2600,10 +2603,17 @@ internal static bool CheckValEscape(SyntaxNode node, BoundExpression expr, uint
return true;
case BoundKind.DeconstructValuePlaceholder:
var placeholder = (BoundDeconstructValuePlaceholder)expr;
if (placeholder.ValEscape > escapeTo)
if (((BoundDeconstructValuePlaceholder)expr).ValEscape > escapeTo)
{
Error(diagnostics, ErrorCode.ERR_EscapeLocal, node, expr.Syntax);
return false;
}
return true;
case BoundKind.AwaitableValuePlaceholder:
if (((BoundAwaitableValuePlaceholder)expr).ValEscape > escapeTo)
{
Error(diagnostics, ErrorCode.ERR_EscapeLocal, node, placeholder.Syntax);
Error(diagnostics, ErrorCode.ERR_EscapeLocal, node, expr.Syntax);
return false;
}
return true;
......
......@@ -23,7 +23,10 @@ private BoundExpression BindAwait(AwaitExpressionSyntax node, DiagnosticBag diag
private BoundAwaitExpression BindAwait(BoundExpression expression, SyntaxNode node, DiagnosticBag diagnostics)
{
bool hasErrors = false;
AwaitableInfo info = BindAwaitInfo(expression, node, node.Location, diagnostics, ref hasErrors);
var placeholder = new BoundAwaitableValuePlaceholder(expression.Syntax, GetValEscape(expression, LocalScopeDepth), expression.Type);
ReportBadAwaitDiagnostics(node, node.Location, diagnostics, ref hasErrors);
var info = BindAwaitInfo(placeholder, node, diagnostics, ref hasErrors, expressionOpt: expression);
// Spec 7.7.7.2:
// The expression await t is classified the same way as the expression (t).GetAwaiter().GetResult(). Thus,
......@@ -34,24 +37,27 @@ private BoundAwaitExpression BindAwait(BoundExpression expression, SyntaxNode no
return new BoundAwaitExpression(node, expression, info, awaitExpressionType, hasErrors);
}
internal AwaitableInfo BindAwaitInfo(BoundExpression expressionOpt, SyntaxNode node, Location location, DiagnosticBag diagnostics, ref bool hasErrors)
internal void ReportBadAwaitDiagnostics(SyntaxNode node, Location location, DiagnosticBag diagnostics, ref bool hasErrors)
{
hasErrors |= ReportBadAwaitWithoutAsync(location, diagnostics);
hasErrors |= ReportBadAwaitContext(node, location, diagnostics);
}
if (expressionOpt is null)
{
return AwaitableInfo.Empty;
}
else
{
MethodSymbol getAwaiter;
PropertySymbol isCompleted;
MethodSymbol getResult;
hasErrors |= !GetAwaitableExpressionInfo(expressionOpt, out getAwaiter, out isCompleted, out getResult, out _, node, diagnostics);
return new AwaitableInfo(getAwaiter, isCompleted, getResult);
}
internal BoundAwaitableInfo BindAwaitInfo(BoundAwaitableValuePlaceholder placeholder, SyntaxNode node, DiagnosticBag diagnostics, ref bool hasErrors, BoundExpression expressionOpt = null)
{
bool hasGetAwaitableErrors = !GetAwaitableExpressionInfo(
expressionOpt ?? placeholder,
placeholder,
out bool isDynamic,
out BoundExpression getAwaiter,
out PropertySymbol isCompleted,
out MethodSymbol getResult,
getAwaiterGetResultCall: out _,
node,
diagnostics);
hasErrors |= hasGetAwaitableErrors;
return new BoundAwaitableInfo(node, placeholder, isDynamic: isDynamic, getAwaiter, isCompleted, getResult, hasErrors: hasGetAwaitableErrors) { WasCompilerGenerated = true };
}
/// <summary>
......@@ -222,13 +228,27 @@ private bool ReportBadAwaitContext(SyntaxNode node, Location location, Diagnosti
/// <returns>True if the expression is awaitable; false otherwise.</returns>
internal bool GetAwaitableExpressionInfo(
BoundExpression expression,
out MethodSymbol getAwaiter,
out BoundExpression getAwaiterGetResultCall,
SyntaxNode node,
DiagnosticBag diagnostics)
{
return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out getAwaiterGetResultCall, node, diagnostics);
}
private bool GetAwaitableExpressionInfo(
BoundExpression expression,
BoundExpression getAwaiterArgument,
out bool isDynamic,
out BoundExpression getAwaiter,
out PropertySymbol isCompleted,
out MethodSymbol getResult,
out BoundExpression getAwaiterGetResultCall,
SyntaxNode node,
DiagnosticBag diagnostics)
{
Debug.Assert(TypeSymbol.Equals(expression.Type, getAwaiterArgument.Type, TypeCompareKind.ConsiderEverything));
isDynamic = false;
getAwaiter = null;
isCompleted = null;
getResult = null;
......@@ -241,19 +261,19 @@ private bool ReportBadAwaitContext(SyntaxNode node, Location location, Diagnosti
if (expression.HasDynamicType())
{
isDynamic = true;
return true;
}
BoundExpression getAwaiterCall = null;
if (!GetGetAwaiterMethod(expression, node, diagnostics, out getAwaiter, out getAwaiterCall))
if (!GetGetAwaiterMethod(getAwaiterArgument, node, diagnostics, out getAwaiter))
{
return false;
}
TypeSymbol awaiterType = getAwaiter.ReturnType;
TypeSymbol awaiterType = getAwaiter.Type;
return GetIsCompletedProperty(awaiterType, node, expression.Type, diagnostics, out isCompleted)
&& AwaiterImplementsINotifyCompletion(awaiterType, node, diagnostics)
&& GetGetResultMethod(getAwaiterCall, node, expression.Type, diagnostics, out getResult, out getAwaiterGetResultCall);
&& GetGetResultMethod(getAwaiter, node, expression.Type, diagnostics, out getResult, out getAwaiterGetResultCall);
}
/// <summary>
......@@ -287,12 +307,11 @@ private static bool ValidateAwaitedExpression(BoundExpression expression, Syntax
/// NOTE: this is an error in the spec. An extension method of the form
/// Awaiter&lt;T&gt; GetAwaiter&lt;T&gt;(this Task&lt;T&gt;) may be used.
/// </remarks>
private bool GetGetAwaiterMethod(BoundExpression expression, SyntaxNode node, DiagnosticBag diagnostics, out MethodSymbol getAwaiterMethod, out BoundExpression getAwaiterCall)
private bool GetGetAwaiterMethod(BoundExpression expression, SyntaxNode node, DiagnosticBag diagnostics, out BoundExpression getAwaiterCall)
{
if (expression.Type.IsVoidType())
{
Error(diagnostics, ErrorCode.ERR_BadAwaitArgVoidCall, node);
getAwaiterMethod = null;
getAwaiterCall = null;
return false;
}
......@@ -300,7 +319,6 @@ private bool GetGetAwaiterMethod(BoundExpression expression, SyntaxNode node, Di
getAwaiterCall = MakeInvocationExpression(node, expression, WellKnownMemberNames.GetAwaiter, ImmutableArray<BoundExpression>.Empty, diagnostics);
if (getAwaiterCall.HasAnyErrors) // && !expression.HasAnyErrors?
{
getAwaiterMethod = null;
getAwaiterCall = null;
return false;
}
......@@ -308,18 +326,16 @@ private bool GetGetAwaiterMethod(BoundExpression expression, SyntaxNode node, Di
if (getAwaiterCall.Kind != BoundKind.Call)
{
Error(diagnostics, ErrorCode.ERR_BadAwaitArg, node, expression.Type);
getAwaiterMethod = null;
getAwaiterCall = null;
return false;
}
getAwaiterMethod = ((BoundCall)getAwaiterCall).Method;
var getAwaiterMethod = ((BoundCall)getAwaiterCall).Method;
if (getAwaiterMethod is ErrorMethodSymbol ||
HasOptionalOrVariableParameters(getAwaiterMethod) || // We might have been able to resolve a GetAwaiter overload with optional parameters, so check for that here
getAwaiterMethod.ReturnsVoid) // If GetAwaiter returns void, don't bother checking that it returns an Awaiter.
{
Error(diagnostics, ErrorCode.ERR_BadAwaitArg, node, expression.Type);
getAwaiterMethod = null;
getAwaiterCall = null;
return false;
}
......
......@@ -30,7 +30,7 @@ internal sealed class ForEachEnumeratorInfo
public readonly bool IsAsync;
// When async and needs disposal, this stores the information to await the DisposeAsync() invocation
public AwaitableInfo DisposeAwaitableInfo;
public readonly BoundAwaitableInfo DisposeAwaitableInfo;
// When using pattern-based Dispose, this stores the method to invoke to Dispose
public readonly MethodSymbol DisposeMethod;
......@@ -51,7 +51,7 @@ internal sealed class ForEachEnumeratorInfo
MethodSymbol moveNextMethod,
bool isAsync,
bool needsDisposal,
AwaitableInfo disposeAwaitableInfo,
BoundAwaitableInfo disposeAwaitableInfo,
MethodSymbol disposeMethod,
Conversion collectionConversion,
Conversion currentConversion,
......@@ -92,7 +92,7 @@ internal struct Builder
public bool IsAsync;
public bool NeedsDisposal;
public AwaitableInfo DisposeAwaitableInfo;
public BoundAwaitableInfo DisposeAwaitableInfo;
public MethodSymbol DisposeMethod;
public Conversion CollectionConversion;
......
......@@ -205,7 +205,7 @@ private BoundForEachStatement BindForEachPartsWorker(DiagnosticBag diagnostics,
// These occur when special types are missing or malformed, or the patterns are incompletely implemented.
hasErrors |= builder.IsIncomplete;
AwaitableInfo awaitInfo = null;
BoundAwaitableInfo awaitInfo = null;
MethodSymbol getEnumeratorMethod = builder.GetEnumeratorMethod;
if (getEnumeratorMethod != null)
{
......@@ -213,12 +213,14 @@ private BoundForEachStatement BindForEachPartsWorker(DiagnosticBag diagnostics,
}
if (IsAsync)
{
var placeholder = new BoundAwaitableValuePlaceholder(_syntax.Expression, builder.MoveNextMethod?.ReturnType ?? CreateErrorType());
awaitInfo = BindAwaitInfo(placeholder, _syntax.Expression, _syntax.AwaitKeyword.GetLocation(), diagnostics, ref hasErrors);
var expr = _syntax.Expression;
ReportBadAwaitDiagnostics(expr, _syntax.AwaitKeyword.GetLocation(), diagnostics, ref hasErrors);
var placeholder = new BoundAwaitableValuePlaceholder(expr, valEscape: this.LocalScopeDepth, builder.MoveNextMethod?.ReturnType ?? CreateErrorType());
awaitInfo = BindAwaitInfo(placeholder, expr, diagnostics, ref hasErrors);
if (!hasErrors && awaitInfo.GetResult?.ReturnType.SpecialType != SpecialType.System_Boolean)
{
diagnostics.Add(ErrorCode.ERR_BadGetAsyncEnumerator, _syntax.Expression.Location, getEnumeratorMethod.ReturnTypeWithAnnotations, getEnumeratorMethod);
diagnostics.Add(ErrorCode.ERR_BadGetAsyncEnumerator, expr.Location, getEnumeratorMethod.ReturnTypeWithAnnotations, getEnumeratorMethod);
hasErrors = true;
}
}
......@@ -521,10 +523,12 @@ private bool GetAwaitDisposeAsyncInfo(ref ForEachEnumeratorInfo.Builder builder,
? this.GetWellKnownType(WellKnownType.System_Threading_Tasks_ValueTask, diagnostics, this._syntax)
: builder.DisposeMethod.ReturnType;
var placeholder = new BoundAwaitableValuePlaceholder(_syntax.Expression, awaitableType);
bool hasErrors = false;
builder.DisposeAwaitableInfo = BindAwaitInfo(placeholder, _syntax.Expression, _syntax.AwaitKeyword.GetLocation(), diagnostics, ref hasErrors);
var expr = _syntax.Expression;
ReportBadAwaitDiagnostics(expr, _syntax.AwaitKeyword.GetLocation(), diagnostics, ref hasErrors);
var placeholder = new BoundAwaitableValuePlaceholder(expr, valEscape: this.LocalScopeDepth, awaitableType);
builder.DisposeAwaitableInfo = BindAwaitInfo(placeholder, expr, diagnostics, ref hasErrors);
return hasErrors;
}
......
......@@ -96,19 +96,18 @@ internal static BoundStatement BindUsingStatementOrDeclarationFromParts(SyntaxNo
Debug.Assert((object)disposableInterface != null);
bool hasErrors = ReportUseSiteDiagnostics(disposableInterface, diagnostics, hasAwait ? awaitKeyword : usingKeyword);
Conversion iDisposableConversion = Conversion.NoConversion;
Conversion iDisposableConversion;
ImmutableArray<BoundLocalDeclaration> declarationsOpt = default;
BoundMultipleLocalDeclarations multipleDeclarationsOpt = null;
BoundExpression expressionOpt = null;
AwaitableInfo awaitOpt = null;
TypeSymbol declarationTypeOpt = null;
MethodSymbol disposeMethodOpt = null;
TypeSymbol awaitableTypeOpt = null;
MethodSymbol disposeMethodOpt;
TypeSymbol awaitableTypeOpt;
if (isExpression)
{
expressionOpt = usingBinderOpt.BindTargetExpression(diagnostics, originalBinder);
hasErrors |= !populateDisposableConversionOrDisposeMethod(fromExpression: true);
hasErrors |= !populateDisposableConversionOrDisposeMethod(fromExpression: true, out iDisposableConversion, out disposeMethodOpt, out awaitableTypeOpt);
}
else
{
......@@ -122,28 +121,31 @@ internal static BoundStatement BindUsingStatementOrDeclarationFromParts(SyntaxNo
if (declarationTypeOpt.IsDynamic())
{
iDisposableConversion = Conversion.ImplicitDynamic;
disposeMethodOpt = null;
awaitableTypeOpt = null;
}
else
{
hasErrors |= !populateDisposableConversionOrDisposeMethod(fromExpression: false);
hasErrors |= !populateDisposableConversionOrDisposeMethod(fromExpression: false, out iDisposableConversion, out disposeMethodOpt, out awaitableTypeOpt);
}
}
BoundAwaitableInfo awaitOpt = null;
if (hasAwait)
{
BoundAwaitableValuePlaceholder placeholderOpt;
// even if we don't have a proper value to await, we'll still report bad usages of `await`
originalBinder.ReportBadAwaitDiagnostics(syntax, awaitKeyword.GetLocation(), diagnostics, ref hasErrors);
if (awaitableTypeOpt is null)
{
placeholderOpt = null;
awaitOpt = new BoundAwaitableInfo(syntax, awaitableInstancePlaceholder: null, isDynamic: true, getAwaiter: null, isCompleted: null, getResult: null) { WasCompilerGenerated = true };
}
else
{
hasErrors |= ReportUseSiteDiagnostics(awaitableTypeOpt, diagnostics, awaitKeyword);
placeholderOpt = new BoundAwaitableValuePlaceholder(syntax, awaitableTypeOpt).MakeCompilerGenerated();
var placeholder = new BoundAwaitableValuePlaceholder(syntax, valEscape: originalBinder.LocalScopeDepth, awaitableTypeOpt).MakeCompilerGenerated();
awaitOpt = originalBinder.BindAwaitInfo(placeholder, syntax, diagnostics, ref hasErrors);
}
// even if we don't have a proper value to await, we'll still report bad usages of `await`
awaitOpt = originalBinder.BindAwaitInfo(placeholderOpt, syntax, awaitKeyword.GetLocation(), diagnostics, ref hasErrors);
}
// This is not awesome, but its factored.
......@@ -168,11 +170,12 @@ internal static BoundStatement BindUsingStatementOrDeclarationFromParts(SyntaxNo
hasErrors);
}
// initializes iDisposableConversion, awaitableTypeOpt and disposeMethodOpt
bool populateDisposableConversionOrDisposeMethod(bool fromExpression)
bool populateDisposableConversionOrDisposeMethod(bool fromExpression, out Conversion iDisposableConversion, out MethodSymbol disposeMethodOpt, out TypeSymbol awaitableTypeOpt)
{
HashSet<DiagnosticInfo> useSiteDiagnostics = null;
iDisposableConversion = classifyConversion(fromExpression, disposableInterface, ref useSiteDiagnostics);
disposeMethodOpt = null;
awaitableTypeOpt = null;
diagnostics.Add(syntax, useSiteDiagnostics);
......
......@@ -77,7 +77,6 @@
</AbstractNode>
<AbstractNode Name="BoundValuePlaceholderBase" Base="BoundExpression">
<Field Name="Type" Type="TypeSymbol" Override="true" Null="disallow"/>
</AbstractNode>
<!--
......@@ -85,6 +84,7 @@
It is used to perform intermediate binding, and will not survive the local rewriting.
-->
<Node Name="BoundDeconstructValuePlaceholder" Base="BoundValuePlaceholderBase">
<Field Name="Type" Type="TypeSymbol" Override="true" Null="disallow"/>
<Field Name="ValEscape" Type="uint" Null="NotApplicable"/>
</Node>
......@@ -99,9 +99,10 @@
<!--
This node is used to represent an awaitable expression of a certain type, when binding an using-await statement.
It does not survive past initial binding.
-->
<Node Name="BoundAwaitableValuePlaceholder" Base="BoundValuePlaceholderBase">
<Field Name="Type" Type="TypeSymbol" Override="true" Null="allow"/>
<Field Name="ValEscape" Type="uint" Null="NotApplicable"/>
</Node>
<!--
......@@ -109,10 +110,12 @@
It does not survive past initial binding.
-->
<Node Name="BoundDisposableValuePlaceholder" Base="BoundValuePlaceholderBase">
<Field Name="Type" Type="TypeSymbol" Override="true" Null="disallow"/>
</Node>
<!-- The implicit collection in an object or collection initializer expression. -->
<Node Name="BoundObjectOrCollectionValuePlaceholder" Base="BoundValuePlaceholderBase">
<Field Name="Type" Type="TypeSymbol" Override="true" Null="disallow"/>
</Node>
<!-- only used by codegen -->
......@@ -527,12 +530,21 @@
<Field Name="Expression" Type="BoundExpression"/>
</Node>
<Node Name="BoundAwaitableInfo" Base="BoundNode">
<!-- Used to refer to the awaitable expression in GetAwaiter -->
<Field Name="AwaitableInstancePlaceholder" Type="BoundAwaitableValuePlaceholder?" Null="allow" />
<Field Name="IsDynamic" Type="bool"/>
<Field Name="GetAwaiter" Type="BoundExpression?" Null="allow"/>
<Field Name="IsCompleted" Type="PropertySymbol?" Null="allow"/>
<Field Name="GetResult" Type="MethodSymbol?" Null="allow"/>
</Node>
<Node Name="BoundAwaitExpression" Base="BoundExpression">
<!-- Non-null type is required for this node kind -->
<Field Name="Type" Type="TypeSymbol" Override="true" Null="disallow"/>
<Field Name="Expression" Type="BoundExpression"/>
<Field Name="AwaitableInfo" Type="AwaitableInfo" Null="disallow"/>
<Field Name="AwaitableInfo" Type="BoundAwaitableInfo" Null="disallow"/>
</Node>
<AbstractNode Name="BoundTypeOf" Base="BoundExpression">
......@@ -842,7 +854,7 @@
<Node Name="BoundUsingLocalDeclarations" Base="BoundMultipleLocalDeclarations">
<Field Name="DisposeMethodOpt" Type="MethodSymbol" Null="Allow"/>
<Field Name="IDisposableConversion" Type="Conversion"/>
<Field Name="AwaitOpt" Type="AwaitableInfo?"/>
<Field Name="AwaitOpt" Type="BoundAwaitableInfo?"/>
</Node>
<!--
......@@ -990,7 +1002,7 @@
boxing. If this node has errors, then the conversion may not be present.-->
<Field Name="Expression" Type="BoundExpression"/>
<Field Name="DeconstructionOpt" Type="BoundForEachDeconstructStep?"/>
<Field Name="AwaitOpt" Type="AwaitableInfo?"/>
<Field Name="AwaitOpt" Type="BoundAwaitableInfo?"/>
<Field Name="Body" Type="BoundStatement"/>
<Field Name="Checked" Type="bool"/>
......@@ -1009,7 +1021,7 @@
<Field Name="ExpressionOpt" Type="BoundExpression?"/>
<Field Name="IDisposableConversion" Type="Conversion" />
<Field Name="Body" Type="BoundStatement"/>
<Field Name="AwaitOpt" Type="AwaitableInfo?"/>
<Field Name="AwaitOpt" Type="BoundAwaitableInfo?"/>
<Field Name="DisposeMethodOpt" Type="MethodSymbol?"/>
</Node>
......
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using Roslyn.Utilities;
namespace Microsoft.CodeAnalysis.CSharp
......@@ -11,20 +10,20 @@ namespace Microsoft.CodeAnalysis.CSharp
/// </summary>
public struct AwaitExpressionInfo : IEquatable<AwaitExpressionInfo>
{
private readonly AwaitableInfo _awaitableInfo;
public IMethodSymbol GetAwaiterMethod { get; }
public IMethodSymbol GetAwaiterMethod => _awaitableInfo?.GetAwaiter;
public IPropertySymbol IsCompletedProperty { get; }
public IPropertySymbol IsCompletedProperty => _awaitableInfo?.IsCompleted;
public IMethodSymbol GetResultMethod { get; }
public IMethodSymbol GetResultMethod => _awaitableInfo?.GetResult;
public bool IsDynamic { get; }
public bool IsDynamic => _awaitableInfo?.IsDynamic == true;
internal AwaitExpressionInfo(AwaitableInfo awaitableInfo)
internal AwaitExpressionInfo(IMethodSymbol getAwaiter, IPropertySymbol isCompleted, IMethodSymbol getResult, bool isDynamic)
{
Debug.Assert(awaitableInfo != null);
_awaitableInfo = awaitableInfo;
GetAwaiterMethod = getAwaiter;
IsCompletedProperty = isCompleted;
GetResultMethod = getResult;
IsDynamic = isDynamic;
}
public override bool Equals(object obj)
......@@ -36,16 +35,13 @@ public bool Equals(AwaitExpressionInfo other)
{
return object.Equals(this.GetAwaiterMethod, other.GetAwaiterMethod)
&& object.Equals(this.IsCompletedProperty, other.IsCompletedProperty)
&& object.Equals(this.GetResultMethod, other.GetResultMethod);
&& object.Equals(this.GetResultMethod, other.GetResultMethod)
&& IsDynamic == other.IsDynamic;
}
public override int GetHashCode()
{
if (_awaitableInfo is null)
{
return 0;
}
return Hash.Combine(GetAwaiterMethod, Hash.Combine(IsCompletedProperty, GetResultMethod.GetHashCode()));
return Hash.Combine(GetAwaiterMethod, Hash.Combine(IsCompletedProperty, Hash.Combine(GetResultMethod, IsDynamic.GetHashCode())));
}
}
}
......@@ -1722,7 +1722,7 @@ internal bool ReturnsAwaitableToVoidOrInt(MethodSymbol method, DiagnosticBag dia
return false;
}
// Early bail so we only even check things that are System.Threading.Tasks.Task(<T>)
// Early bail so we only ever check things that are System.Threading.Tasks.Task(<T>)
if (!(TypeSymbol.Equals(namedType.ConstructedFrom, GetWellKnownType(WellKnownType.System_Threading_Tasks_Task), TypeCompareKind.ConsiderEverything2) ||
TypeSymbol.Equals(namedType.ConstructedFrom, GetWellKnownType(WellKnownType.System_Threading_Tasks_Task_T), TypeCompareKind.ConsiderEverything2)))
{
......@@ -1733,7 +1733,7 @@ internal bool ReturnsAwaitableToVoidOrInt(MethodSymbol method, DiagnosticBag dia
var dumbInstance = new BoundLiteral(syntax, ConstantValue.Null, namedType);
var binder = GetBinder(syntax);
BoundExpression result;
var success = binder.GetAwaitableExpressionInfo(dumbInstance, out _, out _, out _, out result, syntax, diagnostics);
var success = binder.GetAwaitableExpressionInfo(dumbInstance, out result, syntax, diagnostics);
return success &&
(result.Type.IsVoidType() || result.Type.SpecialType == SpecialType.System_Int32);
......
......@@ -3323,7 +3323,6 @@ public ILocalSymbol GetDeclaredSymbol(CatchDeclarationSyntax catchDeclaration, C
case BoundKind.AwaitExpression:
var await = (BoundAwaitExpression)boundNode;
isDynamic = await.AwaitableInfo.IsDynamic;
// TODO:
goto default;
case BoundKind.ConditionalOperator:
......
......@@ -252,6 +252,11 @@ public override BoundNode VisitRangeVariable(BoundRangeVariable node)
return null;
}
public override BoundNode VisitAwaitableInfo(BoundAwaitableInfo node)
{
return null;
}
public override BoundNode VisitBinaryOperator(BoundBinaryOperator node)
{
throw ExceptionUtilities.Unreachable;
......
......@@ -892,13 +892,17 @@ public override AwaitExpressionInfo GetAwaitExpressionInfo(AwaitExpressionSyntax
}
var bound = GetUpperBoundNode(node);
BoundAwaitExpression boundAwait = ((bound as BoundExpressionStatement)?.Expression ?? bound) as BoundAwaitExpression;
if (boundAwait == null)
BoundAwaitableInfo awaitableInfo = (((bound as BoundExpressionStatement)?.Expression ?? bound) as BoundAwaitExpression)?.AwaitableInfo;
if (awaitableInfo == null)
{
return default(AwaitExpressionInfo);
}
return new AwaitExpressionInfo(boundAwait.AwaitableInfo);
return new AwaitExpressionInfo(
getAwaiter: (IMethodSymbol)awaitableInfo.GetAwaiter?.ExpressionSymbol,
isCompleted: awaitableInfo.IsCompleted,
getResult: awaitableInfo.GetResult,
isDynamic: awaitableInfo.IsDynamic);
}
public override ForEachStatementInfo GetForEachStatementInfo(ForEachStatementSyntax node)
......
......@@ -2863,6 +2863,11 @@ public override BoundNode VisitObjectOrCollectionValuePlaceholder(BoundObjectOrC
return null;
}
public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node)
{
return null;
}
public override sealed BoundNode VisitOutVariablePendingInference(OutVariablePendingInference node)
{
throw ExceptionUtilities.Unreachable;
......
......@@ -114,6 +114,7 @@ private void VerifyExpression(BoundExpression expression, bool overrideSkippedEx
public override BoundNode? VisitForEachStatement(BoundForEachStatement node)
{
Visit(node.IterationVariableType);
Visit(node.AwaitOpt);
Visit(node.Expression);
// https://github.com/dotnet/roslyn/issues/35010: handle the deconstruction
//this.Visit(node.DeconstructionOpt);
......
......@@ -172,6 +172,11 @@ public VisitArgumentResult(VisitResult visitResult, Optional<LocalState> stateFo
/// </summary>
private PooledDictionary<BoundExpression, TypeWithState> _methodGroupReceiverMapOpt;
/// <summary>
/// State of awaitable expressions, for substitution in placeholders within GetAwaiter calls.
/// </summary>
private PooledDictionary<BoundAwaitableValuePlaceholder, (BoundExpression AwaitableExpression, VisitResult Result)> _awaitablePlaceholdersOpt;
/// <summary>
/// True if we're analyzing speculative code. This turns off some initialization steps
/// that would otherwise be taken.
......@@ -328,6 +333,7 @@ private void SetAnalyzedNullability(BoundExpression expr, VisitResult result, bo
protected override void Free()
{
_awaitablePlaceholdersOpt?.Free();
_methodGroupReceiverMapOpt?.Free();
_variableTypes.Free();
_placeholderLocalsOpt?.Free();
......@@ -1560,15 +1566,23 @@ public override BoundNode VisitForStatement(BoundForStatement node)
public override BoundNode VisitForEachStatement(BoundForEachStatement node)
{
DeclareLocals(node.IterationVariables);
Visit(node.AwaitOpt);
return base.VisitForEachStatement(node);
}
public override BoundNode VisitUsingStatement(BoundUsingStatement node)
{
DeclareLocals(node.Locals);
Visit(node.AwaitOpt);
return base.VisitUsingStatement(node);
}
public override BoundNode VisitUsingLocalDeclarations(BoundUsingLocalDeclarations node)
{
Visit(node.AwaitOpt);
return base.VisitUsingLocalDeclarations(node);
}
public override BoundNode VisitFixedStatement(BoundFixedStatement node)
{
DeclareLocals(node.Locals);
......@@ -2653,6 +2667,18 @@ private static void MarkSlotsAsNotNull(ArrayBuilder<int> slots, ref LocalState s
private void LearnFromNonNullTest(BoundExpression expression, ref LocalState state)
{
if (expression.Kind == BoundKind.AwaitableValuePlaceholder)
{
if (_awaitablePlaceholdersOpt != null && _awaitablePlaceholdersOpt.TryGetValue((BoundAwaitableValuePlaceholder)expression, out var value))
{
expression = value.AwaitableExpression;
}
else
{
return;
}
}
var slotBuilder = ArrayBuilder<int>.GetInstance();
GetSlotsToMarkAsNotNullable(expression, slotBuilder);
MarkSlotsAsNotNull(slotBuilder, ref state);
......@@ -7137,7 +7163,14 @@ private TypeWithState InferResultNullabilityOfBinaryLogicalOperator(BoundExpress
public override BoundNode VisitAwaitExpression(BoundAwaitExpression node)
{
var result = base.VisitAwaitExpression(node);
_ = CheckPossibleNullReceiver(node.Expression);
var awaitableInfo = node.AwaitableInfo;
var placeholder = awaitableInfo.AwaitableInstancePlaceholder;
_awaitablePlaceholdersOpt ??= PooledDictionary<BoundAwaitableValuePlaceholder, (BoundExpression AwaitableExpression, VisitResult Result)>.GetInstance();
_awaitablePlaceholdersOpt.Add(placeholder, (node.Expression, _visitResult));
Visit(awaitableInfo);
_awaitablePlaceholdersOpt.Remove(placeholder);
if (node.Type.IsValueType || node.HasErrors || node.AwaitableInfo.GetResult is null)
{
SetNotNullResult(node);
......@@ -7145,7 +7178,7 @@ public override BoundNode VisitAwaitExpression(BoundAwaitExpression node)
else
{
// Update method based on inferred receiver type: see https://github.com/dotnet/roslyn/issues/29605.
SetResultType(node, node.AwaitableInfo.GetResult.ReturnTypeWithAnnotations.ToTypeWithState());
SetResultType(node, awaitableInfo.GetResult.ReturnTypeWithAnnotations.ToTypeWithState());
}
return result;
......@@ -7737,6 +7770,22 @@ public override BoundNode VisitObjectOrCollectionValuePlaceholder(BoundObjectOrC
return null;
}
public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node)
{
VisitResult result = _awaitablePlaceholdersOpt != null && _awaitablePlaceholdersOpt.TryGetValue(node, out var value) ?
value.Result :
new VisitResult(TypeWithState.Create(node.Type, default));
SetResult(node, result.RValueType, result.LValueType);
return null;
}
public override BoundNode VisitAwaitableInfo(BoundAwaitableInfo node)
{
Visit(node.AwaitableInstancePlaceholder);
Visit(node.GetAwaiter);
return null;
}
protected override string Dump(LocalState state)
{
if (!state.Reachable)
......
......@@ -58,6 +58,8 @@ internal class AsyncMethodToStateMachineRewriter : MethodToStateMachineRewriter
private readonly Dictionary<TypeSymbol, FieldSymbol> _awaiterFields;
private int _nextAwaiterId;
private readonly Dictionary<BoundValuePlaceholderBase, BoundExpression> _placeholderMap;
internal AsyncMethodToStateMachineRewriter(
MethodSymbol method,
int methodOrdinal,
......@@ -86,6 +88,8 @@ internal class AsyncMethodToStateMachineRewriter : MethodToStateMachineRewriter
_dynamicFactory = new LoweredDynamicOperationFactory(F, methodOrdinal);
_awaiterFields = new Dictionary<TypeSymbol, FieldSymbol>(TypeSymbol.EqualsIgnoringDynamicTupleNamesAndNullabilityComparer);
_nextAwaiterId = slotAllocatorOpt?.PreviousAwaiterSlotCount ?? 0;
_placeholderMap = new Dictionary<BoundValuePlaceholderBase, BoundExpression>();
}
private FieldSymbol GetAwaiterField(TypeSymbol awaiterType)
......@@ -288,22 +292,37 @@ public sealed override BoundNode VisitBadExpression(BoundBadExpression node)
private BoundBlock VisitAwaitExpression(BoundAwaitExpression node, BoundExpression resultPlace)
{
var expression = (BoundExpression)Visit(node.Expression);
var awaitablePlaceholder = node.AwaitableInfo.AwaitableInstancePlaceholder;
if (awaitablePlaceholder != null)
{
_placeholderMap.Add(awaitablePlaceholder, expression);
}
var getAwaiter = node.AwaitableInfo.IsDynamic ?
MakeCallMaybeDynamic(expression, null, WellKnownMemberNames.GetAwaiter) :
(BoundExpression)Visit(node.AwaitableInfo.GetAwaiter);
resultPlace = (BoundExpression)Visit(resultPlace);
MethodSymbol getAwaiter = VisitMethodSymbol(node.AwaitableInfo.GetAwaiter);
MethodSymbol getResult = VisitMethodSymbol(node.AwaitableInfo.GetResult);
MethodSymbol isCompletedMethod = ((object)node.AwaitableInfo.IsCompleted != null) ? VisitMethodSymbol(node.AwaitableInfo.IsCompleted.GetMethod) : null;
TypeSymbol type = VisitType(node.Type);
if (awaitablePlaceholder != null)
{
_placeholderMap.Remove(awaitablePlaceholder);
}
// The awaiter temp facilitates EnC method remapping and thus have to be long-lived.
// It transfers the awaiter objects from the old version of the MoveNext method to the new one.
Debug.Assert(node.Syntax.IsKind(SyntaxKind.AwaitExpression) || node.WasCompilerGenerated);
TypeSymbol awaiterType = node.AwaitableInfo.IsDynamic ? DynamicTypeSymbol.Instance : getAwaiter.ReturnType;
var awaiterTemp = F.SynthesizedLocal(awaiterType, syntax: node.Syntax, kind: SynthesizedLocalKind.Awaiter);
var awaiterTemp = F.SynthesizedLocal(getAwaiter.Type, syntax: node.Syntax, kind: SynthesizedLocalKind.Awaiter);
var awaitIfIncomplete = F.Block(
// temp $awaiterTemp = <expr>.GetAwaiter();
F.Assignment(
F.Local(awaiterTemp),
MakeCallMaybeDynamic(expression, getAwaiter, WellKnownMemberNames.GetAwaiter)),
getAwaiter),
// hidden sequence point facilitates EnC method remapping, see explanation on SynthesizedLocalKind.Awaiter:
F.HiddenSequencePoint(),
......@@ -329,6 +348,11 @@ private BoundBlock VisitAwaitExpression(BoundAwaitExpression node, BoundExpressi
getResultStatement);
}
public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node)
{
return _placeholderMap[node];
}
private BoundExpression MakeCallMaybeDynamic(
BoundExpression receiver,
MethodSymbol methodSymbol = null,
......
......@@ -875,12 +875,6 @@ public override BoundNode VisitUsingStatement(BoundUsingStatement node)
return null;
}
public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node)
{
Fail(node);
return null;
}
public override BoundNode VisitIfStatement(BoundIfStatement node)
{
Fail(node);
......
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.CodeAnalysis.CSharp.Symbols;
namespace Microsoft.CodeAnalysis.CSharp
......@@ -18,7 +17,7 @@ public BoundExpression VisitAwaitExpression(BoundAwaitExpression node, bool used
return RewriteAwaitExpression((BoundExpression)base.VisitAwaitExpression(node), used);
}
private BoundExpression RewriteAwaitExpression(SyntaxNode syntax, BoundExpression rewrittenExpression, AwaitableInfo awaitableInfo, TypeSymbol type, bool used)
private BoundExpression RewriteAwaitExpression(SyntaxNode syntax, BoundExpression rewrittenExpression, BoundAwaitableInfo awaitableInfo, TypeSymbol type, bool used)
{
return RewriteAwaitExpression(new BoundAwaitExpression(syntax, rewrittenExpression, awaitableInfo, type) { WasCompilerGenerated = true }, used);
}
......
......@@ -400,7 +400,7 @@ private bool TryGetDisposeMethod(CommonForEachStatementSyntax forEachSyntax, For
/// Produce:
/// await /* disposeCall */;
/// </summary>
private BoundStatement WrapWithAwait(CommonForEachStatementSyntax forEachSyntax, BoundExpression disposeCall, AwaitableInfo disposeAwaitableInfoOpt)
private BoundStatement WrapWithAwait(CommonForEachStatementSyntax forEachSyntax, BoundExpression disposeCall, BoundAwaitableInfo disposeAwaitableInfoOpt)
{
TypeSymbol awaitExpressionType = disposeAwaitableInfoOpt.GetResult?.ReturnType ?? _compilation.DynamicType;
var awaitExpr = RewriteAwaitExpression(forEachSyntax, disposeCall, disposeAwaitableInfoOpt, awaitExpressionType, used: false);
......
......@@ -4,7 +4,6 @@
using System.Diagnostics;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.PooledObjects;
namespace Microsoft.CodeAnalysis.CSharp
{
......@@ -60,7 +59,7 @@ public override BoundNode VisitUsingStatement(BoundUsingStatement node)
ImmutableArray<BoundLocalDeclaration> declarations,
Conversion iDisposableConversion,
MethodSymbol disposeMethodOpt,
AwaitableInfo awaitOpt,
BoundAwaitableInfo awaitOpt,
SyntaxToken awaitKeyword)
{
Debug.Assert(declarations != null);
......@@ -196,7 +195,7 @@ private BoundBlock MakeExpressionUsingStatement(BoundUsingStatement node, BoundB
/// Assumes that the local symbol will be declared (i.e. in the LocalsOpt array) of an enclosing block.
/// Assumes that using statements with multiple locals have already been split up into multiple using statements.
/// </remarks>
private BoundBlock RewriteDeclarationUsingStatement(SyntaxNode usingSyntax, BoundLocalDeclaration localDeclaration, BoundBlock tryBlock, Conversion iDisposableConversion, SyntaxToken awaitKeywordOpt, AwaitableInfo awaitOpt, MethodSymbol methodSymbol)
private BoundBlock RewriteDeclarationUsingStatement(SyntaxNode usingSyntax, BoundLocalDeclaration localDeclaration, BoundBlock tryBlock, Conversion iDisposableConversion, SyntaxToken awaitKeywordOpt, BoundAwaitableInfo awaitOpt, MethodSymbol methodSymbol)
{
SyntaxNode declarationSyntax = localDeclaration.Syntax;
......@@ -253,7 +252,7 @@ private BoundBlock RewriteDeclarationUsingStatement(SyntaxNode usingSyntax, Boun
}
}
private BoundStatement RewriteUsingStatementTryFinally(SyntaxNode syntax, BoundBlock tryBlock, BoundLocal local, SyntaxToken awaitKeywordOpt, AwaitableInfo awaitOpt, MethodSymbol methodOpt)
private BoundStatement RewriteUsingStatementTryFinally(SyntaxNode syntax, BoundBlock tryBlock, BoundLocal local, SyntaxToken awaitKeywordOpt, BoundAwaitableInfo awaitOpt, MethodSymbol methodOpt)
{
// SPEC: When ResourceType is a non-nullable value type, the expansion is:
// SPEC:
......@@ -324,7 +323,7 @@ private BoundStatement RewriteUsingStatementTryFinally(SyntaxNode syntax, BoundB
// "{ dynamic temp1 = x; IDisposable temp2 = (IDisposable) temp1; ... }". Rather, we elide
// the completely unnecessary first temporary.
Debug.Assert((awaitKeywordOpt == default) == (awaitOpt == default(AwaitableInfo)));
Debug.Assert((awaitKeywordOpt == default) == (awaitOpt is null));
BoundExpression disposedExpression;
bool isNullableValueType = local.Type.IsNullableType();
......@@ -396,7 +395,7 @@ private BoundStatement RewriteUsingStatementTryFinally(SyntaxNode syntax, BoundB
return tryFinally;
}
private BoundExpression GenerateDisposeCall(SyntaxNode syntax, BoundExpression disposedExpression, MethodSymbol methodOpt, AwaitableInfo awaitOpt, SyntaxToken awaitKeyword)
private BoundExpression GenerateDisposeCall(SyntaxNode syntax, BoundExpression disposedExpression, MethodSymbol methodOpt, BoundAwaitableInfo awaitOpt, SyntaxToken awaitKeyword)
{
Debug.Assert(awaitOpt is null || awaitKeyword != default);
......
......@@ -48,6 +48,8 @@ internal abstract partial class MethodToClassRewriter : BoundTreeRewriterWithSta
protected readonly DiagnosticBag Diagnostics;
protected readonly VariableSlotAllocator slotAllocatorOpt;
private readonly Dictionary<BoundValuePlaceholderBase, BoundExpression> _placeholderMap;
protected MethodToClassRewriter(VariableSlotAllocator slotAllocatorOpt, TypeCompilationState compilationState, DiagnosticBag diagnostics)
{
Debug.Assert(compilationState != null);
......@@ -56,6 +58,7 @@ protected MethodToClassRewriter(VariableSlotAllocator slotAllocatorOpt, TypeComp
this.CompilationState = compilationState;
this.Diagnostics = diagnostics;
this.slotAllocatorOpt = slotAllocatorOpt;
this._placeholderMap = new Dictionary<BoundValuePlaceholderBase, BoundExpression>();
}
/// <summary>
......@@ -356,16 +359,29 @@ private BoundNode VisitUnhoistedLocal(BoundLocal node)
return base.VisitLocal(node);
}
public override BoundNode VisitAwaitExpression(BoundAwaitExpression node)
public override BoundNode VisitAwaitableInfo(BoundAwaitableInfo node)
{
BoundExpression expression = (BoundExpression)this.Visit(node.Expression);
TypeSymbol type = this.VisitType(node.Type);
var awaitablePlaceholder = node.AwaitableInstancePlaceholder;
if (awaitablePlaceholder is null)
{
return node;
}
AwaitableInfo info = node.AwaitableInfo;
return node.Update(
expression,
info.Update(VisitMethodSymbol(info.GetAwaiter), VisitPropertySymbol(info.IsCompleted), VisitMethodSymbol(info.GetResult)),
type);
var rewrittenPlaceholder = awaitablePlaceholder.Update(awaitablePlaceholder.ValEscape, VisitType(awaitablePlaceholder.Type));
_placeholderMap.Add(awaitablePlaceholder, rewrittenPlaceholder);
var getAwaiter = (BoundExpression)this.Visit(node.GetAwaiter);
var isCompleted = VisitPropertySymbol(node.IsCompleted);
var getResult = VisitMethodSymbol(node.GetResult);
_placeholderMap.Remove(awaitablePlaceholder);
return node.Update(rewrittenPlaceholder, node.IsDynamic, getAwaiter, isCompleted, getResult);
}
public override BoundNode VisitAwaitableValuePlaceholder(BoundAwaitableValuePlaceholder node)
{
return _placeholderMap[node];
}
public override BoundNode VisitAssignmentOperator(BoundAssignmentOperator node)
......
......@@ -349,7 +349,7 @@ internal sealed class AsyncForwardEntryPoint : SynthesizedEntryPointSymbol
// The diagnostics that would be produced here will already have been captured and returned.
var droppedBag = DiagnosticBag.GetInstance();
var success = binder.GetAwaitableExpressionInfo(userMainInvocation, out _, out _, out _, out _getAwaiterGetResultCall, _userMainReturnTypeSyntax, droppedBag);
var success = binder.GetAwaitableExpressionInfo(userMainInvocation, out _getAwaiterGetResultCall, _userMainReturnTypeSyntax, droppedBag);
droppedBag.Free();
Debug.Assert(
......@@ -466,7 +466,7 @@ internal override BoundBlock CreateBody(DiagnosticBag diagnostics)
var initializeCall = CreateParameterlessCall(syntax, scriptLocal, initializer);
BoundExpression getAwaiterGetResultCall;
if (!binder.GetAwaitableExpressionInfo(initializeCall, out _, out _, out _, out getAwaiterGetResultCall, syntax, diagnostics))
if (!binder.GetAwaitableExpressionInfo(initializeCall, out getAwaiterGetResultCall, syntax, diagnostics))
{
return new BoundBlock(
syntax: syntax,
......
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.CodeAnalysis.CSharp.Test.Utilities;
using Xunit;
using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata;
using System.Reflection.PortableExecutable;
using System.Threading;
using Microsoft.CodeAnalysis.CSharp.Test.Utilities;
using Microsoft.CodeAnalysis.Test.Utilities;
using System.Reflection.PortableExecutable;
using System.Reflection.Metadata;
using System.Linq;
using Roslyn.Test.Utilities;
using System.Collections.Generic;
using Xunit;
namespace Microsoft.CodeAnalysis.CSharp.UnitTests.CodeGen
{
......@@ -467,7 +467,6 @@ class Program {
Diagnostic(ErrorCode.ERR_ReturnExpected, "Main").WithArguments("Program.Main()").WithLocation(6, 28));
}
[Fact]
public void AsyncEmitMainOfIntTest_StringArgs()
{
......@@ -626,7 +625,6 @@ static Task Main(string[] args)
CompileAndVerify(compilation, expectedReturnCode: 0);
}
[Fact]
public void MainCantBeAsyncWithRefTask()
{
......@@ -1957,5 +1955,108 @@ .maxstack 3
IL_013f: ret
}");
}
[Fact]
public void ValueTask()
{
var source =
@"using System.Threading.Tasks;
class Program
{
static async ValueTask Main()
{
await Task.Delay(0);
}
}";
var comp = CreateCompilationWithTasksExtensions(source, options: TestOptions.ReleaseExe);
comp.VerifyDiagnostics(
// error CS5001: Program does not contain a static 'Main' method suitable for an entry point
Diagnostic(ErrorCode.ERR_NoEntryPoint).WithLocation(1, 1),
// (4,28): warning CS0028: 'Program.Main()' has the wrong signature to be an entry point
// static async ValueTask Main()
Diagnostic(ErrorCode.WRN_InvalidMainSig, "Main").WithArguments("Program.Main()").WithLocation(4, 28));
}
[Fact]
public void ValueTaskOfInt()
{
var source =
@"using System.Threading.Tasks;
class Program
{
static async ValueTask<int> Main()
{
await Task.Delay(0);
return 0;
}
}";
var comp = CreateCompilationWithTasksExtensions(source, options: TestOptions.ReleaseExe);
comp.VerifyDiagnostics(
// error CS5001: Program does not contain a static 'Main' method suitable for an entry point
Diagnostic(ErrorCode.ERR_NoEntryPoint).WithLocation(1, 1),
// (4,33): warning CS0028: 'Program.Main()' has the wrong signature to be an entry point
// static async ValueTask<int> Main()
Diagnostic(ErrorCode.WRN_InvalidMainSig, "Main").WithArguments("Program.Main()").WithLocation(4, 33));
}
[Fact]
public void TasklikeType()
{
var source =
@"using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
namespace System.Runtime.CompilerServices
{
class AsyncMethodBuilderAttribute : System.Attribute
{
public AsyncMethodBuilderAttribute(System.Type t) { }
}
}
[AsyncMethodBuilder(typeof(MyTaskMethodBuilder))]
struct MyTask
{
internal Awaiter GetAwaiter() => new Awaiter();
internal class Awaiter : INotifyCompletion
{
public void OnCompleted(Action a) { }
internal bool IsCompleted => true;
internal void GetResult() { }
}
}
struct MyTaskMethodBuilder
{
private MyTask _task;
public static MyTaskMethodBuilder Create() => new MyTaskMethodBuilder(new MyTask());
internal MyTaskMethodBuilder(MyTask task)
{
_task = task;
}
public void SetStateMachine(IAsyncStateMachine stateMachine) { }
public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine
{
stateMachine.MoveNext();
}
public void SetException(Exception e) { }
public void SetResult() { }
public void AwaitOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine) where TAwaiter : INotifyCompletion where TStateMachine : IAsyncStateMachine { }
public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine) where TAwaiter : ICriticalNotifyCompletion where TStateMachine : IAsyncStateMachine { }
public MyTask Task => _task;
}
class Program
{
static async MyTask Main()
{
await Task.Delay(0);
}
}";
var comp = CreateCompilation(source, options: TestOptions.ReleaseExe);
comp.VerifyDiagnostics(
// error CS5001: Program does not contain a static 'Main' method suitable for an entry point
Diagnostic(ErrorCode.ERR_NoEntryPoint).WithLocation(1, 1),
// (43,25): warning CS0028: 'Program.Main()' has the wrong signature to be an entry point
// static async MyTask Main()
Diagnostic(ErrorCode.WRN_InvalidMainSig, "Main").WithArguments("Program.Main()").WithLocation(43, 25));
}
}
}
......@@ -5541,5 +5541,71 @@ class Program
Diagnostic(ErrorCode.ERR_IllegalStatement, "b ? await Task.Delay(1) : await Task.Delay(2)").WithLocation(8, 9)
);
}
[Fact]
[WorkItem(30956, "https://github.com/dotnet/roslyn/issues/30956")]
public void GetAwaiterBoxingConversion_01()
{
var source =
@"using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
interface IAwaitable { }
struct StructAwaitable : IAwaitable { }
static class Extensions
{
public static TaskAwaiter GetAwaiter(this IAwaitable x)
{
if (x == null) throw new ArgumentNullException(nameof(x));
Console.Write(x);
return Task.CompletedTask.GetAwaiter();
}
}
class Program
{
static async Task Main()
{
await new StructAwaitable();
}
}";
var comp = CSharpTestBase.CreateCompilation(source, options: TestOptions.ReleaseExe);
CompileAndVerify(comp, expectedOutput: "StructAwaitable");
}
[Fact]
[WorkItem(30956, "https://github.com/dotnet/roslyn/issues/30956")]
public void GetAwaiterBoxingConversion_02()
{
var source =
@"using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
struct StructAwaitable { }
static class Extensions
{
public static TaskAwaiter GetAwaiter(this object x)
{
if (x == null) throw new ArgumentNullException(nameof(x));
Console.Write(x);
return Task.CompletedTask.GetAwaiter();
}
}
class Program
{
static async Task Main()
{
StructAwaitable? s = new StructAwaitable();
await s;
}
}";
var comp = CSharpTestBase.CreateCompilation(source, options: TestOptions.ReleaseExe);
CompileAndVerify(comp, expectedOutput: "StructAwaitable");
}
}
}
......@@ -4917,5 +4917,59 @@ public async Task<int> DisposeAsync()
comp.VerifyDiagnostics();
CompileAndVerify(comp, expectedOutput: "MoveNextAsync DisposeAsync Done");
}
[Fact]
[WorkItem(30956, "https://github.com/dotnet/roslyn/issues/30956")]
public void GetAwaiterBoxingConversion()
{
var source =
@"using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
interface I1 { }
interface I2 { }
struct StructAwaitable1 : I1 { }
struct StructAwaitable2 : I2 { }
class Enumerable
{
public Enumerator GetAsyncEnumerator() => new Enumerator();
internal class Enumerator
{
public object Current => null;
public StructAwaitable1 MoveNextAsync() => new StructAwaitable1();
public StructAwaitable2 DisposeAsync() => new StructAwaitable2();
}
}
static class Extensions
{
internal static TaskAwaiter<bool> GetAwaiter(this I1 x)
{
if (x == null) throw new ArgumentNullException(nameof(x));
Console.Write(x);
return Task.FromResult(false).GetAwaiter();
}
internal static TaskAwaiter GetAwaiter(this I2 x)
{
if (x == null) throw new ArgumentNullException(nameof(x));
Console.Write(x);
return Task.CompletedTask.GetAwaiter();
}
}
class Program
{
static async Task Main()
{
await foreach (var o in new Enumerable())
{
}
}
}";
var comp = CreateCompilationWithTasksExtensions(new[] { source, s_IAsyncEnumerable }, options: TestOptions.ReleaseExe);
CompileAndVerify(comp, expectedOutput: "StructAwaitable1StructAwaitable2");
}
}
}
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using ICSharpCode.Decompiler.DebugInfo;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Test.Utilities;
using Microsoft.CodeAnalysis.Test.Utilities;
......@@ -1389,6 +1388,42 @@ .maxstack 3
}");
}
[Fact]
public void Struct_ExplicitImplementation()
{
string source =
@"using System;
using System.Threading.Tasks;
class C
{
internal bool _disposed;
}
struct S : IAsyncDisposable
{
C _c;
S(C c)
{
_c = c;
}
static async Task Main()
{
var s = new S(new C());
await using (s)
{
}
Console.WriteLine(s._c._disposed);
}
ValueTask IAsyncDisposable.DisposeAsync()
{
_c._disposed = true;
return new ValueTask(Task.CompletedTask);
}
}";
var comp = CreateCompilationWithTasksExtensions(new[] { source, s_interfaces }, options: TestOptions.DebugExe);
comp.VerifyDiagnostics();
CompileAndVerify(comp, expectedOutput: "True");
}
[Fact]
public void TestWithNullableExpression()
{
......@@ -1617,19 +1652,27 @@ public class D
var getAwaiter1 = (MethodSymbol)comp.GetMember("C.GetAwaiter");
var isCompleted1 = (PropertySymbol)comp.GetMember("C.IsCompleted");
var getResult1 = (MethodSymbol)comp.GetMember("C.GetResult");
var first = new AwaitExpressionInfo(new AwaitableInfo(getAwaiter1, isCompleted1, getResult1));
var first = new AwaitExpressionInfo(getAwaiter1, isCompleted1, getResult1, false);
var nulls1 = new AwaitExpressionInfo(new AwaitableInfo(null, isCompleted1, getResult1));
var nulls2 = new AwaitExpressionInfo(new AwaitableInfo(getAwaiter1, null, getResult1));
var nulls3 = new AwaitExpressionInfo(new AwaitableInfo(getAwaiter1, isCompleted1, null));
var nulls1 = new AwaitExpressionInfo(null, isCompleted1, getResult1, false);
var nulls2 = new AwaitExpressionInfo(getAwaiter1, null, getResult1, false);
var nulls3 = new AwaitExpressionInfo(getAwaiter1, isCompleted1, null, false);
var nulls4 = new AwaitExpressionInfo(getAwaiter1, isCompleted1, null, true);
Assert.False(first.Equals(nulls1));
Assert.False(first.Equals(nulls2));
Assert.False(first.Equals(nulls3));
Assert.False(first.Equals(nulls4));
Assert.False(nulls1.Equals(first));
Assert.False(nulls2.Equals(first));
Assert.False(nulls3.Equals(first));
Assert.False(nulls4.Equals(first));
_ = nulls1.GetHashCode();
_ = nulls2.GetHashCode();
_ = nulls3.GetHashCode();
_ = nulls4.GetHashCode();
object nullObj = null;
Assert.False(first.Equals(nullObj));
......@@ -1637,10 +1680,10 @@ public class D
var getAwaiter2 = (MethodSymbol)comp.GetMember("D.GetAwaiter");
var isCompleted2 = (PropertySymbol)comp.GetMember("D.IsCompleted");
var getResult2 = (MethodSymbol)comp.GetMember("D.GetResult");
var second1 = new AwaitExpressionInfo(new AwaitableInfo(getAwaiter2, isCompleted1, getResult1));
var second2 = new AwaitExpressionInfo(new AwaitableInfo(getAwaiter1, isCompleted2, getResult1));
var second3 = new AwaitExpressionInfo(new AwaitableInfo(getAwaiter1, isCompleted1, getResult2));
var second4 = new AwaitExpressionInfo(new AwaitableInfo(getAwaiter2, isCompleted2, getResult2));
var second1 = new AwaitExpressionInfo(getAwaiter2, isCompleted1, getResult1, false);
var second2 = new AwaitExpressionInfo(getAwaiter1, isCompleted2, getResult1, false);
var second3 = new AwaitExpressionInfo(getAwaiter1, isCompleted1, getResult2, false);
var second4 = new AwaitExpressionInfo(getAwaiter2, isCompleted2, getResult2, false);
Assert.False(first.Equals(second1));
Assert.False(first.Equals(second2));
......@@ -1655,7 +1698,7 @@ public class D
Assert.True(first.Equals(first));
Assert.True(first.Equals((object)first));
var another = new AwaitExpressionInfo(new AwaitableInfo(getAwaiter1, isCompleted1, getResult1));
var another = new AwaitExpressionInfo(getAwaiter1, isCompleted1, getResult1, false);
Assert.True(first.GetHashCode() == another.GetHashCode());
}
......@@ -2315,5 +2358,44 @@ void M()
Diagnostic(ErrorCode.ERR_NoConvToIAsyncDisp, "var y = new object()").WithArguments("object").WithLocation(7, 22)
);
}
[Fact]
[WorkItem(30956, "https://github.com/dotnet/roslyn/issues/30956")]
public void GetAwaiterBoxingConversion()
{
var source =
@"using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
struct StructAwaitable { }
class Disposable
{
public StructAwaitable DisposeAsync() => new StructAwaitable();
}
static class Extensions
{
public static TaskAwaiter GetAwaiter(this object x)
{
if (x == null) throw new ArgumentNullException(nameof(x));
Console.Write(x);
return Task.CompletedTask.GetAwaiter();
}
}
class Program
{
static async Task Main()
{
await using (new Disposable())
{
}
}
}";
var comp = CreateCompilationWithTasksExtensions(new[] { source, s_interfaces }, options: TestOptions.ReleaseExe);
CompileAndVerify(comp, expectedOutput: "StructAwaitable");
}
}
}
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.CSharp.Test.Utilities;
......@@ -192,5 +191,29 @@ void Goo(Task<int> t)
var symbolV = (LocalSymbol)semanticModel.GetDeclaredSymbol(decl);
Assert.Equal("System.Int32", symbolV.TypeWithAnnotations.ToTestDisplayString());
}
[Fact]
public void Dynamic()
{
string source =
@"using System.Threading.Tasks;
class Program
{
static async Task Main()
{
dynamic d = Task.CompletedTask;
await d;
}
}";
var comp = CreateCompilation(source);
var tree = comp.SyntaxTrees[0];
var model = comp.GetSemanticModel(tree);
var expr = (AwaitExpressionSyntax)tree.FindNodeOrTokenByKind(SyntaxKind.AwaitExpression).AsNode();
var info = model.GetAwaitExpressionInfo(expr);
Assert.True(info.IsDynamic);
Assert.Null(info.GetAwaiterMethod);
Assert.Null(info.IsCompletedProperty);
Assert.Null(info.GetResultMethod);
}
}
}
......@@ -612,6 +612,7 @@ class C
static async Task<int> M()
{
return await d; //-fieldAccess: dynamic
//-awaitableValuePlaceholder: dynamic
//-awaitExpression: dynamic
//-conversion: int
}
......@@ -635,7 +636,9 @@ static async void M()
{
var x = await await d; //-typeExpression: dynamic
//-fieldAccess: dynamic
//-awaitableValuePlaceholder: dynamic
//-awaitExpression: dynamic
//-awaitableValuePlaceholder: dynamic
//-awaitExpression: dynamic
}
}";
......
......@@ -50155,11 +50155,7 @@ public static class Extensions
}
";
var comp = CreateCompilation(new[] { source }, options: WithNonNullTypesTrue());
comp.VerifyDiagnostics(
// (6,15): warning CS8602: Dereference of a possibly null reference.
// await Async();
Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "Async()").WithLocation(6, 15)
);
comp.VerifyDiagnostics();
}
[Fact]
......@@ -117223,5 +117219,177 @@ static T F<T>() where T : IEnumerable, new()
var comp = CreateCompilation(source);
comp.VerifyDiagnostics();
}
[Fact]
public void GetAwaiterExtensionMethod()
{
var source =
@"#nullable enable
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
class Awaitable { }
static class Program
{
static TaskAwaiter GetAwaiter(this Awaitable? a) => default;
static async Task Main()
{
Awaitable? x = new Awaitable();
Awaitable y = null; // 1
await x;
await y;
}
}";
var comp = CreateCompilation(source);
comp.VerifyDiagnostics(
// (11,23): warning CS8600: Converting null literal or possible null value to non-nullable type.
// Awaitable y = null; // 1
Diagnostic(ErrorCode.WRN_ConvertingNullableToNonNullable, "null").WithLocation(11, 23));
}
[Fact]
[WorkItem(30956, "https://github.com/dotnet/roslyn/issues/30956")]
public void GetAwaiterExtensionMethod_Await()
{
var source =
@"#nullable enable
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
struct StructAwaitable<T> { }
static class Program
{
static TaskAwaiter GetAwaiter(this StructAwaitable<object?> s) => default;
static StructAwaitable<T> Create<T>(T t) => new StructAwaitable<T>();
static async Task Main()
{
object? x = new object();
object y = null; // 1
await Create(x); // 2
await Create(y);
}
}";
var comp = CreateCompilation(source);
comp.VerifyDiagnostics(
// (12,20): warning CS8600: Converting null literal or possible null value to non-nullable type.
// object y = null; // 1
Diagnostic(ErrorCode.WRN_ConvertingNullableToNonNullable, "null").WithLocation(12, 20),
// (13,15): warning CS8620: Argument of type 'StructAwaitable<object>' cannot be used for parameter 's' of type 'StructAwaitable<object?>' in 'TaskAwaiter Program.GetAwaiter(StructAwaitable<object?> s)' due to differences in the nullability of reference types.
// await Create(x); // 2
Diagnostic(ErrorCode.WRN_NullabilityMismatchInArgument, "Create(x)").WithArguments("StructAwaitable<object>", "StructAwaitable<object?>", "s", "TaskAwaiter Program.GetAwaiter(StructAwaitable<object?> s)").WithLocation(13, 15));
}
[Fact]
[WorkItem(30956, "https://github.com/dotnet/roslyn/issues/30956")]
public void GetAwaiterExtensionMethod_AwaitUsing()
{
var source =
@"#nullable enable
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
struct StructAwaitable<T> { }
class Disposable<T>
{
public StructAwaitable<T> DisposeAsync() => new StructAwaitable<T>();
}
static class Program
{
static TaskAwaiter GetAwaiter(this StructAwaitable<object> s) => default;
static Disposable<T> Create<T>(T t) => new Disposable<T>();
static async Task Main()
{
object? x = new object();
object y = null; // 1
await using (Create(x))
{
}
await using (Create(y)) // 2
{
}
}
}";
var comp = CreateCompilationWithTasksExtensions(new[] { IAsyncDisposableDefinition, source });
// Should report warning for GetAwaiter().
comp.VerifyDiagnostics(
// (16,20): warning CS8600: Converting null literal or possible null value to non-nullable type.
// object y = null; // 1
Diagnostic(ErrorCode.WRN_ConvertingNullableToNonNullable, "null").WithLocation(16, 20));
}
[Fact]
[WorkItem(30956, "https://github.com/dotnet/roslyn/issues/30956")]
public void GetAwaiterExtensionMethod_AwaitUsingLocal()
{
var source =
@"#nullable enable
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
struct StructAwaitable<T> { }
class Disposable<T>
{
public StructAwaitable<T> DisposeAsync() => new StructAwaitable<T>();
}
static class Program
{
static TaskAwaiter GetAwaiter(this StructAwaitable<object> s) => default;
static Disposable<T> Create<T>(T t) => new Disposable<T>();
static async Task Main()
{
object? x = new object();
object y = null; // 1
await using var dx = Create(x);
await using var dy = Create(y); // 2
}
}";
var comp = CreateCompilationWithTasksExtensions(new[] { IAsyncDisposableDefinition, source });
// Should report warning for GetAwaiter().
comp.VerifyDiagnostics(
// (16,20): warning CS8600: Converting null literal or possible null value to non-nullable type.
// object y = null; // 1
Diagnostic(ErrorCode.WRN_ConvertingNullableToNonNullable, "null").WithLocation(16, 20));
}
[Fact]
[WorkItem(30956, "https://github.com/dotnet/roslyn/issues/30956")]
public void GetAwaiterExtensionMethod_AwaitForEach()
{
var source =
@"#nullable enable
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
struct StructAwaitable1<T> { }
struct StructAwaitable2<T> { }
class Enumerable<T>
{
public Enumerator<T> GetAsyncEnumerator() => new Enumerator<T>();
}
class Enumerator<T>
{
public object Current => null!;
public StructAwaitable1<T> MoveNextAsync() => new StructAwaitable1<T>();
public StructAwaitable2<T> DisposeAsync() => new StructAwaitable2<T>();
}
static class Program
{
static TaskAwaiter<bool> GetAwaiter(this StructAwaitable1<object?> s) => default;
static TaskAwaiter GetAwaiter(this StructAwaitable2<object> s) => default;
static Enumerable<T> Create<T>(T t) => new Enumerable<T>();
static async Task Main()
{
object? x = new object();
object y = null; // 1
await foreach (var o in Create(x)) // 2
{
}
await foreach (var o in Create(y))
{
}
}
}";
var comp = CreateCompilationWithTasksExtensions(new[] { s_IAsyncEnumerable, source });
// Should report warning for GetAwaiter().
comp.VerifyDiagnostics(
// (24,20): warning CS8600: Converting null literal or possible null value to non-nullable type.
// object y = null; // 1
Diagnostic(ErrorCode.WRN_ConvertingNullableToNonNullable, "null").WithLocation(24, 20));
}
}
}
......@@ -179,6 +179,17 @@ public sealed class NotNullIfNotNullAttribute : Attribute
public NotNullIfNotNullAttribute(string parameterName) { }
}
}
";
protected const string IAsyncDisposableDefinition = @"
using System.Threading.Tasks;
namespace System
{
public interface IAsyncDisposable
{
ValueTask DisposeAsync();
}
}
";
protected const string AsyncStreamsTypes = @"
......
......@@ -43,9 +43,7 @@
<Compile Remove="Symbols\EmbeddedSymbols\VbMyTemplateText.vb" />
</ItemGroup>
<ItemGroup>
<BoundTreeDefinition Include="BoundTree\BoundNodes.xml">
<SubType>Designer</SubType>
</BoundTreeDefinition>
<None Include="BoundTree\BoundNodes.xml" />
<PublicAPI Include="PublicAPI.Shipped.txt" />
<PublicAPI Include="PublicAPI.Unshipped.txt" />
<None Include="Generated\VisualBasic.Grammar.g4" />
......
......@@ -9184,6 +9184,71 @@ End Module
CompileAndVerify(compilation, expectedOutput:=expectedOutput)
CompileAndVerify(compilation.WithOptions(TestOptions.ReleaseExe), expectedOutput:=expectedOutput)
End Sub
<Fact>
Public Sub GetAwaiterBoxingConversion_01()
Dim source =
"Imports System
Imports System.Runtime.CompilerServices
Imports System.Threading.Tasks
Interface IAwaitable
End Interface
Structure StructAwaitable
Implements IAwaitable
End Structure
Module Program
<Extension>
Function GetAwaiter(x As IAwaitable) As TaskAwaiter
If x Is Nothing Then Throw New ArgumentNullException(Nameof(x))
Console.Write(x)
Return Task.CompletedTask.GetAwaiter()
End Function
Async Function M() As Task
Await New StructAwaitable()
End Function
Sub Main()
M().Wait()
End Sub
End Module"
Dim compilation = CreateCompilation(source, options:=TestOptions.ReleaseExe)
CompileAndVerify(compilation, expectedOutput:="StructAwaitable")
End Sub
<Fact>
Public Sub GetAwaiterBoxingConversion_02()
Dim source =
"Imports System
Imports System.Runtime.CompilerServices
Imports System.Threading.Tasks
Structure StructAwaitable
End Structure
Module Program
<Extension>
Function GetAwaiter(x As Object) As TaskAwaiter
If x Is Nothing Then Throw New ArgumentNullException(Nameof(x))
Console.Write(x)
Return Task.CompletedTask.GetAwaiter()
End Function
Async Function M() As Task
Dim s As StructAwaitable? = New StructAwaitable()
Await s
End Function
Sub Main()
M().Wait()
End Sub
End Module"
Dim compilation = CreateCompilation(source, options:=TestOptions.ReleaseExe)
CompileAndVerify(compilation, expectedOutput:="StructAwaitable")
End Sub
End Class
End Namespace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册