未验证 提交 436a9f14 编写于 作者: C cc 提交者: GitHub

fix log_softmax if any dimension is 0-d (#34635)

上级 06651c48
......@@ -131,8 +131,10 @@ class LogSoftmaxKernel : public framework::OpKernel<T> {
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
LogSoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Out, axis);
if (X->numel() != 0) {
LogSoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Out, axis);
}
}
};
......@@ -183,8 +185,11 @@ class LogSoftmaxGradKernel : public framework::OpKernel<T> {
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
LogSoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Out, dOut, dX, axis);
if (Out->numel() != 0) {
LogSoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Out, dOut, dX,
axis);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册