提交 30af01fd 编写于 作者: A Arjen Poutsma

Use DataBufferUtils.write in DefaultFilePart.transferTo

This commit makes sure that in DefaultMultipartMessageReader's
DefaultFilePart, the file is not closed before all bytes are written,
by using DataBufferUtils.write (see c1b6885191d6a50347aeaa14da994f0db88f26fe).

The commit also improves on the logging of the
DefaultMultipartMessageReader.

Closes gh-23130
上级 f08656c6
......@@ -16,14 +16,9 @@
package org.springframework.http.codec.multipart;
import java.io.IOException;
import java.nio.channels.AsynchronousFileChannel;
import java.nio.channels.Channel;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Collections;
import java.util.List;
import java.util.Map;
......@@ -100,10 +95,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
return Flux.error(new CodecException("No multipart boundary found in Content-Type: \"" +
message.getHeaders().getContentType() + "\""));
}
if (logger.isTraceEnabled()) {
logger.trace("Boundary: " + toString(boundary));
}
byte[] boundaryNeedle = concat(BOUNDARY_PREFIX, boundary);
Flux<DataBuffer> body = skipUntilFirstBoundary(message.getBody(), boundary);
......@@ -148,8 +139,10 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
DataBuffer slice = dataBuffer.retainedSlice(endIdx + 1, length);
DataBufferUtils.release(dataBuffer);
if (logger.isTraceEnabled()) {
logger.trace("Found first boundary at " + endIdx + " in " + toString(dataBuffer));
}
logger.trace(
"Found last byte of first boundary (" + toString(boundary)
+ ") at " + endIdx);
}
return Mono.just(slice);
}
else {
......@@ -188,14 +181,14 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
}
}
if (logger.isTraceEnabled()) {
logger.trace("Part data: " + toString(dataBuffer));
}
int endIdx = HEADER_MATCHER.match(dataBuffer);
HttpHeaders headers;
DataBuffer body;
if (endIdx > 0) {
if (logger.isTraceEnabled()) {
logger.trace("Found last byte of part header at " + endIdx );
}
readPosition = dataBuffer.readPosition();
int headersLength = endIdx + 1 - (readPosition + HEADER_BODY_SEPARATOR.length);
DataBuffer headersBuffer = dataBuffer.retainedSlice(readPosition, headersLength);
......@@ -204,6 +197,9 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
headers = toHeaders(headersBuffer);
}
else {
if (logger.isTraceEnabled()) {
logger.trace("No header found");
}
headers = new HttpHeaders();
body = DataBufferUtils.retain(dataBuffer);
}
......@@ -252,16 +248,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
return result;
}
private static String toString(DataBuffer dataBuffer) {
byte[] bytes = new byte[dataBuffer.readableByteCount()];
int j = 0;
for (int i = dataBuffer.readPosition(); i < dataBuffer.writePosition(); i++) {
bytes[j++] = dataBuffer.getByte(i);
}
return toString(bytes);
}
private static String toString(byte[] bytes) {
StringBuilder builder = new StringBuilder();
for (byte b : bytes) {
......@@ -368,10 +354,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
private static class DefaultFilePart extends DefaultPart implements FilePart {
private static final OpenOption[] FILE_CHANNEL_OPTIONS =
{StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE};
public DefaultFilePart(HttpHeaders headers, DataBuffer body) {
super(headers, body);
}
......@@ -385,23 +367,9 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
@Override
public Mono<Void> transferTo(Path dest) {
return Mono.using(() -> AsynchronousFileChannel.open(dest, FILE_CHANNEL_OPTIONS),
this::writeBody, this::close);
}
private Mono<Void> writeBody(AsynchronousFileChannel channel) {
return DataBufferUtils.write(content(), channel)
.map(DataBufferUtils::release)
.then();
return DataBufferUtils.write(content(), dest);
}
private void close(Channel channel) {
try {
channel.close();
}
catch (IOException ignore) {
}
}
}
}
......@@ -16,6 +16,10 @@
package org.springframework.web.reactive.result.method.annotation;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.stream.Collectors;
......@@ -31,6 +35,7 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.MultipartBodyBuilder;
......@@ -145,6 +150,34 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes
.verifyComplete();
}
@Test
public void transferTo() {
Flux<String> result = webClient
.post()
.uri("/transferTo")
.syncBody(generateBody())
.retrieve()
.bodyToFlux(String.class);
StepVerifier.create(result)
.consumeNextWith(filename -> verifyContents(Paths.get(filename), new ClassPathResource("foo.txt", MultipartHttpMessageReader.class)))
.consumeNextWith(filename -> verifyContents(Paths.get(filename), new ClassPathResource("logo.png", getClass())))
.verifyComplete();
}
private static void verifyContents(Path tempFile, Resource resource) {
try {
byte[] tempBytes = Files.readAllBytes(tempFile);
byte[] resourceBytes = Files.readAllBytes(resource.getFile().toPath());
assertThat(tempBytes).isEqualTo(resourceBytes);
}
catch (IOException ex) {
throw new AssertionError(ex);
}
}
@Test
public void modelAttribute() {
Mono<String> result = webClient
......@@ -217,6 +250,21 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes
return partFluxDescription(Flux.from(parts));
}
@PostMapping("/transferTo")
Flux<String> transferTo(@RequestPart("fileParts") Flux<FilePart> parts) {
return parts.flatMap(filePart -> {
try {
Path tempFile = Files.createTempFile("MultipartIntegrationTests", filePart.filename());
return filePart.transferTo(tempFile)
.then(Mono.just(tempFile.toString() + "\n"));
}
catch (IOException e) {
return Mono.error(e);
}
});
}
@PostMapping("/modelAttribute")
String modelAttribute(@ModelAttribute FormBean formBean) {
return formBean.toString();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册