diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.h b/paddle/fluid/operators/distributed/parameter_prefetch.h index 89671bd741e0a08873cf727150a1ad3a91211f66..47d082c4af54a2a6591432dbed617f74ada240b9 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.h +++ b/paddle/fluid/operators/distributed/parameter_prefetch.h @@ -39,6 +39,9 @@ void prefetch_with_reconstruct(const std::string& id_name, const framework::ExecutionContext& context, const framework::Scope& scope, framework::LoDTensor* original) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& actual_ctx = *pool.Get(context.GetPlace()); + prefetch(id_name, out_name, table_names, epmap, height_sections, context, scope); auto& out = scope.FindVar(out_name)->Get(); @@ -62,9 +65,10 @@ void prefetch_with_reconstruct(const std::string& id_name, PADDLE_THROW("paddle is not compiled with CUDA!"); #else auto stream = - static_cast(actual_ctx)->stream(); - memory::Copy(boost::get(ids.place()), out_rows, - cpu_place, original_row, original_width * sizeof(T), stream); + static_cast(&actual_ctx)->stream(); + memory::Copy(boost::get(ids.place()), original_row, + platform::CPUPlace(), out_rows, original_width * sizeof(T), + stream); #endif } } diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_remote_table_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_remote_table_op.py index 9ed6c94bd2066231772571db1c6376d771f9aed2..da343dd503a62e83f431dd0ffb02a7e70be7d0d5 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_remote_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_remote_table_op.py @@ -253,8 +253,6 @@ class TestListenAndServOp(unittest.TestCase): port1 = self._get_pserver_port(p1.pid) places = [core.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(core.CUDAPlace(0)) for place in places: self._run_hsigmoid_op_one_pserver(place, port0) diff --git a/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py b/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py index b5f93f93a1bc8440ea0fe5893073b66e864b3884..cc6f40de86e302605a416c48790c74cbb431b2e3 100644 --- a/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py @@ -221,8 +221,6 @@ class TestListenAndServOp(unittest.TestCase): port1 = self._get_pserver_port(p1.pid) places = [core.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(core.CUDAPlace(0)) for place in places: self._run_nce_op_two_pserver(place, port0, port1)