未验证 提交 d9f1ade5 编写于 作者: G Geoff Kizer 提交者: GitHub

fix read abort handling and revert CanRead/CanWrite to previous behavior (#55341)

Co-authored-by: NGeoffrey Kizer <geoffrek@windows.microsoft.com>
上级 31c2bed4
......@@ -70,7 +70,7 @@ internal override async ValueTask<int> ReadAsync(Memory<byte> buffer, Cancellati
int bytesRead = await streamBuffer.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
long errorCode = _isInitiator ? _streamState._inboundErrorCode : _streamState._outboundErrorCode;
long errorCode = _isInitiator ? _streamState._inboundReadErrorCode : _streamState._outboundReadErrorCode;
if (errorCode != 0)
{
throw new QuicStreamAbortedException(errorCode);
......@@ -121,6 +121,12 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool e
throw new NotSupportedException();
}
long errorCode = _isInitiator ? _streamState._inboundWriteErrorCode : _streamState._outboundWriteErrorCode;
if (errorCode != 0)
{
throw new QuicStreamAbortedException(errorCode);
}
using var registration = cancellationToken.UnsafeRegister(static s =>
{
var stream = (MockStream)s!;
......@@ -171,18 +177,27 @@ internal override Task FlushAsync(CancellationToken cancellationToken)
internal override void AbortRead(long errorCode)
{
throw new NotImplementedException();
if (_isInitiator)
{
_streamState._outboundWriteErrorCode = errorCode;
}
else
{
_streamState._inboundWriteErrorCode = errorCode;
}
ReadStreamBuffer?.AbortRead();
}
internal override void AbortWrite(long errorCode)
{
if (_isInitiator)
{
_streamState._outboundErrorCode = errorCode;
_streamState._outboundReadErrorCode = errorCode;
}
else
{
_streamState._inboundErrorCode = errorCode;
_streamState._inboundReadErrorCode = errorCode;
}
WriteStreamBuffer?.EndWrite();
......@@ -255,8 +270,10 @@ internal sealed class StreamState
public readonly long _streamId;
public StreamBuffer _outboundStreamBuffer;
public StreamBuffer? _inboundStreamBuffer;
public long _outboundErrorCode;
public long _inboundErrorCode;
public long _outboundReadErrorCode;
public long _inboundReadErrorCode;
public long _outboundWriteErrorCode;
public long _inboundWriteErrorCode;
private const int InitialBufferSize =
#if DEBUG
......
......@@ -20,6 +20,9 @@ internal sealed class MsQuicStream : QuicStreamProvider
private readonly State _state = new State();
private readonly bool _canRead;
private readonly bool _canWrite;
// Backing for StreamId
private long _streamId = -1;
......@@ -80,8 +83,10 @@ public void Cleanup()
internal MsQuicStream(MsQuicConnection.State connectionState, SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAGS flags)
{
_state.Handle = streamHandle;
_canRead = true;
_canWrite = !flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL);
_started = true;
if (flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL))
if (!_canWrite)
{
_state.SendState = SendState.Closed;
}
......@@ -122,8 +127,11 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F
{
Debug.Assert(connectionState.Handle != null);
_canRead = !flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL);
_canWrite = true;
_state.StateGCHandle = GCHandle.Alloc(_state);
if (flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL))
if (!_canRead)
{
_state.ReadState = ReadState.Closed;
}
......@@ -167,9 +175,9 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F
}
}
internal override bool CanRead => _disposed == 0 && _state.ReadState < ReadState.Aborted;
internal override bool CanRead => _disposed == 0 && _canRead;
internal override bool CanWrite => _disposed == 0 && _state.SendState < SendState.Aborted;
internal override bool CanWrite => _disposed == 0 && _canWrite;
internal override long StreamId
{
......@@ -242,6 +250,11 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
}
else if ( _state.SendState == SendState.Aborted)
{
if (_state.SendErrorCode != -1)
{
throw new QuicStreamAbortedException(_state.SendErrorCode);
}
throw new OperationCanceledException(cancellationToken);
}
......@@ -292,6 +305,12 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
if (_state.SendState == SendState.Aborted)
{
cancellationToken.ThrowIfCancellationRequested();
if (_state.SendErrorCode != -1)
{
throw new QuicStreamAbortedException(_state.SendErrorCode);
}
throw new OperationCanceledException(SR.net_quic_sending_aborted);
}
else if (_state.SendState == SendState.ConnectionClosed)
......
......@@ -437,17 +437,15 @@ public async Task Read_StreamAborted_Throws()
}
[Fact]
public async Task StreamAbortedWithoutWriting_ReadThrows()
public async Task WriteAbortedWithoutWriting_ReadThrows()
{
long expectedErrorCode = 1234;
const long expectedErrorCode = 1234;
await RunClientServer(
clientFunction: async connection =>
{
await using QuicStream stream = connection.OpenUnidirectionalStream();
stream.AbortWrite(expectedErrorCode);
await stream.ShutdownCompleted();
},
serverFunction: async connection =>
{
......@@ -458,7 +456,32 @@ public async Task StreamAbortedWithoutWriting_ReadThrows()
QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => ReadAll(stream, buffer));
Assert.Equal(expectedErrorCode, ex.ErrorCode);
await stream.ShutdownCompleted();
// We should still return true from CanRead, even though the read has been aborted.
Assert.True(stream.CanRead);
}
);
}
[Fact]
public async Task ReadAbortedWithoutReading_WriteThrows()
{
const long expectedErrorCode = 1234;
await RunClientServer(
clientFunction: async connection =>
{
await using QuicStream stream = connection.OpenBidirectionalStream();
stream.AbortRead(expectedErrorCode);
},
serverFunction: async connection =>
{
await using QuicStream stream = await connection.AcceptStreamAsync();
QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => WriteForever(stream));
Assert.Equal(expectedErrorCode, ex.ErrorCode);
// We should still return true from CanWrite, even though the write has been aborted.
Assert.True(stream.CanWrite);
}
);
}
......@@ -466,7 +489,7 @@ public async Task StreamAbortedWithoutWriting_ReadThrows()
[Fact]
public async Task WritePreCanceled_Throws()
{
long expectedErrorCode = 1234;
const long expectedErrorCode = 1234;
await RunClientServer(
clientFunction: async connection =>
......@@ -502,7 +525,7 @@ public async Task WritePreCanceled_Throws()
[Fact]
public async Task WriteCanceled_NextWriteThrows()
{
long expectedErrorCode = 1234;
const long expectedErrorCode = 1234;
await RunClientServer(
clientFunction: async connection =>
......
......@@ -130,6 +130,15 @@ internal static async Task<int> ReadAll(QuicStream stream, byte[] buffer)
return bytesRead;
}
internal static async Task<int> WriteForever(QuicStream stream)
{
Memory<byte> buffer = new byte[] { 123 };
while (true)
{
await stream.WriteAsync(buffer);
}
}
internal static void AssertArrayEqual(byte[] expected, byte[] actual)
{
for (int i = 0; i < expected.Length; ++i)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册