未验证 提交 68dec6ac 编写于 作者: S Stephen Toub 提交者: GitHub

Fix Deflate/Brotli/CryptoStream handling of partial and zero-byte reads (#53644)

Stream.Read{Async} is supposed to return once at least a byte of data is available, and in particular, if there's any data already available, it shouldn't block.  But Read{Async} on DeflateStream (and thus also GZipStream and ZLibStream), BrotliStream, and CryptoStream won't return until either it hits the end of the stream or the caller's buffer is filled.  This makes it behave very unexpectedly when used in a context where the app is using a large read buffer but expects to be able to process data as it's available, e.g. in networked streaming scenarios where messages are being sent as part of bidirectional communication.

This fixes that by stopping looping once any data is consumed.  Just doing that, though, caused problems for zero-byte reads.  Zero-byte reads are typically used by code that's trying to delay-allocate a buffer for the read data until data will be available to read.  At present, however, zero-byte reads return immediately regardless of whether data is available to be consumed.  I've changed the flow to make it so that zero-byte reads don't return until there's at least some data available as input to the inflater/transform (this, though, doesn't 100% guarantee the inflater/transform will be able to produce output data).

Note that both of these changes have the potential to introduce breaks into an app that erroneously depended on these implementation details:
- If an app passing in a buffer of size N to Read{Async} depended on that call always producing the requested number of bytes (rather than what the Stream contract defines), they might experience behavioral changes.
- If an app passing in a zero-byte buffer expected it to return immediately, it might instead end up waiting until data was actually available.
上级 ffcef4af
......@@ -1562,11 +1562,6 @@ public abstract class ConnectedStreamConformanceTests : StreamConformanceTests
/// Gets whether the stream guarantees that all data written to it will be flushed as part of Flush{Async}.
/// </summary>
protected virtual bool FlushGuaranteesAllDataWritten => true;
/// <summary>
/// Gets whether a stream implements an aggressive read that tries to fill the supplied buffer and only
/// stops when it does so or hits EOF.
/// </summary>
protected virtual bool ReadsMayBlockUntilBufferFullOrEOF => false;
/// <summary>Gets whether reads for a count of 0 bytes block if no bytes are available to read.</summary>
protected virtual bool BlocksOnZeroByteReads => false;
/// <summary>
......@@ -1709,6 +1704,10 @@ public virtual async Task ReadWriteByte_Success()
}
}
public static IEnumerable<object[]> ReadWrite_Modes =>
from mode in Enum.GetValues<ReadWriteMode>()
select new object[] { mode };
public static IEnumerable<object[]> ReadWrite_Success_MemberData() =>
from mode in Enum.GetValues<ReadWriteMode>()
from writeSize in new[] { 1, 42, 10 * 1024 }
......@@ -1785,6 +1784,54 @@ public virtual async Task ReadWrite_Success(ReadWriteMode mode, int writeSize, b
}
}
[Theory]
[MemberData(nameof(ReadWrite_Modes))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/51371", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)]
public virtual async Task ReadWrite_MessagesSmallerThanReadBuffer_Success(ReadWriteMode mode)
{
if (!FlushGuaranteesAllDataWritten)
{
return;
}
foreach (CancellationToken nonCanceledToken in new[] { CancellationToken.None, new CancellationTokenSource().Token })
{
using StreamPair streams = await CreateConnectedStreamsAsync();
foreach ((Stream writeable, Stream readable) in GetReadWritePairs(streams))
{
byte[] writerBytes = RandomNumberGenerator.GetBytes(512);
var readerBytes = new byte[writerBytes.Length * 2];
// Repeatedly write then read a message smaller in size than the read buffer
for (int i = 0; i < 5; i++)
{
Task writes = Task.Run(async () =>
{
await WriteAsync(mode, writeable, writerBytes, 0, writerBytes.Length, nonCanceledToken);
if (FlushRequiredToWriteData)
{
await writeable.FlushAsync();
}
});
int n = 0;
while (n < writerBytes.Length)
{
int r = await ReadAsync(mode, readable, readerBytes, n, readerBytes.Length - n);
Assert.InRange(r, 1, writerBytes.Length - n);
n += r;
}
Assert.Equal(writerBytes.Length, n);
AssertExtensions.SequenceEqual(writerBytes, readerBytes.AsSpan(0, writerBytes.Length));
await writes;
}
}
}
}
[Theory]
[MemberData(nameof(AllReadWriteModesAndValue), false)]
[MemberData(nameof(AllReadWriteModesAndValue), true)]
......@@ -2160,6 +2207,10 @@ public virtual async Task ZeroByteRead_BlocksUntilDataAvailableOrNops(ReadWriteM
});
Assert.Equal(0, await zeroByteRead);
// Perform a second zero-byte read.
await Task.Run(() => ReadAsync(mode, readable, Array.Empty<byte>(), 0, 0));
// Now consume all the data.
var readBytes = new byte[5];
int count = 0;
while (count < readBytes.Length)
......@@ -2684,7 +2735,7 @@ public virtual async Task Flush_FlushesUnderlyingStream(bool flushAsync)
[InlineData(true, true)]
public virtual async Task Dispose_Flushes(bool useAsync, bool leaveOpen)
{
if (leaveOpen && (!SupportsLeaveOpen || ReadsMayBlockUntilBufferFullOrEOF))
if (leaveOpen && !SupportsLeaveOpen)
{
return;
}
......
......@@ -54,6 +54,6 @@ protected override Task<StreamPair> CreateConnectedStreamsAsync()
protected override Type UnsupportedReadWriteExceptionType => typeof(InvalidOperationException);
protected override bool WrappedUsableAfterClose => false;
protected override bool FlushRequiredToWriteData => true;
protected override bool FlushGuaranteesAllDataWritten => false;
protected override bool BlocksOnZeroByteReads => true;
}
}
......@@ -65,6 +65,17 @@ public static void ReadBytes(Stream stream, byte[] buffer, long bytesToRead)
}
}
public static int ReadAllBytes(Stream stream, byte[] buffer, int offset, int count)
{
int bytesRead;
int totalRead = 0;
while ((bytesRead = stream.Read(buffer, offset + totalRead, count - totalRead)) != 0)
{
totalRead += bytesRead;
}
return totalRead;
}
public static bool ArraysEqual<T>(T[] a, T[] b) where T : IComparable<T>
{
if (a.Length != b.Length) return false;
......@@ -111,8 +122,8 @@ public static void StreamsEqual(Stream ast, Stream bst, int blocksToRead)
if (blocksToRead != -1 && blocksRead >= blocksToRead)
break;
ac = ast.Read(ad, 0, 4096);
bc = bst.Read(bd, 0, 4096);
ac = ReadAllBytes(ast, ad, 0, 4096);
bc = ReadAllBytes(bst, bd, 0, 4096);
if (ac != bc)
{
......@@ -170,7 +181,7 @@ public static void IsZipSameAsDir(Stream archiveFile, string directory, ZipArchi
var buffer = new byte[entry.Length];
using (Stream entrystream = entry.Open())
{
entrystream.Read(buffer, 0, buffer.Length);
ReadAllBytes(entrystream, buffer, 0, buffer.Length);
#if NETCOREAPP
uint zipcrc = entry.Crc32;
Assert.Equal(CRC.CalculateCRC(buffer), zipcrc);
......
......@@ -173,7 +173,7 @@ private void EnsureNoActiveAsyncOperation()
private void AsyncOperationStarting()
{
if (Interlocked.CompareExchange(ref _activeAsyncOperation, 1, 0) != 0)
if (Interlocked.Exchange(ref _activeAsyncOperation, 1) != 0)
{
ThrowInvalidBeginCall();
}
......@@ -181,13 +181,11 @@ private void AsyncOperationStarting()
private void AsyncOperationCompleting()
{
int oldValue = Interlocked.CompareExchange(ref _activeAsyncOperation, 0, 1);
Debug.Assert(oldValue == 1, $"Expected {nameof(_activeAsyncOperation)} to be 1, got {oldValue}");
Debug.Assert(_activeAsyncOperation == 1);
Volatile.Write(ref _activeAsyncOperation, 0);
}
private static void ThrowInvalidBeginCall()
{
private static void ThrowInvalidBeginCall() =>
throw new InvalidOperationException(SR.InvalidBeginCall);
}
}
}
......@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
......@@ -42,8 +43,8 @@ public override int Read(byte[] buffer, int offset, int count)
public override int ReadByte()
{
byte b = default;
int numRead = Read(MemoryMarshal.CreateSpan(ref b, 1));
return numRead != 0 ? b : -1;
int bytesRead = Read(MemoryMarshal.CreateSpan(ref b, 1));
return bytesRead != 0 ? b : -1;
}
/// <summary>Reads a sequence of bytes from the current Brotli stream to a byte span and advances the position within the Brotli stream by the number of bytes read.</summary>
......@@ -57,59 +58,25 @@ public override int Read(Span<byte> buffer)
if (_mode != CompressionMode.Decompress)
throw new InvalidOperationException(SR.BrotliStream_Compress_UnsupportedOperation);
EnsureNotDisposed();
int totalWritten = 0;
OperationStatus lastResult = OperationStatus.DestinationTooSmall;
// We want to continue calling Decompress until we're either out of space for output or until Decompress indicates it is finished.
while (buffer.Length > 0 && lastResult != OperationStatus.Done)
int bytesWritten;
while (!TryDecompress(buffer, out bytesWritten))
{
if (lastResult == OperationStatus.NeedMoreData)
int bytesRead = _stream.Read(_buffer, _bufferCount, _buffer.Length - _bufferCount);
if (bytesRead <= 0)
{
// Ensure any left over data is at the beginning of the array so we can fill the remainder.
if (_bufferCount > 0 && _bufferOffset != 0)
{
_buffer.AsSpan(_bufferOffset, _bufferCount).CopyTo(_buffer);
}
_bufferOffset = 0;
int numRead = 0;
while (_bufferCount < _buffer.Length && ((numRead = _stream.Read(_buffer, _bufferCount, _buffer.Length - _bufferCount)) > 0))
{
_bufferCount += numRead;
if (_bufferCount > _buffer.Length)
{
// The stream is either malicious or poorly implemented and returned a number of
// bytes larger than the buffer supplied to it.
throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream);
}
}
if (_bufferCount <= 0)
{
break;
}
}
lastResult = _decoder.Decompress(new ReadOnlySpan<byte>(_buffer, _bufferOffset, _bufferCount), buffer, out int bytesConsumed, out int bytesWritten);
if (lastResult == OperationStatus.InvalidData)
{
throw new InvalidOperationException(SR.BrotliStream_Decompress_InvalidData);
break;
}
if (bytesConsumed > 0)
{
_bufferOffset += bytesConsumed;
_bufferCount -= bytesConsumed;
}
_bufferCount += bytesRead;
if (bytesWritten > 0)
if (_bufferCount > _buffer.Length)
{
totalWritten += bytesWritten;
buffer = buffer.Slice(bytesWritten);
ThrowInvalidStream();
}
}
return totalWritten;
return bytesWritten;
}
/// <summary>Begins an asynchronous read operation. (Consider using the <see cref="System.IO.Stream.ReadAsync(byte[],int,int)" /> method instead.)</summary>
......@@ -169,73 +136,100 @@ public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken
{
return ValueTask.FromCanceled<int>(cancellationToken);
}
return FinishReadAsyncMemory(buffer, cancellationToken);
}
private async ValueTask<int> FinishReadAsyncMemory(Memory<byte> buffer, CancellationToken cancellationToken)
{
AsyncOperationStarting();
try
return Core(buffer, cancellationToken);
async ValueTask<int> Core(Memory<byte> buffer, CancellationToken cancellationToken)
{
int totalWritten = 0;
OperationStatus lastResult = OperationStatus.DestinationTooSmall;
// We want to continue calling Decompress until we're either out of space for output or until Decompress indicates it is finished.
while (buffer.Length > 0 && lastResult != OperationStatus.Done)
AsyncOperationStarting();
try
{
if (lastResult == OperationStatus.NeedMoreData)
int bytesWritten;
while (!TryDecompress(buffer.Span, out bytesWritten))
{
// Ensure any left over data is at the beginning of the array so we can fill the remainder.
if (_bufferCount > 0 && _bufferOffset != 0)
int bytesRead = await _stream.ReadAsync(_buffer.AsMemory(_bufferCount), cancellationToken).ConfigureAwait(false);
if (bytesRead <= 0)
{
_buffer.AsSpan(_bufferOffset, _bufferCount).CopyTo(_buffer);
break;
}
_bufferOffset = 0;
int numRead = 0;
while (_bufferCount < _buffer.Length &&
((numRead = await _stream.ReadAsync(new Memory<byte>(_buffer, _bufferCount, _buffer.Length - _bufferCount), cancellationToken).ConfigureAwait(false)) > 0))
{
_bufferCount += numRead;
if (_bufferCount > _buffer.Length)
{
// The stream is either malicious or poorly implemented and returned a number of
// bytes larger than the buffer supplied to it.
throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream);
}
}
_bufferCount += bytesRead;
if (_bufferCount <= 0)
if (_bufferCount > _buffer.Length)
{
break;
ThrowInvalidStream();
}
}
cancellationToken.ThrowIfCancellationRequested();
lastResult = _decoder.Decompress(new ReadOnlySpan<byte>(_buffer, _bufferOffset, _bufferCount), buffer.Span, out int bytesConsumed, out int bytesWritten);
if (lastResult == OperationStatus.InvalidData)
{
throw new InvalidOperationException(SR.BrotliStream_Decompress_InvalidData);
}
return bytesWritten;
}
finally
{
AsyncOperationCompleting();
}
}
}
if (bytesConsumed > 0)
{
_bufferOffset += bytesConsumed;
_bufferCount -= bytesConsumed;
}
/// <summary>Tries to decode available data into the destination buffer.</summary>
/// <param name="destination">The destination buffer for the decompressed data.</param>
/// <param name="bytesWritten">The number of bytes written to destination.</param>
/// <returns>true if the caller should consider the read operation completed; otherwise, false.</returns>
private bool TryDecompress(Span<byte> destination, out int bytesWritten)
{
// Decompress any data we may have in our buffer.
OperationStatus lastResult = _decoder.Decompress(new ReadOnlySpan<byte>(_buffer, _bufferOffset, _bufferCount), destination, out int bytesConsumed, out bytesWritten);
if (lastResult == OperationStatus.InvalidData)
{
throw new InvalidOperationException(SR.BrotliStream_Decompress_InvalidData);
}
if (bytesWritten > 0)
{
totalWritten += bytesWritten;
buffer = buffer.Slice(bytesWritten);
}
}
if (bytesConsumed != 0)
{
_bufferOffset += bytesConsumed;
_bufferCount -= bytesConsumed;
}
// If we successfully decompressed any bytes, or if we've reached the end of the decompression, we're done.
if (bytesWritten != 0 || lastResult == OperationStatus.Done)
{
return true;
}
return totalWritten;
if (destination.IsEmpty)
{
// The caller provided a zero-byte buffer. This is typically done in order to avoid allocating/renting
// a buffer until data is known to be available. We don't have perfect knowledge here, as _decoder.Decompress
// will return DestinationTooSmall whether or not more data is required. As such, we assume that if there's
// any data in our input buffer, it would have been decompressible into at least one byte of output, and
// otherwise we need to do a read on the underlying stream. This isn't perfect, because having input data
// doesn't necessarily mean it'll decompress into at least one byte of output, but it's a reasonable approximation
// for the 99% case. If it's wrong, it just means that a caller using zero-byte reads as a way to delay
// getting a buffer to use for a subsequent call may end up getting one earlier than otherwise preferred.
Debug.Assert(lastResult == OperationStatus.DestinationTooSmall);
if (_bufferCount != 0)
{
Debug.Assert(bytesWritten == 0);
return true;
}
}
finally
Debug.Assert(
lastResult == OperationStatus.NeedMoreData ||
(lastResult == OperationStatus.DestinationTooSmall && destination.IsEmpty && _bufferCount == 0), $"{nameof(lastResult)} == {lastResult}, {nameof(destination.Length)} == {destination.Length}");
// Ensure any left over data is at the beginning of the array so we can fill the remainder.
if (_bufferCount != 0 && _bufferOffset != 0)
{
AsyncOperationCompleting();
new ReadOnlySpan<byte>(_buffer, _bufferOffset, _bufferCount).CopyTo(_buffer);
}
_bufferOffset = 0;
return false;
}
private static void ThrowInvalidStream() =>
// The stream is either malicious or poorly implemented and returned a number of
// bytes larger than the buffer supplied to it.
throw new InvalidDataException(SR.BrotliStream_Decompress_InvalidStream);
}
}
......@@ -68,8 +68,8 @@ internal void WriteCore(ReadOnlySpan<byte> buffer, bool isFinalBlock = false)
Span<byte> output = new Span<byte>(_buffer);
while (lastResult == OperationStatus.DestinationTooSmall)
{
int bytesConsumed = 0;
int bytesWritten = 0;
int bytesConsumed;
int bytesWritten;
lastResult = _encoder.Compress(buffer, output, out bytesConsumed, out bytesWritten, isFinalBlock);
if (lastResult == OperationStatus.InvalidData)
throw new InvalidOperationException(SR.BrotliStream_Compress_InvalidData);
......@@ -176,7 +176,7 @@ public override void Flush()
Span<byte> output = new Span<byte>(_buffer);
while (lastResult == OperationStatus.DestinationTooSmall)
{
int bytesWritten = 0;
int bytesWritten;
lastResult = _encoder.Flush(output, out bytesWritten);
if (lastResult == OperationStatus.InvalidData)
throw new InvalidDataException(SR.BrotliStream_Compress_InvalidData);
......
......@@ -14,7 +14,8 @@ public class BrotliStreamUnitTests : CompressionStreamUnitTestBase
public override Stream CreateStream(Stream stream, CompressionLevel level) => new BrotliStream(stream, level);
public override Stream CreateStream(Stream stream, CompressionLevel level, bool leaveOpen) => new BrotliStream(stream, level, leaveOpen);
public override Stream BaseStream(Stream stream) => ((BrotliStream)stream).BaseStream;
protected override bool ReadsMayBlockUntilBufferFullOrEOF => true;
protected override bool FlushGuaranteesAllDataWritten => false;
// The tests are relying on an implementation detail of BrotliStream, using knowledge of its internal buffer size
// in various test calculations. Currently the implementation is using the ArrayPool, which will round up to a
......
......@@ -104,12 +104,14 @@ internal void InitializeDeflater(Stream stream, bool leaveOpen, int windowBits,
InitializeBuffer();
}
[MemberNotNull(nameof(_buffer))]
private void InitializeBuffer()
{
Debug.Assert(_buffer == null);
_buffer = ArrayPool<byte>.Shared.Rent(DefaultBufferSize);
}
[MemberNotNull(nameof(_buffer))]
private void EnsureBufferInitialized()
{
if (_buffer == null)
......@@ -259,83 +261,94 @@ internal int ReadCore(Span<byte> buffer)
EnsureDecompressionMode();
EnsureNotDisposed();
EnsureBufferInitialized();
int totalRead = 0;
Debug.Assert(_inflater != null);
int bytesRead;
while (true)
{
int bytesRead = _inflater.Inflate(buffer.Slice(totalRead));
totalRead += bytesRead;
if (totalRead == buffer.Length)
{
break;
}
// If the stream is finished then we have a few potential cases here:
// 1. DeflateStream => return
// 2. GZipStream that is finished but may have an additional GZipStream appended => feed more input
// 3. GZipStream that is finished and appended with garbage => return
if (_inflater.Finished() && (!_inflater.IsGzipStream() || !_inflater.NeedsInput()))
// Try to decompress any data from the inflater into the caller's buffer.
// If we're able to decompress any bytes, or if decompression is completed, we're done.
bytesRead = _inflater.Inflate(buffer);
if (bytesRead != 0 || InflatorIsFinished)
{
break;
}
// We were unable to decompress any data. If the inflater needs additional input
// data to proceed, read some to populate it.
if (_inflater.NeedsInput())
{
Debug.Assert(_buffer != null);
int bytes = _stream.Read(_buffer, 0, _buffer.Length);
if (bytes <= 0)
int n = _stream.Read(_buffer, 0, _buffer.Length);
if (n <= 0)
{
break;
}
else if (bytes > _buffer.Length)
else if (n > _buffer.Length)
{
ThrowGenericInvalidData();
}
else
{
// The stream is either malicious or poorly implemented and returned a number of
// bytes larger than the buffer supplied to it.
throw new InvalidDataException(SR.GenericInvalidData);
_inflater.SetInput(_buffer, 0, n);
}
}
_inflater.SetInput(_buffer, 0, bytes);
if (buffer.IsEmpty)
{
// The caller provided a zero-byte buffer. This is typically done in order to avoid allocating/renting
// a buffer until data is known to be available. We don't have perfect knowledge here, as _inflater.Inflate
// will return 0 whether or not more data is required, and having input data doesn't necessarily mean it'll
// decompress into at least one byte of output, but it's a reasonable approximation for the 99% case. If it's
// wrong, it just means that a caller using zero-byte reads as a way to delay getting a buffer to use for a
// subsequent call may end up getting one earlier than otherwise preferred.
Debug.Assert(bytesRead == 0);
break;
}
}
return totalRead;
return bytesRead;
}
private bool InflatorIsFinished =>
// If the stream is finished then we have a few potential cases here:
// 1. DeflateStream => return
// 2. GZipStream that is finished but may have an additional GZipStream appended => feed more input
// 3. GZipStream that is finished and appended with garbage => return
_inflater!.Finished() &&
(!_inflater.IsGzipStream() || !_inflater.NeedsInput());
private void EnsureNotDisposed()
{
if (_stream == null)
ThrowStreamClosedException();
}
private static void ThrowStreamClosedException()
{
throw new ObjectDisposedException(nameof(DeflateStream), SR.ObjectDisposed_StreamClosed);
static void ThrowStreamClosedException() =>
throw new ObjectDisposedException(nameof(DeflateStream), SR.ObjectDisposed_StreamClosed);
}
private void EnsureDecompressionMode()
{
if (_mode != CompressionMode.Decompress)
ThrowCannotReadFromDeflateStreamException();
}
private static void ThrowCannotReadFromDeflateStreamException()
{
throw new InvalidOperationException(SR.CannotReadFromDeflateStream);
static void ThrowCannotReadFromDeflateStreamException() =>
throw new InvalidOperationException(SR.CannotReadFromDeflateStream);
}
private void EnsureCompressionMode()
{
if (_mode != CompressionMode.Compress)
ThrowCannotWriteToDeflateStreamException();
}
private static void ThrowCannotWriteToDeflateStreamException()
{
throw new InvalidOperationException(SR.CannotWriteToDeflateStream);
static void ThrowCannotWriteToDeflateStreamException() =>
throw new InvalidOperationException(SR.CannotWriteToDeflateStream);
}
private static void ThrowGenericInvalidData() =>
// The stream is either malicious or poorly implemented and returned a number of
// bytes < 0 || > than the buffer supplied to it.
throw new InvalidDataException(SR.GenericInvalidData);
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) =>
TaskToApm.Begin(ReadAsync(buffer, offset, count, CancellationToken.None), asyncCallback, asyncState);
......@@ -378,6 +391,7 @@ internal ValueTask<int> ReadAsyncMemory(Memory<byte> buffer, CancellationToken c
}
EnsureBufferInitialized();
Debug.Assert(_inflater != null);
return Core(buffer, cancellationToken);
......@@ -386,48 +400,49 @@ async ValueTask<int> Core(Memory<byte> buffer, CancellationToken cancellationTok
AsyncOperationStarting();
try
{
int totalRead = 0;
Debug.Assert(_inflater != null);
int bytesRead;
while (true)
{
int bytesRead = _inflater.Inflate(buffer.Span.Slice(totalRead));
totalRead += bytesRead;
if (totalRead == buffer.Length)
{
break;
}
// If the stream is finished then we have a few potential cases here:
// 1. DeflateStream => return
// 2. GZipStream that is finished but may have an additional GZipStream appended => feed more input
// 3. GZipStream that is finished and appended with garbage => return
if (_inflater.Finished() && (!_inflater.IsGzipStream() || !_inflater.NeedsInput()))
// Try to decompress any data from the inflater into the caller's buffer.
// If we're able to decompress any bytes, or if decompression is completed, we're done.
bytesRead = _inflater.Inflate(buffer.Span);
if (bytesRead != 0 || InflatorIsFinished)
{
break;
}
// We were unable to decompress any data. If the inflater needs additional input
// data to proceed, read some to populate it.
if (_inflater.NeedsInput())
{
Debug.Assert(_buffer != null);
int bytes = await _stream.ReadAsync(_buffer, cancellationToken).ConfigureAwait(false);
EnsureNotDisposed();
if (bytes <= 0)
int n = await _stream.ReadAsync(new Memory<byte>(_buffer, 0, _buffer.Length), cancellationToken).ConfigureAwait(false);
if (n <= 0)
{
break;
}
else if (bytes > _buffer.Length)
else if (n > _buffer.Length)
{
// The stream is either malicious or poorly implemented and returned a number of
// bytes larger than the buffer supplied to it.
throw new InvalidDataException(SR.GenericInvalidData);
ThrowGenericInvalidData();
}
else
{
_inflater.SetInput(_buffer, 0, n);
}
}
_inflater.SetInput(_buffer, 0, bytes);
if (buffer.IsEmpty)
{
// The caller provided a zero-byte buffer. This is typically done in order to avoid allocating/renting
// a buffer until data is known to be available. We don't have perfect knowledge here, as _inflater.Inflate
// will return 0 whether or not more data is required, and having input data doesn't necessarily mean it'll
// decompress into at least one byte of output, but it's a reasonable approximation for the 99% case. If it's
// wrong, it just means that a caller using zero-byte reads as a way to delay getting a buffer to use for a
// subsequent call may end up getting one earlier than otherwise preferred.
break;
}
}
return totalRead;
return bytesRead;
}
finally
{
......@@ -1014,21 +1029,16 @@ private void EnsureNoActiveAsyncOperation()
private void AsyncOperationStarting()
{
if (Interlocked.CompareExchange(ref _activeAsyncOperation, 1, 0) != 0)
if (Interlocked.Exchange(ref _activeAsyncOperation, 1) != 0)
{
ThrowInvalidBeginCall();
}
}
private void AsyncOperationCompleting()
{
int oldValue = Interlocked.CompareExchange(ref _activeAsyncOperation, 0, 1);
Debug.Assert(oldValue == 1, $"Expected {nameof(_activeAsyncOperation)} to be 1, got {oldValue}");
}
private void AsyncOperationCompleting() =>
Volatile.Write(ref _activeAsyncOperation, 0);
private static void ThrowInvalidBeginCall()
{
private static void ThrowInvalidBeginCall() =>
throw new InvalidOperationException(SR.InvalidBeginCall);
}
}
}
......@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using Xunit;
......@@ -123,7 +124,7 @@ private static void ValidateCryptoStream(string expected, string data, ICryptoTr
using (var ms = new MemoryStream(inputBytes))
using (var cs = new CryptoStream(ms, transform, CryptoStreamMode.Read))
{
int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length);
int bytesRead = ReadAll(cs, outputBytes);
string outputString = Text.Encoding.ASCII.GetString(outputBytes, 0, bytesRead);
Assert.Equal(expected, outputString);
}
......@@ -195,7 +196,7 @@ public static void ValidateFromBase64_NoPadding(string data)
using (var ms = new MemoryStream(inputBytes))
using (var cs = new CryptoStream(ms, transform, CryptoStreamMode.Read))
{
int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length);
int bytesRead = ReadAll(cs, outputBytes);
// Missing padding bytes not supported (no exception, however)
Assert.NotEqual(inputBytes.Length, bytesRead);
......@@ -230,7 +231,7 @@ public static void ValidateWhitespace(string expected, string data)
using (var ms = new MemoryStream(inputBytes))
using (var cs = new CryptoStream(ms, base64Transform, CryptoStreamMode.Read))
{
int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length);
int bytesRead = ReadAll(cs, outputBytes);
string outputString = Text.Encoding.ASCII.GetString(outputBytes, 0, bytesRead);
Assert.Equal(expected, outputString);
}
......@@ -240,7 +241,7 @@ public static void ValidateWhitespace(string expected, string data)
using (var ms = new MemoryStream(inputBytes))
using (var cs = new CryptoStream(ms, base64Transform, CryptoStreamMode.Read))
{
int bytesRead = cs.Read(outputBytes, 0, outputBytes.Length);
int bytesRead = ReadAll(cs, outputBytes);
string outputString = Text.Encoding.ASCII.GetString(outputBytes, 0, bytesRead);
Assert.Equal(expected, outputString);
}
......@@ -293,5 +294,22 @@ public void TransformUsageFlags_FromBase64Transform()
Assert.True(transform.CanReuseTransform);
}
}
private static int ReadAll(Stream stream, Span<byte> buffer)
{
int totalRead = 0;
while (totalRead < buffer.Length)
{
int bytesRead = stream.Read(buffer.Slice(totalRead));
if (bytesRead == 0)
{
break;
}
totalRead += bytesRead;
}
return totalRead;
}
}
}
......@@ -27,6 +27,7 @@ protected override Task<StreamPair> CreateWrappedConnectedStreamsAsync(StreamPai
}
protected override Type UnsupportedConcurrentExceptionType => null;
protected override bool BlocksOnZeroByteReads => true;
[ActiveIssue("https://github.com/dotnet/runtime/issues/45080")]
[Theory]
......@@ -37,7 +38,7 @@ protected override Task<StreamPair> CreateWrappedConnectedStreamsAsync(StreamPai
public static void Ctor()
{
var transform = new IdentityTransform(1, 1, true);
AssertExtensions.Throws<ArgumentException>(null, () => new CryptoStream(new MemoryStream(), transform, (CryptoStreamMode)12345));
AssertExtensions.Throws<ArgumentException>("mode", () => new CryptoStream(new MemoryStream(), transform, (CryptoStreamMode)12345));
AssertExtensions.Throws<ArgumentException>(null, "stream", () => new CryptoStream(new MemoryStream(new byte[0], writable: false), transform, CryptoStreamMode.Write));
AssertExtensions.Throws<ArgumentException>(null, "stream", () => new CryptoStream(new CryptoStream(new MemoryStream(new byte[0]), transform, CryptoStreamMode.Write), transform, CryptoStreamMode.Read));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册