未验证 提交 896c9315 编写于 作者: M megemini 提交者: GitHub

【Hackathon 4th No.12】为 Paddle 新增 Cauchy API (#52999)

* 【Hackathon 4th No.12】为 Paddle 新增 Cauchy API

* [Change]修改初始化方法与类型检查

* [Change]将测试用例移动到新的目录下

* [Change]适配to_tensor的0D
上级 bca1b0c6
......@@ -16,6 +16,7 @@ from paddle.distribution import transform
from paddle.distribution.bernoulli import Bernoulli
from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical
from paddle.distribution.cauchy import Cauchy
from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution
from paddle.distribution.gumbel import Gumbel
......@@ -35,6 +36,7 @@ __all__ = [ # noqa
'Bernoulli',
'Beta',
'Categorical',
'Cauchy',
'Dirichlet',
'Distribution',
'ExponentialFamily',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import numpy as np
import paddle
from paddle.distribution import distribution
from paddle.fluid import framework
class Cauchy(distribution.Distribution):
r"""Cauchy distribution is also called Cauchy–Lorentz distribution. It is a continuous probability distribution named after Augustin-Louis Cauchy and Hendrik Lorentz. It has a very wide range of applications in natural sciences.
The Cauchy distribution has the probability density function (PDF):
.. math::
{ f(x; loc, scale) = \frac{1}{\pi scale \left[1 + \left(\frac{x - loc}{ scale}\right)^2\right]} = { 1 \over \pi } \left[ { scale \over (x - loc)^2 + scale^2 } \right], }
Args:
loc (float|Tensor): Location of the peak of the distribution. The data type is float32 or float64.
scale (float|Tensor): The half-width at half-maximum (HWHM). The data type is float32 or float64. Must be positive values.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Cauchy
# init Cauchy with float
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.entropy())
# Tensor(shape=1, dtype=float32, place=Place(cpu), stop_gradient=True,
# 2.71334577)
# init Cauchy with N-Dim tensor
rv = Cauchy(loc=paddle.to_tensor(0.1), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.entropy())
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [2.53102422, 3.22417140])
"""
def __init__(self, loc, scale, name=None):
self.name = name if name is not None else 'Cauchy'
if not isinstance(loc, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of loc is Real|Variable, but got {type(loc)}"
)
if not isinstance(scale, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of scale is Real|Variable, but got {type(scale)}"
)
if isinstance(loc, numbers.Real):
loc = paddle.full(shape=(), fill_value=loc)
if isinstance(scale, numbers.Real):
scale = paddle.full(shape=(), fill_value=scale)
if loc.shape != scale.shape:
self.loc, self.scale = paddle.broadcast_tensors([loc, scale])
else:
self.loc, self.scale = loc, scale
self.dtype = self.loc.dtype
super().__init__(batch_shape=self.loc.shape, event_shape=())
@property
def mean(self):
"""Mean of Cauchy distribution."""
raise ValueError("Cauchy distribution has no mean.")
@property
def variance(self):
"""Variance of Cauchy distribution."""
raise ValueError("Cauchy distribution has no variance.")
@property
def stddev(self):
"""Standard Deviation of Cauchy distribution."""
raise ValueError("Cauchy distribution has no stddev.")
def sample(self, shape, name=None):
"""Sample from Cauchy distribution.
Note:
`sample` method has no grad, if you want so, please use `rsample` instead.
Args:
shape (Sequence[int]): Sample shape.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Cauchy
# init Cauchy with float
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.sample([10]).shape)
# [10]
# init Cauchy with 0-Dim tensor
rv = Cauchy(loc=paddle.full((), 0.1), scale=paddle.full((), 1.2))
print(rv.sample([10]).shape)
# [10]
# init Cauchy with N-Dim tensor
rv = Cauchy(loc=paddle.to_tensor(0.1), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.sample([10]).shape)
# [10, 2]
# sample 2-Dim data
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.sample([10, 2]).shape)
# [10, 2]
rv = Cauchy(loc=paddle.to_tensor(0.1), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.sample([10, 2]).shape)
# [10, 2, 2]
"""
name = name if name is not None else (self.name + '_sample')
with paddle.no_grad():
return self.rsample(shape, name)
def rsample(self, shape, name=None):
"""Sample from Cauchy distribution (reparameterized).
Args:
shape (Sequence[int]): Sample shape.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Cauchy
# init Cauchy with float
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.rsample([10]).shape)
# [10]
# init Cauchy with 0-Dim tensor
rv = Cauchy(loc=paddle.full((), 0.1), scale=paddle.full((), 1.2))
print(rv.rsample([10]).shape)
# [10]
# init Cauchy with N-Dim tensor
rv = Cauchy(loc=paddle.to_tensor(0.1), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.rsample([10]).shape)
# [10, 2]
# sample 2-Dim data
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.rsample([10, 2]).shape)
# [10, 2]
rv = Cauchy(loc=paddle.to_tensor(0.1), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.rsample([10, 2]).shape)
# [10, 2, 2]
"""
name = name if name is not None else (self.name + '_rsample')
if not isinstance(shape, (np.ndarray, framework.Variable, list, tuple)):
raise TypeError(
f"Expected type of shape is Sequence[int], but got {type(shape)}"
)
shape = shape if isinstance(shape, tuple) else tuple(shape)
shape = self._extend_shape(shape)
loc = self.loc.expand(shape)
scale = self.scale.expand(shape)
uniforms = paddle.rand(shape, dtype=self.dtype)
return paddle.add(
loc,
paddle.multiply(scale, paddle.tan(np.pi * (uniforms - 0.5))),
name=name,
)
def prob(self, value):
r"""Probability density function(PDF) evaluated at value.
.. math::
{ f(x; loc, scale) = \frac{1}{\pi scale \left[1 + \left(\frac{x - loc}{ scale}\right)^2\right]} = { 1 \over \pi } \left[ { scale \over (x - loc)^2 + scale^2 } \right], }
Args:
value (Tensor): Value to be evaluated.
Returns:
Tensor: PDF evaluated at value.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Cauchy
# init Cauchy with float
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.prob(paddle.to_tensor(1.5)))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.11234467])
# broadcast to value
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.prob(paddle.to_tensor([1.5, 5.1])))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.11234467, 0.01444674])
# init Cauchy with N-Dim tensor
rv = Cauchy(loc=paddle.to_tensor([0.1, 0.1]), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.prob(paddle.to_tensor([1.5, 5.1])))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.10753712, 0.02195240])
# init Cauchy with N-Dim tensor with broadcast
rv = Cauchy(loc=paddle.to_tensor(0.1), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.prob(paddle.to_tensor([1.5, 5.1])))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.10753712, 0.02195240])
"""
name = self.name + '_prob'
if not isinstance(value, framework.Variable):
raise TypeError(
f"Expected type of value is Variable, but got {type(value)}"
)
return self.log_prob(value).exp(name=name)
def log_prob(self, value):
"""Log of probability densitiy function.
Args:
value (Tensor): Value to be evaluated.
Returns:
Tensor: Log of probability densitiy evaluated at value.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Cauchy
# init Cauchy with float
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.log_prob(paddle.to_tensor(1.5)))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [-2.18618369])
# broadcast to value
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.log_prob(paddle.to_tensor([1.5, 5.1])))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [-2.18618369, -4.23728657])
# init Cauchy with N-Dim tensor
rv = Cauchy(loc=paddle.to_tensor([0.1, 0.1]), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.log_prob(paddle.to_tensor([1.5, 5.1])))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [-2.22991920, -3.81887865])
# init Cauchy with N-Dim tensor with broadcast
rv = Cauchy(loc=paddle.to_tensor(0.1), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.log_prob(paddle.to_tensor([1.5, 5.1])))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [-2.22991920, -3.81887865])
"""
name = self.name + '_log_prob'
if not isinstance(value, framework.Variable):
raise TypeError(
f"Expected type of value is Variable, but got {type(value)}"
)
value = self._check_values_dtype_in_probs(self.loc, value)
loc, scale, value = paddle.broadcast_tensors(
[self.loc, self.scale, value]
)
return paddle.subtract(
-(
paddle.square(paddle.divide(paddle.subtract(value, loc), scale))
).log1p(),
paddle.add(
paddle.full(loc.shape, np.log(np.pi), dtype=self.dtype),
scale.log(),
),
name=name,
)
def cdf(self, value):
r"""Cumulative distribution function(CDF) evaluated at value.
.. math::
{ \frac{1}{\pi} \arctan\left(\frac{x-loc}{ scale}\right)+\frac{1}{2}\! }
Args:
value (Tensor): Value to be evaluated.
Returns:
Tensor: CDF evaluated at value.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Cauchy
# init Cauchy with float
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.cdf(paddle.to_tensor(1.5)))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.77443725])
# broadcast to value
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.cdf(paddle.to_tensor([1.5, 5.1])))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.77443725, 0.92502367])
# init Cauchy with N-Dim tensor
rv = Cauchy(loc=paddle.to_tensor([0.1, 0.1]), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.cdf(paddle.to_tensor([1.5, 5.1])))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.80256844, 0.87888104])
# init Cauchy with N-Dim tensor with broadcast
rv = Cauchy(loc=paddle.to_tensor(0.1), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.cdf(paddle.to_tensor([1.5, 5.1])))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.80256844, 0.87888104])
"""
name = self.name + '_cdf'
if not isinstance(value, framework.Variable):
raise TypeError(
f"Expected type of value is Variable, but got {type(value)}"
)
value = self._check_values_dtype_in_probs(self.loc, value)
loc, scale, value = paddle.broadcast_tensors(
[self.loc, self.scale, value]
)
return (
paddle.atan(
paddle.divide(paddle.subtract(value, loc), scale), name=name
)
/ np.pi
+ 0.5
)
def entropy(self):
r"""Entropy of Cauchy distribution.
.. math::
{ \log(4\pi scale)\! }
Returns:
Tensor: Entropy of distribution.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Cauchy
# init Cauchy with float
rv = Cauchy(loc=0.1, scale=1.2)
print(rv.entropy())
# Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
# 2.71334577)
# init Cauchy with N-Dim tensor
rv = Cauchy(loc=paddle.to_tensor(0.1), scale=paddle.to_tensor([1.0, 2.0]))
print(rv.entropy())
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [2.53102422, 3.22417140])
"""
name = self.name + '_entropy'
return paddle.add(
paddle.full(self.loc.shape, np.log(4 * np.pi), dtype=self.dtype),
self.scale.log(),
name=name,
)
def kl_divergence(self, other):
"""The KL-divergence between two Cauchy distributions.
Note:
[1] Frédéric Chyzak, Frank Nielsen, A closed-form formula for the Kullback-Leibler divergence between Cauchy distributions, 2019
Args:
other (Cauchy): instance of Cauchy.
Returns:
Tensor: kl-divergence between two Cauchy distributions.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Cauchy
rv = Cauchy(loc=0.1, scale=1.2)
rv_other = Cauchy(loc=paddle.to_tensor(1.2), scale=paddle.to_tensor([2.3, 3.4]))
print(rv.kl_divergence(rv_other))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.19819736, 0.31532931])
"""
name = self.name + '_kl_divergence'
if not isinstance(other, Cauchy):
raise TypeError(
f"Expected type of other is Cauchy, but got {type(other)}"
)
a_loc = self.loc
b_loc = other.loc
a_scale = self.scale
b_scale = other.scale
t1 = paddle.add(
paddle.pow(paddle.add(a_scale, b_scale), 2),
paddle.pow(paddle.subtract(a_loc, b_loc), 2),
).log()
t2 = (4 * paddle.multiply(a_scale, b_scale)).log()
return paddle.subtract(t1, t2, name=name)
......@@ -18,6 +18,7 @@ import paddle
from paddle.distribution.bernoulli import Bernoulli
from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical
from paddle.distribution.cauchy import Cauchy
from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily
......@@ -186,6 +187,11 @@ def _kl_categorical_categorical(p, q):
return p.kl_divergence(q)
@register_kl(Cauchy, Cauchy)
def _kl_cauchy_cauchy(p, q):
return p.kl_divergence(q)
@register_kl(Normal, Normal)
def _kl_normal_normal(p, q):
return p.kl_divergence(q)
......
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from config import ATOL, DEVICES, RTOL
from parameterize import (
TEST_CASE_NAME,
parameterize_cls,
parameterize_func,
place,
)
from test_distribution_cauchy import CauchyNumpy, _kstest
import paddle
from paddle.distribution import Cauchy
from paddle.distribution.kl import kl_divergence
np.random.seed(2023)
paddle.seed(2023)
paddle.enable_static()
default_dtype = paddle.get_default_dtype()
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'params'),
# params: name, loc, scale, loc_other, scale_other, value
[
(
'params',
(
# 1-D params
(
'params_not_iterable',
0.3,
1.2,
-1.2,
2.3,
3.4,
),
(
'params_not_iterable_and_broadcast_for_value',
0.3,
1.2,
-1.2,
2.3,
np.array([[0.1, 1.2], [1.2, 3.4]], dtype=default_dtype),
),
# N-D params
(
'params_tuple_0305',
(0.3, 0.5),
0.7,
-1.2,
2.3,
3.4,
),
(
'params_tuple_03050104',
((0.3, 0.5), (0.1, 0.4)),
0.7,
-1.2,
2.3,
3.4,
),
),
)
],
)
class CauchyTestFeature(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
self.params_len = len(self.params)
with paddle.static.program_guard(self.program):
self.init_numpy_data(self.params)
self.init_static_data(self.params)
def init_numpy_data(self, params):
self.log_prob_np = []
self.prob_np = []
self.cdf_np = []
self.entropy_np = []
self.kl_np = []
self.shapes = []
for _, loc, scale, loc_other, scale_other, value in params:
rv_np = CauchyNumpy(loc=loc, scale=scale)
rv_np_other = CauchyNumpy(loc=loc_other, scale=scale_other)
self.log_prob_np.append(rv_np.log_prob(value))
self.prob_np.append(rv_np.prob(value))
self.cdf_np.append(rv_np.cdf(value))
self.entropy_np.append(rv_np.entropy())
self.kl_np.append(rv_np.kl_divergence(rv_np_other))
# paddle return data ndim>0
self.shapes.append(
(np.array(loc) + np.array(scale) + np.array(value)).shape
or (1,)
)
def init_static_data(self, params):
with paddle.static.program_guard(self.program):
rv_paddles = []
rv_paddles_other = []
values = []
for name, loc, scale, loc_other, scale_other, value in params:
if not isinstance(value, np.ndarray):
value = paddle.full([1], value, dtype=default_dtype)
else:
value = paddle.to_tensor(value, place=self.place)
# We should set name in static mode, or the executor confuse rv_paddles[i].
rv_paddles.append(
Cauchy(
loc=paddle.to_tensor(loc),
scale=paddle.to_tensor(scale),
name=name,
)
)
rv_paddles_other.append(
Cauchy(
loc=paddle.to_tensor(loc_other),
scale=paddle.to_tensor(scale_other),
name=name,
)
)
values.append(value)
results = self.executor.run(
self.program,
feed={},
fetch_list=[
[
rv_paddles[i].log_prob(values[i]),
rv_paddles[i].prob(values[i]),
rv_paddles[i].cdf(values[i]),
rv_paddles[i].entropy(),
rv_paddles[i].kl_divergence(rv_paddles_other[i]),
kl_divergence(rv_paddles[i], rv_paddles_other[i]),
]
for i in range(self.params_len)
],
)
self.log_prob_paddle = []
self.prob_paddle = []
self.cdf_paddle = []
self.entropy_paddle = []
self.kl_paddle = []
self.kl_func_paddle = []
for i in range(self.params_len):
(
_log_prob,
_prob,
_cdf,
_entropy,
_kl,
_kl_func,
) = results[i * 6 : (i + 1) * 6]
self.log_prob_paddle.append(_log_prob)
self.prob_paddle.append(_prob)
self.cdf_paddle.append(_cdf)
self.entropy_paddle.append(_entropy)
self.kl_paddle.append(_kl)
self.kl_func_paddle.append(_kl_func)
def test_all(self):
for i in range(self.params_len):
self._test_log_prob(i)
self._test_prob(i)
self._test_cdf(i)
self._test_entropy(i)
self._test_kl_divergence(i)
def _test_log_prob(self, i):
np.testing.assert_allclose(
self.log_prob_np[i],
self.log_prob_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
# check shape
self.assertTrue(self.log_prob_paddle[i].shape == self.shapes[i])
def _test_prob(self, i):
np.testing.assert_allclose(
self.prob_np[i],
self.prob_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
# check shape
self.assertTrue(self.prob_paddle[i].shape == self.shapes[i])
def _test_cdf(self, i):
np.testing.assert_allclose(
self.cdf_np[i],
self.cdf_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
# check shape
self.assertTrue(self.cdf_paddle[i].shape == self.shapes[i])
def _test_entropy(self, i):
np.testing.assert_allclose(
self.entropy_np[i],
self.entropy_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
def _test_kl_divergence(self, i):
np.testing.assert_allclose(
self.kl_np[i],
self.kl_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
np.testing.assert_allclose(
self.kl_np[i],
self.kl_func_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
@place(DEVICES)
@parameterize_cls(
(
TEST_CASE_NAME,
'loc',
'scale',
'shape',
'expected_shape',
),
[
# 1-D params
(
'params_1d',
[0.1],
[1.2],
[100],
[100, 1],
),
# N-D params
(
'params_2d',
[0.3],
[1.2, 2.3],
[100],
[100, 2],
),
],
)
class CauchyTestSample(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.program):
self.init_numpy_data(self.loc, self.scale, self.shape)
self.init_static_data(self.loc, self.scale, self.shape)
def init_numpy_data(self, loc, scale, shape):
self.rv_np = CauchyNumpy(loc=loc, scale=scale)
self.sample_np = self.rv_np.sample(shape)
def init_static_data(self, loc, scale, shape):
with paddle.static.program_guard(self.program):
self.rv_paddle = Cauchy(
loc=paddle.to_tensor(loc),
scale=paddle.to_tensor(scale),
)
[self.sample_paddle, self.rsample_paddle] = self.executor.run(
self.program,
feed={},
fetch_list=[
self.rv_paddle.sample(shape),
self.rv_paddle.rsample(shape),
],
)
def test_sample(self):
with paddle.static.program_guard(self.program):
self.assertEqual(
list(self.sample_paddle.shape), self.expected_shape
)
for i in range(self.expected_shape[-1]):
self.assertTrue(
_kstest(
self.sample_np[..., i].reshape(-1),
self.sample_paddle[..., i].reshape(-1),
)
)
def test_rsample(self):
"""Compare two samples from `rsample` method, one from scipy and another from paddle."""
with paddle.static.program_guard(self.program):
self.assertEqual(
list(self.rsample_paddle.shape), self.expected_shape
)
for i in range(self.expected_shape[-1]):
self.assertTrue(
_kstest(
self.sample_np[..., i].reshape(-1),
self.rsample_paddle[..., i].reshape(-1),
)
)
@place(DEVICES)
@parameterize_cls([TEST_CASE_NAME], ['CauchyTestError'])
class CauchyTestError(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
@parameterize_func(
[
((0.3,),), # tuple
([0.3],), # list
(np.array([0.3]),), # ndarray
(-1j + 1,), # complex
('0',), # str
]
)
def test_bad_init_type(self, param):
"""Test bad init for loc/scale"""
with paddle.static.program_guard(self.program):
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program,
feed={},
fetch_list=[Cauchy(loc=0.0, scale=param).scale],
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program,
feed={},
fetch_list=[Cauchy(loc=param, scale=1.0).loc],
)
def test_bad_property(self):
"""For property like mean/variance/stddev which is undefined in math,
we should raise `ValueError` instead of `NotImplementedError`.
"""
with paddle.static.program_guard(self.program):
rv = Cauchy(loc=0.0, scale=1.0)
with self.assertRaises(ValueError):
_ = rv.mean
with self.assertRaises(ValueError):
_ = rv.variance
with self.assertRaises(ValueError):
_ = rv.stddev
@parameterize_func(
[
(100,), # int
(100.0,), # float
]
)
def test_bad_sample_shape_type(self, shape):
with paddle.static.program_guard(self.program):
rv = Cauchy(loc=0.0, scale=1.0)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.sample(shape)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.rsample(shape)]
)
@parameterize_func(
[
(1,), # int
(1.0,), # float
([1.0],), # list
((1.0),), # tuple
(np.array(1.0),), # ndarray
]
)
def test_bad_value_type(self, value):
with paddle.static.program_guard(self.program):
rv = Cauchy(loc=0.0, scale=1.0)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.log_prob(value)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.prob(value)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.cdf(value)]
)
@parameterize_func(
[
(np.array(1.0),), # ndarray or other distribution
]
)
def test_bad_kl_other_type(self, other):
with paddle.static.program_guard(self.program):
rv = Cauchy(loc=0.0, scale=1.0)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.kl_divergence(other)]
)
@parameterize_func(
[
(paddle.to_tensor([0.1, 0.2, 0.3]),),
]
)
def test_bad_broadcast(self, value):
with paddle.static.program_guard(self.program):
rv = Cauchy(
loc=paddle.to_tensor(0.0), scale=paddle.to_tensor((1.0, 2.0))
)
# `logits, value = paddle.broadcast_tensors([self.logits, value])`
# raise ValueError in dygraph, raise TypeError in static.
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.cdf(value)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.log_prob(value)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.prob(value)]
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册