未验证 提交 4beaa754 编写于 作者: W Wei Shengyu 提交者: GitHub

mark extra attr for unsqueeze2 (#35528)

* mark extra attr for unsqueeze2

* debug for inference
上级 f05e444a
......@@ -253,10 +253,7 @@ class Unsqueeze2Op : public UnsqueezeOp {
UnsqueezeOp::InferShape(ctx);
const auto &x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
ctx->HasOutput("XShape"), true,
platform::errors::InvalidArgument("Output(XShape) of Unsqueeze "
"operator should not be null."));
if (!ctx->HasOutput("XShape")) return;
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
......@@ -274,7 +271,8 @@ class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in UnsqueezeGradOp.")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册