提交 1bec52f5 编写于 作者: J JiabinYang

test=develop, fix cpu running error

上级 bfcb5e52
...@@ -39,9 +39,6 @@ void prefetch_with_reconstruct(const std::string& id_name, ...@@ -39,9 +39,6 @@ void prefetch_with_reconstruct(const std::string& id_name,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope, const framework::Scope& scope,
framework::LoDTensor* original) { framework::LoDTensor* original) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& actual_ctx = *pool.Get(context.GetPlace());
prefetch(id_name, out_name, table_names, epmap, height_sections, context, prefetch(id_name, out_name, table_names, epmap, height_sections, context,
scope); scope);
auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>(); auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>();
...@@ -54,23 +51,30 @@ void prefetch_with_reconstruct(const std::string& id_name, ...@@ -54,23 +51,30 @@ void prefetch_with_reconstruct(const std::string& id_name,
if (!platform::is_cpu_place(ids.place())) { if (!platform::is_cpu_place(ids.place())) {
is_on_cpu_place = false; is_on_cpu_place = false;
} }
if (is_on_cpu_place) {
for (int64_t i = 0; i < ids.numel(); i++) { for (int64_t i = 0; i < ids.numel(); i++) {
const T* out_rows = out_value + original_width * i; const T* out_rows = out_value + original_width * i;
T* original_row = original_value + original_width * ids.data<int64_t>()[i]; T* original_row =
if (is_on_cpu_place) { original_value + original_width * ids.data<int64_t>()[i];
std::memcpy(original_row, out_rows, original_width * sizeof(T)); std::memcpy(original_row, out_rows, original_width * sizeof(T));
}
} else { } else {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!"); PADDLE_THROW("paddle is not compiled with CUDA!");
#else #else
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& actual_ctx = *pool.Get(context.GetPlace());
for (int64_t i = 0; i < ids.numel(); i++) {
const T* out_rows = out_value + original_width * i;
T* original_row =
original_value + original_width * ids.data<int64_t>()[i];
auto stream = auto stream =
static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream(); static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), original_row, memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), original_row,
platform::CPUPlace(), out_rows, original_width * sizeof(T), platform::CPUPlace(), out_rows, original_width * sizeof(T),
stream); stream);
#endif
} }
#endif
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册