未验证 提交 3d39f5c7 编写于 作者: X xiongkun 提交者: GitHub

Fix unsqueeze op get wrong output shape in compile time infer shape. (#41097)

上级 e494b73b
...@@ -2555,7 +2555,8 @@ void UnfoldInferMeta(const MetaTensor& x, ...@@ -2555,7 +2555,8 @@ void UnfoldInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x,
const ScalarArray& axes, const ScalarArray& axes,
MetaTensor* xshape, MetaTensor* xshape,
MetaTensor* out) { MetaTensor* out,
MetaConfig config) {
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6). // Validity Check: input tensor dims (<6).
PADDLE_ENFORCE_LE(x_dims.size(), PADDLE_ENFORCE_LE(x_dims.size(),
...@@ -2564,7 +2565,13 @@ void UnsqueezeInferMeta(const MetaTensor& x, ...@@ -2564,7 +2565,13 @@ void UnsqueezeInferMeta(const MetaTensor& x,
"Invalid " "Invalid "
"dimensions, the rank of Input(X) " "dimensions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit)")); "should be in the range of [1, 6] (Eigen limit)"));
if (!axes.GetData().empty()) { if (!config.is_runtime && axes.FromTensor()) {
// compile time infershape. set all elements to -1.
int output_size = x.dims().size() + axes.GetData().size();
std::vector<int64_t> vec_out_dims(output_size, -1);
out->set_dtype(x.dtype());
out->set_dims(phi::make_ddim(vec_out_dims));
} else if (!axes.GetData().empty()) {
std::vector<int32_t> tmp; std::vector<int32_t> tmp;
tmp.reserve(axes.GetData().size()); tmp.reserve(axes.GetData().size());
std::for_each(axes.GetData().begin(), std::for_each(axes.GetData().begin(),
...@@ -2575,7 +2582,9 @@ void UnsqueezeInferMeta(const MetaTensor& x, ...@@ -2575,7 +2582,9 @@ void UnsqueezeInferMeta(const MetaTensor& x,
if (x_dims[0] == out_dims[0]) { if (x_dims[0] == out_dims[0]) {
out->share_lod(x); out->share_lod(x);
} }
out->set_dtype(x.dtype());
} }
// set xshape dims.
std::vector<int64_t> xshape_dims(x_dims.size() + 1); std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0; xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) { for (int i = 0; i < x_dims.size(); ++i) {
...@@ -2583,7 +2592,6 @@ void UnsqueezeInferMeta(const MetaTensor& x, ...@@ -2583,7 +2592,6 @@ void UnsqueezeInferMeta(const MetaTensor& x,
} }
xshape->set_dims(phi::make_ddim(xshape_dims)); xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x); xshape->share_lod(x);
out->set_dtype(x.dtype());
xshape->set_dtype(x.dtype()); xshape->set_dtype(x.dtype());
} }
......
...@@ -363,7 +363,8 @@ void UnfoldInferMeta(const MetaTensor& x, ...@@ -363,7 +363,8 @@ void UnfoldInferMeta(const MetaTensor& x,
void UnsqueezeInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x,
const ScalarArray& axes, const ScalarArray& axes,
MetaTensor* xshape, MetaTensor* xshape,
MetaTensor* out); MetaTensor* out,
MetaConfig config = MetaConfig());
void OneHotRawInferMeta(const MetaTensor& x, void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth, int32_t depth,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册