From 436a9f142d62b0233dc5e1bc259458b05d9c1632 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 6 Aug 2021 10:28:14 +0800 Subject: [PATCH] fix log_softmax if any dimension is 0-d (#34635) --- paddle/fluid/operators/log_softmax_op.h | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/log_softmax_op.h b/paddle/fluid/operators/log_softmax_op.h index c732ec5a2da..162087a7566 100644 --- a/paddle/fluid/operators/log_softmax_op.h +++ b/paddle/fluid/operators/log_softmax_op.h @@ -131,8 +131,10 @@ class LogSoftmaxKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); - LogSoftmaxFunctor()( - context.template device_context(), X, Out, axis); + if (X->numel() != 0) { + LogSoftmaxFunctor()( + context.template device_context(), X, Out, axis); + } } }; @@ -183,8 +185,11 @@ class LogSoftmaxGradKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); - LogSoftmaxGradFunctor()( - context.template device_context(), Out, dOut, dX, axis); + if (Out->numel() != 0) { + LogSoftmaxGradFunctor()( + context.template device_context(), Out, dOut, dX, + axis); + } } }; -- GitLab