提交 4877f5d7 编写于 作者: J JiabinYang

test=develop, fix compile error under gpu mode

上级 8515ee3a
...@@ -47,10 +47,26 @@ void prefetch_with_reconstruct(const std::string& id_name, ...@@ -47,10 +47,26 @@ void prefetch_with_reconstruct(const std::string& id_name,
auto* out_value = out.data<T>(); auto* out_value = out.data<T>();
size_t original_width = original->numel() / original->dims()[0]; size_t original_width = original->numel() / original->dims()[0];
bool is_on_cpu_place = true;
if (!platform::is_cpu_place(ids.place())) {
is_on_cpu_place = false;
}
for (int64_t i = 0; i < ids.numel(); i++) { for (int64_t i = 0; i < ids.numel(); i++) {
const T* out_rows = out_value + original_width * i; const T* out_rows = out_value + original_width * i;
T* original_row = original_value + original_width * ids.data<int64_t>()[i]; T* original_row = original_value + original_width * ids.data<int64_t>()[i];
if (is_on_cpu_place) {
std::memcpy(original_row, out_rows, original_width * sizeof(T)); std::memcpy(original_row, out_rows, original_width * sizeof(T));
} else {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
auto stream =
static_cast<platform::CUDADeviceContext*>(actual_ctx)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), out_rows,
cpu_place, original_row, original_width * sizeof(T), stream);
#endif
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册