提交 f9e6ea54 编写于 作者: R Rossen Stoyanchev

MvcResult returns asyncResult after asyncDispatch

Issue: SPR-16648
上级 e6020ed3
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -65,7 +65,14 @@ public class MockAsyncContext implements AsyncContext {
public void addDispatchHandler(Runnable handler) {
Assert.notNull(handler, "Dispatch handler must not be null");
this.dispatchHandlers.add(handler);
synchronized (this) {
if (this.dispatchedPath == null) {
this.dispatchHandlers.add(handler);
}
else {
handler.run();
}
}
}
@Override
......@@ -96,9 +103,9 @@ public class MockAsyncContext implements AsyncContext {
@Override
public void dispatch(@Nullable ServletContext context, String path) {
this.dispatchedPath = path;
for (Runnable r : this.dispatchHandlers) {
r.run();
synchronized (this) {
this.dispatchedPath = path;
this.dispatchHandlers.forEach(Runnable::run);
}
}
......
......@@ -16,11 +16,14 @@
package org.springframework.test.web.servlet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.util.Assert;
import org.springframework.web.servlet.FlashMap;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
......@@ -56,6 +59,9 @@ class DefaultMvcResult implements MvcResult {
private final AtomicReference<Object> asyncResult = new AtomicReference<>(RESULT_NONE);
@Nullable
private CountDownLatch asyncDispatchLatch;
/**
* Create a new instance with the given request and response.
......@@ -135,27 +141,31 @@ class DefaultMvcResult implements MvcResult {
if (this.mockRequest.getAsyncContext() != null) {
timeToWait = (timeToWait == -1 ? this.mockRequest.getAsyncContext().getTimeout() : timeToWait);
}
if (timeToWait > 0) {
long endTime = System.currentTimeMillis() + timeToWait;
while (System.currentTimeMillis() < endTime && this.asyncResult.get() == RESULT_NONE) {
try {
Thread.sleep(100);
}
catch (InterruptedException ex) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Interrupted while waiting for " +
"async result to be set for handler [" + this.handler + "]", ex);
}
}
if (!awaitAsyncDispatch(timeToWait)) {
throw new IllegalStateException("Async result for handler [" + this.handler + "]" +
" was not set during the specified timeToWait=" + timeToWait);
}
Object result = this.asyncResult.get();
if (result == RESULT_NONE) {
throw new IllegalStateException("Async result for handler [" + this.handler + "] " +
"was not set during the specified timeToWait=" + timeToWait);
Assert.state(result != RESULT_NONE, "Async result for handler [" + this.handler + "] was not set");
return this.asyncResult.get();
}
/**
* True if is there a latch was not set, or the latch count reached 0.
*/
private boolean awaitAsyncDispatch(long timeout) {
Assert.state(this.asyncDispatchLatch != null,
"The asynDispatch CountDownLatch was not set by the TestDispatcherServlet.\n");
try {
return this.asyncDispatchLatch.await(timeout, TimeUnit.MILLISECONDS);
}
catch (InterruptedException e) {
return false;
}
return result;
}
void setAsyncDispatchLatch(CountDownLatch asyncDispatchLatch) {
this.asyncDispatchLatch = asyncDispatchLatch;
}
}
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -18,18 +18,19 @@ package org.springframework.test.web.servlet;
import java.io.IOException;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockAsyncContext;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.async.CallableProcessingInterceptor;
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.DeferredResultProcessingInterceptor;
import org.springframework.web.context.request.async.WebAsyncManager;
import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerExecutionChain;
......@@ -63,23 +64,34 @@ final class TestDispatcherServlet extends DispatcherServlet {
throws ServletException, IOException {
registerAsyncResultInterceptors(request);
super.service(request, response);
if (request.getAsyncContext() != null) {
CountDownLatch dispatchLatch = new CountDownLatch(1);
((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(dispatchLatch::countDown);
getMvcResult(request).setAsyncDispatchLatch(dispatchLatch);
}
}
private void registerAsyncResultInterceptors(final HttpServletRequest request) {
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
asyncManager.registerCallableInterceptor(KEY, new CallableProcessingInterceptor() {
@Override
public <T> void postProcess(NativeWebRequest r, Callable<T> task, Object value) throws Exception {
getMvcResult(request).setAsyncResult(value);
}
});
asyncManager.registerDeferredResultInterceptor(KEY, new DeferredResultProcessingInterceptor() {
@Override
public <T> void postProcess(NativeWebRequest r, DeferredResult<T> result, Object value) throws Exception {
getMvcResult(request).setAsyncResult(value);
}
});
WebAsyncUtils.getAsyncManager(request).registerCallableInterceptor(KEY,
new CallableProcessingInterceptor() {
@Override
public <T> void postProcess(NativeWebRequest r, Callable<T> task, Object value) {
// We got the result, must also wait for the dispatch
getMvcResult(request).setAsyncResult(value);
}
});
WebAsyncUtils.getAsyncManager(request).registerDeferredResultInterceptor(KEY,
new DeferredResultProcessingInterceptor() {
@Override
public <T> void postProcess(NativeWebRequest r, DeferredResult<T> result, Object value) {
getMvcResult(request).setAsyncResult(value);
}
});
}
protected DefaultMvcResult getMvcResult(ServletRequest request) {
......
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -15,6 +15,8 @@
*/
package org.springframework.test.web.servlet;
import java.util.concurrent.CountDownLatch;
import org.junit.Before;
import org.junit.Test;
......@@ -38,13 +40,14 @@ public class DefaultMvcResultTests {
}
@Test
public void getAsyncResultSuccess() throws Exception {
public void getAsyncResultSuccess() {
this.mvcResult.setAsyncResult("Foo");
assertEquals("Foo", this.mvcResult.getAsyncResult());
this.mvcResult.setAsyncDispatchLatch(new CountDownLatch(0));
this.mvcResult.getAsyncResult();
}
@Test(expected = IllegalStateException.class)
public void getAsyncResultFailure() throws Exception {
public void getAsyncResultFailure() {
this.mvcResult.getAsyncResult(0);
}
......
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -30,6 +30,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.BeanUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.util.WebUtils;
......@@ -43,10 +44,12 @@ public class MockAsyncContext implements AsyncContext {
private final HttpServletRequest request;
@Nullable
private final HttpServletResponse response;
private final List<AsyncListener> listeners = new ArrayList<>();
@Nullable
private String dispatchedPath;
private long timeout = 10 * 1000L; // 10 seconds is Tomcat's default
......@@ -54,7 +57,7 @@ public class MockAsyncContext implements AsyncContext {
private final List<Runnable> dispatchHandlers = new ArrayList<>();
public MockAsyncContext(ServletRequest request, ServletResponse response) {
public MockAsyncContext(ServletRequest request, @Nullable ServletResponse response) {
this.request = (HttpServletRequest) request;
this.response = (HttpServletResponse) response;
}
......@@ -62,7 +65,14 @@ public class MockAsyncContext implements AsyncContext {
public void addDispatchHandler(Runnable handler) {
Assert.notNull(handler, "Dispatch handler must not be null");
this.dispatchHandlers.add(handler);
synchronized (this) {
if (this.dispatchedPath == null) {
this.dispatchHandlers.add(handler);
}
else {
handler.run();
}
}
}
@Override
......@@ -71,6 +81,7 @@ public class MockAsyncContext implements AsyncContext {
}
@Override
@Nullable
public ServletResponse getResponse() {
return this.response;
}
......@@ -91,13 +102,14 @@ public class MockAsyncContext implements AsyncContext {
}
@Override
public void dispatch(ServletContext context, String path) {
this.dispatchedPath = path;
for (Runnable r : this.dispatchHandlers) {
r.run();
public void dispatch(@Nullable ServletContext context, String path) {
synchronized (this) {
this.dispatchedPath = path;
this.dispatchHandlers.forEach(Runnable::run);
}
}
@Nullable
public String getDispatchedPath() {
return this.dispatchedPath;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册