diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 47f505e53f4afa53e1a3e7a5ca525a85cbee6c3a..b7959d2809e4f64ef06f719818222c049daf5b41 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -436,6 +436,8 @@ void EinsumInferMeta(const std::vector& inputs, << paddle::string::join_strings(output_dims, ","); VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype); VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); + out->set_dims(make_ddim(output_dims)); + out->set_dtype(inputs[0]->dtype()); } void ExpandInferMeta(const MetaTensor& x,