diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 1325f7b1e5cdd2b7ea7d694ceb771b9030808e28..a8a4bcae4815f4601fff3ea70549ca203304513f 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -63,7 +63,10 @@ black_ops_list = [ # white ops list whose kernel can be deleted after performance analysis # original kernel and its derivative kernel can be deleted when composite_grad # kernel performs same to it. -prim_white_list = ["matmul_double_grad"] +prim_white_list = [ + "matmul_double_grad", + "tanh_double_grad", +] # dict of special api that forward api's output will affect bacward api's output # bacward api's output usually affected by backward api's input diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index a90160f260ac05a48e9e457e7546dc7da4b0b37d..35059ab2ddfbd96ae1635c5427b88bfbe49da989 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -125,6 +125,26 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { set_output(grad_x_tmp, grad_x); } +template +void tanh_double_grad(const Tensor& out, + const Tensor& grad_out, + const Tensor& grad_x_grad, + Tensor* out_grad, + Tensor* grad_out_grad) { + // tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out * + // ddx) + auto out_m_grad_x_grad = out * grad_x_grad; + if (out_grad) { + auto out_grad_tmp = -2 * grad_out * out_m_grad_x_grad; + set_output(out_grad_tmp, out_grad); + } + + if (grad_out_grad) { + auto grad_out_grad_tmp = grad_x_grad - out * out_m_grad_x_grad; + set_output(grad_out_grad_tmp, grad_out_grad); + } +} + template void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) { if (grad_x) { diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index bc85748e23aab15f66972169ea5a5a9925287f22..f4be9a37c62fbf3cd276166bfe5485d96552bb00 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1655,6 +1655,7 @@ param : [out, out] kernel : func : tanh_double_grad + composite : tanh_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad) backward : tanh_triple_grad inplace : (grad_x_grad -> grad_out_grad) diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 3f682ecfec1055b5bf317f35a738824cba41f630..93ead4586e88ae9a055f50f2066e1754d25f0646 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -1335,4 +1335,5 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ else: invoke_code = self.invoke params_code = self.get_define_args() + return self.gene_invoke_code(invoke_code, params_code)