diff --git a/paddle/operators/Sequence_concat_op.cu b/paddle/operators/Sequence_concat_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..200b2a8ab9e0d7f841c3e9499d0ea90aa81142ae --- /dev/null +++ b/paddle/operators/Sequence_concat_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/sequence_concat_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + sequence_concat, + ops::SequenceConcatOpKernel); +REGISTER_OP_GPU_KERNEL( + sequence_concat_grad, + ops::SequenceConcatGradOpKernel); diff --git a/paddle/operators/sequence_concat_op.cc b/paddle/operators/sequence_concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..02961d00ec750e4e48dbf4a217e10e9c252f409a --- /dev/null +++ b/paddle/operators/sequence_concat_op.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_concat_op.h" + +namespace paddle { +namespace operators { + +class SequenceConcatOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0UL, + "Inputs(X) of SequenceConcatOp should not be empty."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceConcatOp should not be null."); + const size_t level = static_cast(ctx->Attrs().Get("level")); + const size_t axis = static_cast(ctx->Attrs().Get("axis")); + PADDLE_ENFORCE(level == 0UL || level == 1UL, + "Sequence Concat Op only support one or two sequence now."); + auto ins_dims = ctx->GetInputsDim("X"); + framework::DDim out_dims = ins_dims[0]; + const size_t n = ins_dims.size(); + for (size_t i = 1; i < n; i++) { + out_dims[axis] += ins_dims[i][axis]; + } + ctx->SetOutputDim("Out", out_dims); + } +}; + +class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceConcatOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "Multip LodTensors, the variable-length inputs of " + "SequenceConcatOp") + .AsDuplicable(); + AddOutput("Out", + "A float LodTensor, the variable-length output of " + "SequenceConcatOp."); + AddAttr("axis", + "The axis which the inputs will be joined with." + "If axis is 0, the inputs will be joined with Lod index.") + .SetDefault(0); + AddAttr("level", + "The level which the inputs will be joined with." + "If level is 0, the inputs will be joined with word." + "If level is 1, the inputs will be joined with sentence.") + .SetDefault(0); + AddComment(R"DOC( + SequenceConcatOp concat multip LodTensors and only supports one or two levels. + - Case1: + axis is 1, level is 1, the Lod of Inputs are the same, + LoD(x0) = {{0,2,4},{0,1,2,3,4}}; Dims(x0) = (2,3,4) + LoD(x1) = {{0,2,4},{0,1,2,3,4}}; Dims(x1) = (2,4,4) + LoD(Out) = {{0,2,4},{01,2,3,4}}; Dims(Out) = (2,7,4) + - Case2: + If axis is 0, level is 1, the Lod of inputs are different, + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (2,3,4) + LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (3,3,4) + LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}}; Dims(Out) = (5,3,4) + )DOC"); + } +}; + +class SequenceConcatGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Gradient of Out should not be null."); + PADDLE_ENFORCE_GT(ctx->Outputs(framework::GradVarName("X")).size(), 0UL, + "Gradient of X should not be empty.") + ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sequence_concat, ops::SequenceConcatOp, ops::SequenceConcatOpMaker, + sequence_concat_grad, ops::SequenceConcatGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_concat, + ops::SequenceConcatOpKernel); +REGISTER_OP_CPU_KERNEL( + sequence_concat_grad, + ops::SequenceConcatGradOpKernel); diff --git a/paddle/operators/sequence_concat_op.h b/paddle/operators/sequence_concat_op.h new file mode 100644 index 0000000000000000000000000000000000000000..79e372a797b2a59ddc8b2f443443711ab7a09e39 --- /dev/null +++ b/paddle/operators/sequence_concat_op.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; + +// Concat Lod, the initialized Lod of Output is lod(x0), +// if axis is not 0, the LoD(Out) will be the same as Inputs, if axis is 0: +// Case1: +// There is one level, the Output LoD will be modified: +// LoD(x0) = {{0,2,4}} +// LoD(x1) = {{0,1,5}} +// LoD(Out) = {{0,3,9}} +// Case2: +// There is two level, and concat level is 1, +// the Output LoD will be modified as followed: +// LoD(x0) = {{0,2,4}, {0,1,2,3,4}} +// LoD(x1) = {{0,3,5}, {0,1,3,4,5}} +// LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}} +template +LoD concatLod(const std::vector ins, const size_t axis, + const size_t level) { + auto out_lod = ins[0]->lod(); + const size_t n = ins.size(); + if (axis == 0UL) { + if (level == 0) { + for (size_t i = 1; i < n; i++) { + for (size_t j = 0; j < ins[i]->lod()[0].size(); j++) { + out_lod[0][j] += ins[i]->lod()[0][j]; + } + } + } else if (level == 1) { + for (size_t i = 1; i < n; i++) { + PADDLE_ENFORCE_EQ(ins[i]->NumLevels(), 2UL, + "All the LoDTensors of Inputs(X) should " + "have two level."); + for (size_t j = 0; j < ins[i]->lod()[0].size(); j++) { + out_lod[0].push_back(ins[i]->lod()[0][j]); + } + for (size_t j = 0; j < ins[i]->lod()[1].size(); j++) { + out_lod[1][j] += ins[i]->lod()[1][j]; + } + } + } + } + return out_lod; +} + +template +class SequenceConcatOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + const size_t axis = static_cast(ctx.Attr("axis")); + const size_t level = static_cast(ctx.Attr("level")); + const size_t n = ins.size(); + out->mutable_data(ctx.GetPlace()); + auto out_lod = concatLod(ins, axis, level); + out->set_lod(out_lod); + + auto out_lod_level = out_lod[level]; + for (size_t i = 0; i < out_lod_level.size() - 1; i++) { + Tensor out_t = out->Slice(static_cast(out_lod_level[i]), + static_cast(out_lod_level[i + 1])); + auto out_stride = framework::stride(out_t.dims()); + size_t offset = 0; + + for (size_t j = 0; j < n; j++) { + auto in_lod_level = ins[j]->lod()[level]; + auto in_stride = framework::stride(ins[j]->dims()); + Tensor in_t = ins[j]->Slice(static_cast(in_lod_level[i]), + static_cast(in_lod_level[i + 1])); + size_t axis_dim = in_t.dims()[axis]; + StridedMemcpy(ctx.device_context(), in_t.data(), in_stride, + in_t.dims(), out_stride, out_t.data() + offset); + offset += axis_dim * in_stride[axis]; + } + } + } +}; + +template +class SequenceConcatGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto x_grads = + ctx.MultiOutput(framework::GradVarName("X")); + size_t axis = static_cast(ctx.Attr("axis")); + size_t level = static_cast(ctx.Attr("level")); + const size_t n = x_grads.size(); + + // Set Grad(X) LoD as X + for (size_t i = 0; i < n; i++) { + x_grads[i]->set_lod(ins[i]->lod()); + x_grads[i]->mutable_data(ctx.GetPlace()); + } + + auto out_lod = concatLod(ins, axis, level); + auto out_lod_level = out_lod[level]; + + for (size_t i = 0; i < out_lod_level.size() - 1; i++) { + Tensor out_grad_t = + out_grad->Slice(static_cast(out_lod_level[i]), + static_cast(out_lod_level[i + 1])); + auto out_grad_stride = framework::stride(out_grad_t.dims()); + size_t offset = 0; + + for (size_t j = 0; j < n; j++) { + auto x_grad_lod_level = x_grads[j]->lod()[level]; + auto x_grad_stride = framework::stride(x_grads[j]->dims()); + Tensor x_grad_t = + x_grads[j]->Slice(static_cast(x_grad_lod_level[i]), + static_cast(x_grad_lod_level[i + 1])); + size_t axis_dim = x_grad_t.dims()[axis]; + StridedMemcpy(ctx.device_context(), out_grad_t.data() + offset, + out_grad_stride, out_grad_t.dims(), x_grad_stride, + x_grad_t.data()); + offset += axis_dim * out_grad_stride[axis]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_seq_concat_op.py b/python/paddle/v2/framework/tests/test_seq_concat_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3d40d82ae7b8ca540582cb6e7a80f90381646ca4 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_seq_concat_op.py @@ -0,0 +1,57 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestConcatOp(OpTest): + def set_data(self): + # two level, batch size is 3 + x0 = np.random.random((11, 6, 3)).astype('float32') + lod0 = [[0, 2, 5, 11], [0, 1, 2, 5, 7, 11]] + x1 = np.random.random((11, 8, 3)).astype('float32') + lod1 = [[0, 2, 5, 11], [0, 1, 2, 5, 7, 11]] + axis = 1 + level = 1 + self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]} + self.attrs = {'axis': axis, 'level': level} + outs = [] + for i in range(5): + sub_x0 = x0[lod0[level][i]:lod0[level][i + 1], :] + sub_x1 = x1[lod1[level][i]:lod1[level][i + 1], :] + outs.append(np.concatenate((sub_x0, sub_x1), axis=axis)) + + self.outputs = {'Out': np.concatenate(outs, axis=0)} + + def setUp(self): + self.op_type = "sequence_concat" + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + + +class TestConcatOpDiffLod(TestConcatOp): + def set_data(self): + # two level, batch size is 3 + x0 = np.random.random((12, 6, 3)).astype('float32') + lod0 = [[0, 3, 9, 12], [0, 2, 3, 5, 9, 12]] + x1 = np.random.random((11, 6, 3)).astype('float32') + lod1 = [[0, 2, 5, 11], [0, 1, 2, 5, 7, 11]] + axis = 0 + level = 1 + self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]} + self.attrs = {'axis': axis, 'level': level} + outs = [] + for i in range(5): + sub_x0 = x0[lod0[level][i]:lod0[level][i + 1], :] + sub_x1 = x1[lod1[level][i]:lod1[level][i + 1], :] + outs.append(np.concatenate((sub_x0, sub_x1), axis=axis)) + + self.outputs = {'Out': np.concatenate(outs, axis=0)} + + +if __name__ == '__main__': + unittest.main()