未验证 提交 d966301e 编写于 作者: X xiaoguoguo626807 提交者: GitHub

tanh_double_grad_rules (#52192)

* tanh_double_grad_rules

* delete log got api_base

* modify composite yaml

* optimize rules
上级 d96fbdf1
......@@ -63,7 +63,10 @@ black_ops_list = [
# white ops list whose kernel can be deleted after performance analysis
# original kernel and its derivative kernel can be deleted when composite_grad
# kernel performs same to it.
prim_white_list = ["matmul_double_grad"]
prim_white_list = [
"matmul_double_grad",
"tanh_double_grad",
]
# dict of special api that forward api's output will affect bacward api's output
# bacward api's output usually affected by backward api's input
......
......@@ -125,6 +125,26 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
set_output<T>(grad_x_tmp, grad_x);
}
template <typename T>
void tanh_double_grad(const Tensor& out,
const Tensor& grad_out,
const Tensor& grad_x_grad,
Tensor* out_grad,
Tensor* grad_out_grad) {
// tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out *
// ddx)
auto out_m_grad_x_grad = out * grad_x_grad;
if (out_grad) {
auto out_grad_tmp = -2 * grad_out * out_m_grad_x_grad;
set_output<T>(out_grad_tmp, out_grad);
}
if (grad_out_grad) {
auto grad_out_grad_tmp = grad_x_grad - out * out_m_grad_x_grad;
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
}
template <typename T>
void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) {
if (grad_x) {
......
......@@ -1655,6 +1655,7 @@
param : [out, out]
kernel :
func : tanh_double_grad
composite : tanh_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad)
backward : tanh_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
......
......@@ -1335,4 +1335,5 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
else:
invoke_code = self.invoke
params_code = self.get_define_args()
return self.gene_invoke_code(invoke_code, params_code)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册