未验证 提交 10abdb8f 编写于 作者: K kangguangli 提交者: GitHub

fix memcpy_h2d bug related to cuda stream setting when allocate memory (#45450)

* fix memcpy_h2d bug related to cuda stream setting when allocate memory

* add header file

* fix compile error for cpu only
上级 8b24c795
...@@ -52,7 +52,6 @@ void Copy(const Context& dev_ctx, ...@@ -52,7 +52,6 @@ void Copy(const Context& dev_ctx,
<< dst_place; << dst_place;
dst->Resize(src.dims()); dst->Resize(src.dims());
dst->mutable_data(dst_place);
void* dst_ptr = nullptr; void* dst_ptr = nullptr;
if (paddle::platform::is_cpu_place(dst_place)) { if (paddle::platform::is_cpu_place(dst_place)) {
......
...@@ -16,13 +16,41 @@ ...@@ -16,13 +16,41 @@
#include <vector> #include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/stream.h"
namespace phi { namespace phi {
static constexpr size_t WAIT_THRESHOLD = 64 * 1024; static constexpr size_t WAIT_THRESHOLD = 64 * 1024;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <>
void MemcpyH2DKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
int dst_place_type,
DenseTensor* out) {
PADDLE_ENFORCE_GE(
dst_place_type,
0,
errors::OutOfRange("dst_place_type only support 0-3, but got: %d",
dst_place_type));
PADDLE_ENFORCE_LE(
dst_place_type,
3,
errors::OutOfRange("dst_place_type only support 0-3, but got: %d",
dst_place_type));
auto stream = dev_ctx.stream();
out->mutable_data(dev_ctx.GetPlace(),
x.dtype(),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
#endif
template <typename Context> template <typename Context>
void MemcpyH2DKernel(const Context& dev_ctx, void MemcpyH2DKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -39,7 +67,6 @@ void MemcpyH2DKernel(const Context& dev_ctx, ...@@ -39,7 +67,6 @@ void MemcpyH2DKernel(const Context& dev_ctx,
errors::OutOfRange("dst_place_type only support 0-3, but got: %d", errors::OutOfRange("dst_place_type only support 0-3, but got: %d",
dst_place_type)); dst_place_type));
// Copy will set the stream of the tensor while setting blocking to false
Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
} }
...@@ -48,9 +75,12 @@ void MemcpyD2HKernel(const Context& dev_ctx, ...@@ -48,9 +75,12 @@ void MemcpyD2HKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int dst_place_type, int dst_place_type,
DenseTensor* out) { DenseTensor* out) {
// Copy will set the stream of the tensor while setting blocking to false
switch (dst_place_type) { switch (dst_place_type) {
case 0: case 0:
// NOTE(lvyongkang): phi::Copy will use DeviceContext.zero_allocator to
// alloc and assign DeviceContext.place to out, which causes place check
// fails. So we specify out's place here.
out->mutable_data(CPUPlace());
Copy(dev_ctx, x, CPUPlace(), false, out); Copy(dev_ctx, x, CPUPlace(), false, out);
// NOTE(copy from Aurelius84): host <-> device memory copies of a memory // NOTE(copy from Aurelius84): host <-> device memory copies of a memory
// block of 64 KB or less are asynchronous. See // block of 64 KB or less are asynchronous. See
...@@ -61,6 +91,10 @@ void MemcpyD2HKernel(const Context& dev_ctx, ...@@ -61,6 +91,10 @@ void MemcpyD2HKernel(const Context& dev_ctx,
break; break;
case 1: case 1:
// NOTE(lvyongkang): phi::Copy will use DeviceContext.zero_allocator to
// alloc and assign DeviceContext.place to out, which causes place check
// fails. So we specify out's place here.
out->mutable_data(GPUPinnedPlace());
Copy(dev_ctx, x, GPUPinnedPlace(), false, out); Copy(dev_ctx, x, GPUPinnedPlace(), false, out);
// paddle::memory::Copy use async copy for GPUPinnedPlace // paddle::memory::Copy use async copy for GPUPinnedPlace
dev_ctx.Wait(); dev_ctx.Wait();
...@@ -89,9 +123,9 @@ void MemcpyD2HMultiIOKernel(const Context& dev_ctx, ...@@ -89,9 +123,9 @@ void MemcpyD2HMultiIOKernel(const Context& dev_ctx,
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
array[i], array[i],
errors::PreconditionNotMet("input tesnor %d should not be nullptr", i)); errors::PreconditionNotMet("input tesnor %d should not be nullptr", i));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(out_array[i],
out_array[i], errors::PreconditionNotMet(
errors::PreconditionNotMet("input tesnor %d should not be nullptr", i)); "output tesnor %d should not be nullptr", i));
const auto& x = *(array[i]); const auto& x = *(array[i]);
MemcpyD2HKernel<Context>(dev_ctx, x, dst_place_type, out_array[i]); MemcpyD2HKernel<Context>(dev_ctx, x, dst_place_type, out_array[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册