diff --git a/paddle/fluid/operators/log_softmax_op.h b/paddle/fluid/operators/log_softmax_op.h index c732ec5a2da0abcf608018dc0494f657ca2b4e59..162087a75662d711a63cbbe4beeaecf265367c6a 100644 --- a/paddle/fluid/operators/log_softmax_op.h +++ b/paddle/fluid/operators/log_softmax_op.h @@ -131,8 +131,10 @@ class LogSoftmaxKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); - LogSoftmaxFunctor()( - context.template device_context(), X, Out, axis); + if (X->numel() != 0) { + LogSoftmaxFunctor()( + context.template device_context(), X, Out, axis); + } } }; @@ -183,8 +185,11 @@ class LogSoftmaxGradKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); - LogSoftmaxGradFunctor()( - context.template device_context(), Out, dOut, dX, axis); + if (Out->numel() != 0) { + LogSoftmaxGradFunctor()( + context.template device_context(), Out, dOut, dX, + axis); + } } };