提交 bfcb5e52 编写于 作者: J JiabinYang

test=develop, fix gpu compile error on prefetch, and fix hs/nce ut failed on gpu

上级 4877f5d7
...@@ -39,6 +39,9 @@ void prefetch_with_reconstruct(const std::string& id_name, ...@@ -39,6 +39,9 @@ void prefetch_with_reconstruct(const std::string& id_name,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope, const framework::Scope& scope,
framework::LoDTensor* original) { framework::LoDTensor* original) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& actual_ctx = *pool.Get(context.GetPlace());
prefetch(id_name, out_name, table_names, epmap, height_sections, context, prefetch(id_name, out_name, table_names, epmap, height_sections, context,
scope); scope);
auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>(); auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>();
...@@ -62,9 +65,10 @@ void prefetch_with_reconstruct(const std::string& id_name, ...@@ -62,9 +65,10 @@ void prefetch_with_reconstruct(const std::string& id_name,
PADDLE_THROW("paddle is not compiled with CUDA!"); PADDLE_THROW("paddle is not compiled with CUDA!");
#else #else
auto stream = auto stream =
static_cast<platform::CUDADeviceContext*>(actual_ctx)->stream(); static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), out_rows, memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), original_row,
cpu_place, original_row, original_width * sizeof(T), stream); platform::CPUPlace(), out_rows, original_width * sizeof(T),
stream);
#endif #endif
} }
} }
......
...@@ -253,8 +253,6 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -253,8 +253,6 @@ class TestListenAndServOp(unittest.TestCase):
port1 = self._get_pserver_port(p1.pid) port1 = self._get_pserver_port(p1.pid)
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places: for place in places:
self._run_hsigmoid_op_one_pserver(place, port0) self._run_hsigmoid_op_one_pserver(place, port0)
......
...@@ -221,8 +221,6 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -221,8 +221,6 @@ class TestListenAndServOp(unittest.TestCase):
port1 = self._get_pserver_port(p1.pid) port1 = self._get_pserver_port(p1.pid)
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places: for place in places:
self._run_nce_op_two_pserver(place, port0, port1) self._run_nce_op_two_pserver(place, port0, port1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册