未验证 提交 0c71edc3 编写于 作者: D dyning 提交者: GitHub

operators/flatten_op.cc add AsExtra (#35471)

* operators/flatten_op.cc add AsExtra

* operators/flatten_op.cc add AsExtra

* fix format
上级 7907e241
......@@ -188,8 +188,8 @@ class Flatten2Op : public framework::OperatorWithKernel {
// are the same.
ctx->ShareLoD("X", "Out");
}
OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2");
if (!ctx->HasOutput("XShape")) return;
// OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2");
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
......@@ -207,7 +207,8 @@ class Flatten2OpMaker : public FlattenOpMaker {
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp.")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
}
};
......@@ -281,8 +282,8 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
// are the same.
ctx->ShareLoD("X", "Out");
}
OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2");
if (!ctx->HasOutput("XShape")) return;
// OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2");
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
......@@ -361,7 +362,8 @@ Case 2:
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp.")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册