diff --git a/paddle/operators/sequence_reshape_op.cc b/paddle/operators/sequence_reshape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..884c49276cead5b329217fceb174d40b2d632025 --- /dev/null +++ b/paddle/operators/sequence_reshape_op.cc @@ -0,0 +1,130 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_reshape_op.h" +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace operators { + +class SequenceReshapeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceReshapeOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto x_numel = product(x_dims); + PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2."); + int new_dim = ctx->Attrs().Get("new_dim"); + ctx->SetOutputDim("Out", + {x_numel / new_dim, static_cast(new_dim)}); + } +}; + +class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor, default LoDTensor) A 2-D LoDTensor with shape " + "being [N, M]."); + AddOutput("Out", + "(LoDTensor, default LoDTensor) A 2-D LoDTensor with " + "shape [T, new_dim] where T is calculated based on X.lod, M and " + "new_dim."); + AddAttr("new_dim", "Sequence dimension of the output LoDTensor."); + AddComment(R"DOC( +Sequence Reshape Operator. + +This operator will rearrange the input sequences. The new dimension is set by +attribute and length of each sequence may change longer or shorter which is +decided by original length, original dimension and new dimension. The following +example will help to illustrate the function of this operator: + +x is a LoDTensor: + x.lod = [[0, 2, 6]] + x.data = [[1, 2], [3, 4], + [5, 6], [7, 8], [9, 10], [11, 12]] + x.dims = [6, 2] + +set new_dim = 4 + +then out is a LoDTensor: + out.lod = [[0, 1, 3]] + out.data = [[1, 2, 3, 4], + [5, 6, 7, 8], [9, 10, 11, 12]] + out.dims = [3, 4] + +Currently, only 1-level LoDTensor is supported and please make sure (original +length * original dimension) can be divided by new_dim with no remainder for +each sequence. + +)DOC"); + } +}; + +class SequenceReshapeGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequenceReshapeGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceReshapeGradOp should not be null."); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); + } +}; + +class SequenceReshapeGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType("sequence_reshape_grad"); + op_desc_ptr->SetInput("X", Input("X")); + op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op_desc_ptr->SetAttrMap(Attrs()); + return std::unique_ptr(op_desc_ptr); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp, + ops::SequenceReshapeOpMaker, ops::SequenceReshapeGradOpMaker); +REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_reshape, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel); +REGISTER_OP_CPU_KERNEL( + sequence_reshape_grad, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel); diff --git a/paddle/operators/sequence_reshape_op.cu b/paddle/operators/sequence_reshape_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..d9c2f7e9a4149371867cf2a8b81d58566999bfba --- /dev/null +++ b/paddle/operators/sequence_reshape_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_reshape_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sequence_reshape, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel); +REGISTER_OP_CUDA_KERNEL( + sequence_reshape_grad, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel); diff --git a/paddle/operators/sequence_reshape_op.h b/paddle/operators/sequence_reshape_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dd9b611250bf61f82ee23a31717cb4363f0c388e --- /dev/null +++ b/paddle/operators/sequence_reshape_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +template +class SequenceReshapeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int out_width = context.Attr("new_dim"); + + auto in_dims = in->dims(); + int64_t in_width = in_dims[1]; + auto& in_lod = in->lod(); + + PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ( + in_dims[0], in_lod[0].back(), + "Inconsistent size between X.shape[0] and X.lod()[0].back()."); + + auto in_lod_l0 = in_lod[0]; + int seq_num = in_lod_l0.size() - 1; + + if (in_width == out_width) { + out->set_lod(in->lod()); + } else { + auto& out_lod = *out->mutable_lod(); + out_lod.resize(1); + out_lod[0].resize(seq_num + 1); + out_lod[0][0] = 0; + for (int i = 0; i < seq_num; ++i) { + size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; + size_t offset = 0; + offset = (seq_len * in_width) / out_width; + PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width, + "Please make sure (sequence_length * dimension) can " + "be divided by new_dim with no remainder for each " + "sequence. The %dth sequence is invalid.", + i + 1); + out_lod[0][i + 1] = out_lod[0][i] + offset; + } + } + + framework::Copy(*in, context.GetPlace(), out); + out->Resize({static_cast(out->lod()[0].back()), out_width}); + } +}; + +template +class SequenceReshapeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x_tensor_ptr = context.Input("X"); + auto* outg_tensor_ptr = + context.Input(framework::GradVarName("Out")); + auto* xg_tensor_ptr = + context.Output(framework::GradVarName("X")); + + xg_tensor_ptr->mutable_data(context.GetPlace()); + framework::Copy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr); + xg_tensor_ptr->Resize(x_tensor_ptr->dims()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_sequence_reshape.py b/python/paddle/v2/fluid/tests/test_sequence_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..857b15237ae872278c1c4f3d1bfe13d1b69fb6b1 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_reshape.py @@ -0,0 +1,84 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +import math +from op_test import OpTest + + +class TestSequenceReshape(OpTest): + def setUp(self): + self.op_type = 'sequence_reshape' + dimension = 12 + x_lod = [[0, 4, 5, 8, 11]] + x = np.random.uniform(0.1, 1, [11, 24]).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'new_dim': dimension} + + out, out_lod = self.compute_output(x, x_lod, dimension) + + self.outputs = {'Out': (out, out_lod)} + + def compute_output(self, x, x_lod, dimension): + x_width = x.shape[1] + out_lod = [[0]] + for i in xrange(len(x_lod[0]) - 1): + seq_len = x_lod[0][i + 1] - x_lod[0][i] + offset = (seq_len * x_width) / dimension + assert int(offset) * dimension == seq_len * x_width + out_lod[0].append(out_lod[0][-1] + int(offset)) + out = np.zeros(shape=(out_lod[0][-1], dimension)).astype('float32') + out.ravel()[:] = x.ravel()[:] + return out, out_lod + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSequenceReshape_reduce(TestSequenceReshape): + def setUp(self): + self.op_type = 'sequence_reshape' + dimension = 24 + x_lod = [[0, 4, 6, 8, 12]] + x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'new_dim': dimension} + + out, out_lod = self.compute_output(x, x_lod, dimension) + + self.outputs = {'Out': (out, out_lod)} + + +class TestSequenceReshape_same(TestSequenceReshape): + def setUp(self): + self.op_type = 'sequence_reshape' + dimension = 12 + x_lod = [[0, 4, 6, 8, 12]] + x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'new_dim': dimension} + + out, out_lod = self.compute_output(x, x_lod, dimension) + + self.outputs = {'Out': (out, out_lod)} + + +if __name__ == '__main__': + unittest.main()