未验证 提交 4beaa754 编写于 作者: W Wei Shengyu 提交者: GitHub

mark extra attr for unsqueeze2 (#35528)

* mark extra attr for unsqueeze2

* debug for inference
上级 f05e444a
...@@ -253,10 +253,7 @@ class Unsqueeze2Op : public UnsqueezeOp { ...@@ -253,10 +253,7 @@ class Unsqueeze2Op : public UnsqueezeOp {
UnsqueezeOp::InferShape(ctx); UnsqueezeOp::InferShape(ctx);
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ( if (!ctx->HasOutput("XShape")) return;
ctx->HasOutput("XShape"), true,
platform::errors::InvalidArgument("Output(XShape) of Unsqueeze "
"operator should not be null."));
std::vector<int64_t> xshape_dims(x_dims.size() + 1); std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0; xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) { for (int i = 0; i < x_dims.size(); ++i) {
...@@ -274,7 +271,8 @@ class Unsqueeze2OpMaker : public UnsqueezeOpMaker { ...@@ -274,7 +271,8 @@ class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
AddOutput("XShape", AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will " "XShape is just used to store the shape and lod of X, which will "
"be used in UnsqueezeGradOp.") "be used in UnsqueezeGradOp.")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册