diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 688de2700c4bc2a00ae76024d227ca0446a417a8..10ea47e4bfd5dc09d463925bf0a92822cb100b39 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" @@ -457,6 +458,26 @@ class PowDoubleGradOpMaker : public framework::SingleGradOpMaker { op->SetAttrMap(this->Attrs()); } }; +template +class PowTripleGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("pow_triple_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("DOut", this->Input("DOut")); + op->SetInput("DDX", this->Input("DDX")); + op->SetInput("D_DX", this->OutputGrad("DX")); + op->SetInput("D_DDOut", this->OutputGrad("DDOut")); + op->SetOutput("D_X", this->InputGrad("X")); + op->SetOutput("D_DOut", this->InputGrad("DOut")); + op->SetOutput("D_DDX", this->InputGrad("DDX")); + op->SetInput("FactorTensor", this->Input("FactorTensor")); + op->SetAttrMap(this->Attrs()); + } +}; class PowOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -523,6 +544,16 @@ class PowOpDoubleGrad : public framework::OperatorWithKernel { } }; +class PowOpTripleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this, "X"); + } +}; DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"}); } // namespace operators } // namespace paddle @@ -575,6 +606,9 @@ REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor); DECLARE_INFER_SHAPE_FUNCTOR(pow_double_grad, PowDoubleGradInferShapeFunctor, PD_INFER_META(phi::GeneralBinaryGradInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(pow_triple_grad, + PowTripleGradInferShapeFunctor, + PD_INFER_META(phi::GeneralTernaryGradInferMeta)); REGISTER_OPERATOR( pow, @@ -594,7 +628,12 @@ REGISTER_OPERATOR(pow_grad, REGISTER_OPERATOR(pow_double_grad, ops::PowOpDoubleGrad, ops::ActivationDoubleGradOpInplaceInferer, + ops::PowTripleGradOpMaker, + ops::PowTripleGradOpMaker, PowDoubleGradInferShapeFunctor); +REGISTER_OPERATOR(pow_triple_grad, + ops::PowOpTripleGrad, + PowTripleGradInferShapeFunctor); /* ========================================================================== */ /* ========================== register checkpoint ===========================*/ diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index ff66751103548d0335d812f3e3fae6662f5ce0be..a4720905ad54ad4ea9039742e669361e32a43183 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -27,7 +27,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index cda591df9595fa7591d424e9f18d7b8ad37a8fd1..0cf0555f9ef3e392e92f90389388c184d0a74e7e 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1330,9 +1330,10 @@ output : Tensor(x_grad), Tensor(grad_out_grad) infer_meta : func : GeneralBinaryGradInferMeta - param: [x, x] + param: [x, grad_out] kernel : func : pow_double_grad + backward : pow_triple_grad inplace : (grad_x_grad -> x_grad) - backward_op : pow_grad @@ -1347,6 +1348,16 @@ backward: pow_double_grad inplace : (out_grad -> x_grad) +- backward_op : pow_triple_grad + forward : pow_double_grad(Tensor x, Tensor grad_out, Tensor grad_grad_x, Scalar y) -> Tensor(grad_x), Tensor(grad_grad_out) + args : (Tensor x, Tensor grad_out, Tensor grad_grad_x, Tensor grad_x_grad, Tensor grad_grad_out_grad, Scalar y) + output : Tensor(x_grad), Tensor(grad_out_grad), Tensor(grad_grad_x_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param: [x, grad_out, grad_grad_x] + kernel : + func : pow_triple_grad + - backward_op : prelu_grad forward : prelu(Tensor x, Tensor alpha, str data_format, str mode) -> Tensor(out) args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode) diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index edd621244676285c6896cdcf905708b79caf963e..8004f4ad80b9ca36adbc44a44415fff248fd6681 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -226,6 +226,18 @@ void PowDoubleGradKernel(const Context& dev_ctx, const Scalar& factor, DenseTensor* dx, DenseTensor* ddout); + +template +void PowTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& d_dx, + const DenseTensor& d_ddout, + const Scalar& factor, + DenseTensor* out_d_x, + DenseTensor* out_d_dout, + DenseTensor* out_d_ddx); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos); diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 74dc1f3b5c0cb15e4773e3ccbeffc38c37f3dc9b..edbe4083fb02f56a797504ad695fac53c4c8c6b3 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -390,3 +390,11 @@ PD_REGISTER_KERNEL(pow_double_grad, double, int, int64_t) {} +PD_REGISTER_KERNEL(pow_triple_grad, + CPU, + ALL_LAYOUT, + phi::PowTripleGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index a21c6511b03b2b108715b93f126fbf710ca4c1cf..763af5652ee5b2171ecc7667f6229aabd4d49106 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -472,7 +472,6 @@ PD_REGISTER_KERNEL(pow_grad, int64_t, phi::dtype::float16, phi::dtype::bfloat16) {} - PD_REGISTER_KERNEL(pow_double_grad, GPU, ALL_LAYOUT, @@ -483,3 +482,13 @@ PD_REGISTER_KERNEL(pow_double_grad, int64_t, phi::dtype::float16, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(pow_triple_grad, + GPU, + ALL_LAYOUT, + phi::PowTripleGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index 5c15564f524b8a69a790bd370ae00c38a8f46628..76082608e8227a18ef3d369253275f495b3d49ef 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -17,6 +17,7 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/activation_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/activation_functor.h" @@ -347,10 +348,10 @@ void PowDoubleGradKernel(const Context& dev_ctx, DenseTensor* dx, DenseTensor* ddout) { PADDLE_ENFORCE_NOT_NULL( - dx, errors::NotFound("The output DenseTensor dx can not be nullptr")); + dx, errors::NotFound("The output DenseTensor DX can not be nullptr")); PADDLE_ENFORCE_NOT_NULL( ddout, - errors::NotFound("The output DenseTensor ddout can not be nullptr")); + errors::NotFound("The output DenseTensor DDOut can not be nullptr")); float exponent = factor.to(); if (exponent == 1) { *dx = phi::FullLike(dev_ctx, x, static_cast(0)); @@ -366,6 +367,150 @@ void PowDoubleGradKernel(const Context& dev_ctx, *ddout = phi::Scale(dev_ctx, ddout_tmp, exponent, 0.0, true); } +template +void PowTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& d_dx, + const DenseTensor& d_ddout, + const Scalar& factor, + DenseTensor* out_d_x, + DenseTensor* out_d_dout, + DenseTensor* out_d_ddx) { + PADDLE_ENFORCE_NOT_NULL( + out_d_x, + errors::NotFound("The output DenseTensor D_X can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + out_d_dout, + errors::NotFound("The output DenseTensor D_DOut can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + out_d_ddx, + errors::NotFound("The output DenseTensor D_DDX can not be nullptr")); + float exponent = factor.to(); + + if (exponent != 2 && exponent != 1) { + // case1: b != 2 and b != 1 + // D_X = D_DX * DDX * DOut * b * (b-1) * (b-2) * X^(b-3) + // + D_DDOut * DDX * b * (b-1) * X^(b-2) + DenseTensor out_d_x_tmp1 = phi::Multiply(dev_ctx, d_dx, ddx); + DenseTensor out_d_x_tmp2 = + phi::Scale(dev_ctx, + phi::Pow(dev_ctx, x, exponent - 3), + exponent * (exponent - 1) * (exponent - 2), + 0.0, + true); + DenseTensor out_d_x_part1 = phi::Multiply( + dev_ctx, + phi::Multiply(dev_ctx, out_d_x_tmp1, dout), + out_d_x_tmp2); + + DenseTensor out_d_x_tmp3 = phi::Multiply(dev_ctx, d_ddout, ddx); + DenseTensor out_d_x_tmp4 = + phi::Scale(dev_ctx, + phi::Pow(dev_ctx, x, exponent - 2), + exponent * (exponent - 1), + 0.0, + true); + DenseTensor out_d_x_part2 = + phi::Multiply(dev_ctx, out_d_x_tmp3, out_d_x_tmp4); + + *out_d_x = phi::Add(dev_ctx, out_d_x_part1, out_d_x_part2); + + // D_DOut = D_DX * DDX * b * (b-1) * X^(b-2) + DenseTensor out_d_dout_tmp = + phi::Scale(dev_ctx, + phi::Pow(dev_ctx, x, exponent - 2), + exponent * (exponent - 1), + 0.0, + true); + + *out_d_dout = + phi::Multiply(dev_ctx, out_d_x_tmp1, out_d_dout_tmp); + // D_DDX = D_DX * DOut * b * (b-1) * X^(b-2) + D_DDOut * b * X^(b-1) + DenseTensor out_d_ddx_tmp1 = phi::Multiply(dev_ctx, d_dx, dout); + DenseTensor out_d_ddx_part1 = + phi::Multiply(dev_ctx, out_d_ddx_tmp1, out_d_dout_tmp); + + DenseTensor out_d_ddx_tmp2 = + phi::Scale(dev_ctx, + phi::Pow(dev_ctx, x, exponent - 1), + exponent, + 0.0, + true); + DenseTensor out_d_ddx_part2 = + phi::Multiply(dev_ctx, d_ddout, out_d_ddx_tmp2); + + *out_d_ddx = + phi::Add(dev_ctx, out_d_ddx_part1, out_d_ddx_part2); + } else if (exponent == 2) { + // case2: b = 2 + // D_X = D_DDOut * DDX * b * (b-1) * X^(b-2) + DenseTensor out_d_x_tmp1 = phi::Multiply(dev_ctx, d_ddout, ddx); + DenseTensor out_d_x_tmp2 = + phi::Scale(dev_ctx, + phi::Pow(dev_ctx, x, exponent - 2), + exponent * (exponent - 1), + 0.0, + true); + + *out_d_x = phi::Multiply(dev_ctx, out_d_x_tmp1, out_d_x_tmp2); + // D_DOut = D_DX * DDX * b * (b-1) * X^(b-2) + DenseTensor out_d_dout_tmp1 = phi::Multiply(dev_ctx, d_dx, ddx); + DenseTensor out_d_dout_tmp2 = + phi::Scale(dev_ctx, + phi::Pow(dev_ctx, x, exponent - 2), + exponent * (exponent - 1), + 0.0, + true); + + *out_d_dout = + phi::Multiply(dev_ctx, out_d_dout_tmp1, out_d_dout_tmp2); + // D_DDX = D_DX * DOut * b * (b-1) * X^(b-2) + D_DDOut * b * X^(b-1) + DenseTensor out_d_ddx_tmp1 = phi::Multiply(dev_ctx, d_dx, dout); + DenseTensor out_d_ddx_part1 = + phi::Multiply(dev_ctx, out_d_ddx_tmp1, out_d_dout_tmp2); + + DenseTensor out_d_ddx_tmp2 = + phi::Scale(dev_ctx, + phi::Pow(dev_ctx, x, exponent - 1), + exponent, + 0.0, + true); + DenseTensor out_d_ddx_part2 = + phi::Multiply(dev_ctx, d_ddout, out_d_ddx_tmp2); + + *out_d_ddx = + phi::Add(dev_ctx, out_d_ddx_part1, out_d_ddx_part2); + } else { + // case3: b = 1 + // D_X = D_DX * DDX * DOut * b * (b-1) * (b-2) * X^(b-3) + DenseTensor out_d_x_tmp1 = phi::Multiply(dev_ctx, d_dx, ddx); + DenseTensor out_d_x_tmp2 = + phi::Scale(dev_ctx, + phi::Pow(dev_ctx, x, exponent - 3), + exponent * (exponent - 1) * (exponent - 2), + 0.0, + true); + + *out_d_x = phi::Multiply( + dev_ctx, + phi::Multiply(dev_ctx, out_d_x_tmp1, dout), + out_d_x_tmp2); + // D_DOut = 0 + *out_d_dout = phi::FullLike(dev_ctx, dout, static_cast(0)); + // D_DDX = D_DDOut * b * X^(b-1) + DenseTensor out_d_ddx_tmp = + phi::Scale(dev_ctx, + phi::Pow(dev_ctx, x, exponent - 1), + exponent, + 0.0, + true); + + *out_d_ddx = phi::Multiply(dev_ctx, d_ddout, out_d_ddx_tmp); + } +} + template void SqrtDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index 2436198403c642c7ec8a33def816c00e36af19b4..fbff006ee93af69450796308037b5db7a46d4495 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -83,6 +83,21 @@ KernelSignature PowDoubleGradOpArgumentMapping( "pow_double_grad", {"X", "DOut", "DDX"}, {"factor"}, {"DX", "DDOut"}); } } + +KernelSignature PowTripleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.HasInput("FactorTensor")) { + return KernelSignature("pow_triple_grad", + {"X", "DOut", "DDX", "D_DX", "D_DDOut"}, + {"FactorTensor"}, + {"D_X", "D_DOut", "D_DDX"}); + } else { + return KernelSignature("pow_triple_grad", + {"X", "DOut", "DDX", "D_DX", "D_DDOut"}, + {"factor"}, + {"D_X", "D_DOut", "D_DDX"}); + } +} } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(brelu, hard_tanh); @@ -100,4 +115,6 @@ PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow_double_grad, phi::PowDoubleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(pow_triple_grad, + phi::PowTripleGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py index ea9ac9fa9d2c85e284473efffe3ffa3f7543c8f6..f8ec154f92dc256b7f882af52d76e6d8df1183bd 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py @@ -597,5 +597,98 @@ class TestSinTripleGradCheck(unittest.TestCase): self.func(p) +class TestPowTripleGradCheck1(unittest.TestCase): + def pow_wrapper(self, x): + return paddle.pow(x[0], 1) + + @prog_scope() + def func(self, place): + shape = [2, 3, 7, 9] + eps = 1e-6 + dtype = np.float64 + x = layers.data('x', shape, False, dtype=dtype) + x.persistable = True + y = paddle.pow(x, 1) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + gradient_checker.triple_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + gradient_checker.triple_grad_check_for_dygraph( + self.pow_wrapper, [x], y, x_init=x_arr, place=place + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) + + def test_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestPowTripleGradCheck2(unittest.TestCase): + def pow_wrapper(self, x): + return paddle.pow(x[0], 2) + + @prog_scope() + def func(self, place): + shape = [2, 3, 7, 9] + eps = 1e-6 + dtype = np.float64 + x = layers.data('x', shape, False, dtype=dtype) + x.persistable = True + y = paddle.pow(x, 2) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + gradient_checker.triple_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + gradient_checker.triple_grad_check_for_dygraph( + self.pow_wrapper, [x], y, x_init=x_arr, place=place + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) + + def test_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestPowTripleGradCheck3(unittest.TestCase): + def pow_wrapper(self, x): + return paddle.pow(x[0], 4) + + @prog_scope() + def func(self, place): + shape = [2, 3, 7, 9] + eps = 1e-6 + dtype = np.float64 + x = layers.data('x', shape, False, dtype=dtype) + x.persistable = True + y = paddle.pow(x, 4) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + gradient_checker.triple_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + gradient_checker.triple_grad_check_for_dygraph( + self.pow_wrapper, [x], y, x_init=x_arr, place=place + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) + + def test_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main()