diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 013a58e30ed4c09a1071db157c2425e83fcbb7c5..0d91f7965485f36f8afa9ca2242d66fc5f48cbb1 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -814,7 +814,7 @@ void MultiClassNMSInferMeta(const MetaTensor& bboxes, out->set_dims(phi::make_ddim({-1, box_dims[2] + 2})); out->set_dtype(bboxes.dtype()); - index->set_dims(phi::make_ddim({-1, box_dims[2] + 2})); + index->set_dims(phi::make_ddim({-1, 1})); index->set_dtype(DataType::INT32); nms_rois_num->set_dims(phi::make_ddim({-1})); nms_rois_num->set_dtype(DataType::INT32);