提交 8d4e2d4c 编写于 作者: W wanghaoshuang

1. Add unitest for empty sequence case

2. Fix comments and paddle enforce check
上级 9f32b61c
......@@ -25,10 +25,8 @@ class SeqExpandOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SeqExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SeqExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasOutput("Out"));
PADDLE_ENFORCE(
ctx->HasInput("Y"),
"Input(Y) of SeqExpandOp should not be null while repeat == 0.");
......@@ -54,7 +52,7 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker {
"The element numbers of last level in input('Y') "
"must be equal to dims[0] of input('X').");
AddOutput("Out",
"The output of seq_expand op."
"(LodTensor)The output of seq_expand op."
"The lod of output will be as same as input(Y)'s lod.");
AddComment(R"DOC(
Expand input(X) according to LOD of input(Y).
......@@ -69,6 +67,7 @@ Given 2-level a LoDTensor input(X)
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
......@@ -83,6 +82,7 @@ Given a 0-level LoDTensor input(X)
X.dims = [3, 1]
and input(Y)
Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 1-level LoDTensor
Out.lod = [[0, 2, 3, 6]]
Out.data = [a, a, b, c, c, c]
......@@ -96,11 +96,29 @@ Given a 0-level LoDTensor input(X)
X.dims = [3, 2]
and input(Y)
Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 1-level LoDTensor
Out.lod = [[0, 2, 3, 6]]
Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]]
Out.dims = [6, 2]
Case 4:
Given 2-level a LoDTensor input(X)
X.lod = [[0, 2, 3],
[0, 1, 3, 4]]
X.data = [a, b, c, d]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
Out.data = [a, a, a, b, b, b, d, d]
Out.dims = [8, 1]
)DOC");
}
......@@ -112,8 +130,8 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null");
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput("Out"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
......
......@@ -36,7 +36,6 @@ class SeqExpandKernel : public framework::OpKernel<T> {
"The size of last lod level in Input(Y)"
"must be equal to dims[0] of Input(X).");
out->set_lod(y->lod());
out->Resize(y->dims());
auto place = context.GetEigenDevice<Place>();
size_t element_len = framework::product(x_dims) / x_dims[0];
T* out_data = out->mutable_data<T>(context.GetPlace());
......@@ -57,6 +56,18 @@ class SeqExpandKernel : public framework::OpKernel<T> {
}
};
/*
*Given Grad(Out)
*
* Grad(Out).lod = [[0, 2],
* [0, 3, 6]]
* Grad(Out).data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
* Then
* Grad(X).data = [(0.1 + 0.2 + 0.3), (0.4 + 0.5 + 0.6)]
* = [0.6, 1.5]
* Grad(X).lod = Input(X).lod
*
* */
template <typename Place, typename T>
class SeqExpandGradKernel : public framework::OpKernel<T> {
public:
......@@ -68,10 +79,8 @@ class SeqExpandGradKernel : public framework::OpKernel<T> {
auto out_last_level = out->lod().back();
d_x->set_lod(x->lod());
const T* d_out_data = d_out->data<T>();
auto d_out_dims = d_out->dims();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
size_t element_len = framework::product(d_out_dims) / d_out_dims[0];
size_t element_len = d_out->numel() / d_out->dims()[0];
for (size_t i = 0; i < out_last_level.size() - 1; ++i) {
size_t repeat = out_last_level[i + 1] - out_last_level[i];
Eigen::TensorMap<
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册