From 3d39f5c771197f4eb40fc9cb0de0a3998ddca3f2 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 30 Mar 2022 15:44:29 +0800 Subject: [PATCH] Fix unsqueeze op get wrong output shape in compile time infer shape. (#41097) --- paddle/phi/infermeta/unary.cc | 14 +++++++++++--- paddle/phi/infermeta/unary.h | 3 ++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c5cc845625..bbeb14363e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2555,7 +2555,8 @@ void UnfoldInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x, const ScalarArray& axes, MetaTensor* xshape, - MetaTensor* out) { + MetaTensor* out, + MetaConfig config) { const auto& x_dims = x.dims(); // Validity Check: input tensor dims (<6). PADDLE_ENFORCE_LE(x_dims.size(), @@ -2564,7 +2565,13 @@ void UnsqueezeInferMeta(const MetaTensor& x, "Invalid " "dimensions, the rank of Input(X) " "should be in the range of [1, 6] (Eigen limit)")); - if (!axes.GetData().empty()) { + if (!config.is_runtime && axes.FromTensor()) { + // compile time infershape. set all elements to -1. + int output_size = x.dims().size() + axes.GetData().size(); + std::vector vec_out_dims(output_size, -1); + out->set_dtype(x.dtype()); + out->set_dims(phi::make_ddim(vec_out_dims)); + } else if (!axes.GetData().empty()) { std::vector tmp; tmp.reserve(axes.GetData().size()); std::for_each(axes.GetData().begin(), @@ -2575,7 +2582,9 @@ void UnsqueezeInferMeta(const MetaTensor& x, if (x_dims[0] == out_dims[0]) { out->share_lod(x); } + out->set_dtype(x.dtype()); } + // set xshape dims. std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; for (int i = 0; i < x_dims.size(); ++i) { @@ -2583,7 +2592,6 @@ void UnsqueezeInferMeta(const MetaTensor& x, } xshape->set_dims(phi::make_ddim(xshape_dims)); xshape->share_lod(x); - out->set_dtype(x.dtype()); xshape->set_dtype(x.dtype()); } diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 3b6a34cff6..ea902e0d98 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -363,7 +363,8 @@ void UnfoldInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x, const ScalarArray& axes, MetaTensor* xshape, - MetaTensor* out); + MetaTensor* out, + MetaConfig config = MetaConfig()); void OneHotRawInferMeta(const MetaTensor& x, int32_t depth, -- GitLab