未验证 提交 a1ad3a63 编写于 作者: S sneaxiy 提交者: GitHub

Fix CUDA Graph H2D bug by restore host memory (#37774)

* fix CUDA Graph H2D bug again

* fix no return bug
上级 9a2d327c
......@@ -287,13 +287,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
const T** dev_ins_data = nullptr;
if (!has_same_shape || in_num < 2 || in_num > 4) {
tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*));
{
platform::SkipCUDAGraphCaptureGuard guard;
auto* restored =
platform::RestoreHostMemIfCapturingCUDAGraph(inputs_data, in_num);
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_data), in_num * sizeof(T*),
context.stream());
}
tmp_dev_ins_data->ptr(), platform::CPUPlace(), restored,
in_num * sizeof(T*), context.stream());
dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
}
......@@ -317,13 +315,12 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
} else {
auto tmp_dev_ins_col_data =
memory::Alloc(context, inputs_col_num * sizeof(int64_t));
{
platform::SkipCUDAGraphCaptureGuard guard;
auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph(
inputs_col, inputs_col_num);
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), restored,
inputs_col_num * sizeof(int64_t), context.stream());
}
int64_t* dev_ins_col_data =
static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
......@@ -422,13 +419,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
T** dev_out_gpu_data = nullptr;
if (!has_same_shape || o_num < 2 || o_num > 4) {
tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*));
{
platform::SkipCUDAGraphCaptureGuard guard;
auto* restored =
platform::RestoreHostMemIfCapturingCUDAGraph(outputs_data, o_num);
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_outs_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*),
context.stream());
}
tmp_dev_outs_data->ptr(), platform::CPUPlace(), restored,
o_num * sizeof(T*), context.stream());
dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
}
......@@ -452,13 +447,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
} else {
auto tmp_dev_ins_col_data =
memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
{
platform::SkipCUDAGraphCaptureGuard guard;
auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph(
outputs_cols, outputs_cols_num);
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), restored,
outputs_cols_num * sizeof(int64_t), context.stream());
}
int64_t* dev_outs_col_data =
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
......
......@@ -60,6 +60,23 @@ inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) {
callback();
}
template <typename T>
inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) {
static_assert(std::is_trivial<T>::value, "T must be trivial type");
static_assert(!std::is_same<T, void>::value, "T cannot be void");
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
size_t nbytes = size * sizeof(T);
void *new_host_mem = new uint8_t[nbytes];
std::memcpy(new_host_mem, host_mem, nbytes);
AddResetCallbackIfCapturingCUDAGraph(
[new_host_mem] { delete[] reinterpret_cast<uint8_t *>(new_host_mem); });
return reinterpret_cast<T *>(new_host_mem);
}
#endif
return host_mem;
}
class SkipCUDAGraphCaptureGuard {
DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册