未验证 提交 436a9f14 编写于 作者: C cc 提交者: GitHub

fix log_softmax if any dimension is 0-d (#34635)

上级 06651c48
...@@ -131,8 +131,10 @@ class LogSoftmaxKernel : public framework::OpKernel<T> { ...@@ -131,8 +131,10 @@ class LogSoftmaxKernel : public framework::OpKernel<T> {
// allocate memory on device. // allocate memory on device.
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
LogSoftmaxFunctor<DeviceContext, T>()( if (X->numel() != 0) {
context.template device_context<DeviceContext>(), X, Out, axis); LogSoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Out, axis);
}
} }
}; };
...@@ -183,8 +185,11 @@ class LogSoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -183,8 +185,11 @@ class LogSoftmaxGradKernel : public framework::OpKernel<T> {
// allocate memory on device. // allocate memory on device.
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
LogSoftmaxGradFunctor<DeviceContext, T>()( if (Out->numel() != 0) {
context.template device_context<DeviceContext>(), Out, dOut, dX, axis); LogSoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Out, dOut, dX,
axis);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册