diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index 0aa38699931612782e9387651ef80c3d07a37ee9..74544ffb73b1214b5dd00813237673f8db37ab2d 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -396,14 +396,11 @@ public abstract class DataBufferUtils { AtomicLong countDown = new AtomicLong(maxByteCount); return Flux.from(publisher) - .takeWhile(buffer -> { - int delta = -buffer.readableByteCount(); - return countDown.getAndAdd(delta) >= 0; - }) .map(buffer -> { - long count = countDown.get(); + long count = countDown.addAndGet(-buffer.readableByteCount()); return count >= 0 ? buffer : buffer.slice(0, buffer.readableByteCount() + (int) count); - }); + }) + .takeUntil(buffer -> countDown.get() <= 0); } /** diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index 8f19ea31085246128d91c4dd3727474926f73a04..4b7ddb7dc39fc493236769ea9e5bda8b83971c61 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -226,19 +226,32 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { @Test public void takeUntilByteCount() { - DataBuffer foo = stringBuffer("foo"); - DataBuffer bar = stringBuffer("bar"); - DataBuffer baz = stringBuffer("baz"); - Flux flux = Flux.just(foo, bar, baz); - Flux result = DataBufferUtils.takeUntilByteCount(flux, 5L); + + Flux result = DataBufferUtils.takeUntilByteCount( + Flux.just(stringBuffer("foo"), stringBuffer("bar")), 5L); StepVerifier.create(result) .consumeNextWith(stringConsumer("foo")) .consumeNextWith(stringConsumer("ba")) .expectComplete() .verify(Duration.ofSeconds(5)); + } + + @Test + public void takeUntilByteCountExact() { + + DataBuffer extraBuffer = stringBuffer("baz"); + + Flux result = DataBufferUtils.takeUntilByteCount( + Flux.just(stringBuffer("foo"), stringBuffer("bar"), extraBuffer), 6L); + + StepVerifier.create(result) + .consumeNextWith(stringConsumer("foo")) + .consumeNextWith(stringConsumer("bar")) + .expectComplete() + .verify(Duration.ofSeconds(5)); - release(baz); + release(extraBuffer); } @Test