From 8d8d48a34f9116f5a501d69cc4dbbf9ce13a1446 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 17 Aug 2018 17:58:12 +0800 Subject: [PATCH] Complete sequence_pad_op and its CPU kernel. Add unittests --- .../fluid/operators/math/sequence_padding.cc | 24 +++- .../fluid/operators/math/sequence_padding.h | 3 - paddle/fluid/operators/sequence_pad_op.cc | 105 +++++++------- paddle/fluid/operators/sequence_pad_op.cu | 10 +- paddle/fluid/operators/sequence_pad_op.h | 93 +++--------- .../tests/unittests/test_sequence_pad_op.py | 134 ++++++++++++++++++ 6 files changed, 234 insertions(+), 135 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_sequence_pad_op.py diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index e8ccf006ad..d3dab64f60 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -70,9 +70,10 @@ class PaddingLoDTensorFunctor { std::vector pad_value = {0}, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, const PadLayout layout = kBatchLengthWidth) { - auto seq_offsets = framework::ToAbsOffset(seq_tensor.lod())[lod_level]; - auto seq_tensor_dims = seq_tensor.dims(); - auto pad_tensor_dims = pad_tensor->dims(); + auto seq_lod = seq_tensor.lod(); + const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; + const auto& seq_tensor_dims = seq_tensor.dims(); + const auto& pad_tensor_dims = pad_tensor->dims(); if (pad_seq_len == -1) { pad_seq_len = MaximumSequenceLength(seq_offsets); } @@ -91,12 +92,21 @@ class PaddingLoDTensorFunctor { // fill padding value T* pad_data = pad_tensor->data(); - for (int i = 0; i < pad_tensor->numel() / step_width; ++i) { - memcpy(pad_data, pad_value.data(), step_width * sizeof(T)); + for (int i = 0; i < pad_tensor->numel(); i += step_width) { + memcpy(pad_data + i, pad_value.data(), step_width * sizeof(T)); } CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, step_width, norm_by_times, kSeqToPad, layout); + + // Set pad_tensor's lod info if possible + if (layout == kBatchLengthWidth) { + framework::LoD pad_lod(seq_lod.begin() + lod_level, seq_lod.end()); + for (size_t i = 0; i < pad_lod[0].size(); ++i) { + pad_lod[0][i] = i * pad_seq_len; + } + pad_tensor->set_lod(pad_lod); + } } }; @@ -109,8 +119,8 @@ class UnpaddingLoDTensorFunctor { int lod_level = 0, bool norm_by_times = false, const PadLayout& layout = kBatchLengthWidth) { auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; - auto seq_tensor_dims = seq_tensor->dims(); - auto pad_tensor_dims = pad_tensor.dims(); + const auto& seq_tensor_dims = seq_tensor->dims(); + const auto& pad_tensor_dims = pad_tensor.dims(); if (pad_seq_len == -1) { pad_seq_len = MaximumSequenceLength(seq_offsets); } diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index d5790e2ba2..9b8c892c53 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -44,9 +44,6 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims, "Value of 1st dimension of the sequence tensor should be " "equal to sum of lengths of all sequences."); - PADDLE_ENFORCE(seq_tensor_dims.size() == 1 || seq_tensor_dims.size() == 2, - "seq_tensor's rank should be 1 or 2."); - PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() || seq_tensor_dims.size() == pad_tensor_dims.size(), "pad_tensor's rank should be 1 greater than seq_tensor's " diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc index dc79b252c7..f23710cf4d 100644 --- a/paddle/fluid/operators/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -21,82 +21,85 @@ class SequencePadOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SequencePadOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("PadValue"), + "Input(PadValue) of SequencePadOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SequencePadOp should not be null."); auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "The rank of Input(x) can't be less than 2."); + auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size()); + auto pad_value_dims = ctx->GetInputDim("PadValue"); + PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) || + pad_value_dims == time_step_dims, + "The Input(PadValue) must be a scalar or a tensor whose " + "shape equals to time steps in sequences"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, - "Only support 2-D tensor, rank of Input(X) should be 2."); - - int lod_level = ctx->Attrs().Get("lod_level"); - - int64_t max_len = -1; - int64_t seq_num = -1; - int x_lod_size = -1; + int batch_dim_size = -1; if (ctx->IsRuntime()) { + // run time framework::Variable* x_var = boost::get(ctx->GetInputVarPtrs("X")[0]); - - auto& x_lod = x_var->Get().lod(); - - x_lod_size = x_lod.size(); - - auto x_abs_offset = framework::ToAbsOffset(x_lod)[lod_level]; - - PADDLE_ENFORCE_EQ(x_dims[0], static_cast(x_abs_offset.back()), - "The first dimension of `X` should be equal to sum " - "of all sequences' length."); - - seq_num = x_abs_offset.size() - 1; - - for (int64_t i = 1; i <= seq_num; ++i) { - int64_t seq_len = x_abs_offset[i] - x_abs_offset[i - 1]; - max_len = max_len < seq_len ? seq_len : max_len; + const auto& x_lod = x_var->Get().lod(); + PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info."); + const auto& x_lod_0 = x_lod[0]; + PADDLE_ENFORCE_GE(x_lod_0.size(), 2, + "The Input(X)'s lod info is corrupted."); + PADDLE_ENFORCE_EQ( + x_dims[0], static_cast(x_lod_0.back()), + "The Input(X)'s lod info mismatches the actual tensor shape."); + + int seq_num = x_lod_0.size() - 1; + int max_seq_len = math::MaximumSequenceLength(x_lod_0); + int padded_length = ctx->Attrs().Get("padded_length"); + if (padded_length == -1) { + padded_length = max_seq_len; } + PADDLE_ENFORCE_GE(padded_length, max_seq_len, + "The Attr(padded_length) must be -1 or an int greater " + "than the length of the longest original sequence."); + batch_dim_size = padded_length * seq_num; } else { + // compile time framework::VarDesc* x_desc = boost::get(ctx->GetInputVarPtrs("X")[0]); - x_lod_size = x_desc->GetLoDLevel(); + PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1); } - PADDLE_ENFORCE(lod_level >= 0 && lod_level < x_lod_size, - "Invalid `lod_level` which should be at least 0 and less " - "than maximum lod level of `X`"); - - ctx->SetOutputDim("Out", {seq_num, max_len, x_dims[1]}); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + auto out_dims = x_dims; + out_dims[0] = batch_dim_size; + ctx->SetOutputDim("Out", out_dims); } }; class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { public: - SequencePadOpMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { + void Make() override { AddInput("X", "(LoDTensor, default LoDTensor) Input variable which " - "should contain lod information. Length of each sequence would " - "be computed from the most bottom level lod."); - AddOutput("Out", - "(Tensor) Output variable which would be a common tensor " - "without lod. Each sequence would be padded to the maximum " - "length."); - AddAttr("lod_level", - "(int, default 0) Specify which level lod to referred to."); - AddAttr("pad_value", - "(float, default 0.0) Specify which value to be padded to " - "the end of each sequence."); + "should contain lod information."); + AddInput("PadValue", + "(LoDTensor), this Tensor holds values that will be fill into " + "padded steps. It can be a scalar or a tensor whose shape equals " + "to time steps in sequences. If it's a scalar, it will be " + "automatically broadcasted to the shape of time step."); + AddOutput( + "Out", + "(LoDTensor) The output vairable, which contains padded sequences."); + AddAttr( + "padded_length", + "The length of padded sequences. It can be setted to -1 or " + "any positive int. When it is -1, all sequences will be padded up to " + "the length of the longest one among them; when it a certain positive " + "value, it must be greater than the length of the longest original " + "sequence.") + .SetDefault(-1); AddComment(R"DOC( )DOC"); diff --git a/paddle/fluid/operators/sequence_pad_op.cu b/paddle/fluid/operators/sequence_pad_op.cu index a2fa62957e..ff8f81a2f0 100644 --- a/paddle/fluid/operators/sequence_pad_op.cu +++ b/paddle/fluid/operators/sequence_pad_op.cu @@ -17,7 +17,13 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( sequence_pad, - ops::SequencePadOpKernel); + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel, + ops::SequencePadOpKernel); REGISTER_OP_CUDA_KERNEL( sequence_pad_grad, - ops::SequencePadGradOpKernel); + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel, + ops::SequencePadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_pad_op.h b/paddle/fluid/operators/sequence_pad_op.h index 6d136b65f1..44aff30879 100644 --- a/paddle/fluid/operators/sequence_pad_op.h +++ b/paddle/fluid/operators/sequence_pad_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/math/math_function.h" @@ -24,68 +26,24 @@ namespace operators { using LoDTensor = framework::LoDTensor; using LoD = framework::LoD; -template -struct CopyFunctor { - LoDTensor* lod_tensor_; - LoDTensor* pad_tensor_; - const LoD& ref_lod_; - const DeviceContext& ctx_; - bool is_lod_to_pad_; - - CopyFunctor(LoDTensor* lod_tensor, const LoD& ref_lod, LoDTensor* pad_tensor, - const DeviceContext& ctx, bool is_lod_to_pad) - : lod_tensor_(lod_tensor), - pad_tensor_(pad_tensor), - ref_lod_(ref_lod), - ctx_(ctx), - is_lod_to_pad_(is_lod_to_pad) {} - - void operator()() const { - /* - auto seq_num = ref_lod_.size() - 1; - auto max_len = pad_tensor_->dims()[0] / seq_num; - - PADDLE_ENFORCE_EQ(max_len * seq_num, pad_tensor_->dims()[0], - "First dimension of padded tensor should be equal to " - "maximum sequence length mulplied by sequence number."); - - for (size_t i = 1; i < ref_lod_.size(); ++i) { - auto seq_start = ref_lod_[i - 1]; - auto seq_end = ref_lod_[i]; - auto pad_start = (i - 1) * max_len; - auto pad_end = pad_start + (seq_end - seq_start); - auto sub_lod_tensor = lod_tensor_->Slice(seq_start, seq_end); - auto sub_pad_tensor = pad_tensor_->Slice(pad_start, pad_end); - if (is_lod_to_pad_) { - framework::TensorCopy(sub_lod_tensor, ctx.GetPlace(), &sub_pad_tensor); - } else { - framework::TensorCopy(sub_pad_tensor, ctx.GetPlace(), &sub_lod_tensor); - } - } - */ - } -}; - template class SequencePadOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - /* - auto* x = ctx.Input("X"); - auto* out_ptr = ctx.Output("Out"); - - out_ptr->mutable_data(ctx.GetPlace()); + const auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); - // Resize(); + const auto* pad_value = ctx.Input("PadValue"); + const T* pad_value_data = pad_value->data(); + std::vector pad_value_vec(pad_value_data, + pad_value_data + pad_value->numel()); - T pad_value = static_cast(ctx.Attr("pad_value")); + int padded_length = ctx.Attr("padded_length"); math::PaddingLoDTensorFunctor()( - ctx.template device_context(), *x, *, false); - - math::SetConstant set_func; - set_func(ctx.template device_context(), out_ptr, pad_value); - */ + ctx.template device_context(), *x, out, pad_value_vec, + padded_length, 0, false, math::kBatchLengthWidth); } }; @@ -93,26 +51,17 @@ template class SequencePadGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - /* - auto* x_ptr = ctx.Input("X"); - auto* g_out_ptr = ctx.Input(framework::GradVarName("Out")); - auto* g_x_ptr = ctx.Output(framework::GradVarName("X")); - - math::SetConstant set_func; - set_func(ctx.template device_context(), - g_x_ptr, - static_cast(0)); + auto* d_x = ctx.Output(framework::GradVarName("X")); + if (d_x) { + const auto* d_out = ctx.Input(framework::GradVarName("Out")); + d_x->mutable_data(ctx.GetPlace()); - auto& x_lod = x_ptr->lod(); - auto& x_last_level_lod = x_lod[x_lod.size() - 1]; + int padded_length = ctx.Attr("padded_length"); - CopyFunctor copy_func(g_out_ptr, - x_last_level_lod, - g_x_ptr, - ctx, - false); - copy_func(); - */ + math::UnpaddingLoDTensorFunctor()( + ctx.template device_context(), *d_out, d_x, + padded_length, 0, false, math::kBatchLengthWidth); + } } }; diff --git a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py new file mode 100644 index 0000000000..7b9eedbf52 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py @@ -0,0 +1,134 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestSequencePadOp(OpTest): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = -1 + self.dtype = 'float32' + + def set_data(self): + x_data = np.random.uniform(0.1, 0.5, self.x_shape).astype(self.dtype) + pad_value_data = np.array(self.pad_value).astype(self.dtype) + self.inputs = { + 'X': (x_data, self.x_len_lod), + 'PadValue': pad_value_data + } + self.attrs = {'padded_length': self.padded_length} + + def compute(self): + # get padded length + padded_length = self.padded_length + x_len_lod_0 = self.x_len_lod[0] + if padded_length == -1: + max_seq_len = 0 + for l in x_len_lod_0: + max_seq_len = max(max_seq_len, l) + padded_length = max_seq_len + + # do padding + x_data = self.inputs['X'][0] + pad_value_data = self.inputs['PadValue'] + if pad_value_data.shape == (1, ): + pad_value_data = np.broadcast_to( + pad_value_data, shape=x_data.shape[1:]) + padded_sequences = [] + start_idx = 0 + for l in x_len_lod_0: + end_idx = start_idx + l + seq = x_data[start_idx:end_idx] + to_pad_len = padded_length - l + for _ in range(to_pad_len): + seq = np.append(seq, pad_value_data[np.newaxis, :], axis=0) + padded_sequences.append(seq) + start_idx = end_idx + + out_len_lod = self.x_len_lod[:] + out_len_lod_0 = [padded_length] * len(x_len_lod_0) + out_len_lod[0] = out_len_lod_0 + out_data = np.concatenate(padded_sequences, axis=0) + self.outputs = {'Out': (out_data, out_len_lod)} + + def setUp(self): + self.op_type = 'sequence_pad' + self.set_attr() + self.set_data() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSequencePadOp2(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0, 2.0, 3.0, 4.0] + self.padded_length = -1 + self.dtype = 'float32' + + +class TestSequencePadOp3(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = 7 + self.dtype = 'float32' + + +class TestSequencePadOp4(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 4] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0, 2.0, 3.0, 4.0] + self.padded_length = 7 + self.dtype = 'float32' + + +class TestSequencePadOp5(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = -1 + self.dtype = 'float32' + + +class TestSequencePadOp6(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [[1.0, 2.0], [3.0, 4.0]] + self.padded_length = -1 + self.dtype = 'float32' + + +class TestSequencePadOp7(TestSequencePadOp): + def set_attr(self): + self.x_shape = [12, 2, 2] + self.x_len_lod = [[2, 3, 4, 3]] + self.pad_value = [1.0] + self.padded_length = 7 + self.dtype = 'float32' -- GitLab