未验证 提交 08dbb33c 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Fix silu backward prim bug (#56280)

* fix silu backward prim bug

* fix silu double grad prim
上级 86bb6a01
......@@ -66,9 +66,21 @@ void silu_grad(const Tensor& x,
const Tensor& out_grad,
Tensor* x_grad) {
if (x_grad) {
auto sigmoid = out / x;
auto res = out_grad * sigmoid * (1.0 + x * (1.0 - sigmoid));
set_output<T>(res, x_grad);
auto org_dtype = x.dtype();
bool need_cast = org_dtype == phi::DataType::FLOAT16 ||
org_dtype == phi::DataType::BFLOAT16;
if (need_cast) {
auto x_cast = cast<T>(x, phi::DataType::FLOAT32);
auto out_cast = cast<T>(out, phi::DataType::FLOAT32);
auto out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32);
auto sigmoid = 1.0 / (1.0 + exp<T>(-x_cast));
auto res = out_grad_cast * sigmoid * (1.0 + x_cast - out_cast);
set_output<T>(cast<T>(res, org_dtype), x_grad);
} else {
auto sigmoid = 1.0 / (1.0 + exp<T>(-x));
auto res = out_grad * sigmoid * (1.0 + x - out);
set_output<T>(res, x_grad);
}
}
}
......
......@@ -439,7 +439,7 @@ void silu_double_grad(const Tensor& x,
const Tensor& grad_x_grad,
Tensor* grad_x,
Tensor* grad_out_grad) {
auto sigmoid = out / x;
auto sigmoid = 1 / (1 + exp<T>(-x));
auto tmp1 = 1 - sigmoid;
auto tmp2 = 1 + tmp1 * x;
if (grad_out_grad) {
......@@ -447,8 +447,7 @@ void silu_double_grad(const Tensor& x,
set_output<T>(ddout, grad_out_grad);
}
if (grad_x) {
auto dx =
sigmoid * grad_x_grad * out_grad * (1 + (tmp2 - x * sigmoid)) * tmp1;
auto dx = sigmoid * grad_x_grad * out_grad * (1 + (tmp2 - out)) * tmp1;
set_output<T>(dx, grad_x);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册