提交 6b6384a0 编写于 作者: S Sebastien Deleuze

Improve WebFlux Protobuf support

 - Update javadoc for decoding default instances
 - Refactor and simplify tests
 - Add missing tests
 - Refactor decoding with flatMapIterable instead of
   concatMap and avoid recursive call

Issue: SPR-15776
上级 8e571dec
......@@ -18,6 +18,7 @@ package org.springframework.http.codec.protobuf;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
......@@ -44,13 +45,17 @@ import org.springframework.util.MimeType;
* A {@code Decoder} that reads {@link com.google.protobuf.Message}s
* using <a href="https://developers.google.com/protocol-buffers/">Google Protocol Buffers</a>.
*
* Flux deserialized via
* <p>Flux deserialized via
* {@link #decode(Publisher, ResolvableType, MimeType, Map)} are expected to use
* <a href="https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming">delimited Protobuf messages</a>
* with the size of each message specified before the message itself. Single values deserialized
* via {@link #decodeToMono(Publisher, ResolvableType, MimeType, Map)} are expected to use
* regular Protobuf message format (without the size prepended before the message).
*
* <p>Notice that default instance of Protobuf message produces empty byte array, so
* {@code Mono.just(Msg.getDefaultInstance())} sent over the network will be deserialized
* as an empty {@link Mono}.
*
* <p>To generate {@code Message} Java classes, you need to install the {@code protoc} binary.
*
* <p>This decoder requires Protobuf 3 or higher, and supports
......@@ -108,7 +113,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
return Flux.from(inputStream)
.concatMap(new MessageDecoderFunction(elementType, this.maxMessageSize));
.flatMapIterable(new MessageDecoderFunction(elementType, this.maxMessageSize));
}
@Override
......@@ -152,7 +157,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
}
private class MessageDecoderFunction implements Function<DataBuffer, Publisher<? extends Message>> {
private class MessageDecoderFunction implements Function<DataBuffer, Iterable<? extends Message>> {
private final ResolvableType elementType;
......@@ -163,55 +168,59 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
private int messageBytesToRead;
public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize) {
this.elementType = elementType;
this.maxMessageSize = maxMessageSize;
}
// TODO Instead of the recursive call, loop over the current DataBuffer,
// produce a list of as many messages as are contained, and save any remaining bytes with flatMapIterable
@Override
public Publisher<? extends Message> apply(DataBuffer input) {
public Iterable<? extends Message> apply(DataBuffer input) {
try {
if (this.output == null) {
int firstByte = input.read();
if (firstByte == -1) {
return Flux.error(new DecodingException("Can't parse message size"));
List<Message> messages = new ArrayList<>();
int remainingBytesToRead;
int chunkBytesToRead;
do {
if (this.output == null) {
int firstByte = input.read();
if (firstByte == -1) {
throw new DecodingException("Can't parse message size");
}
this.messageBytesToRead = CodedInputStream.readRawVarint32(firstByte, input.asInputStream());
if (this.messageBytesToRead > this.maxMessageSize) {
throw new DecodingException(
"The number of bytes to read parsed in the incoming stream (" +
this.messageBytesToRead + ") exceeds the configured limit (" + this.maxMessageSize + ")");
}
this.output = input.factory().allocateBuffer(this.messageBytesToRead);
}
this.messageBytesToRead = CodedInputStream.readRawVarint32(firstByte, input.asInputStream());
if (this.messageBytesToRead > this.maxMessageSize) {
return Flux.error(new DecodingException(
"The number of bytes to read parsed in the incoming stream (" +
this.messageBytesToRead + ") exceeds the configured limit (" + this.maxMessageSize + ")"));
chunkBytesToRead = this.messageBytesToRead >= input.readableByteCount() ?
input.readableByteCount() : this.messageBytesToRead;
remainingBytesToRead = input.readableByteCount() - chunkBytesToRead;
byte[] bytesToWrite = new byte[chunkBytesToRead];
input.read(bytesToWrite, 0, chunkBytesToRead);
this.output.write(bytesToWrite);
this.messageBytesToRead -= chunkBytesToRead;
if (this.messageBytesToRead == 0) {
Message.Builder builder = getMessageBuilder(this.elementType.toClass());
builder.mergeFrom(CodedInputStream.newInstance(this.output.asByteBuffer()), extensionRegistry);
messages.add(builder.build());
DataBufferUtils.release(this.output);
this.output = null;
}
this.output = input.factory().allocateBuffer(this.messageBytesToRead);
}
int chunkBytesToRead = this.messageBytesToRead >= input.readableByteCount() ?
input.readableByteCount() : this.messageBytesToRead;
int remainingBytesToRead = input.readableByteCount() - chunkBytesToRead;
this.output.write(input.slice(input.readPosition(), chunkBytesToRead));
this.messageBytesToRead -= chunkBytesToRead;
Message message = null;
if (this.messageBytesToRead == 0) {
Message.Builder builder = getMessageBuilder(this.elementType.toClass());
builder.mergeFrom(CodedInputStream.newInstance(this.output.asByteBuffer()), extensionRegistry);
message = builder.build();
DataBufferUtils.release(this.output);
this.output = null;
}
if (remainingBytesToRead > 0) {
return Mono.justOrEmpty(message).concatWith(
apply(input.slice(input.readPosition() + chunkBytesToRead, remainingBytesToRead)));
}
else {
return Mono.justOrEmpty(message);
}
} while (remainingBytesToRead > 0);
return messages;
}
catch (IOException ex) {
return Flux.error(new DecodingException("I/O error while parsing input stream", ex));
throw new DecodingException("I/O error while parsing input stream", ex);
}
catch (Exception ex) {
return Flux.error(new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex));
throw new DecodingException("Could not read Protobuf message: " + ex.getMessage(), ex);
}
}
}
......
......@@ -40,7 +40,7 @@ import org.springframework.util.MimeType;
* An {@code Encoder} that writes {@link com.google.protobuf.Message}s
* using <a href="https://developers.google.com/protocol-buffers/">Google Protocol Buffers</a>.
*
* Flux are serialized using
* <p>Flux are serialized using
* <a href="https://developers.google.com/protocol-buffers/docs/techniques?hl=en#streaming">delimited Protobuf messages</a>
* with the size of each message specified before the message itself. Single values are
* serialized using regular Protobuf message format (without the size prepended before the message).
......
......@@ -16,12 +16,7 @@
package org.springframework.http.codec.protobuf;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.google.protobuf.Message;
import org.junit.Before;
......@@ -47,8 +42,6 @@ import static org.springframework.core.ResolvableType.forClass;
/**
* Unit tests for {@link ProtobufDecoder}.
* TODO Make tests more readable
* TODO Add a test where an input DataBuffer is larger than a message
*
* @author Sebastien Deleuze
*/
......@@ -56,7 +49,13 @@ public class ProtobufDecoderTests extends AbstractDataBufferAllocatingTestCase {
private final static MimeType PROTOBUF_MIME_TYPE = new MimeType("application", "x-protobuf");
private final Msg testMsg = Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build();
private final SecondMsg secondMsg = SecondMsg.newBuilder().setBlah(123).build();
private final Msg testMsg = Msg.newBuilder().setFoo("Foo").setBlah(secondMsg).build();
private final SecondMsg secondMsg2 = SecondMsg.newBuilder().setBlah(456).build();
private final Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(secondMsg2).build();
private ProtobufDecoder decoder;
......@@ -82,51 +81,59 @@ public class ProtobufDecoderTests extends AbstractDataBufferAllocatingTestCase {
@Test
public void decodeToMono() {
byte[] body = this.testMsg.toByteArray();
Flux<DataBuffer> source = Flux.just(this.bufferFactory.wrap(body));
DataBuffer data = this.bufferFactory.wrap(testMsg.toByteArray());
ResolvableType elementType = forClass(Msg.class);
Mono<Message> mono = this.decoder.decodeToMono(source, elementType, null,
emptyMap());
Mono<Message> mono = this.decoder.decodeToMono(Flux.just(data), elementType, null, emptyMap());
StepVerifier.create(mono)
.expectNext(this.testMsg)
.expectNext(testMsg)
.verifyComplete();
}
@Test
public void decodeToMonoWithLargerDataBuffer() {
DataBuffer buffer = this.bufferFactory.allocateBuffer(1024);
buffer.write(testMsg.toByteArray());
ResolvableType elementType = forClass(Msg.class);
Mono<Message> mono = this.decoder.decodeToMono(Flux.just(buffer), elementType, null, emptyMap());
StepVerifier.create(mono)
.expectNext(testMsg)
.verifyComplete();
}
@Test
public void decodeChunksToMono() {
byte[] body = this.testMsg.toByteArray();
List<DataBuffer> chunks = new ArrayList<>();
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(body, 0, 4)));
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(body, 4, body.length)));
Flux<DataBuffer> source = Flux.fromIterable(chunks);
DataBuffer buffer = this.bufferFactory.wrap(testMsg.toByteArray());
Flux<DataBuffer> chunks = Flux.just(
buffer.slice(0, 4),
buffer.slice(4, buffer.readableByteCount() - 4));
DataBufferUtils.retain(buffer);
ResolvableType elementType = forClass(Msg.class);
Mono<Message> mono = this.decoder.decodeToMono(source, elementType, null,
Mono<Message> mono = this.decoder.decodeToMono(chunks, elementType, null,
emptyMap());
StepVerifier.create(mono)
.expectNext(this.testMsg)
.expectNext(testMsg)
.verifyComplete();
}
@Test
public void decode() throws IOException {
Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build();
DataBuffer buffer = bufferFactory.allocateBuffer();
OutputStream outputStream = buffer.asOutputStream();
this.testMsg.writeDelimitedTo(outputStream);
testMsg.writeDelimitedTo(buffer.asOutputStream());
DataBuffer buffer2 = bufferFactory.allocateBuffer();
OutputStream outputStream2 = buffer2.asOutputStream();
testMsg2.writeDelimitedTo(outputStream2);
testMsg2.writeDelimitedTo(buffer2.asOutputStream());
Flux<DataBuffer> source = Flux.just(buffer, buffer2);
ResolvableType elementType = forClass(Msg.class);
Flux<Message> messages = this.decoder.decode(source, elementType, null, emptyMap());
StepVerifier.create(messages)
.expectNext(this.testMsg)
.expectNext(testMsg)
.expectNext(testMsg2)
.verifyComplete();
......@@ -135,42 +142,50 @@ public class ProtobufDecoderTests extends AbstractDataBufferAllocatingTestCase {
}
@Test
public void decodeChunks() throws IOException {
Msg testMsg2 = Msg.newBuilder().setFoo("Bar").setBlah(SecondMsg.newBuilder().setBlah(456).build()).build();
List<DataBuffer> chunks = new ArrayList<>();
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
this.testMsg.writeDelimitedTo(outputStream);
byte[] byteArray = outputStream.toByteArray();
ByteArrayOutputStream outputStream2 = new ByteArrayOutputStream();
testMsg2.writeDelimitedTo(outputStream2);
byte[] byteArray2 = outputStream2.toByteArray();
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(byteArray, 0, 4)));
byte[] chunk2 = Arrays.copyOfRange(byteArray, 4, byteArray.length);
byte[] chunk3 = Arrays.copyOfRange(byteArray2, 0, 4);
byte[] combined = new byte[chunk2.length + chunk3.length];
for (int i = 0; i < combined.length; ++i)
{
combined[i] = i < chunk2.length ? chunk2[i] : chunk3[i - chunk2.length];
}
chunks.add(this.bufferFactory.wrap(combined));
chunks.add(this.bufferFactory.wrap(Arrays.copyOfRange(byteArray2, 4, byteArray2.length)));
Flux<DataBuffer> source = Flux.fromIterable(chunks);
public void decodeSplitChunks() throws IOException {
DataBuffer buffer = bufferFactory.allocateBuffer();
testMsg.writeDelimitedTo(buffer.asOutputStream());
DataBuffer buffer2 = bufferFactory.allocateBuffer();
testMsg2.writeDelimitedTo(buffer2.asOutputStream());
Flux<DataBuffer> chunks = Flux.just(
buffer.slice(0, 4),
buffer.slice(4, buffer.readableByteCount() - 4),
buffer2.slice(0, 2),
buffer2.slice(2, buffer2.readableByteCount() - 2));
ResolvableType elementType = forClass(Msg.class);
Flux<Message> messages = this.decoder.decode(source, elementType, null, emptyMap());
Flux<Message> messages = this.decoder.decode(chunks, elementType, null, emptyMap());
StepVerifier.create(messages)
.expectNext(this.testMsg)
.expectNext(testMsg)
.expectNext(testMsg2)
.verifyComplete();
DataBufferUtils.release(buffer);
DataBufferUtils.release(buffer2);
}
@Test
public void decodeMergedChunks() throws IOException {
DataBuffer buffer = bufferFactory.allocateBuffer();
testMsg.writeDelimitedTo(buffer.asOutputStream());
testMsg.writeDelimitedTo(buffer.asOutputStream());
ResolvableType elementType = forClass(Msg.class);
Flux<Message> messages = this.decoder.decode(Mono.just(buffer), elementType, null, emptyMap());
StepVerifier.create(messages)
.expectNext(testMsg)
.expectNext(testMsg)
.verifyComplete();
DataBufferUtils.release(buffer);
}
@Test
public void exceedMaxSize() {
this.decoder.setMaxMessageSize(1);
byte[] body = this.testMsg.toByteArray();
byte[] body = testMsg.toByteArray();
Flux<DataBuffer> source = Flux.just(this.bufferFactory.wrap(body));
ResolvableType elementType = forClass(Msg.class);
Flux<Message> messages = this.decoder.decode(source, elementType, null,
......
......@@ -129,6 +129,17 @@ public class ProtobufIntegrationTests extends AbstractRequestMappingIntegrationT
.verifyComplete();
}
@Test
public void defaultInstance() {
Mono<Msg> result = this.webClient.get()
.uri("/default-instance")
.retrieve()
.bodyToMono(Msg.class);
StepVerifier.create(result)
.verifyComplete();
}
@RestController
@SuppressWarnings("unused")
static class ProtobufController {
......@@ -153,6 +164,11 @@ public class ProtobufIntegrationTests extends AbstractRequestMappingIntegrationT
return Mono.empty();
}
@GetMapping("default-instance")
Mono<Msg> defaultInstance() {
return Mono.just(Msg.getDefaultInstance());
}
}
@Configuration
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册