diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c5cc845625479dd76cf13bb82f96dae221af6ecc..bbeb14363e84eaabe6486c27893a3b6c7a881d25 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2555,7 +2555,8 @@ void UnfoldInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x, const ScalarArray& axes, MetaTensor* xshape, - MetaTensor* out) { + MetaTensor* out, + MetaConfig config) { const auto& x_dims = x.dims(); // Validity Check: input tensor dims (<6). PADDLE_ENFORCE_LE(x_dims.size(), @@ -2564,7 +2565,13 @@ void UnsqueezeInferMeta(const MetaTensor& x, "Invalid " "dimensions, the rank of Input(X) " "should be in the range of [1, 6] (Eigen limit)")); - if (!axes.GetData().empty()) { + if (!config.is_runtime && axes.FromTensor()) { + // compile time infershape. set all elements to -1. + int output_size = x.dims().size() + axes.GetData().size(); + std::vector vec_out_dims(output_size, -1); + out->set_dtype(x.dtype()); + out->set_dims(phi::make_ddim(vec_out_dims)); + } else if (!axes.GetData().empty()) { std::vector tmp; tmp.reserve(axes.GetData().size()); std::for_each(axes.GetData().begin(), @@ -2575,7 +2582,9 @@ void UnsqueezeInferMeta(const MetaTensor& x, if (x_dims[0] == out_dims[0]) { out->share_lod(x); } + out->set_dtype(x.dtype()); } + // set xshape dims. std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; for (int i = 0; i < x_dims.size(); ++i) { @@ -2583,7 +2592,6 @@ void UnsqueezeInferMeta(const MetaTensor& x, } xshape->set_dims(phi::make_ddim(xshape_dims)); xshape->share_lod(x); - out->set_dtype(x.dtype()); xshape->set_dtype(x.dtype()); } diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 3b6a34cff610dc4db85a906456842b467a851783..ea902e0d98eca22df2b178c91e035016714d47f0 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -363,7 +363,8 @@ void UnfoldInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x, const ScalarArray& axes, MetaTensor* xshape, - MetaTensor* out); + MetaTensor* out, + MetaConfig config = MetaConfig()); void OneHotRawInferMeta(const MetaTensor& x, int32_t depth,