提交 837da796 编写于 作者: Y Yang Yu

Merge branch 'feature/enhance_dev_ctx_pool' into feature/is_nan

......@@ -63,9 +63,10 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
tensor.dims(), platform::CPUPlace()));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance;
auto dev_ctx = static_cast<const platform::CUDADeviceContext *>(
pool.Borrow(tensor.place()));
pool.Get(tensor.place()));
paddle::platform::GpuMemcpyAsync(
dst_ptr, src_ptr, sizeof(CUR_TYPE) * tensor.numel(),
......@@ -137,9 +138,9 @@ void PyCUDATensorSetFromArray(
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Borrow(place));
static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(),
cudaMemcpyHostToDevice, dev_ctx->stream());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册