未验证 提交 51a9fca3 编写于 作者: Q qingqing01 提交者: GitHub

Async memory copy (#15013)

上级 93870574
......@@ -231,11 +231,14 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
inputs[i].data.length());
} else {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(place_));
auto dst_gpu_place = boost::get<platform::CUDAPlace>(place_);
memory::Copy(dst_gpu_place, static_cast<void *>(input_ptr),
platform::CPUPlace(), inputs[i].data.data(),
inputs[i].data.length(),
0); // stream 0 for sync copy
inputs[i].data.length(), dev_ctx->stream());
#else
PADDLE_THROW("Not compile with CUDA, should not reach here.");
#endif
......
......@@ -208,11 +208,14 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
inputs[i].data.length());
} else {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(place_));
auto dst_gpu_place = boost::get<platform::CUDAPlace>(place_);
memory::Copy(dst_gpu_place, static_cast<void *>(input_ptr),
platform::CPUPlace(), inputs[i].data.data(),
inputs[i].data.length(),
0); // stream 0 for sync copy
inputs[i].data.length(), dev_ctx->stream());
#else
PADDLE_THROW("Not compile with CUDA, should not reach here.");
#endif
......
......@@ -142,12 +142,13 @@ class DensityPriorBoxOpCUDAKernel : public framework::OpKernel<T> {
vars->mutable_data<T>(ctx.GetPlace());
framework::Tensor d_temp;
framework::TensorCopySync(h_temp, ctx.GetPlace(), &d_temp);
framework::TensorCopy(h_temp, ctx.GetPlace(), &d_temp);
// At least use 32 threads, at most 512 threads.
// blockx is multiple of 32.
int blockx = std::min(
static_cast<long>(((feature_width * num_priors + 31) >> 5) << 5), 512L);
static_cast<int64_t>(((feature_width * num_priors + 31) >> 5) << 5),
512L);
int gridx = (feature_width * num_priors + blockx - 1) / blockx;
dim3 threads(blockx, 1);
dim3 grids(gridx, feature_height);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册