提交 ef14d76d 编写于 作者: R Rossen Stoyanchev

Merge limits on input in codecs

......@@ -48,12 +48,39 @@ import org.springframework.util.MimeType;
@SuppressWarnings("deprecation")
public abstract class AbstractDataBufferDecoder<T> extends AbstractDecoder<T> {
private int maxInMemorySize = 256 * 1024;
protected AbstractDataBufferDecoder(MimeType... supportedMimeTypes) {
super(supportedMimeTypes);
}
/**
* Configure a limit on the number of bytes that can be buffered whenever
* the input stream needs to be aggregated. This can be a result of
* decoding to a single {@code DataBuffer},
* {@link java.nio.ByteBuffer ByteBuffer}, {@code byte[]},
* {@link org.springframework.core.io.Resource Resource}, {@code String}, etc.
* It can also occur when splitting the input stream, e.g. delimited text,
* in which case the limit applies to data buffered between delimiters.
* <p>By default this is set to 256K.
* @param byteCount the max number of bytes to buffer, or -1 for unlimited
* @since 5.1.11
*/
public void setMaxInMemorySize(int byteCount) {
this.maxInMemorySize = byteCount;
}
/**
* Return the {@link #setMaxInMemorySize configured} byte count limit.
* @since 5.1.11
*/
public int getMaxInMemorySize() {
return this.maxInMemorySize;
}
@Override
public Flux<T> decode(Publisher<DataBuffer> input, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
......@@ -65,7 +92,7 @@ public abstract class AbstractDataBufferDecoder<T> extends AbstractDecoder<T> {
public Mono<T> decodeToMono(Publisher<DataBuffer> input, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
return DataBufferUtils.join(input)
return DataBufferUtils.join(input, this.maxInMemorySize)
.map(buffer -> decodeDataBuffer(buffer, elementType, mimeType, hints));
}
......
......@@ -25,15 +25,18 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DataBufferWrapper;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.core.io.buffer.LimitedDataBufferList;
import org.springframework.core.io.buffer.PooledDataBuffer;
import org.springframework.core.log.LogFormatUtils;
import org.springframework.lang.Nullable;
......@@ -91,12 +94,18 @@ public final class StringDecoder extends AbstractDataBufferDecoder<String> {
byte[][] delimiterBytes = getDelimiterBytes(mimeType);
// TODO: Drop Consumer and use bufferUntil with Supplier<LimistedDataBufferList> (reactor-core#1925)
// TODO: Drop doOnDiscard(LimitedDataBufferList.class, ...) (reactor-core#1924)
LimitedDataBufferConsumer limiter = new LimitedDataBufferConsumer(getMaxInMemorySize());
Flux<DataBuffer> inputFlux = Flux.defer(() -> {
DataBufferUtils.Matcher matcher = DataBufferUtils.matcher(delimiterBytes);
return Flux.from(input)
.concatMapIterable(buffer -> endFrameAfterDelimiter(buffer, matcher))
.doOnNext(limiter)
.bufferUntil(buffer -> buffer instanceof EndFrameBuffer)
.map(buffers -> joinAndStrip(buffers, this.stripDelimiter))
.doOnDiscard(LimitedDataBufferList.class, LimitedDataBufferList::releaseAndClear)
.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
});
......@@ -279,4 +288,34 @@ public final class StringDecoder extends AbstractDataBufferDecoder<String> {
}
/**
* Temporary measure for reactor-core#1925.
* Consumer that adds to a {@link LimitedDataBufferList} to enforce limits.
*/
private static class LimitedDataBufferConsumer implements Consumer<DataBuffer> {
private final LimitedDataBufferList bufferList;
public LimitedDataBufferConsumer(int maxInMemorySize) {
this.bufferList = new LimitedDataBufferList(maxInMemorySize);
}
@Override
public void accept(DataBuffer buffer) {
if (buffer instanceof EndFrameBuffer) {
this.bufferList.clear();
}
else {
try {
this.bufferList.add(buffer);
}
catch (DataBufferLimitException ex) {
DataBufferUtils.release(buffer);
throw ex;
}
}
}
}
}
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.core.io.buffer;
/**
* Exception that indicates the cumulative number of bytes consumed from a
* stream of {@link DataBuffer DataBuffer}'s exceeded some pre-configured limit.
* This can be raised when data buffers are cached and aggregated, e.g.
* {@link DataBufferUtils#join}. Or it could also be raised when data buffers
* have been released but a parsed representation is being aggregated, e.g. async
* parsing with Jackson.
*
* @author Rossen Stoyanchev
* @since 5.1.11
*/
@SuppressWarnings("serial")
public class DataBufferLimitException extends IllegalStateException {
public DataBufferLimitException(String message) {
super(message);
}
}
......@@ -525,16 +525,35 @@ public abstract class DataBufferUtils {
*/
@SuppressWarnings("unchecked")
public static Mono<DataBuffer> join(Publisher<? extends DataBuffer> dataBuffers) {
Assert.notNull(dataBuffers, "'dataBuffers' must not be null");
return join(dataBuffers, -1);
}
/**
* Variant of {@link #join(Publisher)} that behaves the same way up until
* the specified max number of bytes to buffer. Once the limit is exceeded,
* {@link DataBufferLimitException} is raised.
* @param buffers the data buffers that are to be composed
* @param maxByteCount the max number of bytes to buffer, or -1 for unlimited
* @return a buffer with the aggregated content, possibly an empty Mono if
* the max number of bytes to buffer is exceeded.
* @throws DataBufferLimitException if maxByteCount is exceeded
* @since 5.1.11
*/
@SuppressWarnings("unchecked")
public static Mono<DataBuffer> join(Publisher<? extends DataBuffer> buffers, int maxByteCount) {
Assert.notNull(buffers, "'dataBuffers' must not be null");
if (dataBuffers instanceof Mono) {
return (Mono<DataBuffer>) dataBuffers;
if (buffers instanceof Mono) {
return (Mono<DataBuffer>) buffers;
}
return Flux.from(dataBuffers)
.collectList()
// TODO: Drop doOnDiscard(LimitedDataBufferList.class, ...) (reactor-core#1924)
return Flux.from(buffers)
.collect(() -> new LimitedDataBufferList(maxByteCount), LimitedDataBufferList::add)
.filter(list -> !list.isEmpty())
.map(list -> list.get(0).factory().join(list))
.doOnDiscard(LimitedDataBufferList.class, LimitedDataBufferList::releaseAndClear)
.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
}
......
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.core.io.buffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.function.Predicate;
import reactor.core.publisher.Flux;
/**
* Custom {@link List} to collect data buffers with and enforce a
* limit on the total number of bytes buffered. For use with "collect" or
* other buffering operators in declarative APIs, e.g. {@link Flux}.
*
* <p>Adding elements increases the byte count and if the limit is exceeded,
* {@link DataBufferLimitException} is raised. {@link #clear()} resets the
* count. Remove and set are not supported.
*
* <p><strong>Note:</strong> This class does not automatically release the
* buffers it contains. It is usually preferable to use hooks such as
* {@link Flux#doOnDiscard} that also take care of cancel and error signals,
* or otherwise {@link #releaseAndClear()} can be used.
*
* @author Rossen Stoyanchev
* @since 5.1.11
*/
@SuppressWarnings("serial")
public class LimitedDataBufferList extends ArrayList<DataBuffer> {
private final int maxByteCount;
private int byteCount;
public LimitedDataBufferList(int maxByteCount) {
this.maxByteCount = maxByteCount;
}
@Override
public boolean add(DataBuffer buffer) {
boolean result = super.add(buffer);
if (result) {
updateCount(buffer.readableByteCount());
}
return result;
}
@Override
public void add(int index, DataBuffer buffer) {
super.add(index, buffer);
updateCount(buffer.readableByteCount());
}
@Override
public boolean addAll(Collection<? extends DataBuffer> collection) {
boolean result = super.addAll(collection);
collection.forEach(buffer -> updateCount(buffer.readableByteCount()));
return result;
}
@Override
public boolean addAll(int index, Collection<? extends DataBuffer> collection) {
boolean result = super.addAll(index, collection);
collection.forEach(buffer -> updateCount(buffer.readableByteCount()));
return result;
}
private void updateCount(int bytesToAdd) {
if (this.maxByteCount < 0) {
return;
}
if (bytesToAdd > Integer.MAX_VALUE - this.byteCount) {
raiseLimitException();
}
else {
this.byteCount += bytesToAdd;
if (this.byteCount > this.maxByteCount) {
raiseLimitException();
}
}
}
private void raiseLimitException() {
// Do not release here, it's likely down via doOnDiscard..
throw new DataBufferLimitException(
"Exceeded limit on max bytes to buffer : " + this.maxByteCount);
}
@Override
public DataBuffer remove(int index) {
throw new UnsupportedOperationException();
}
@Override
public boolean remove(Object o) {
throw new UnsupportedOperationException();
}
@Override
protected void removeRange(int fromIndex, int toIndex) {
throw new UnsupportedOperationException();
}
@Override
public boolean removeAll(Collection<?> c) {
throw new UnsupportedOperationException();
}
@Override
public boolean removeIf(Predicate<? super DataBuffer> filter) {
throw new UnsupportedOperationException();
}
@Override
public DataBuffer set(int index, DataBuffer element) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
this.byteCount = 0;
super.clear();
}
/**
* Shortcut to {@link DataBufferUtils#release release} all data buffers and
* then {@link #clear()}.
*/
public void releaseAndClear() {
forEach(buf -> {
try {
DataBufferUtils.release(buf);
}
catch (Throwable ex) {
// Keep going..
}
});
clear();
}
}
......@@ -19,7 +19,6 @@ package org.springframework.core.codec;
import java.util.Collections;
import java.util.function.Consumer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
......@@ -30,9 +29,9 @@ import reactor.test.StepVerifier;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.buffer.AbstractLeakCheckingTests;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.LeakAwareDataBufferFactory;
import org.springframework.core.io.buffer.support.DataBufferTestUtils;
import org.springframework.core.io.support.ResourceRegion;
import org.springframework.util.MimeType;
......@@ -45,18 +44,10 @@ import static org.assertj.core.api.Assertions.assertThat;
* Test cases for {@link ResourceRegionEncoder} class.
* @author Brian Clozel
*/
class ResourceRegionEncoderTests {
class ResourceRegionEncoderTests extends AbstractLeakCheckingTests {
private ResourceRegionEncoder encoder = new ResourceRegionEncoder();
private LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory();
@AfterEach
void tearDown() throws Exception {
this.bufferFactory.checkForLeaks();
}
@Test
void canEncode() {
ResolvableType resourceRegion = ResolvableType.forClass(ResourceRegion.class);
......
......@@ -29,6 +29,7 @@ import reactor.test.StepVerifier;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
......@@ -127,6 +128,20 @@ class StringDecoderTests extends AbstractDecoderTests<StringDecoder> {
.verify());
}
@Test
void decodeNewLineWithLimit() {
Flux<DataBuffer> input = Flux.just(
stringBuffer("abc\n"),
stringBuffer("defg\n"),
stringBuffer("hijkl\n")
);
this.decoder.setMaxInMemorySize(5);
testDecode(input, String.class, step ->
step.expectNext("abc", "defg")
.verifyError(DataBufferLimitException.class));
}
@Test
void decodeNewLineIncludeDelimiters() {
this.decoder = StringDecoder.allMimeTypes(StringDecoder.DEFAULT_DELIMITERS, false);
......
......@@ -813,13 +813,27 @@ class DataBufferUtilsTests extends AbstractDataBufferAllocatingTests {
Mono<DataBuffer> result = DataBufferUtils.join(flux);
StepVerifier.create(result)
.consumeNextWith(dataBuffer -> {
assertThat(DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8)).isEqualTo("foobarbaz");
release(dataBuffer);
.consumeNextWith(buf -> {
assertThat(DataBufferTestUtils.dumpString(buf, StandardCharsets.UTF_8)).isEqualTo("foobarbaz");
release(buf);
})
.verifyComplete();
}
@ParameterizedDataBufferAllocatingTest
void joinWithLimit(String displayName, DataBufferFactory bufferFactory) {
super.bufferFactory = bufferFactory;
DataBuffer foo = stringBuffer("foo");
DataBuffer bar = stringBuffer("bar");
DataBuffer baz = stringBuffer("baz");
Flux<DataBuffer> flux = Flux.just(foo, bar, baz);
Mono<DataBuffer> result = DataBufferUtils.join(flux, 8);
StepVerifier.create(result)
.verifyError(DataBufferLimitException.class);
}
@ParameterizedDataBufferAllocatingTest
void joinErrors(String displayName, DataBufferFactory bufferFactory) {
super.bufferFactory = bufferFactory;
......
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.core.io.buffer;
import java.nio.charset.StandardCharsets;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
/**
* Unit tests for {@link LimitedDataBufferList}.
* @author Rossen Stoyanchev
* @since 5.1.11
*/
public class LimitedDataBufferListTests {
private final static DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
@Test
void limitEnforced() {
Assertions.assertThatThrownBy(() -> new LimitedDataBufferList(5).add(toDataBuffer("123456")))
.isInstanceOf(DataBufferLimitException.class);
}
@Test
void limitIgnored() {
new LimitedDataBufferList(-1).add(toDataBuffer("123456"));
}
@Test
void clearResetsCount() {
LimitedDataBufferList list = new LimitedDataBufferList(5);
list.add(toDataBuffer("12345"));
list.clear();
list.add(toDataBuffer("12345"));
}
private static DataBuffer toDataBuffer(String value) {
return bufferFactory.wrap(value.getBytes(StandardCharsets.UTF_8));
}
}
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -143,6 +143,20 @@ public interface CodecConfigurer {
*/
void jaxb2Encoder(Encoder<?> encoder);
/**
* Configure a limit on the number of bytes that can be buffered whenever
* the input stream needs to be aggregated. This can be a result of
* decoding to a single {@code DataBuffer},
* {@link java.nio.ByteBuffer ByteBuffer}, {@code byte[]},
* {@link org.springframework.core.io.Resource Resource}, {@code String}, etc.
* It can also occur when splitting the input stream, e.g. delimited text,
* in which case the limit applies to data buffered between delimiters.
* <p>By default this is not set, in which case individual codec defaults
* apply. All codecs are limited to 256K by default.
* @param byteCount the max number of bytes to buffer, or -1 for unlimited
* @sine 5.1.11
*/
void maxInMemorySize(int byteCount);
/**
* Whether to log form data at DEBUG level, and headers at TRACE level.
* Both may contain sensitive information.
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -30,6 +30,7 @@ import reactor.core.publisher.Mono;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Hints;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.log.LogFormatUtils;
import org.springframework.http.MediaType;
......@@ -62,6 +63,8 @@ public class FormHttpMessageReader extends LoggingCodecSupport
private Charset defaultCharset = DEFAULT_CHARSET;
private int maxInMemorySize = 256 * 1024;
/**
* Set the default character set to use for reading form data when the
......@@ -80,6 +83,26 @@ public class FormHttpMessageReader extends LoggingCodecSupport
return this.defaultCharset;
}
/**
* Set the max number of bytes for input form data. As form data is buffered
* before it is parsed, this helps to limit the amount of buffering. Once
* the limit is exceeded, {@link DataBufferLimitException} is raised.
* <p>By default this is set to 256K.
* @param byteCount the max number of bytes to buffer, or -1 for unlimited
* @since 5.1.11
*/
public void setMaxInMemorySize(int byteCount) {
this.maxInMemorySize = byteCount;
}
/**
* Return the {@link #setMaxInMemorySize configured} byte count limit.
* @since 5.1.11
*/
public int getMaxInMemorySize() {
return this.maxInMemorySize;
}
@Override
public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) {
......@@ -105,7 +128,7 @@ public class FormHttpMessageReader extends LoggingCodecSupport
MediaType contentType = message.getHeaders().getContentType();
Charset charset = getMediaTypeCharset(contentType);
return DataBufferUtils.join(message.getBody())
return DataBufferUtils.join(message.getBody(), this.maxInMemorySize)
.map(buffer -> {
CharBuffer charBuffer = charset.decode(buffer.asByteBuffer());
String body = charBuffer.toString();
......
......@@ -76,6 +76,20 @@ public interface ServerCodecConfigurer extends CodecConfigurer {
*/
interface ServerDefaultCodecs extends DefaultCodecs {
/**
* Configure the {@code HttpMessageReader} to use for multipart requests.
* <p>By default, if
* <a href="https://github.com/synchronoss/nio-multipart">Synchronoss NIO Multipart</a>
* is present, this is set to
* {@link org.springframework.http.codec.multipart.MultipartHttpMessageReader
* MultipartHttpMessageReader} created with an instance of
* {@link org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader
* SynchronossPartHttpMessageReader}.
* @param reader the message reader to use for multipart requests.
* @since 5.1.11
*/
void multipartReader(HttpMessageReader<?> reader);
/**
* Configure the {@code Encoder} to use for Server-Sent Events.
* <p>By default if this is not set, and Jackson is available, the
......
......@@ -37,6 +37,7 @@ import org.springframework.core.codec.CodecException;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.codec.Hints;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.log.LogFormatUtils;
import org.springframework.http.codec.HttpMessageDecoder;
......@@ -59,6 +60,9 @@ import org.springframework.util.MimeType;
*/
public abstract class AbstractJackson2Decoder extends Jackson2CodecSupport implements HttpMessageDecoder<Object> {
private int maxInMemorySize = 256 * 1024;
/**
* Constructor with a Jackson {@link ObjectMapper} to use.
*/
......@@ -67,6 +71,28 @@ public abstract class AbstractJackson2Decoder extends Jackson2CodecSupport imple
}
/**
* Set the max number of bytes that can be buffered by this decoder. This
* is either the size of the entire input when decoding as a whole, or the
* size of one top-level JSON object within a JSON stream. When the limit
* is exceeded, {@link DataBufferLimitException} is raised.
* <p>By default this is set to 256K.
* @param byteCount the max number of bytes to buffer, or -1 for unlimited
* @since 5.1.11
*/
public void setMaxInMemorySize(int byteCount) {
this.maxInMemorySize = byteCount;
}
/**
* Return the {@link #setMaxInMemorySize configured} byte count limit.
* @since 5.1.11
*/
public int getMaxInMemorySize() {
return this.maxInMemorySize;
}
@Override
public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) {
JavaType javaType = getObjectMapper().getTypeFactory().constructType(elementType.getType());
......@@ -81,7 +107,7 @@ public abstract class AbstractJackson2Decoder extends Jackson2CodecSupport imple
ObjectMapper mapper = getObjectMapper();
Flux<TokenBuffer> tokens = Jackson2Tokenizer.tokenize(
Flux.from(input), mapper.getFactory(), mapper, true);
Flux.from(input), mapper.getFactory(), mapper, true, getMaxInMemorySize());
ObjectReader reader = getObjectReader(elementType, hints);
......@@ -103,7 +129,7 @@ public abstract class AbstractJackson2Decoder extends Jackson2CodecSupport imple
public Mono<Object> decodeToMono(Publisher<DataBuffer> input, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
return DataBufferUtils.join(input)
return DataBufferUtils.join(input, this.maxInMemorySize)
.map(dataBuffer -> decode(dataBuffer, elementType, mimeType, hints));
}
......
......@@ -35,6 +35,7 @@ import reactor.core.publisher.Flux;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
/**
......@@ -61,30 +62,39 @@ final class Jackson2Tokenizer {
private int arrayDepth;
private final int maxInMemorySize;
private int byteCount;
// TODO: change to ByteBufferFeeder when supported by Jackson
// See https://github.com/FasterXML/jackson-core/issues/478
private final ByteArrayFeeder inputFeeder;
private Jackson2Tokenizer(
JsonParser parser, DeserializationContext deserializationContext, boolean tokenizeArrayElements) {
private Jackson2Tokenizer(JsonParser parser, DeserializationContext deserializationContext,
boolean tokenizeArrayElements, int maxInMemorySize) {
this.parser = parser;
this.deserializationContext = deserializationContext;
this.tokenizeArrayElements = tokenizeArrayElements;
this.tokenBuffer = new TokenBuffer(parser, deserializationContext);
this.inputFeeder = (ByteArrayFeeder) this.parser.getNonBlockingInputFeeder();
this.maxInMemorySize = maxInMemorySize;
}
private List<TokenBuffer> tokenize(DataBuffer dataBuffer) {
byte[] bytes = new byte[dataBuffer.readableByteCount()];
int bufferSize = dataBuffer.readableByteCount();
byte[] bytes = new byte[bufferSize];
dataBuffer.read(bytes);
DataBufferUtils.release(dataBuffer);
try {
this.inputFeeder.feedInput(bytes, 0, bytes.length);
return parseTokenBufferFlux();
List<TokenBuffer> result = parseTokenBufferFlux();
assertInMemorySize(bufferSize, result);
return result;
}
catch (JsonProcessingException ex) {
throw new DecodingException("JSON decoding error: " + ex.getOriginalMessage(), ex);
......@@ -174,18 +184,40 @@ final class Jackson2Tokenizer {
(token == JsonToken.END_ARRAY && this.arrayDepth == 0));
}
private void assertInMemorySize(int currentBufferSize, List<TokenBuffer> result) {
if (this.maxInMemorySize >= 0) {
if (!result.isEmpty()) {
this.byteCount = 0;
}
else if (currentBufferSize > Integer.MAX_VALUE - this.byteCount) {
raiseLimitException();
}
else {
this.byteCount += currentBufferSize;
if (this.byteCount > this.maxInMemorySize) {
raiseLimitException();
}
}
}
}
private void raiseLimitException() {
throw new DataBufferLimitException(
"Exceeded limit on max bytes per JSON object: " + this.maxInMemorySize);
}
/**
* Tokenize the given {@code Flux<DataBuffer>} into {@code Flux<TokenBuffer>}.
* @param dataBuffers the source data buffers
* @param jsonFactory the factory to use
* @param objectMapper the current mapper instance
* @param tokenizeArrayElements if {@code true} and the "top level" JSON object is
* @param tokenizeArrays if {@code true} and the "top level" JSON object is
* an array, each element is returned individually immediately after it is received
* @return the resulting token buffers
*/
public static Flux<TokenBuffer> tokenize(Flux<DataBuffer> dataBuffers, JsonFactory jsonFactory,
ObjectMapper objectMapper, boolean tokenizeArrayElements) {
ObjectMapper objectMapper, boolean tokenizeArrays, int maxInMemorySize) {
try {
JsonParser parser = jsonFactory.createNonBlockingByteArrayParser();
......@@ -194,7 +226,7 @@ final class Jackson2Tokenizer {
context = ((DefaultDeserializationContext) context).createInstance(
objectMapper.getDeserializationConfig(), parser, objectMapper.getInjectableValues());
}
Jackson2Tokenizer tokenizer = new Jackson2Tokenizer(parser, context, tokenizeArrayElements);
Jackson2Tokenizer tokenizer = new Jackson2Tokenizer(parser, context, tokenizeArrays, maxInMemorySize);
return dataBuffers.concatMapIterable(tokenizer::tokenize).concatWith(tokenizer.endOfInput());
}
catch (IOException ex) {
......
......@@ -65,6 +65,14 @@ public class MultipartHttpMessageReader extends LoggingCodecSupport
}
/**
* Return the configured parts reader.
* @since 5.1.11
*/
public HttpMessageReader<Part> getPartReader() {
return this.partReader;
}
@Override
public List<MediaType> getReadableMediaTypes() {
return Collections.singletonList(MediaType.MULTIPART_FORM_DATA);
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -40,14 +40,18 @@ import org.synchronoss.cloud.nio.multipart.NioMultipartParser;
import org.synchronoss.cloud.nio.multipart.NioMultipartParserListener;
import org.synchronoss.cloud.nio.multipart.PartBodyStreamStorageFactory;
import org.synchronoss.cloud.nio.stream.storage.StreamStorage;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.codec.Hints;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.core.log.LogFormatUtils;
......@@ -69,15 +73,82 @@ import org.springframework.util.Assert;
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
* @author Arjen Poutsma
* @author Brian Clozel
* @since 5.0
* @see <a href="https://github.com/synchronoss/nio-multipart">Synchronoss NIO Multipart</a>
* @see MultipartHttpMessageReader
*/
public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implements HttpMessageReader<Part> {
private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
// Static DataBufferFactory to copy from FileInputStream or wrap bytes[].
private static final DataBufferFactory bufferFactory = new DefaultDataBufferFactory();
private final PartBodyStreamStorageFactory streamStorageFactory = new DefaultPartBodyStreamStorageFactory();
private int maxInMemorySize = 256 * 1024;
private long maxDiskUsagePerPart = -1;
private long maxParts = -1;
/**
* Configure the maximum amount of memory that is allowed to use per part.
* When the limit is exceeded:
* <ul>
* <li>file parts are written to a temporary file.
* <li>non-file parts are rejected with {@link DataBufferLimitException}.
* </ul>
* <p>By default this is set to 256K.
* @param byteCount the in-memory limit in bytes; if set to -1 this limit is
* not enforced, and all parts may be written to disk and are limited only
* by the {@link #setMaxDiskUsagePerPart(long) maxDiskUsagePerPart} property.
* @since 5.1.11
*/
public void setMaxInMemorySize(int byteCount) {
this.maxInMemorySize = byteCount;
}
/**
* Get the {@link #setMaxInMemorySize configured} maximum in-memory size.
* @since 5.1.11
*/
public int getMaxInMemorySize() {
return this.maxInMemorySize;
}
/**
* Configure the maximum amount of disk space allowed for file parts.
* <p>By default this is set to -1.
* @param maxDiskUsagePerPart the disk limit in bytes, or -1 for unlimited
* @since 5.1.11
*/
public void setMaxDiskUsagePerPart(long maxDiskUsagePerPart) {
this.maxDiskUsagePerPart = maxDiskUsagePerPart;
}
/**
* Get the {@link #setMaxDiskUsagePerPart configured} maximum disk usage.
* @since 5.1.11
*/
public long getMaxDiskUsagePerPart() {
return this.maxDiskUsagePerPart;
}
/**
* Specify the maximum number of parts allowed in a given multipart request.
* @since 5.1.11
*/
public void setMaxParts(long maxParts) {
this.maxParts = maxParts;
}
/**
* Return the {@link #setMaxParts configured} limit on the number of parts.
* @since 5.1.11
*/
public long getMaxParts() {
return this.maxParts;
}
@Override
......@@ -94,7 +165,7 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
@Override
public Flux<Part> read(ResolvableType elementType, ReactiveHttpInputMessage message, Map<String, Object> hints) {
return Flux.create(new SynchronossPartGenerator(message, this.bufferFactory, this.streamStorageFactory))
return Flux.create(new SynchronossPartGenerator(message))
.doOnNext(part -> {
if (!Hints.isLoggingSuppressed(hints)) {
LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Parsed " +
......@@ -107,33 +178,36 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
@Override
public Mono<Part> readMono(ResolvableType elementType, ReactiveHttpInputMessage message, Map<String, Object> hints) {
return Mono.error(new UnsupportedOperationException("Cannot read multipart request body into single Part"));
public Mono<Part> readMono(
ResolvableType elementType, ReactiveHttpInputMessage message, Map<String, Object> hints) {
return Mono.error(new UnsupportedOperationException(
"Cannot read multipart request body into single Part"));
}
/**
* Consume and feed input to the Synchronoss parser, then listen for parser
* output events and adapt to {@code Flux<Sink<Part>>}.
* Subscribe to the input stream and feed the Synchronoss parser. Then listen
* for parser output, creating parts, and pushing them into the FluxSink.
*/
private static class SynchronossPartGenerator implements Consumer<FluxSink<Part>> {
private class SynchronossPartGenerator extends BaseSubscriber<DataBuffer> implements Consumer<FluxSink<Part>> {
private final ReactiveHttpInputMessage inputMessage;
private final DataBufferFactory bufferFactory;
private final LimitedPartBodyStreamStorageFactory storageFactory = new LimitedPartBodyStreamStorageFactory();
private final PartBodyStreamStorageFactory streamStorageFactory;
private NioMultipartParserListener listener;
SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage, DataBufferFactory bufferFactory,
PartBodyStreamStorageFactory streamStorageFactory) {
private NioMultipartParser parser;
public SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage) {
this.inputMessage = inputMessage;
this.bufferFactory = bufferFactory;
this.streamStorageFactory = streamStorageFactory;
}
@Override
public void accept(FluxSink<Part> emitter) {
public void accept(FluxSink<Part> sink) {
HttpHeaders headers = this.inputMessage.getHeaders();
MediaType mediaType = headers.getContentType();
Assert.state(mediaType != null, "No content type set");
......@@ -142,40 +216,57 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
Charset charset = Optional.ofNullable(mediaType.getCharset()).orElse(StandardCharsets.UTF_8);
MultipartContext context = new MultipartContext(mediaType.toString(), length, charset.name());
NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter, this.bufferFactory, context);
NioMultipartParser parser = Multipart
this.listener = new FluxSinkAdapterListener(sink, context, this.storageFactory);
this.parser = Multipart
.multipart(context)
.usePartBodyStreamStorageFactory(this.streamStorageFactory)
.forNIO(listener);
this.inputMessage.getBody().subscribe(buffer -> {
byte[] resultBytes = new byte[buffer.readableByteCount()];
buffer.read(resultBytes);
try {
parser.write(resultBytes);
}
catch (IOException ex) {
listener.onError("Exception thrown providing input to the parser", ex);
}
finally {
DataBufferUtils.release(buffer);
}
}, ex -> {
try {
listener.onError("Request body input error", ex);
parser.close();
}
catch (IOException ex2) {
listener.onError("Exception thrown while closing the parser", ex2);
}
}, () -> {
try {
parser.close();
}
catch (IOException ex) {
listener.onError("Exception thrown while closing the parser", ex);
}
});
.usePartBodyStreamStorageFactory(this.storageFactory)
.forNIO(this.listener);
this.inputMessage.getBody().subscribe(this);
}
@Override
protected void hookOnNext(DataBuffer buffer) {
int size = buffer.readableByteCount();
this.storageFactory.increaseByteCount(size);
byte[] resultBytes = new byte[size];
buffer.read(resultBytes);
try {
this.parser.write(resultBytes);
}
catch (IOException ex) {
cancel();
int index = this.storageFactory.getCurrentPartIndex();
this.listener.onError("Parser error for part [" + index + "]", ex);
}
finally {
DataBufferUtils.release(buffer);
}
}
@Override
protected void hookOnError(Throwable ex) {
try {
this.parser.close();
}
catch (IOException ex2) {
// ignore
}
finally {
int index = this.storageFactory.getCurrentPartIndex();
this.listener.onError("Failure while parsing part[" + index + "]", ex);
}
}
@Override
protected void hookFinally(SignalType type) {
try {
this.parser.close();
}
catch (IOException ex) {
this.listener.onError("Error while closing parser", ex);
}
}
private int getContentLength(HttpHeaders headers) {
......@@ -186,6 +277,54 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
}
private class LimitedPartBodyStreamStorageFactory implements PartBodyStreamStorageFactory {
private final PartBodyStreamStorageFactory storageFactory = maxInMemorySize > 0 ?
new DefaultPartBodyStreamStorageFactory(maxInMemorySize) :
new DefaultPartBodyStreamStorageFactory();
private int index = 1;
private boolean isFilePart;
private long partSize;
public int getCurrentPartIndex() {
return this.index;
}
@Override
public StreamStorage newStreamStorageForPartBody(Map<String, List<String>> headers, int index) {
this.index = index;
this.isFilePart = (MultipartUtils.getFileName(headers) != null);
this.partSize = 0;
if (maxParts > 0 && index > maxParts) {
throw new DecodingException("Too many parts (" + index + " allowed)");
}
return this.storageFactory.newStreamStorageForPartBody(headers, index);
}
public void increaseByteCount(long byteCount) {
this.partSize += byteCount;
if (maxInMemorySize > 0 && !this.isFilePart && this.partSize >= maxInMemorySize) {
throw new DataBufferLimitException("Part[" + this.index + "] " +
"exceeded the in-memory limit of " + maxInMemorySize + " bytes");
}
if (maxDiskUsagePerPart > 0 && this.isFilePart && this.partSize > maxDiskUsagePerPart) {
throw new DecodingException("Part[" + this.index + "] " +
"exceeded the disk usage limit of " + maxDiskUsagePerPart + " bytes");
}
}
public void partFinished() {
this.index++;
this.isFilePart = false;
this.partSize = 0;
}
}
/**
* Listen for parser output and adapt to {@code Flux<Sink<Part>>}.
*/
......@@ -193,43 +332,48 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
private final FluxSink<Part> sink;
private final DataBufferFactory bufferFactory;
private final MultipartContext context;
private final LimitedPartBodyStreamStorageFactory storageFactory;
private final AtomicInteger terminated = new AtomicInteger(0);
FluxSinkAdapterListener(FluxSink<Part> sink, DataBufferFactory factory, MultipartContext context) {
FluxSinkAdapterListener(
FluxSink<Part> sink, MultipartContext context, LimitedPartBodyStreamStorageFactory factory) {
this.sink = sink;
this.bufferFactory = factory;
this.context = context;
this.storageFactory = factory;
}
@Override
public void onPartFinished(StreamStorage storage, Map<String, List<String>> headers) {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(headers);
this.storageFactory.partFinished();
this.sink.next(createPart(storage, httpHeaders));
}
private Part createPart(StreamStorage storage, HttpHeaders httpHeaders) {
String filename = MultipartUtils.getFileName(httpHeaders);
if (filename != null) {
return new SynchronossFilePart(httpHeaders, filename, storage, this.bufferFactory);
return new SynchronossFilePart(httpHeaders, filename, storage);
}
else if (MultipartUtils.isFormField(httpHeaders, this.context)) {
String value = MultipartUtils.readFormParameterValue(storage, httpHeaders);
return new SynchronossFormFieldPart(httpHeaders, this.bufferFactory, value);
return new SynchronossFormFieldPart(httpHeaders, value);
}
else {
return new SynchronossPart(httpHeaders, storage, this.bufferFactory);
return new SynchronossPart(httpHeaders, storage);
}
}
@Override
public void onError(String message, Throwable cause) {
if (this.terminated.getAndIncrement() == 0) {
this.sink.error(new RuntimeException(message, cause));
this.sink.error(new DecodingException(message, cause));
}
}
......@@ -256,14 +400,10 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
private final HttpHeaders headers;
private final DataBufferFactory bufferFactory;
AbstractSynchronossPart(HttpHeaders headers, DataBufferFactory bufferFactory) {
AbstractSynchronossPart(HttpHeaders headers) {
Assert.notNull(headers, "HttpHeaders is required");
Assert.notNull(bufferFactory, "DataBufferFactory is required");
this.name = MultipartUtils.getFieldName(headers);
this.headers = headers;
this.bufferFactory = bufferFactory;
}
@Override
......@@ -276,10 +416,6 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
return this.headers;
}
DataBufferFactory getBufferFactory() {
return this.bufferFactory;
}
@Override
public String toString() {
return "Part '" + this.name + "', headers=" + this.headers;
......@@ -291,15 +427,15 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
private final StreamStorage storage;
SynchronossPart(HttpHeaders headers, StreamStorage storage, DataBufferFactory factory) {
super(headers, factory);
SynchronossPart(HttpHeaders headers, StreamStorage storage) {
super(headers);
Assert.notNull(storage, "StreamStorage is required");
this.storage = storage;
}
@Override
public Flux<DataBuffer> content() {
return DataBufferUtils.readInputStream(getStorage()::getInputStream, getBufferFactory(), 4096);
return DataBufferUtils.readInputStream(getStorage()::getInputStream, bufferFactory, 4096);
}
protected StreamStorage getStorage() {
......@@ -315,8 +451,8 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
private final String filename;
SynchronossFilePart(HttpHeaders headers, String filename, StreamStorage storage, DataBufferFactory factory) {
super(headers, storage, factory);
SynchronossFilePart(HttpHeaders headers, String filename, StreamStorage storage) {
super(headers, storage);
this.filename = filename;
}
......@@ -375,8 +511,8 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
private final String content;
SynchronossFormFieldPart(HttpHeaders headers, DataBufferFactory bufferFactory, String content) {
super(headers, bufferFactory);
SynchronossFormFieldPart(HttpHeaders headers, String content) {
super(headers);
this.content = content;
}
......@@ -388,9 +524,7 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
@Override
public Flux<DataBuffer> content() {
byte[] bytes = this.content.getBytes(getCharset());
DataBuffer buffer = getBufferFactory().allocateBuffer(bytes.length);
buffer.write(bytes);
return Flux.just(buffer);
return Flux.just(bufferFactory.wrap(bytes));
}
private Charset getCharset() {
......
......@@ -36,6 +36,7 @@ import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
......@@ -73,7 +74,7 @@ import org.springframework.util.MimeType;
public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Message> {
/** The default max size for aggregating messages. */
protected static final int DEFAULT_MESSAGE_MAX_SIZE = 64 * 1024;
protected static final int DEFAULT_MESSAGE_MAX_SIZE = 256 * 1024;
private static final ConcurrentMap<Class<?>, Method> methodCache = new ConcurrentReferenceHashMap<>();
......@@ -101,10 +102,23 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
}
/**
* The max size allowed per message.
* <p>By default, this is set to 256K.
* @param maxMessageSize the max size per message, or -1 for unlimited
*/
public void setMaxMessageSize(int maxMessageSize) {
this.maxMessageSize = maxMessageSize;
}
/**
* Return the {@link #setMaxMessageSize configured} message size limit.
* @since 5.1.11
*/
public int getMaxMessageSize() {
return this.maxMessageSize;
}
@Override
public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) {
......@@ -127,7 +141,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
public Mono<Message> decodeToMono(Publisher<DataBuffer> inputStream, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
return DataBufferUtils.join(inputStream)
return DataBufferUtils.join(inputStream, this.maxMessageSize)
.map(dataBuffer -> decode(dataBuffer, elementType, mimeType, hints));
}
......@@ -204,9 +218,9 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
if (!readMessageSize(input)) {
return messages;
}
if (this.messageBytesToRead > this.maxMessageSize) {
throw new DecodingException(
"The number of bytes to read from the incoming stream " +
if (this.maxMessageSize > 0 && this.messageBytesToRead > this.maxMessageSize) {
throw new DataBufferLimitException(
"The number of bytes to read for message " +
"(" + this.messageBytesToRead + ") exceeds " +
"the configured limit (" + this.maxMessageSize + ")");
}
......
......@@ -20,6 +20,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.springframework.core.codec.AbstractDataBufferDecoder;
import org.springframework.core.codec.ByteArrayDecoder;
import org.springframework.core.codec.ByteArrayEncoder;
import org.springframework.core.codec.ByteBufferDecoder;
......@@ -29,6 +30,7 @@ import org.springframework.core.codec.DataBufferDecoder;
import org.springframework.core.codec.DataBufferEncoder;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.codec.ResourceDecoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.http.codec.CodecConfigurer;
import org.springframework.http.codec.DecoderHttpMessageReader;
......@@ -38,6 +40,7 @@ import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.codec.ResourceHttpMessageReader;
import org.springframework.http.codec.ResourceHttpMessageWriter;
import org.springframework.http.codec.json.AbstractJackson2Decoder;
import org.springframework.http.codec.json.Jackson2JsonDecoder;
import org.springframework.http.codec.json.Jackson2JsonEncoder;
import org.springframework.http.codec.json.Jackson2SmileDecoder;
......@@ -95,6 +98,9 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs {
@Nullable
private Encoder<?> jaxb2Encoder;
@Nullable
private Integer maxInMemorySize;
private boolean enableLoggingRequestDetails = false;
private boolean registerDefaults = true;
......@@ -130,6 +136,16 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs {
this.jaxb2Encoder = encoder;
}
@Override
public void maxInMemorySize(int byteCount) {
this.maxInMemorySize = byteCount;
}
@Nullable
protected Integer maxInMemorySize() {
return this.maxInMemorySize;
}
@Override
public void enableLoggingRequestDetails(boolean enable) {
this.enableLoggingRequestDetails = enable;
......@@ -155,17 +171,20 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs {
return Collections.emptyList();
}
List<HttpMessageReader<?>> readers = new ArrayList<>();
readers.add(new DecoderHttpMessageReader<>(new ByteArrayDecoder()));
readers.add(new DecoderHttpMessageReader<>(new ByteBufferDecoder()));
readers.add(new DecoderHttpMessageReader<>(new DataBufferDecoder()));
readers.add(new ResourceHttpMessageReader());
readers.add(new DecoderHttpMessageReader<>(StringDecoder.textPlainOnly()));
readers.add(new DecoderHttpMessageReader<>(init(new ByteArrayDecoder())));
readers.add(new DecoderHttpMessageReader<>(init(new ByteBufferDecoder())));
readers.add(new DecoderHttpMessageReader<>(init(new DataBufferDecoder())));
readers.add(new ResourceHttpMessageReader(init(new ResourceDecoder())));
readers.add(new DecoderHttpMessageReader<>(init(StringDecoder.textPlainOnly())));
if (protobufPresent) {
Decoder<?> decoder = this.protobufDecoder != null ? this.protobufDecoder : new ProtobufDecoder();
Decoder<?> decoder = this.protobufDecoder != null ? this.protobufDecoder : init(new ProtobufDecoder());
readers.add(new DecoderHttpMessageReader<>(decoder));
}
FormHttpMessageReader formReader = new FormHttpMessageReader();
if (this.maxInMemorySize != null) {
formReader.setMaxInMemorySize(this.maxInMemorySize);
}
formReader.setEnableLoggingRequestDetails(this.enableLoggingRequestDetails);
readers.add(formReader);
......@@ -174,6 +193,28 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs {
return readers;
}
private <T extends Decoder<?>> T init(T decoder) {
if (this.maxInMemorySize != null) {
if (decoder instanceof AbstractDataBufferDecoder) {
((AbstractDataBufferDecoder<?>) decoder).setMaxInMemorySize(this.maxInMemorySize);
}
if (decoder instanceof ProtobufDecoder) {
((ProtobufDecoder) decoder).setMaxMessageSize(this.maxInMemorySize);
}
if (jackson2Present) {
if (decoder instanceof AbstractJackson2Decoder) {
((AbstractJackson2Decoder) decoder).setMaxInMemorySize(this.maxInMemorySize);
}
}
if (jaxb2Present) {
if (decoder instanceof Jaxb2XmlDecoder) {
((Jaxb2XmlDecoder) decoder).setMaxInMemorySize(this.maxInMemorySize);
}
}
}
return decoder;
}
/**
* Hook for client or server specific typed readers.
*/
......@@ -189,13 +230,13 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs {
}
List<HttpMessageReader<?>> readers = new ArrayList<>();
if (jackson2Present) {
readers.add(new DecoderHttpMessageReader<>(getJackson2JsonDecoder()));
readers.add(new DecoderHttpMessageReader<>(init(getJackson2JsonDecoder())));
}
if (jackson2SmilePresent) {
readers.add(new DecoderHttpMessageReader<>(new Jackson2SmileDecoder()));
readers.add(new DecoderHttpMessageReader<>(init(new Jackson2SmileDecoder())));
}
if (jaxb2Present) {
Decoder<?> decoder = this.jaxb2Decoder != null ? this.jaxb2Decoder : new Jaxb2XmlDecoder();
Decoder<?> decoder = this.jaxb2Decoder != null ? this.jaxb2Decoder : init(new Jaxb2XmlDecoder());
readers.add(new DecoderHttpMessageReader<>(decoder));
}
extendObjectReaders(readers);
......@@ -216,7 +257,7 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs {
return Collections.emptyList();
}
List<HttpMessageReader<?>> result = new ArrayList<>();
result.add(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes()));
result.add(new DecoderHttpMessageReader<>(init(StringDecoder.allMimeTypes())));
return result;
}
......
......@@ -39,10 +39,18 @@ class ServerDefaultCodecsImpl extends BaseDefaultCodecs implements ServerCodecCo
DefaultServerCodecConfigurer.class.getClassLoader());
@Nullable
private HttpMessageReader<?> multipartReader;
@Nullable
private Encoder<?> sseEncoder;
@Override
public void multipartReader(HttpMessageReader<?> reader) {
this.multipartReader = reader;
}
@Override
public void serverSentEventEncoder(Encoder<?> encoder) {
this.sseEncoder = encoder;
......@@ -51,10 +59,18 @@ class ServerDefaultCodecsImpl extends BaseDefaultCodecs implements ServerCodecCo
@Override
protected void extendTypedReaders(List<HttpMessageReader<?>> typedReaders) {
if (this.multipartReader != null) {
typedReaders.add(this.multipartReader);
return;
}
if (synchronossMultipartPresent) {
boolean enable = isEnableLoggingRequestDetails();
SynchronossPartHttpMessageReader partReader = new SynchronossPartHttpMessageReader();
Integer size = maxInMemorySize();
if (size != null) {
partReader.setMaxInMemorySize(size);
}
partReader.setEnableLoggingRequestDetails(enable);
typedReaders.add(partReader);
......
......@@ -49,6 +49,7 @@ import org.springframework.core.codec.CodecException;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.codec.Hints;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.log.LogFormatUtils;
import org.springframework.lang.Nullable;
......@@ -87,6 +88,8 @@ public class Jaxb2XmlDecoder extends AbstractDecoder<Object> {
private Function<Unmarshaller, Unmarshaller> unmarshallerProcessor = Function.identity();
private int maxInMemorySize = 256 * 1024;
public Jaxb2XmlDecoder() {
super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML);
......@@ -119,6 +122,28 @@ public class Jaxb2XmlDecoder extends AbstractDecoder<Object> {
return this.unmarshallerProcessor;
}
/**
* Set the max number of bytes that can be buffered by this decoder.
* This is either the size of the entire input when decoding as a whole, or when
* using async parsing with Aalto XML, it is the size of one top-level XML tree.
* When the limit is exceeded, {@link DataBufferLimitException} is raised.
* <p>By default this is set to 256K.
* @param byteCount the max number of bytes to buffer, or -1 for unlimited
* @since 5.1.11
*/
public void setMaxInMemorySize(int byteCount) {
this.maxInMemorySize = byteCount;
this.xmlEventDecoder.setMaxInMemorySize(byteCount);
}
/**
* Return the {@link #setMaxInMemorySize configured} byte count limit.
* @since 5.1.11
*/
public int getMaxInMemorySize() {
return this.maxInMemorySize;
}
@Override
public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) {
......@@ -153,7 +178,7 @@ public class Jaxb2XmlDecoder extends AbstractDecoder<Object> {
public Mono<Object> decodeToMono(Publisher<DataBuffer> input, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
return DataBufferUtils.join(input)
return DataBufferUtils.join(input, this.maxInMemorySize)
.map(dataBuffer -> decode(dataBuffer, elementType, mimeType, hints));
}
......
......@@ -40,6 +40,7 @@ import reactor.core.publisher.Flux;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.AbstractDecoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
......@@ -89,26 +90,50 @@ public class XmlEventDecoder extends AbstractDecoder<XMLEvent> {
boolean useAalto = aaltoPresent;
private int maxInMemorySize = 256 * 1024;
public XmlEventDecoder() {
super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML);
}
/**
* Set the max number of bytes that can be buffered by this decoder. This
* is either the size the entire input when decoding as a whole, or when
* using async parsing via Aalto XML, it is size one top-level XML tree.
* When the limit is exceeded, {@link DataBufferLimitException} is raised.
* <p>By default this is set to 256K.
* @param byteCount the max number of bytes to buffer, or -1 for unlimited
* @since 5.1.11
*/
public void setMaxInMemorySize(int byteCount) {
this.maxInMemorySize = byteCount;
}
/**
* Return the {@link #setMaxInMemorySize configured} byte count limit.
* @since 5.1.11
*/
public int getMaxInMemorySize() {
return this.maxInMemorySize;
}
@Override
@SuppressWarnings({"rawtypes", "unchecked", "cast"}) // XMLEventReader is Iterator<Object> on JDK 9
public Flux<XMLEvent> decode(Publisher<DataBuffer> input, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
if (this.useAalto) {
AaltoDataBufferToXmlEvent mapper = new AaltoDataBufferToXmlEvent();
AaltoDataBufferToXmlEvent mapper = new AaltoDataBufferToXmlEvent(this.maxInMemorySize);
return Flux.from(input)
.flatMapIterable(mapper)
.doFinally(signalType -> mapper.endOfInput());
}
else {
return DataBufferUtils.join(input).
flatMapIterable(buffer -> {
return DataBufferUtils.join(input, this.maxInMemorySize)
.flatMapIterable(buffer -> {
try {
InputStream is = buffer.asInputStream();
Iterator eventReader = inputFactory.createXMLEventReader(is);
......@@ -140,10 +165,22 @@ public class XmlEventDecoder extends AbstractDecoder<XMLEvent> {
private final XMLEventAllocator eventAllocator = EventAllocatorImpl.getDefaultInstance();
private final int maxInMemorySize;
private int byteCount;
private int elementDepth;
public AaltoDataBufferToXmlEvent(int maxInMemorySize) {
this.maxInMemorySize = maxInMemorySize;
}
@Override
public List<? extends XMLEvent> apply(DataBuffer dataBuffer) {
try {
increaseByteCount(dataBuffer);
this.streamReader.getInputFeeder().feedInput(dataBuffer.asByteBuffer());
List<XMLEvent> events = new ArrayList<>();
while (true) {
......@@ -157,8 +194,12 @@ public class XmlEventDecoder extends AbstractDecoder<XMLEvent> {
if (event.isEndDocument()) {
break;
}
checkDepthAndResetByteCount(event);
}
}
if (this.maxInMemorySize > 0 && this.byteCount > this.maxInMemorySize) {
raiseLimitException();
}
return events;
}
catch (XMLStreamException ex) {
......@@ -169,9 +210,40 @@ public class XmlEventDecoder extends AbstractDecoder<XMLEvent> {
}
}
private void increaseByteCount(DataBuffer dataBuffer) {
if (this.maxInMemorySize > 0) {
if (dataBuffer.readableByteCount() > Integer.MAX_VALUE - this.byteCount) {
raiseLimitException();
}
else {
this.byteCount += dataBuffer.readableByteCount();
}
}
}
private void checkDepthAndResetByteCount(XMLEvent event) {
if (this.maxInMemorySize > 0) {
if (event.isStartElement()) {
this.byteCount = this.elementDepth == 1 ? 0 : this.byteCount;
this.elementDepth++;
}
else if (event.isEndElement()) {
this.elementDepth--;
this.byteCount = this.elementDepth == 1 ? 0 : this.byteCount;
}
}
}
private void raiseLimitException() {
throw new DataBufferLimitException(
"Exceeded limit on max bytes per XML top-level node: " + this.maxInMemorySize);
}
public void endOfInput() {
this.streamReader.getInputFeeder().endOfInput();
}
}
}
......@@ -20,7 +20,6 @@ import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.function.Consumer;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.TreeNode;
......@@ -36,6 +35,7 @@ import reactor.test.StepVerifier;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.buffer.AbstractLeakCheckingTests;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
......@@ -181,11 +181,68 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests {
testTokenize(asList("[1", ",2,", "3]"), asList("1", "2", "3"), true);
}
private void testTokenize(List<String> input, List<String> output, boolean tokenize) {
StepVerifier.FirstStep<String> builder = StepVerifier.create(decode(input, tokenize, -1));
output.forEach(expected -> builder.assertNext(actual -> {
try {
JSONAssert.assertEquals(expected, actual, true);
}
catch (JSONException ex) {
throw new RuntimeException(ex);
}
}));
builder.verifyComplete();
}
@Test
public void testLimit() {
List<String> source = asList("[",
"{", "\"id\":1,\"name\":\"Dan\"", "},",
"{", "\"id\":2,\"name\":\"Ron\"", "},",
"{", "\"id\":3,\"name\":\"Bartholomew\"", "}",
"]");
String expected = String.join("", source);
int maxInMemorySize = expected.length();
StepVerifier.create(decode(source, false, maxInMemorySize))
.expectNext(expected)
.verifyComplete();
StepVerifier.create(decode(source, false, maxInMemorySize - 1))
.expectError(DataBufferLimitException.class);
}
@Test
public void testLimitTokenized() {
List<String> source = asList("[",
"{", "\"id\":1, \"name\":\"Dan\"", "},",
"{", "\"id\":2, \"name\":\"Ron\"", "},",
"{", "\"id\":3, \"name\":\"Bartholomew\"", "}",
"]");
String expected = "{\"id\":3,\"name\":\"Bartholomew\"}";
int maxInMemorySize = expected.length();
StepVerifier.create(decode(source, true, maxInMemorySize))
.expectNext("{\"id\":1,\"name\":\"Dan\"}")
.expectNext("{\"id\":2,\"name\":\"Ron\"}")
.expectNext(expected)
.verifyComplete();
StepVerifier.create(decode(source, true, maxInMemorySize - 1))
.expectNext("{\"id\":1,\"name\":\"Dan\"}")
.expectNext("{\"id\":2,\"name\":\"Ron\"}")
.verifyError(DataBufferLimitException.class);
}
@Test
public void errorInStream() {
DataBuffer buffer = stringBuffer("{\"id\":1,\"name\":");
Flux<DataBuffer> source = Flux.just(buffer).concatWith(Flux.error(new RuntimeException()));
Flux<TokenBuffer> result = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, true);
Flux<TokenBuffer> result = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, true, -1);
StepVerifier.create(result)
.expectError(RuntimeException.class)
......@@ -195,7 +252,7 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests {
@Test // SPR-16521
public void jsonEOFExceptionIsWrappedAsDecodingError() {
Flux<DataBuffer> source = Flux.just(stringBuffer("{\"status\": \"noClosingQuote}"));
Flux<TokenBuffer> tokens = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, false);
Flux<TokenBuffer> tokens = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, false, -1);
StepVerifier.create(tokens)
.expectError(DecodingException.class)
......@@ -203,12 +260,13 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests {
}
private void testTokenize(List<String> source, List<String> expected, boolean tokenizeArrayElements) {
private Flux<String> decode(List<String> source, boolean tokenize, int maxInMemorySize) {
Flux<TokenBuffer> tokens = Jackson2Tokenizer.tokenize(
Flux.fromIterable(source).map(this::stringBuffer),
this.jsonFactory, this.objectMapper, tokenizeArrayElements);
this.jsonFactory, this.objectMapper, tokenize, maxInMemorySize);
Flux<String> result = tokens
return tokens
.map(tokenBuffer -> {
try {
TreeNode root = this.objectMapper.readTree(tokenBuffer.asParser());
......@@ -218,10 +276,6 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests {
throw new UncheckedIOException(ex);
}
});
StepVerifier.FirstStep<String> builder = StepVerifier.create(result);
expected.forEach(s -> builder.assertNext(new JSONAssertConsumer(s)));
builder.verifyComplete();
}
private DataBuffer stringBuffer(String value) {
......@@ -231,24 +285,4 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests {
return buffer;
}
private static class JSONAssertConsumer implements Consumer<String> {
private final String expected;
JSONAssertConsumer(String expected) {
this.expected = expected;
}
@Override
public void accept(String s) {
try {
JSONAssert.assertEquals(this.expected, s, true);
}
catch (JSONException ex) {
throw new RuntimeException(ex);
}
}
}
}
......@@ -17,19 +17,26 @@
package org.springframework.http.codec.multipart;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Map;
import java.util.function.Consumer;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.buffer.AbstractLeakCheckingTests;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.core.io.buffer.support.DataBufferTestUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.MultipartBodyBuilder;
......@@ -49,17 +56,20 @@ import static org.springframework.http.MediaType.MULTIPART_FORM_DATA;
*
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
* @author Brian Clozel
*/
public class SynchronossPartHttpMessageReaderTests {
public class SynchronossPartHttpMessageReaderTests extends AbstractLeakCheckingTests {
private final MultipartHttpMessageReader reader =
new MultipartHttpMessageReader(new SynchronossPartHttpMessageReader());
private static final ResolvableType PARTS_ELEMENT_TYPE =
forClassWithGenerics(MultiValueMap.class, String.class, Part.class);
@Test
public void canRead() {
void canRead() {
assertThat(this.reader.canRead(
forClassWithGenerics(MultiValueMap.class, String.class, Part.class),
PARTS_ELEMENT_TYPE,
MediaType.MULTIPART_FORM_DATA)).isTrue();
assertThat(this.reader.canRead(
......@@ -80,43 +90,36 @@ public class SynchronossPartHttpMessageReaderTests {
}
@Test
public void resolveParts() {
void resolveParts() {
ServerHttpRequest request = generateMultipartRequest();
ResolvableType elementType = forClassWithGenerics(MultiValueMap.class, String.class, Part.class);
MultiValueMap<String, Part> parts = this.reader.readMono(elementType, request, emptyMap()).block();
assertThat(parts.size()).isEqualTo(2);
assertThat(parts.containsKey("fooPart")).isTrue();
Part part = parts.getFirst("fooPart");
boolean condition1 = part instanceof FilePart;
assertThat(condition1).isTrue();
assertThat(part.name()).isEqualTo("fooPart");
MultiValueMap<String, Part> parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()).block();
assertThat(parts).containsOnlyKeys("filePart", "textPart");
Part part = parts.getFirst("filePart");
assertThat(part).isInstanceOf(FilePart.class);
assertThat(part.name()).isEqualTo("filePart");
assertThat(((FilePart) part).filename()).isEqualTo("foo.txt");
DataBuffer buffer = DataBufferUtils.join(part.content()).block();
assertThat(buffer.readableByteCount()).isEqualTo(12);
byte[] byteContent = new byte[12];
buffer.read(byteContent);
assertThat(new String(byteContent)).isEqualTo("Lorem Ipsum.");
assertThat(parts.containsKey("barPart")).isTrue();
part = parts.getFirst("barPart");
boolean condition = part instanceof FormFieldPart;
assertThat(condition).isTrue();
assertThat(part.name()).isEqualTo("barPart");
assertThat(((FormFieldPart) part).value()).isEqualTo("bar");
assertThat(DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8)).isEqualTo("Lorem Ipsum.");
DataBufferUtils.release(buffer);
part = parts.getFirst("textPart");
assertThat(part).isInstanceOf(FormFieldPart.class);
assertThat(part.name()).isEqualTo("textPart");
assertThat(((FormFieldPart) part).value()).isEqualTo("sample-text");
}
@Test // SPR-16545
public void transferTo() {
void transferTo() throws IOException {
ServerHttpRequest request = generateMultipartRequest();
ResolvableType elementType = forClassWithGenerics(MultiValueMap.class, String.class, Part.class);
MultiValueMap<String, Part> parts = this.reader.readMono(elementType, request, emptyMap()).block();
MultiValueMap<String, Part> parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()).block();
assertThat(parts).isNotNull();
FilePart part = (FilePart) parts.getFirst("fooPart");
FilePart part = (FilePart) parts.getFirst("filePart");
assertThat(part).isNotNull();
File dest = new File(System.getProperty("java.io.tmpdir") + "/" + part.filename());
File dest = File.createTempFile(part.filename(), "multipart");
part.transferTo(dest).block(Duration.ofSeconds(5));
assertThat(dest.exists()).isTrue();
......@@ -125,33 +128,95 @@ public class SynchronossPartHttpMessageReaderTests {
}
@Test
public void bodyError() {
void bodyError() {
ServerHttpRequest request = generateErrorMultipartRequest();
ResolvableType elementType = forClassWithGenerics(MultiValueMap.class, String.class, Part.class);
StepVerifier.create(this.reader.readMono(elementType, request, emptyMap())).verifyError();
StepVerifier.create(this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap())).verifyError();
}
@Test
void readPartsWithoutDemand() {
ServerHttpRequest request = generateMultipartRequest();
Mono<MultiValueMap<String, Part>> parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap());
ZeroDemandSubscriber subscriber = new ZeroDemandSubscriber();
parts.subscribe(subscriber);
subscriber.cancel();
}
@Test
void readTooManyParts() {
testMultipartExceptions(reader -> reader.setMaxParts(1), ex -> {
assertThat(ex)
.isInstanceOf(DecodingException.class)
.hasMessageStartingWith("Failure while parsing part[2]");
assertThat(ex.getCause())
.hasMessage("Too many parts (2 allowed)");
}
);
}
private ServerHttpRequest generateMultipartRequest() {
@Test
void readFilePartTooBig() {
testMultipartExceptions(reader -> reader.setMaxDiskUsagePerPart(5), ex -> {
assertThat(ex)
.isInstanceOf(DecodingException.class)
.hasMessageStartingWith("Failure while parsing part[1]");
assertThat(ex.getCause())
.hasMessage("Part[1] exceeded the disk usage limit of 5 bytes");
}
);
}
@Test
void readPartHeadersTooBig() {
testMultipartExceptions(reader -> reader.setMaxInMemorySize(1), ex -> {
assertThat(ex)
.isInstanceOf(DecodingException.class)
.hasMessageStartingWith("Failure while parsing part[1]");
assertThat(ex.getCause())
.hasMessage("Part[1] exceeded the in-memory limit of 1 bytes");
}
);
}
private void testMultipartExceptions(
Consumer<SynchronossPartHttpMessageReader> configurer, Consumer<Throwable> assertions) {
SynchronossPartHttpMessageReader reader = new SynchronossPartHttpMessageReader();
configurer.accept(reader);
MultipartHttpMessageReader multipartReader = new MultipartHttpMessageReader(reader);
StepVerifier.create(multipartReader.readMono(PARTS_ELEMENT_TYPE, generateMultipartRequest(), emptyMap()))
.consumeErrorWith(assertions)
.verify();
}
private ServerHttpRequest generateMultipartRequest() {
MultipartBodyBuilder partsBuilder = new MultipartBodyBuilder();
partsBuilder.part("fooPart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"));
partsBuilder.part("barPart", "bar");
partsBuilder.part("filePart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"));
partsBuilder.part("textPart", "sample-text");
MockClientHttpRequest outputMessage = new MockClientHttpRequest(HttpMethod.POST, "/");
new MultipartHttpMessageWriter()
.write(Mono.just(partsBuilder.build()), null, MediaType.MULTIPART_FORM_DATA, outputMessage, null)
.block(Duration.ofSeconds(5));
Flux<DataBuffer> requestBody = outputMessage.getBody()
.map(buffer -> this.bufferFactory.wrap(buffer.asByteBuffer()));
return MockServerHttpRequest.post("/")
.contentType(outputMessage.getHeaders().getContentType())
.body(outputMessage.getBody());
.body(requestBody);
}
private ServerHttpRequest generateErrorMultipartRequest() {
return MockServerHttpRequest.post("/")
.header(CONTENT_TYPE, MULTIPART_FORM_DATA.toString())
.body(Flux.just(new DefaultDataBufferFactory().wrap("invalid content".getBytes())));
.body(Flux.just(this.bufferFactory.wrap("invalid content".getBytes())));
}
private static class ZeroDemandSubscriber extends BaseSubscriber<MultiValueMap<String, Part>> {
@Override
protected void hookOnSubscribe(Subscription subscription) {
// Just subscribe without requesting
}
}
}
......@@ -36,6 +36,7 @@ import org.springframework.core.codec.DataBufferDecoder;
import org.springframework.core.codec.DataBufferEncoder;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.codec.ResourceDecoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.MediaType;
......@@ -124,13 +125,45 @@ public class ServerCodecConfigurerTests {
.filter(e -> e == encoder).orElse(null)).isSameAs(encoder);
}
@Test
public void maxInMemorySize() {
int size = 99;
this.configurer.defaultCodecs().maxInMemorySize(size);
List<HttpMessageReader<?>> readers = this.configurer.getReaders();
assertThat(readers.size()).isEqualTo(13);
assertThat(((ByteArrayDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
assertThat(((ByteBufferDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
assertThat(((DataBufferDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
ResourceHttpMessageReader resourceReader = (ResourceHttpMessageReader) nextReader(readers);
ResourceDecoder decoder = (ResourceDecoder) resourceReader.getDecoder();
assertThat(decoder.getMaxInMemorySize()).isEqualTo(size);
assertThat(((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
assertThat(((ProtobufDecoder) getNextDecoder(readers)).getMaxMessageSize()).isEqualTo(size);
assertThat(((FormHttpMessageReader) nextReader(readers)).getMaxInMemorySize()).isEqualTo(size);
assertThat(((SynchronossPartHttpMessageReader) nextReader(readers)).getMaxInMemorySize()).isEqualTo(size);
MultipartHttpMessageReader multipartReader = (MultipartHttpMessageReader) nextReader(readers);
SynchronossPartHttpMessageReader reader = (SynchronossPartHttpMessageReader) multipartReader.getPartReader();
assertThat((reader).getMaxInMemorySize()).isEqualTo(size);
assertThat(((Jackson2JsonDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
assertThat(((Jackson2SmileDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
assertThat(((Jaxb2XmlDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
assertThat(((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);
}
private Decoder<?> getNextDecoder(List<HttpMessageReader<?>> readers) {
HttpMessageReader<?> reader = readers.get(this.index.getAndIncrement());
HttpMessageReader<?> reader = nextReader(readers);
assertThat(reader.getClass()).isEqualTo(DecoderHttpMessageReader.class);
return ((DecoderHttpMessageReader<?>) reader).getDecoder();
}
private HttpMessageReader<?> nextReader(List<HttpMessageReader<?>> readers) {
return readers.get(this.index.getAndIncrement());
}
private Encoder<?> getNextEncoder(List<HttpMessageWriter<?>> writers) {
HttpMessageWriter<?> writer = writers.get(this.index.getAndIncrement());
assertThat(writer.getClass()).isEqualTo(EncoderHttpMessageWriter.class);
......
......@@ -28,6 +28,7 @@ import reactor.test.StepVerifier;
import org.springframework.core.io.buffer.AbstractLeakCheckingTests;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import static org.assertj.core.api.Assertions.assertThat;
......@@ -44,11 +45,12 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests {
private XmlEventDecoder decoder = new XmlEventDecoder();
@Test
public void toXMLEventsAalto() {
Flux<XMLEvent> events =
this.decoder.decode(stringBuffer(XML), null, null, Collections.emptyMap());
this.decoder.decode(stringBufferMono(XML), null, null, Collections.emptyMap());
StepVerifier.create(events)
.consumeNextWith(e -> assertThat(e.isStartDocument()).isTrue())
......@@ -69,7 +71,7 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests {
decoder.useAalto = false;
Flux<XMLEvent> events =
this.decoder.decode(stringBuffer(XML), null, null, Collections.emptyMap());
this.decoder.decode(stringBufferMono(XML), null, null, Collections.emptyMap());
StepVerifier.create(events)
.consumeNextWith(e -> assertThat(e.isStartDocument()).isTrue())
......@@ -86,10 +88,32 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests {
.verify();
}
@Test
public void toXMLEventsWithLimit() {
this.decoder.setMaxInMemorySize(6);
Flux<String> source = Flux.just(
"<pojo>", "<foo>", "foofoo", "</foo>", "<bar>", "barbarbar", "</bar>", "</pojo>");
Flux<XMLEvent> events = this.decoder.decode(
source.map(this::stringBuffer), null, null, Collections.emptyMap());
StepVerifier.create(events)
.consumeNextWith(e -> assertThat(e.isStartDocument()).isTrue())
.consumeNextWith(e -> assertStartElement(e, "pojo"))
.consumeNextWith(e -> assertStartElement(e, "foo"))
.consumeNextWith(e -> assertCharacters(e, "foofoo"))
.consumeNextWith(e -> assertEndElement(e, "foo"))
.consumeNextWith(e -> assertStartElement(e, "bar"))
.expectError(DataBufferLimitException.class)
.verify();
}
@Test
public void decodeErrorAalto() {
Flux<DataBuffer> source = Flux.concat(
stringBuffer("<pojo>"),
stringBufferMono("<pojo>"),
Flux.error(new RuntimeException()));
Flux<XMLEvent> events =
......@@ -107,7 +131,7 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests {
decoder.useAalto = false;
Flux<DataBuffer> source = Flux.concat(
stringBuffer("<pojo>"),
stringBufferMono("<pojo>"),
Flux.error(new RuntimeException()));
Flux<XMLEvent> events =
......@@ -133,13 +157,15 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests {
assertThat(event.asCharacters().getData()).isEqualTo(expectedData);
}
private Mono<DataBuffer> stringBuffer(String value) {
return Mono.defer(() -> {
byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length);
buffer.write(bytes);
return Mono.just(buffer);
});
private DataBuffer stringBuffer(String value) {
byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length);
buffer.write(bytes);
return buffer;
}
private Mono<DataBuffer> stringBufferMono(String value) {
return Mono.defer(() -> Mono.just(stringBuffer(value)));
}
}
......@@ -818,6 +818,33 @@ for repeated, map-like access to parts, or otherwise rely on the
`SynchronossPartHttpMessageReader` for a one-time access to `Flux<Part>`.
[[webflux-codecs-limits]]
==== Limits
`Decoder` and `HttpMessageReader` implementations that buffer some or all of the input
stream can be configured with a limit on the maximum number of bytes to buffer in memory.
In some cases buffering occurs because input is aggregated and represented as a single
object, e.g. controller method with `@RequestBody byte[]`, `x-www-form-urlencoded` data,
and so on. Buffering can also occurs with streaming, when splitting the input stream,
e.g. delimited text, a stream of JSON objects, and so on. For those streaming cases, the
limit applies to the number of bytes associted with one object in the stream.
To configure buffer sizes, you can check if a given `Decoder` or `HttpMessageReader`
exposes a `maxInMemorySize` property and if so the Javadoc will have details about default
values. In WebFlux, the `ServerCodecConfigurer` provides a
<<webflux-config-message-codecs,single place>> from where to set all codecs, through the
`maxInMemorySize` property for default codecs.
For <<webflux-codecs-multipart,Multipart parsing>> the `maxInMemorySize` property limits
the size of non-file parts. For file parts it determines the threshold at which the part
is written to disk. For file parts written to disk, there is an additional
`maxDiskUsagePerPart` property to limit the amount of disk space per part. There is also
a `maxParts` property to limit the overall number of parts in a multipart request.
To configure all 3 in WebFlux, you'll need to supply a pre-configured instance of
`MultipartHttpMessageReader` to `ServerCodecConfigurer`.
[[webflux-codecs-streaming]]
==== Streaming
[.small]#<<web.adoc#mvc-ann-async-http-streaming, Same as in Spring MVC>>#
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册