From 34b209cffa81593092a308e2ffe0536b475e81e6 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 20 Aug 2018 16:33:04 +0800 Subject: [PATCH] Complete sequence_padding GPU kernel --- paddle/fluid/operators/CMakeLists.txt | 1 + .../fluid/operators/math/sequence_padding.cc | 26 +-- .../fluid/operators/math/sequence_padding.cu | 151 ++++++++---------- .../fluid/operators/math/sequence_padding.h | 6 +- .../operators/math/sequence_padding_test.cc | 13 +- paddle/fluid/operators/sequence_pad_op.h | 5 +- paddle/fluid/operators/warpctc_op.h | 15 +- 7 files changed, 113 insertions(+), 104 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ff0e989464..2179a5acdb 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -277,6 +277,7 @@ op_library(unsqueeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op) op_library(extract_rows_op DEPS memory) op_library(flatten_op DEPS reshape_op) +op_library(sequence_pad_op DEPS sequence_padding) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index d3dab64f60..02ede3edce 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -18,8 +18,6 @@ namespace paddle { namespace operators { namespace math { -enum CopyType { kSeqToPad, kPadToSeq }; - template void CopyValidData(framework::Tensor* dst_tensor, const framework::Tensor* src_tensor, @@ -67,7 +65,7 @@ class PaddingLoDTensorFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, framework::LoDTensor* pad_tensor, - std::vector pad_value = {0}, int pad_seq_len = -1, + const framework::LoDTensor& pad_value, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, const PadLayout layout = kBatchLengthWidth) { auto seq_lod = seq_tensor.lod(); @@ -81,19 +79,21 @@ class PaddingLoDTensorFunctor { CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, step_width, layout); - PADDLE_ENFORCE(pad_value.size() == 1 || - static_cast(pad_value.size()) == step_width, - "The size of 'pad_value' can only be 1 or be equal to the " + PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width, + "The numel of 'pad_value' can only be 1 or be equal to the " "'step_width'."); - if (pad_value.size() == 1) { - pad_value = std::vector(step_width, pad_value[0]); - } - // fill padding value T* pad_data = pad_tensor->data(); - for (int i = 0; i < pad_tensor->numel(); i += step_width) { - memcpy(pad_data + i, pad_value.data(), step_width * sizeof(T)); + const T* pad_value_data = pad_value.data(); + if (pad_value.numel() == 1) { + for (int i = 0; i < pad_tensor->numel(); ++i) { + pad_data[i] = *pad_value_data; + } + } else { + for (int i = 0; i < pad_tensor->numel(); i += step_width) { + memcpy(pad_data + i, pad_value_data, step_width * sizeof(T)); + } } CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, @@ -117,7 +117,7 @@ class UnpaddingLoDTensorFunctor { const framework::LoDTensor& pad_tensor, framework::LoDTensor* seq_tensor, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - const PadLayout& layout = kBatchLengthWidth) { + const PadLayout layout = kBatchLengthWidth) { auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; const auto& seq_tensor_dims = seq_tensor->dims(); const auto& pad_tensor_dims = pad_tensor.dims(); diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 20e3e3de2a..3b1a44a457 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -19,46 +19,32 @@ namespace paddle { namespace operators { namespace math { -template +template __global__ void SequencePaddingKernel( - T* pad_data, T* seq_data, const size_t* seq_offset, const size_t& seq_num, - const size_t& max_seq_len, const size_t& seq_width, bool norm_by_times, - const T& pad_value, const OutputLayout& output_layout) { + T* dst, const T* src, const T* pad_value, bool is_constant_pad, + const size_t* seq_offsets, const size_t& seq_num, const size_t& pad_seq_len, + const size_t& step_width, bool norm_by_len, const PadLayout& layout) { size_t seq_idx = blockIdx.y; - size_t seq_start = seq_offset[seq_idx]; - size_t seq_len = seq_offset[seq_idx + 1] - seq_start; - - size_t seq_step_idx = blockIdx.x * blockDim.y + threadIdx.y; - - size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width; - - size_t pad_data_offset = 0; - - if (output_layout == kLengthBatchWidth) { - pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width; - } else { - pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width; - } - - if (seq_step_idx < seq_len) { - T scale = norm_by_times ? (1.0f / static_cast(seq_len)) : 1.0f; - if (Padding) { - /* seq -> pad */ - for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i]; - } - } else { - /* pad -> seq */ - for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i]; - } + size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; + + size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y; + size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width; + size_t pad_data_offset = layout == kBatchLengthWidth + ? (seq_idx * pad_seq_len + step_idx) * step_width + : (step_idx * seq_num + seq_idx) * step_width; + + T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset); + const T* src_data = + src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset); + + if (step_idx < seq_len) { + float scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; + for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) { + dst_data[i] = scale * src_data[i]; } - } else if (seq_step_idx < max_seq_len) { - if (Padding) { - /* seq -> pad */ - for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - pad_data[pad_data_offset + i] = pad_value; - } + } else if (step_idx < pad_seq_len && Type == kSeqToPad) { + for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { + dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i]; } } } @@ -69,24 +55,26 @@ class PaddingLoDTensorFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::LoDTensor& seq_tensor, framework::Tensor* pad_tensor, - T pad_value = static_cast(0), bool norm_by_times = false, - size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth) { - CheckLoD(seq_tensor, lod_level); - - auto& lod = seq_tensor.lod(); - auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - - auto seq_tensor_dims = seq_tensor.dims(); - auto pad_tensor_dims = pad_tensor->dims(); - int64_t max_seq_len = MaximumSequenceLength(seq_offset); - int64_t seq_num = seq_offset.size() - 1; - int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0]; + const framework::LoDTensor& pad_value, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_lod = seq_tensor.lod(); + const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; + const auto& seq_tensor_dims = seq_tensor.dims(); + const auto& pad_tensor_dims = pad_tensor->dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor.numel() / seq_tensor_dims[0]; + int seq_num = seq_offset.size() - 1; - CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, - seq_num, seq_width, output_layout); + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width, + "The numel of 'pad_value' can only be 1 or be equal to the " + "'step_width'."); - if (!norm_by_times && seq_num == 1UL) { + if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) { TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor); pad_tensor->Resize(pad_tensor_dims); return; @@ -98,21 +86,22 @@ class PaddingLoDTensorFunctor { * and at least 8 elements for each thread. */ size_t block_dim_x = - std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); size_t block_dim_y = kBlockSize / block_dim_x; dim3 threads(block_dim_x, block_dim_y); - size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; size_t grid_dim_y = seq_num; dim3 grid(grid_dim_x, grid_dim_y); const T* seq_data = seq_tensor.data(); T* pad_data = pad_tensor->data(); + const T* pad_value_data = pad_value.data(); - SequencePaddingKernel<<>>( - pad_data, const_cast(seq_data), - seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, - seq_width, norm_by_times, pad_value, output_layout); + SequencePaddingKernel<<>>( + pad_data, seq_data, pad_value_data, pad_value.numel() == 1, + seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + step_width, norm_by_times, layout); } }; @@ -120,25 +109,23 @@ template class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, - framework::LoDTensor* seq_tensor, - const framework::Tensor& pad_tensor, - bool norm_by_times = false, size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth) { - CheckLoD(*seq_tensor, lod_level); - - auto& lod = seq_tensor->lod(); - auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - - auto seq_tensor_dims = seq_tensor->dims(); - auto pad_tensor_dims = pad_tensor.dims(); - int64_t max_seq_len = MaximumSequenceLength(seq_offset); - int64_t seq_num = seq_offset.size() - 1; - int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0]; + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; + const auto& seq_tensor_dims = seq_tensor->dims(); + const auto& pad_tensor_dims = pad_tensor.dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor->numel() / seq_tensor_dims[0]; + int seq_num = seq_offset.size() - 1; - CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, - seq_num, seq_width, output_layout); + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); - if (!norm_by_times && seq_num == 1UL) { + if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) { TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor); seq_tensor->Resize(seq_tensor_dims); return; @@ -150,21 +137,21 @@ class UnpaddingLoDTensorFunctor { * and at least 8 elements for each thread. */ size_t block_dim_x = - std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); size_t block_dim_y = kBlockSize / block_dim_x; dim3 threads(block_dim_x, block_dim_y); - size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; size_t grid_dim_y = seq_num; dim3 grid(grid_dim_x, grid_dim_y); const T* pad_data = pad_tensor.data(); T* seq_data = seq_tensor->data(); - SequencePaddingKernel<<>>( - const_cast(pad_data), seq_data, - seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, - seq_width, norm_by_times, static_cast(0), output_layout); + SequencePaddingKernel<<>>( + seq_data, pad_data, nullptr, false, + seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + step_width, norm_by_times, layout); } }; diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index 9b8c892c53..3fb5859e3b 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -25,6 +25,8 @@ namespace math { enum PadLayout { kBatchLengthWidth = 0, kLengthBatchWidth }; +enum CopyType { kSeqToPad, kPadToSeq }; + inline static size_t MaximumSequenceLength( const framework::Vector& seq_offset) { size_t seq_num = seq_offset.size() - 1; @@ -82,7 +84,7 @@ class PaddingLoDTensorFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, framework::LoDTensor* pad_tensor, - std::vector pad_value = {0}, int pad_seq_len = -1, + const framework::LoDTensor& pad_value, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, const PadLayout layout = kBatchLengthWidth); }; @@ -94,7 +96,7 @@ class UnpaddingLoDTensorFunctor { const framework::LoDTensor& pad_tensor, framework::LoDTensor* seq_tensor, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - const PadLayout& layout = kBatchLengthWidth); + const PadLayout layout = kBatchLengthWidth); }; } // namespace math diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index 3171c7c33e..4f61b1029c 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -24,6 +24,8 @@ void TestSequencePadding(const paddle::framework::LoD& lod, paddle::framework::LoDTensor seq; paddle::framework::LoDTensor seq_back; paddle::framework::LoDTensor padding; + paddle::framework::LoDTensor cpu_pad_value; + paddle::framework::LoDTensor pad_value; const size_t level = lod.size() - 1; auto seq_dims = @@ -55,8 +57,17 @@ void TestSequencePadding(const paddle::framework::LoD& lod, padding.mutable_data(padding_dims, *place); + T* pad_value_data = + cpu_pad_value.mutable_data({1}, paddle::platform::CPUPlace()); + *pad_value_data = static_cast(0); + if (paddle::platform::is_cpu_place(*place)) { + pad_value = cpu_pad_value; + } else { + TensorCopySync(cpu_pad_value, *place, &pad_value); + } + paddle::operators::math::PaddingLoDTensorFunctor()( - *context, seq, &padding, {0}, -1, 0, false, + *context, seq, &padding, pad_value, -1, 0, false, paddle::operators::math::kLengthBatchWidth); seq_back.set_lod(lod); diff --git a/paddle/fluid/operators/sequence_pad_op.h b/paddle/fluid/operators/sequence_pad_op.h index 44aff30879..5fc9da69d7 100644 --- a/paddle/fluid/operators/sequence_pad_op.h +++ b/paddle/fluid/operators/sequence_pad_op.h @@ -35,14 +35,11 @@ class SequencePadOpKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); const auto* pad_value = ctx.Input("PadValue"); - const T* pad_value_data = pad_value->data(); - std::vector pad_value_vec(pad_value_data, - pad_value_data + pad_value->numel()); int padded_length = ctx.Attr("padded_length"); math::PaddingLoDTensorFunctor()( - ctx.template device_context(), *x, out, pad_value_vec, + ctx.template device_context(), *x, out, *pad_value, padded_length, 0, false, math::kBatchLengthWidth); } }; diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 6cbf985039..444265f58d 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -161,10 +161,21 @@ class WarpCTCKernel : public framework::OpKernel { static_cast(num_sequences), static_cast(sequence_width)}); warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); + + LoDTensor cpu_pad_value; + T* pad_value_data = + cpu_pad_value.mutable_data({1}, platform::CPUPlace()); + *pad_value_data = static_cast(0); + LoDTensor pad_value; + if (platform::is_cpu_place(ctx.GetPlace())) { + pad_value = cpu_pad_value; + } else { + TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value); + } + math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, &warpctc_logits, - {static_cast(0)}, -1, 0, false /* norm_by_times */, - math::kLengthBatchWidth); + pad_value, -1, 0, false /* norm_by_times */, math::kLengthBatchWidth); const T* warpctc_logits_data = warpctc_logits.data(); std::vector warpctc_label_lengths(num_sequences); -- GitLab