未验证 提交 af436523 编写于 作者: M Marie Píchová 提交者: GitHub

[QUIC] API QuicStream (#71969)

* Quic stream API surface

* Fixed test compilation

* Fixed http test compilation

* HttpLoopbackConnection Dispose -> DisposeAsync

* QuicStream implementation

* Fixed some tests

* Fixed all QUIC and HTTP tests

* Fixed exception type for stream closed by connection close

* Feedback

* Fixed WebSocket.Client test build

* Feedback, test fixes

* Fixed build on framework and windows

* Fixed winhandler test

* Swap variable based on order in defining class

* Post merge fixes

* Feedback and build

* Reverted connection state to pass around abort error code

* Fixed exception type.
上级 750157db
......@@ -122,9 +122,9 @@ private void CloseWebSocket()
}
}
public abstract class GenericLoopbackConnection : IDisposable
public abstract class GenericLoopbackConnection : IAsyncDisposable
{
public abstract void Dispose();
public abstract ValueTask DisposeAsync();
public abstract Task InitializeConnectionAsync();
......
......@@ -838,12 +838,12 @@ public async Task SendResponseBodyAsync(int streamId, ReadOnlyMemory<byte> respo
await SendResponseDataAsync(streamId, responseBody, isFinal).ConfigureAwait(false);
}
public override void Dispose()
public override async ValueTask DisposeAsync()
{
// Might have been already shutdown manually via WaitForConnectionShutdownAsync which nulls the _connectionStream.
if (_connectionStream != null)
{
ShutdownIgnoringErrorsAsync(_lastStreamId).GetAwaiter().GetResult();
await ShutdownIgnoringErrorsAsync(_lastStreamId);
}
}
......
......@@ -148,7 +148,7 @@ public override void Dispose()
public override async Task<HttpRequestData> HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList<HttpHeaderData> headers = null, string content = "")
{
using (Http2LoopbackConnection connection = await EstablishConnectionAsync().ConfigureAwait(false))
await using (Http2LoopbackConnection connection = await EstablishConnectionAsync().ConfigureAwait(false))
{
return await connection.HandleRequestAsync(statusCode, headers, content).ConfigureAwait(false);
}
......@@ -156,7 +156,7 @@ public override async Task<HttpRequestData> HandleRequestAsync(HttpStatusCode st
public override async Task AcceptConnectionAsync(Func<GenericLoopbackConnection, Task> funcAsync)
{
using (Http2LoopbackConnection connection = await EstablishConnectionAsync().ConfigureAwait(false))
await using (Http2LoopbackConnection connection = await EstablishConnectionAsync().ConfigureAwait(false))
{
await funcAsync(connection).ConfigureAwait(false);
}
......
......@@ -56,17 +56,17 @@ public Http3LoopbackConnection(QuicConnection connection)
public long MaxHeaderListSize { get; private set; } = -1;
public override void Dispose()
public override async ValueTask DisposeAsync()
{
// Close any remaining request streams (but NOT control streams, as these should not be closed while the connection is open)
foreach (Http3LoopbackStream stream in _openStreams.Values)
{
stream.Dispose();
await stream.DisposeAsync().ConfigureAwait(false);
}
foreach (QuicStream stream in _delayedStreams)
{
stream.Dispose();
await stream.DisposeAsync().ConfigureAwait(false);
}
// We don't dispose the connection currently, because this causes races when the server connection is closed before
......@@ -79,8 +79,8 @@ public override void Dispose()
_connection.Dispose();
// Dispose control streams so that we release their handles too.
_inboundControlStream?.Dispose();
_outboundControlStream?.Dispose();
await _inboundControlStream?.DisposeAsync().ConfigureAwait(false);
await _outboundControlStream?.DisposeAsync().ConfigureAwait(false);
#endif
}
......@@ -104,7 +104,7 @@ public static int GetRequestId(QuicStream stream)
Debug.Assert(stream.CanRead && stream.CanWrite, "Stream must be a request stream.");
// TODO: QUIC streams can have IDs larger than int.MaxValue; update all our tests to use long rather than int.
return checked((int)stream.StreamId + 1);
return checked((int)stream.Id + 1);
}
public Http3LoopbackStream GetOpenRequest(int requestId = 0)
......@@ -172,9 +172,9 @@ public async Task<Http3LoopbackStream> AcceptRequestStreamAsync()
Assert.True(quicStream.CanWrite, "Expected writeable stream.");
_openStreams.Add(checked((int)quicStream.StreamId), stream);
_openStreams.Add(checked((int)quicStream.Id), stream);
_currentStream = stream;
_currentStreamId = quicStream.StreamId;
_currentStreamId = quicStream.Id;
return stream;
}
......@@ -293,9 +293,9 @@ public async Task WaitForClientDisconnectAsync(bool refuseNewRequests = true)
break;
}
using (stream)
await using (stream)
{
await stream.AbortAndWaitForShutdownAsync(H3_REQUEST_REJECTED);
stream.Abort(H3_REQUEST_REJECTED);
}
}
......
......@@ -82,14 +82,14 @@ public override async Task<GenericLoopbackConnection> EstablishGenericConnection
public override async Task AcceptConnectionAsync(Func<GenericLoopbackConnection, Task> funcAsync)
{
using Http3LoopbackConnection con = await EstablishHttp3ConnectionAsync().ConfigureAwait(false);
await using Http3LoopbackConnection con = await EstablishHttp3ConnectionAsync().ConfigureAwait(false);
await funcAsync(con).ConfigureAwait(false);
await con.ShutdownAsync();
}
public override async Task<HttpRequestData> HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList<HttpHeaderData> headers = null, string content = "")
{
using var con = (Http3LoopbackConnection)await EstablishGenericConnectionAsync().ConfigureAwait(false);
await using Http3LoopbackConnection con = (Http3LoopbackConnection)await EstablishGenericConnectionAsync().ConfigureAwait(false);
return await con.HandleRequestAsync(statusCode, headers, content).ConfigureAwait(false);
}
}
......
......@@ -15,7 +15,7 @@
namespace System.Net.Test.Common
{
internal sealed class Http3LoopbackStream : IDisposable
internal sealed class Http3LoopbackStream : IAsyncDisposable
{
private const int MaximumVarIntBytes = 8;
private const long VarIntMax = (1L << 62) - 1;
......@@ -40,12 +40,9 @@ public Http3LoopbackStream(QuicStream stream)
_stream = stream;
}
public void Dispose()
{
_stream.Dispose();
}
public ValueTask DisposeAsync() => _stream.DisposeAsync();
public long StreamId => _stream.StreamId;
public long StreamId => _stream.Id;
public async Task<HttpRequestData> HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList<HttpHeaderData> headers = null, string content = "")
{
......@@ -285,9 +282,7 @@ public async Task SendResponseBodyAsync(byte[] content, bool isFinal = true)
if (isFinal)
{
_stream.Shutdown();
await _stream.ShutdownCompleted().ConfigureAwait(false);
Dispose();
_stream.CompleteWrites();
}
}
......@@ -389,7 +384,7 @@ async Task WaitForWriteCancellation()
{
try
{
await _stream.WaitForWriteCompletionAsync();
await _stream.WritesClosed;
}
catch (QuicException ex) when (ex.QuicError == QuicError.StreamAborted && ex.ApplicationErrorCode == Http3LoopbackConnection.H3_REQUEST_CANCELLED)
{
......@@ -424,11 +419,9 @@ private async Task DrainResponseData()
}
}
public async Task AbortAndWaitForShutdownAsync(long errorCode)
public void Abort(long errorCode)
{
_stream.AbortRead(errorCode);
_stream.AbortWrite(errorCode);
await _stream.ShutdownCompleted();
_stream.Abort(QuicAbortDirection.Both, errorCode);
}
public async Task<(long? frameType, byte[] payload)> ReadFrameAsync()
......
......@@ -109,15 +109,18 @@ public override async Task<GenericLoopbackConnection> EstablishGenericConnection
{
return connection = await Http2LoopbackServerFactory.Singleton.CreateConnectionAsync(new SocketWrapper(socket), stream, options).ConfigureAwait(false);
}
else
else
{
throw new Exception($"Invalid ClearTextVersion={_options.ClearTextVersion} specified");
}
}
catch
{
connection?.Dispose();
connection = null;
{
if (connection is not null)
{
await connection.DisposeAsync();
connection = null;
}
stream.Dispose();
throw;
}
......@@ -132,7 +135,7 @@ public override async Task<GenericLoopbackConnection> EstablishGenericConnection
public override async Task<HttpRequestData> HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList<HttpHeaderData> headers = null, string content = "")
{
using (GenericLoopbackConnection connection = await EstablishGenericConnectionAsync().ConfigureAwait(false))
await using (GenericLoopbackConnection connection = await EstablishGenericConnectionAsync().ConfigureAwait(false))
{
return await connection.HandleRequestAsync(statusCode, headers, content).ConfigureAwait(false);
}
......@@ -140,7 +143,7 @@ public override async Task<HttpRequestData> HandleRequestAsync(HttpStatusCode st
public override async Task AcceptConnectionAsync(Func<GenericLoopbackConnection, Task> funcAsync)
{
using (GenericLoopbackConnection connection = await EstablishGenericConnectionAsync().ConfigureAwait(false))
await using (GenericLoopbackConnection connection = await EstablishGenericConnectionAsync().ConfigureAwait(false))
{
await funcAsync(connection).ConfigureAwait(false);
}
......
......@@ -717,7 +717,7 @@ public async Task Credentials_BrokenNtlmFromServer()
Assert.Equal(0, requestData.GetHeaderValueCount("Authorization"));
// Establish a session connection
using var connection = await server.EstablishConnectionAsync();
await using LoopbackServer.Connection connection = await server.EstablishConnectionAsync();
requestData = await connection.ReadRequestDataAsync();
string authHeaderValue = requestData.GetSingleHeaderValue("Authorization");
Assert.Contains("NTLM", authHeaderValue);
......
......@@ -156,7 +156,7 @@ public async Task<Connection> EstablishConnectionAsync()
public async Task AcceptConnectionAsync(Func<Connection, Task> funcAsync)
{
using (Connection connection = await EstablishConnectionAsync().ConfigureAwait(false))
await using (Connection connection = await EstablishConnectionAsync().ConfigureAwait(false))
{
await funcAsync(connection).ConfigureAwait(false);
}
......@@ -654,7 +654,7 @@ private async Task<byte[]> ReadLineBytesAsync()
return null;
}
public override void Dispose()
public override async ValueTask DisposeAsync()
{
try
{
......@@ -666,7 +666,12 @@ public override void Dispose()
}
catch (Exception) { }
#if !NETSTANDARD2_0 && !NETFRAMEWORK
await _stream.DisposeAsync().ConfigureAwait(false);
#else
_stream.Dispose();
await Task.CompletedTask.ConfigureAwait(false);
#endif
_socket?.Dispose();
}
......@@ -1076,7 +1081,7 @@ public override Task WaitForCloseAsync(CancellationToken cancellationToken)
public override async Task<HttpRequestData> HandleRequestAsync(HttpStatusCode statusCode = HttpStatusCode.OK, IList<HttpHeaderData> headers = null, string content = "")
{
using (Connection connection = await EstablishConnectionAsync().ConfigureAwait(false))
await using (Connection connection = await EstablishConnectionAsync().ConfigureAwait(false))
{
return await connection.HandleRequestAsync(statusCode, headers, content).ConfigureAwait(false);
}
......
......@@ -41,7 +41,7 @@ public async Task UseClientCertOnHttp2_DowngradedToHttp1MutualAuth_Success()
},
async s =>
{
using (LoopbackServer.Connection connection = await s.EstablishConnectionAsync().ConfigureAwait(false))
await using (LoopbackServer.Connection connection = await s.EstablishConnectionAsync().ConfigureAwait(false))
{
SslStream sslStream = connection.Stream as SslStream;
Assert.NotNull(sslStream);
......@@ -76,7 +76,7 @@ public async Task UseClientCertOnHttp2_OSSupportsIt_Success()
},
async s =>
{
using (Http2LoopbackConnection connection = await s.EstablishConnectionAsync().ConfigureAwait(false))
await using (Http2LoopbackConnection connection = await s.EstablishConnectionAsync().ConfigureAwait(false))
{
SslStream sslStream = connection.Stream as SslStream;
Assert.NotNull(sslStream);
......
......@@ -210,7 +210,7 @@ public async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, lon
throw new HttpRequestException(SR.net_http_request_aborted, null, RequestRetryType.RetryOnConnectionFailure);
}
requestStream!.StreamId = quicStream.StreamId;
requestStream!.StreamId = quicStream.Id;
bool goAway;
lock (SyncObj)
......@@ -542,7 +542,7 @@ await using (stream.ConfigureAwait(false))
NetEventSource.Info(this, $"Ignoring server-initiated stream of unknown type {unknownStreamType}.");
}
stream.AbortRead((long)Http3ErrorCode.StreamCreationError);
stream.Abort(QuicAbortDirection.Read, (long)Http3ErrorCode.StreamCreationError);
stream.Dispose();
return;
}
......
......@@ -260,7 +260,7 @@ public async Task<HttpResponseMessage> SendAsync(CancellationToken cancellationT
// We're either observing GOAWAY, or the cancellationToken parameter has been canceled.
if (cancellationToken.IsCancellationRequested)
{
_stream.AbortWrite((long)Http3ErrorCode.RequestCancelled);
_stream.Abort(QuicAbortDirection.Write, (long)Http3ErrorCode.RequestCancelled);
throw new TaskCanceledException(ex.Message, ex, cancellationToken);
}
else
......@@ -277,7 +277,7 @@ public async Task<HttpResponseMessage> SendAsync(CancellationToken cancellationT
}
catch (Exception ex)
{
_stream.AbortWrite((long)Http3ErrorCode.InternalError);
_stream.Abort(QuicAbortDirection.Write, (long)Http3ErrorCode.InternalError);
if (ex is HttpRequestException)
{
throw;
......@@ -398,7 +398,7 @@ private async Task SendContentAsync(HttpContent content, CancellationToken cance
}
else
{
_stream.Shutdown();
_stream.CompleteWrites();
}
if (HttpTelemetry.Log.IsEnabled()) HttpTelemetry.Log.RequestContentStop(writeStream.BytesWritten);
......@@ -814,7 +814,7 @@ private async ValueTask ReadHeadersAsync(long headersLength, CancellationToken c
// https://tools.ietf.org/html/draft-ietf-quic-http-24#section-4.1.1
if (headersLength > _headerBudgetRemaining)
{
_stream.AbortWrite((long)Http3ErrorCode.ExcessiveLoad);
_stream.Abort(QuicAbortDirection.Write, (long)Http3ErrorCode.ExcessiveLoad);
throw new HttpRequestException(SR.Format(SR.net_http_response_headers_exceeded_length, _connection.Pool.Settings._maxResponseHeadersLength * 1024L));
}
......@@ -1201,12 +1201,12 @@ private void HandleReadResponseContentException(Exception ex, CancellationToken
_connection.Abort(ex);
throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex));
case OperationCanceledException oce when oce.CancellationToken == cancellationToken:
_stream.AbortRead((long)Http3ErrorCode.RequestCancelled);
_stream.Abort(QuicAbortDirection.Read, (long)Http3ErrorCode.RequestCancelled);
ExceptionDispatchInfo.Throw(ex); // Rethrow.
return; // Never reached.
}
_stream.AbortRead((long)Http3ErrorCode.InternalError);
_stream.Abort(QuicAbortDirection.Read, (long)Http3ErrorCode.InternalError);
throw new IOException(SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex));
}
......@@ -1264,12 +1264,12 @@ private void AbortStream()
// If the request body isn't completed, cancel it now.
if (_requestContentLengthRemaining != 0) // 0 is used for the end of content writing, -1 is used for unknown Content-Length
{
_stream.AbortWrite((long)Http3ErrorCode.RequestCancelled);
_stream.Abort(QuicAbortDirection.Write, (long)Http3ErrorCode.RequestCancelled);
}
// If the response body isn't completed, cancel it now.
if (_responseDataPayloadRemaining != -1) // -1 is used for EOF, 0 for consumed DATA frame payload before the next read
{
_stream.AbortRead((long)Http3ErrorCode.RequestCancelled);
_stream.Abort(QuicAbortDirection.Read, (long)Http3ErrorCode.RequestCancelled);
}
}
......
......@@ -81,7 +81,7 @@ public async Task AltSvc_ConnectionFrame_UpgradeFrom20_Success()
Task<HttpResponseMessage> firstResponseTask = client.GetAsync(firstServer.Address);
Task serverTask = Task.Run(async () =>
{
using Http2LoopbackConnection connection = await firstServer.EstablishConnectionAsync();
await using Http2LoopbackConnection connection = await firstServer.EstablishConnectionAsync();
int streamId = await connection.ReadRequestHeaderAsync();
await connection.WriteFrameAsync(new AltSvcFrame($"https://{firstServer.Address.IdnHost}:{firstServer.Address.Port}", $"h3=\"{secondServer.Address.IdnHost}:{secondServer.Address.Port}\"", streamId: 0));
......@@ -106,7 +106,7 @@ public async Task AltSvc_ResponseFrame_UpgradeFrom20_Success()
Task<HttpResponseMessage> firstResponseTask = client.GetAsync(firstServer.Address);
Task serverTask = Task.Run(async () =>
{
using Http2LoopbackConnection connection = await firstServer.EstablishConnectionAsync();
await using Http2LoopbackConnection connection = await firstServer.EstablishConnectionAsync();
int streamId = await connection.ReadRequestHeaderAsync();
await connection.SendDefaultResponseHeadersAsync(streamId);
......
......@@ -291,7 +291,7 @@ public async Task Http2_ServerSendsInvalidSettingsValue_Error(SettingId settingI
await Assert.ThrowsAsync<HttpRequestException>(() => sendTask);
connection.Dispose();
await connection.DisposeAsync();
}
}
......@@ -2609,7 +2609,7 @@ public async Task ConnectAsync_ReadWriteWebSocketStream()
Assert.Equal(0, await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10)));
Assert.NotNull(connection);
connection.Dispose();
await connection.DisposeAsync();
}
[Fact]
......
......@@ -41,12 +41,12 @@ public async Task ClientSettingsReceived_Success(int headerSizeLimit)
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
(Http3LoopbackStream settingsStream, Http3LoopbackStream requestStream) = await connection.AcceptControlAndRequestStreamAsync();
using (settingsStream)
using (requestStream)
await using (settingsStream)
await using (requestStream)
{
Assert.False(settingsStream.CanWrite, "Expected unidirectional control stream.");
Assert.Equal(headerSizeLimit * 1024L, connection.MaxHeaderListSize);
......@@ -85,10 +85,10 @@ public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit)
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
for (int i = 0; i < streamLimit + 1; ++i)
{
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.HandleRequestAsync();
}
});
......@@ -123,10 +123,10 @@ public async Task SendStreamLimitRequestsConcurrently_Succeeds(int streamLimit)
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
for (int i = 0; i < streamLimit; ++i)
{
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.HandleRequestAsync();
}
});
......@@ -171,7 +171,7 @@ public async Task SendMoreThanStreamLimitRequestsConcurrently_LastWaits(int stre
Task serverTask = Task.Run(async () =>
{
// Read the first streamLimit requests, keep the streams open to make the last one wait.
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
var streams = new Http3LoopbackStream[streamLimit];
for (int i = 0; i < streamLimit; ++i)
{
......@@ -183,7 +183,7 @@ public async Task SendMoreThanStreamLimitRequestsConcurrently_LastWaits(int stre
// Make the last request running independently.
var lastRequest = Task.Run(async () =>
{
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.HandleRequestAsync();
});
......@@ -194,7 +194,7 @@ public async Task SendMoreThanStreamLimitRequestsConcurrently_LastWaits(int stre
for (int i = 0; i < streamLimit; ++i)
{
await streams[i].SendResponseAsync();
streams[i].Dispose();
await streams[i].DisposeAsync();
// After the first request is fully processed, the last request should unblock and get processed.
if (i == 0)
{
......@@ -273,15 +273,15 @@ public async Task ReservedFrameType_Throws()
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.SendFrameAsync(ReservedHttp2PriorityFrameId, new byte[8]);
QuicException ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, async () =>
{
await stream.HandleRequestAsync();
using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync();
});
Assert.Equal(UnexpectedFrameErrorCode, ex.ApplicationErrorCode);
......@@ -313,8 +313,8 @@ public async Task RequestSentResponseDisposed_ThrowsOnServer()
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
HttpRequestData request = await stream.ReadRequestDataAsync();
await stream.SendResponseHeadersAsync();
......@@ -371,8 +371,8 @@ public async Task RequestSendingResponseDisposed_ThrowsOnServer()
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
HttpRequestData request = await stream.ReadRequestDataAsync(false);
await stream.SendResponseHeadersAsync();
......@@ -436,10 +436,10 @@ public async Task ServerCertificateCustomValidationCallback_Succeeds()
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.HandleRequestAsync();
using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync();
await stream2.HandleRequestAsync();
});
......@@ -479,8 +479,8 @@ public async Task EmptyCustomContent_FlushHeaders()
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
// Receive headers and unblock the client.
await stream.ReadRequestDataAsync(false);
......@@ -528,7 +528,7 @@ public async Task DisposeHttpClient_Http3ConnectionIsClosed()
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
HttpRequestData request = await connection.ReadRequestDataAsync();
await connection.SendResponseAsync();
......@@ -665,8 +665,8 @@ public async Task ResponseCancellation_ServerReceivesCancellation(CancellationTy
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
HttpRequestData request = await stream.ReadRequestDataAsync().ConfigureAwait(false);
......@@ -746,8 +746,8 @@ public async Task ResponseCancellation_BothCancellationTokenAndDispose_Success()
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
HttpRequestData request = await stream.ReadRequestDataAsync().ConfigureAwait(false);
......@@ -829,7 +829,7 @@ public async Task Alpn_H3_Success()
Task serverTask = Task.Run(async () =>
{
connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.HandleRequestAsync();
});
......@@ -850,7 +850,7 @@ public async Task Alpn_H3_Success()
SslApplicationProtocol negotiatedAlpn = ExtractMsQuicNegotiatedAlpn(connection);
Assert.Equal(new SslApplicationProtocol("h3"), negotiatedAlpn);
connection.Dispose();
await connection.DisposeAsync();
}
[Fact]
......@@ -906,7 +906,7 @@ public async Task StatusCodes_ReceiveSuccess(HttpStatusCode statusCode, bool qpa
Task serverTask = Task.Run(async () =>
{
connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
HttpRequestData request = await stream.ReadRequestDataAsync().ConfigureAwait(false);
......@@ -934,7 +934,7 @@ public async Task StatusCodes_ReceiveSuccess(HttpStatusCode statusCode, bool qpa
await serverTask;
Assert.NotNull(connection);
connection.Dispose();
await connection.DisposeAsync();
}
[Theory]
......@@ -1017,9 +1017,9 @@ public async Task EchoServerStreaming_DifferentMessageSize_Success(int messageSi
await serverTask.WaitAsync(TimeSpan.FromSeconds(60));
serverStream.Dispose();
await serverStream.DisposeAsync();
Assert.NotNull(connection);
connection.Dispose();
await connection.DisposeAsync();
}
[Fact]
......@@ -1082,9 +1082,9 @@ public async Task RequestContentStreaming_Timeout_BothClientAndServerReceiveCanc
Assert.Equal(268 /*H3_REQUEST_CANCELLED (0x10C)*/, ex.ApplicationErrorCode);
Assert.NotNull(serverStream);
serverStream.Dispose();
await serverStream.DisposeAsync();
Assert.NotNull(connection);
connection.Dispose();
await connection.DisposeAsync();
}
[Fact]
......@@ -1147,9 +1147,9 @@ public async Task RequestContentStreaming_Cancellation_BothClientAndServerReceiv
Assert.Equal(268 /*H3_REQUEST_CANCELLED (0x10C)*/, ex.ApplicationErrorCode);
Assert.NotNull(serverStream);
serverStream.Dispose();
await serverStream.DisposeAsync();
Assert.NotNull(connection);
connection.Dispose();
await connection.DisposeAsync();
}
[Fact]
......@@ -1228,9 +1228,9 @@ public async Task DuplexStreaming_RequestCTCancellation_DoesNotApply()
await serverTask.WaitAsync(TimeSpan.FromSeconds(120));
Assert.NotNull(serverStream);
serverStream.Dispose();
await serverStream.DisposeAsync();
Assert.NotNull(connection);
connection.Dispose();
await connection.DisposeAsync();
}
[Theory]
......@@ -1314,9 +1314,9 @@ public async Task DuplexStreaming_AbortByServer_StreamingCancelled(bool graceful
await serverTask.WaitAsync(TimeSpan.FromSeconds(120));
Assert.NotNull(serverStream);
serverStream.Dispose();
await serverStream.DisposeAsync();
Assert.NotNull(connection);
connection.Dispose();
await connection.DisposeAsync();
}
private static async Task<QuicException> AssertThrowsQuicExceptionAsync(QuicError expectedError, Func<Task> testCode)
......
......@@ -323,7 +323,7 @@ private async Task ProcessIncomingFramesAsync(CancellationToken cancellationToke
}
_output?.WriteLine("ProcessIncomingFramesAsync finished");
_connection.Dispose();
await _connection.DisposeAsync();
}
private void DisablePingResponse() => Interlocked.Exchange(ref _sendPingResponse, 0);
......
......@@ -2250,21 +2250,21 @@ public async Task Http2_MultipleConnectionsEnabled_ManyRequestsEnqueuedSimultane
List<(Http2LoopbackConnection connection, int streamId)> acceptedRequests = new List<(Http2LoopbackConnection connection, int streamId)>();
using Http2LoopbackConnection c1 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 });
await using Http2LoopbackConnection c1 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 });
for (int i = 0; i < MaxConcurrentStreams; i++)
{
(int streamId, _) = await c1.ReadAndParseRequestHeaderAsync();
acceptedRequests.Add((c1, streamId));
}
using Http2LoopbackConnection c2 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 });
await using Http2LoopbackConnection c2 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 });
for (int i = 0; i < MaxConcurrentStreams; i++)
{
(int streamId, _) = await c2.ReadAndParseRequestHeaderAsync();
acceptedRequests.Add((c2, streamId));
}
using Http2LoopbackConnection c3 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 });
await using Http2LoopbackConnection c3 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 });
(int finalStreamId, _) = await c3.ReadAndParseRequestHeaderAsync();
acceptedRequests.Add((c3, finalStreamId));
......@@ -2646,7 +2646,7 @@ public async Task ConnectCallback_UseMemoryBuffer_Success(bool useSsl)
Task serverTask = Task.Run(async () =>
{
using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options);
await using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options);
await loopbackConnection.InitializeConnectionAsync();
HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync();
......@@ -2708,7 +2708,7 @@ public async Task ConnectCallback_UseUnixDomainSocket_Success(bool useSsl)
Task<string> clientTask = client.GetStringAsync($"{(options.UseSsl ? "https" : "http")}://{guid}/foo");
Socket serverSocket = await listenSocket.AcceptAsync();
using (GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, new NetworkStream(serverSocket, ownsSocket: true), options))
await using (GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, new NetworkStream(serverSocket, ownsSocket: true), options))
{
await loopbackConnection.InitializeConnectionAsync();
......@@ -2771,7 +2771,7 @@ public async Task ConnectCallback_ConnectionPrefix_Success(bool useSsl)
await serverStream.WriteAsync(ResponsePrefix);
using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options);
await using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, options);
await loopbackConnection.InitializeConnectionAsync();
HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync();
......@@ -3271,7 +3271,7 @@ public async Task PlaintextStreamFilter_ConnectionPrefix_Success(bool useSsl)
await serverStream.WriteAsync(ResponsePrefix);
using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, new GenericLoopbackOptions() { UseSsl = false });
await using GenericLoopbackConnection loopbackConnection = await LoopbackServerFactory.CreateConnectionAsync(socket: null, serverStream, new GenericLoopbackOptions() { UseSsl = false });
await loopbackConnection.InitializeConnectionAsync();
HttpRequestData requestData = await loopbackConnection.ReadRequestDataAsync();
......
......@@ -664,7 +664,7 @@ public void EventSource_ConnectionPoolAtMaxConnections_LogsRequestLeftQueue()
connection = await server.EstablishGenericConnectionAsync();
}
using (connection)
await using (connection)
{
// Dummy request to ensure that the MaxConcurrentStreams setting has been acknowledged
await connection.ReadRequestDataAsync(readBody: false);
......
......@@ -6,6 +6,13 @@
namespace System.Net.Quic
{
[System.FlagsAttribute]
public enum QuicAbortDirection
{
Read = 1,
Write = 2,
Both = 3,
}
public sealed partial class QuicClientConnectionOptions : System.Net.Quic.QuicConnectionOptions
{
public QuicClientConnectionOptions() { }
......@@ -89,37 +96,35 @@ public sealed partial class QuicStream : System.IO.Stream
public override bool CanSeek { get { throw null; } }
public override bool CanTimeout { get { throw null; } }
public override bool CanWrite { get { throw null; } }
public long Id { get { throw null; } }
public override long Length { get { throw null; } }
public override long Position { get { throw null; } set { } }
public bool ReadsCompleted { get { throw null; } }
public System.Threading.Tasks.Task ReadsClosed { get { throw null; } }
public override int ReadTimeout { get { throw null; } set { } }
public long StreamId { get { throw null; } }
public System.Net.Quic.QuicStreamType Type { get { throw null; } }
public System.Threading.Tasks.Task WritesClosed { get { throw null; } }
public override int WriteTimeout { get { throw null; } set { } }
public void AbortRead(long errorCode) { }
public void AbortWrite(long errorCode) { }
public void Abort(System.Net.Quic.QuicAbortDirection abortDirection, long errorCode) { }
public override System.IAsyncResult BeginRead(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; }
public override System.IAsyncResult BeginWrite(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; }
public void CompleteWrites() { }
protected override void Dispose(bool disposing) { }
public override System.Threading.Tasks.ValueTask DisposeAsync() { throw null; }
public override int EndRead(System.IAsyncResult asyncResult) { throw null; }
public override void EndWrite(System.IAsyncResult asyncResult) { }
public override void Flush() { }
public override System.Threading.Tasks.Task FlushAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
public override System.Threading.Tasks.Task FlushAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override int Read(byte[] buffer, int offset, int count) { throw null; }
public override int Read(System.Span<byte> buffer) { throw null; }
public override System.Threading.Tasks.Task<int> ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
public override System.Threading.Tasks.Task<int> ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override System.Threading.Tasks.ValueTask<int> ReadAsync(System.Memory<byte> buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override int ReadByte() { throw null; }
public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; }
public override void SetLength(long value) { }
public void Shutdown() { }
public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask WaitForWriteCompletionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override void Write(byte[] buffer, int offset, int count) { }
public override void Write(System.ReadOnlySpan<byte> buffer) { }
public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory<byte> buffer, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory<byte> buffer, bool completeWrites, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory<byte> buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override void WriteByte(byte value) { }
}
......
......@@ -159,6 +159,15 @@
<data name="net_quic_timeout_use_gt_zero" xml:space="preserve">
<value>Timeout can only be set to 'System.Threading.Timeout.Infinite' or a value &gt; 0.</value>
</data>
<data name="net_quic_unsupported_endpoint_type" xml:space="preserve">
<value>'{0}' in not supported remote endpoint type, expected IP or DNS endpoint."</value>
</data>
<data name="net_quic_not_null_not_empty_connection" xml:space="preserve">
<value>'{0}' must be specified and contain at least one item to establish the connection.</value>
</data>
<data name="net_quic_not_null_ceritifcate" xml:space="preserve">
<value>Server must provide a certificate in '{0}' or '{1}' or via '{2}' for the connection.</value>
</data>
<data name="net_quic_timeout" xml:space="preserve">
<value>Connection timed out waiting for a response from the peer.</value>
</data>
......
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Microsoft.Quic;
namespace System.Net.Quic.Implementations.MsQuic.Internal
{
internal static class MsQuicAddressHelpers
{
internal static unsafe IPEndPoint ToIPEndPoint(this ref QuicAddr quicAddress)
{
// MsQuic always uses storage size as if IPv6 was used
// QuicAddr is native memory, it cannot be moved by GC, thus no need for fixed expression here.
Span<byte> addressBytes = new Span<byte>((byte*)Unsafe.AsPointer(ref quicAddress), Internals.SocketAddress.IPv6AddressSize);
return new Internals.SocketAddress(SocketAddressPal.GetAddressFamily(addressBytes), addressBytes).GetIPEndPoint();
}
internal static unsafe QuicAddr ToQuicAddr(this IPEndPoint iPEndPoint)
{
// TODO: is the layout same for SocketAddress.Buffer and QuicAddr on all platforms?
QuicAddr result = default;
Span<byte> rawAddress = MemoryMarshal.AsBytes(new Span<QuicAddr>(ref result));
Internals.SocketAddress address = IPEndPointExtensions.Serialize(iPEndPoint);
Debug.Assert(address.Size <= rawAddress.Length);
address.Buffer.AsSpan(0, address.Size).CopyTo(rawAddress);
return result;
}
}
}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using Microsoft.Quic;
using static Microsoft.Quic.MsQuic;
#if TARGET_WINDOWS
using Microsoft.Win32;
#endif
namespace System.Net.Quic
{
internal sealed unsafe class MsQuicApi
{
private static readonly Version MinWindowsVersion = new Version(10, 0, 20145, 1000);
private static readonly Version MsQuicVersion = new Version(2, 0);
public MsQuicSafeHandle Registration { get; }
public QUIC_API_TABLE* ApiTable { get; }
// This is workaround for a bug in ILTrimmer.
// Without these DynamicDependency attributes, .ctor() will be removed from the safe handles.
// Remove once fixed: https://github.com/mono/linker/issues/1660
[DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(MsQuicSafeHandle))]
[DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(MsQuicContextSafeHandle))]
private MsQuicApi(QUIC_API_TABLE* apiTable)
{
ApiTable = apiTable;
fixed (byte* pAppName = "System.Net.Quic"u8)
{
var cfg = new QUIC_REGISTRATION_CONFIG
{
AppName = (sbyte*)pAppName,
ExecutionProfile = QUIC_EXECUTION_PROFILE.LOW_LATENCY
};
QUIC_HANDLE* handle;
ThrowHelper.ThrowIfMsQuicError(ApiTable->RegistrationOpen(&cfg, &handle), "RegistrationOpen failed");
Registration = new MsQuicSafeHandle(handle, apiTable->RegistrationClose, SafeHandleType.Registration);
}
}
internal static MsQuicApi Api { get; } = null!;
internal static bool IsQuicSupported { get; }
internal static bool Tls13ServerMayBeDisabled { get; }
internal static bool Tls13ClientMayBeDisabled { get; }
static MsQuicApi()
{
if (OperatingSystem.IsWindows())
{
if (!IsWindowsVersionSupported())
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"Current Windows version ({Environment.OSVersion}) is not supported by QUIC. Minimal supported version is {MinWindowsVersion}");
}
return;
}
Tls13ServerMayBeDisabled = IsTls13Disabled(true);
Tls13ClientMayBeDisabled = IsTls13Disabled(false);
}
IntPtr msQuicHandle;
if (NativeLibrary.TryLoad($"{Interop.Libraries.MsQuic}.{MsQuicVersion.Major}", typeof(MsQuicApi).Assembly, DllImportSearchPath.AssemblyDirectory, out msQuicHandle) ||
NativeLibrary.TryLoad(Interop.Libraries.MsQuic, typeof(MsQuicApi).Assembly, DllImportSearchPath.AssemblyDirectory, out msQuicHandle))
{
try
{
if (NativeLibrary.TryGetExport(msQuicHandle, "MsQuicOpenVersion", out IntPtr msQuicOpenVersionAddress))
{
QUIC_API_TABLE* apiTable;
delegate* unmanaged[Cdecl]<uint, QUIC_API_TABLE**, int> msQuicOpenVersion = (delegate* unmanaged[Cdecl]<uint, QUIC_API_TABLE**, int>)msQuicOpenVersionAddress;
if (StatusSucceeded(msQuicOpenVersion((uint)MsQuicVersion.Major, &apiTable)))
{
int arraySize = 4;
uint* libVersion = stackalloc uint[arraySize];
uint size = (uint)arraySize * sizeof(uint);
if (StatusSucceeded(apiTable->GetParam(null, QUIC_PARAM_GLOBAL_LIBRARY_VERSION, &size, libVersion)))
{
var version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]);
if (version >= MsQuicVersion)
{
Api = new MsQuicApi(apiTable);
IsQuicSupported = true;
}
else
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"Incompatible MsQuic library version '{version}', expecting '{MsQuicVersion}'");
}
}
}
}
}
}
finally
{
if (!IsQuicSupported)
{
NativeLibrary.Free(msQuicHandle);
}
}
}
}
private static bool IsWindowsVersionSupported() => OperatingSystem.IsWindowsVersionAtLeast(MinWindowsVersion.Major,
MinWindowsVersion.Minor, MinWindowsVersion.Build, MinWindowsVersion.Revision);
private static bool IsTls13Disabled(bool isServer)
{
#if TARGET_WINDOWS
string SChannelTls13RegistryKey = isServer
? @"SYSTEM\CurrentControlSet\Control\SecurityProviders\SCHANNEL\Protocols\TLS 1.3\Server"
: @"SYSTEM\CurrentControlSet\Control\SecurityProviders\SCHANNEL\Protocols\TLS 1.3\Client";
using var regKey = Registry.LocalMachine.OpenSubKey(SChannelTls13RegistryKey);
if (regKey is null)
{
return false;
}
if (regKey.GetValue("Enabled") is int enabled && enabled == 0)
{
return true;
}
if (regKey.GetValue("DisabledByDefault") is int disabled && disabled == 1)
{
return true;
}
#endif
return false;
}
}
}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Net.Sockets;
using Microsoft.Quic;
using static Microsoft.Quic.MsQuic;
namespace System.Net.Quic.Implementations.MsQuic.Internal
{
internal static class MsQuicParameterHelpers
{
internal static unsafe IPEndPoint GetIPEndPointParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param, AddressFamily? addressFamilyOverride = null)
{
// MsQuic always uses storage size as if IPv6 was used
uint valueLen = (uint)Internals.SocketAddress.IPv6AddressSize;
Span<byte> address = stackalloc byte[Internals.SocketAddress.IPv6AddressSize];
fixed (byte* paddress = &MemoryMarshal.GetReference(address))
{
ThrowHelper.ThrowIfMsQuicError(api.ApiTable->GetParam(
nativeObject.QuicHandle,
param,
&valueLen,
paddress), "GetIPEndPointParam failed.");
}
address = address.Slice(0, (int)valueLen);
return new Internals.SocketAddress(addressFamilyOverride ?? SocketAddressPal.GetAddressFamily(address), address).GetIPEndPoint();
}
internal static unsafe void SetIPEndPointParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param, IPEndPoint value)
{
Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(value);
// MsQuic always reads same amount of memory as if IPv6 was used, so we can't pass pointer to socketAddress.Buffer directly
Span<byte> address = stackalloc byte[Internals.SocketAddress.IPv6AddressSize];
socketAddress.Buffer.AsSpan(0, socketAddress.Size).CopyTo(address);
address.Slice(socketAddress.Size).Clear();
fixed (byte* paddress = &MemoryMarshal.GetReference(address))
{
ThrowHelper.ThrowIfMsQuicError(api.ApiTable->SetParam(
nativeObject.QuicHandle,
param,
(uint)address.Length,
paddress), "Could not set IPEndPoint");
}
}
internal static unsafe ushort GetUShortParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param)
{
ushort value;
uint valueLen = (uint)sizeof(ushort);
ThrowHelper.ThrowIfMsQuicError(api.ApiTable->GetParam(
nativeObject.QuicHandle,
param,
&valueLen,
(byte*)&value), "GetUShortParam failed");
Debug.Assert(valueLen == sizeof(ushort));
return value;
}
internal static unsafe void SetUShortParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param, ushort value)
{
ThrowHelper.ThrowIfMsQuicError(api.ApiTable->SetParam(
nativeObject.QuicHandle,
param,
sizeof(ushort),
(byte*)&value), "Could not set ushort");
}
internal static unsafe ulong GetULongParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param)
{
ulong value;
uint valueLen = (uint)sizeof(ulong);
ThrowHelper.ThrowIfMsQuicError(api.ApiTable->GetParam(
nativeObject.QuicHandle,
param,
&valueLen,
(byte*)&value), "GetULongParam failed");
Debug.Assert(valueLen == sizeof(ulong));
return value;
}
internal static unsafe void SetULongParam(MsQuicApi api, MsQuicSafeHandle nativeObject, uint param, ulong value)
{
ThrowHelper.ThrowIfMsQuicError(api.ApiTable->SetParam(
nativeObject.QuicHandle,
param,
sizeof(ulong),
(byte*)&value), "Could not set ulong");
}
}
}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;
namespace System.Net.Quic.Implementations.MsQuic.Internal
{
/// <summary>
/// A resettable completion source which can be completed multiple times.
/// Used to make methods async between completed events and their associated async method.
/// </summary>
internal sealed class ResettableCompletionSource<T> : IValueTaskSource<T>, IValueTaskSource
{
private ManualResetValueTaskSourceCore<T> _valueTaskSource;
public ResettableCompletionSource()
{
_valueTaskSource.RunContinuationsAsynchronously = true;
}
public ValueTask<T> GetValueTask()
{
return new ValueTask<T>(this, _valueTaskSource.Version);
}
public ValueTask GetTypelessValueTask()
{
return new ValueTask(this, _valueTaskSource.Version);
}
public ValueTaskSourceStatus GetStatus(short token)
{
return _valueTaskSource.GetStatus(token);
}
public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
{
_valueTaskSource.OnCompleted(continuation, state, token, flags);
}
public void Complete(T result)
{
_valueTaskSource.SetResult(result);
}
public void CompleteException(Exception ex)
{
_valueTaskSource.SetException(ex);
}
public T GetResult(short token)
{
bool isValid = token == _valueTaskSource.Version;
try
{
return _valueTaskSource.GetResult(token);
}
finally
{
if (isValid)
{
_valueTaskSource.Reset();
}
}
}
void IValueTaskSource.GetResult(short token)
{
bool isValid = token == _valueTaskSource.Version;
try
{
_valueTaskSource.GetResult(token);
}
finally
{
if (isValid)
{
_valueTaskSource.Reset();
}
}
}
}
}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using Microsoft.Quic;
using static Microsoft.Quic.MsQuic;
#if TARGET_WINDOWS
using Microsoft.Win32;
#endif
namespace System.Net.Quic;
internal sealed unsafe class MsQuicApi
{
private static readonly Version MinWindowsVersion = new Version(10, 0, 20145, 1000);
private static readonly Version MsQuicVersion = new Version(2, 0);
public MsQuicSafeHandle Registration { get; }
public QUIC_API_TABLE* ApiTable { get; }
// This is workaround for a bug in ILTrimmer.
// Without these DynamicDependency attributes, .ctor() will be removed from the safe handles.
// Remove once fixed: https://github.com/mono/linker/issues/1660
[DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(MsQuicSafeHandle))]
[DynamicDependency(DynamicallyAccessedMemberTypes.PublicConstructors, typeof(MsQuicContextSafeHandle))]
private MsQuicApi(QUIC_API_TABLE* apiTable)
{
ApiTable = apiTable;
fixed (byte* pAppName = "System.Net.Quic"u8)
{
var cfg = new QUIC_REGISTRATION_CONFIG
{
AppName = (sbyte*)pAppName,
ExecutionProfile = QUIC_EXECUTION_PROFILE.LOW_LATENCY
};
QUIC_HANDLE* handle;
ThrowHelper.ThrowIfMsQuicError(ApiTable->RegistrationOpen(&cfg, &handle), "RegistrationOpen failed");
Registration = new MsQuicSafeHandle(handle, apiTable->RegistrationClose, SafeHandleType.Registration);
}
}
internal static MsQuicApi Api { get; } = null!;
internal static bool IsQuicSupported { get; }
internal static bool Tls13ServerMayBeDisabled { get; }
internal static bool Tls13ClientMayBeDisabled { get; }
static MsQuicApi()
{
if (OperatingSystem.IsWindows())
{
if (!IsWindowsVersionSupported())
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"Current Windows version ({Environment.OSVersion}) is not supported by QUIC. Minimal supported version is {MinWindowsVersion}");
}
return;
}
Tls13ServerMayBeDisabled = IsTls13Disabled(true);
Tls13ClientMayBeDisabled = IsTls13Disabled(false);
}
IntPtr msQuicHandle;
if (NativeLibrary.TryLoad($"{Interop.Libraries.MsQuic}.{MsQuicVersion.Major}", typeof(MsQuicApi).Assembly, DllImportSearchPath.AssemblyDirectory, out msQuicHandle) ||
NativeLibrary.TryLoad(Interop.Libraries.MsQuic, typeof(MsQuicApi).Assembly, DllImportSearchPath.AssemblyDirectory, out msQuicHandle))
{
try
{
if (NativeLibrary.TryGetExport(msQuicHandle, "MsQuicOpenVersion", out IntPtr msQuicOpenVersionAddress))
{
QUIC_API_TABLE* apiTable;
delegate* unmanaged[Cdecl]<uint, QUIC_API_TABLE**, int> msQuicOpenVersion = (delegate* unmanaged[Cdecl]<uint, QUIC_API_TABLE**, int>)msQuicOpenVersionAddress;
if (StatusSucceeded(msQuicOpenVersion((uint)MsQuicVersion.Major, &apiTable)))
{
int arraySize = 4;
uint* libVersion = stackalloc uint[arraySize];
uint size = (uint)arraySize * sizeof(uint);
if (StatusSucceeded(apiTable->GetParam(null, QUIC_PARAM_GLOBAL_LIBRARY_VERSION, &size, libVersion)))
{
var version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]);
if (version >= MsQuicVersion)
{
Api = new MsQuicApi(apiTable);
IsQuicSupported = true;
}
else
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, $"Incompatible MsQuic library version '{version}', expecting '{MsQuicVersion}'");
}
}
}
}
}
}
finally
{
if (!IsQuicSupported)
{
NativeLibrary.Free(msQuicHandle);
}
}
}
}
private static bool IsWindowsVersionSupported() => OperatingSystem.IsWindowsVersionAtLeast(MinWindowsVersion.Major,
MinWindowsVersion.Minor, MinWindowsVersion.Build, MinWindowsVersion.Revision);
private static bool IsTls13Disabled(bool isServer)
{
#if TARGET_WINDOWS
string SChannelTls13RegistryKey = isServer
? @"SYSTEM\CurrentControlSet\Control\SecurityProviders\SCHANNEL\Protocols\TLS 1.3\Server"
: @"SYSTEM\CurrentControlSet\Control\SecurityProviders\SCHANNEL\Protocols\TLS 1.3\Client";
using var regKey = Registry.LocalMachine.OpenSubKey(SChannelTls13RegistryKey);
if (regKey is null)
{
return false;
}
if (regKey.GetValue("Enabled") is int enabled && enabled == 0)
{
return true;
}
if (regKey.GetValue("DisabledByDefault") is int disabled && disabled == 1)
{
return true;
}
#endif
return false;
}
}
......@@ -89,44 +89,6 @@ public void Initialize(ReadOnlyMemory<byte> buffer)
SetBuffer(0, buffer);
}
/// <summary>
/// Initializes QUIC_BUFFER* with the provided buffers.
/// Note that the struct either needs to be freshly created via new or previously cleaned up with Reset.
/// </summary>
/// <param name="buffers">Buffers to be passed to MsQuic as QUIC_BUFFER*.</param>
public void Initialize(ReadOnlySequence<byte> buffers)
{
int count = 0;
foreach (ReadOnlyMemory<byte> _ in buffers)
{
++count;
}
Reserve(count);
int i = 0;
foreach (ReadOnlyMemory<byte> buffer in buffers)
{
SetBuffer(i++, buffer);
}
}
/// <summary>
/// Initializes QUIC_BUFFER* with the provided buffers.
/// Note that the struct either needs to be freshly created via new or previously cleaned up with Reset.
/// </summary>
/// <param name="buffers">Buffers to be passed to MsQuic as QUIC_BUFFER*.</param>
public void Initialize(ReadOnlyMemory<ReadOnlyMemory<byte>> buffers)
{
int count = buffers.Length;
Reserve(count);
ReadOnlySpan<ReadOnlyMemory<byte>> span = buffers.Span;
for (int i = 0; i < span.Length; i++)
{
SetBuffer(i, span[i]);
}
}
/// <summary>
/// Unpins the managed memory and allows reuse of this struct.
/// </summary>
......
......@@ -81,7 +81,7 @@ public static MsQuicSafeHandle Create(QuicServerConnectionOptions options, strin
certificate ??= authenticationOptions.ServerCertificate ?? authenticationOptions.ServerCertificateSelectionCallback?.Invoke(authenticationOptions, targetHost);
if (certificate is null)
{
throw new ArgumentException($"Server must provide a certificate in '{nameof(SslServerAuthenticationOptions.ServerCertificate)}' or '{nameof(SslServerAuthenticationOptions.ServerCertificateContext)}' or via '{nameof(SslServerAuthenticationOptions.ServerCertificateSelectionCallback)}' for the connection.", nameof(options));
throw new ArgumentException(SR.Format(SR.net_quic_not_null_ceritifcate, nameof(SslServerAuthenticationOptions.ServerCertificate), nameof(SslServerAuthenticationOptions.ServerCertificateContext), nameof(SslServerAuthenticationOptions.ServerCertificateSelectionCallback)), nameof(options));
}
return Create(options, flags, certificate, intermediates, authenticationOptions.ApplicationProtocols, authenticationOptions.CipherSuitesPolicy, authenticationOptions.EncryptionPolicy);
......@@ -92,7 +92,7 @@ private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUI
// Validate options and SSL parameters.
if (alpnProtocols is null || alpnProtocols.Count <= 0)
{
throw new ArgumentException($"Expected at least one '{nameof(SslApplicationProtocol)}' for the connection.", nameof(options));
throw new ArgumentException(SR.Format(SR.net_quic_not_null_not_empty_connection, nameof(SslApplicationProtocol)), nameof(options));
}
#pragma warning disable SYSLIB0040 // NoEncryption and AllowNoEncryption are obsolete
......@@ -124,7 +124,8 @@ private static unsafe MsQuicSafeHandle Create(QuicConnectionOptions options, QUI
&settings,
(uint)sizeof(QUIC_SETTINGS),
(void*)IntPtr.Zero,
&handle), "ConfigurationOpen failed");
&handle),
"ConfigurationOpen failed");
MsQuicSafeHandle configurationHandle = new MsQuicSafeHandle(handle, MsQuicApi.Api.ApiTable->ConfigurationClose, SafeHandleType.Configuration);
try
......
......@@ -65,7 +65,8 @@ internal static unsafe T GetMsQuicParameter<T>(MsQuicSafeHandle handle, uint par
handle.QuicHandle,
parameter,
&length,
(byte*)&value));
(byte*)&value),
$"GetParam({handle}, {parameter}) failed");
return value;
}
......@@ -77,6 +78,7 @@ internal static unsafe void SetMsQuicParameter<T>(MsQuicSafeHandle handle, uint
handle.QuicHandle,
parameter,
(uint)sizeof(T),
(byte*)&value));
(byte*)&value),
$"SetParam({handle}, {parameter}) failed");
}
}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using Microsoft.Quic;
namespace System.Net.Quic;
internal struct ReceiveBuffers
{
private const int MaxBufferedBytes = 64 * 1024;
private readonly object _syncRoot;
private MultiArrayBuffer _buffer;
private bool _final;
public ReceiveBuffers()
{
_syncRoot = new object();
_buffer = default;
_final = default;
}
public void SetFinal()
{
lock (_syncRoot)
{
_final = true;
}
}
public int CopyFrom(ReadOnlySpan<QUIC_BUFFER> quicBuffers, int totalLength, bool final)
{
lock (_syncRoot)
{
if (_buffer.ActiveMemory.Length > MaxBufferedBytes - totalLength)
{
totalLength = MaxBufferedBytes - _buffer.ActiveMemory.Length;
final = false;
}
_final = final;
_buffer.EnsureAvailableSpace(totalLength);
int totalCopied = 0;
for (int i = 0; i < quicBuffers.Length; ++i)
{
Span<byte> quicBuffer = quicBuffers[i].Span;
if (totalLength < quicBuffer.Length)
{
quicBuffer = quicBuffer.Slice(0, totalLength);
}
_buffer.AvailableMemory.CopyFrom(quicBuffer);
_buffer.Commit(quicBuffer.Length);
totalCopied += quicBuffer.Length;
totalLength -= quicBuffer.Length;
}
return totalCopied;
}
}
public int CopyTo(Memory<byte> buffer, out bool isCompleted, out bool isEmpty)
{
lock (_syncRoot)
{
int copied = 0;
if (!_buffer.IsEmpty)
{
MultiMemory activeBuffer = _buffer.ActiveMemory;
copied = Math.Min(buffer.Length, activeBuffer.Length);
activeBuffer.Slice(0, copied).CopyTo(buffer.Span);
_buffer.Discard(copied);
}
isCompleted = _buffer.IsEmpty && _final;
isEmpty = _buffer.IsEmpty;
return copied;
}
}
}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;
namespace System.Net.Quic;
internal sealed class ResettableValueTaskSource : IValueTaskSource
{
// None -> [TryGetValueTask] -> Awaiting -> [TrySetResult|TrySetException(final: false)] -> Ready -> [GetResult] -> None
// None -> [TrySetResult|TrySetException(final: false)] -> Ready -> [TryGetValueTask] -> [GetResult] -> None
// None|Awaiting -> [TrySetResult|TrySetException(final: true)] -> Final(never leaves this state)
private enum State
{
None,
Awaiting,
Ready,
Completed
}
private State _state;
private ManualResetValueTaskSourceCore<bool> _valueTaskSource;
private CancellationTokenRegistration _cancellationRegistration;
private Action<object?>? _cancellationAction;
private GCHandle _keepAlive;
private FinalTaskSource _finalTaskSource;
public ResettableValueTaskSource(bool runContinuationsAsynchronously = true)
{
_state = State.None;
_valueTaskSource = new ManualResetValueTaskSourceCore<bool>() { RunContinuationsAsynchronously = runContinuationsAsynchronously };
_cancellationRegistration = default;
_keepAlive = default;
_finalTaskSource = new FinalTaskSource(runContinuationsAsynchronously);
}
/// <summary>
/// Allows setting additional cancellation action to be called if token passed to <see cref="TryGetValueTask(out ValueTask, object?, CancellationToken)"/> fires off.
/// The argument for the action is the <c>keepAlive</c> object from the same <see cref="TryGetValueTask(out ValueTask, object?, CancellationToken)"/> call.
/// </summary>
public Action<object?> CancellationAction { init { _cancellationAction = value; } }
/// <summary>
/// Returns <c>true</c> is this task source has entered its final state, i.e. <see cref="TryComplete(Exception?, bool)"/> or <see cref="TrySetException(Exception, bool)"/>
/// was called with <c>final</c> set to <c>true</c> and the result was propagated.
/// </summary>
public bool IsCompleted => (State)Volatile.Read(ref Unsafe.As<State, byte>(ref _state)) == State.Completed;
public bool TryGetValueTask(out ValueTask valueTask, object? keepAlive = null, CancellationToken cancellationToken = default)
{
lock (this)
{
// Cancellation might kick off synchronously, re-entering the lock and changing the state to completed.
if (_state == State.None)
{
// Register cancellation if the token can be cancelled and the task is not completed yet.
if (cancellationToken.CanBeCanceled)
{
_cancellationRegistration = cancellationToken.UnsafeRegister(static (obj, cancellationToken) =>
{
(ResettableValueTaskSource parent, object? target) = ((ResettableValueTaskSource, object?))obj!;
if (parent.TrySetException(new OperationCanceledException(cancellationToken)))
{
parent._cancellationAction?.Invoke(target);
}
}, (this, keepAlive));
}
}
State state = _state;
// None: prepare for the actual operation happening and transition to Awaiting.
if (state == State.None)
{
// Keep alive the caller object until the result is read from the task.
// Used for keeping caller alive during async interop calls.
if (keepAlive is not null)
{
Debug.Assert(!_keepAlive.IsAllocated);
_keepAlive = GCHandle.Alloc(keepAlive);
}
_state = State.Awaiting;
}
// None, Completed, Final: return the current task.
if (state == State.None ||
state == State.Ready ||
state == State.Completed)
{
valueTask = new ValueTask(this, _valueTaskSource.Version);
return true;
}
// Awaiting: forbidden concurrent call.
valueTask = default;
return false;
}
}
public Task GetFinalTask() => _finalTaskSource.Task;
private bool TryComplete(Exception? exception, bool final)
{
CancellationTokenRegistration cancellationRegistration = default;
try
{
lock (this)
{
try
{
State state = _state;
// None,Awaiting: clean up and finish the task source.
if (state == State.Awaiting ||
state == State.None)
{
_state = final ? State.Completed : State.Ready;
// Swap the cancellation registration so the one that's been registered gets eventually Disposed.
// Ideally, we would dispose it here, but if the callbacks kicks in, it tries to take the lock held by this thread leading to deadlock.
cancellationRegistration = _cancellationRegistration;
_cancellationRegistration = default;
// Unblock the current task source and in case of a final also the final task source.
if (exception is not null)
{
// Set up the exception stack strace for the caller.
exception = exception.StackTrace is null ? ExceptionDispatchInfo.SetCurrentStackTrace(exception) : exception;
_valueTaskSource.SetException(exception);
}
else
{
_valueTaskSource.SetResult(final);
}
if (final)
{
_finalTaskSource.TryComplete(exception);
_finalTaskSource.TrySignal(out _);
}
return true;
}
// Final: remember the first final result to set it once the current non-final result gets retrieved.
if (final)
{
return _finalTaskSource.TryComplete(exception);
}
return false;
}
finally
{
// Un-root the the kept alive object in all cases.
if (_keepAlive.IsAllocated)
{
_keepAlive.Free();
}
}
}
}
finally
{
// Dispose the cancellation if registered.
// Must be done outside of lock since Dispose will wait on pending cancellation callbacks which require taking the lock.
cancellationRegistration.Dispose();
}
}
public bool TrySetResult(bool final = false)
{
return TryComplete(null, final);
}
public bool TrySetException(Exception exception, bool final = false)
{
return TryComplete(exception, final);
}
ValueTaskSourceStatus IValueTaskSource.GetStatus(short token)
=> _valueTaskSource.GetStatus(token);
void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
=> _valueTaskSource.OnCompleted(continuation, state, token, flags);
void IValueTaskSource.GetResult(short token)
{
try
{
_valueTaskSource.GetResult(token);
}
finally
{
lock (this)
{
State state = _state;
if (state == State.Ready)
{
_valueTaskSource.Reset();
if (_finalTaskSource.TrySignal(out Exception? exception))
{
_state = State.Completed;
if (exception is not null)
{
_valueTaskSource.SetException(exception);
}
else
{
_valueTaskSource.SetResult(true);
}
}
else
{
_state = State.None;
}
}
}
}
}
private struct FinalTaskSource
{
private TaskCompletionSource _finalTaskSource;
private bool _isCompleted;
private Exception? _exception;
public FinalTaskSource(bool runContinuationsAsynchronously = true)
{
// TODO: defer instantiation only after Task is retrieved
_finalTaskSource = new TaskCompletionSource(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None);
_isCompleted = false;
_exception = null;
}
public Task Task => _finalTaskSource.Task;
public bool TryComplete(Exception? exception = null)
{
if (_isCompleted)
{
return false;
}
_exception = exception;
_isCompleted = true;
return true;
}
public bool TrySignal(out Exception? exception)
{
if (!_isCompleted)
{
exception = default;
return false;
}
if (_exception is not null)
{
_finalTaskSource.SetException(_exception);
}
else
{
_finalTaskSource.SetResult();
}
exception = _exception;
return true;
}
}
}
......@@ -36,6 +36,7 @@ public ValueTaskSource(bool runContinuationsAsynchronously = true)
}
public bool IsCompleted => (State)Volatile.Read(ref Unsafe.As<State, byte>(ref _state)) == State.Completed;
public bool IsCompletedSuccessfully => IsCompleted && _valueTaskSource.GetStatus(_valueTaskSource.Version) == ValueTaskSourceStatus.Succeeded;
public bool TryInitialize(out ValueTask valueTask, object? keepAlive = null, CancellationToken cancellationToken = default)
{
......@@ -64,6 +65,7 @@ public bool TryInitialize(out ValueTask valueTask, object? keepAlive = null, Can
if (state == State.None)
{
// Keep alive the caller object until the result is read from the task.
// Used for keeping caller alive during async interop calls.
if (keepAlive is not null)
{
Debug.Assert(!_keepAlive.IsAllocated);
......
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
namespace System.Net.Quic;
/// <summary>
/// Specifies direction of the <see cref="QuicStream"/> which is to be <see cref="QuicStream.Abort(QuicAbortDirection, long)">aborted</see>.
/// </summary>
[Flags]
public enum QuicAbortDirection
{
/// <summary>
/// Abort read side of the stream.
/// </summary>
Read = 1,
/// <summary>
/// Abort write side of the stream.
/// </summary>
Write = 2,
/// <summary>
/// Abort both sides of the stream, i.e.: <see cref="Read"/> and <see cref="Write"/>) at the same time.
/// </summary>
Both = Read | Write
}
......@@ -107,17 +107,16 @@ public static async ValueTask<QuicConnection> ConnectAsync(QuicClientConnectionO
/// </summary>
private MsQuicSafeHandle? _configuration;
/// <summary>
/// Set when SHUTDOWN_INITIATED_BY_PEER is received.
/// </summary>
private long _abortErrorCode = -1;
/// <summary>
/// Used by <see cref="AcceptInboundStreamAsync(CancellationToken)" /> to throw in case no stream can be opened from the peer.
/// <c>true</c> when at least one of <see cref="QuicConnectionOptions.MaxInboundBidirectionalStreams" /> or <see cref="QuicConnectionOptions.MaxInboundUnidirectionalStreams" /> is greater than <c>0</c>.
/// </summary>
private bool _canAccept;
/// <summary>
/// From <see cref="QuicConnectionOptions.DefaultStreamErrorCode"/>, passed to newly created <see cref="QuicStream"/>.
/// </summary>
private long _defaultStreamErrorCode;
/// <summary>
/// From <see cref="QuicConnectionOptions.DefaultCloseErrorCode"/>, used to close connection in <see cref="DisposeAsync"/>.
/// </summary>
private long _defaultCloseErrorCode;
......@@ -193,7 +192,8 @@ private unsafe QuicConnection()
MsQuicApi.Api.Registration.QuicHandle,
&NativeCallback,
(void*)GCHandle.ToIntPtr(context),
&handle));
&handle),
"ConnectionOpen failed");
_handle = new MsQuicContextSafeHandle(handle, context, MsQuicApi.Api.ApiTable->ConnectionClose, SafeHandleType.Connection);
}
catch
......@@ -237,11 +237,12 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options,
if (_connectedTcs.TryInitialize(out ValueTask valueTask, this, cancellationToken))
{
_canAccept = options.MaxInboundBidirectionalStreams > 0 || options.MaxInboundUnidirectionalStreams > 0;
_defaultStreamErrorCode = options.DefaultStreamErrorCode;
_defaultCloseErrorCode = options.DefaultCloseErrorCode;
if (!options.RemoteEndPoint.TryParse(out string? host, out IPAddress? address, out int port))
{
throw new ArgumentException($"Unsupported remote endpoint type '{options.RemoteEndPoint.GetType()}', expected IP or DNS endpoint.", nameof(options));
throw new ArgumentException(SR.Format(SR.net_quic_unsupported_endpoint_type, options.RemoteEndPoint.GetType()), nameof(options));
}
int addressFamily = QUIC_ADDRESS_FAMILY_UNSPEC;
......@@ -306,7 +307,8 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options,
_configuration.QuicHandle,
(ushort)addressFamily,
(sbyte*)targetHostPtr,
(ushort)port));
(ushort)port),
"ConnectionStart failed");
}
}
finally
......@@ -325,6 +327,7 @@ internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, str
if (_connectedTcs.TryInitialize(out ValueTask valueTask, this, cancellationToken))
{
_canAccept = options.MaxInboundBidirectionalStreams > 0 || options.MaxInboundUnidirectionalStreams > 0;
_defaultStreamErrorCode = options.DefaultStreamErrorCode;
_defaultCloseErrorCode = options.DefaultCloseErrorCode;
_sslConnectionOptions = new SslConnectionOptions(
......@@ -340,7 +343,8 @@ internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, str
{
ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionSetConfiguration(
_handle.QuicHandle,
_configuration.QuicHandle));
_configuration.QuicHandle),
"ConnectionSetConfiguration failed");
}
}
......@@ -359,14 +363,23 @@ public async ValueTask<QuicStream> OpenOutboundStreamAsync(QuicStreamType type,
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);
QuicStream stream = new QuicStream(new Implementations.MsQuic.MsQuicStream(_state, _handle, type));
QuicStream? stream = null;
try
{
stream = new QuicStream(_state, _handle, type, _defaultStreamErrorCode);
await stream.StartAsync(cancellationToken).ConfigureAwait(false);
}
catch
{
await stream.DisposeAsync().ConfigureAwait(false);
if (stream is not null)
{
await stream.DisposeAsync().ConfigureAwait(false);
}
// Propagate connection error if present.
if (_acceptQueue.Reader.Completion.IsFaulted)
{
await _acceptQueue.Reader.Completion.ConfigureAwait(false);
}
throw;
}
return stream;
......@@ -449,21 +462,16 @@ private unsafe int HandleEventConnected(ref CONNECTED_DATA data)
}
private unsafe int HandleEventShutdownInitiatedByTransport(ref SHUTDOWN_INITIATED_BY_TRANSPORT_DATA data)
{
_connectedTcs.TrySetException(ThrowHelper.GetExceptionForMsQuicStatus(data.Status));
// To throw QuicConnectionAbortedException (instead of QuicOperationAbortedException) out of AcceptStreamAsync() since
// it wasn't our side who shutdown the connection.
// We should rather keep the Status and propagate it either in a different exception or as a different field of QuicConnectionAbortedException.
// See: https://github.com/dotnet/runtime/issues/60133
_abortErrorCode = 0;
_state.AbortErrorCode = _abortErrorCode;
_acceptQueue.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetConnectionAbortedException(_abortErrorCode)));
_state.AbortErrorCode = 0;
Exception exception = ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetExceptionForMsQuicStatus(data.Status));
_connectedTcs.TrySetException(exception);
_acceptQueue.Writer.TryComplete(exception);
return QUIC_STATUS_SUCCESS;
}
private unsafe int HandleEventShutdownInitiatedByPeer(ref SHUTDOWN_INITIATED_BY_PEER_DATA data)
{
_abortErrorCode = (long)data.ErrorCode;
_state.AbortErrorCode = _abortErrorCode;
_acceptQueue.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetConnectionAbortedException(_abortErrorCode)));
_state.AbortErrorCode = (long)data.ErrorCode;
_acceptQueue.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetConnectionAbortedException((long)data.ErrorCode)));
return QUIC_STATUS_SUCCESS;
}
private unsafe int HandleEventShutdownComplete(ref SHUTDOWN_COMPLETE_DATA data)
......@@ -484,7 +492,7 @@ private unsafe int HandleEventPeerAddressChanged(ref PEER_ADDRESS_CHANGED_DATA d
}
private unsafe int HandleEventPeerStreamStarted(ref PEER_STREAM_STARTED_DATA data)
{
QuicStream stream = new QuicStream(new Implementations.MsQuic.MsQuicStream(_state, _handle, data.Stream, data.Flags));
QuicStream stream = new QuicStream(_state, _handle, data.Stream, data.Flags, _defaultStreamErrorCode);
if (!_acceptQueue.Writer.TryWrite(stream))
{
if (NetEventSource.Log.IsEnabled())
......@@ -583,6 +591,7 @@ public async ValueTask DisposeAsync()
}
}
// Wait for SHUTDOWN_COMPLETE, the last event, so that all resources can be safely released.
await valueTask.ConfigureAwait(false);
_handle.Dispose();
......
......@@ -14,7 +14,7 @@ public enum QuicError
Success,
/// <summary>
/// An internal implementation error has occured.
/// An internal implementation error has occurred.
/// </summary>
InternalError,
......
......@@ -271,6 +271,7 @@ public async ValueTask DisposeAsync()
}
}
// Wait for STOP_COMPLETE, the last event, so that all resources can be safely released.
await valueTask.ConfigureAwait(false);
_handle.Dispose();
......
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
namespace System.Net.Quic;
// Boilerplate implementation of Stream methods.
public partial class QuicStream : Stream
{
// Seek and length.
public override bool CanSeek => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
// Read and Write timeouts.
public override bool CanTimeout => true;
private TimeSpan _readTimeout = Timeout.InfiniteTimeSpan;
private TimeSpan _writeTimeout = Timeout.InfiniteTimeSpan;
public override int ReadTimeout
{
get
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);
return (int)_readTimeout.TotalMilliseconds;
}
set
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);
if (value <= 0 && value != Timeout.Infinite)
{
throw new ArgumentOutOfRangeException(nameof(value), SR.net_quic_timeout_use_gt_zero);
}
_readTimeout = TimeSpan.FromMilliseconds(value);
}
}
public override int WriteTimeout
{
get
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);
return (int)_writeTimeout.TotalMilliseconds;
}
set
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);
if (value <= 0 && value != Timeout.Infinite)
{
throw new ArgumentOutOfRangeException(nameof(value), SR.net_quic_timeout_use_gt_zero);
}
_writeTimeout = TimeSpan.FromMilliseconds(value);
}
}
// Read boilerplate.
public override bool CanRead => Volatile.Read(ref _disposed) == 0 && _canRead;
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
=> TaskToApm.Begin(ReadAsync(buffer, offset, count, default), callback, state);
public override int EndRead(IAsyncResult asyncResult)
=> TaskToApm.End<int>(asyncResult);
public override int Read(byte[] buffer, int offset, int count)
{
ValidateBufferArguments(buffer, offset, count);
return Read(buffer.AsSpan(offset, count));
}
public override int ReadByte()
{
byte b = 0;
return Read(MemoryMarshal.CreateSpan(ref b, 1)) != 0 ? b : -1;
}
public override int Read(Span<byte> buffer)
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);
byte[] rentedBuffer = ArrayPool<byte>.Shared.Rent(buffer.Length);
CancellationTokenSource? cts = null;
try
{
if (_readTimeout > TimeSpan.Zero)
{
cts = new CancellationTokenSource(_readTimeout);
}
int readLength = ReadAsync(new Memory<byte>(rentedBuffer, 0, buffer.Length), cts?.Token ?? default).AsTask().GetAwaiter().GetResult();
rentedBuffer.AsSpan(0, readLength).CopyTo(buffer);
return readLength;
}
catch (OperationCanceledException) when (cts?.IsCancellationRequested == true)
{
// sync operations do not have Cancellation
throw new IOException(SR.net_quic_timeout);
}
finally
{
ArrayPool<byte>.Shared.Return(rentedBuffer);
cts?.Dispose();
}
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
{
ValidateBufferArguments(buffer, offset, count);
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
// Write boilerplate.
public override bool CanWrite => Volatile.Read(ref _disposed) == 0 && _canWrite;
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
=> TaskToApm.Begin(WriteAsync(buffer, offset, count, default), callback, state);
public override void EndWrite(IAsyncResult asyncResult)
=> TaskToApm.End(asyncResult);
public override void Write(byte[] buffer, int offset, int count)
{
ValidateBufferArguments(buffer, offset, count);
Write(buffer.AsSpan(offset, count));
}
public override void WriteByte(byte value)
{
Write(MemoryMarshal.CreateReadOnlySpan(ref value, 1));
}
public override void Write(ReadOnlySpan<byte> buffer)
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);
CancellationTokenSource? cts = null;
if (_writeTimeout > TimeSpan.Zero)
{
cts = new CancellationTokenSource(_writeTimeout);
}
try
{
WriteAsync(buffer.ToArray(), cts?.Token ?? default).AsTask().GetAwaiter().GetResult();
}
catch (OperationCanceledException) when (cts?.IsCancellationRequested == true)
{
// sync operations do not have Cancellation
throw new IOException(SR.net_quic_timeout);
}
finally
{
cts?.Dispose();
}
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
{
ValidateBufferArguments(buffer, offset, count);
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
// Flush.
public override void Flush()
=> FlushAsync().GetAwaiter().GetResult();
public override Task FlushAsync(CancellationToken cancellationToken = default)
// NOP for now
=> Task.CompletedTask;
// Dispose.
protected override void Dispose(bool disposing)
{
DisposeAsync().AsTask().GetAwaiter().GetResult();
base.Dispose(disposing);
}
}
......@@ -440,12 +440,12 @@ public async Task OpenStreamAsync_BlocksUntilAvailable(bool unidirectional)
Assert.False(waitTask.IsCompleted);
// Close the streams, the waitTask should finish as a result.
stream.Dispose();
await stream.DisposeAsync();
QuicStream newStream = await serverConnection.AcceptInboundStreamAsync();
newStream.Dispose();
await newStream.DisposeAsync();
newStream = await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10));
newStream.Dispose();
await newStream.DisposeAsync();
await clientConnection.DisposeAsync();
await serverConnection.DisposeAsync();
......@@ -488,13 +488,13 @@ public async Task OpenStreamAsync_Canceled_Throws_OperationCanceledException(boo
Assert.Equal(cts.Token, ex.CancellationToken);
// Close the streams, the waitTask should finish as a result.
stream.Dispose();
await stream.DisposeAsync();
QuicStream newStream = await serverConnection.AcceptInboundStreamAsync();
newStream.Dispose();
await newStream.DisposeAsync();
// next call should work as intended
newStream = await OpenStreamAsync(clientConnection).AsTask().WaitAsync(TimeSpan.FromSeconds(10));
newStream.Dispose();
await newStream.DisposeAsync();
await clientConnection.DisposeAsync();
await serverConnection.DisposeAsync();
......@@ -588,7 +588,7 @@ public async Task SetListenerTimeoutWorksWithSmallTimeout()
[Theory]
[MemberData(nameof(WriteData))]
public async Task WriteTests(int[][] writes, WriteType writeType)
public async Task WriteTests(int[][] writes)
{
await RunClientServer(
async clientConnection =>
......@@ -597,34 +597,13 @@ public async Task WriteTests(int[][] writes, WriteType writeType)
foreach (int[] bufferLengths in writes)
{
switch (writeType)
foreach (int bufferLength in bufferLengths)
{
case WriteType.SingleBuffer:
foreach (int bufferLength in bufferLengths)
{
await stream.WriteAsync(new byte[bufferLength]);
}
break;
case WriteType.GatheredSequence:
var firstSegment = new BufferSegment(new byte[bufferLengths[0]]);
BufferSegment lastSegment = firstSegment;
foreach (int bufferLength in bufferLengths.Skip(1))
{
lastSegment = lastSegment.Append(new byte[bufferLength]);
}
var buffer = new ReadOnlySequence<byte>(firstSegment, 0, lastSegment, lastSegment.Memory.Length);
await stream.WriteAsync(buffer);
break;
default:
Debug.Fail("Unknown write type.");
break;
await stream.WriteAsync(new byte[bufferLength]);
}
}
stream.Shutdown();
await stream.ShutdownCompleted();
stream.CompleteWrites();
},
async serverConnection =>
{
......@@ -641,8 +620,7 @@ public async Task WriteTests(int[][] writes, WriteType writeType)
int expectedTotalBytes = writes.SelectMany(x => x).Sum();
Assert.Equal(expectedTotalBytes, totalBytes);
stream.Shutdown();
await stream.ShutdownCompleted();
stream.CompleteWrites();
});
}
......@@ -653,7 +631,6 @@ public static IEnumerable<object[]> WriteData()
return
from bufferCount in new[] { 1, 2, 3, 10 }
from writeType in Enum.GetValues<WriteType>()
let writes =
Enumerable.Range(0, 5)
.Select(_ =>
......@@ -661,13 +638,7 @@ public static IEnumerable<object[]> WriteData()
.Select(_ => bufferSizes[r.Next(bufferSizes.Length)])
.ToArray())
.ToArray()
select new object[] { writes, writeType };
}
public enum WriteType
{
SingleBuffer,
GatheredSequence
select new object[] { writes };
}
[Fact]
......@@ -676,12 +647,10 @@ public async Task CallDifferentWriteMethodsWorks()
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
ReadOnlyMemory<byte> helloWorld = "Hello world!"u8.ToArray();
ReadOnlySequence<byte> ros = CreateReadOnlySequenceFromBytes(helloWorld.ToArray());
Assert.False(ros.IsSingleSegment);
using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
ValueTask writeTask = clientStream.WriteAsync(ros);
using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync();
await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
ValueTask writeTask = clientStream.WriteAsync(helloWorld, completeWrites: true);
await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync();
await writeTask;
byte[] memory = new byte[24];
......@@ -838,9 +807,7 @@ async Task RunTest(byte[] data)
{
await stream.WriteAsync(data[pos..(pos + writeSize)]);
}
await stream.WriteAsync(Memory<byte>.Empty, endStream: true);
await stream.ShutdownCompleted();
await stream.WriteAsync(Memory<byte>.Empty, completeWrites: true);
},
clientFunction: async connection =>
{
......@@ -850,14 +817,12 @@ async Task RunTest(byte[] data)
{
await stream.WriteAsync(data[pos..(pos + writeSize)]);
}
await stream.WriteAsync(Memory<byte>.Empty, endStream: true);
await stream.WriteAsync(Memory<byte>.Empty, completeWrites: true);
byte[] buffer = new byte[data.Length];
int bytesRead = await ReadAll(stream, buffer);
Assert.Equal(data.Length, bytesRead);
AssertExtensions.SequenceEqual(data, buffer);
await stream.ShutdownCompleted();
}
);
}
......@@ -872,8 +837,8 @@ async Task GetStreamIdWithoutStartWorks()
{
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
Assert.Equal(0, clientStream.StreamId);
await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
Assert.Equal(0, clientStream.Id);
// TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer
await clientConnection.DisposeAsync();
......@@ -892,8 +857,8 @@ async Task GetStreamIdWithoutStartWorks()
{
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
Assert.Equal(0, clientStream.StreamId);
await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
Assert.Equal(0, clientStream.Id);
// Dispose all connections before the streams;
await clientConnection.DisposeAsync();
......@@ -923,6 +888,7 @@ public async Task Read_ConnectionAbortedByPeer_Throws()
await clientConnection.CloseAsync(ExpectedErrorCode);
byte[] buffer = new byte[100];
await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => clientStream.ReadAsync(buffer).AsTask());
QuicException ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, () => serverStream.ReadAsync(buffer).AsTask());
Assert.Equal(ExpectedErrorCode, ex.ApplicationErrorCode);
}).WaitAsync(TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds));
......@@ -931,6 +897,8 @@ public async Task Read_ConnectionAbortedByPeer_Throws()
[Fact]
public async Task Read_ConnectionAbortedByUser_Throws()
{
const int ExpectedErrorCode = 1234;
await Task.Run(async () =>
{
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
......@@ -941,9 +909,11 @@ public async Task Read_ConnectionAbortedByUser_Throws()
await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync();
await serverStream.ReadAsync(new byte[1]);
await serverConnection.CloseAsync(0);
await serverConnection.CloseAsync(ExpectedErrorCode);
byte[] buffer = new byte[100];
QuicException ex = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, () => clientStream.ReadAsync(buffer).AsTask());
Assert.Equal(ExpectedErrorCode, ex.ApplicationErrorCode);
await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => serverStream.ReadAsync(buffer).AsTask());
}).WaitAsync(TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds));
}
......@@ -991,7 +961,7 @@ await using (serverConnection)
if (!closeWithData)
{
serverStream.Shutdown();
serverStream.CompleteWrites();
}
readLength = await clientStream.ReadAsync(actual);
......@@ -1009,34 +979,31 @@ public async Task BasicTest_WithReadsCompletedCheck()
iterations: 100,
serverFunction: async connection =>
{
using QuicStream stream = await connection.AcceptInboundStreamAsync();
Assert.False(stream.ReadsCompleted);
await using QuicStream stream = await connection.AcceptInboundStreamAsync();
Assert.False(stream.ReadsClosed.IsCompleted);
byte[] buffer = new byte[s_data.Length];
int bytesRead = await ReadAll(stream, buffer);
Assert.True(stream.ReadsCompleted);
Assert.True(stream.ReadsClosed.IsCompletedSuccessfully);
Assert.Equal(s_data.Length, bytesRead);
Assert.Equal(s_data, buffer);
await stream.WriteAsync(s_data, endStream: true);
await stream.ShutdownCompleted();
await stream.WriteAsync(s_data, completeWrites: true);
},
clientFunction: async connection =>
{
using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
Assert.False(stream.ReadsCompleted);
await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
Assert.False(stream.ReadsClosed.IsCompleted);
await stream.WriteAsync(s_data, endStream: true);
await stream.WriteAsync(s_data, completeWrites: true);
byte[] buffer = new byte[s_data.Length];
int bytesRead = await ReadAll(stream, buffer);
Assert.True(stream.ReadsCompleted);
Assert.True(stream.ReadsClosed.IsCompletedSuccessfully);
Assert.Equal(s_data.Length, bytesRead);
Assert.Equal(s_data, buffer);
await stream.ShutdownCompleted();
}
);
}
......@@ -1047,22 +1014,22 @@ public async Task Read_ReadsCompleted_ReportedBeforeReturning0()
await RunBidirectionalClientServer(
async clientStream =>
{
await clientStream.WriteAsync(new byte[1], endStream: true);
await clientStream.WriteAsync(new byte[1], completeWrites: true);
},
async serverStream =>
{
Assert.False(serverStream.ReadsCompleted);
Assert.False(serverStream.ReadsClosed.IsCompleted);
var received = await serverStream.ReadAsync(new byte[1]);
Assert.Equal(1, received);
Assert.True(serverStream.ReadsCompleted);
Assert.True(serverStream.ReadsClosed.IsCompletedSuccessfully);
var task = serverStream.ReadAsync(new byte[1]);
Assert.True(task.IsCompleted);
received = await task;
Assert.Equal(0, received);
Assert.True(serverStream.ReadsCompleted);
Assert.True(serverStream.ReadsClosed.IsCompletedSuccessfully);
});
}
}
......
......@@ -189,7 +189,7 @@ public async Task CloseAsync_WithOpenStream_LocalAndPeerStreamsFailWithQuicOpera
await RunClientServer(
async clientConnection =>
{
using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await DoWrites(clientStream, writesBeforeClose);
// Wait for peer to receive data
......@@ -202,7 +202,7 @@ public async Task CloseAsync_WithOpenStream_LocalAndPeerStreamsFailWithQuicOpera
},
async serverConnection =>
{
using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync();
await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync();
await DoReads(serverStream, writesBeforeClose);
sync.Release();
......@@ -269,7 +269,7 @@ public async Task Dispose_WithOpenLocalStream_LocalStreamFailsWithQuicOperationA
await RunClientServer(
async clientConnection =>
{
using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await using QuicStream clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await DoWrites(clientStream, writesBeforeClose);
// Wait for peer to receive data
......@@ -282,7 +282,7 @@ public async Task Dispose_WithOpenLocalStream_LocalStreamFailsWithQuicOperationA
},
async serverConnection =>
{
using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync();
await using QuicStream serverStream = await serverConnection.AcceptInboundStreamAsync();
await DoReads(serverStream, writesBeforeClose);
sync.Release();
......
......@@ -201,9 +201,9 @@ internal async Task<(QuicConnection, QuicConnection)> CreateConnectedQuicConnect
internal async Task PingPong(QuicConnection client, QuicConnection server)
{
using QuicStream clientStream = await client.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await using QuicStream clientStream = await client.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
ValueTask t = clientStream.WriteAsync(s_ping);
using QuicStream serverStream = await server.AcceptInboundStreamAsync();
await using QuicStream serverStream = await server.AcceptInboundStreamAsync();
byte[] buffer = new byte[s_ping.Length];
int remains = s_ping.Length;
......@@ -294,8 +294,7 @@ internal async Task RunStreamClientServer(Func<QuicStream, Task> clientFunction,
await clientFunction(stream);
stream.Shutdown();
await stream.ShutdownCompleted();
stream.CompleteWrites();
},
serverFunction: async connection =>
{
......@@ -304,8 +303,7 @@ internal async Task RunStreamClientServer(Func<QuicStream, Task> clientFunction,
await serverFunction(stream);
stream.Shutdown();
await stream.ShutdownCompleted();
stream.CompleteWrites();
},
iterations,
millisecondsTimeout
......
......@@ -268,7 +268,7 @@ public async Task Connect_ViaProxy_ProxyTunnelRequestIssued(string scheme)
// Send non-success error code so that SocketsHttpHandler won't retry.
await connection.SendResponseAsync(statusCode: HttpStatusCode.Forbidden);
connection.Dispose();
await connection.DisposeAsync();
}));
Assert.True(connectionAccepted);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册