未验证 提交 a822d390 编写于 作者: M Marie Píchová 提交者: GitHub

QUIC stream limits (#52704)

Implements the 3rd option Allowing the caller to perform their own wait from #32079 (comment)
Adds WaitForAvailable(Bidi|Uni)rectionalStreamsAsync:
- triggered by peer announcement about new streams (QUIC_CONNECTION_EVENT_TYPE.STREAMS_AVAILABLE)
- if the connection is closed/disposed, the method throws QuicConnectionAbortedException which fitted our H3 better than boolean (can be changed)
Changes stream limit type to int
上级 a9d2f032
......@@ -20,26 +20,32 @@ public sealed class Http3LoopbackServer : GenericLoopbackServer
public override Uri Address => new Uri($"https://{_listener.ListenEndPoint}/");
public Http3LoopbackServer(QuicImplementationProvider quicImplementationProvider = null, GenericLoopbackOptions options = null)
public Http3LoopbackServer(QuicImplementationProvider quicImplementationProvider = null, Http3Options options = null)
{
options ??= new GenericLoopbackOptions();
options ??= new Http3Options();
_cert = Configuration.Certificates.GetServerCertificate();
var sslOpts = new SslServerAuthenticationOptions
var listenerOptions = new QuicListenerOptions()
{
EnabledSslProtocols = options.SslProtocols,
ApplicationProtocols = new List<SslApplicationProtocol>
ListenEndPoint = new IPEndPoint(options.Address, 0),
ServerAuthenticationOptions = new SslServerAuthenticationOptions
{
new SslApplicationProtocol("h3-31"),
new SslApplicationProtocol("h3-30"),
new SslApplicationProtocol("h3-29")
EnabledSslProtocols = options.SslProtocols,
ApplicationProtocols = new List<SslApplicationProtocol>
{
new SslApplicationProtocol("h3-31"),
new SslApplicationProtocol("h3-30"),
new SslApplicationProtocol("h3-29")
},
ServerCertificate = _cert,
ClientCertificateRequired = false
},
ServerCertificate = _cert,
ClientCertificateRequired = false
MaxUnidirectionalStreams = options.MaxUnidirectionalStreams,
MaxBidirectionalStreams = options.MaxBidirectionalStreams,
};
_listener = new QuicListener(quicImplementationProvider ?? QuicImplementationProviders.Default, new IPEndPoint(options.Address, 0), sslOpts);
_listener = new QuicListener(quicImplementationProvider ?? QuicImplementationProviders.Default, listenerOptions);
}
public override void Dispose()
......@@ -82,7 +88,7 @@ public Http3LoopbackServerFactory(QuicImplementationProvider quicImplementationP
public override GenericLoopbackServer CreateServer(GenericLoopbackOptions options = null)
{
return new Http3LoopbackServer(_quicImplementationProvider, options);
return new Http3LoopbackServer(_quicImplementationProvider, CreateOptions(options));
}
public override async Task CreateServerAsync(Func<GenericLoopbackServer, Uri, Task> funcAsync, int millisecondsTimeout = 60000, GenericLoopbackOptions options = null)
......@@ -97,5 +103,29 @@ public override Task<GenericLoopbackConnection> CreateConnectionAsync(Socket soc
// This method is always unacceptable to call for HTTP/3.
throw new NotImplementedException("HTTP/3 does not operate over a Socket.");
}
private static Http3Options CreateOptions(GenericLoopbackOptions options)
{
Http3Options http3Options = new Http3Options();
if (options != null)
{
http3Options.Address = options.Address;
http3Options.UseSsl = options.UseSsl;
http3Options.SslProtocols = options.SslProtocols;
http3Options.ListenBacklog = options.ListenBacklog;
}
return http3Options;
}
}
public class Http3Options : GenericLoopbackOptions
{
public int MaxUnidirectionalStreams {get; set; }
public int MaxBidirectionalStreams {get; set; }
public Http3Options()
{
MaxUnidirectionalStreams = 100;
MaxBidirectionalStreams = 100;
}
}
}
......@@ -49,11 +49,6 @@ internal sealed class Http3Connection : HttpConnectionBase, IDisposable
private int _haveServerQpackDecodeStream;
private int _haveServerQpackEncodeStream;
// Manages MAX_STREAM count from server.
private long _maximumRequestStreams;
private long _requestStreamsRemaining;
private readonly Queue<TaskCompletionSourceWithCancellation<bool>> _waitingRequests = new Queue<TaskCompletionSourceWithCancellation<bool>>();
// A connection-level error will abort any future operations.
private Exception? _abortException;
......@@ -87,8 +82,6 @@ public Http3Connection(HttpConnectionPool pool, HttpAuthority? origin, HttpAutho
string altUsedValue = altUsedDefaultPort ? authority.IdnHost : authority.IdnHost + ":" + authority.Port.ToString(Globalization.CultureInfo.InvariantCulture);
_altUsedEncodedHeader = QPack.QPackEncoder.EncodeLiteralHeaderFieldWithoutNameReferenceToArray(KnownHeaders.AltUsed.Name, altUsedValue);
_maximumRequestStreams = _requestStreamsRemaining = connection.GetRemoteAvailableBidirectionalStreamCount();
// Errors are observed via Abort().
_ = SendSettingsAsync();
......@@ -166,45 +159,34 @@ public override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage req
{
Debug.Assert(async);
// Wait for an available stream (based on QUIC MAX_STREAMS) if there isn't one available yet.
TaskCompletionSourceWithCancellation<bool>? waitForAvailableStreamTcs = null;
lock (SyncObj)
{
long remaining = _requestStreamsRemaining;
if (remaining > 0)
{
_requestStreamsRemaining = remaining - 1;
}
else
{
waitForAvailableStreamTcs = new TaskCompletionSourceWithCancellation<bool>();
_waitingRequests.Enqueue(waitForAvailableStreamTcs);
}
}
if (waitForAvailableStreamTcs != null)
{
await waitForAvailableStreamTcs.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false);
}
// Allocate an active request
QuicStream? quicStream = null;
Http3RequestStream? requestStream = null;
ValueTask waitTask = default;
try
{
lock (SyncObj)
while (true)
{
if (_connection != null)
lock (SyncObj)
{
quicStream = _connection.OpenBidirectionalStream();
requestStream = new Http3RequestStream(request, this, quicStream);
_activeRequests.Add(quicStream, requestStream);
if (_connection == null)
{
break;
}
if (_connection.GetRemoteAvailableBidirectionalStreamCount() > 0)
{
quicStream = _connection.OpenBidirectionalStream();
requestStream = new Http3RequestStream(request, this, quicStream);
_activeRequests.Add(quicStream, requestStream);
break;
}
waitTask = _connection.WaitForAvailableBidirectionalStreamsAsync(cancellationToken);
}
// Wait for an available stream (based on QUIC MAX_STREAMS) if there isn't one available yet.
await waitTask.ConfigureAwait(false);
}
if (quicStream == null)
......@@ -212,8 +194,6 @@ public override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage req
throw new HttpRequestException(SR.net_http_request_aborted, null, RequestRetryType.RetryOnConnectionFailure);
}
// 0-byte write to force QUIC to allocate a stream ID.
await quicStream.WriteAsync(Array.Empty<byte>(), cancellationToken).ConfigureAwait(false);
requestStream!.StreamId = quicStream.StreamId;
bool goAway;
......@@ -246,76 +226,6 @@ public override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage req
}
}
/// <summary>
/// Waits for MAX_STREAMS to be raised by the server.
/// </summary>
private Task WaitForAvailableRequestStreamAsync(CancellationToken cancellationToken)
{
TaskCompletionSourceWithCancellation<bool> tcs;
lock (SyncObj)
{
long remaining = _requestStreamsRemaining;
if (remaining > 0)
{
_requestStreamsRemaining = remaining - 1;
return Task.CompletedTask;
}
tcs = new TaskCompletionSourceWithCancellation<bool>();
_waitingRequests.Enqueue(tcs);
}
// Note: cancellation on connection shutdown is handled in CancelWaiters.
return tcs.WaitWithCancellationAsync(cancellationToken).AsTask();
}
/// <summary>
/// Cancels any waiting SendAsync calls.
/// </summary>
/// <remarks>Requires <see cref="SyncObj"/> to be held.</remarks>
private void CancelWaiters()
{
Debug.Assert(Monitor.IsEntered(SyncObj));
while (_waitingRequests.TryDequeue(out TaskCompletionSourceWithCancellation<bool>? tcs))
{
tcs.TrySetException(new HttpRequestException(SR.net_http_request_aborted, null, RequestRetryType.RetryOnConnectionFailure));
}
}
// TODO: how do we get this event? -> HandleEventStreamsAvailable reports currently available Uni/Bi streams
private void OnMaximumStreamCountIncrease(long newMaximumStreamCount)
{
lock (SyncObj)
{
if (newMaximumStreamCount <= _maximumRequestStreams)
{
return;
}
IncreaseRemainingStreamCount(newMaximumStreamCount - _maximumRequestStreams);
_maximumRequestStreams = newMaximumStreamCount;
}
}
private void IncreaseRemainingStreamCount(long delta)
{
Debug.Assert(Monitor.IsEntered(SyncObj));
Debug.Assert(delta > 0);
_requestStreamsRemaining += delta;
while (_requestStreamsRemaining != 0 && _waitingRequests.TryDequeue(out TaskCompletionSourceWithCancellation<bool>? tcs))
{
if (tcs.TrySetResult(true))
{
--_requestStreamsRemaining;
}
}
}
/// <summary>
/// Aborts the connection with an error.
/// </summary>
......@@ -358,7 +268,6 @@ internal Exception Abort(Exception abortException)
_connectionClosedTask = _connection.CloseAsync((long)connectionResetErrorCode).AsTask();
}
CancelWaiters();
CheckForShutdown();
}
......@@ -396,7 +305,6 @@ private void OnServerGoAway(long lastProcessedStreamId)
}
}
CancelWaiters();
CheckForShutdown();
}
......@@ -414,8 +322,6 @@ public void RemoveStream(QuicStream stream)
bool removed = _activeRequests.Remove(stream);
Debug.Assert(removed == true);
IncreaseRemainingStreamCount(1);
if (ShuttingDown)
{
CheckForShutdown();
......
......@@ -79,10 +79,12 @@ public async Task ClientSettingsReceived_Success(int headerSizeLimit)
}
[Theory]
[InlineData(10)]
[InlineData(100)]
[InlineData(1000)]
public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit)
{
using Http3LoopbackServer server = CreateHttp3LoopbackServer();
using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options(){ MaxBidirectionalStreams = streamLimit });
Task serverTask = Task.Run(async () =>
{
......@@ -100,7 +102,7 @@ public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit)
for (int i = 0; i < streamLimit + 1; ++i)
{
using HttpRequestMessage request = new()
HttpRequestMessage request = new()
{
Method = HttpMethod.Get,
RequestUri = server.Address,
......@@ -114,6 +116,162 @@ public async Task SendMoreThanStreamLimitRequests_Succeeds(int streamLimit)
await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}
[Theory]
[InlineData(10)]
[InlineData(100)]
[InlineData(1000)]
public async Task SendStreamLimitRequestsConcurrently_Succeeds(int streamLimit)
{
using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options(){ MaxBidirectionalStreams = streamLimit });
Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
for (int i = 0; i < streamLimit; ++i)
{
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.HandleRequestAsync();
}
});
Task clientTask = Task.Run(async () =>
{
using HttpClient client = CreateHttpClient();
var tasks = new Task<HttpResponseMessage>[streamLimit];
Parallel.For(0, streamLimit, i =>
{
HttpRequestMessage request = new()
{
Method = HttpMethod.Get,
RequestUri = server.Address,
Version = HttpVersion30,
VersionPolicy = HttpVersionPolicy.RequestVersionExact
};
tasks[i] = client.SendAsync(request);
});
var responses = await Task.WhenAll(tasks);
foreach (var response in responses)
{
response.Dispose();
}
});
await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}
[Theory]
[InlineData(10)]
[InlineData(100)]
[InlineData(1000)]
public async Task SendMoreThanStreamLimitRequestsConcurrently_LastWaits(int streamLimit)
{
// This combination leads to a hang manifesting in CI only. Disabling it until there's more time to investigate.
// [ActiveIssue("https://github.com/dotnet/runtime/issues/53688")]
if (streamLimit == 10 && this.UseQuicImplementationProvider == QuicImplementationProviders.Mock)
{
return;
}
using Http3LoopbackServer server = CreateHttp3LoopbackServer(new Http3Options(){ MaxBidirectionalStreams = streamLimit });
var lastRequestContentStarted = new TaskCompletionSource();
Task serverTask = Task.Run(async () =>
{
// Read the first streamLimit requests, keep the streams open to make the last one wait.
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
var streams = new Http3LoopbackStream[streamLimit];
for (int i = 0; i < streamLimit; ++i)
{
Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
var body = await stream.ReadRequestDataAsync();
streams[i] = stream;
}
// Make the last request running independently.
var lastRequest = Task.Run(async () => {
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
await stream.HandleRequestAsync();
});
// All the initial streamLimit streams are still opened so the last request cannot started yet.
Assert.False(lastRequestContentStarted.Task.IsCompleted);
// Reply to the first streamLimit requests.
for (int i = 0; i < streamLimit; ++i)
{
await streams[i].SendResponseAsync();
streams[i].Dispose();
// After the first request is fully processed, the last request should unblock and get processed.
if (i == 0)
{
await lastRequestContentStarted.Task;
}
}
await lastRequest;
});
Task clientTask = Task.Run(async () =>
{
using HttpClient client = CreateHttpClient();
// Fire out the first streamLimit requests in parallel, no waiting for the responses yet.
var countdown = new CountdownEvent(streamLimit);
var tasks = new Task<HttpResponseMessage>[streamLimit];
Parallel.For(0, streamLimit, i =>
{
HttpRequestMessage request = new()
{
Method = HttpMethod.Post,
RequestUri = server.Address,
Version = HttpVersion30,
VersionPolicy = HttpVersionPolicy.RequestVersionExact,
Content = new StreamContent(new DelegateStream(
canReadFunc: () => true,
readFunc: (buffer, offset, count) =>
{
countdown.Signal();
return 0;
}))
};
tasks[i] = client.SendAsync(request);
});
// Wait for the first streamLimit request to get started.
countdown.Wait();
// Fire out the last request, that should wait until the server fully handles at least one request.
HttpRequestMessage last = new()
{
Method = HttpMethod.Post,
RequestUri = server.Address,
Version = HttpVersion30,
VersionPolicy = HttpVersionPolicy.RequestVersionExact,
Content = new StreamContent(new DelegateStream(
canReadFunc: () => true,
readFunc: (buffer, offset, count) =>
{
lastRequestContentStarted.SetResult();
return 0;
}))
};
var lastTask = client.SendAsync(last);
// Wait for all requests to finish. Whether the last request was pending is checked on the server side.
var responses = await Task.WhenAll(tasks);
foreach (var response in responses)
{
response.Dispose();
}
await lastTask;
});
await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/53090")]
public async Task ReservedFrameType_Throws()
......
......@@ -52,9 +52,9 @@ protected static HttpClientHandler CreateHttpClientHandler(Version useVersion =
return handler;
}
protected Http3LoopbackServer CreateHttp3LoopbackServer()
protected Http3LoopbackServer CreateHttp3LoopbackServer(Http3Options options = default)
{
return new Http3LoopbackServer(UseQuicImplementationProvider);
return new Http3LoopbackServer(UseQuicImplementationProvider, options);
}
protected HttpClientHandler CreateHttpClientHandler() => CreateHttpClientHandler(UseVersion, UseQuicImplementationProvider);
......@@ -97,7 +97,7 @@ protected static LoopbackServerFactory GetFactoryForVersion(Version useVersion,
internal class VersionHttpClientHandler : HttpClientHandler
{
private readonly Version _useVersion;
public VersionHttpClientHandler(Version useVersion)
{
_useVersion = useVersion;
......@@ -120,7 +120,7 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
{
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
}
return base.SendAsync(request, cancellationToken);
}
......
......@@ -27,10 +27,12 @@ public sealed partial class QuicConnection : System.IDisposable
public System.Threading.Tasks.ValueTask CloseAsync(long errorCode, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask ConnectAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public void Dispose() { }
public long GetRemoteAvailableBidirectionalStreamCount() { throw null; }
public long GetRemoteAvailableUnidirectionalStreamCount() { throw null; }
public int GetRemoteAvailableBidirectionalStreamCount() { throw null; }
public int GetRemoteAvailableUnidirectionalStreamCount() { throw null; }
public System.Net.Quic.QuicStream OpenBidirectionalStream() { throw null; }
public System.Net.Quic.QuicStream OpenUnidirectionalStream() { throw null; }
public System.Threading.Tasks.ValueTask WaitForAvailableBidirectionalStreamsAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask WaitForAvailableUnidirectionalStreamsAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
}
public partial class QuicConnectionAbortedException : System.Net.Quic.QuicException
{
......@@ -73,8 +75,8 @@ public partial class QuicOptions
{
public QuicOptions() { }
public System.TimeSpan IdleTimeout { get { throw null; } set { } }
public long MaxBidirectionalStreams { get { throw null; } set { } }
public long MaxUnidirectionalStreams { get { throw null; } set { } }
public int MaxBidirectionalStreams { get { throw null; } set { } }
public int MaxUnidirectionalStreams { get { throw null; } set { } }
}
public sealed partial class QuicStream : System.IO.Stream
{
......@@ -101,8 +103,8 @@ public sealed partial class QuicStream : System.IO.Stream
public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; }
public override void SetLength(long value) { }
public void Shutdown() { }
public System.Threading.Tasks.ValueTask ShutdownWriteCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask ShutdownWriteCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override void Write(byte[] buffer, int offset, int count) { }
public override void Write(System.ReadOnlySpan<byte> buffer) { }
public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
......
......@@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Net;
using System.Net.Security;
using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
......@@ -20,11 +21,16 @@ internal sealed class MockConnection : QuicConnectionProvider
private object _syncObject = new object();
private long _nextOutboundBidirectionalStream;
private long _nextOutboundUnidirectionalStream;
private readonly int _maxUnidirectionalStreams;
private readonly int _maxBidirectionalStreams;
private ConnectionState? _state;
internal PeerStreamLimit? LocalStreamLimit => _isClient ? _state?._clientStreamLimit : _state?._serverStreamLimit;
internal PeerStreamLimit? RemoteStreamLimit => _isClient ? _state?._serverStreamLimit : _state?._clientStreamLimit;
// Constructor for outbound connections
internal MockConnection(EndPoint? remoteEndPoint, SslClientAuthenticationOptions? sslClientAuthenticationOptions, IPEndPoint? localEndPoint = null)
internal MockConnection(EndPoint? remoteEndPoint, SslClientAuthenticationOptions? sslClientAuthenticationOptions, IPEndPoint? localEndPoint = null, int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100)
{
if (remoteEndPoint is null)
{
......@@ -43,6 +49,8 @@ internal MockConnection(EndPoint? remoteEndPoint, SslClientAuthenticationOptions
_sslClientAuthenticationOptions = sslClientAuthenticationOptions;
_nextOutboundBidirectionalStream = 0;
_nextOutboundUnidirectionalStream = 2;
_maxUnidirectionalStreams = maxUnidirectionalStreams;
_maxBidirectionalStreams = maxBidirectionalStreams;
// _state is not initialized until ConnectAsync
}
......@@ -129,7 +137,10 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
}
// TODO: deal with protocol negotiation
_state = new ConnectionState(_sslClientAuthenticationOptions!.ApplicationProtocols![0]);
_state = new ConnectionState(_sslClientAuthenticationOptions!.ApplicationProtocols![0])
{
_clientStreamLimit = new PeerStreamLimit(_maxUnidirectionalStreams, _maxBidirectionalStreams)
};
if (!listener.TryConnect(_state))
{
throw new QuicException("Connection refused");
......@@ -138,8 +149,41 @@ internal override ValueTask ConnectAsync(CancellationToken cancellationToken = d
return ValueTask.CompletedTask;
}
internal override ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default)
{
PeerStreamLimit? streamLimit = RemoteStreamLimit;
if (streamLimit is null)
{
throw new InvalidOperationException("Not connected");
}
return streamLimit.Unidirectional.WaitForAvailableStreams(cancellationToken);
}
internal override ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default)
{
PeerStreamLimit? streamLimit = RemoteStreamLimit;
if (streamLimit is null)
{
throw new InvalidOperationException("Not connected");
}
return streamLimit.Bidirectional.WaitForAvailableStreams(cancellationToken);
}
internal override QuicStreamProvider OpenUnidirectionalStream()
{
PeerStreamLimit? streamLimit = RemoteStreamLimit;
if (streamLimit is null)
{
throw new InvalidOperationException("Not connected");
}
if (!streamLimit.Unidirectional.TryIncrement())
{
throw new QuicException("No available unidirectional stream");
}
long streamId;
lock (_syncObject)
{
......@@ -152,6 +196,17 @@ internal override QuicStreamProvider OpenUnidirectionalStream()
internal override QuicStreamProvider OpenBidirectionalStream()
{
PeerStreamLimit? streamLimit = RemoteStreamLimit;
if (streamLimit is null)
{
throw new InvalidOperationException("Not connected");
}
if (!streamLimit.Bidirectional.TryIncrement())
{
throw new QuicException("No available bidirectional stream");
}
long streamId;
lock (_syncObject)
{
......@@ -174,12 +229,30 @@ internal MockStream OpenStream(long streamId, bool bidirectional)
Channel<MockStream.StreamState> streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel;
streamChannel.Writer.TryWrite(streamState);
return new MockStream(streamState, true);
return new MockStream(this, streamState, true);
}
internal override long GetRemoteAvailableUnidirectionalStreamCount() => long.MaxValue;
internal override int GetRemoteAvailableUnidirectionalStreamCount()
{
PeerStreamLimit? streamLimit = RemoteStreamLimit;
if (streamLimit is null)
{
throw new InvalidOperationException("Not connected");
}
return streamLimit.Unidirectional.AvailableCount;
}
internal override int GetRemoteAvailableBidirectionalStreamCount()
{
PeerStreamLimit? streamLimit = RemoteStreamLimit;
if (streamLimit is null)
{
throw new InvalidOperationException("Not connected");
}
internal override long GetRemoteAvailableBidirectionalStreamCount() => long.MaxValue;
return streamLimit.Bidirectional.AvailableCount;
}
internal override async ValueTask<QuicStreamProvider> AcceptStreamAsync(CancellationToken cancellationToken = default)
{
......@@ -196,7 +269,7 @@ internal override async ValueTask<QuicStreamProvider> AcceptStreamAsync(Cancella
try
{
MockStream.StreamState streamState = await streamChannel.Reader.ReadAsync(cancellationToken).ConfigureAwait(false);
return new MockStream(streamState, false);
return new MockStream(this, streamState, false);
}
catch (ChannelClosedException)
{
......@@ -251,6 +324,14 @@ private void Dispose(bool disposing)
Channel<MockStream.StreamState> streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel;
streamChannel.Writer.Complete();
}
PeerStreamLimit? streamLimit = LocalStreamLimit;
if (streamLimit is not null)
{
streamLimit.Unidirectional.CloseWaiters();
streamLimit.Bidirectional.CloseWaiters();
}
}
// TODO: free unmanaged resources (unmanaged objects) and override a finalizer below.
......@@ -271,11 +352,77 @@ public override void Dispose()
GC.SuppressFinalize(this);
}
internal sealed class StreamLimit
{
public readonly int MaxCount;
private int _actualCount;
// Since this is mock, we don't need to be conservative with the allocations.
// We keep the TCSes allocated all the time for the simplicity of the code.
private TaskCompletionSource _availableTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly object _syncRoot = new object();
public StreamLimit(int maxCount)
{
MaxCount = maxCount;
}
public int AvailableCount => MaxCount - _actualCount;
public void Decrement()
{
lock (_syncRoot)
{
--_actualCount;
if (!_availableTcs.Task.IsCompleted)
{
_availableTcs.SetResult();
_availableTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}
}
}
public bool TryIncrement()
{
lock (_syncRoot)
{
if (_actualCount < MaxCount)
{
++_actualCount;
return true;
}
return false;
}
}
public ValueTask WaitForAvailableStreams(CancellationToken cancellationToken)
=> new ValueTask(_availableTcs.Task.WaitAsync(cancellationToken));
public void CloseWaiters()
=> _availableTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
}
internal class PeerStreamLimit
{
public readonly StreamLimit Unidirectional;
public readonly StreamLimit Bidirectional;
public PeerStreamLimit(int maxUnidirectional, int maxBidirectional)
{
Unidirectional = new StreamLimit(maxUnidirectional);
Bidirectional = new StreamLimit(maxBidirectional);
}
}
internal sealed class ConnectionState
{
public readonly SslApplicationProtocol _applicationProtocol;
public Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
public Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
public PeerStreamLimit? _clientStreamLimit;
public PeerStreamLimit? _serverStreamLimit;
public long _clientErrorCode;
public long _serverErrorCode;
public bool _closed;
......
......@@ -16,7 +16,11 @@ internal override QuicListenerProvider CreateListener(QuicListenerOptions option
internal override QuicConnectionProvider CreateConnection(QuicClientConnectionOptions options)
{
return new MockConnection(options.RemoteEndPoint, options.ClientAuthenticationOptions, options.LocalEndPoint);
return new MockConnection(options.RemoteEndPoint,
options.ClientAuthenticationOptions,
options.LocalEndPoint,
options.MaxUnidirectionalStreams,
options.MaxBidirectionalStreams);
}
}
}
......@@ -69,6 +69,7 @@ internal override async ValueTask<QuicConnectionProvider> AcceptConnectionAsync(
// Returns false if backlog queue is full.
internal bool TryConnect(MockConnection.ConnectionState state)
{
state._serverStreamLimit = new MockConnection.PeerStreamLimit(_options.MaxUnidirectionalStreams, _options.MaxBidirectionalStreams);
return _listenQueue.Writer.TryWrite(state);
}
......
......@@ -14,12 +14,14 @@ internal sealed class MockStream : QuicStreamProvider
{
private bool _disposed;
private readonly bool _isInitiator;
private readonly MockConnection _connection;
private readonly StreamState _streamState;
private bool _writesCanceled;
internal MockStream(StreamState streamState, bool isInitiator)
internal MockStream(MockConnection connection, StreamState streamState, bool isInitiator)
{
_connection = connection;
_streamState = streamState;
_isInitiator = isInitiator;
}
......@@ -186,7 +188,6 @@ internal override void AbortWrite(long errorCode)
WriteStreamBuffer?.EndWrite();
}
internal override ValueTask ShutdownWriteCompleted(CancellationToken cancellationToken = default)
{
CheckDisposed();
......@@ -208,6 +209,15 @@ internal override void Shutdown()
// This seems to mean shutdown send, in particular, not both.
WriteStreamBuffer?.EndWrite();
if (_streamState._inboundStreamBuffer is null) // unidirectional stream
{
_connection.LocalStreamLimit!.Unidirectional.Decrement();
}
else
{
_connection.LocalStreamLimit!.Bidirectional.Decrement();
}
}
private void CheckDisposed()
......
......@@ -51,6 +51,12 @@ internal sealed class State
public readonly TaskCompletionSource<uint> ConnectTcs = new TaskCompletionSource<uint>(TaskCreationOptions.RunContinuationsAsynchronously);
public readonly TaskCompletionSource<uint> ShutdownTcs = new TaskCompletionSource<uint>(TaskCreationOptions.RunContinuationsAsynchronously);
// Note that there's no such thing as resetable TCS, so we cannot reuse the same instance after we've set the result.
// We also cannot use solutions like ManualResetValueTaskSourceCore, since we can have multiple waiters on the same TCS.
// As a result, we allocate a new TCS when needed, which is when someone explicitely asks for them in WaitForAvailableStreamsAsync.
public TaskCompletionSource? NewUnidirectionalStreamsAvailable;
public TaskCompletionSource? NewBidirectionalStreamsAvailable;
public bool Connected;
public long AbortErrorCode = -1;
......@@ -192,6 +198,26 @@ private static uint HandleEventShutdownComplete(State state, ref ConnectionEvent
// Stop accepting new streams.
state.AcceptQueue.Writer.Complete();
// Stop notifying about available streams.
TaskCompletionSource? unidirectionalTcs = null;
TaskCompletionSource? bidirectionalTcs = null;
lock (state)
{
unidirectionalTcs = state.NewBidirectionalStreamsAvailable;
bidirectionalTcs = state.NewBidirectionalStreamsAvailable;
state.NewUnidirectionalStreamsAvailable = null;
state.NewBidirectionalStreamsAvailable = null;
}
if (unidirectionalTcs is not null)
{
unidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
}
if (bidirectionalTcs is not null)
{
bidirectionalTcs.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException()));
}
return MsQuicStatusCodes.Success;
}
......@@ -206,6 +232,32 @@ private static uint HandleEventNewStream(State state, ref ConnectionEvent connec
private static uint HandleEventStreamsAvailable(State state, ref ConnectionEvent connectionEvent)
{
TaskCompletionSource? unidirectionalTcs = null;
TaskCompletionSource? bidirectionalTcs = null;
lock (state)
{
if (connectionEvent.Data.StreamsAvailable.UniDirectionalCount > 0)
{
unidirectionalTcs = state.NewUnidirectionalStreamsAvailable;
state.NewUnidirectionalStreamsAvailable = null;
}
if (connectionEvent.Data.StreamsAvailable.BiDirectionalCount > 0)
{
bidirectionalTcs = state.NewBidirectionalStreamsAvailable;
state.NewBidirectionalStreamsAvailable = null;
}
}
if (unidirectionalTcs is not null)
{
unidirectionalTcs.SetResult();
}
if (bidirectionalTcs is not null)
{
bidirectionalTcs.SetResult();
}
return MsQuicStatusCodes.Success;
}
......@@ -329,24 +381,82 @@ internal override async ValueTask<QuicStreamProvider> AcceptStreamAsync(Cancella
return stream;
}
internal override ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default)
{
TaskCompletionSource? tcs = _state.NewUnidirectionalStreamsAvailable;
if (tcs is null)
{
lock (_state)
{
if (_state.NewUnidirectionalStreamsAvailable is null)
{
if (_state.ShutdownTcs.Task.IsCompleted)
{
throw new QuicOperationAbortedException();
}
if (GetRemoteAvailableUnidirectionalStreamCount() > 0)
{
return ValueTask.CompletedTask;
}
_state.NewUnidirectionalStreamsAvailable = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}
tcs = _state.NewUnidirectionalStreamsAvailable;
}
}
return new ValueTask(tcs.Task.WaitAsync(cancellationToken));
}
internal override ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default)
{
TaskCompletionSource? tcs = _state.NewBidirectionalStreamsAvailable;
if (tcs is null)
{
lock (_state)
{
if (_state.NewBidirectionalStreamsAvailable is null)
{
if (_state.ShutdownTcs.Task.IsCompleted)
{
throw new QuicOperationAbortedException();
}
if (GetRemoteAvailableBidirectionalStreamCount() > 0)
{
return ValueTask.CompletedTask;
}
_state.NewBidirectionalStreamsAvailable = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}
tcs = _state.NewBidirectionalStreamsAvailable;
}
}
return new ValueTask(tcs.Task.WaitAsync(cancellationToken));
}
internal override QuicStreamProvider OpenUnidirectionalStream()
{
ThrowIfDisposed();
return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL);
}
internal override QuicStreamProvider OpenBidirectionalStream()
{
ThrowIfDisposed();
return new MsQuicStream(_state, QUIC_STREAM_OPEN_FLAGS.NONE);
}
internal override long GetRemoteAvailableUnidirectionalStreamCount()
internal override int GetRemoteAvailableUnidirectionalStreamCount()
{
return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_UNIDI_STREAM_COUNT);
}
internal override long GetRemoteAvailableBidirectionalStreamCount()
internal override int GetRemoteAvailableBidirectionalStreamCount()
{
return MsQuicParameterHelpers.GetUShortParam(MsQuicApi.Api, _state.Handle, QUIC_PARAM_LEVEL.CONNECTION, (uint)QUIC_PARAM_CONN.LOCAL_BIDI_STREAM_COUNT);
}
......
......@@ -64,7 +64,6 @@ private sealed class State
// Set once writes have been shutdown.
public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
public ShutdownState ShutdownState;
// Set once stream have been shutdown.
......@@ -124,7 +123,7 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F
QuicExceptionHelpers.ThrowIfFailed(status, "Failed to open stream to peer.");
status = MsQuicApi.Api.StreamStartDelegate(_state.Handle, QUIC_STREAM_START_FLAGS.ASYNC);
status = MsQuicApi.Api.StreamStartDelegate(_state.Handle, QUIC_STREAM_START_FLAGS.FAIL_BLOCKED);
QuicExceptionHelpers.ThrowIfFailed(status, "Could not start stream.");
}
catch
......@@ -492,6 +491,7 @@ internal override async ValueTask ShutdownCompleted(CancellationToken cancellati
internal override void Shutdown()
{
ThrowIfDisposed();
// it is ok to send shutdown several times, MsQuic will ignore it
StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0);
}
......@@ -592,7 +592,7 @@ private static uint HandleEvent(State state, ref StreamEvent evt)
// Stream has started.
// Will only be done for outbound streams (inbound streams have already started)
case QUIC_STREAM_EVENT_TYPE.START_COMPLETE:
return HandleStartComplete(state);
return HandleEventStartComplete(state);
// Received data on the stream
case QUIC_STREAM_EVENT_TYPE.RECEIVE:
return HandleEventRecv(state, ref evt);
......@@ -678,7 +678,7 @@ private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt)
return MsQuicStatusCodes.Success;
}
private static uint HandleStartComplete(State state)
private static uint HandleEventStartComplete(State state)
{
bool shouldComplete = false;
lock (state)
......
......@@ -16,13 +16,17 @@ internal abstract class QuicConnectionProvider : IDisposable
internal abstract ValueTask ConnectAsync(CancellationToken cancellationToken = default);
internal abstract ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default);
internal abstract ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default);
internal abstract QuicStreamProvider OpenUnidirectionalStream();
internal abstract QuicStreamProvider OpenBidirectionalStream();
internal abstract long GetRemoteAvailableUnidirectionalStreamCount();
internal abstract int GetRemoteAvailableUnidirectionalStreamCount();
internal abstract long GetRemoteAvailableBidirectionalStreamCount();
internal abstract int GetRemoteAvailableBidirectionalStreamCount();
internal abstract ValueTask<QuicStreamProvider> AcceptStreamAsync(CancellationToken cancellationToken = default);
......
......@@ -67,6 +67,18 @@ internal QuicConnection(QuicConnectionProvider provider)
/// <returns></returns>
public ValueTask ConnectAsync(CancellationToken cancellationToken = default) => _provider.ConnectAsync(cancellationToken);
/// <summary>
/// Waits for available unidirectional stream capacity to be announced by the peer. If any capacity is available, returns immediately.
/// </summary>
/// <returns></returns>
public ValueTask WaitForAvailableUnidirectionalStreamsAsync(CancellationToken cancellationToken = default) => _provider.WaitForAvailableUnidirectionalStreamsAsync(cancellationToken);
/// <summary>
/// Waits for available bidirectional stream capacity to be announced by the peer. If any capacity is available, returns immediately.
/// </summary>
/// <returns></returns>
public ValueTask WaitForAvailableBidirectionalStreamsAsync(CancellationToken cancellationToken = default) => _provider.WaitForAvailableBidirectionalStreamsAsync(cancellationToken);
/// <summary>
/// Create an outbound unidirectional stream.
/// </summary>
......@@ -95,11 +107,11 @@ internal QuicConnection(QuicConnectionProvider provider)
/// <summary>
/// Gets the maximum number of bidirectional streams that can be made to the peer.
/// </summary>
public long GetRemoteAvailableUnidirectionalStreamCount() => _provider.GetRemoteAvailableUnidirectionalStreamCount();
public int GetRemoteAvailableUnidirectionalStreamCount() => _provider.GetRemoteAvailableUnidirectionalStreamCount();
/// <summary>
/// Gets the maximum number of unidirectional streams that can be made to the peer.
/// </summary>
public long GetRemoteAvailableBidirectionalStreamCount() => _provider.GetRemoteAvailableBidirectionalStreamCount();
public int GetRemoteAvailableBidirectionalStreamCount() => _provider.GetRemoteAvailableBidirectionalStreamCount();
}
}
......@@ -19,14 +19,14 @@ public class QuicOptions
/// Default is 100.
/// </summary>
// TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using.
public long MaxBidirectionalStreams { get; set; } = 100;
public int MaxBidirectionalStreams { get; set; } = 100;
/// <summary>
/// Limit on the number of unidirectional streams the remote peer connection can create on an open connection.
/// Default is 100.
/// </summary>
// TODO consider constraining these limits to 0 to whatever the max of the QUIC library we are using.
public long MaxUnidirectionalStreams { get; set; } = 100;
public int MaxUnidirectionalStreams { get; set; } = 100;
/// <summary>
/// Idle timeout for connections, after which the connection will be closed.
......
......@@ -95,6 +95,56 @@ public async Task ConnectWithCertificateChain()
await clientTask;
}
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/52048")]
public async Task WaitForAvailableUnidirectionStreamsAsyncWorks()
{
using QuicListener listener = CreateQuicListener(maxUnidirectionalStreams: 1);
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
// No stream openned yet, should return immediately.
Assert.True(clientConnection.WaitForAvailableUnidirectionalStreamsAsync().IsCompletedSuccessfully);
// Open one stream, should wait till it closes.
QuicStream stream = clientConnection.OpenUnidirectionalStream();
ValueTask waitTask = clientConnection.WaitForAvailableUnidirectionalStreamsAsync();
Assert.False(waitTask.IsCompleted);
Assert.Throws<QuicException>(() => clientConnection.OpenUnidirectionalStream());
// Close the stream, the waitTask should finish as a result.
stream.Dispose();
await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10));
}
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/52048")]
public async Task WaitForAvailableBidirectionStreamsAsyncWorks()
{
using QuicListener listener = CreateQuicListener(maxBidirectionalStreams: 1);
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);
ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
// No stream openned yet, should return immediately.
Assert.True(clientConnection.WaitForAvailableBidirectionalStreamsAsync().IsCompletedSuccessfully);
// Open one stream, should wait till it closes.
QuicStream stream = clientConnection.OpenBidirectionalStream();
ValueTask waitTask = clientConnection.WaitForAvailableBidirectionalStreamsAsync();
Assert.False(waitTask.IsCompleted);
Assert.Throws<QuicException>(() => clientConnection.OpenBidirectionalStream());
// Close the stream, the waitTask should finish as a result.
stream.Dispose();
await waitTask.AsTask().WaitAsync(TimeSpan.FromSeconds(10));
}
[Fact]
[OuterLoop("May take several seconds")]
public async Task SetListenerTimeoutWorksWithSmallTimeout()
......@@ -234,7 +284,7 @@ public async Task CallDifferentWriteMethodsWorks()
int res = await serverStream.ReadAsync(memory);
Assert.Equal(12, res);
ReadOnlyMemory<ReadOnlyMemory<byte>> romrom = new ReadOnlyMemory<ReadOnlyMemory<byte>>(new ReadOnlyMemory<byte>[] { helloWorld, helloWorld });
await clientStream.WriteAsync(romrom);
res = await serverStream.ReadAsync(memory);
......@@ -254,7 +304,7 @@ public async Task CloseAsync_ByServer_AcceptThrows()
{
var acceptTask = serverConnection.AcceptStreamAsync();
await serverConnection.CloseAsync(errorCode: 0);
// make sure
// make sure
await Assert.ThrowsAsync<QuicOperationAbortedException>(() => acceptTask.AsTask());
});
}
......
......@@ -92,7 +92,7 @@ public SslServerAuthenticationOptions GetSslServerAuthenticationOptions()
ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate()
};
}
protected abstract QuicImplementationProvider Provider { get; }
protected override async Task<StreamPair> CreateConnectedStreamsAsync()
......
......@@ -53,16 +53,30 @@ internal QuicConnection CreateQuicConnection(IPEndPoint endpoint)
return new QuicConnection(ImplementationProvider, endpoint, GetSslClientAuthenticationOptions());
}
internal QuicListener CreateQuicListener()
internal QuicListener CreateQuicListener(int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100)
{
return CreateQuicListener(new IPEndPoint(IPAddress.Loopback, 0));
var options = new QuicListenerOptions()
{
ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0),
ServerAuthenticationOptions = GetSslServerAuthenticationOptions(),
MaxUnidirectionalStreams = maxUnidirectionalStreams,
MaxBidirectionalStreams = maxBidirectionalStreams
};
return CreateQuicListener(options);
}
internal QuicListener CreateQuicListener(IPEndPoint endpoint)
{
return new QuicListener(ImplementationProvider, endpoint, GetSslServerAuthenticationOptions());
var options = new QuicListenerOptions()
{
ListenEndPoint = endpoint,
ServerAuthenticationOptions = GetSslServerAuthenticationOptions()
};
return CreateQuicListener(options);
}
private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options);
internal async Task RunClientServer(Func<QuicConnection, Task> clientFunction, Func<QuicConnection, Task> serverFunction, int iterations = 1, int millisecondsTimeout = 10_000)
{
using QuicListener listener = CreateQuicListener();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册