提交 be3fa792 编写于 作者: Y Yancey1989

add sequence concat op

上级 d7db15f3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_concat_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_concat,
ops::SequenceConcatOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_concat_grad,
ops::SequenceConcatGradOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_concat_op.h"
namespace paddle {
namespace operators {
class SequenceConcatOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0UL,
"Inputs(X) of SequenceConcatOp should not be empty.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceConcatOp should not be null.");
const size_t level = static_cast<size_t>(ctx->Attrs().Get<int>("level"));
const size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE(level == 0UL || level == 1UL,
"Sequence Concat Op only support one or two sequence now.");
auto ins_dims = ctx->GetInputsDim("X");
framework::DDim out_dims = ins_dims[0];
const size_t n = ins_dims.size();
for (size_t i = 1; i < n; i++) {
out_dims[axis] += ins_dims[i][axis];
}
ctx->SetOutputDim("Out", out_dims);
}
};
class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConcatOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"Multip LodTensors, the variable-length inputs of "
"SequenceConcatOp")
.AsDuplicable();
AddOutput("Out",
"A float LodTensor, the variable-length output of "
"SequenceConcatOp.");
AddAttr<int>("axis",
"The axis which the inputs will be joined with."
"If axis is 0, the inputs will be joined with Lod index.")
.SetDefault(0);
AddAttr<int>("level",
"The level which the inputs will be joined with."
"If level is 0, the inputs will be joined with word."
"If level is 1, the inputs will be joined with sentence.")
.SetDefault(0);
AddComment(R"DOC(
SequenceConcatOp concat multip LodTensors and only supports one or two levels.
- Case1:
axis is 1, level is 1, the Lod of Inputs are the same,
LoD(x0) = {{0,2,4},{0,1,2,3,4}}; Dims(x0) = (2,3,4)
LoD(x1) = {{0,2,4},{0,1,2,3,4}}; Dims(x1) = (2,4,4)
LoD(Out) = {{0,2,4},{01,2,3,4}}; Dims(Out) = (2,7,4)
- Case2:
If axis is 0, level is 1, the Lod of inputs are different,
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (2,3,4)
LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (3,3,4)
LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}}; Dims(Out) = (5,3,4)
)DOC");
}
};
class SequenceConcatGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null.");
PADDLE_ENFORCE_GT(ctx->Outputs(framework::GradVarName("X")).size(), 0UL,
"Gradient of X should not be empty.")
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_concat, ops::SequenceConcatOp, ops::SequenceConcatOpMaker,
sequence_concat_grad, ops::SequenceConcatGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_concat,
ops::SequenceConcatOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_concat_grad,
ops::SequenceConcatGradOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
// Concat Lod, the initialized Lod of Output is lod(x0),
// if axis is not 0, the LoD(Out) will be the same as Inputs, if axis is 0:
// Case1:
// There is one level, the Output LoD will be modified:
// LoD(x0) = {{0,2,4}}
// LoD(x1) = {{0,1,5}}
// LoD(Out) = {{0,3,9}}
// Case2:
// There is two level, and concat level is 1,
// the Output LoD will be modified as followed:
// LoD(x0) = {{0,2,4}, {0,1,2,3,4}}
// LoD(x1) = {{0,3,5}, {0,1,3,4,5}}
// LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}}
template <typename T>
LoD concatLod(const std::vector<const T*> ins, const size_t axis,
const size_t level) {
auto out_lod = ins[0]->lod();
const size_t n = ins.size();
if (axis == 0UL) {
if (level == 0) {
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < ins[i]->lod()[0].size(); j++) {
out_lod[0][j] += ins[i]->lod()[0][j];
}
}
} else if (level == 1) {
for (size_t i = 1; i < n; i++) {
PADDLE_ENFORCE_EQ(ins[i]->NumLevels(), 2UL,
"All the LoDTensors of Inputs(X) should "
"have two level.");
for (size_t j = 0; j < ins[i]->lod()[0].size(); j++) {
out_lod[0].push_back(ins[i]->lod()[0][j]);
}
for (size_t j = 0; j < ins[i]->lod()[1].size(); j++) {
out_lod[1][j] += ins[i]->lod()[1][j];
}
}
}
}
return out_lod;
}
template <typename Place, typename T>
class SequenceConcatOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
const size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
const size_t level = static_cast<size_t>(ctx.Attr<int>("level"));
const size_t n = ins.size();
out->mutable_data<T>(ctx.GetPlace());
auto out_lod = concatLod<LoDTensor>(ins, axis, level);
out->set_lod(out_lod);
auto out_lod_level = out_lod[level];
for (size_t i = 0; i < out_lod_level.size() - 1; i++) {
Tensor out_t = out->Slice<T>(static_cast<int>(out_lod_level[i]),
static_cast<int>(out_lod_level[i + 1]));
auto out_stride = framework::stride(out_t.dims());
size_t offset = 0;
for (size_t j = 0; j < n; j++) {
auto in_lod_level = ins[j]->lod()[level];
auto in_stride = framework::stride(ins[j]->dims());
Tensor in_t = ins[j]->Slice<T>(static_cast<int>(in_lod_level[i]),
static_cast<int>(in_lod_level[i + 1]));
size_t axis_dim = in_t.dims()[axis];
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(), in_stride,
in_t.dims(), out_stride, out_t.data<T>() + offset);
offset += axis_dim * in_stride[axis];
}
}
}
};
template <typename Place, typename T>
class SequenceConcatGradOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto* out_grad =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto x_grads =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
size_t level = static_cast<size_t>(ctx.Attr<int>("level"));
const size_t n = x_grads.size();
// Set Grad(X) LoD as X
for (size_t i = 0; i < n; i++) {
x_grads[i]->set_lod(ins[i]->lod());
x_grads[i]->mutable_data<T>(ctx.GetPlace());
}
auto out_lod = concatLod<LoDTensor>(ins, axis, level);
auto out_lod_level = out_lod[level];
for (size_t i = 0; i < out_lod_level.size() - 1; i++) {
Tensor out_grad_t =
out_grad->Slice<T>(static_cast<int>(out_lod_level[i]),
static_cast<int>(out_lod_level[i + 1]));
auto out_grad_stride = framework::stride(out_grad_t.dims());
size_t offset = 0;
for (size_t j = 0; j < n; j++) {
auto x_grad_lod_level = x_grads[j]->lod()[level];
auto x_grad_stride = framework::stride(x_grads[j]->dims());
Tensor x_grad_t =
x_grads[j]->Slice<T>(static_cast<int>(x_grad_lod_level[i]),
static_cast<int>(x_grad_lod_level[i + 1]));
size_t axis_dim = x_grad_t.dims()[axis];
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>() + offset,
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
offset += axis_dim * out_grad_stride[axis];
}
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestConcatOp(OpTest):
def set_data(self):
# two level, batch size is 3
x0 = np.random.random((11, 6, 3)).astype('float32')
lod0 = [[0, 2, 5, 11], [0, 1, 2, 5, 7, 11]]
x1 = np.random.random((11, 8, 3)).astype('float32')
lod1 = [[0, 2, 5, 11], [0, 1, 2, 5, 7, 11]]
axis = 1
level = 1
self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]}
self.attrs = {'axis': axis, 'level': level}
outs = []
for i in range(5):
sub_x0 = x0[lod0[level][i]:lod0[level][i + 1], :]
sub_x1 = x1[lod1[level][i]:lod1[level][i + 1], :]
outs.append(np.concatenate((sub_x0, sub_x1), axis=axis))
self.outputs = {'Out': np.concatenate(outs, axis=0)}
def setUp(self):
self.op_type = "sequence_concat"
self.set_data()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['x0'], 'Out')
class TestConcatOpDiffLod(TestConcatOp):
def set_data(self):
# two level, batch size is 3
x0 = np.random.random((12, 6, 3)).astype('float32')
lod0 = [[0, 3, 9, 12], [0, 2, 3, 5, 9, 12]]
x1 = np.random.random((11, 6, 3)).astype('float32')
lod1 = [[0, 2, 5, 11], [0, 1, 2, 5, 7, 11]]
axis = 0
level = 1
self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]}
self.attrs = {'axis': axis, 'level': level}
outs = []
for i in range(5):
sub_x0 = x0[lod0[level][i]:lod0[level][i + 1], :]
sub_x1 = x1[lod1[level][i]:lod1[level][i + 1], :]
outs.append(np.concatenate((sub_x0, sub_x1), axis=axis))
self.outputs = {'Out': np.concatenate(outs, axis=0)}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册