From 171525100d0206addde04beca0723adbd3968136 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 16 Aug 2018 20:53:47 +0800 Subject: [PATCH] update CPU sequence_padding functor --- .../fluid/operators/math/sequence_padding.cc | 149 +++++++++--------- .../fluid/operators/math/sequence_padding.h | 56 +++---- .../operators/math/sequence_padding_test.cc | 6 +- paddle/fluid/operators/warpctc_op.h | 10 +- 4 files changed, 108 insertions(+), 113 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index 5ceb26553..e8ccf006a 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -18,37 +18,45 @@ namespace paddle { namespace operators { namespace math { +enum CopyType { kSeqToPad, kPadToSeq }; + template -void CopyDataCPU(framework::LoDTensor* seq_tensor, - framework::Tensor* pad_tensor, - const framework::Vector& seq_offset, - const int64_t& max_seq_len, const int64_t& seq_width, - bool seq_to_pad, bool norm_by_len, - OutputLayout output_layout) { - T* seq_data = seq_tensor->data(); - T* pad_data = pad_tensor->data(); - - int64_t seq_num = seq_offset.size() - 1; - - for (int64_t i = 0; i < seq_num; ++i) { - int64_t seq_start = seq_offset[i]; - int64_t seq_len = seq_offset[i + 1] - seq_start; - T scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; - for (int64_t j = 0; j < seq_len; ++j) { - for (int64_t k = 0; k < seq_width; ++k) { - size_t pad_data_idx = 0; - size_t seq_data_idx = (seq_start + j) * seq_width + k; - if (output_layout == kBatchLengthWidth) { - pad_data_idx = (i * max_seq_len + j) * seq_width + k; - } else { - pad_data_idx = (j * seq_num + i) * seq_width + k; - } - if (seq_to_pad) { - pad_data[pad_data_idx] = seq_data[seq_data_idx] * scale; - } else { - seq_data[seq_data_idx] = pad_data[pad_data_idx] * scale; +void CopyValidData(framework::Tensor* dst_tensor, + const framework::Tensor* src_tensor, + const framework::Vector& seq_offsets, + int pad_seq_len, int step_width, bool norm_by_len, + CopyType type, PadLayout layout) { + int seq_num = seq_offsets.size() - 1; + const T* src_data = src_tensor->data(); + T* dst_data = dst_tensor->data(); + + int seq_cpy_gap = step_width; + int pad_cpy_gap = + layout == kBatchLengthWidth ? step_width : seq_num * step_width; + for (int seq_idx = 0; seq_idx < seq_num; ++seq_idx) { + int valid_seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; + PADDLE_ENFORCE_GE( + pad_seq_len, valid_seq_len, + "The padded sequence length can not be less than its original length."); + int seq_data_offset = seq_offsets[seq_idx] * step_width; + int pad_data_offset = layout == kBatchLengthWidth + ? seq_idx * pad_seq_len * step_width + : seq_idx * step_width; + float scale = 1.0f / static_cast(valid_seq_len); + + for (int step_idx = 0; step_idx < valid_seq_len; ++step_idx) { + const T* src = + src_data + (type == kSeqToPad ? seq_data_offset : pad_data_offset); + T* dst = + dst_data + (type == kSeqToPad ? pad_data_offset : seq_data_offset); + memcpy(dst, src, step_width * sizeof(T)); + if (norm_by_len) { + for (int i = 0; i < step_width; ++i) { + *(dst + i) *= scale; } } + seq_data_offset += seq_cpy_gap; + pad_data_offset += pad_cpy_gap; } } } @@ -58,31 +66,37 @@ class PaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* pad_tensor, - T pad_value = static_cast(0), bool norm_by_times = false, - size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth) { - CheckLoD(seq_tensor, lod_level); - - auto& lod = seq_tensor.lod(); - auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - + framework::LoDTensor* pad_tensor, + std::vector pad_value = {0}, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor.lod())[lod_level]; auto seq_tensor_dims = seq_tensor.dims(); auto pad_tensor_dims = pad_tensor->dims(); - int64_t max_seq_len = MaximumSequenceLength(seq_offset); - int64_t seq_num = seq_offset.size() - 1; - int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0]; + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor.numel() / seq_tensor_dims[0]; - CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, - seq_num, seq_width, output_layout); + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + PADDLE_ENFORCE(pad_value.size() == 1 || + static_cast(pad_value.size()) == step_width, + "The size of 'pad_value' can only be 1 or be equal to the " + "'step_width'."); - T* pad_data = pad_tensor->data(); + if (pad_value.size() == 1) { + pad_value = std::vector(step_width, pad_value[0]); + } - memset(pad_data, pad_value, max_seq_len * seq_num * seq_width * sizeof(T)); + // fill padding value + T* pad_data = pad_tensor->data(); + for (int i = 0; i < pad_tensor->numel() / step_width; ++i) { + memcpy(pad_data, pad_value.data(), step_width * sizeof(T)); + } - CopyDataCPU(const_cast(&seq_tensor), pad_tensor, - seq_offset, max_seq_len, seq_width, true /* seq_to_pad */, - norm_by_times, output_layout); + CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, + step_width, norm_by_times, kSeqToPad, layout); } }; @@ -90,30 +104,23 @@ template class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, - framework::LoDTensor* seq_tensor, - const framework::Tensor& pad_tensor, - bool norm_by_times = false, size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth) { - CheckLoD(*seq_tensor, lod_level); - - auto& lod = seq_tensor->lod(); - auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - - auto& seq_tensor_dims = seq_tensor->dims(); - auto& pad_tensor_dims = pad_tensor.dims(); - int64_t max_seq_len = MaximumSequenceLength(seq_offset); - int64_t seq_num = seq_offset.size() - 1; - int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0]; - - CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, - seq_num, seq_width, output_layout); - - T* seq_data = seq_tensor->data(); - memset(seq_data, static_cast(0), seq_tensor->numel() * sizeof(T)); - - CopyDataCPU(seq_tensor, const_cast(&pad_tensor), - seq_offset, max_seq_len, seq_width, false /* seq_to_pad */, - norm_by_times, output_layout); + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout& layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; + auto seq_tensor_dims = seq_tensor->dims(); + auto pad_tensor_dims = pad_tensor.dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor->numel() / seq_tensor_dims[0]; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + + CopyValidData(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len, + step_width, norm_by_times, kPadToSeq, layout); } }; diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index 44d640433..d5790e2ba 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -22,7 +23,7 @@ namespace paddle { namespace operators { namespace math { -enum OutputLayout { kBatchLengthWidth = 0, kLengthBatchWidth }; +enum PadLayout { kBatchLengthWidth = 0, kLengthBatchWidth }; inline static size_t MaximumSequenceLength( const framework::Vector& seq_offset) { @@ -34,35 +35,22 @@ inline static size_t MaximumSequenceLength( return max_seq_len; } -inline static void CheckLoD(const framework::LoDTensor& seq_tensor, - const size_t& lod_level) { - PADDLE_ENFORCE(lod_level < seq_tensor.lod().size(), - "Invalid lod level which should be at least 0 and less " - "than maximum lod level of sequence tensor."); -} - inline static void CheckDims(const framework::DDim& seq_tensor_dims, - const size_t& last_offset, const framework::DDim& pad_tensor_dims, - const int64_t& max_seq_len, const int64_t& seq_num, - const int64_t& seq_width, - const OutputLayout& output_layout) { - PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), last_offset, + const framework::Vector& seq_offset, + int64_t padded_seq_len, int64_t step_width, + const PadLayout& layout) { + PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), seq_offset.back(), "Value of 1st dimension of the sequence tensor should be " "equal to sum of lengths of all sequences."); - PADDLE_ENFORCE_EQ(pad_tensor_dims.size(), 3UL, - "Padded tensor should be a 3-D tensor."); + PADDLE_ENFORCE(seq_tensor_dims.size() == 1 || seq_tensor_dims.size() == 2, + "seq_tensor's rank should be 1 or 2."); - if (output_layout == kBatchLengthWidth) { - PADDLE_ENFORCE_EQ(pad_tensor_dims, - framework::make_ddim({seq_num, max_seq_len, seq_width})); - } else if (output_layout == kLengthBatchWidth) { - PADDLE_ENFORCE_EQ(pad_tensor_dims, - framework::make_ddim({max_seq_len, seq_num, seq_width})); - } else { - PADDLE_THROW("Unsupported output layout."); - } + PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() || + seq_tensor_dims.size() == pad_tensor_dims.size(), + "pad_tensor's rank should be 1 greater than seq_tensor's " + "rank, or be equal with it."); } /* @@ -94,22 +82,22 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims, template class PaddingLoDTensorFunctor { public: - void operator()(const DeviceContext& context, + void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* pad_tensor, - T pad_value = static_cast(0), bool norm_by_times = false, - size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth); + framework::LoDTensor* pad_tensor, + std::vector pad_value = {0}, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth); }; template class UnpaddingLoDTensorFunctor { public: - void operator()(const DeviceContext& context, - framework::LoDTensor* seq_tensor, - const framework::Tensor& pad_tensor, - bool norm_by_times = false, size_t lod_level = 0, - OutputLayout output_layout = kBatchLengthWidth); + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout& layout = kBatchLengthWidth); }; } // namespace math diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index 82459274c..3171c7c33 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -23,7 +23,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod, paddle::framework::LoDTensor cpu_seq_back; paddle::framework::LoDTensor seq; paddle::framework::LoDTensor seq_back; - paddle::framework::Tensor padding; + paddle::framework::LoDTensor padding; const size_t level = lod.size() - 1; auto seq_dims = @@ -56,13 +56,13 @@ void TestSequencePadding(const paddle::framework::LoD& lod, padding.mutable_data(padding_dims, *place); paddle::operators::math::PaddingLoDTensorFunctor()( - *context, seq, &padding, 0, false, 0, + *context, seq, &padding, {0}, -1, 0, false, paddle::operators::math::kLengthBatchWidth); seq_back.set_lod(lod); seq_back.mutable_data(seq_dims, *place); paddle::operators::math::UnpaddingLoDTensorFunctor()( - *context, &seq_back, padding, false, 0, + *context, padding, &seq_back, -1, 0, false, paddle::operators::math::kLengthBatchWidth); if (paddle::platform::is_cpu_place(*place)) { diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index cb56f42a8..6cbf98503 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -153,7 +153,7 @@ class WarpCTCKernel : public framework::OpKernel { framework::make_ddim({static_cast(num_sequences), 1}); // warpctc needs sequences data stored in transposed padding format - Tensor warpctc_logits; + LoDTensor warpctc_logits; const size_t max_sequence_length = math::MaximumSequenceLength(logits_lod[level]); auto warpctc_logits_dims = @@ -163,7 +163,7 @@ class WarpCTCKernel : public framework::OpKernel { warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, &warpctc_logits, - static_cast(0), false /* norm_by_times */, 0, + {static_cast(0)}, -1, 0, false /* norm_by_times */, math::kLengthBatchWidth); const T* warpctc_logits_data = warpctc_logits.data(); @@ -210,15 +210,15 @@ template class WarpCTCGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* warpctc_grad = ctx.Input("WarpCTCGrad"); + auto* warpctc_grad = ctx.Input("WarpCTCGrad"); auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); math::UnpaddingLoDTensorFunctor()( - ctx.template device_context(), logits_grad, - *warpctc_grad, norm_by_times, 0, math::kLengthBatchWidth); + ctx.template device_context(), *warpctc_grad, + logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth); const T* loss_grad_data = loss_grad->data(); math::ScaleLoDTensorFunctor()( -- GitLab