From d966301e620dd14dce82776d57402713e7d6f32b Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Wed, 29 Mar 2023 14:28:54 +0800 Subject: [PATCH] tanh_double_grad_rules (#52192) * tanh_double_grad_rules * delete log got api_base * modify composite yaml * optimize rules --- .../generator/eager_gen.py | 5 ++++- .../composite_backward_api.h | 20 +++++++++++++++++++ paddle/phi/api/yaml/backward.yaml | 1 + paddle/phi/api/yaml/generator/api_base.py | 1 + 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 1325f7b1e5c..a8a4bcae481 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -63,7 +63,10 @@ black_ops_list = [ # white ops list whose kernel can be deleted after performance analysis # original kernel and its derivative kernel can be deleted when composite_grad # kernel performs same to it. -prim_white_list = ["matmul_double_grad"] +prim_white_list = [ + "matmul_double_grad", + "tanh_double_grad", +] # dict of special api that forward api's output will affect bacward api's output # bacward api's output usually affected by backward api's input diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index a90160f260a..35059ab2ddf 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -125,6 +125,26 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { set_output(grad_x_tmp, grad_x); } +template +void tanh_double_grad(const Tensor& out, + const Tensor& grad_out, + const Tensor& grad_x_grad, + Tensor* out_grad, + Tensor* grad_out_grad) { + // tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out * + // ddx) + auto out_m_grad_x_grad = out * grad_x_grad; + if (out_grad) { + auto out_grad_tmp = -2 * grad_out * out_m_grad_x_grad; + set_output(out_grad_tmp, out_grad); + } + + if (grad_out_grad) { + auto grad_out_grad_tmp = grad_x_grad - out * out_m_grad_x_grad; + set_output(grad_out_grad_tmp, grad_out_grad); + } +} + template void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) { if (grad_x) { diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index bc85748e23a..f4be9a37c62 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1655,6 +1655,7 @@ param : [out, out] kernel : func : tanh_double_grad + composite : tanh_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad) backward : tanh_triple_grad inplace : (grad_x_grad -> grad_out_grad) diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 3f682ecfec1..93ead4586e8 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -1335,4 +1335,5 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ else: invoke_code = self.invoke params_code = self.get_define_args() + return self.gene_invoke_code(invoke_code, params_code) -- GitLab