diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 2958b7c0ad8ddd4f26676edb639094e771e24f5c..529986e6ebe1cabb34c8e226e42c5478dc20bc64 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1500,6 +1500,16 @@ kernel : func : poisson_grad +- backward_op : polygamma_grad + forward : polygamma (Tensor x, int n) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int n) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : polygamma_grad + - backward_op : pow_double_grad forward : pow_grad(Tensor x, Tensor grad_out, Scalar y) -> Tensor(grad_x) args : (Tensor x, Tensor grad_out, Tensor grad_x_grad, Scalar y) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 83f62e2bf598058ac35cf7b38e5e4591b74d8150..ef4ebb6f6a14dad146bc041b5eab2159605b6be0 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1725,6 +1725,16 @@ func : poisson backward : poisson_grad +- op : polygamma + args : (Tensor x, int n) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : polygamma + backward : polygamma_grad + - op : pow args : (Tensor x, Scalar y=1.0f) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/polygamma_grad_kernel.cc b/paddle/phi/kernels/cpu/polygamma_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a131214b4fbe174601a927977dc62b1655db481 --- /dev/null +++ b/paddle/phi/kernels/cpu/polygamma_grad_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/polygamma_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/impl/polygamma_kernel_impl.h" + +namespace phi { + +template +void PolygammaGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const int n, + DenseTensor* x_grad) { + auto size = x.numel(); + auto* x_data = x.data(); + auto* out_grad_data = out_grad.data(); + auto* x_gard_data = ctx.template Alloc(x_grad); + + phi::funcs::ForRange for_range(ctx, size); + PolygammaGradFunctor functor( + x_data, n + 1, out_grad_data, x_gard_data, size); + for_range(functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + polygamma_grad, CPU, ALL_LAYOUT, phi::PolygammaGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/polygamma_kernel.cc b/paddle/phi/kernels/cpu/polygamma_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..61838033fd9187fa29883503dc13fbddb49bd7be --- /dev/null +++ b/paddle/phi/kernels/cpu/polygamma_kernel.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/polygamma_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/impl/polygamma_kernel_impl.h" + +namespace phi { + +template +void PolygammaKernel(const Context& ctx, + const DenseTensor& x, + const int n, + DenseTensor* out) { + const int64_t size = x.numel(); + const T* x_data = x.data(); + T* out_data = ctx.template Alloc(out); + + phi::funcs::ForRange for_range(ctx, size); + PolygammaFunctor functor(x_data, n, out_data, size); + for_range(functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + polygamma, CPU, ALL_LAYOUT, phi::PolygammaKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/polygamma_grad_kernel.cu b/paddle/phi/kernels/gpu/polygamma_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..61353d2ce6bb53964240ee705e3801d5beb98bf4 --- /dev/null +++ b/paddle/phi/kernels/gpu/polygamma_grad_kernel.cu @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/polygamma_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/polygamma_kernel_impl.h" + +namespace phi { + +template +void PolygammaGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const int n, + DenseTensor* x_grad) { + ctx.template Alloc(x_grad); + std::vector ins = {&x, &out_grad}; + std::vector outs = {x_grad}; + auto functor = CudaPolygammaGradFunctor(n + 1); + phi::funcs::ElementwiseKernel(ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + polygamma_grad, GPU, ALL_LAYOUT, phi::PolygammaGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/polygamma_kernel.cu b/paddle/phi/kernels/gpu/polygamma_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a858e58aac28a664efbb6e217594a9fd7e102928 --- /dev/null +++ b/paddle/phi/kernels/gpu/polygamma_kernel.cu @@ -0,0 +1,40 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/polygamma_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/polygamma_kernel_impl.h" + +namespace phi { + +template +void PolygammaKernel(const Context& ctx, + const DenseTensor& x, + const int n, + DenseTensor* out) { + ctx.template Alloc(out); + std::vector ins = {&x}; + std::vector outs = {out}; + + auto functor = CudaPolygammaFunctor(n); + phi::funcs::ElementwiseKernel(ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + polygamma, GPU, ALL_LAYOUT, phi::PolygammaKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/polygamma_kernel_impl.h b/paddle/phi/kernels/impl/polygamma_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..8b4274b0882c84f0e01405d5e1cbce66de9bf78f --- /dev/null +++ b/paddle/phi/kernels/impl/polygamma_kernel_impl.h @@ -0,0 +1,276 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#else +#include "paddle/phi/kernels/funcs/for_range.h" +#endif + +namespace phi { + +#if defined(__NVCC__) || defined(__HIPCC__) +template +__host__ __device__ T zeta(T x, T q) { + /* + * REFERENCE: + * Gradshteyn, I. S., and I. M. Ryzhik, Tables of Integrals, + * Series, and Products, p. 1073; Academic Press, 1980. + * From https://netlib.org/cephes/doubldoc.html - zeta.c + */ + const T MACHEP = T{1.11022302462515654042E-16}; + constexpr T zero = T{0.0}; + constexpr T half = T{0.5}; + constexpr T one = T{1.0}; + static const T A[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, /*1.307674368e12/691*/ + 7.47242496e10, + -2.950130727918164224e12, /*1.067062284288e16/3617*/ + 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ + -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ + 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ + -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ + }; + + int i = 0; + T a, b, k, s, t, w; + if (x == one) { + return std::numeric_limits::infinity(); + } + + if (x < one) { + return std::numeric_limits::quiet_NaN(); + } + + if (q <= zero) { + if (q == std::floor(q)) { + return std::numeric_limits::infinity(); + } + if (x != std::floor(x)) { + return std::numeric_limits::quiet_NaN(); + } + } + + s = std::pow(q, -x); + a = q; + i = 0; + b = zero; + while ((i < 9) || (a <= T{9.0})) { + i += 1; + a += one; + b = ::pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return static_cast(s); + } + } + + w = a; + s += b * w / (x - one); + s -= half * b; + a = one; + k = zero; + for (int i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / A[i]; + s = s + t; + t = std::fabs(t / s); + if (t < MACHEP) { + return static_cast(s); + } + k += one; + a *= x + k; + b /= w; + k += one; + } + return static_cast(s); +} + +template +struct CudaPolygammaFunctor { + int _n; + __forceinline__ CudaPolygammaFunctor(int n) { _n = n; } + __device__ __forceinline__ T operator()(const T _x) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(_x); + const auto one = MT{1}; + return static_cast(((_n % 2) ? one : -one) * + std::exp(std::lgamma(static_cast(_n) + one)) * + zeta(static_cast(_n + 1), mp_x)); + } +}; + +template +struct CudaPolygammaGradFunctor { + int _n; + __forceinline__ CudaPolygammaGradFunctor(int n) { _n = n; } + __device__ __forceinline__ T operator()(const T _x, const T _out_grad) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(_x); + const MT mp_out_grad = static_cast(_out_grad); + const auto one = MT{1}; + return static_cast(mp_out_grad * ((_n % 2) ? one : -one) * + std::exp(std::lgamma(static_cast(_n) + one)) * + zeta(static_cast(_n + 1), mp_x)); + } +}; +#else +template +static inline T zeta(T x, T q) { + /* + * REFERENCE: + * Gradshteyn, I. S., and I. M. Ryzhik, Tables of Integrals, + * Series, and Products, p. 1073; Academic Press, 1980. + * From https://netlib.org/cephes/doubldoc.html - zeta.c + */ + const T MACHEP = T{1.11022302462515654042E-16}; + constexpr T zero = T{0.0}; + constexpr T half = T{0.5}; + constexpr T one = T{1.0}; + static const T A[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, /*1.307674368e12/691*/ + 7.47242496e10, + -2.950130727918164224e12, /*1.067062284288e16/3617*/ + 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ + -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ + 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ + -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ + }; + + int i = 0; + T a, b, k, s, t, w; + if (x == one) { + return std::numeric_limits::infinity(); + } + + if (x < one) { + return std::numeric_limits::quiet_NaN(); + } + + if (q <= zero) { + if (q == std::floor(q)) { + return std::numeric_limits::infinity(); + } + if (x != std::floor(x)) { + return std::numeric_limits::quiet_NaN(); + } + } + + s = std::pow(q, -x); + a = q; + i = 0; + b = zero; + while ((i < 9) || (a <= T{9.0})) { + i += 1; + a += one; + b = std::pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return static_cast(s); + } + } + + w = a; + s += b * w / (x - one); + s -= half * b; + a = one; + k = zero; + for (int i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / A[i]; + s = s + t; + t = std::fabs(t / s); + if (t < MACHEP) { + return static_cast(s); + } + k += one; + a *= x + k; + b /= w; + k += one; + } + return static_cast(s); +} + +template +struct PolygammaFunctor { + PolygammaFunctor(const T* input, const int n, T* output, int64_t size) + : input_(input), n_(n), output_(output), size_(size) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(input_[idx]); + + const auto one = MT{1}; + output_[idx] = + static_cast(((n_ % 2) ? one : -one) * + std::exp(std::lgamma(static_cast(n_) + one)) * + zeta(static_cast(n_ + 1), mp_x)); + } + + private: + const T* input_; + const int n_; + T* output_; + int64_t size_; +}; + +template +struct PolygammaGradFunctor { + PolygammaGradFunctor( + const T* input, const int n, const T* out_grad, T* output, int64_t size) + : input_(input), + n_(n), + out_grad_(out_grad), + output_(output), + size_(size) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(input_[idx]); + const MT mp_out_grad = static_cast(out_grad_[idx]); + + const auto one = MT{1}; + auto partial_x = ((n_ % 2) ? one : -one) * + std::exp(std::lgamma(static_cast(n_) + one)) * + zeta(static_cast(n_ + 1), mp_x); + output_[idx] = static_cast(mp_out_grad * partial_x); + } + + private: + const T* input_; + const int n_; + const T* out_grad_; + T* output_; + int64_t size_; +}; +#endif + +} // namespace phi diff --git a/paddle/phi/kernels/polygamma_grad_kernel.h b/paddle/phi/kernels/polygamma_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..77108fcadf0c9cabbf75445e9d3e61463146fff2 --- /dev/null +++ b/paddle/phi/kernels/polygamma_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void PolygammaGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const int n, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/polygamma_kernel.h b/paddle/phi/kernels/polygamma_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..06c2365be634a921525938a4572c9ddf97b09112 --- /dev/null +++ b/paddle/phi/kernels/polygamma_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +/** + * @brief This kernel is used to perform elementwise polygamma for x. + * @param ctx device context + * @param x the input tensor of polygamma + * @param n the input tensor of polygamma + * @param out the output tensor of polygamma + */ +template +void PolygammaKernel(const Context& ctx, + const DenseTensor& x, + const int n, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 719849500376c75ac1ed3685bc939d653350ffee..a5c3754de7f7285461cf77b161ce4e1034900f65 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -309,6 +309,7 @@ from .tensor.math import i0 # noqa: F401 from .tensor.math import i0e # noqa: F401 from .tensor.math import i1 # noqa: F401 from .tensor.math import i1e # noqa: F401 +from .tensor.math import polygamma # noqa: F401 from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -708,4 +709,5 @@ __all__ = [ # noqa 'i0e', 'i1', 'i1e', + 'polygamma', ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 4d083a5febb353dc14eb21137a41ba04bbd0f8f9..4271ee22d4c9823acbb92ac0eeeb209ea5982741 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -261,6 +261,7 @@ from .math import i0 # noqa: F401 from .math import i0e # noqa: F401 from .math import i1 # noqa: F401 from .math import i1e # noqa: F401 +from .math import polygamma # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -560,6 +561,7 @@ tensor_method_func = [ # noqa 'i0e', 'i1', 'i1e', + 'polygamma', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index c8838a1cb5d97b4507b97e935b3c6d08d3064b94..32676239c257e453240e366019a08d8130794484 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5727,3 +5727,61 @@ def i1e(x, name=None): type='i1e', inputs={'x': x}, outputs={'out': out}, attrs={} ) return out + + +def polygamma(x, n, name=None): + r""" + Calculates the polygamma of the given input tensor, element-wise. + + The equation is: + + .. math:: + \Phi^n(x) = \frac{d^n}{dx^n} [\ln(\Gamma(x))] + + Args: + x (Tensor): Input Tensor. Must be one of the following types: float32, float64. + n (int): Order of the derivative. Must be integral. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + - out (Tensor), A Tensor. the polygamma of the input Tensor, the shape and data type is the same with input. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.to_tensor([2, 3, 25.5], dtype='float32') + res = paddle.polygamma(data, 1) + print(res) + # Tensor(shape=[2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [0.64493407, 0.39493407, 0.03999467]) + """ + if not isinstance(n, int): + raise TypeError( + "The input of n must be int type, but received: %s " % (type(n)) + ) + if n < 0: + raise ValueError( + "The input of n must be greater than or equal to 0. But received n = %s" + % (n) + ) + if n == 0: + return digamma(x) + else: + if in_dynamic_mode(): + return _C_ops.polygamma(x, n) + else: + check_variable_and_dtype( + x, "x", ["float32", "float64"], "polygamma" + ) + + helper = LayerHelper("polygamma", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='polygamma', + inputs={'x': x}, + outputs={'out': out}, + attrs={'n': n}, + ) + return out diff --git a/test/legacy_test/test_polygamma_op.py b/test/legacy_test/test_polygamma_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5cf3062f3d2b924b30c66ce9ae1810982777bf --- /dev/null +++ b/test/legacy_test/test_polygamma_op.py @@ -0,0 +1,213 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from eager_op_test import OpTest +from scipy import special + +import paddle +from paddle.fluid import core + +np.random.seed(100) +paddle.seed(100) + + +def ref_polygamma(x, n): + """ + The case where x = 0 differs from + the current mainstream implementation, + and requires specifying a special value point. + """ + mask = x == 0 + if n == 0: + out = special.psi(x) + out[mask] = np.nan + else: + out = special.polygamma(n, x) + return out + + +def ref_polygamma_grad(x, dout, n): + """ + The case where x = 0 differs from + the current mainstream implementation, + and requires specifying a special value point. + """ + mask = x == 0 + gradx = special.polygamma(n + 1, x) + if n == 0: + gradx[mask] = np.nan + return dout * gradx + + +class TestPolygammaAPI(unittest.TestCase): + DTYPE = "float64" + DATA = [0, 1, 2, 3, 4, 5] + ORDER = 1 + + def setUp(self): + self.x = np.array(self.DATA).astype(self.DTYPE) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_api_static(self): + def run(place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name="x", shape=self.x.shape, dtype=self.DTYPE + ) + y = paddle.polygamma(x, self.ORDER) + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={"x": self.x}, + fetch_list=[y], + ) + out_ref = ref_polygamma(self.x, self.ORDER) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-5) + paddle.disable_static() + + for place in self.place: + run(place) + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + out = paddle.polygamma(x, self.ORDER) + + out_ref = ref_polygamma(self.x, self.ORDER) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_empty_input_error(self): + for place in self.place: + paddle.disable_static(place) + x = None + self.assertRaises(ValueError, paddle.polygamma, x, self.ORDER) + paddle.enable_static() + + def test_input_type_error(self): + for place in self.place: + paddle.disable_static(place) + self.assertRaises( + TypeError, paddle.polygamma, self.x, float(self.ORDER) + ) + paddle.enable_static() + + def test_negative_order_error(self): + for place in self.place: + paddle.disable_static(place) + self.assertRaises(ValueError, paddle.polygamma, self.x, -self.ORDER) + paddle.enable_static() + + +class TestPolygammaFloat32Order1(TestPolygammaAPI): + DTYPE = "float32" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 1 + + +class TestPolygammaFloat32Order2(TestPolygammaAPI): + DTYPE = "float32" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 2 + + +class TestPolygammaFloat32Order3(TestPolygammaAPI): + DTYPE = "float32" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 3 + + +class TestPolygammaFloat64Order1(TestPolygammaAPI): + DTYPE = "float64" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 1 + + +class TestPolygammaFloat64Order2(TestPolygammaAPI): + DTYPE = "float64" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 2 + + +class TestPolygammaFloat64Order3(TestPolygammaAPI): + DTYPE = "float64" + DATA = [2, 3, 5, 2.25, 7, 7.25] + ORDER = 3 + + +class TestPolygammaNegativeInputOrder1(TestPolygammaAPI): + DTYPE = "float64" + DATA = [-2, 3, 5, 2.25, 7, 7.25] + ORDER = 1 + + +class TestPolygammaMultiDimOrder1(TestPolygammaAPI): + DTYPE = "float64" + DATA = [[-2, 3, 5, 2.25, 7, 7.25], [0, 1, 2, 3, 4, 5]] + ORDER = 1 + + +class TestPolygammaMultiDimOrder2(TestPolygammaAPI): + DTYPE = "float64" + DATA = [ + [[-2, 3, 5, 2.25, 7, 7.25], [0, 1, 2, 3, 4, 5]], + [[6, 7, 8, 9, 1, 2], [0, 1, 2, 3, 4, 5]], + ] + ORDER = 2 + + +class TestPolygammaOp(OpTest): + def setUp(self) -> None: + self.op_type = "polygamma" + self.python_api = paddle.polygamma + self.init_config() + self.outputs = {"out": self.target} + + def init_config(self): + self.dtype = np.float64 + self.order = 1 + rand_case = np.random.randn(100).astype(self.dtype) + int_case = np.random.randint(low=1, high=100, size=100).astype( + self.dtype + ) + self.case = np.concatenate([rand_case, int_case]) + self.inputs = {'x': self.case} + self.attrs = {'n': self.order} + self.target = ref_polygamma(self.inputs['x'], self.order) + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['x'], + 'out', + user_defined_grads=[ + ref_polygamma_grad(self.case, 1 / self.case.size, self.order) + ], + ) + + +if __name__ == "__main__": + unittest.main()