From 20da7703897fe4cd6946f86bc8f713c016cf15a8 Mon Sep 17 00:00:00 2001 From: xiayanming <41795079@qq.com> Date: Wed, 7 Jul 2021 16:38:44 +0800 Subject: [PATCH] =?UTF-8?q?[HIP]=20=E8=A7=A3=E5=86=B3hipMemcpy=E6=97=A0?= =?UTF-8?q?=E6=B3=95overlap=E7=9A=84=E9=97=AE=E9=A2=98=EF=BC=8C=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E5=90=8EAMD=20GPU=E6=80=A7=E8=83=BD=E6=8F=90=E5=8D=87?= =?UTF-8?q?=E5=A4=A7=E4=BA=8E10%=20(#33982)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../fluid/operators/math/concat_and_split.cu | 105 ++++++++++++++---- 1 file changed, 83 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index d62c1e42d3b..58f936788a3 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include "gflags/gflags.h" #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/concat_and_split.h" @@ -242,8 +243,28 @@ class ConcatFunctor { int in_col = input[0].numel() / in_row; int out_row = in_row, out_col = 0; - std::vector inputs_data(in_num); - std::vector inputs_col(in_num + 1); + int inputs_col_num = in_num + 1; + std::vector inputs_data_vec(in_num); + std::vector inputs_col_vec(inputs_col_num); + const T** inputs_data = inputs_data_vec.data(); + int* inputs_col = inputs_col_vec.data(); + +// There are some differences between hip runtime and NV runtime. +// In NV, when the pageable memory data less than 64K is transferred from +// hosttodevice, it will be automatically asynchronous. +// However, only pinned memory in hip can copy asynchronously +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device +// 3.2.6.1. Concurrent Execution between Host and Device +// Memory copies from host to device of a memory block of 64 KB or less +#ifdef PADDLE_WITH_HIP + memory::AllocationPtr data_alloc, col_alloc; + data_alloc = + memory::Alloc(platform::CUDAPinnedPlace(), in_num * sizeof(T*)); + inputs_data = reinterpret_cast(data_alloc->ptr()); + col_alloc = memory::Alloc(platform::CUDAPinnedPlace(), + inputs_col_num * sizeof(int)); + inputs_col = reinterpret_cast(col_alloc->ptr()); +#endif inputs_col[0] = 0; bool has_same_shape = true; @@ -264,12 +285,11 @@ class ConcatFunctor { memory::allocation::AllocationPtr tmp_dev_ins_data; const T** dev_ins_data = nullptr; if (!has_same_shape || in_num < 2 || in_num > 4) { - tmp_dev_ins_data = - memory::Alloc(context, inputs_data.size() * sizeof(T*)); + tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*)); memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), tmp_dev_ins_data->ptr(), platform::CPUPlace(), - static_cast(inputs_data.data()), - inputs_data.size() * sizeof(T*), context.stream()); + static_cast(inputs_data), in_num * sizeof(T*), + context.stream()); dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); } @@ -292,17 +312,29 @@ class ConcatFunctor { } } else { auto tmp_dev_ins_col_data = - memory::Alloc(context, inputs_col.size() * sizeof(int)); + memory::Alloc(context, inputs_col_num * sizeof(int)); memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), - static_cast(inputs_col.data()), - inputs_col.size() * sizeof(int), context.stream()); + static_cast(inputs_col), inputs_col_num * sizeof(int), + context.stream()); int* dev_ins_col_data = static_cast(tmp_dev_ins_col_data->ptr()); ConcatKernel<<>>( - dev_ins_data, dev_ins_col_data, static_cast(inputs_col.size()), + dev_ins_data, dev_ins_col_data, static_cast(inputs_col_num), out_row, out_col, output->data()); } +#ifdef PADDLE_WITH_HIP + // Prevent the pinned memory value from being covered and release the memory + // after the launch kernel of the stream is executed (reapply pinned memory + // next time) + auto* data_alloc_released = data_alloc.release(); + auto* col_alloc_released = col_alloc.release(); + context.AddStreamCallback([data_alloc_released, col_alloc_released] { + memory::allocation::AllocationDeleter deleter; + deleter(data_alloc_released); + deleter(col_alloc_released); + }); +#endif } }; @@ -313,6 +345,7 @@ class ConcatFunctor { template class SplitFunctor { public: + SplitFunctor(); void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const std::vector& ref_inputs, @@ -329,8 +362,27 @@ class SplitFunctor { int64_t in_col = 0, in_row = out_row; bool has_same_shape = true; - std::vector outputs_data(o_num); - std::vector outputs_cols(o_num + 1); + int outputs_cols_num = o_num + 1; + std::vector outputs_data_vec(o_num); + std::vector outputs_cols_vec(outputs_cols_num); + T** outputs_data = outputs_data_vec.data(); + int64_t* outputs_cols = outputs_cols_vec.data(); + +// There are some differences between hip runtime and NV runtime. +// In NV, when the pageable memory data less than 64K is transferred from +// hosttodevice, it will be automatically asynchronous. +// However, only pinned memory in hip can copy asynchronously +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device +// 3.2.6.1. Concurrent Execution between Host and Device +// Memory copies from host to device of a memory block of 64 KB or less +#ifdef PADDLE_WITH_HIP + memory::AllocationPtr data_alloc, cols_alloc; + data_alloc = memory::Alloc(platform::CUDAPinnedPlace(), o_num * sizeof(T*)); + outputs_data = reinterpret_cast(data_alloc->ptr()); + cols_alloc = memory::Alloc(platform::CUDAPinnedPlace(), + (outputs_cols_num) * sizeof(int64_t)); + outputs_cols = reinterpret_cast(cols_alloc->ptr()); +#endif outputs_cols[0] = 0; for (int i = 0; i < o_num; ++i) { @@ -354,12 +406,11 @@ class SplitFunctor { memory::allocation::AllocationPtr tmp_dev_outs_data; T** dev_out_gpu_data = nullptr; if (!has_same_shape || o_num < 2 || o_num > 4) { - tmp_dev_outs_data = - memory::Alloc(context, outputs_data.size() * sizeof(T*)); + tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*)); memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), tmp_dev_outs_data->ptr(), platform::CPUPlace(), - reinterpret_cast(outputs_data.data()), - outputs_data.size() * sizeof(T*), context.stream()); + reinterpret_cast(outputs_data), o_num * sizeof(T*), + context.stream()); dev_out_gpu_data = reinterpret_cast(tmp_dev_outs_data->ptr()); } @@ -382,20 +433,30 @@ class SplitFunctor { } } else { auto tmp_dev_ins_col_data = - memory::Alloc(context, - - outputs_cols.size() * sizeof(int64_t)); + memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), - reinterpret_cast(outputs_cols.data()), - outputs_cols.size() * sizeof(int64_t), context.stream()); + reinterpret_cast(outputs_cols), + outputs_cols_num * sizeof(int64_t), context.stream()); int64_t* dev_outs_col_data = reinterpret_cast(tmp_dev_ins_col_data->ptr()); SplitKernel<<>>( input.data(), in_row, in_col, dev_outs_col_data, - static_cast(outputs_cols.size()), dev_out_gpu_data); + static_cast(outputs_cols_num), dev_out_gpu_data); } +#ifdef PADDLE_WITH_HIP + // Prevent the pinned memory value from being covered and release the memory + // after the launch kernel of the stream is executed (reapply pinned memory + // next time) + auto* data_alloc_released = data_alloc.release(); + auto* cols_alloc_released = cols_alloc.release(); + context.AddStreamCallback([data_alloc_released, cols_alloc_released] { + memory::allocation::AllocationDeleter deleter; + deleter(data_alloc_released); + deleter(cols_alloc_released); + }); +#endif } }; -- GitLab