diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 2194435af3c6637b9e2e2f1e4ea47f61f2f45d27..9f7eba86b8fdb14d666614ce7423b42b6e9befcb 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -88,10 +88,14 @@ void BmmGradInferMeta(const MetaTensor& x, const MetaTensor& out_grad, MetaTensor* x_grad, MetaTensor* y_grad) { - x_grad->set_dims(x.dims()); - y_grad->set_dims(y.dims()); - x_grad->set_dtype(x.dtype()); - y_grad->set_dtype(y.dtype()); + if (x_grad) { + x_grad->set_dims(x.dims()); + x_grad->set_dtype(x.dtype()); + } + if (y_grad) { + y_grad->set_dims(y.dims()); + y_grad->set_dtype(y.dtype()); + } } void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,