diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index eda461be95a406a19a6049fda57acab8d19ada01..6d37a31f5456208070a3a3e6f77a1efbd3510ebc 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -436,6 +436,8 @@ void EinsumInferMeta(const std::vector& inputs, << paddle::string::join_strings(output_dims, ","); VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype); VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); + out->set_dims(make_ddim(output_dims)); + out->set_dtype(inputs[0]->dtype()); } void ExpandInferMeta(const MetaTensor& x,