提交 5fc18064 编写于 作者: R Rossen Stoyanchev

Use encode with an Object value where feasible

Closes gh-22782
上级 181482fa
......@@ -80,10 +80,14 @@ public class DefaultDataBuffer implements DataBuffer {
/**
* Directly exposes the native {@code ByteBuffer} that this buffer is based on.
* Directly exposes the native {@code ByteBuffer} that this buffer is based
* on also updating the {@code ByteBuffer's} position and limit to match
* the current {@link #readPosition()} and {@link #readableByteCount()}.
* @return the wrapped byte buffer
*/
public ByteBuffer getNativeBuffer() {
this.byteBuffer.position(this.readPosition);
this.byteBuffer.limit(readableByteCount());
return this.byteBuffer;
}
......
......@@ -33,7 +33,6 @@ import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
......@@ -148,7 +147,7 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler
Encoder<?> encoder = getEncoder(elementType, mimeType);
return Flux.from((Publisher) publisher).concatMap(value ->
return Flux.from((Publisher) publisher).map(value ->
encodeValue(value, elementType, encoder, bufferFactory, mimeType, hints));
}
......@@ -176,7 +175,7 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler
}
@SuppressWarnings("unchecked")
private <T> Mono<DataBuffer> encodeValue(
private <T> DataBuffer encodeValue(
Object element, ResolvableType elementType, @Nullable Encoder<T> encoder,
DataBufferFactory bufferFactory, @Nullable MimeType mimeType,
@Nullable Map<String, Object> hints) {
......@@ -184,13 +183,11 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler
if (encoder == null) {
encoder = getEncoder(ResolvableType.forInstance(element), mimeType);
if (encoder == null) {
return Mono.error(new MessagingException(
"No encoder for " + elementType + ", current value type is " + element.getClass()));
throw new MessagingException(
"No encoder for " + elementType + ", current value type is " + element.getClass());
}
}
Mono<T> mono = Mono.just((T) element);
Flux<DataBuffer> dataBuffers = encoder.encode(mono, bufferFactory, elementType, mimeType, hints);
return DataBufferUtils.join(dataBuffers);
return encoder.encodeValue((T) element, bufferFactory, elementType, mimeType, hints);
}
/**
......
......@@ -32,7 +32,6 @@ import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
......@@ -124,8 +123,10 @@ final class DefaultRSocketRequester implements RSocketRequester {
publisher = adapter.toPublisher(input);
}
else {
Mono<Payload> payloadMono = encodeValue(input, ResolvableType.forInstance(input), null)
Mono<Payload> payloadMono = Mono
.fromCallable(() -> encodeValue(input, ResolvableType.forInstance(input), null))
.map(this::firstPayload)
.doOnDiscard(Payload.class, Payload::release)
.switchIfEmpty(emptyPayload());
return new DefaultResponseSpec(payloadMono);
}
......@@ -140,36 +141,36 @@ final class DefaultRSocketRequester implements RSocketRequester {
if (adapter != null && !adapter.isMultiValue()) {
Mono<Payload> payloadMono = Mono.from(publisher)
.flatMap(value -> encodeValue(value, dataType, encoder))
.map(value -> encodeValue(value, dataType, encoder))
.map(this::firstPayload)
.switchIfEmpty(emptyPayload());
return new DefaultResponseSpec(payloadMono);
}
Flux<Payload> payloadFlux = Flux.from(publisher)
.concatMap(value -> encodeValue(value, dataType, encoder))
.map(value -> encodeValue(value, dataType, encoder))
.switchOnFirst((signal, inner) -> {
DataBuffer data = signal.get();
if (data != null) {
return Flux.concat(
Mono.just(firstPayload(data)),
inner.skip(1).map(PayloadUtils::createPayload));
return Mono.fromCallable(() -> firstPayload(data))
.concatWith(inner.skip(1).map(PayloadUtils::createPayload));
}
else {
return inner.map(PayloadUtils::createPayload);
}
})
.doOnDiscard(Payload.class, Payload::release)
.switchIfEmpty(emptyPayload());
return new DefaultResponseSpec(payloadFlux);
}
@SuppressWarnings("unchecked")
private <T> Mono<DataBuffer> encodeValue(T value, ResolvableType valueType, @Nullable Encoder<?> encoder) {
private <T> DataBuffer encodeValue(T value, ResolvableType valueType, @Nullable Encoder<?> encoder) {
if (encoder == null) {
encoder = strategies.encoder(ResolvableType.forInstance(value), dataMimeType);
}
return DataBufferUtils.join(((Encoder<T>) encoder).encode(
Mono.just(value), strategies.dataBufferFactory(), valueType, dataMimeType, EMPTY_HINTS));
return ((Encoder<T>) encoder).encodeValue(
value, strategies.dataBufferFactory(), valueType, dataMimeType, EMPTY_HINTS);
}
private Payload firstPayload(DataBuffer data) {
......
......@@ -81,7 +81,7 @@ public class MessageMappingMessageHandlerTests {
@Test
public void handleFluxString() {
MessageMappingMessageHandler messsageHandler = initMesssageHandler();
messsageHandler.handleMessage(message("fluxString", "abc\ndef\nghi")).block(Duration.ofSeconds(5));
messsageHandler.handleMessage(message("fluxString", "abc", "def", "ghi")).block(Duration.ofSeconds(5));
verifyOutputContent(Arrays.asList("abc::response", "def::response", "ghi::response"));
}
......
......@@ -129,9 +129,10 @@ public class PayloadMethodArgumentResolverTests {
@Test
public void validateStringMono() {
TestValidator validator = new TestValidator();
ResolvableType type = ResolvableType.forClassWithGenerics(Mono.class, String.class);
MethodParameter param = this.testMethod.arg(type);
Mono<Object> mono = resolveValue(param, Mono.just(toDataBuffer("12345")), new TestValidator());
Mono<Object> mono = resolveValue(param, Mono.just(toDataBuffer("12345")), validator);
StepVerifier.create(mono).expectNextCount(0)
.expectError(MethodArgumentNotValidException.class).verify();
......@@ -139,9 +140,11 @@ public class PayloadMethodArgumentResolverTests {
@Test
public void validateStringFlux() {
TestValidator validator = new TestValidator();
ResolvableType type = ResolvableType.forClassWithGenerics(Flux.class, String.class);
MethodParameter param = this.testMethod.arg(type);
Flux<Object> flux = resolveValue(param, Mono.just(toDataBuffer("12345678\n12345")), new TestValidator());
Flux<DataBuffer> content = Flux.just(toDataBuffer("12345678"), toDataBuffer("12345"));
Flux<Object> flux = resolveValue(param, content, validator);
StepVerifier.create(flux)
.expectNext("12345678")
......
......@@ -18,6 +18,7 @@ package org.springframework.http.codec;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
......@@ -111,9 +112,9 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter<Objec
}
private Flux<Publisher<DataBuffer>> encode(Publisher<?> input, ResolvableType elementType,
MediaType mediaType, DataBufferFactory factory, Map<String, Object> hints) {
MediaType mediaType, DataBufferFactory bufferFactory, Map<String, Object> hints) {
ResolvableType valueType = (ServerSentEvent.class.isAssignableFrom(elementType.toClass()) ?
ResolvableType dataType = (ServerSentEvent.class.isAssignableFrom(elementType.toClass()) ?
elementType.getGeneric() : elementType);
return Flux.from(input).map(element -> {
......@@ -143,12 +144,10 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter<Objec
sb.append("data:");
}
Flux<DataBuffer> flux = Flux.concat(
encodeText(sb, mediaType, factory),
encodeData(data, valueType, mediaType, factory, hints),
encodeText("\n", mediaType, factory));
Mono<DataBuffer> bufferMono = Mono.fromCallable(() ->
bufferFactory.join(encodeEvent(sb, data, dataType, mediaType, bufferFactory, hints)));
return flux.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
return bufferMono.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release);
});
}
......@@ -160,31 +159,32 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter<Objec
}
@SuppressWarnings("unchecked")
private <T> Flux<DataBuffer> encodeData(@Nullable T dataValue, ResolvableType valueType,
private <T> List<DataBuffer> encodeEvent(CharSequence markup, @Nullable T data, ResolvableType dataType,
MediaType mediaType, DataBufferFactory factory, Map<String, Object> hints) {
if (dataValue == null) {
return Flux.empty();
}
if (dataValue instanceof String) {
String text = (String) dataValue;
return Flux.from(encodeText(StringUtils.replace(text, "\n", "\ndata:") + "\n", mediaType, factory));
}
if (this.encoder == null) {
return Flux.error(new CodecException("No SSE encoder configured and the data is not String."));
List<DataBuffer> result = new ArrayList<>(4);
result.add(encodeText(markup, mediaType, factory));
if (data != null) {
if (data instanceof String) {
String dataLine = StringUtils.replace((String) data, "\n", "\ndata:") + "\n";
result.add(encodeText(dataLine, mediaType, factory));
}
else if (this.encoder == null) {
throw new CodecException("No SSE encoder configured and the data is not String.");
}
else {
result.add(((Encoder<T>) this.encoder).encodeValue(data, factory, dataType, mediaType, hints));
result.add(encodeText("\n", mediaType, factory));
}
}
return ((Encoder<T>) this.encoder)
.encode(Mono.just(dataValue), factory, valueType, mediaType, hints)
.concatWith(encodeText("\n", mediaType, factory));
result.add(encodeText("\n", mediaType, factory));
return result;
}
private Mono<DataBuffer> encodeText(CharSequence text, MediaType mediaType, DataBufferFactory bufferFactory) {
private DataBuffer encodeText(CharSequence text, MediaType mediaType, DataBufferFactory bufferFactory) {
Assert.notNull(mediaType.getCharset(), "Expected MediaType with charset");
byte[] bytes = text.toString().getBytes(mediaType.getCharset());
return Mono.just(bufferFactory.wrap(bytes)); // wrapping, not allocating
return bufferFactory.wrap(bytes); // wrapping, not allocating
}
@Override
......
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -40,7 +40,7 @@ import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import static org.junit.Assert.*;
import static org.springframework.core.ResolvableType.forClass;
import static org.springframework.core.ResolvableType.*;
/**
* Unit tests for {@link ServerSentEventHttpMessageWriter}.
......@@ -88,9 +88,8 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll
testWrite(source, outputMessage, ServerSentEvent.class);
StepVerifier.create(outputMessage.getBody())
.consumeNextWith(stringConsumer("id:c42\nevent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:"))
.consumeNextWith(stringConsumer("bar\n"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer(
"id:c42\nevent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:bar\n\n"))
.expectComplete()
.verify();
}
......@@ -101,12 +100,8 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll
testWrite(source, outputMessage, String.class);
StepVerifier.create(outputMessage.getBody())
.consumeNextWith(stringConsumer("data:"))
.consumeNextWith(stringConsumer("foo\n"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("data:"))
.consumeNextWith(stringConsumer("bar\n"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("data:foo\n\n"))
.consumeNextWith(stringConsumer("data:bar\n\n"))
.expectComplete()
.verify();
}
......@@ -117,12 +112,8 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll
testWrite(source, outputMessage, String.class);
StepVerifier.create(outputMessage.getBody())
.consumeNextWith(stringConsumer("data:"))
.consumeNextWith(stringConsumer("foo\ndata:bar\n"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("data:"))
.consumeNextWith(stringConsumer("foo\ndata:baz\n"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("data:foo\ndata:bar\n\n"))
.consumeNextWith(stringConsumer("data:foo\ndata:baz\n\n"))
.expectComplete()
.verify();
}
......@@ -136,14 +127,11 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll
assertEquals(mediaType, outputMessage.getHeaders().getContentType());
StepVerifier.create(outputMessage.getBody())
.consumeNextWith(stringConsumer("data:"))
.consumeNextWith(dataBuffer -> {
String value =
DataBufferTestUtils.dumpString(dataBuffer, charset);
String value = DataBufferTestUtils.dumpString(dataBuffer, charset);
DataBufferUtils.release(dataBuffer);
assertEquals("\u00A3\n", value);
assertEquals("data:\u00A3\n\n", value);
})
.consumeNextWith(stringConsumer("\n"))
.expectComplete()
.verify();
}
......@@ -154,14 +142,8 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll
testWrite(source, outputMessage, Pojo.class);
StepVerifier.create(outputMessage.getBody())
.consumeNextWith(stringConsumer("data:"))
.consumeNextWith(stringConsumer("{\"foo\":\"foofoo\",\"bar\":\"barbar\"}"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("data:"))
.consumeNextWith(stringConsumer("{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("data:{\"foo\":\"foofoo\",\"bar\":\"barbar\"}\n\n"))
.consumeNextWith(stringConsumer("data:{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}\n\n"))
.expectComplete()
.verify();
}
......@@ -175,18 +157,12 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll
testWrite(source, outputMessage, Pojo.class);
StepVerifier.create(outputMessage.getBody())
.consumeNextWith(stringConsumer("data:"))
.consumeNextWith(stringConsumer("{\n" +
.consumeNextWith(stringConsumer("data:{\n" +
"data: \"foo\" : \"foofoo\",\n" +
"data: \"bar\" : \"barbar\"\n" + "data:}"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("data:"))
.consumeNextWith(stringConsumer("{\n" +
"data: \"bar\" : \"barbar\"\n" + "data:}\n\n"))
.consumeNextWith(stringConsumer("data:{\n" +
"data: \"foo\" : \"foofoofoo\",\n" +
"data: \"bar\" : \"barbarbar\"\n" + "data:}"))
.consumeNextWith(stringConsumer("\n"))
.consumeNextWith(stringConsumer("\n"))
"data: \"bar\" : \"barbarbar\"\n" + "data:}\n\n"))
.expectComplete()
.verify();
}
......@@ -200,28 +176,10 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll
assertEquals(mediaType, outputMessage.getHeaders().getContentType());
StepVerifier.create(outputMessage.getBody())
.consumeNextWith(dataBuffer1 -> {
String value1 =
DataBufferTestUtils.dumpString(dataBuffer1, charset);
DataBufferUtils.release(dataBuffer1);
assertEquals("data:", value1);
})
.consumeNextWith(dataBuffer -> {
String value = DataBufferTestUtils.dumpString(dataBuffer, charset);
DataBufferUtils.release(dataBuffer);
assertEquals("{\"foo\":\"foo\uD834\uDD1E\",\"bar\":\"bar\uD834\uDD1E\"}", value);
})
.consumeNextWith(dataBuffer2 -> {
String value2 =
DataBufferTestUtils.dumpString(dataBuffer2, charset);
DataBufferUtils.release(dataBuffer2);
assertEquals("\n", value2);
})
.consumeNextWith(dataBuffer3 -> {
String value3 =
DataBufferTestUtils.dumpString(dataBuffer3, charset);
DataBufferUtils.release(dataBuffer3);
assertEquals("\n", value3);
assertEquals("data:{\"foo\":\"foo\uD834\uDD1E\",\"bar\":\"bar\uD834\uDD1E\"}\n\n", value);
})
.expectComplete()
.verify();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册