提交 84f471b4 编写于 作者: W wanghaoshuang

Fix comments

上级 8d4e2d4c
...@@ -27,9 +27,7 @@ class SeqExpandOp : public framework::OperatorWithKernel { ...@@ -27,9 +27,7 @@ class SeqExpandOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasOutput("Out")); PADDLE_ENFORCE(ctx->HasOutput("Out"));
PADDLE_ENFORCE( PADDLE_ENFORCE(ctx->HasInput("Y"));
ctx->HasInput("Y"),
"Input(Y) of SeqExpandOp should not be null while repeat == 0.");
framework::DDim out_dim; framework::DDim out_dim;
out_dim = ctx->GetInputDim("Y"); out_dim = ctx->GetInputDim("Y");
ctx->ShareLoD("Y", "Out"); ctx->ShareLoD("Y", "Out");
...@@ -43,14 +41,14 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,14 +41,14 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor or LoDTensor) The input('X') of this operator can be a " "(Tensor or LoDTensor) The input(X) of this operator can be a "
"LoDTensor or a base Tensor."); "LoDTensor or a base Tensor.");
AddInput("Y", AddInput("Y",
"(LoDTensor)The reference input('Y') of seq_expand op." "(LoDTensor)The reference input(Y) of seq_expand op."
"It must be a LoDTensor with k-level(k>0)." "It must be a LoDTensor with k-level(k>0)."
"Input(X) will be expanded according to LOD of input(Y)." "The input(X) will be expanded according to LOD of input(Y)."
"The element numbers of last level in input('Y') " "The element numbers of last level in input(Y) "
"must be equal to dims[0] of input('X')."); "must be equal to dims[0] of input(X).");
AddOutput("Out", AddOutput("Out",
"(LodTensor)The output of seq_expand op." "(LodTensor)The output of seq_expand op."
"The lod of output will be as same as input(Y)'s lod."); "The lod of output will be as same as input(Y)'s lod.");
...@@ -133,7 +131,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { ...@@ -133,7 +131,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput("Out")); PADDLE_ENFORCE(ctx->HasInput("Out"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "The input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册