未验证 提交 9bc1e0a1 编写于 作者: Z Zhang Ting 提交者: GitHub

fix the CI random failure for dist op (#23743)

上级 54d3b5a1
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -169,7 +170,9 @@ static void DistGradFunction(const framework::ExecutionContext& context) { ...@@ -169,7 +170,9 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
// 1: Lp-norm(z), z = x-y, compute dz // 1: Lp-norm(z), z = x-y, compute dz
if (p == 0) { if (p == 0) {
grad_t.device(place) = grad_t * static_cast<T>(0); math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, &grad, static_cast<T>(0));
} else if (p == INFINITY || p == -INFINITY) { } else if (p == INFINITY || p == -INFINITY) {
// p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if // p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if
// j!=i, or equals to sign(z_i) * dout if j=i. // j!=i, or equals to sign(z_i) * dout if j=i.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册