未验证 提交 7c4a3556 编写于 作者: L lxd-cumt 提交者: GitHub

add tanh_triple_grad composite logic (#56072)

* decompose tanh_triple_grad and add it into prim_white_list test=develop

* fix TanhTripleGradKernel bugs test=develop

* decompose tanh_triple_grad test=develop
上级 110f769d
...@@ -72,7 +72,7 @@ prim_white_list = [ ...@@ -72,7 +72,7 @@ prim_white_list = [
"subtract_double_grad", "subtract_double_grad",
"add_triple_grad", "add_triple_grad",
"silu_double_grad", "silu_double_grad",
"tanh_double_grad", "tanh_triple_grad",
] ]
# dict of special api that forward api's output will affect bacward api's output # dict of special api that forward api's output will affect bacward api's output
......
...@@ -53,6 +53,76 @@ void tanh_double_grad(const Tensor& out, ...@@ -53,6 +53,76 @@ void tanh_double_grad(const Tensor& out,
} }
} }
template <typename T>
void tanh_triple_grad(const Tensor& out,
const Tensor& grad_out_forward,
const Tensor& grad_x_grad_forward,
const paddle::optional<Tensor>& grad_out_new_grad,
const paddle::optional<Tensor>& grad_out_grad_grad,
Tensor* out_grad,
Tensor* grad_out_forward_grad,
Tensor* grad_x_grad_forward_grad) {
if (out_grad) {
if (grad_out_grad_grad) {
if (grad_out_new_grad) {
auto out_grad_tmp =
(-2 * out * grad_x_grad_forward * grad_out_grad_grad.get()) -
(2 * grad_out_forward * grad_x_grad_forward *
grad_out_new_grad.get());
set_output<T>(out_grad_tmp, out_grad);
} else {
auto out_grad_tmp =
-2 * out * grad_x_grad_forward * grad_out_grad_grad.get();
set_output<T>(out_grad_tmp, out_grad);
}
} else {
if (grad_out_new_grad) {
auto out_grad_tmp = -(2 * grad_out_forward * grad_x_grad_forward *
grad_out_new_grad.get());
set_output<T>(out_grad_tmp, out_grad);
} else {
auto out_grad_tmp = 0 * out;
set_output<T>(out_grad_tmp, out_grad);
}
}
}
if (grad_out_forward_grad) {
if (grad_out_new_grad) {
auto grad_out_forward_grad_tmp =
-2 * out * grad_x_grad_forward * grad_out_new_grad.get();
set_output<T>(grad_out_forward_grad_tmp, grad_out_forward_grad);
} else {
auto grad_out_forward_grad_tmp = 0 * out;
set_output<T>(grad_out_forward_grad_tmp, grad_out_forward_grad);
}
}
if (grad_x_grad_forward_grad) {
if (grad_out_grad_grad) {
if (grad_out_new_grad) {
auto grad_x_grad_forward_grad_tmp =
(1 - (out * out)) * grad_out_grad_grad.get() -
2 * out * grad_out_forward * grad_out_new_grad.get();
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
} else {
auto grad_x_grad_forward_grad_tmp =
(1 - (out * out)) * grad_out_grad_grad.get();
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
}
} else {
if (grad_out_new_grad) {
auto grad_x_grad_forward_grad_tmp =
-(2 * out * grad_out_forward * grad_out_new_grad.get());
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
} else {
auto grad_x_grad_forward_grad_tmp = 0 * grad_x_grad_forward;
set_output<T>(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad);
}
}
}
}
template <typename T> template <typename T>
void matmul_double_grad(const Tensor& x, void matmul_double_grad(const Tensor& x,
const Tensor& y, const Tensor& y,
......
...@@ -2345,6 +2345,7 @@ ...@@ -2345,6 +2345,7 @@
param : [out, out, grad_x_grad_forward] param : [out, out, grad_x_grad_forward]
kernel : kernel :
func : tanh_triple_grad func : tanh_triple_grad
composite : tanh_triple_grad(out, grad_out_forward, grad_x_grad_forward, grad_out_new_grad, grad_out_grad_grad, out_grad, grad_out_forward_grad, grad_x_grad_forward_grad)
inplace : (grad_x_grad_forward -> grad_out_forward_grad) inplace : (grad_x_grad_forward -> grad_out_forward_grad)
optional : grad_out_new_grad, grad_out_grad_grad optional : grad_out_new_grad, grad_out_grad_grad
......
...@@ -189,11 +189,11 @@ void TanhTripleGradKernel(const Context& dev_ctx, ...@@ -189,11 +189,11 @@ void TanhTripleGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(d_dout); dev_ctx.template Alloc<T>(d_dout);
} }
if (d_out_new) { if (d_out_new) {
d_dout->Resize(out.dims()); d_out_new->Resize(out.dims());
dev_ctx.template Alloc<T>(d_out_new); dev_ctx.template Alloc<T>(d_out_new);
} }
if (d_ddx) { if (d_ddx) {
d_dout->Resize(ddx.dims()); d_ddx->Resize(ddx.dims());
dev_ctx.template Alloc<T>(d_ddx); dev_ctx.template Alloc<T>(d_ddx);
} }
funcs::TanhTripleGradFunctor<T> functor; funcs::TanhTripleGradFunctor<T> functor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册