未验证 提交 3d39f5c7 编写于 作者: X xiongkun 提交者: GitHub

Fix unsqueeze op get wrong output shape in compile time infer shape. (#41097)

上级 e494b73b
......@@ -2555,7 +2555,8 @@ void UnfoldInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const ScalarArray& axes,
MetaTensor* xshape,
MetaTensor* out) {
MetaTensor* out,
MetaConfig config) {
const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE_LE(x_dims.size(),
......@@ -2564,7 +2565,13 @@ void UnsqueezeInferMeta(const MetaTensor& x,
"Invalid "
"dimensions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit)"));
if (!axes.GetData().empty()) {
if (!config.is_runtime && axes.FromTensor()) {
// compile time infershape. set all elements to -1.
int output_size = x.dims().size() + axes.GetData().size();
std::vector<int64_t> vec_out_dims(output_size, -1);
out->set_dtype(x.dtype());
out->set_dims(phi::make_ddim(vec_out_dims));
} else if (!axes.GetData().empty()) {
std::vector<int32_t> tmp;
tmp.reserve(axes.GetData().size());
std::for_each(axes.GetData().begin(),
......@@ -2575,7 +2582,9 @@ void UnsqueezeInferMeta(const MetaTensor& x,
if (x_dims[0] == out_dims[0]) {
out->share_lod(x);
}
out->set_dtype(x.dtype());
}
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
......@@ -2583,7 +2592,6 @@ void UnsqueezeInferMeta(const MetaTensor& x,
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
out->set_dtype(x.dtype());
xshape->set_dtype(x.dtype());
}
......
......@@ -363,7 +363,8 @@ void UnfoldInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x,
const ScalarArray& axes,
MetaTensor* xshape,
MetaTensor* out);
MetaTensor* out,
MetaConfig config = MetaConfig());
void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册