提交 b0de99bc 编写于 作者: A Arjen Poutsma

Refactor ResponseBodySubscriber to Processor

This commit changes the AbstractResponseBodySubscriber into a
AbstractResponseBodyProcessor<DataBuffer, Void>, so that the processor
can be used as a return value for writeWith.

Additional, this commit no longer closes the response after an eror
occurred.

This fixes #59.
上级 c85d1dc1
...@@ -52,6 +52,8 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> { ...@@ -52,6 +52,8 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> {
private Subscriber<? super DataBuffer> subscriber; private Subscriber<? super DataBuffer> subscriber;
private volatile boolean dataAvailable;
@Override @Override
public void subscribe(Subscriber<? super DataBuffer> subscriber) { public void subscribe(Subscriber<? super DataBuffer> subscriber) {
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
...@@ -199,7 +201,9 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> { ...@@ -199,7 +201,9 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> {
void subscribe(AbstractRequestBodyPublisher publisher, void subscribe(AbstractRequestBodyPublisher publisher,
Subscriber<? super DataBuffer> subscriber) { Subscriber<? super DataBuffer> subscriber) {
Objects.requireNonNull(subscriber); Objects.requireNonNull(subscriber);
if (publisher.changeState(this, DATA_UNAVAILABLE)) { State newState =
publisher.dataAvailable ? DATA_AVAILABLE : DATA_UNAVAILABLE;
if (publisher.changeState(this, newState)) {
Subscription subscription = new RequestBodySubscription( Subscription subscription = new RequestBodySubscription(
publisher); publisher);
publisher.subscriber = subscriber; publisher.subscriber = subscriber;
...@@ -209,6 +213,11 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> { ...@@ -209,6 +213,11 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> {
throw new IllegalStateException(toString()); throw new IllegalStateException(toString());
} }
} }
@Override
void onDataAvailable(AbstractRequestBodyPublisher publisher) {
publisher.dataAvailable = true;
}
}, },
/** /**
* State that gets entered when there is no data to be read. Responds to {@link * State that gets entered when there is no data to be read. Responds to {@link
...@@ -252,20 +261,11 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> { ...@@ -252,20 +261,11 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> {
} }
} }
@Override
void onDataAvailable(AbstractRequestBodyPublisher publisher) {
// ignore
}
}, },
/** /**
* The terminal completed state. Does not respond to any events. * The terminal completed state. Does not respond to any events.
*/ */
COMPLETED { COMPLETED {
@Override
void subscribe(AbstractRequestBodyPublisher publisher,
Subscriber<? super DataBuffer> subscriber) {
// ignore
}
@Override @Override
void request(AbstractRequestBodyPublisher publisher, long n) { void request(AbstractRequestBodyPublisher publisher, long n) {
...@@ -277,11 +277,6 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> { ...@@ -277,11 +277,6 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> {
// ignore // ignore
} }
@Override
void onDataAvailable(AbstractRequestBodyPublisher publisher) {
// ignore
}
@Override @Override
void onAllDataRead(AbstractRequestBodyPublisher publisher) { void onAllDataRead(AbstractRequestBodyPublisher publisher) {
// ignore // ignore
...@@ -309,7 +304,7 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> { ...@@ -309,7 +304,7 @@ abstract class AbstractRequestBodyPublisher implements Publisher<DataBuffer> {
} }
void onDataAvailable(AbstractRequestBodyPublisher publisher) { void onDataAvailable(AbstractRequestBodyPublisher publisher) {
throw new IllegalStateException(toString()); // ignore
} }
void onAllDataRead(AbstractRequestBodyPublisher publisher) { void onAllDataRead(AbstractRequestBodyPublisher publisher) {
......
...@@ -24,8 +24,10 @@ import javax.servlet.WriteListener; ...@@ -24,8 +24,10 @@ import javax.servlet.WriteListener;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Processor;
import org.reactivestreams.Subscriber; import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription; import org.reactivestreams.Subscription;
import reactor.core.util.BackpressureUtils;
import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.FlushingDataBuffer; import org.springframework.core.io.buffer.FlushingDataBuffer;
...@@ -40,58 +42,95 @@ import org.springframework.util.Assert; ...@@ -40,58 +42,95 @@ import org.springframework.util.Assert;
* @see ServletServerHttpRequest * @see ServletServerHttpRequest
* @see UndertowHttpHandlerAdapter * @see UndertowHttpHandlerAdapter
*/ */
abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer> { abstract class AbstractResponseBodyProcessor implements Processor<DataBuffer, Void> {
protected final Log logger = LogFactory.getLog(getClass()); protected final Log logger = LogFactory.getLog(getClass());
private final AtomicReference<State> state = private final AtomicReference<SubscriberState> subscriberState =
new AtomicReference<>(State.UNSUBSCRIBED); new AtomicReference<>(SubscriberState.UNSUBSCRIBED);
private final AtomicReference<PublisherState> publisherState =
new AtomicReference<>(PublisherState.UNSUBSCRIBED);
private volatile DataBuffer currentBuffer; private volatile DataBuffer currentBuffer;
private volatile boolean subscriptionCompleted; private volatile boolean subscriberCompleted;
private volatile boolean publisherCompleted;
private volatile Throwable publisherError;
private Subscription subscription; private Subscription subscription;
private Subscriber<? super Void> subscriber;
// Subscriber
@Override @Override
public final void onSubscribe(Subscription subscription) { public final void onSubscribe(Subscription subscription) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace(this.state + " onSubscribe: " + subscription); logger.trace("SUB " + this.subscriberState + " onSubscribe: " + subscription);
} }
this.state.get().onSubscribe(this, subscription); this.subscriberState.get().onSubscribe(this, subscription);
} }
@Override @Override
public final void onNext(DataBuffer dataBuffer) { public final void onNext(DataBuffer dataBuffer) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace(this.state + " onNext: " + dataBuffer); logger.trace("SUB " + this.subscriberState + " onNext: " + dataBuffer);
} }
this.state.get().onNext(this, dataBuffer); this.subscriberState.get().onNext(this, dataBuffer);
} }
@Override @Override
public final void onError(Throwable t) { public final void onError(Throwable t) {
if (logger.isErrorEnabled()) { if (logger.isErrorEnabled()) {
logger.error(this.state + " onError: " + t, t); logger.error("SUB " + this.subscriberState + " publishError: " + t, t);
} }
this.state.get().onError(this, t); this.subscriberState.get().onError(this, t);
} }
@Override @Override
public final void onComplete() { public final void onComplete() {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace(this.state + " onComplete"); logger.trace("SUB " + this.subscriberState + " onComplete");
} }
this.state.get().onComplete(this); this.subscriberState.get().onComplete(this);
} }
// Publisher
@Override
public final void subscribe(Subscriber<? super Void> subscriber) {
if (logger.isTraceEnabled()) {
logger.trace("PUB " + this.publisherState + " subscribe: " + subscriber);
}
this.publisherState.get().subscribe(this, subscriber);
}
private void publishComplete() {
if (logger.isTraceEnabled()) {
logger.trace("PUB " + this.publisherState + " publishComplete");
}
this.publisherState.get().publishComplete(this);
}
private void publishError(Throwable t) {
if (logger.isTraceEnabled()) {
logger.trace("PUB " + this.publisherState + " publishError: " + t);
}
this.publisherState.get().publishError(this, t);
}
// listener methods
/** /**
* Called via a listener interface to indicate that writing is possible. * Called via a listener interface to indicate that writing is possible.
* @see WriteListener#onWritePossible() * @see WriteListener#onWritePossible()
* @see org.xnio.ChannelListener#handleEvent(Channel) * @see org.xnio.ChannelListener#handleEvent(Channel)
*/ */
protected final void onWritePossible() { protected final void onWritePossible() {
this.state.get().onWritePossible(this); this.subscriberState.get().onWritePossible(this);
} }
/** /**
...@@ -134,11 +173,6 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer> ...@@ -134,11 +173,6 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer>
*/ */
protected abstract boolean write(DataBuffer dataBuffer) throws IOException; protected abstract boolean write(DataBuffer dataBuffer) throws IOException;
/**
* Writes the given exception to the output.
*/
protected abstract void writeError(Throwable t);
/** /**
* Flushes the output. * Flushes the output.
*/ */
...@@ -149,8 +183,44 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer> ...@@ -149,8 +183,44 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer>
*/ */
protected abstract void close(); protected abstract void close();
private boolean changeState(State oldState, State newState) { private boolean changeSubscriberState(SubscriberState oldState,
return this.state.compareAndSet(oldState, newState); SubscriberState newState) {
return this.subscriberState.compareAndSet(oldState, newState);
}
private boolean changePublisherState(PublisherState oldState,
PublisherState newState) {
return this.publisherState.compareAndSet(oldState, newState);
}
private static final class ResponseBodySubscription implements Subscription {
private final AbstractResponseBodyProcessor processor;
public ResponseBodySubscription(AbstractResponseBodyProcessor processor) {
this.processor = processor;
}
@Override
public final void request(long n) {
if (this.processor.logger.isTraceEnabled()) {
this.processor.logger.trace("PUB " + state() + " request: " + n);
}
state().request(this.processor, n);
}
@Override
public final void cancel() {
if (this.processor.logger.isTraceEnabled()) {
this.processor.logger.trace("PUB " + state() + " cancel");
}
state().cancel(this.processor);
}
private PublisherState state() {
return this.processor.publisherState.get();
}
} }
/** /**
...@@ -171,7 +241,7 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer> ...@@ -171,7 +241,7 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer>
* </pre> * </pre>
* Refer to the individual states for more information. * Refer to the individual states for more information.
*/ */
private enum State { private enum SubscriberState {
/** /**
* The initial unsubscribed state. Will respond to {@code onSubscribe} by * The initial unsubscribed state. Will respond to {@code onSubscribe} by
...@@ -180,15 +250,15 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer> ...@@ -180,15 +250,15 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer>
*/ */
UNSUBSCRIBED { UNSUBSCRIBED {
@Override @Override
void onSubscribe(AbstractResponseBodySubscriber subscriber, void onSubscribe(AbstractResponseBodyProcessor processor,
Subscription subscription) { Subscription subscription) {
Objects.requireNonNull(subscription, "Subscription cannot be null"); Objects.requireNonNull(subscription, "Subscription cannot be null");
if (subscriber.changeState(this, REQUESTED)) { if (processor.changeSubscriberState(this, REQUESTED)) {
subscriber.subscription = subscription; processor.subscription = subscription;
subscription.request(1); subscription.request(1);
} }
else { else {
super.onSubscribe(subscriber, subscription); super.onSubscribe(processor, subscription);
} }
} }
}, },
...@@ -200,18 +270,18 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer> ...@@ -200,18 +270,18 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer>
*/ */
REQUESTED { REQUESTED {
@Override @Override
void onNext(AbstractResponseBodySubscriber subscriber, void onNext(AbstractResponseBodyProcessor processor, DataBuffer dataBuffer) {
DataBuffer dataBuffer) { if (processor.changeSubscriberState(this, RECEIVED)) {
if (subscriber.changeState(this, RECEIVED)) { processor.receiveBuffer(dataBuffer);
subscriber.receiveBuffer(dataBuffer);
} }
} }
@Override @Override
void onComplete(AbstractResponseBodySubscriber subscriber) { void onComplete(AbstractResponseBodyProcessor processor) {
if (subscriber.changeState(this, COMPLETED)) { if (processor.changeSubscriberState(this, COMPLETED)) {
subscriber.subscriptionCompleted = true; processor.subscriberCompleted = true;
subscriber.close(); processor.close();
processor.publishComplete();
} }
} }
}, },
...@@ -226,40 +296,40 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer> ...@@ -226,40 +296,40 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer>
*/ */
RECEIVED { RECEIVED {
@Override @Override
void onWritePossible(AbstractResponseBodySubscriber subscriber) { void onWritePossible(AbstractResponseBodyProcessor processor) {
if (subscriber.changeState(this, WRITING)) { if (processor.changeSubscriberState(this, WRITING)) {
DataBuffer dataBuffer = subscriber.currentBuffer; DataBuffer dataBuffer = processor.currentBuffer;
try { try {
boolean writeCompleted = subscriber.write(dataBuffer); boolean writeCompleted = processor.write(dataBuffer);
if (writeCompleted) { if (writeCompleted) {
if (dataBuffer instanceof FlushingDataBuffer) { if (dataBuffer instanceof FlushingDataBuffer) {
subscriber.flush(); processor.flush();
} }
subscriber.releaseBuffer(); processor.releaseBuffer();
boolean subscriptionCompleted = subscriber.subscriptionCompleted; if (!processor.subscriberCompleted) {
if (!subscriptionCompleted) { processor.changeSubscriberState(WRITING, REQUESTED);
subscriber.changeState(WRITING, REQUESTED); processor.subscription.request(1);
subscriber.subscription.request(1);
} }
else { else {
subscriber.changeState(WRITING, COMPLETED); processor.changeSubscriberState(WRITING, COMPLETED);
subscriber.close(); processor.close();
processor.publishComplete();
} }
} }
else { else {
subscriber.changeState(WRITING, RECEIVED); processor.changeSubscriberState(WRITING, RECEIVED);
subscriber.checkOnWritePossible(); processor.checkOnWritePossible();
} }
} }
catch (IOException ex) { catch (IOException ex) {
subscriber.onError(ex); processor.onError(ex);
} }
} }
} }
@Override @Override
void onComplete(AbstractResponseBodySubscriber subscriber) { void onComplete(AbstractResponseBodyProcessor processor) {
subscriber.subscriptionCompleted = true; processor.subscriberCompleted = true;
} }
}, },
/** /**
...@@ -268,8 +338,8 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer> ...@@ -268,8 +338,8 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer>
*/ */
WRITING { WRITING {
@Override @Override
void onComplete(AbstractResponseBodySubscriber subscriber) { void onComplete(AbstractResponseBodyProcessor processor) {
subscriber.subscriptionCompleted = true; processor.subscriberCompleted = true;
} }
}, },
/** /**
...@@ -277,44 +347,139 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer> ...@@ -277,44 +347,139 @@ abstract class AbstractResponseBodySubscriber implements Subscriber<DataBuffer>
*/ */
COMPLETED { COMPLETED {
@Override @Override
void onNext(AbstractResponseBodySubscriber subscriber, void onNext(AbstractResponseBodyProcessor processor, DataBuffer dataBuffer) {
DataBuffer dataBuffer) {
// ignore // ignore
} }
@Override @Override
void onError(AbstractResponseBodySubscriber subscriber, Throwable t) { void onError(AbstractResponseBodyProcessor processor, Throwable t) {
// ignore // ignore
} }
@Override @Override
void onComplete(AbstractResponseBodySubscriber subscriber) { void onComplete(AbstractResponseBodyProcessor processor) {
// ignore // ignore
} }
}; };
void onSubscribe(AbstractResponseBodySubscriber subscriber, Subscription s) { void onSubscribe(AbstractResponseBodyProcessor processor, Subscription s) {
s.cancel(); s.cancel();
} }
void onNext(AbstractResponseBodySubscriber subscriber, DataBuffer dataBuffer) { void onNext(AbstractResponseBodyProcessor processor, DataBuffer dataBuffer) {
throw new IllegalStateException(toString()); throw new IllegalStateException(toString());
} }
void onError(AbstractResponseBodySubscriber subscriber, Throwable t) { void onError(AbstractResponseBodyProcessor processor, Throwable t) {
if (subscriber.changeState(this, COMPLETED)) { if (processor.changeSubscriberState(this, COMPLETED)) {
subscriber.writeError(t); processor.publishError(t);
subscriber.close();
} }
} }
void onComplete(AbstractResponseBodySubscriber subscriber) { void onComplete(AbstractResponseBodyProcessor processor) {
throw new IllegalStateException(toString()); throw new IllegalStateException(toString());
} }
void onWritePossible(AbstractResponseBodySubscriber subscriber) { void onWritePossible(AbstractResponseBodyProcessor processor) {
// ignore // ignore
} }
}
private enum PublisherState {
UNSUBSCRIBED {
@Override
void subscribe(AbstractResponseBodyProcessor processor,
Subscriber<? super Void> subscriber) {
Objects.requireNonNull(subscriber);
if (processor.changePublisherState(this, SUBSCRIBED)) {
Subscription subscription = new ResponseBodySubscription(processor);
processor.subscriber = subscriber;
subscriber.onSubscribe(subscription);
if (processor.publisherCompleted) {
processor.publishComplete();
}
else if (processor.publisherError != null) {
processor.publishError(processor.publisherError);
}
}
else {
throw new IllegalStateException(toString());
}
}
@Override
void publishComplete(AbstractResponseBodyProcessor processor) {
processor.publisherCompleted = true;
}
@Override
void publishError(AbstractResponseBodyProcessor processor, Throwable t) {
processor.publisherError = t;
}
},
SUBSCRIBED {
@Override
void request(AbstractResponseBodyProcessor processor, long n) {
BackpressureUtils.checkRequest(n, processor.subscriber);
}
@Override
void publishComplete(AbstractResponseBodyProcessor processor) {
if (processor.changePublisherState(this, COMPLETED)) {
processor.subscriber.onComplete();
}
}
@Override
void publishError(AbstractResponseBodyProcessor processor, Throwable t) {
if (processor.changePublisherState(this, COMPLETED)) {
processor.subscriber.onError(t);
}
}
},
COMPLETED {
@Override
void request(AbstractResponseBodyProcessor processor, long n) {
// ignore
}
@Override
void cancel(AbstractResponseBodyProcessor processor) {
// ignore
}
@Override
void publishComplete(AbstractResponseBodyProcessor processor) {
// ignore
}
@Override
void publishError(AbstractResponseBodyProcessor processor, Throwable t) {
// ignore
}
};
void subscribe(AbstractResponseBodyProcessor processor,
Subscriber<? super Void> subscriber) {
throw new IllegalStateException(toString());
}
void request(AbstractResponseBodyProcessor processor, long n) {
throw new IllegalStateException(toString());
}
void cancel(AbstractResponseBodyProcessor processor) {
processor.changePublisherState(this, COMPLETED);
}
void publishComplete(AbstractResponseBodyProcessor processor) {
throw new IllegalStateException(toString());
}
void publishError(AbstractResponseBodyProcessor processor, Throwable t) {
throw new IllegalStateException(toString());
}
} }
......
...@@ -38,7 +38,6 @@ import reactor.core.publisher.Mono; ...@@ -38,7 +38,6 @@ import reactor.core.publisher.Mono;
import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpStatus;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
...@@ -91,13 +90,15 @@ public class ServletHttpHandlerAdapter extends HttpServlet { ...@@ -91,13 +90,15 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
ServletServerHttpRequest request = ServletServerHttpRequest request =
new ServletServerHttpRequest(servletRequest, requestBody); new ServletServerHttpRequest(servletRequest, requestBody);
ResponseBodySubscriber responseBody = ResponseBodyProcessor responseBody =
new ResponseBodySubscriber(synchronizer, this.bufferSize); new ResponseBodyProcessor(synchronizer, this.bufferSize);
responseBody.registerListener(); responseBody.registerListener();
ServletServerHttpResponse response = ServletServerHttpResponse response =
new ServletServerHttpResponse(servletResponse, this.dataBufferFactory, new ServletServerHttpResponse(servletResponse, this.dataBufferFactory,
publisher -> Mono publisher -> Mono.from(subscriber -> {
.from(subscriber -> publisher.subscribe(responseBody))); publisher.subscribe(responseBody);
responseBody.subscribe(subscriber);
}));
HandlerResultSubscriber resultSubscriber = HandlerResultSubscriber resultSubscriber =
new HandlerResultSubscriber(synchronizer); new HandlerResultSubscriber(synchronizer);
...@@ -129,7 +130,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet { ...@@ -129,7 +130,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
logger.error("Error from request handling. Completing the request.", ex); logger.error("Error from request handling. Completing the request.", ex);
HttpServletResponse response = HttpServletResponse response =
(HttpServletResponse) this.synchronizer.getResponse(); (HttpServletResponse) this.synchronizer.getResponse();
response.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value()); response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
this.synchronizer.complete(); this.synchronizer.complete();
} }
...@@ -206,8 +207,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet { ...@@ -206,8 +207,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
} }
} }
private static class ResponseBodyProcessor extends AbstractResponseBodyProcessor {
private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber {
private final ResponseBodyWriteListener writeListener = private final ResponseBodyWriteListener writeListener =
new ResponseBodyWriteListener(); new ResponseBodyWriteListener();
...@@ -218,7 +218,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet { ...@@ -218,7 +218,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
private volatile boolean flushOnNext; private volatile boolean flushOnNext;
public ResponseBodySubscriber(ServletAsyncContextSynchronizer synchronizer, public ResponseBodyProcessor(ServletAsyncContextSynchronizer synchronizer,
int bufferSize) { int bufferSize) {
this.synchronizer = synchronizer; this.synchronizer = synchronizer;
this.bufferSize = bufferSize; this.bufferSize = bufferSize;
...@@ -272,13 +272,6 @@ public class ServletHttpHandlerAdapter extends HttpServlet { ...@@ -272,13 +272,6 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
} }
} }
@Override
protected void writeError(Throwable t) {
HttpServletResponse response =
(HttpServletResponse) this.synchronizer.getResponse();
response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
}
@Override @Override
protected void flush() throws IOException { protected void flush() throws IOException {
ServletOutputStream output = outputStream(); ServletOutputStream output = outputStream();
...@@ -324,14 +317,14 @@ public class ServletHttpHandlerAdapter extends HttpServlet { ...@@ -324,14 +317,14 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
@Override @Override
public void onWritePossible() throws IOException { public void onWritePossible() throws IOException {
ResponseBodySubscriber.this.onWritePossible(); ResponseBodyProcessor.this.onWritePossible();
} }
@Override @Override
public void onError(Throwable ex) { public void onError(Throwable ex) {
// Error on writing to the HTTP stream, so any further writes will probably // Error on writing to the HTTP stream, so any further writes will probably
// fail. Let's log instead of calling {@link #writeError}. // fail. Let's log instead of calling {@link #writeError}.
ResponseBodySubscriber.this.logger ResponseBodyProcessor.this.logger
.error("ResponseBodyWriteListener error", ex); .error("ResponseBodyWriteListener error", ex);
} }
} }
......
...@@ -70,11 +70,8 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle ...@@ -70,11 +70,8 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle
@Override @Override
public void onError(Throwable ex) { public void onError(Throwable ex) {
if (exchange.isResponseStarted() || exchange.getStatusCode() > 500) { logger.error("Error from request handling. Completing the request.", ex);
logger.error("Error from request handling. Completing the request.", if (!exchange.isResponseStarted() && exchange.getStatusCode() <= 500) {
ex);
}
else {
exchange.setStatusCode(500); exchange.setStatusCode(500);
} }
exchange.endExchange(); exchange.endExchange();
......
...@@ -79,13 +79,14 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse ...@@ -79,13 +79,14 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse
@Override @Override
protected Mono<Void> writeWithInternal(Publisher<DataBuffer> publisher) { protected Mono<Void> writeWithInternal(Publisher<DataBuffer> publisher) {
return Mono.from(s -> { // lazily create Subscriber, since calling
// lazily create Subscriber, since calling // {@link HttpServerExchange#getResponseChannel} as done in the
// {@link HttpServerExchange#getResponseChannel} as done in the // ResponseBodyProcessor constructor commits the response status and headers
// ResponseBodySubscriber constructor commits the response status and headers return Mono.from(subscriber -> {
ResponseBodySubscriber subscriber = new ResponseBodySubscriber(this.exchange); ResponseBodyProcessor processor = new ResponseBodyProcessor(this.exchange);
subscriber.registerListener(); processor.registerListener();
publisher.subscribe(subscriber); publisher.subscribe(processor);
processor.subscribe(subscriber);
}); });
} }
...@@ -137,7 +138,7 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse ...@@ -137,7 +138,7 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse
} }
} }
private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber { private static class ResponseBodyProcessor extends AbstractResponseBodyProcessor {
private final ChannelListener<StreamSinkChannel> listener = new WriteListener(); private final ChannelListener<StreamSinkChannel> listener = new WriteListener();
...@@ -147,7 +148,7 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse ...@@ -147,7 +148,7 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse
private volatile ByteBuffer byteBuffer; private volatile ByteBuffer byteBuffer;
public ResponseBodySubscriber(HttpServerExchange exchange) { public ResponseBodyProcessor(HttpServerExchange exchange) {
this.exchange = exchange; this.exchange = exchange;
this.responseChannel = exchange.getResponseChannel(); this.responseChannel = exchange.getResponseChannel();
} }
...@@ -157,14 +158,6 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse ...@@ -157,14 +158,6 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse
this.responseChannel.resumeWrites(); this.responseChannel.resumeWrites();
} }
@Override
protected void writeError(Throwable t) {
if (!this.exchange.isResponseStarted() &&
this.exchange.getStatusCode() < 500) {
this.exchange.setStatusCode(500);
}
}
@Override @Override
protected void flush() throws IOException { protected void flush() throws IOException {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
......
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.server.reactive;
import java.io.IOException;
import java.net.URI;
import org.junit.Test;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.server.reactive.boot.ReactorHttpServer;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestTemplate;
import static org.junit.Assert.assertEquals;
import static org.junit.Assume.assumeFalse;
/**
* @author Arjen Poutsma
*/
public class ErrorHandlerIntegrationTests extends AbstractHttpHandlerIntegrationTests {
private ErrorHandler handler = new ErrorHandler();
@Override
protected HttpHandler createHttpHandler() {
return handler;
}
@Test
public void response() throws Exception {
// TODO: fix Reactor
assumeFalse(server instanceof ReactorHttpServer);
RestTemplate restTemplate = new RestTemplate();
restTemplate.setErrorHandler(NO_OP_ERROR_HANDLER);
ResponseEntity<String> response = restTemplate
.getForEntity(new URI("http://localhost:" + port + "/response"),
String.class);
assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode());
}
@Test
public void returnValue() throws Exception {
// TODO: fix Reactor
assumeFalse(server instanceof ReactorHttpServer);
RestTemplate restTemplate = new RestTemplate();
restTemplate.setErrorHandler(NO_OP_ERROR_HANDLER);
ResponseEntity<String> response = restTemplate
.getForEntity(new URI("http://localhost:" + port + "/returnValue"),
String.class);
assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode());
}
private static class ErrorHandler implements HttpHandler {
@Override
public Mono<Void> handle(ServerHttpRequest request, ServerHttpResponse response) {
Exception error = new UnsupportedOperationException();
String path = request.getURI().getPath();
if (path.endsWith("response")) {
return response.writeWith(Mono.error(error));
}
else if (path.endsWith("returnValue")) {
return Mono.error(error);
}
else {
return Mono.empty();
}
}
}
private static final ResponseErrorHandler NO_OP_ERROR_HANDLER =
new ResponseErrorHandler() {
@Override
public boolean hasError(ClientHttpResponse response) throws IOException {
return false;
}
@Override
public void handleError(ClientHttpResponse response) throws IOException {
}
};
}
...@@ -21,8 +21,6 @@ import java.util.Random; ...@@ -21,8 +21,6 @@ import java.util.Random;
import org.junit.Test; import org.junit.Test;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
...@@ -31,12 +29,9 @@ import org.springframework.core.io.buffer.DataBufferFactory; ...@@ -31,12 +29,9 @@ import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.RequestEntity; import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.http.server.reactive.boot.ReactorHttpServer;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.*;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assume.assumeFalse;
public class RandomHandlerIntegrationTests extends AbstractHttpHandlerIntegrationTests { public class RandomHandlerIntegrationTests extends AbstractHttpHandlerIntegrationTests {
...@@ -60,7 +55,6 @@ public class RandomHandlerIntegrationTests extends AbstractHttpHandlerIntegratio ...@@ -60,7 +55,6 @@ public class RandomHandlerIntegrationTests extends AbstractHttpHandlerIntegratio
@Test @Test
public void random() throws Throwable { public void random() throws Throwable {
// TODO: fix Reactor support // TODO: fix Reactor support
assumeFalse(server instanceof ReactorHttpServer);
RestTemplate restTemplate = new RestTemplate(); RestTemplate restTemplate = new RestTemplate();
...@@ -72,14 +66,6 @@ public class RandomHandlerIntegrationTests extends AbstractHttpHandlerIntegratio ...@@ -72,14 +66,6 @@ public class RandomHandlerIntegrationTests extends AbstractHttpHandlerIntegratio
assertEquals(RESPONSE_SIZE, assertEquals(RESPONSE_SIZE,
response.getHeaders().getContentLength()); response.getHeaders().getContentLength());
assertEquals(RESPONSE_SIZE, response.getBody().length); assertEquals(RESPONSE_SIZE, response.getBody().length);
while (!handler.requestComplete) {
Thread.sleep(100);
}
if (handler.requestError != null) {
throw handler.requestError;
}
assertEquals(REQUEST_SIZE, handler.requestSize);
} }
...@@ -93,45 +79,21 @@ public class RandomHandlerIntegrationTests extends AbstractHttpHandlerIntegratio ...@@ -93,45 +79,21 @@ public class RandomHandlerIntegrationTests extends AbstractHttpHandlerIntegratio
public static final int CHUNKS = 16; public static final int CHUNKS = 16;
private volatile boolean requestComplete;
private int requestSize;
private Throwable requestError;
@Override @Override
public Mono<Void> handle(ServerHttpRequest request, ServerHttpResponse response) { public Mono<Void> handle(ServerHttpRequest request, ServerHttpResponse response) {
requestError = null; Mono<Integer> requestSizeMono = request.getBody().
reduce(0, (integer, dataBuffer) -> integer +
request.getBody().subscribe(new Subscriber<DataBuffer>() { dataBuffer.readableByteCount()).
doAfterTerminate((size, throwable) -> {
@Override assertNull(throwable);
public void onSubscribe(Subscription s) { assertEquals(REQUEST_SIZE, (long) size);
requestComplete = false; });
requestSize = 0;
requestError = null;
s.request(Long.MAX_VALUE);
}
@Override
public void onNext(DataBuffer bytes) {
requestSize += bytes.readableByteCount();
}
@Override
public void onError(Throwable t) {
requestComplete = true;
requestError = t;
}
@Override
public void onComplete() {
requestComplete = true;
}
});
response.getHeaders().setContentLength(RESPONSE_SIZE); response.getHeaders().setContentLength(RESPONSE_SIZE);
return response.writeWith(multipleChunks());
return requestSizeMono.then(response.writeWith(multipleChunks()));
} }
private Publisher<DataBuffer> singleChunk() { private Publisher<DataBuffer> singleChunk() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册