未验证 提交 53d01afe 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix the nan bug when passing all zero values into clip_by_norm_op. (#30777)

上级 3858f458
...@@ -81,7 +81,12 @@ class ClipByNormKernel : public framework::OpKernel<T> { ...@@ -81,7 +81,12 @@ class ClipByNormKernel : public framework::OpKernel<T> {
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
auto temp = (x_norm <= max_norm).template cast<T>(); auto temp = (x_norm <= max_norm).template cast<T>();
auto scaling = temp + (static_cast<T>(1) - temp) * max_norm / x_norm; auto epsilon =
((x_norm <= static_cast<T>(1e-30)).all().template cast<T>()) *
static_cast<T>(1e-6);
auto scaling =
temp + (static_cast<T>(1) - temp) * max_norm / (x_norm + epsilon);
Eigen::array<int, 1> one_dim{{1}}; Eigen::array<int, 1> one_dim{{1}};
Eigen::DSizes<int, 1> m_dsize(input->numel()); Eigen::DSizes<int, 1> m_dsize(input->numel());
if (context.GetPlace() == platform::CPUPlace()) { if (context.GetPlace() == platform::CPUPlace()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册