提交 927767b6 编写于 作者: Y Yancey1989

add some checking

上级 a35e82a6
...@@ -23,18 +23,19 @@ class SequenceConcatOp : public framework::OperatorWithKernel { ...@@ -23,18 +23,19 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase* ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0UL, PADDLE_ENFORCE(ctx->HasInputs("X"),
"Inputs(X) of SequenceConcatOp should not be empty."); "Inputs(X) of SequenceConcatOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceConcatOp should not be null."); "Output(Out) of SequenceConcatOp should not be null.");
const size_t level = static_cast<size_t>(ctx->Attrs().Get<int>("level")); const size_t level = static_cast<size_t>(ctx->Attrs().Get<int>("level"));
const size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis")); const size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE(level == 0UL || level == 1UL, PADDLE_ENFORCE(level == 0UL || level == 1UL,
"Sequence Concat Op only support one or two sequence now."); "The sequence_concat operator only accepts sequence "
"or a nested sequence as its input.");
auto ins_dims = ctx->GetInputsDim("X"); auto ins_dims = ctx->GetInputsDim("X");
framework::DDim out_dims = ins_dims[0]; framework::DDim out_dims = ins_dims[0];
const size_t n = ins_dims.size(); const size_t n = ins_dims.size();
for (size_t i = 1; i < n; i++) { for (size_t i = 1; i < n; ++i) {
out_dims[axis] += ins_dims[i][axis]; out_dims[axis] += ins_dims[i][axis];
} }
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
...@@ -47,33 +48,40 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -47,33 +48,40 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"Multip LodTensors, the variable-length inputs of " "The input Multip LoDTensors, which are variable-length "
"SequenceConcatOp") "sequence or nested sequence.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", AddOutput("Out",
"A float LodTensor, the variable-length output of " "A LoDTensor, the variable-length output of "
"SequenceConcatOp."); "sequence_concat Op.");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default 0)"
"The axis which the inputs will be joined with." "The axis which the inputs will be joined with."
"If axis is 0, the inputs will be joined with Lod index.") "If axis is 0, the inputs will be joined with LoD index.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("level", AddAttr<int>("level",
"(int, default 0)"
"The level which the inputs will be joined with." "The level which the inputs will be joined with."
"If level is 0, the inputs will be joined with word." "If level is 0, the inputs will be joined with "
"If level is 1, the inputs will be joined with sentence.") "nested sequences."
"If level is 1, the inputs will be joined with sequences.")
.SetDefault(0); .SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
SequenceConcatOp concat multip LodTensors and only supports one or two levels. The sequence_concat operator concatenates multiple LoDTensors.
It only supports sequences ( LoD Tensor with level=1)
or nested sequences (LoD tensor with level=0) as its inputs.
- Case1: - Case1:
axis is 1, level is 1, the Lod of Inputs are the same, If the axis is 1, level is 1, the LoD of Inputs are the same,
LoD(x0) = {{0,2,4},{0,1,2,3,4}}; Dims(x0) = (2,3,4) LoD(x0) = {{0,2,4},{0,1,2,3,4}}; Dims(x0) = (2,3,4)
LoD(x1) = {{0,2,4},{0,1,2,3,4}}; Dims(x1) = (2,4,4) LoD(x1) = {{0,2,4},{0,1,2,3,4}}; Dims(x1) = (2,4,4)
LoD(Out) = {{0,2,4},{01,2,3,4}}; Dims(Out) = (2,7,4) LoD(Out) = {{0,2,4},{0,1,2,3,4}}; Dims(Out) = (2,7,4)
- Case2: - Case2:
If axis is 0, level is 1, the Lod of inputs are different, If the axis is 0, level is 1, the LoD of inputs are different,
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (2,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (2,3,4)
LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (3,3,4) LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (3,3,4)
LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}}; Dims(Out) = (5,3,4) LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}}; Dims(Out) = (5,3,4)
NOTE: The level of all the inputs should be the same.
)DOC"); )DOC");
} }
}; };
...@@ -85,9 +93,9 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel { ...@@ -85,9 +93,9 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase* ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null."); "The gradient of Out should not be null.");
PADDLE_ENFORCE_GT(ctx->Outputs(framework::GradVarName("X")).size(), 0UL, PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
"Gradient of X should not be empty.") "The gradient of X should not be empty.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
} }
}; };
......
...@@ -23,7 +23,7 @@ using Tensor = framework::Tensor; ...@@ -23,7 +23,7 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD; using LoD = framework::LoD;
// Concat Lod, the initialized Lod of Output is lod(x0), // Concat LoD, the initialized LoD of Output is lod(x0),
// if axis is not 0, the LoD(Out) will be the same as Inputs, if axis is 0: // if axis is not 0, the LoD(Out) will be the same as Inputs, if axis is 0:
// Case1: // Case1:
// There is one level, the Output LoD will be modified: // There is one level, the Output LoD will be modified:
...@@ -37,26 +37,26 @@ using LoD = framework::LoD; ...@@ -37,26 +37,26 @@ using LoD = framework::LoD;
// LoD(x1) = {{0,3,5}, {0,1,3,4,5}} // LoD(x1) = {{0,3,5}, {0,1,3,4,5}}
// LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}} // LoD(Out) = {{0,5,9}, {0,1,2,4,5,6,7,8,9}}
template <typename T> template <typename T>
LoD concatLod(const std::vector<const T*> ins, const size_t axis, LoD concatLoD(const std::vector<const T*> ins, const size_t axis,
const size_t level) { const size_t level) {
auto out_lod = ins[0]->lod(); auto out_lod = ins[0]->lod();
const size_t n = ins.size(); const size_t n = ins.size();
if (axis == 0UL) { if (axis == 0UL) {
if (level == 0) { if (level == 0) {
for (size_t i = 1; i < n; i++) { for (size_t i = 1; i < n; ++i) {
for (size_t j = 0; j < ins[i]->lod()[0].size(); j++) { for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) {
out_lod[0][j] += ins[i]->lod()[0][j]; out_lod[0][j] += ins[i]->lod()[0][j];
} }
} }
} else if (level == 1) { } else if (level == 1) {
for (size_t i = 1; i < n; i++) { PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), 2UL,
PADDLE_ENFORCE_EQ(ins[i]->NumLevels(), 2UL, "If the level is 1, all of the inputs "
"All the LoDTensors of Inputs(X) should " "should be the the nested sequence.");
"have two level."); for (size_t i = 1; i < n; ++i) {
for (size_t j = 0; j < ins[i]->lod()[0].size(); j++) { for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) {
out_lod[0].push_back(ins[i]->lod()[0][j]); out_lod[0].push_back(ins[i]->lod()[0][j]);
} }
for (size_t j = 0; j < ins[i]->lod()[1].size(); j++) { for (size_t j = 0; j < ins[i]->lod()[1].size(); ++j) {
out_lod[1][j] += ins[i]->lod()[1][j]; out_lod[1][j] += ins[i]->lod()[1][j];
} }
} }
...@@ -66,7 +66,7 @@ LoD concatLod(const std::vector<const T*> ins, const size_t axis, ...@@ -66,7 +66,7 @@ LoD concatLod(const std::vector<const T*> ins, const size_t axis,
} }
template <typename Place, typename T> template <typename Place, typename T>
class SequenceConcatOpKernel : public framework::OpKernel { class SequenceConcatOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<LoDTensor>("X"); auto ins = ctx.MultiInput<LoDTensor>("X");
...@@ -74,18 +74,37 @@ class SequenceConcatOpKernel : public framework::OpKernel { ...@@ -74,18 +74,37 @@ class SequenceConcatOpKernel : public framework::OpKernel {
const size_t axis = static_cast<size_t>(ctx.Attr<int>("axis")); const size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
const size_t level = static_cast<size_t>(ctx.Attr<int>("level")); const size_t level = static_cast<size_t>(ctx.Attr<int>("level"));
const size_t n = ins.size(); const size_t n = ins.size();
for (size_t i = 1; i < n; ++i) {
PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), ins[i]->NumLevels(),
"The level number of all the input LoDTensors "
"should be the same.");
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), ins[i]->dims().size(),
"The dimensions size of all the input LoDTensors "
"should be the same.");
const size_t dims_size = ins[i]->dims().size();
for (size_t j = 0; j < dims_size; ++j) {
if (j == axis) continue;
PADDLE_ENFORCE_EQ(ins[0]->dims()[j], ins[i]->dims()[j],
"The dimensions of all the input LoDTensors "
"except for the specify axis should be "
"matched exactly.");
}
}
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto out_lod = concatLod<LoDTensor>(ins, axis, level); auto out_lod = concatLoD<LoDTensor>(ins, axis, level);
out->set_lod(out_lod); out->set_lod(out_lod);
auto out_lod_level = out_lod[level]; auto out_lod_level = out_lod[level];
for (size_t i = 0; i < out_lod_level.size() - 1; i++) { for (size_t i = 0; i < out_lod_level.size() - 1; ++i) {
Tensor out_t = out->Slice<T>(static_cast<int>(out_lod_level[i]), Tensor out_t = out->Slice<T>(static_cast<int>(out_lod_level[i]),
static_cast<int>(out_lod_level[i + 1])); static_cast<int>(out_lod_level[i + 1]));
auto out_stride = framework::stride(out_t.dims()); auto out_stride = framework::stride(out_t.dims());
size_t offset = 0; size_t offset = 0;
for (size_t j = 0; j < n; j++) { for (size_t j = 0; j < n; ++j) {
auto in_lod_level = ins[j]->lod()[level]; auto in_lod_level = ins[j]->lod()[level];
auto in_stride = framework::stride(ins[j]->dims()); auto in_stride = framework::stride(ins[j]->dims());
Tensor in_t = ins[j]->Slice<T>(static_cast<int>(in_lod_level[i]), Tensor in_t = ins[j]->Slice<T>(static_cast<int>(in_lod_level[i]),
...@@ -100,7 +119,7 @@ class SequenceConcatOpKernel : public framework::OpKernel { ...@@ -100,7 +119,7 @@ class SequenceConcatOpKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SequenceConcatGradOpKernel : public framework::OpKernel { class SequenceConcatGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X"); auto ins = ctx.MultiInput<framework::LoDTensor>("X");
...@@ -118,17 +137,17 @@ class SequenceConcatGradOpKernel : public framework::OpKernel { ...@@ -118,17 +137,17 @@ class SequenceConcatGradOpKernel : public framework::OpKernel {
x_grads[i]->mutable_data<T>(ctx.GetPlace()); x_grads[i]->mutable_data<T>(ctx.GetPlace());
} }
auto out_lod = concatLod<LoDTensor>(ins, axis, level); auto out_lod = concatLoD<LoDTensor>(ins, axis, level);
auto out_lod_level = out_lod[level]; auto out_lod_level = out_lod[level];
for (size_t i = 0; i < out_lod_level.size() - 1; i++) { for (size_t i = 0; i < out_lod_level.size() - 1; ++i) {
Tensor out_grad_t = Tensor out_grad_t =
out_grad->Slice<T>(static_cast<int>(out_lod_level[i]), out_grad->Slice<T>(static_cast<int>(out_lod_level[i]),
static_cast<int>(out_lod_level[i + 1])); static_cast<int>(out_lod_level[i + 1]));
auto out_grad_stride = framework::stride(out_grad_t.dims()); auto out_grad_stride = framework::stride(out_grad_t.dims());
size_t offset = 0; size_t offset = 0;
for (size_t j = 0; j < n; j++) { for (size_t j = 0; j < n; ++j) {
auto x_grad_lod_level = x_grads[j]->lod()[level]; auto x_grad_lod_level = x_grads[j]->lod()[level];
auto x_grad_stride = framework::stride(x_grads[j]->dims()); auto x_grad_stride = framework::stride(x_grads[j]->dims());
Tensor x_grad_t = Tensor x_grad_t =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册