diff --git a/paddle/fluid/operators/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_unpad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3a0762b9a4d3e080d5d6d10b249e0bd81980b95 --- /dev/null +++ b/paddle/fluid/operators/sequence_unpad_op.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_unpad_op.h" + +namespace paddle { +namespace operators { + +class SequenceUnpadOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceUnpadOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Length"), + "Input(Length) of SequenceUnpadOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceUnpadOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "The rank of Input(X) can't be less than 2."); + + auto len_dims = ctx->GetInputDim("Length"); + PADDLE_ENFORCE(len_dims.size() == 2 && len_dims[1] == 1, + "The shape of Input(Length) should be [batch_size, 1]."); + PADDLE_ENFORCE( + len_dims[0] == x_dims[0], + "Input(X) and Input(Length) should have the same first dimension."); + + int64_t out_dim_0 = -1; + if (ctx->IsRuntime()) { + out_dim_0 = x_dims[0] * x_dims[1]; + } + + std::vector out_dims_vec{out_dim_0}; + if (x_dims.size() == 2) { + out_dims_vec.push_back(1); + } else { + for (size_t i = 2; i < x_dims.size(); ++i) { + out_dims_vec.push_back(x_dims[i]); + } + } + ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class SequenceUnpadOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LoDTensor, default LoDTensor) Input tensor which " + "contains the padded sequences with equal length."); + AddInput("Length", + "(LoDTensor) The input tensor which specifies the actual ength of " + "sequences after unpadding."); + AddOutput( + "Out", + "(LoDTensor) The output tensor which contains unpadded sequences."); + AddComment(R"DOC( + Sequence Unpad Operator + + This operator removes the padding data in the input sequences and convert + them into sequences with actual length as output, identitied by lod + information. + + Example: + + Given input tensor Input(X): + X.data = [[ 1.0, 2.0, 3.0, 4.0, 5.0], + [ 6.0, 7.0, 8.0, 9.0, 10.0], + [11.0, 12.0, 13.0, 14.0, 15.0]], +` + in which there are 3 sequences padded to length 5, and the acutal length + specified by Input(Length): + + Length.data = [[2], [3], [4]], + + after unpadding, Output(Out) will be: + + Out.data = [[1.0, 2.0, 6.0, 7.0, 8.0, 11.0, 12.0, 13.0, 14.0]] + Out.lod = [[0, 2, 5, 9]] + + )DOC"); + } +}; + +class SequenceUnpadGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceUnpadGradOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequenceUnpadGradOp should not be null."); + + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp, + ops::SequenceUnpadOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_unpad, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel); +REGISTER_OP_CPU_KERNEL( + sequence_unpad_grad, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_unpad_op.cu b/paddle/fluid/operators/sequence_unpad_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..75248372237ec2cb23122f6b16e64f6ce750ebf9 --- /dev/null +++ b/paddle/fluid/operators/sequence_unpad_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_unpad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sequence_unpad, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel, + ops::SequenceUnpadOpKernel); +REGISTER_OP_CUDA_KERNEL( + sequence_unpad_grad, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel, + ops::SequenceUnpadGradOpKernel); diff --git a/paddle/fluid/operators/sequence_unpad_op.h b/paddle/fluid/operators/sequence_unpad_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ebe3118b985bdfd41ca55e8c572047aa87502ff4 --- /dev/null +++ b/paddle/fluid/operators/sequence_unpad_op.h @@ -0,0 +1,104 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sequence_padding.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +template +class SequenceUnpadOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x_t = ctx.Input("X"); + auto* len_t = ctx.Input("Length"); + auto* out_t = ctx.Output("Out"); + out_t->mutable_data(ctx.GetPlace()); + + const int64_t* seq_len_ptr = nullptr; + if (platform::is_gpu_place(ctx.GetPlace())) { + LoDTensor seq_len_cpu; + seq_len_cpu.Resize(len_t->dims()); + seq_len_ptr = seq_len_cpu.mutable_data(platform::CPUPlace()); + framework::TensorCopy(*len_t, platform::CPUPlace(), + ctx.template device_context(), + &seq_len_cpu); + } else { + seq_len_ptr = len_t->data(); + } + + size_t batch_size = x_t->dims()[0]; + std::vector out_lod0(batch_size + 1, 0); + for (size_t i = 0; i < batch_size; ++i) { + out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i]; + } + + framework::LoD out_lod; + out_lod.push_back(out_lod0); + out_t->set_lod(out_lod); + + std::vector out_dims_vec{static_cast(out_lod0.back())}; + if (x_t->dims().size() == 2) { + out_dims_vec.push_back(1); + } else { + for (size_t i = 2; i < x_t->dims().size(); ++i) { + out_dims_vec.push_back(x_t->dims()[i]); + } + } + out_t->Resize(framework::make_ddim(out_dims_vec)); + + int64_t padded_length = x_t->dims()[1]; + math::UnpaddingLoDTensorFunctor()( + ctx.template device_context(), *x_t, out_t, + padded_length, 0, false, math::kBatchLengthWidth); + } +}; + +template +class SequenceUnpadGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_x = ctx.Output(framework::GradVarName("X")); + if (d_x) { + const auto* d_out = ctx.Input(framework::GradVarName("Out")); + const auto* x_t = ctx.Input("X"); + d_x->mutable_data(ctx.GetPlace()); + + int padded_length = x_t->dims()[1]; + + LoDTensor zero_pads; + zero_pads.Resize({1, 1}); + zero_pads.mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + set_zero(dev_ctx, &zero_pads, static_cast(0)); + + math::PaddingLoDTensorFunctor()( + ctx.template device_context(), *d_out, d_x, zero_pads, + padded_length, 0, false, math::kBatchLengthWidth); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py new file mode 100644 index 0000000000000000000000000000000000000000..673b0ea180464b8b8f6f5c6e76d5c5c80f347d25 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py @@ -0,0 +1,75 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import six +import numpy as np +from op_test import OpTest + + +class TestSequenceUnpadOp(OpTest): + def init(self): + self.length = [2, 3, 4] + self.x_shape = (3, 5) + self.dtype = "float32" + + def compute(self): + assert len(self.length) == self.x_shape[0] + x = np.random.random(self.x_shape).astype(self.dtype) + out_lod = [self.length] + + out = x[0, 0:self.length[0]] + for i in six.moves.xrange(1, x.shape[0]): + out = np.append(out, x[i, 0:self.length[i]], axis=0) + + out_shape = (sum(self.length), ) + if len(self.x_shape) == 2: + out_shape = out_shape + (1, ) + else: + out_shape = out_shape + self.x_shape[2:] + + self.inputs = { + 'X': x, + 'Length': np.array(self.length).astype('int64').reshape(-1, 1) + } + self.outputs = {'Out': (out.reshape(out_shape), out_lod)} + + def setUp(self): + self.op_type = 'sequence_unpad' + self.init() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSequenceUnpadOp2(TestSequenceUnpadOp): + def init(self): + self.length = [2, 3, 4] + self.x_shape = (3, 5, 4, 3) + self.dtype = "float32" + + +class TestSequenceUnpadOp3(TestSequenceUnpadOp): + def init(self): + self.length = [5, 2, 3, 4] + self.x_shape = (4, 5, 3, 3, 6) + self.dtype = "float64" + + +if __name__ == '__main__': + unittest.main()