diff --git a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py index 796e0089110da8d57911f78038a5dc36dcee3af9..ac0b01dd4de98e9d5149d90d78cc5eb5ef583b0e 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py @@ -38,6 +38,8 @@ ops_to_fill_zero_for_empty_grads = set( "tanh_triple_grad", "sin_double_grad", "sin_triple_grad", + "cos_double_grad", + "cos_triple_grad", "subtract_double_grad", "divide_double_grad", "log_double_grad", diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 82c2b12b1762dd63c4602c71e7fb91d67cebe9c0..919f69525bbc2700b6a1848e53be7ba35d0d1a9b 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -172,6 +172,18 @@ kernel : func : cholesky_solve_grad +- backward_op : cos_double_grad + forward : cos_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x) + args : (Tensor x, Tensor grad_out, Tensor grad_x_grad) + output : Tensor(x_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, x] + kernel : + func : cos_double_grad + backward : cos_triple_grad + inplace : (grad_x_grad -> grad_out_grad) + - backward_op : cos_grad forward : cos (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -181,8 +193,20 @@ param : [x] kernel : func : cos_grad + backward : cos_double_grad inplace : (out_grad -> x_grad) +- backward_op : cos_triple_grad + forward : cos_double_grad (Tensor x, Tensor grad_out_forward, Tensor grad_x_grad_forward) -> Tensor(grad_x), Tensor(grad_out_grad) + args : (Tensor x, Tensor grad_out_forward, Tensor grad_x_grad_forward, Tensor grad_x_grad, Tensor grad_out_grad_grad) + output : Tensor(x_grad), Tensor(grad_out_forward_grad), Tensor(grad_x_grad_forward_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [x, x, grad_x_grad_forward] + kernel : + func : cos_triple_grad + inplace : (grad_x_grad_forward -> grad_out_forward_grad) + - backward_op : cosh_grad forward : cosh (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 0af8731d5aa22e5b34040a12eb74d61dbc2c6ac3..a0894a9aca8f6333b602ea264cc8d97bca84066f 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -229,7 +229,7 @@ attrs : [bool use_cudnn = true, bool use_mkldnn = false, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()] - op : cos - backward : cos_grad + backward : cos_grad, cos_double_grad, cos_triple_grad inputs : x : X outputs : diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index 8004f4ad80b9ca36adbc44a44415fff248fd6681..847383fc38e94260b9daba84b35b2490820e95d8 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -88,6 +88,14 @@ void SinDoubleGradKernel(const Context& dev_ctx, DenseTensor* dx, DenseTensor* ddout); +template +void CosDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + DenseTensor* dx, + DenseTensor* ddout); + template void TanhDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, @@ -118,6 +126,17 @@ void SinTripleGradKernel(const Context& dev_ctx, DenseTensor* d_dout, DenseTensor* d_ddx); +template +void CosTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& d_dx_new, + const DenseTensor& d_ddout, + DenseTensor* d_x_new, + DenseTensor* d_dout, + DenseTensor* d_ddx); + template void LeakyReluDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index edbe4083fb02f56a797504ad695fac53c4c8c6b3..06485e847d6ada977e44ff9d81cb4dcda34da0b2 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -336,6 +336,7 @@ PD_REGISTER_KERNEL(square_double_grad, phi::dtype::float16, int, int64_t) {} + PD_REGISTER_KERNEL(sin_double_grad, CPU, ALL_LAYOUT, @@ -345,6 +346,7 @@ PD_REGISTER_KERNEL(sin_double_grad, phi::dtype::float16, int, int64_t) {} + PD_REGISTER_KERNEL(sin_triple_grad, CPU, ALL_LAYOUT, @@ -354,6 +356,27 @@ PD_REGISTER_KERNEL(sin_triple_grad, phi::dtype::float16, int, int64_t) {} + +PD_REGISTER_KERNEL(cos_double_grad, + CPU, + ALL_LAYOUT, + phi::CosDoubleGradKernel, + float, + double, + phi::dtype::float16, + int, + int64_t) {} + +PD_REGISTER_KERNEL(cos_triple_grad, + CPU, + ALL_LAYOUT, + phi::CosTripleGradKernel, + float, + double, + phi::dtype::float16, + int, + int64_t) {} + PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 460e6300c4087fa6db63fe85effe6dd2d7c176fc..ccdff93d5b23c7932ef2255f8fc188be87ed3e49 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -117,23 +117,22 @@ struct SinDoubleGradFunctor : public BaseActivationFunctor { DenseTensor* dX, DenseTensor* ddOut) const { auto* d = dev.eigen_device(); - auto ddx = EigenVector::Flatten( - GET_DATA_SAFELY(ddX, "Input", "DDX", "SinDoubleGrad")); + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinDoubleGrad")); auto x = EigenVector::Flatten( - GET_DATA_SAFELY(X, "Input", "X", "SinDoubleGrad")); - // sin DoubleGrad: ddy=cos(x)*ddx, dx=-sin(x)*dy*ddx + GET_DATA_SAFELY(X, "Input", "x", "SinDoubleGrad")); - // calculate dx first, so ddy can inplace ddx - auto dx = EigenVector::Flatten( - GET_DATA_SAFELY(dX, "Output", "DX", "SinDoubleGrad")); - auto dout = EigenVector::Flatten( - GET_DATA_SAFELY(dOut, "Output", "DOut", "SinDoubleGrad")); - dx.device(*d) = -ddx * x.unaryExpr(Sine()) * dout; + // calculate d2x first, so d2d1y can inplace d2d1x + auto d2x = EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "d2x", "SinDoubleGrad")); + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "d1y", "SinDoubleGrad")); + d2x.device(*d) = -d2d1x * x.unaryExpr(Sine()) * d1y; - // calculate ddout - auto ddout = EigenVector::Flatten( - GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SinDoubleGrad")); - ddout.device(*d) = ddx * x.unaryExpr(Cosine()); + // calculate d2d1y + auto d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "SinDoubleGrad")); + d2d1y.device(*d) = d2d1x * x.unaryExpr(Cosine()); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; @@ -221,6 +220,22 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor { } }; +// 1st reverse grad +// y = cos(x) +// x --> y +// d1x = d1y * -sin(x) +// +// 2nd reverse grad +// x, d1y --> d1x +// d2x = -cos(x) * d1y * d2d1x +// d2d1y = -sin(x) * d2d1x +// +// 3rd reverse grad +// x, d1y, d2d1x --> d2x, d2d1y +// d3x = sin(x) * d1y * d2d1x * d3d2x - cos(x) * d2d1x * d3d2d1y +// d3d1y = -cos(x) * d2d1x * d3d2x +// d3d2d1x = -cos(x) * d1y * d3d2x - sin(x) * d3d2d1y + // cosine'(x) = -sin(x) template struct CosGradFunctor : public BaseActivationFunctor { @@ -236,6 +251,80 @@ struct CosGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +// cos''(x) = -cos(x) +template +struct CosDoubleGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* dOut, + const DenseTensor* ddX, + DenseTensor* dX, + DenseTensor* ddOut) const { + auto* d = dev.eigen_device(); + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosDoubleGrad")); + auto x = EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "x", "CosDoubleGrad")); + + // calculate d2x first, so d2d1y can inplace d2d1x + auto d2x = EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "d2x", "CosDoubleGrad")); + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "d1y", "CosDoubleGrad")); + d2x.device(*d) = -d2d1x * x.unaryExpr(Cosine()) * d1y; + + // calculate d2d1y + auto d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "CosDoubleGrad")); + d2d1y.device(*d) = -d2d1x * x.unaryExpr(Sine()); + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template +struct CosTripleGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* ddX, + const DenseTensor* dOut, + const DenseTensor* d_DDOut, + const DenseTensor* d_dx_New, + DenseTensor* d_d_Out, + DenseTensor* d_x_New, + DenseTensor* d_DDx) const { + auto* d = dev.eigen_device(); + auto x = EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "x", "CosTripleGrad")); + auto d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad")); + auto d1y = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad")); + auto d3d2d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad")); + auto d3d2x = EigenVector::Flatten( + GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "CosTripleGrad")); + + auto d3x = EigenVector::Flatten( + GET_DATA_SAFELY(d_x_New, "Output", "d3x", "CosTripleGrad")); + d3x.device(*d) = x.unaryExpr(Sine()) * d1y * d2d1x * d3d2x - + x.unaryExpr(Cosine()) * d2d1x * d3d2d1y; + + auto d3d1y = EigenVector::Flatten( + GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "CosTripleGrad")); + d3d1y.device(*d) = -x.unaryExpr(Cosine()) * d2d1x * d3d2x; + + auto d3d2d1x = EigenVector::Flatten( + GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "CosTripleGrad")); + d3d2d1x.device(*d) = -x.unaryExpr(Cosine()) * d1y * d3d2x - + x.unaryExpr(Sine()) * d3d2d1y; + } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + // cosine(x) = cos(x) template struct CosFunctor : public BaseActivationFunctor { diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 763af5652ee5b2171ecc7667f6229aabd4d49106..5e75909649a65e3c315e6739dbe991c6c4534871 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -437,6 +437,26 @@ PD_REGISTER_KERNEL(sin_triple_grad, int64_t, phi::dtype::float16) {} +PD_REGISTER_KERNEL(cos_double_grad, + GPU, + ALL_LAYOUT, + phi::CosDoubleGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} + +PD_REGISTER_KERNEL(cos_triple_grad, + GPU, + ALL_LAYOUT, + phi::CosTripleGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} + PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index 76082608e8227a18ef3d369253275f495b3d49ef..dd7dadc1e1cf9ec58ae8fa9a66c372531b6b473a 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -646,4 +646,56 @@ void SinTripleGradKernel(const Context& dev_ctx, d_ddx); // output } +template +void CosDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + DenseTensor* dx, + DenseTensor* ddout) { + if (dx) { + dx->Resize(x.dims()); + dev_ctx.template Alloc(dx); + } + if (ddout) { + dev_ctx.template Alloc(ddout); + } + phi::funcs::CosDoubleGradFunctor functor; + functor(dev_ctx, &x, &dout, &ddx, dx, ddout); +} + +template +void CosTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& d_dx_new, + const DenseTensor& d_ddout, + DenseTensor* d_x_new, + DenseTensor* d_dout, + DenseTensor* d_ddx) { + if (d_dout) { + d_dout->Resize(x.dims()); + dev_ctx.template Alloc(d_dout); + } + if (d_x_new) { + d_dout->Resize(x.dims()); + dev_ctx.template Alloc(d_x_new); + } + if (d_ddx) { + d_dout->Resize(ddx.dims()); + dev_ctx.template Alloc(d_ddx); + } + funcs::CosTripleGradFunctor functor; + functor(dev_ctx, + &x, + &ddx, + &dout, + &d_ddout, + &d_dx_new, // input + d_dout, + d_x_new, + d_ddx); // output +} + } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py index f8ec154f92dc256b7f882af52d76e6d8df1183bd..38a894755f464b101f1769c47bc9582064e07c2a 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py @@ -503,6 +503,38 @@ class TestSinDoubleGradCheck(unittest.TestCase): self.func(p) +class TestCosDoubleGradCheck(unittest.TestCase): + def cos_wrapper(self, x): + return paddle.cos(x[0]) + + @prog_scope() + def func(self, place): + shape = [2, 3, 7, 9] + eps = 0.0005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype=dtype) + x.persistable = True + y = paddle.cos(x) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + gradient_checker.double_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + gradient_checker.double_grad_check_for_dygraph( + self.cos_wrapper, [x], y, x_init=x_arr, place=place + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) + + def test_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + class TestPowDoubleGradCheck1(unittest.TestCase): def pow_wrapper(self, x): return paddle.pow(x[0], 2) @@ -690,5 +722,37 @@ class TestPowTripleGradCheck3(unittest.TestCase): self.func(p) +class TestCosTripleGradCheck(unittest.TestCase): + def cos_wrapper(self, x): + return paddle.cos(x[0]) + + @prog_scope() + def func(self, place): + shape = [2, 3, 7, 9] + eps = 0.0005 + dtype = np.float64 + x = layers.data('x', shape, False, dtype=dtype) + x.persistable = True + y = layers.cos(x) + x_arr = np.random.random(shape).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + gradient_checker.triple_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + gradient_checker.triple_grad_check_for_dygraph( + self.cos_wrapper, [x], y, x_init=x_arr, place=place + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) + + def test_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main()