diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index 32bb479e00517e91a4c61ae698e03b8e3a03bca4..bc2d496a3e76a8fa620dcf17a0cb4818516ab302 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -287,13 +287,11 @@ class ConcatFunctor { const T** dev_ins_data = nullptr; if (!has_same_shape || in_num < 2 || in_num > 4) { tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*)); - { - platform::SkipCUDAGraphCaptureGuard guard; - memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), - tmp_dev_ins_data->ptr(), platform::CPUPlace(), - static_cast(inputs_data), in_num * sizeof(T*), - context.stream()); - } + auto* restored = + platform::RestoreHostMemIfCapturingCUDAGraph(inputs_data, in_num); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_dev_ins_data->ptr(), platform::CPUPlace(), restored, + in_num * sizeof(T*), context.stream()); dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); } @@ -317,13 +315,12 @@ class ConcatFunctor { } else { auto tmp_dev_ins_col_data = memory::Alloc(context, inputs_col_num * sizeof(int64_t)); - { - platform::SkipCUDAGraphCaptureGuard guard; - memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), - tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), - static_cast(inputs_col), - inputs_col_num * sizeof(int64_t), context.stream()); - } + + auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph( + inputs_col, inputs_col_num); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), restored, + inputs_col_num * sizeof(int64_t), context.stream()); int64_t* dev_ins_col_data = static_cast(tmp_dev_ins_col_data->ptr()); @@ -422,13 +419,11 @@ class SplitFunctor { T** dev_out_gpu_data = nullptr; if (!has_same_shape || o_num < 2 || o_num > 4) { tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*)); - { - platform::SkipCUDAGraphCaptureGuard guard; - memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), - tmp_dev_outs_data->ptr(), platform::CPUPlace(), - reinterpret_cast(outputs_data), o_num * sizeof(T*), - context.stream()); - } + auto* restored = + platform::RestoreHostMemIfCapturingCUDAGraph(outputs_data, o_num); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_dev_outs_data->ptr(), platform::CPUPlace(), restored, + o_num * sizeof(T*), context.stream()); dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); } @@ -452,13 +447,11 @@ class SplitFunctor { } else { auto tmp_dev_ins_col_data = memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); - { - platform::SkipCUDAGraphCaptureGuard guard; - memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), - tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), - reinterpret_cast(outputs_cols), - outputs_cols_num * sizeof(int64_t), context.stream()); - } + auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph( + outputs_cols, outputs_cols_num); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), restored, + outputs_cols_num * sizeof(int64_t), context.stream()); int64_t* dev_outs_col_data = reinterpret_cast(tmp_dev_ins_col_data->ptr()); diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.h b/paddle/fluid/platform/cuda_graph_with_memory_pool.h index fe082c850aa4d278d5d8c05d6c506dac22485676..7a9e1a3a1419ca62b794e53df0bd34b45dae8b9e 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.h +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.h @@ -60,6 +60,23 @@ inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) { callback(); } +template +inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) { + static_assert(std::is_trivial::value, "T must be trivial type"); + static_assert(!std::is_same::value, "T cannot be void"); +#ifdef PADDLE_WITH_CUDA + if (UNLIKELY(IsCUDAGraphCapturing())) { + size_t nbytes = size * sizeof(T); + void *new_host_mem = new uint8_t[nbytes]; + std::memcpy(new_host_mem, host_mem, nbytes); + AddResetCallbackIfCapturingCUDAGraph( + [new_host_mem] { delete[] reinterpret_cast(new_host_mem); }); + return reinterpret_cast(new_host_mem); + } +#endif + return host_mem; +} + class SkipCUDAGraphCaptureGuard { DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard);