提交 d9221fb8 编写于 作者: R Rossen Stoyanchev

Async boundary for Spring MVC reactive type streaming

Issue: SPR-15365
上级 af6f6881
......@@ -19,7 +19,12 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
......@@ -27,6 +32,7 @@ import org.reactivestreams.Subscription;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.http.server.ServerHttpResponse;
......@@ -60,18 +66,26 @@ import org.springframework.web.servlet.HandlerMapping;
*/
class ReactiveTypeHandler {
private static Log logger = LogFactory.getLog(ReactiveTypeHandler.class);
private static final MediaType JSON_TYPE = new MediaType("application", "*+json");
private final ReactiveAdapterRegistry reactiveRegistry;
private final TaskExecutor taskExecutor;
private final ContentNegotiationManager contentNegotiationManager;
ReactiveTypeHandler(ReactiveAdapterRegistry registry, ContentNegotiationManager manager) {
ReactiveTypeHandler(ReactiveAdapterRegistry registry, TaskExecutor executor,
ContentNegotiationManager manager) {
Assert.notNull(registry, "ReactiveAdapterRegistry is required");
Assert.notNull(executor, "TaskExecutor is required");
Assert.notNull(manager, "ContentNegotiationManager is required");
this.reactiveRegistry = registry;
this.taskExecutor = executor;
this.contentNegotiationManager = manager;
}
......@@ -108,17 +122,17 @@ class ReactiveTypeHandler {
if (mediaTypes.stream().anyMatch(MediaType.TEXT_EVENT_STREAM::includes) ||
ServerSentEvent.class.isAssignableFrom(elementType)) {
SseEmitter emitter = new SseEmitter();
new SseEmitterSubscriber(emitter).connect(adapter, returnValue);
new SseEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue);
return emitter;
}
if (mediaTypes.stream().anyMatch(MediaType.APPLICATION_STREAM_JSON::includes)) {
ResponseBodyEmitter emitter = getEmitter(MediaType.APPLICATION_STREAM_JSON);
new JsonEmitterSubscriber(emitter).connect(adapter, returnValue);
new JsonEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue);
return emitter;
}
if (CharSequence.class.isAssignableFrom(elementType) && !jsonArrayOfStrings) {
ResponseBodyEmitter emitter = getEmitter(mediaType.orElse(MediaType.TEXT_PLAIN));
new TextEmitterSubscriber(emitter).connect(adapter, returnValue);
new TextEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue);
return emitter;
}
}
......@@ -161,13 +175,25 @@ class ReactiveTypeHandler {
private static abstract class AbstractEmitterSubscriber implements Subscriber<Object> {
private static final Object COMPLETE_SIGNAL = new Object();
private final ResponseBodyEmitter emitter;
private final TaskExecutor taskExecutor;
private Subscription subscription;
private final Queue<Object> queue = new ConcurrentLinkedQueue<>();
private final AtomicBoolean executing = new AtomicBoolean(false);
private volatile boolean done;
protected AbstractEmitterSubscriber(ResponseBodyEmitter emitter) {
protected AbstractEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) {
this.emitter = emitter;
this.taskExecutor = executor;
}
......@@ -181,42 +207,113 @@ class ReactiveTypeHandler {
return this.emitter;
}
@Override
public void onSubscribe(Subscription subscription) {
public final void onSubscribe(Subscription subscription) {
this.subscription = subscription;
this.emitter.onTimeout(subscription::cancel);
if (logger.isDebugEnabled()) {
logger.debug("Subscribed to Publisher for " + this.emitter);
}
this.emitter.onTimeout(() -> {
if (logger.isDebugEnabled()) {
logger.debug("Connection timed out for " + this.emitter);
}
terminate();
this.emitter.complete();
});
subscription.request(1);
}
@Override
public void onNext(Object element) {
try {
send(element);
this.subscription.request(1);
public final void onNext(Object element) {
this.queue.offer(element);
trySchedule();
}
@Override
public final void onError(Throwable ex) {
this.queue.offer(ex);
trySchedule();
}
@Override
public final void onComplete() {
this.queue.offer(COMPLETE_SIGNAL);
trySchedule();
}
private void trySchedule() {
if (this.executing.compareAndSet(false, true)) {
try {
this.taskExecutor.execute(() -> {
try {
Object signal = this.queue.poll();
if (!this.done) {
handle(signal);
}
}
finally {
this.executing.set(false);
if(!this.queue.isEmpty())
trySchedule();
}
});
}
catch (Throwable ex) {
try {
terminate();
}
finally {
this.executing.set(false);
this.queue.clear();
}
}
}
}
private void handle(Object signal) {
if (signal instanceof Throwable) {
if (logger.isDebugEnabled()) {
logger.debug("Publisher error for " + this.emitter, (Throwable) signal);
}
this.done = true;
this.emitter.completeWithError((Throwable) signal);
}
catch (IOException ex) {
this.subscription.cancel();
else if (signal == COMPLETE_SIGNAL) {
if (logger.isDebugEnabled()) {
logger.debug("Publishing completed for " + this.emitter);
}
this.done = true;
this.emitter.complete();
}
else {
try {
send(signal);
this.subscription.request(1);
}
catch (final Throwable ex) {
if (logger.isDebugEnabled()) {
logger.debug("Send error for " + this.emitter, ex);
}
terminate();
}
}
}
protected abstract void send(Object element) throws IOException;
@Override
public void onError(Throwable ex) {
this.emitter.completeWithError(ex);
private void terminate() {
this.done = true;
this.subscription.cancel();
}
@Override
public void onComplete() {
this.emitter.complete();
}
}
private static class SseEmitterSubscriber extends AbstractEmitterSubscriber {
SseEmitterSubscriber(SseEmitter sseEmitter) {
super(sseEmitter);
SseEmitterSubscriber(SseEmitter sseEmitter, TaskExecutor executor) {
super(sseEmitter, executor);
}
@Override
......@@ -243,8 +340,8 @@ class ReactiveTypeHandler {
private static class JsonEmitterSubscriber extends AbstractEmitterSubscriber {
JsonEmitterSubscriber(ResponseBodyEmitter emitter) {
super(emitter);
JsonEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) {
super(emitter, executor);
}
@Override
......@@ -257,8 +354,8 @@ class ReactiveTypeHandler {
private static class TextEmitterSubscriber extends AbstractEmitterSubscriber {
TextEmitterSubscriber(ResponseBodyEmitter emitter) {
super(emitter);
TextEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) {
super(emitter, executor);
}
@Override
......
......@@ -683,7 +683,7 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter
handlers.add(new ModelMethodProcessor());
handlers.add(new ViewMethodReturnValueHandler());
handlers.add(new ResponseBodyEmitterReturnValueHandler(getMessageConverters(),
this.reactiveRegistry, this.contentNegotiationManager));
this.reactiveRegistry, this.taskExecutor, this.contentNegotiationManager));
handlers.add(new StreamingResponseBodyReturnValueHandler());
handlers.add(new HttpEntityMethodProcessor(getMessageConverters(),
this.contentNegotiationManager, this.requestResponseBodyAdvice));
......
......@@ -23,6 +23,7 @@ import java.util.Set;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
/**
* A controller method return value type for asynchronous request processing
......@@ -224,6 +225,12 @@ public class ResponseBodyEmitter {
}
@Override
public String toString() {
return "ResponseBodyEmitter@" + ObjectUtils.getIdentityHexString(this);
}
/**
* Handle sent objects and complete request processing.
*/
......
......@@ -28,6 +28,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
......@@ -63,11 +64,12 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur
public ResponseBodyEmitterReturnValueHandler(List<HttpMessageConverter<?>> messageConverters,
ReactiveAdapterRegistry reactiveRegistry, ContentNegotiationManager manager) {
ReactiveAdapterRegistry reactiveRegistry, TaskExecutor executor,
ContentNegotiationManager manager) {
Assert.notEmpty(messageConverters, "HttpMessageConverter List must not be empty");
this.messageConverters = messageConverters;
this.reactiveHandler = new ReactiveTypeHandler(reactiveRegistry, manager);
this.reactiveHandler = new ReactiveTypeHandler(reactiveRegistry, executor, manager);
}
......@@ -158,13 +160,13 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur
@SuppressWarnings("unchecked")
private <T> void sendInternal(T data, MediaType mediaType) throws IOException {
if (logger.isTraceEnabled()) {
logger.trace("Writing [" + data + "]");
}
for (HttpMessageConverter<?> converter : ResponseBodyEmitterReturnValueHandler.this.messageConverters) {
if (converter.canWrite(data.getClass(), mediaType)) {
((HttpMessageConverter<T>) converter).write(data, mediaType, this.outputMessage);
this.outputMessage.flush();
if (logger.isDebugEnabled()) {
logger.debug("Written [" + data + "] using [" + converter + "]");
}
return;
}
}
......
......@@ -25,6 +25,7 @@ import java.util.Set;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
/**
......@@ -128,6 +129,11 @@ public class SseEmitter extends ResponseBodyEmitter {
}
}
@Override
public String toString() {
return "SseEmitter@" + ObjectUtils.getIdentityHexString(this);
}
public static SseEventBuilder event() {
return new SseEventBuilderImpl();
......
......@@ -34,6 +34,7 @@ import rx.SingleEmitter;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.http.server.ServletServerHttpResponse;
......@@ -75,7 +76,7 @@ public class ReactiveTypeHandlerTests {
ContentNegotiationManagerFactoryBean factoryBean = new ContentNegotiationManagerFactoryBean();
factoryBean.afterPropertiesSet();
ContentNegotiationManager manager = factoryBean.getObject();
this.handler = new ReactiveTypeHandler(new ReactiveAdapterRegistry(), manager);
this.handler = new ReactiveTypeHandler(new ReactiveAdapterRegistry(), new SyncTaskExecutor(), manager);
resetRequest();
}
......
......@@ -26,6 +26,8 @@ import org.junit.Test;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.ResponseEntity;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
......@@ -66,8 +68,9 @@ public class ResponseBodyEmitterReturnValueHandlerTests {
new StringHttpMessageConverter(), new MappingJackson2HttpMessageConverter());
ReactiveAdapterRegistry registry = new ReactiveAdapterRegistry();
TaskExecutor executor = new SyncTaskExecutor();
ContentNegotiationManager manager = new ContentNegotiationManager();
this.handler = new ResponseBodyEmitterReturnValueHandler(converters, registry, manager);
this.handler = new ResponseBodyEmitterReturnValueHandler(converters, registry, executor, manager);
this.request = new MockHttpServletRequest();
this.response = new MockHttpServletResponse();
this.webRequest = new ServletWebRequest(this.request, this.response);
......
......@@ -30,6 +30,7 @@ import reactor.core.publisher.Flux;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AliasFor;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.http.converter.HttpMessageConverter;
......@@ -254,9 +255,11 @@ public class ServletInvocableHandlerMethodTests {
@Test
public void wrapConcurrentResult_ResponseBodyEmitter() throws Exception {
ReactiveAdapterRegistry registry = new ReactiveAdapterRegistry();
ContentNegotiationManager manager = new ContentNegotiationManager();
this.returnValueHandlers.addHandler(new ResponseBodyEmitterReturnValueHandler(this.converters, registry, manager));
this.returnValueHandlers.addHandler(
new ResponseBodyEmitterReturnValueHandler(this.converters,
new ReactiveAdapterRegistry(), new SyncTaskExecutor(), new ContentNegotiationManager()));
ServletInvocableHandlerMethod handlerMethod = getHandlerMethod(new StreamingHandler(), "handleEmitter");
handlerMethod = handlerMethod.wrapConcurrentResult(null);
handlerMethod.invokeAndHandle(this.webRequest, this.mavContainer);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册