提交 1d48e7c5 编写于 作者: A Arjen Poutsma

Allow to set response status on Undertow

Refactored Undertow support to register a response listener only when
the body is written to, as opposed to registering it at startup. The
reason for this is that getting the response channel from the
HttpServerExchange commits the status and response, making it impossible
to change them after the fact.

Fixed issue #119.
上级 2418ff0a
......@@ -16,29 +16,19 @@
package org.springframework.http.server.reactive;
import java.io.IOException;
import java.nio.ByteBuffer;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.server.HttpServerExchange;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.xnio.ChannelListener;
import org.xnio.ChannelListeners;
import org.xnio.IoUtils;
import org.xnio.channels.StreamSinkChannel;
import org.xnio.channels.StreamSourceChannel;
import reactor.core.publisher.Mono;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.util.Assert;
/**
* @author Marek Hawrylczak
* @author Rossen Stoyanchev
* @author Arjen Poutsma
*/
public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandler {
......@@ -60,20 +50,11 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle
@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
RequestBodyPublisher requestBody =
new RequestBodyPublisher(exchange, this.dataBufferFactory);
requestBody.registerListener();
ServerHttpRequest request = new UndertowServerHttpRequest(exchange, requestBody);
ServerHttpRequest request =
new UndertowServerHttpRequest(exchange, this.dataBufferFactory);
StreamSinkChannel responseChannel = exchange.getResponseChannel();
ResponseBodySubscriber responseBody =
new ResponseBodySubscriber(exchange, responseChannel);
responseBody.registerListener();
ServerHttpResponse response =
new UndertowServerHttpResponse(exchange, responseChannel,
publisher -> Mono
.from(subscriber -> publisher.subscribe(responseBody)),
this.dataBufferFactory);
new UndertowServerHttpResponse(exchange, this.dataBufferFactory);
this.delegate.handle(request, response).subscribe(new Subscriber<Void>() {
......@@ -106,183 +87,4 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle
});
}
private static class RequestBodyPublisher extends AbstractRequestBodyPublisher {
private final ChannelListener<StreamSourceChannel> readListener =
new ReadListener();
private final ChannelListener<StreamSourceChannel> closeListener =
new CloseListener();
private final StreamSourceChannel requestChannel;
private final DataBufferFactory dataBufferFactory;
private final PooledByteBuffer pooledByteBuffer;
public RequestBodyPublisher(HttpServerExchange exchange,
DataBufferFactory dataBufferFactory) {
this.requestChannel = exchange.getRequestChannel();
this.pooledByteBuffer =
exchange.getConnection().getByteBufferPool().allocate();
this.dataBufferFactory = dataBufferFactory;
}
public void registerListener() {
this.requestChannel.getReadSetter().set(this.readListener);
this.requestChannel.getCloseSetter().set(this.closeListener);
this.requestChannel.resumeReads();
}
@Override
protected DataBuffer read() throws IOException {
ByteBuffer byteBuffer = this.pooledByteBuffer.getBuffer();
int read = this.requestChannel.read(byteBuffer);
if (logger.isTraceEnabled()) {
logger.trace("read:" + read);
}
if (read > 0) {
byteBuffer.flip();
return this.dataBufferFactory.wrap(byteBuffer);
}
else if (read == -1) {
onAllDataRead();
}
return null;
}
@Override
protected void close() {
if (this.pooledByteBuffer != null) {
IoUtils.safeClose(this.pooledByteBuffer);
}
if (this.requestChannel != null) {
IoUtils.safeClose(this.requestChannel);
}
}
private class ReadListener implements ChannelListener<StreamSourceChannel> {
@Override
public void handleEvent(StreamSourceChannel channel) {
onDataAvailable();
}
}
private class CloseListener implements ChannelListener<StreamSourceChannel> {
@Override
public void handleEvent(StreamSourceChannel channel) {
onAllDataRead();
}
}
}
private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber {
private final ChannelListener<StreamSinkChannel> listener =
new ResponseBodyListener();
private final HttpServerExchange exchange;
private final StreamSinkChannel responseChannel;
private volatile ByteBuffer byteBuffer;
public ResponseBodySubscriber(HttpServerExchange exchange,
StreamSinkChannel responseChannel) {
this.exchange = exchange;
this.responseChannel = responseChannel;
}
public void registerListener() {
this.responseChannel.getWriteSetter().set(this.listener);
this.responseChannel.resumeWrites();
}
@Override
protected void writeError(Throwable t) {
if (!this.exchange.isResponseStarted() &&
this.exchange.getStatusCode() < 500) {
this.exchange.setStatusCode(500);
}
}
@Override
protected void flush() throws IOException {
if (logger.isTraceEnabled()) {
logger.trace("flush");
}
this.responseChannel.flush();
}
@Override
protected boolean write(DataBuffer dataBuffer) throws IOException {
if (this.byteBuffer == null) {
return false;
}
if (logger.isTraceEnabled()) {
logger.trace("write: " + dataBuffer);
}
int total = this.byteBuffer.remaining();
int written = writeByteBuffer(this.byteBuffer);
if (logger.isTraceEnabled()) {
logger.trace("written: " + written + " total: " + total);
}
return written == total;
}
private int writeByteBuffer(ByteBuffer byteBuffer) throws IOException {
int written;
int totalWritten = 0;
do {
written = this.responseChannel.write(byteBuffer);
totalWritten += written;
}
while (byteBuffer.hasRemaining() && written > 0);
return totalWritten;
}
@Override
protected void receiveBuffer(DataBuffer dataBuffer) {
super.receiveBuffer(dataBuffer);
this.byteBuffer = dataBuffer.asByteBuffer();
}
@Override
protected void releaseBuffer() {
super.releaseBuffer();
this.byteBuffer = null;
}
@Override
protected void close() {
try {
this.responseChannel.shutdownWrites();
if (!this.responseChannel.flush()) {
this.responseChannel.getWriteSetter().set(ChannelListeners
.flushingChannelListener(
o -> IoUtils.safeClose(this.responseChannel),
ChannelListeners.closingChannelExceptionHandler()));
this.responseChannel.resumeWrites();
}
}
catch (IOException ignored) {
}
}
private class ResponseBodyListener implements ChannelListener<StreamSinkChannel> {
@Override
public void handleEvent(StreamSinkChannel channel) {
onWritePossible();
}
}
}
}
\ No newline at end of file
......@@ -16,16 +16,22 @@
package org.springframework.http.server.reactive;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.handlers.Cookie;
import io.undertow.util.HeaderValues;
import org.reactivestreams.Publisher;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.channels.StreamSourceChannel;
import reactor.core.publisher.Flux;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
......@@ -43,14 +49,14 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest {
private final HttpServerExchange exchange;
private final Flux<DataBuffer> body;
private final RequestBodyPublisher body;
public UndertowServerHttpRequest(HttpServerExchange exchange,
Publisher<DataBuffer> body) {
DataBufferFactory dataBufferFactory) {
Assert.notNull(exchange, "'exchange' is required.");
Assert.notNull(exchange, "'body' is required.");
this.exchange = exchange;
this.body = Flux.from(body);
this.body = new RequestBodyPublisher(exchange, dataBufferFactory);
this.body.registerListener();
}
......@@ -92,7 +98,79 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest {
@Override
public Flux<DataBuffer> getBody() {
return this.body;
return Flux.from(this.body);
}
private static class RequestBodyPublisher extends AbstractRequestBodyPublisher {
private final ChannelListener<StreamSourceChannel> readListener =
new ReadListener();
private final ChannelListener<StreamSourceChannel> closeListener =
new CloseListener();
private final StreamSourceChannel requestChannel;
private final DataBufferFactory dataBufferFactory;
private final PooledByteBuffer pooledByteBuffer;
public RequestBodyPublisher(HttpServerExchange exchange,
DataBufferFactory dataBufferFactory) {
this.requestChannel = exchange.getRequestChannel();
this.pooledByteBuffer =
exchange.getConnection().getByteBufferPool().allocate();
this.dataBufferFactory = dataBufferFactory;
}
private void registerListener() {
this.requestChannel.getReadSetter().set(this.readListener);
this.requestChannel.getCloseSetter().set(this.closeListener);
this.requestChannel.resumeReads();
}
@Override
protected DataBuffer read() throws IOException {
ByteBuffer byteBuffer = this.pooledByteBuffer.getBuffer();
int read = this.requestChannel.read(byteBuffer);
if (logger.isTraceEnabled()) {
logger.trace("read:" + read);
}
if (read > 0) {
byteBuffer.flip();
return this.dataBufferFactory.wrap(byteBuffer);
}
else if (read == -1) {
onAllDataRead();
}
return null;
}
@Override
protected void close() {
if (this.pooledByteBuffer != null) {
IoUtils.safeClose(this.pooledByteBuffer);
}
if (this.requestChannel != null) {
IoUtils.safeClose(this.requestChannel);
}
}
private class ReadListener implements ChannelListener<StreamSourceChannel> {
@Override
public void handleEvent(StreamSourceChannel channel) {
onDataAvailable();
}
}
private class CloseListener implements ChannelListener<StreamSourceChannel> {
@Override
public void handleEvent(StreamSourceChannel channel) {
onAllDataRead();
}
}
}
}
......@@ -19,16 +19,19 @@ package org.springframework.http.server.reactive;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.handlers.Cookie;
import io.undertow.server.handlers.CookieImpl;
import io.undertow.util.HttpString;
import org.reactivestreams.Publisher;
import org.xnio.ChannelListener;
import org.xnio.ChannelListeners;
import org.xnio.IoUtils;
import org.xnio.channels.StreamSinkChannel;
import reactor.core.publisher.Mono;
......@@ -44,27 +47,18 @@ import org.springframework.util.Assert;
*
* @author Marek Hawrylczak
* @author Rossen Stoyanchev
* @author Arjen Poutsma
*/
public class UndertowServerHttpResponse extends AbstractServerHttpResponse
implements ZeroCopyHttpOutputMessage {
private final HttpServerExchange exchange;
private final StreamSinkChannel responseChannel;
private final Function<Publisher<DataBuffer>, Mono<Void>> responseBodyWriter;
public UndertowServerHttpResponse(HttpServerExchange exchange,
StreamSinkChannel responseChannel,
Function<Publisher<DataBuffer>, Mono<Void>> responseBodyWriter,
DataBufferFactory dataBufferFactory) {
super(dataBufferFactory);
Assert.notNull(exchange, "'exchange' is required.");
Assert.notNull(responseChannel, "'responseChannel' must not be null");
Assert.notNull(responseBodyWriter, "'responseBodyWriter' must not be null");
this.exchange = exchange;
this.responseChannel = responseChannel;
this.responseBodyWriter = responseBodyWriter;
}
......@@ -80,16 +74,26 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse
@Override
protected Mono<Void> writeWithInternal(Publisher<DataBuffer> publisher) {
return this.responseBodyWriter.apply(publisher);
return Mono.from(s -> {
// lazily create Subscriber, since calling
// {@link HttpServerExchange#getResponseChannel} as done in the
// ResponseBodySubscriber constructor commits the response status and headers
ResponseBodySubscriber subscriber = new ResponseBodySubscriber(this.exchange);
subscriber.registerListener();
publisher.subscribe(subscriber);
});
}
@Override
public Mono<Void> writeWith(File file, long position, long count) {
writeHeaders();
writeCookies();
try {
StreamSinkChannel responseChannel =
getUndertowExchange().getResponseChannel();
FileChannel in = new FileInputStream(file).getChannel();
long result = this.responseChannel.transferFrom(in, position, count);
long result = responseChannel.transferFrom(in, position, count);
if (result < count) {
return Mono.error(new IOException("Could only write " + result +
" out of " + count + " bytes"));
......@@ -128,4 +132,107 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse
}
}
private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber {
private final ChannelListener<StreamSinkChannel> listener = new WriteListener();
private final HttpServerExchange exchange;
private final StreamSinkChannel responseChannel;
private volatile ByteBuffer byteBuffer;
public ResponseBodySubscriber(HttpServerExchange exchange) {
this.exchange = exchange;
this.responseChannel = exchange.getResponseChannel();
}
public void registerListener() {
this.responseChannel.getWriteSetter().set(this.listener);
this.responseChannel.resumeWrites();
}
@Override
protected void writeError(Throwable t) {
if (!this.exchange.isResponseStarted() &&
this.exchange.getStatusCode() < 500) {
this.exchange.setStatusCode(500);
}
}
@Override
protected void flush() throws IOException {
if (logger.isTraceEnabled()) {
logger.trace("flush");
}
this.responseChannel.flush();
}
@Override
protected boolean write(DataBuffer dataBuffer) throws IOException {
if (this.byteBuffer == null) {
return false;
}
if (logger.isTraceEnabled()) {
logger.trace("write: " + dataBuffer);
}
int total = this.byteBuffer.remaining();
int written = writeByteBuffer(this.byteBuffer);
if (logger.isTraceEnabled()) {
logger.trace("written: " + written + " total: " + total);
}
return written == total;
}
private int writeByteBuffer(ByteBuffer byteBuffer) throws IOException {
int written;
int totalWritten = 0;
do {
written = this.responseChannel.write(byteBuffer);
totalWritten += written;
}
while (byteBuffer.hasRemaining() && written > 0);
return totalWritten;
}
@Override
protected void receiveBuffer(DataBuffer dataBuffer) {
super.receiveBuffer(dataBuffer);
this.byteBuffer = dataBuffer.asByteBuffer();
}
@Override
protected void releaseBuffer() {
super.releaseBuffer();
this.byteBuffer = null;
}
@Override
protected void close() {
try {
this.responseChannel.shutdownWrites();
if (!this.responseChannel.flush()) {
this.responseChannel.getWriteSetter().set(ChannelListeners
.flushingChannelListener(
o -> IoUtils.safeClose(this.responseChannel),
ChannelListeners.closingChannelExceptionHandler()));
this.responseChannel.resumeWrites();
}
}
catch (IOException ignored) {
}
}
private class WriteListener implements ChannelListener<StreamSinkChannel> {
@Override
public void handleEvent(StreamSinkChannel channel) {
onWritePossible();
}
}
}
}
......@@ -67,9 +67,7 @@ import org.springframework.web.reactive.config.WebReactiveConfiguration;
import org.springframework.web.reactive.result.view.freemarker.FreeMarkerConfigurer;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
......@@ -173,7 +171,6 @@ public class RequestMappingIntegrationTests extends AbstractHttpHandlerIntegrati
}
@Test
@Ignore // Issue #119
public void serializeAsMonoResponseEntity() throws Exception {
serializeAsPojo("http://localhost:" + port + "/monoResponseEntity");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册