diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index b556062343d627590f6d94d69bebce479d4f0d86..532eabdef437c7549cfe83c69c105e898a608521 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -66,7 +66,6 @@ black_ops_list = [ # kernel performs same to it. prim_white_list = [ "matmul_double_grad", - "tanh_double_grad", "subtract_double_grad", "silu_double_grad", ] diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index e6b29051d7c8512b3653e25ab0ccea70241b19e5..f9868df491388b3d9dc30d6b59ed88a9d90cacdb 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -2017,6 +2017,7 @@ func : tanh_double_grad composite : tanh_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad) inplace : (grad_x_grad -> grad_out_grad) + backward : tanh_triple_grad - backward_op : tanh_grad forward : tanh (Tensor x) -> Tensor(out) @@ -2042,6 +2043,18 @@ func : tanh_shrink_grad inplace : (out_grad -> x_grad) +- backward_op : tanh_triple_grad + forward : tanh_double_grad (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward) -> Tensor(grad_out_new), Tensor(grad_out_grad) + args : (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward, Tensor grad_out_new_grad, Tensor grad_out_grad_grad) + output : Tensor(out_grad), Tensor(grad_out_forward_grad), Tensor(grad_x_grad_forward_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [out, out, grad_x_grad_forward] + kernel : + func : tanh_triple_grad + inplace : (grad_x_grad_forward -> grad_out_forward_grad) + optional : grad_out_new_grad, grad_out_grad_grad + - backward_op : temporal_shift_grad forward : temporal_shift(Tensor x, int seg_num, float shift_ratio = 0.25f, str data_format = "NCHW") -> Tensor(out) args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 0d77b08acbadd56e174df4a2ed3ffcef35bc4785..bbe3017e27eba704fedd5273dab4c9b6649d1af1 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2323,7 +2323,7 @@ attrs : [bool use_mkldnn = false, bool use_cudnn = false] - op : tanh - backward : tanh_grad, tanh_double_grad (tanh_grad_grad) + backward : tanh_grad, tanh_double_grad (tanh_grad_grad), tanh_triple_grad inputs : x : X outputs :