未验证 提交 a1ad3a63 编写于 作者: S sneaxiy 提交者: GitHub

Fix CUDA Graph H2D bug by restore host memory (#37774)

* fix CUDA Graph H2D bug again

* fix no return bug
上级 9a2d327c
...@@ -287,13 +287,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -287,13 +287,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
const T** dev_ins_data = nullptr; const T** dev_ins_data = nullptr;
if (!has_same_shape || in_num < 2 || in_num > 4) { if (!has_same_shape || in_num < 2 || in_num > 4) {
tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*)); tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*));
{ auto* restored =
platform::SkipCUDAGraphCaptureGuard guard; platform::RestoreHostMemIfCapturingCUDAGraph(inputs_data, in_num);
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_data->ptr(), platform::CPUPlace(), tmp_dev_ins_data->ptr(), platform::CPUPlace(), restored,
static_cast<void*>(inputs_data), in_num * sizeof(T*), in_num * sizeof(T*), context.stream());
context.stream());
}
dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr()); dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
} }
...@@ -317,13 +315,12 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -317,13 +315,12 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
} else { } else {
auto tmp_dev_ins_col_data = auto tmp_dev_ins_col_data =
memory::Alloc(context, inputs_col_num * sizeof(int64_t)); memory::Alloc(context, inputs_col_num * sizeof(int64_t));
{
platform::SkipCUDAGraphCaptureGuard guard; auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph(
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), inputs_col, inputs_col_num);
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
static_cast<void*>(inputs_col), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), restored,
inputs_col_num * sizeof(int64_t), context.stream()); inputs_col_num * sizeof(int64_t), context.stream());
}
int64_t* dev_ins_col_data = int64_t* dev_ins_col_data =
static_cast<int64_t*>(tmp_dev_ins_col_data->ptr()); static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
...@@ -422,13 +419,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> { ...@@ -422,13 +419,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
T** dev_out_gpu_data = nullptr; T** dev_out_gpu_data = nullptr;
if (!has_same_shape || o_num < 2 || o_num > 4) { if (!has_same_shape || o_num < 2 || o_num > 4) {
tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*)); tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*));
{ auto* restored =
platform::SkipCUDAGraphCaptureGuard guard; platform::RestoreHostMemIfCapturingCUDAGraph(outputs_data, o_num);
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_outs_data->ptr(), platform::CPUPlace(), tmp_dev_outs_data->ptr(), platform::CPUPlace(), restored,
reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*), o_num * sizeof(T*), context.stream());
context.stream());
}
dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr()); dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
} }
...@@ -452,13 +447,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> { ...@@ -452,13 +447,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
} else { } else {
auto tmp_dev_ins_col_data = auto tmp_dev_ins_col_data =
memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
{ auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph(
platform::SkipCUDAGraphCaptureGuard guard; outputs_cols, outputs_cols_num);
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), restored,
reinterpret_cast<void*>(outputs_cols), outputs_cols_num * sizeof(int64_t), context.stream());
outputs_cols_num * sizeof(int64_t), context.stream());
}
int64_t* dev_outs_col_data = int64_t* dev_outs_col_data =
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr()); reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
......
...@@ -60,6 +60,23 @@ inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) { ...@@ -60,6 +60,23 @@ inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) {
callback(); callback();
} }
template <typename T>
inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) {
static_assert(std::is_trivial<T>::value, "T must be trivial type");
static_assert(!std::is_same<T, void>::value, "T cannot be void");
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
size_t nbytes = size * sizeof(T);
void *new_host_mem = new uint8_t[nbytes];
std::memcpy(new_host_mem, host_mem, nbytes);
AddResetCallbackIfCapturingCUDAGraph(
[new_host_mem] { delete[] reinterpret_cast<uint8_t *>(new_host_mem); });
return reinterpret_cast<T *>(new_host_mem);
}
#endif
return host_mem;
}
class SkipCUDAGraphCaptureGuard { class SkipCUDAGraphCaptureGuard {
DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard); DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册