diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java index e579c8e11c95de6305d2e663221f4cee9cbee3dd..249ce6106e7b55290d715274ca200b5b9c2deacf 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java @@ -114,22 +114,20 @@ final class DefaultRSocketRequester implements RSocketRequester { private class DefaultRequestSpec implements RequestSpec { - private final MetadataEncoder metadataEncoder; + private final MetadataEncoder metadataEncoder = new MetadataEncoder(metadataMimeType(), strategies); @Nullable - private Mono payloadMono = emptyPayload(); + private Mono payloadMono; @Nullable - private Flux payloadFlux = null; + private Flux payloadFlux; public DefaultRequestSpec(String route, Object... vars) { - this.metadataEncoder = new MetadataEncoder(metadataMimeType(), strategies); this.metadataEncoder.route(route, vars); } public DefaultRequestSpec(Object metadata, @Nullable MimeType mimeType) { - this.metadataEncoder = new MetadataEncoder(metadataMimeType(), strategies); this.metadataEncoder.metadata(metadata, mimeType); } @@ -188,17 +186,14 @@ final class DefaultRSocketRequester implements RSocketRequester { publisher = adapter.toPublisher(input); } else { - this.payloadMono = Mono - .fromCallable(() -> encodeData(input, ResolvableType.forInstance(input), null)) - .map(this::firstPayload) - .doOnDiscard(Payload.class, Payload::release) - .switchIfEmpty(emptyPayload()); + ResolvableType type = ResolvableType.forInstance(input); + this.payloadMono = firstPayload(Mono.fromCallable(() -> encodeData(input, type, null))); this.payloadFlux = null; return; } if (isVoid(elementType) || (adapter != null && adapter.isNoValue())) { - this.payloadMono = Mono.when(publisher).then(emptyPayload()); + this.payloadMono = firstPayload(Mono.when(publisher).then(Mono.just(emptyDataBuffer))); this.payloadFlux = null; return; } @@ -207,10 +202,10 @@ final class DefaultRSocketRequester implements RSocketRequester { strategies.encoder(elementType, dataMimeType) : null; if (adapter != null && !adapter.isMultiValue()) { - this.payloadMono = Mono.from(publisher) + Mono data = Mono.from(publisher) .map(value -> encodeData(value, elementType, encoder)) - .map(this::firstPayload) - .switchIfEmpty(emptyPayload()); + .defaultIfEmpty(emptyDataBuffer); + this.payloadMono = firstPayload(data); this.payloadFlux = null; return; } @@ -218,18 +213,18 @@ final class DefaultRSocketRequester implements RSocketRequester { this.payloadMono = null; this.payloadFlux = Flux.from(publisher) .map(value -> encodeData(value, elementType, encoder)) + .defaultIfEmpty(emptyDataBuffer) .switchOnFirst((signal, inner) -> { DataBuffer data = signal.get(); if (data != null) { - return Mono.fromCallable(() -> firstPayload(data)) + return firstPayload(Mono.fromCallable(() -> data)) .concatWith(inner.skip(1).map(PayloadUtils::createPayload)); } else { return inner.map(PayloadUtils::createPayload); } }) - .doOnDiscard(Payload.class, Payload::release) - .switchIfEmpty(emptyPayload()); + .doOnDiscard(Payload.class, Payload::release); } @SuppressWarnings("unchecked") @@ -242,26 +237,25 @@ final class DefaultRSocketRequester implements RSocketRequester { value, bufferFactory(), elementType, dataMimeType, EMPTY_HINTS); } - private Payload firstPayload(DataBuffer data) { - DataBuffer metadata; - try { - metadata = this.metadataEncoder.encode(); - } - catch (Throwable ex) { - DataBufferUtils.release(data); - throw ex; - } - return PayloadUtils.createPayload(data, metadata); - } - - private Mono emptyPayload() { - return Mono.fromCallable(() -> firstPayload(emptyDataBuffer)); + /** + * Create the 1st request payload with encoded data and metadata. + * @param encodedData the encoded payload data; expected to not be empty! + */ + private Mono firstPayload(Mono encodedData) { + return Mono.zip(encodedData, this.metadataEncoder.encode()) + .map(tuple -> PayloadUtils.createPayload(tuple.getT1(), tuple.getT2())) + .doOnDiscard(DataBuffer.class, DataBufferUtils::release) + .doOnDiscard(Payload.class, Payload::release); } @Override public Mono send() { - Assert.state(this.payloadMono != null, "No RSocket interaction model for one-way send with Flux"); - return this.payloadMono.flatMap(rsocket::fireAndForget); + return getPayloadMonoRequired().flatMap(rsocket::fireAndForget); + } + + private Mono getPayloadMonoRequired() { + Assert.state(this.payloadFlux == null, "No RSocket interaction model for Flux request to Mono response."); + return this.payloadMono != null ? this.payloadMono : firstPayload(Mono.just(emptyDataBuffer)); } @Override @@ -286,8 +280,7 @@ final class DefaultRSocketRequester implements RSocketRequester { @SuppressWarnings("unchecked") private Mono retrieveMono(ResolvableType elementType) { - Assert.notNull(this.payloadMono, "No RSocket interaction model for Flux request to Mono response."); - Mono payloadMono = this.payloadMono.flatMap(rsocket::requestResponse); + Mono payloadMono = getPayloadMonoRequired().flatMap(rsocket::requestResponse); if (isVoid(elementType)) { return (Mono) payloadMono.then(); diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java index 7d634b9f82a7c8095f5e5a01d42f883b2b661d78..24d8026aa5926eb659b430e40bf0ee3d2afa40f7 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java @@ -33,6 +33,7 @@ import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; import reactor.core.publisher.Mono; +import org.springframework.core.ReactiveAdapter; import org.springframework.core.ResolvableType; import org.springframework.core.codec.Decoder; import org.springframework.core.codec.Encoder; @@ -57,6 +58,8 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { private static final Map HINTS = Collections.emptyMap(); + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + @Nullable private MimeType dataMimeType; @@ -175,50 +178,14 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { factory.dataMimeType(dataMimeType.toString()); factory.metadataMimeType(metaMimeType.toString()); - Payload setupPayload = getSetupPayload(dataMimeType, metaMimeType, rsocketStrategies); - if (setupPayload != null) { - factory.setupPayload(setupPayload); - } - - return factory.transport(transport) - .start() - .map(rsocket -> new DefaultRSocketRequester( - rsocket, dataMimeType, metaMimeType, rsocketStrategies)); - } - - @Nullable - private Payload getSetupPayload(MimeType dataMimeType, MimeType metaMimeType, RSocketStrategies strategies) { - DataBuffer metadata = null; - if (this.setupRoute != null || !CollectionUtils.isEmpty(this.setupMetadata)) { - metadata = new MetadataEncoder(metaMimeType, strategies) - .metadataAndOrRoute(this.setupMetadata, this.setupRoute, this.setupRouteVars) - .encode(); - } - DataBuffer data = null; - if (this.setupData != null) { - try { - ResolvableType type = ResolvableType.forClass(this.setupData.getClass()); - Encoder encoder = strategies.encoder(type, dataMimeType); - Assert.notNull(encoder, () -> "No encoder for " + dataMimeType + ", " + type); - data = encoder.encodeValue(this.setupData, strategies.dataBufferFactory(), type, dataMimeType, HINTS); - } - catch (Throwable ex) { - if (metadata != null) { - DataBufferUtils.release(metadata); - } - throw ex; - } - } - if (metadata == null && data == null) { - return null; - } - metadata = metadata != null ? metadata : emptyBuffer(strategies); - data = data != null ? data : emptyBuffer(strategies); - return PayloadUtils.createPayload(data, metadata); - } - - private DataBuffer emptyBuffer(RSocketStrategies strategies) { - return strategies.dataBufferFactory().wrap(new byte[0]); + return getSetupPayload(dataMimeType, metaMimeType, rsocketStrategies) + .doOnNext(factory::setupPayload) + .then(Mono.defer(() -> + factory.transport(transport) + .start() + .map(rsocket -> new DefaultRSocketRequester( + rsocket, dataMimeType, metaMimeType, rsocketStrategies)) + )); } private RSocketStrategies getRSocketStrategies() { @@ -261,4 +228,45 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { return mimeType.getParameters().isEmpty() ? mimeType : new MimeType(mimeType, Collections.emptyMap()); } + private Mono getSetupPayload( + MimeType dataMimeType, MimeType metaMimeType, RSocketStrategies strategies) { + + Object data = this.setupData; + boolean hasMetadata = (this.setupRoute != null || !CollectionUtils.isEmpty(this.setupMetadata)); + if (!hasMetadata && data == null) { + return Mono.empty(); + } + + Mono dataMono = Mono.empty(); + if (data != null) { + ReactiveAdapter adapter = strategies.reactiveAdapterRegistry().getAdapter(data.getClass()); + Assert.isTrue(adapter == null || !adapter.isMultiValue(), "Expected single value: " + data); + Mono mono = (adapter != null ? Mono.from(adapter.toPublisher(data)) : Mono.just(data)); + dataMono = mono.map(value -> { + ResolvableType type = ResolvableType.forClass(value.getClass()); + Encoder encoder = strategies.encoder(type, dataMimeType); + Assert.notNull(encoder, () -> "No encoder for " + dataMimeType + ", " + type); + return encoder.encodeValue(value, strategies.dataBufferFactory(), type, dataMimeType, HINTS); + }); + } + + Mono metaMono = Mono.empty(); + if (hasMetadata) { + metaMono = new MetadataEncoder(metaMimeType, strategies) + .metadataAndOrRoute(this.setupMetadata, this.setupRoute, this.setupRouteVars) + .encode(); + } + + Mono emptyBuffer = Mono.fromCallable(() -> + strategies.dataBufferFactory().wrap(EMPTY_BYTE_ARRAY)); + + dataMono = dataMono.switchIfEmpty(emptyBuffer); + metaMono = metaMono.switchIfEmpty(emptyBuffer); + + return Mono.zip(dataMono, metaMono) + .map(tuple -> PayloadUtils.createPayload(tuple.getT1(), tuple.getT2())) + .doOnDiscard(DataBuffer.class, DataBufferUtils::release) + .doOnDiscard(Payload.class, Payload::release); + } + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java index cfe849ef707003506d21e74aeaf6e58d52da6411..9ff1e052740431a5cd91d38e71ed46deb470f80b 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java @@ -15,8 +15,9 @@ */ package org.springframework.messaging.rsocket; +import java.util.ArrayList; import java.util.Collections; -import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -27,7 +28,9 @@ import io.netty.buffer.CompositeByteBuf; import io.rsocket.metadata.CompositeMetadataFlyweight; import io.rsocket.metadata.TaggingMetadataFlyweight; import io.rsocket.metadata.WellKnownMimeType; +import reactor.core.publisher.Mono; +import org.springframework.core.ReactiveAdapter; import org.springframework.core.ResolvableType; import org.springframework.core.codec.Encoder; import org.springframework.core.io.buffer.DataBuffer; @@ -50,6 +53,8 @@ final class MetadataEncoder { /** For route variable replacement. */ private static final Pattern VARS_PATTERN = Pattern.compile("\\{([^/]+?)}"); + private static final Object NO_VALUE = new Object(); + private final MimeType metadataMimeType; @@ -62,7 +67,9 @@ final class MetadataEncoder { @Nullable private String route; - private final Map metadata = new LinkedHashMap<>(4); + private final List metadataEntries = new ArrayList<>(4); + + private boolean hasAsyncValues; MetadataEncoder(MimeType metadataMimeType, RSocketStrategies strategies) { @@ -111,7 +118,7 @@ final class MetadataEncoder { private void assertMetadataEntryCount() { if (!this.isComposite) { - int count = this.route != null ? this.metadata.size() + 1 : this.metadata.size(); + int count = this.route != null ? this.metadataEntries.size() + 1 : this.metadataEntries.size(); Assert.isTrue(count < 2, "Composite metadata required for multiple metadata entries."); } } @@ -128,10 +135,17 @@ final class MetadataEncoder { mimeType = this.metadataMimeType; } else if (!this.metadataMimeType.equals(mimeType)) { - throw new IllegalArgumentException("Mime type is optional (may be null) " + - "but was provided and does not match the connection metadata mime type."); + throw new IllegalArgumentException( + "Mime type is optional when not using composite metadata, but it was provided " + + "and does not match the connection metadata mime type '" + this.metadataMimeType + "'."); } - this.metadata.put(metadata, mimeType); + ReactiveAdapter adapter = this.strategies.reactiveAdapterRegistry().getAdapter(metadata.getClass()); + if (adapter != null) { + Assert.isTrue(!adapter.isMultiValue(), "Expected single value: " + metadata); + metadata = Mono.from(adapter.toPublisher(metadata)).defaultIfEmpty(NO_VALUE); + this.hasAsyncValues = true; + } + this.metadataEntries.add(new MetadataEntry(metadata, mimeType)); assertMetadataEntryCount(); return this; } @@ -159,7 +173,13 @@ final class MetadataEncoder { * Encode the collected metadata entries to a {@code DataBuffer}. * @see PayloadUtils#createPayload(DataBuffer, DataBuffer) */ - public DataBuffer encode() { + public Mono encode() { + return this.hasAsyncValues ? + resolveAsyncMetadata().map(this::encodeEntries) : + Mono.fromCallable(() -> encodeEntries(this.metadataEntries)); + } + + private DataBuffer encodeEntries(List entries) { if (this.isComposite) { CompositeByteBuf composite = this.allocator.compositeBuffer(); try { @@ -167,11 +187,11 @@ final class MetadataEncoder { CompositeMetadataFlyweight.encodeAndAddMetadata(composite, this.allocator, WellKnownMimeType.MESSAGE_RSOCKET_ROUTING, encodeRoute()); } - this.metadata.forEach((value, mimeType) -> { - ByteBuf metadata = (value instanceof ByteBuf ? - (ByteBuf) value : PayloadUtils.asByteBuf(encodeEntry(value, mimeType))); + entries.forEach(entry -> { + Object value = entry.value(); CompositeMetadataFlyweight.encodeAndAddMetadata( - composite, this.allocator, mimeType.toString(), metadata); + composite, this.allocator, entry.mimeType().toString(), + value instanceof ByteBuf ? (ByteBuf) value : PayloadUtils.asByteBuf(encodeEntry(entry))); }); return asDataBuffer(composite); } @@ -181,21 +201,21 @@ final class MetadataEncoder { } } else if (this.route != null) { - Assert.isTrue(this.metadata.isEmpty(), "Composite metadata required for route and other entries"); + Assert.isTrue(entries.isEmpty(), "Composite metadata required for route and other entries"); String routingMimeType = WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString(); return this.metadataMimeType.toString().equals(routingMimeType) ? asDataBuffer(encodeRoute()) : encodeEntry(this.route, this.metadataMimeType); } else { - Assert.isTrue(this.metadata.size() == 1, "Composite metadata required for multiple entries"); - Map.Entry entry = this.metadata.entrySet().iterator().next(); - if (!this.metadataMimeType.equals(entry.getValue())) { + Assert.isTrue(entries.size() == 1, "Composite metadata required for multiple entries"); + MetadataEntry entry = entries.get(0); + if (!this.metadataMimeType.equals(entry.mimeType())) { throw new IllegalArgumentException( "Connection configured for metadata mime type " + - "'" + this.metadataMimeType + "', but actual is `" + this.metadata + "`"); + "'" + this.metadataMimeType + "', but actual is `" + entries + "`"); } - return encodeEntry(entry.getKey(), entry.getValue()); + return encodeEntry(entry); } } @@ -204,15 +224,19 @@ final class MetadataEncoder { this.allocator, Collections.singletonList(this.route)).getContent(); } + private DataBuffer encodeEntry(MetadataEntry entry) { + return encodeEntry(entry.value(), entry.mimeType()); + } + @SuppressWarnings("unchecked") - private DataBuffer encodeEntry(Object metadata, MimeType mimeType) { - if (metadata instanceof ByteBuf) { - return asDataBuffer((ByteBuf) metadata); + private DataBuffer encodeEntry(Object value, MimeType mimeType) { + if (value instanceof ByteBuf) { + return asDataBuffer((ByteBuf) value); } - ResolvableType type = ResolvableType.forInstance(metadata); + ResolvableType type = ResolvableType.forInstance(value); Encoder encoder = this.strategies.encoder(type, mimeType); - Assert.notNull(encoder, () -> "No encoder for metadata " + metadata + ", mimeType '" + mimeType + "'"); - return encoder.encodeValue((T) metadata, bufferFactory(), type, mimeType, Collections.emptyMap()); + Assert.notNull(encoder, () -> "No encoder for metadata " + value + ", mimeType '" + mimeType + "'"); + return encoder.encodeValue((T) value, bufferFactory(), type, mimeType, Collections.emptyMap()); } private DataBuffer asDataBuffer(ByteBuf byteBuf) { @@ -225,4 +249,48 @@ final class MetadataEncoder { return buffer; } } + + private Mono> resolveAsyncMetadata() { + Assert.state(this.hasAsyncValues, "No asynchronous values to resolve"); + List> valueMonos = new ArrayList<>(); + this.metadataEntries.forEach(entry -> { + Object v = entry.value(); + valueMonos.add(v instanceof Mono ? (Mono) v : Mono.just(v)); + }); + return Mono.zip(valueMonos, values -> { + List result = new ArrayList<>(values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] != NO_VALUE) { + result.add(new MetadataEntry(values[i], this.metadataEntries.get(i).mimeType())); + } + } + return result; + }); + } + + + /** + * Holder for the metadata value and mime type. + * @since 5.2.2 + */ + private static class MetadataEntry { + + private final Object value; + + private final MimeType mimeType; + + MetadataEntry(Object value, MimeType mimeType) { + this.value = value; + this.mimeType = mimeType; + } + + public Object value() { + return this.value; + } + + public MimeType mimeType() { + return this.mimeType; + } + } + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java index 4998a1f92f0e3e707b83baecea1d93f5499adb1c..d2981ef41b76ced8fa583ddec8d7e4a1def1a7c2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java @@ -85,7 +85,9 @@ public interface RSocketRequester { RequestSpec route(String route, Object... routeVars); /** - * Begin to specify a new request with the given metadata value. + * Begin to specify a new request with the given metadata value, which can + * be a concrete value or any producer of a single value that can be adapted + * to a {@link Publisher} via {@link ReactiveAdapterRegistry}. * @param metadata the metadata value to encode * @param mimeType the mime type that describes the metadata; * This is required for connection using composite metadata. Otherwise the @@ -143,6 +145,8 @@ public interface RSocketRequester { /** * Set the data for the setup payload. The data will be encoded * according to the configured {@link #dataMimeType(MimeType)}. + * The data be a concrete value or any producer of a single value that + * can be adapted to a {@link Publisher} via {@link ReactiveAdapterRegistry}. *

By default this is not set. */ RSocketRequester.Builder setupData(Object data); @@ -158,7 +162,9 @@ public interface RSocketRequester { /** * Add metadata entry to the setup payload. Composite metadata must be * in use if this is called more than once or in addition to - * {@link #setupRoute(String, Object...)}. + * {@link #setupRoute(String, Object...)}. The metadata value be a + * concrete value or any producer of a single value that can be adapted + * to a {@link Publisher} via {@link ReactiveAdapterRegistry}. */ RSocketRequester.Builder setupMetadata(Object value, @Nullable MimeType mimeType); @@ -335,6 +341,9 @@ public interface RSocketRequester { * Use this to append additional metadata entries when using composite * metadata. An {@link IllegalArgumentException} is raised if this * method is used when not using composite metadata. + * The metadata value be a concrete value or any producer of a single + * value that can be adapted to a {@link Publisher} via + * {@link ReactiveAdapterRegistry}. * @param metadata an Object to be encoded with a suitable * {@link org.springframework.core.codec.Encoder Encoder}, or a * {@link org.springframework.core.io.buffer.DataBuffer DataBuffer} diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java index c12a15160dfa13e94a875e7753ab21d5770c0992..66061ec9c2ff017a4e4eba9a2ce3754c3f450fb1 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java @@ -84,7 +84,7 @@ public class DefaultMetadataExtractorTests { .metadata("html data", TEXT_HTML) .metadata("xml data", TEXT_XML); - DataBuffer metadata = metadataEncoder.encode(); + DataBuffer metadata = metadataEncoder.encode().block(); Payload payload = createPayload(metadata); Map result = this.extractor.extract(payload, COMPOSITE_METADATA); payload.release(); @@ -104,7 +104,7 @@ public class DefaultMetadataExtractorTests { .metadata("html data", TEXT_HTML) .metadata("xml data", TEXT_XML); - DataBuffer metadata = metadataEncoder.encode(); + DataBuffer metadata = metadataEncoder.encode().block(); Payload payload = createPayload(metadata); Map result = this.extractor.extract(payload, COMPOSITE_METADATA); payload.release(); @@ -120,7 +120,7 @@ public class DefaultMetadataExtractorTests { public void route() { MimeType metaMimeType = MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); MetadataEncoder metadataEncoder = new MetadataEncoder(metaMimeType, this.strategies).route("toA"); - DataBuffer metadata = metadataEncoder.encode(); + DataBuffer metadata = metadataEncoder.encode().block(); Payload payload = createPayload(metadata); Map result = this.extractor.extract(payload, metaMimeType); payload.release(); @@ -133,7 +133,7 @@ public class DefaultMetadataExtractorTests { this.extractor.metadataToExtract(TEXT_PLAIN, String.class, ROUTE_KEY); MetadataEncoder metadataEncoder = new MetadataEncoder(TEXT_PLAIN, this.strategies).route("toA"); - DataBuffer metadata = metadataEncoder.encode(); + DataBuffer metadata = metadataEncoder.encode().block(); Payload payload = createPayload(metadata); Map result = this.extractor.extract(payload, TEXT_PLAIN); payload.release(); @@ -151,7 +151,7 @@ public class DefaultMetadataExtractorTests { }); MetadataEncoder encoder = new MetadataEncoder(TEXT_PLAIN, this.strategies).metadata("toA:text data", null); - DataBuffer metadata = encoder.encode(); + DataBuffer metadata = encoder.encode().block(); Payload payload = createPayload(metadata); Map result = this.extractor.extract(payload, TEXT_PLAIN); payload.release(); @@ -167,7 +167,7 @@ public class DefaultMetadataExtractorTests { extractor.metadataToExtract(TEXT_PLAIN, String.class, "name"); MetadataEncoder encoder = new MetadataEncoder(TEXT_PLAIN, this.strategies).metadata("value", null); - DataBuffer metadata = encoder.encode(); + DataBuffer metadata = encoder.encode().block(); Payload payload = createPayload(metadata); Map result = extractor.extract(payload, TEXT_PLAIN); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java index 79fefddc1a74a812924935276ee82c7396db356c..a21f4f91d1b4850f33ce334bd8fd0dbe36a470f2 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java @@ -17,6 +17,7 @@ package org.springframework.messaging.rsocket; import java.lang.reflect.Field; +import java.time.Duration; import java.util.Collections; import java.util.List; import java.util.Map; @@ -39,6 +40,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.ResolvableType; import org.springframework.core.codec.Decoder; import org.springframework.core.codec.DecodingException; +import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory; @@ -191,6 +193,39 @@ public class DefaultRSocketRequesterBuilderTests { assertThat(setupPayload.getDataUtf8()).isEqualTo("My data"); } + @Test + public void setupWithAsyncValues() { + + Mono asyncMeta1 = Mono.delay(Duration.ofMillis(1)).map(aLong -> "Async Metadata 1"); + Mono asyncMeta2 = Mono.delay(Duration.ofMillis(1)).map(aLong -> "Async Metadata 2"); + Mono data = Mono.delay(Duration.ofMillis(1)).map(aLong -> "Async data"); + + RSocketRequester.builder() + .dataMimeType(MimeTypeUtils.TEXT_PLAIN) + .setupRoute("toA") + .setupMetadata(asyncMeta1, new MimeType("text", "x.test.metadata1")) + .setupMetadata(asyncMeta2, new MimeType("text", "x.test.metadata2")) + .setupData(data) + .connect(this.transport) + .block(); + + ConnectionSetupPayload payload = Mono.from(this.connection.sentFrames()) + .map(ConnectionSetupPayload::create) + .block(); + + MimeType compositeMimeType = + MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(StringDecoder.allMimeTypes()); + extractor.metadataToExtract(new MimeType("text", "x.test.metadata1"), String.class, "asyncMeta1"); + extractor.metadataToExtract(new MimeType("text", "x.test.metadata2"), String.class, "asyncMeta2"); + Map metadataValues = extractor.extract(payload, compositeMimeType); + + assertThat(metadataValues.get("asyncMeta1")).isEqualTo("Async Metadata 1"); + assertThat(metadataValues.get("asyncMeta2")).isEqualTo("Async Metadata 2"); + assertThat(payload.getDataUtf8()).isEqualTo("Async data"); + } + @Test public void frameDecoderMatchesDataBufferFactory() throws Exception { testFrameDecoder(new NettyDataBufferFactory(ByteBufAllocator.DEFAULT), PayloadDecoder.ZERO_COPY); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java index 94055e63f5e5899ae7147677e44d676167eda4ac..532e1d76b31ee36c6e4d3ebcd57be1b264a44111 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java @@ -20,6 +20,7 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; @@ -28,6 +29,7 @@ import io.reactivex.Observable; import io.reactivex.Single; import io.rsocket.AbstractRSocket; import io.rsocket.Payload; +import io.rsocket.metadata.WellKnownMimeType; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; @@ -38,10 +40,12 @@ import reactor.test.StepVerifier; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.lang.Nullable; import org.springframework.messaging.rsocket.RSocketRequester.RequestSpec; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.springframework.util.MimeTypeUtils.TEXT_PLAIN; /** @@ -131,6 +135,54 @@ public class DefaultRSocketRequesterTests { } } + @Test + public void sendWithoutData() { + this.requester.route("toA").send().block(Duration.ofSeconds(5)); + + assertThat(this.rsocket.getSavedMethodName()).isEqualTo("fireAndForget"); + assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("toA"); + assertThat(this.rsocket.getSavedPayload().getDataUtf8()).isEqualTo(""); + } + + @Test + public void sendMonoWithoutData() { + this.requester.route("toA").retrieveMono(String.class).block(Duration.ofSeconds(5)); + + assertThat(this.rsocket.getSavedMethodName()).isEqualTo("requestResponse"); + assertThat(this.rsocket.getSavedPayload().getMetadataUtf8()).isEqualTo("toA"); + assertThat(this.rsocket.getSavedPayload().getDataUtf8()).isEqualTo(""); + } + + @Test + public void testSendWithAsyncMetadata() { + + MimeType compositeMimeType = + MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + Mono asyncMeta1 = Mono.delay(Duration.ofMillis(1)).map(aLong -> "Async Metadata 1"); + Mono asyncMeta2 = Mono.delay(Duration.ofMillis(1)).map(aLong -> "Async Metadata 2"); + + TestRSocket rsocket = new TestRSocket(); + RSocketRequester.wrap(rsocket, TEXT_PLAIN, compositeMimeType, this.strategies) + .route("toA") + .metadata(asyncMeta1, new MimeType("text", "x.test.metadata1")) + .metadata(asyncMeta2, new MimeType("text", "x.test.metadata2")) + .data("data") + .send() + .block(Duration.ofSeconds(5)); + + Payload payload = rsocket.getSavedPayload(); + + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(this.strategies.decoders()); + extractor.metadataToExtract(new MimeType("text", "x.test.metadata1"), String.class, "asyncMeta1"); + extractor.metadataToExtract(new MimeType("text", "x.test.metadata2"), String.class, "asyncMeta2"); + Map metadataValues = extractor.extract(payload, compositeMimeType); + + assertThat(metadataValues.get("asyncMeta1")).isEqualTo("Async Metadata 1"); + assertThat(metadataValues.get("asyncMeta2")).isEqualTo("Async Metadata 2"); + assertThat(payload.getDataUtf8()).isEqualTo("data"); + } + @Test public void retrieveMono() { String value = "bodyA"; @@ -176,7 +228,7 @@ public class DefaultRSocketRequesterTests { @Test public void fluxToMonoIsRejected() { - assertThatIllegalArgumentException() + assertThatIllegalStateException() .isThrownBy(() -> this.requester.route("").data(Flux.just("a", "b")).retrieveMono(String.class)) .withMessage("No RSocket interaction model for Flux request to Mono response."); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java index 6e951de7e9a9bfb559aed86f2398f46d8276ae62..a1ca77d15dda4979a00d6d70888ca455b2be7da6 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java @@ -15,6 +15,7 @@ */ package org.springframework.messaging.rsocket; +import java.time.Duration; import java.util.Collections; import java.util.Iterator; import java.util.Map; @@ -26,6 +27,7 @@ import io.rsocket.metadata.CompositeMetadata; import io.rsocket.metadata.RoutingMetadata; import io.rsocket.metadata.WellKnownMimeType; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; @@ -56,11 +58,17 @@ public class MetadataEncoderTests { @Test public void compositeMetadata() { + Mono asyncMeta1 = Mono.delay(Duration.ofMillis(1)).map(aLong -> "Async Metadata 1"); + Mono asyncMeta2 = Mono.delay(Duration.ofMillis(1)).map(aLong -> "Async Metadata 2"); + DataBuffer buffer = new MetadataEncoder(COMPOSITE_METADATA, this.strategies) .route("toA") .metadata("My metadata", MimeTypeUtils.TEXT_PLAIN) + .metadata(asyncMeta1, new MimeType("text", "x.test.metadata1")) .metadata(Unpooled.wrappedBuffer("Raw data".getBytes(UTF_8)), MimeTypeUtils.APPLICATION_OCTET_STREAM) - .encode(); + .metadata(asyncMeta2, new MimeType("text", "x.test.metadata2")) + .encode() + .block(); CompositeMetadata entries = new CompositeMetadata(((NettyDataBuffer) buffer).getNativeBuffer(), false); Iterator iterator = entries.iterator(); @@ -75,11 +83,21 @@ public class MetadataEncoderTests { assertThat(entry.getMimeType()).isEqualTo(MimeTypeUtils.TEXT_PLAIN_VALUE); assertThat(entry.getContent().toString(UTF_8)).isEqualTo("My metadata"); + assertThat(iterator.hasNext()).isTrue(); + entry = iterator.next(); + assertThat(entry.getMimeType()).isEqualTo("text/x.test.metadata1"); + assertThat(entry.getContent().toString(UTF_8)).isEqualTo("Async Metadata 1"); + assertThat(iterator.hasNext()).isTrue(); entry = iterator.next(); assertThat(entry.getMimeType()).isEqualTo(MimeTypeUtils.APPLICATION_OCTET_STREAM_VALUE); assertThat(entry.getContent().toString(UTF_8)).isEqualTo("Raw data"); + assertThat(iterator.hasNext()).isTrue(); + entry = iterator.next(); + assertThat(entry.getMimeType()).isEqualTo("text/x.test.metadata2"); + assertThat(entry.getContent().toString(UTF_8)).isEqualTo("Async Metadata 2"); + assertThat(iterator.hasNext()).isFalse(); } @@ -92,7 +110,8 @@ public class MetadataEncoderTests { DataBuffer buffer = new MetadataEncoder(mimeType, this.strategies) .route("toA") - .encode(); + .encode() + .block(); assertRoute("toA", ((NettyDataBuffer) buffer).getNativeBuffer()); } @@ -102,7 +121,8 @@ public class MetadataEncoderTests { DataBuffer buffer = new MetadataEncoder(MimeTypeUtils.TEXT_PLAIN, this.strategies) .route("toA") - .encode(); + .encode() + .block(); assertThat(dumpString(buffer)).isEqualTo("toA"); } @@ -112,7 +132,8 @@ public class MetadataEncoderTests { DataBuffer buffer = new MetadataEncoder(MimeTypeUtils.TEXT_PLAIN, this.strategies) .route("a.{b}.{c}", "BBB", "C.C.C") - .encode(); + .encode() + .block(); assertThat(dumpString(buffer)).isEqualTo("a.BBB.C%2EC%2EC"); } @@ -122,7 +143,8 @@ public class MetadataEncoderTests { DataBuffer buffer = new MetadataEncoder(MimeTypeUtils.TEXT_PLAIN, this.strategies) .metadata(Unpooled.wrappedBuffer("Raw data".getBytes(UTF_8)), null) - .encode(); + .encode() + .block(); assertThat(dumpString(buffer)).isEqualTo("Raw data"); } @@ -132,7 +154,8 @@ public class MetadataEncoderTests { DataBuffer buffer = new MetadataEncoder(MimeTypeUtils.TEXT_PLAIN, this.strategies) .metadata("toA", null) - .encode(); + .encode() + .block(); assertThat(dumpString(buffer)).isEqualTo("toA"); } @@ -175,8 +198,8 @@ public class MetadataEncoderTests { MetadataEncoder encoder = new MetadataEncoder(MimeTypeUtils.TEXT_PLAIN, this.strategies); assertThatThrownBy(() -> encoder.metadata("toA", MimeTypeUtils.APPLICATION_JSON)) - .hasMessage("Mime type is optional (may be null) " + - "but was provided and does not match the connection metadata mime type."); + .hasMessage("Mime type is optional when not using composite metadata, " + + "but it was provided and does not match the connection metadata mime type 'text/plain'."); } @Test @@ -186,7 +209,8 @@ public class MetadataEncoderTests { DataBuffer buffer = new MetadataEncoder(COMPOSITE_METADATA, strategies) .route("toA") - .encode(); + .encode() + .block(); ByteBuf byteBuf = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT) .wrap(buffer.asByteBuffer())