未验证 提交 ed604569 编写于 作者: P PommesPeter 提交者: GitHub

【Hackathon 4 No.19】Add polygamma API to Paddle (#53791)

* feat: added polygamma init code

* feat: added polygamma unittest code

* test: added more test cases

* refactor: added forward impl

* refactor: added backward impl

* test: updated cases

* refactor: updated test cases

* refactor: added more case and fixed some bugs

* test: updated ref func

* refactor: updated code style

* refactor: move the code

* refactor: updated test

* refactor: updated test

* docs: updated en doc
Co-authored-by: Nzachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* docs: updated math eq

---------
Co-authored-by: Nzachary sun <70642955+sunzhongkai588@users.noreply.github.com>
上级 30881647
......@@ -1500,6 +1500,16 @@
kernel :
func : poisson_grad
- backward_op : polygamma_grad
forward : polygamma (Tensor x, int n) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int n)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : polygamma_grad
- backward_op : pow_double_grad
forward : pow_grad(Tensor x, Tensor grad_out, Scalar y) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad, Scalar y)
......
......@@ -1725,6 +1725,16 @@
func : poisson
backward : poisson_grad
- op : polygamma
args : (Tensor x, int n)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : polygamma
backward : polygamma_grad
- op : pow
args : (Tensor x, Scalar y=1.0f)
output : Tensor(out)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/polygamma_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/polygamma_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void PolygammaGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const int n,
DenseTensor* x_grad) {
auto size = x.numel();
auto* x_data = x.data<T>();
auto* out_grad_data = out_grad.data<T>();
auto* x_gard_data = ctx.template Alloc<T>(x_grad);
phi::funcs::ForRange<Context> for_range(ctx, size);
PolygammaGradFunctor<T> functor(
x_data, n + 1, out_grad_data, x_gard_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
polygamma_grad, CPU, ALL_LAYOUT, phi::PolygammaGradKernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/polygamma_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/polygamma_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void PolygammaKernel(const Context& ctx,
const DenseTensor& x,
const int n,
DenseTensor* out) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
phi::funcs::ForRange<Context> for_range(ctx, size);
PolygammaFunctor<T> functor(x_data, n, out_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
polygamma, CPU, ALL_LAYOUT, phi::PolygammaKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/polygamma_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/polygamma_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void PolygammaGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const int n,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = CudaPolygammaGradFunctor<T>(n + 1);
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
polygamma_grad, GPU, ALL_LAYOUT, phi::PolygammaGradKernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/polygamma_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/polygamma_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void PolygammaKernel(const Context& ctx,
const DenseTensor& x,
const int n,
DenseTensor* out) {
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaPolygammaFunctor<T>(n);
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
polygamma, GPU, ALL_LAYOUT, phi::PolygammaKernel, float, double) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#else
#include "paddle/phi/kernels/funcs/for_range.h"
#endif
namespace phi {
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
__host__ __device__ T zeta(T x, T q) {
/*
* REFERENCE:
* Gradshteyn, I. S., and I. M. Ryzhik, Tables of Integrals,
* Series, and Products, p. 1073; Academic Press, 1980.
* From https://netlib.org/cephes/doubldoc.html - zeta.c
*/
const T MACHEP = T{1.11022302462515654042E-16};
constexpr T zero = T{0.0};
constexpr T half = T{0.5};
constexpr T one = T{1.0};
static const T A[] = {
12.0,
-720.0,
30240.0,
-1209600.0,
47900160.0,
-1.8924375803183791606e9, /*1.307674368e12/691*/
7.47242496e10,
-2.950130727918164224e12, /*1.067062284288e16/3617*/
1.1646782814350067249e14, /*5.109094217170944e18/43867*/
-4.5979787224074726105e15, /*8.028576626982912e20/174611*/
1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
-7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
};
int i = 0;
T a, b, k, s, t, w;
if (x == one) {
return std::numeric_limits<T>::infinity();
}
if (x < one) {
return std::numeric_limits<T>::quiet_NaN();
}
if (q <= zero) {
if (q == std::floor(q)) {
return std::numeric_limits<T>::infinity();
}
if (x != std::floor(x)) {
return std::numeric_limits<T>::quiet_NaN();
}
}
s = std::pow(q, -x);
a = q;
i = 0;
b = zero;
while ((i < 9) || (a <= T{9.0})) {
i += 1;
a += one;
b = ::pow(a, -x);
s += b;
if ((-MACHEP * s < b) && (b < MACHEP * s)) {
return static_cast<T>(s);
}
}
w = a;
s += b * w / (x - one);
s -= half * b;
a = one;
k = zero;
for (int i = 0; i < 12; i++) {
a *= x + k;
b /= w;
t = a * b / A[i];
s = s + t;
t = std::fabs(t / s);
if (t < MACHEP) {
return static_cast<T>(s);
}
k += one;
a *= x + k;
b /= w;
k += one;
}
return static_cast<T>(s);
}
template <typename T>
struct CudaPolygammaFunctor {
int _n;
__forceinline__ CudaPolygammaFunctor(int n) { _n = n; }
__device__ __forceinline__ T operator()(const T _x) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
const auto one = MT{1};
return static_cast<T>(((_n % 2) ? one : -one) *
std::exp(std::lgamma(static_cast<MT>(_n) + one)) *
zeta<MT>(static_cast<MT>(_n + 1), mp_x));
}
};
template <typename T>
struct CudaPolygammaGradFunctor {
int _n;
__forceinline__ CudaPolygammaGradFunctor(int n) { _n = n; }
__device__ __forceinline__ T operator()(const T _x, const T _out_grad) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
const MT mp_out_grad = static_cast<MT>(_out_grad);
const auto one = MT{1};
return static_cast<T>(mp_out_grad * ((_n % 2) ? one : -one) *
std::exp(std::lgamma(static_cast<MT>(_n) + one)) *
zeta<MT>(static_cast<MT>(_n + 1), mp_x));
}
};
#else
template <typename T>
static inline T zeta(T x, T q) {
/*
* REFERENCE:
* Gradshteyn, I. S., and I. M. Ryzhik, Tables of Integrals,
* Series, and Products, p. 1073; Academic Press, 1980.
* From https://netlib.org/cephes/doubldoc.html - zeta.c
*/
const T MACHEP = T{1.11022302462515654042E-16};
constexpr T zero = T{0.0};
constexpr T half = T{0.5};
constexpr T one = T{1.0};
static const T A[] = {
12.0,
-720.0,
30240.0,
-1209600.0,
47900160.0,
-1.8924375803183791606e9, /*1.307674368e12/691*/
7.47242496e10,
-2.950130727918164224e12, /*1.067062284288e16/3617*/
1.1646782814350067249e14, /*5.109094217170944e18/43867*/
-4.5979787224074726105e15, /*8.028576626982912e20/174611*/
1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
-7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
};
int i = 0;
T a, b, k, s, t, w;
if (x == one) {
return std::numeric_limits<T>::infinity();
}
if (x < one) {
return std::numeric_limits<T>::quiet_NaN();
}
if (q <= zero) {
if (q == std::floor(q)) {
return std::numeric_limits<T>::infinity();
}
if (x != std::floor(x)) {
return std::numeric_limits<T>::quiet_NaN();
}
}
s = std::pow(q, -x);
a = q;
i = 0;
b = zero;
while ((i < 9) || (a <= T{9.0})) {
i += 1;
a += one;
b = std::pow(a, -x);
s += b;
if ((-MACHEP * s < b) && (b < MACHEP * s)) {
return static_cast<T>(s);
}
}
w = a;
s += b * w / (x - one);
s -= half * b;
a = one;
k = zero;
for (int i = 0; i < 12; i++) {
a *= x + k;
b /= w;
t = a * b / A[i];
s = s + t;
t = std::fabs(t / s);
if (t < MACHEP) {
return static_cast<T>(s);
}
k += one;
a *= x + k;
b /= w;
k += one;
}
return static_cast<T>(s);
}
template <typename T>
struct PolygammaFunctor {
PolygammaFunctor(const T* input, const int n, T* output, int64_t size)
: input_(input), n_(n), output_(output), size_(size) {}
HOSTDEVICE void operator()(int64_t idx) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(input_[idx]);
const auto one = MT{1};
output_[idx] =
static_cast<T>(((n_ % 2) ? one : -one) *
std::exp(std::lgamma(static_cast<MT>(n_) + one)) *
zeta<MT>(static_cast<MT>(n_ + 1), mp_x));
}
private:
const T* input_;
const int n_;
T* output_;
int64_t size_;
};
template <typename T>
struct PolygammaGradFunctor {
PolygammaGradFunctor(
const T* input, const int n, const T* out_grad, T* output, int64_t size)
: input_(input),
n_(n),
out_grad_(out_grad),
output_(output),
size_(size) {}
HOSTDEVICE void operator()(int64_t idx) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(input_[idx]);
const MT mp_out_grad = static_cast<MT>(out_grad_[idx]);
const auto one = MT{1};
auto partial_x = ((n_ % 2) ? one : -one) *
std::exp(std::lgamma(static_cast<MT>(n_) + one)) *
zeta<MT>(static_cast<MT>(n_ + 1), mp_x);
output_[idx] = static_cast<T>(mp_out_grad * partial_x);
}
private:
const T* input_;
const int n_;
const T* out_grad_;
T* output_;
int64_t size_;
};
#endif
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void PolygammaGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const int n,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
/**
* @brief This kernel is used to perform elementwise polygamma for x.
* @param ctx device context
* @param x the input tensor of polygamma
* @param n the input tensor of polygamma
* @param out the output tensor of polygamma
*/
template <typename T, typename Context>
void PolygammaKernel(const Context& ctx,
const DenseTensor& x,
const int n,
DenseTensor* out);
} // namespace phi
......@@ -309,6 +309,7 @@ from .tensor.math import i0 # noqa: F401
from .tensor.math import i0e # noqa: F401
from .tensor.math import i1 # noqa: F401
from .tensor.math import i1e # noqa: F401
from .tensor.math import polygamma # noqa: F401
from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
......@@ -708,4 +709,5 @@ __all__ = [ # noqa
'i0e',
'i1',
'i1e',
'polygamma',
]
......@@ -261,6 +261,7 @@ from .math import i0 # noqa: F401
from .math import i0e # noqa: F401
from .math import i1 # noqa: F401
from .math import i1e # noqa: F401
from .math import polygamma # noqa: F401
from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
......@@ -560,6 +561,7 @@ tensor_method_func = [ # noqa
'i0e',
'i1',
'i1e',
'polygamma',
]
# this list used in math_op_patch.py for magic_method bind
......
......@@ -5727,3 +5727,61 @@ def i1e(x, name=None):
type='i1e', inputs={'x': x}, outputs={'out': out}, attrs={}
)
return out
def polygamma(x, n, name=None):
r"""
Calculates the polygamma of the given input tensor, element-wise.
The equation is:
.. math::
\Phi^n(x) = \frac{d^n}{dx^n} [\ln(\Gamma(x))]
Args:
x (Tensor): Input Tensor. Must be one of the following types: float32, float64.
n (int): Order of the derivative. Must be integral.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
- out (Tensor), A Tensor. the polygamma of the input Tensor, the shape and data type is the same with input.
Examples:
.. code-block:: python
import paddle
data = paddle.to_tensor([2, 3, 25.5], dtype='float32')
res = paddle.polygamma(data, 1)
print(res)
# Tensor(shape=[2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.64493407, 0.39493407, 0.03999467])
"""
if not isinstance(n, int):
raise TypeError(
"The input of n must be int type, but received: %s " % (type(n))
)
if n < 0:
raise ValueError(
"The input of n must be greater than or equal to 0. But received n = %s"
% (n)
)
if n == 0:
return digamma(x)
else:
if in_dynamic_mode():
return _C_ops.polygamma(x, n)
else:
check_variable_and_dtype(
x, "x", ["float32", "float64"], "polygamma"
)
helper = LayerHelper("polygamma", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='polygamma',
inputs={'x': x},
outputs={'out': out},
attrs={'n': n},
)
return out
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import OpTest
from scipy import special
import paddle
from paddle.fluid import core
np.random.seed(100)
paddle.seed(100)
def ref_polygamma(x, n):
"""
The case where x = 0 differs from
the current mainstream implementation,
and requires specifying a special value point.
"""
mask = x == 0
if n == 0:
out = special.psi(x)
out[mask] = np.nan
else:
out = special.polygamma(n, x)
return out
def ref_polygamma_grad(x, dout, n):
"""
The case where x = 0 differs from
the current mainstream implementation,
and requires specifying a special value point.
"""
mask = x == 0
gradx = special.polygamma(n + 1, x)
if n == 0:
gradx[mask] = np.nan
return dout * gradx
class TestPolygammaAPI(unittest.TestCase):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5]
ORDER = 1
def setUp(self):
self.x = np.array(self.DATA).astype(self.DTYPE)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_api_static(self):
def run(place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.DTYPE
)
y = paddle.polygamma(x, self.ORDER)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x},
fetch_list=[y],
)
out_ref = ref_polygamma(self.x, self.ORDER)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-5)
paddle.disable_static()
for place in self.place:
run(place)
def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
out = paddle.polygamma(x, self.ORDER)
out_ref = ref_polygamma(self.x, self.ORDER)
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5)
paddle.enable_static()
for place in self.place:
run(place)
def test_empty_input_error(self):
for place in self.place:
paddle.disable_static(place)
x = None
self.assertRaises(ValueError, paddle.polygamma, x, self.ORDER)
paddle.enable_static()
def test_input_type_error(self):
for place in self.place:
paddle.disable_static(place)
self.assertRaises(
TypeError, paddle.polygamma, self.x, float(self.ORDER)
)
paddle.enable_static()
def test_negative_order_error(self):
for place in self.place:
paddle.disable_static(place)
self.assertRaises(ValueError, paddle.polygamma, self.x, -self.ORDER)
paddle.enable_static()
class TestPolygammaFloat32Order1(TestPolygammaAPI):
DTYPE = "float32"
DATA = [2, 3, 5, 2.25, 7, 7.25]
ORDER = 1
class TestPolygammaFloat32Order2(TestPolygammaAPI):
DTYPE = "float32"
DATA = [2, 3, 5, 2.25, 7, 7.25]
ORDER = 2
class TestPolygammaFloat32Order3(TestPolygammaAPI):
DTYPE = "float32"
DATA = [2, 3, 5, 2.25, 7, 7.25]
ORDER = 3
class TestPolygammaFloat64Order1(TestPolygammaAPI):
DTYPE = "float64"
DATA = [2, 3, 5, 2.25, 7, 7.25]
ORDER = 1
class TestPolygammaFloat64Order2(TestPolygammaAPI):
DTYPE = "float64"
DATA = [2, 3, 5, 2.25, 7, 7.25]
ORDER = 2
class TestPolygammaFloat64Order3(TestPolygammaAPI):
DTYPE = "float64"
DATA = [2, 3, 5, 2.25, 7, 7.25]
ORDER = 3
class TestPolygammaNegativeInputOrder1(TestPolygammaAPI):
DTYPE = "float64"
DATA = [-2, 3, 5, 2.25, 7, 7.25]
ORDER = 1
class TestPolygammaMultiDimOrder1(TestPolygammaAPI):
DTYPE = "float64"
DATA = [[-2, 3, 5, 2.25, 7, 7.25], [0, 1, 2, 3, 4, 5]]
ORDER = 1
class TestPolygammaMultiDimOrder2(TestPolygammaAPI):
DTYPE = "float64"
DATA = [
[[-2, 3, 5, 2.25, 7, 7.25], [0, 1, 2, 3, 4, 5]],
[[6, 7, 8, 9, 1, 2], [0, 1, 2, 3, 4, 5]],
]
ORDER = 2
class TestPolygammaOp(OpTest):
def setUp(self) -> None:
self.op_type = "polygamma"
self.python_api = paddle.polygamma
self.init_config()
self.outputs = {"out": self.target}
def init_config(self):
self.dtype = np.float64
self.order = 1
rand_case = np.random.randn(100).astype(self.dtype)
int_case = np.random.randint(low=1, high=100, size=100).astype(
self.dtype
)
self.case = np.concatenate([rand_case, int_case])
self.inputs = {'x': self.case}
self.attrs = {'n': self.order}
self.target = ref_polygamma(self.inputs['x'], self.order)
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['x'],
'out',
user_defined_grads=[
ref_polygamma_grad(self.case, 1 / self.case.size, self.order)
],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册