未验证 提交 f05c870b 编写于 作者: M megemini 提交者: GitHub

【Hackathon 4th No.13】为 Paddle 新增 Bernoulli API (#52244)

* 【Hackathon 4th No.13】为 Paddle 新增 Bernoulli API

* [Change]change unittest_py scipy version

* [Change]修改BernoulliNumpy的类型参数;优化静态图测试流程

* [Change]优化类的初始化及逻辑;增加0D相关测试用例
上级 7a78a571
......@@ -13,6 +13,7 @@
# limitations under the License.
from paddle.distribution import transform
from paddle.distribution.bernoulli import Bernoulli
from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet
......@@ -30,6 +31,7 @@ from paddle.distribution.uniform import Uniform
from paddle.distribution.laplace import Laplace
__all__ = [ # noqa
'Bernoulli',
'Beta',
'Categorical',
'Dirichlet',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle.distribution import exponential_family
from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.layers import tensor
from paddle.nn.functional import (
binary_cross_entropy_with_logits,
sigmoid,
softplus,
)
# Smallest representable number
EPS = {
'float32': paddle.finfo(paddle.float32).eps,
'float64': paddle.finfo(paddle.float64).eps,
}
def _clip_probs(probs, dtype):
"""Clip probs from [0, 1] to (0, 1) with ``eps``.
Args:
probs (Tensor): probs of Bernoulli.
dtype (str): data type.
Returns:
Tensor: Clipped probs.
"""
eps = EPS.get(dtype)
return paddle.clip(probs, min=eps, max=1 - eps).astype(dtype)
class Bernoulli(exponential_family.ExponentialFamily):
r"""Bernoulli distribution parameterized by ``probs``, which is the probability of value 1.
In probability theory and statistics, the Bernoulli distribution, named after Swiss
mathematician Jacob Bernoulli, is the discrete probability distribution of a random
variable which takes the value 1 with probability ``p`` and the value 0 with
probability ``q=1-p``.
The probability mass function of this distribution, over possible outcomes ``k``, is
.. math::
{\begin{cases}
q=1-p & \text{if }value=0 \\
p & \text{if }value=1
\end{cases}}
Args:
probs (float|Tensor): The ``probs`` input of Bernoulli distribution. The data type is float32 or float64. The range must be in [0, 1].
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Bernoulli
# init `probs` with a float
rv = Bernoulli(probs=0.3)
print(rv.mean)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.30000001])
print(rv.variance)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.21000001])
print(rv.entropy())
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.61086434])
"""
def __init__(self, probs, name=None):
self.name = name or 'Bernoulli'
if not _non_static_mode():
check_type(
probs,
'probs',
(float, tensor.Variable),
self.name,
)
# Get/convert probs to tensor.
if self._validate_args(probs):
self.probs = probs
self.dtype = convert_dtype(probs.dtype)
else:
[self.probs] = self._to_tensor(probs)
self.dtype = paddle.get_default_dtype()
# Check probs range [0, 1].
if _non_static_mode():
"""Not use `paddle.any` in static mode, which always be `True`."""
if (
paddle.any(self.probs < 0)
or paddle.any(self.probs > 1)
or paddle.any(paddle.isnan(self.probs))
):
raise ValueError("The arg of `probs` must be in range [0, 1].")
# Clip probs from [0, 1] to (0, 1) with smallest representable number `eps`.
self.probs = _clip_probs(self.probs, self.dtype)
self.logits = self._probs_to_logits(self.probs, is_binary=True)
super().__init__(batch_shape=self.probs.shape, event_shape=())
@property
def mean(self):
"""Mean of Bernoulli distribution.
Returns:
Tensor: Mean value of distribution.
"""
return self.probs
@property
def variance(self):
"""Variance of Bernoulli distribution.
Returns:
Tensor: Variance value of distribution.
"""
return paddle.multiply(self.probs, (1 - self.probs))
def sample(self, shape):
"""Sample from Bernoulli distribution.
Args:
shape (Sequence[int]): Sample shape.
Returns:
Tensor: Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Bernoulli
rv = Bernoulli(paddle.full((), 0.3))
print(rv.sample([100]).shape)
# [100]
rv = Bernoulli(paddle.to_tensor(0.3))
print(rv.sample([100]).shape)
# [100, 1]
rv = Bernoulli(paddle.to_tensor([0.3, 0.5]))
print(rv.sample([100]).shape)
# [100, 2]
rv = Bernoulli(paddle.to_tensor([0.3, 0.5]))
print(rv.sample([100, 2]).shape)
# [100, 2, 2]
"""
name = self.name + '_sample'
if not _non_static_mode():
check_type(
shape,
'shape',
(np.ndarray, tensor.Variable, list, tuple),
name,
)
shape = shape if isinstance(shape, tuple) else tuple(shape)
shape = self._extend_shape(shape)
with paddle.no_grad():
return paddle.bernoulli(self.probs.expand(shape), name=name)
def rsample(self, shape, temperature=1.0):
"""Sample from Bernoulli distribution (reparameterized).
The `rsample` is a continuously approximate of Bernoulli distribution reparameterized sample method.
[1] Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables. 2016.
[2] Eric Jang, Shixiang Gu, and Ben Poole. Categorical Reparameterization with Gumbel-Softmax. 2016.
Note:
`rsample` need to be followed by a `sigmoid`, which converts samples' value to unit interval (0, 1).
Args:
shape (Sequence[int]): Sample shape.
temperature (float): temperature for rsample, must be positive.
Returns:
Tensor: Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Bernoulli
paddle.seed(2023)
rv = Bernoulli(paddle.full((), 0.3))
print(rv.sample([100]).shape)
# [100]
rv = Bernoulli(0.3)
print(rv.rsample([100]).shape)
# [100, 1]
rv = Bernoulli(paddle.to_tensor([0.3, 0.5]))
print(rv.rsample([100]).shape)
# [100, 2]
rv = Bernoulli(paddle.to_tensor([0.3, 0.5]))
print(rv.rsample([100, 2]).shape)
# [100, 2, 2]
# `rsample` has to be followed by a `sigmoid`
rv = Bernoulli(0.3)
rsample = rv.rsample([3, ])
rsample_sigmoid = paddle.nn.functional.sigmoid(rsample)
print(rsample, rsample_sigmoid)
# Tensor(shape=[3, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[-0.88315082],
# [-0.62347704],
# [-0.31513220]]) Tensor(shape=[3, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[0.29252526],
# [0.34899110],
# [0.42186251]])
# The smaller the `temperature`, the distribution of `rsample` closer to `sample`, with `probs` of 0.3.
print(paddle.nn.functional.sigmoid(rv.rsample([1000, ], temperature=1.0)).sum())
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [361.06829834])
print(paddle.nn.functional.sigmoid(rv.rsample([1000, ], temperature=0.1)).sum())
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [288.66418457])
"""
name = self.name + '_rsample'
if not _non_static_mode():
check_type(
shape,
'shape',
(np.ndarray, tensor.Variable, list, tuple),
name,
)
check_type(
temperature,
'temperature',
(float,),
name,
)
shape = shape if isinstance(shape, tuple) else tuple(shape)
shape = self._extend_shape(shape)
temperature = paddle.full(
shape=(), fill_value=temperature, dtype=self.dtype
)
probs = self.probs.expand(shape)
uniforms = paddle.rand(shape, dtype=self.dtype)
return paddle.divide(
paddle.add(
paddle.subtract(uniforms.log(), (-uniforms).log1p()),
paddle.subtract(probs.log(), (-probs).log1p()),
),
temperature,
)
def cdf(self, value):
r"""Cumulative distribution function(CDF) evaluated at value.
.. math::
{ \begin{cases}
0 & \text{if } value \lt 0 \\
1 - p & \text{if } 0 \leq value \lt 1 \\
1 & \text{if } value \geq 1
\end{cases}
}
Args:
value (Tensor): Value to be evaluated.
Returns:
Tensor: CDF evaluated at value.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Bernoulli
rv = Bernoulli(0.3)
print(rv.cdf(paddle.to_tensor([1.0])))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.])
"""
name = self.name + '_cdf'
if not _non_static_mode():
check_type(value, 'value', tensor.Variable, name)
value = self._check_values_dtype_in_probs(self.probs, value)
probs, value = paddle.broadcast_tensors([self.probs, value])
zeros = paddle.zeros_like(probs)
ones = paddle.ones_like(probs)
return paddle.where(
value < 0,
zeros,
paddle.where(value < 1, paddle.subtract(ones, probs), ones),
name=name,
)
def log_prob(self, value):
"""Log of probability densitiy function.
Args:
value (Tensor): Value to be evaluated.
Returns:
Tensor: Log of probability densitiy evaluated at value.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Bernoulli
rv = Bernoulli(0.3)
print(rv.log_prob(paddle.to_tensor([1.0])))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [-1.20397282])
"""
name = self.name + '_log_prob'
if not _non_static_mode():
check_type(value, 'value', tensor.Variable, name)
value = self._check_values_dtype_in_probs(self.probs, value)
logits, value = paddle.broadcast_tensors([self.logits, value])
return -binary_cross_entropy_with_logits(
logits, value, reduction='none', name=name
)
def prob(self, value):
r"""Probability density function(PDF) evaluated at value.
.. math::
{ \begin{cases}
q=1-p & \text{if }value=0 \\
p & \text{if }value=1
\end{cases}
}
Args:
value (Tensor): Value to be evaluated.
Returns:
Tensor: PDF evaluated at value.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Bernoulli
rv = Bernoulli(0.3)
print(rv.prob(paddle.to_tensor([1.0])))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.29999998])
"""
name = self.name + '_prob'
if not _non_static_mode():
check_type(value, 'value', tensor.Variable, name)
return self.log_prob(value).exp(name=name)
def entropy(self):
r"""Entropy of Bernoulli distribution.
.. math::
{
entropy = -(q \log q + p \log p)
}
Returns:
Tensor: Entropy of distribution.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Bernoulli
rv = Bernoulli(0.3)
print(rv.entropy())
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.61086434])
"""
name = self.name + '_entropy'
return binary_cross_entropy_with_logits(
self.logits, self.probs, reduction='none', name=name
)
def kl_divergence(self, other):
r"""The KL-divergence between two Bernoulli distributions.
.. math::
{
KL(a || b) = p_a \log(p_a / p_b) + (1 - p_a) \log((1 - p_a) / (1 - p_b))
}
Args:
other (Bernoulli): instance of Bernoulli.
Returns:
Tensor: kl-divergence between two Bernoulli distributions.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Bernoulli
rv = Bernoulli(0.3)
rv_other = Bernoulli(0.7)
print(rv.kl_divergence(rv_other))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.33891910])
"""
name = self.name + '_kl_divergence'
if not _non_static_mode():
check_type(other, 'other', Bernoulli, name)
a_logits = self.logits
b_logits = other.logits
log_pa = -softplus(-a_logits)
log_pb = -softplus(-b_logits)
pa = sigmoid(a_logits)
one_minus_pa = sigmoid(-a_logits)
log_one_minus_pa = -softplus(a_logits)
log_one_minus_pb = -softplus(b_logits)
return paddle.add(
paddle.subtract(
paddle.multiply(log_pa, pa), paddle.multiply(log_pb, pa)
),
paddle.subtract(
paddle.multiply(log_one_minus_pa, one_minus_pa),
paddle.multiply(log_one_minus_pb, one_minus_pa),
),
)
......@@ -15,6 +15,7 @@ import functools
import warnings
import paddle
from paddle.distribution.bernoulli import Bernoulli
from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet
......@@ -143,6 +144,11 @@ class _Compare:
return True
@register_kl(Bernoulli, Bernoulli)
def _kl_bernoulli_bernoulli(p, q):
return p.kl_divergence(q)
@register_kl(Beta, Beta)
def _kl_beta_beta(p, q):
return (
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import scipy.special
import scipy.stats
from config import ATOL, DEVICES, RTOL
from parameterize import (
TEST_CASE_NAME,
parameterize_cls,
parameterize_func,
place,
)
from test_distribution import DistributionNumpy
import paddle
from paddle.distribution import Bernoulli
from paddle.distribution.kl import kl_divergence
from paddle.fluid.data_feeder import convert_dtype
np.random.seed(2023)
paddle.seed(2023)
# Smallest representable number.
EPS = {
'float32': np.finfo('float32').eps,
'float64': np.finfo('float64').eps,
}
def _clip_probs_ndarray(probs, dtype):
"""Clip probs from [0, 1] to (0, 1) with ``eps``"""
eps = EPS.get(dtype)
return np.clip(probs, a_min=eps, a_max=1 - eps).astype(dtype)
def _sigmoid(z):
return scipy.special.expit(z)
def _kstest(samples_a, samples_b, temperature=1):
"""Uses the Kolmogorov-Smirnov test for goodness of fit."""
_, p_value = scipy.stats.ks_2samp(samples_a, samples_b)
return not (p_value < 0.02 * (min(1, temperature)))
class BernoulliNumpy(DistributionNumpy):
def __init__(self, probs):
probs = np.array(probs)
if str(probs.dtype) not in ['float32', 'float64']:
self.dtype = 'float32'
else:
self.dtype = probs.dtype
self.batch_shape = np.shape(probs)
self.probs = _clip_probs_ndarray(
np.array(probs, dtype=self.dtype), str(self.dtype)
)
self.logits = self._probs_to_logits(self.probs, is_binary=True)
self.rv = scipy.stats.bernoulli(self.probs.astype('float64'))
@property
def mean(self):
return self.rv.mean().astype(self.dtype)
@property
def variance(self):
return self.rv.var().astype(self.dtype)
def sample(self, shape):
shape = np.array(shape, dtype='int')
if shape.ndim:
shape = shape.tolist()
else:
shape = [shape.tolist()]
return self.rv.rvs(size=shape + list(self.batch_shape)).astype(
self.dtype
)
def log_prob(self, value):
return self.rv.logpmf(value).astype(self.dtype)
def prob(self, value):
return self.rv.pmf(value).astype(self.dtype)
def cdf(self, value):
return self.rv.cdf(value).astype(self.dtype)
def entropy(self):
return (
np.maximum(
self.logits,
0,
)
- self.logits * self.probs
+ np.log(1 + np.exp(-np.abs(self.logits)))
).astype(self.dtype)
def kl_divergence(self, other):
"""
.. math::
KL[a || b] = Pa * Log[Pa / Pb] + (1 - Pa) * Log[(1 - Pa) / (1 - Pb)]
"""
p_a = self.probs
p_b = other.probs
return (
p_a * np.log(p_a / p_b) + (1 - p_a) * np.log((1 - p_a) / (1 - p_b))
).astype(self.dtype)
def _probs_to_logits(self, probs, is_binary=False):
return (
(np.log(probs) - np.log1p(-probs)) if is_binary else np.log(probs)
).astype(self.dtype)
class BernoulliTest(unittest.TestCase):
def setUp(self):
paddle.disable_static(self.place)
with paddle.fluid.dygraph.guard(self.place):
# just for convenience
self.dtype = self.expected_dtype
# init numpy with `dtype`
self.init_numpy_data(self.probs, self.dtype)
# init paddle and check dtype convert.
self.init_dynamic_data(self.probs, self.default_dtype, self.dtype)
def init_numpy_data(self, probs, dtype):
probs = np.array(probs).astype(dtype)
self.rv_np = BernoulliNumpy(probs)
def init_dynamic_data(self, probs, default_dtype, dtype):
self.rv_paddle = Bernoulli(probs)
self.assertTrue(
dtype == convert_dtype(self.rv_paddle.probs.dtype),
(dtype, self.rv_paddle.probs.dtype),
)
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs', 'default_dtype', 'expected_dtype'),
[
# 0-D probs
('probs_00_32', paddle.full((), 0.0), 'float32', 'float32'),
('probs_03_32', paddle.full((), 0.3), 'float32', 'float32'),
('probs_10_32', paddle.full((), 1.0), 'float32', 'float32'),
(
'probs_00_64',
paddle.full((), 0.0, dtype='float64'),
'float64',
'float64',
),
(
'probs_03_64',
paddle.full((), 0.3, dtype='float64'),
'float64',
'float64',
),
(
'probs_10_64',
paddle.full((), 1.0, dtype='float64'),
'float64',
'float64',
),
# 1-D probs
('probs_00', 0.0, 'float64', 'float32'),
('probs_03', 0.3, 'float64', 'float32'),
('probs_10', 1.0, 'float64', 'float32'),
('probs_tensor_03_32', paddle.to_tensor(0.3), 'float32', 'float32'),
(
'probs_tensor_03_64',
paddle.to_tensor(0.3, dtype='float64'),
'float64',
'float64',
),
(
'probs_tensor_03_list_32',
paddle.to_tensor(
[
0.3,
]
),
'float32',
'float32',
),
(
'probs_tensor_03_list_64',
paddle.to_tensor(
[
0.3,
],
dtype='float64',
),
'float64',
'float64',
),
# N-D probs
(
'probs_tensor_0305',
paddle.to_tensor((0.3, 0.5)),
'float32',
'float32',
),
(
'probs_tensor_03050104',
paddle.to_tensor(((0.3, 0.5), (0.1, 0.4))),
'float32',
'float32',
),
],
)
class BernoulliTestFeature(BernoulliTest):
def test_mean(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self.rv_paddle.mean,
self.rv_np.mean,
rtol=RTOL.get(self.dtype),
atol=ATOL.get(self.dtype),
)
def test_variance(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self.rv_paddle.variance,
self.rv_np.variance,
rtol=RTOL.get(self.dtype),
atol=ATOL.get(self.dtype),
)
@parameterize_func(
[
(
paddle.to_tensor(
[
0.0,
]
),
),
(
paddle.to_tensor(
0.0,
),
),
(paddle.to_tensor(1.0),),
(paddle.to_tensor(0.0, dtype='float64'),),
]
)
def test_log_prob(self, value):
with paddle.fluid.dygraph.guard(self.place):
if convert_dtype(value.dtype) == convert_dtype(
self.rv_paddle.probs.dtype
):
log_prob = self.rv_paddle.log_prob(value)
np.testing.assert_allclose(
log_prob,
self.rv_np.log_prob(value),
rtol=RTOL.get(self.dtype),
atol=ATOL.get(self.dtype),
)
self.assertTrue(self.dtype == convert_dtype(log_prob.dtype))
else:
with self.assertWarns(UserWarning):
self.rv_paddle.log_prob(value)
@parameterize_func(
[
(
paddle.to_tensor(
[
0.0,
]
),
),
(paddle.to_tensor(0.0),),
(paddle.to_tensor(1.0),),
(paddle.to_tensor(0.0, dtype='float64'),),
]
)
def test_prob(self, value):
with paddle.fluid.dygraph.guard(self.place):
if convert_dtype(value.dtype) == convert_dtype(
self.rv_paddle.probs.dtype
):
prob = self.rv_paddle.prob(value)
np.testing.assert_allclose(
prob,
self.rv_np.prob(value),
rtol=RTOL.get(self.dtype),
atol=ATOL.get(self.dtype),
)
self.assertTrue(self.dtype == convert_dtype(prob.dtype))
else:
with self.assertWarns(UserWarning):
self.rv_paddle.prob(value)
@parameterize_func(
[
(
paddle.to_tensor(
[
0.0,
]
),
),
(paddle.to_tensor(0.0),),
(paddle.to_tensor(0.3),),
(paddle.to_tensor(0.7),),
(paddle.to_tensor(1.0),),
(paddle.to_tensor(0.0, dtype='float64'),),
]
)
def test_cdf(self, value):
with paddle.fluid.dygraph.guard(self.place):
if convert_dtype(value.dtype) == convert_dtype(
self.rv_paddle.probs.dtype
):
cdf = self.rv_paddle.cdf(value)
np.testing.assert_allclose(
cdf,
self.rv_np.cdf(value),
rtol=RTOL.get(self.dtype),
atol=ATOL.get(self.dtype),
)
self.assertTrue(self.dtype == convert_dtype(cdf.dtype))
else:
with self.assertWarns(UserWarning):
self.rv_paddle.cdf(value)
def test_entropy(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self.rv_paddle.entropy(),
self.rv_np.entropy(),
rtol=RTOL.get(self.dtype),
atol=ATOL.get(self.dtype),
)
def test_kl_divergence(self):
with paddle.fluid.dygraph.guard(self.place):
other_probs = paddle.to_tensor(0.9, dtype=self.dtype)
rv_paddle_other = Bernoulli(other_probs)
rv_np_other = BernoulliNumpy(other_probs)
np.testing.assert_allclose(
self.rv_paddle.kl_divergence(rv_paddle_other),
self.rv_np.kl_divergence(rv_np_other),
rtol=RTOL.get(self.dtype),
atol=ATOL.get(self.dtype),
)
np.testing.assert_allclose(
kl_divergence(self.rv_paddle, rv_paddle_other),
self.rv_np.kl_divergence(rv_np_other),
rtol=RTOL.get(self.dtype),
atol=ATOL.get(self.dtype),
)
@place(DEVICES)
@parameterize_cls(
(
TEST_CASE_NAME,
'probs',
'default_dtype',
'expected_dtype',
'shape',
'expected_shape',
),
[
# 0-D probs
(
'probs_0d_1d',
paddle.full((), 0.3),
'float32',
'float32',
[
100,
],
[
100,
],
),
(
'probs_0d_2d',
paddle.full((), 0.3),
'float32',
'float32',
[100, 1],
[100, 1],
),
(
'probs_0d_3d',
paddle.full((), 0.3),
'float32',
'float32',
[100, 2, 3],
[100, 2, 3],
),
# 1-D probs
(
'probs_1d_1d_32',
paddle.to_tensor(0.3),
'float32',
'float32',
[
100,
],
[100, 1],
),
(
'probs_1d_1d_64',
paddle.to_tensor(0.3, dtype='float64'),
'float64',
'float64',
paddle.to_tensor(
[
100,
]
),
[100, 1],
),
(
'probs_1d_2d',
paddle.to_tensor(0.3),
'float32',
'float32',
[100, 2],
[100, 2, 1],
),
(
'probs_1d_3d',
paddle.to_tensor(0.3),
'float32',
'float32',
[100, 2, 3],
[100, 2, 3, 1],
),
# N-D probs
(
'probs_2d_1d',
paddle.to_tensor((0.3, 0.5)),
'float32',
'float32',
[
100,
],
[100, 2],
),
(
'probs_2d_2d',
paddle.to_tensor((0.3, 0.5)),
'float32',
'float32',
[100, 3],
[100, 3, 2],
),
(
'probs_2d_3d',
paddle.to_tensor((0.3, 0.5)),
'float32',
'float32',
[100, 4, 3],
[100, 4, 3, 2],
),
],
)
class BernoulliTestSample(BernoulliTest):
def test_sample(self):
with paddle.fluid.dygraph.guard(self.place):
sample_np = self.rv_np.sample(self.shape)
sample_paddle = self.rv_paddle.sample(self.shape)
self.assertEqual(list(sample_paddle.shape), self.expected_shape)
self.assertEqual(sample_paddle.dtype, self.rv_paddle.probs.dtype)
if self.probs.ndim:
for i in range(len(self.probs)):
self.assertTrue(
_kstest(
sample_np[..., i].reshape(-1),
sample_paddle.numpy()[..., i].reshape(-1),
)
)
else:
self.assertTrue(
_kstest(
sample_np.reshape(-1),
sample_paddle.numpy().reshape(-1),
)
)
@parameterize_func(
[
(1.0,),
(0.1,),
]
)
def test_rsample(self, temperature):
"""Compare two samples from `rsample` method, one from scipy `sample` and another from paddle `rsample`."""
with paddle.fluid.dygraph.guard(self.place):
sample_np = self.rv_np.sample(self.shape)
rsample_paddle = self.rv_paddle.rsample(self.shape, temperature)
self.assertEqual(list(rsample_paddle.shape), self.expected_shape)
self.assertEqual(rsample_paddle.dtype, self.rv_paddle.probs.dtype)
if self.probs.ndim:
for i in range(len(self.probs)):
self.assertTrue(
_kstest(
sample_np[..., i].reshape(-1),
(
_sigmoid(rsample_paddle.numpy()[..., i]) > 0.5
).reshape(-1),
temperature,
)
)
else:
self.assertTrue(
_kstest(
sample_np.reshape(-1),
(_sigmoid(rsample_paddle.numpy()) > 0.5).reshape(-1),
temperature,
)
)
def test_rsample_backpropagation(self):
with paddle.fluid.dygraph.guard(self.place):
self.rv_paddle.probs.stop_gradient = False
rsample_paddle = self.rv_paddle.rsample(self.shape)
rsample_paddle = paddle.nn.functional.sigmoid(rsample_paddle)
grads = paddle.grad([rsample_paddle], [self.rv_paddle.probs])
self.assertEqual(len(grads), 1)
self.assertEqual(grads[0].dtype, self.rv_paddle.probs.dtype)
self.assertEqual(grads[0].shape, self.rv_paddle.probs.shape)
@place(DEVICES)
@parameterize_cls([TEST_CASE_NAME], ['BernoulliTestError'])
class BernoulliTestError(unittest.TestCase):
def setUp(self):
paddle.disable_static(self.place)
@parameterize_func(
[
(-0.1, ValueError),
(1.1, ValueError),
(np.nan, ValueError),
(-1j + 1, TypeError),
]
)
def test_bad_init(self, probs, error):
with paddle.fluid.dygraph.guard(self.place):
self.assertRaises(error, Bernoulli, probs)
@parameterize_func(
[
(
[0.3, 0.5],
paddle.to_tensor([0.1, 0.2, 0.3]),
),
]
)
def test_bad_broadcast(self, probs, value):
with paddle.fluid.dygraph.guard(self.place):
rv = Bernoulli(probs)
self.assertRaises(ValueError, rv.cdf, value)
self.assertRaises(ValueError, rv.log_prob, value)
self.assertRaises(ValueError, rv.prob, value)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from config import ATOL, DEVICES, RTOL
from parameterize import (
TEST_CASE_NAME,
parameterize_cls,
parameterize_func,
place,
)
from test_distribution_bernoulli import BernoulliNumpy, _kstest, _sigmoid
import paddle
from paddle.distribution import Bernoulli
from paddle.distribution.kl import kl_divergence
np.random.seed(2023)
paddle.seed(2023)
paddle.enable_static()
default_dtype = paddle.get_default_dtype()
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'params'), # params: name, probs, probs_other, value
[
(
'params',
(
# 1-D probs
(
'probs_not_iterable',
0.3,
0.7,
1.0,
),
(
'probs_not_iterable_and_broadcast_for_value',
0.3,
0.7,
np.array([[0.0, 1.0], [1.0, 0.0]], dtype=default_dtype),
),
# N-D probs
(
'probs_tuple_0305',
(0.3, 0.5),
0.7,
1.0,
),
(
'probs_tuple_03050104',
((0.3, 0.5), (0.1, 0.4)),
0.7,
1.0,
),
),
)
],
)
class BernoulliTestFeature(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
self.params_len = len(self.params)
with paddle.static.program_guard(self.program):
self.init_numpy_data(self.params)
self.init_static_data(self.params)
def init_numpy_data(self, params):
self.mean_np = []
self.variance_np = []
self.log_prob_np = []
self.prob_np = []
self.cdf_np = []
self.entropy_np = []
self.kl_np = []
for _, probs, probs_other, value in params:
rv_np = BernoulliNumpy(probs)
rv_np_other = BernoulliNumpy(probs_other)
self.mean_np.append(rv_np.mean)
self.variance_np.append(rv_np.variance)
self.log_prob_np.append(rv_np.log_prob(value))
self.prob_np.append(rv_np.prob(value))
self.cdf_np.append(rv_np.cdf(value))
self.entropy_np.append(rv_np.entropy())
self.kl_np.append(rv_np.kl_divergence(rv_np_other))
def init_static_data(self, params):
with paddle.static.program_guard(self.program):
rv_paddles = []
rv_paddles_other = []
values = []
for _, probs, probs_other, value in params:
if not isinstance(value, np.ndarray):
value = paddle.full([1], value, dtype=default_dtype)
else:
value = paddle.to_tensor(value, place=self.place)
rv_paddles.append(Bernoulli(probs=paddle.to_tensor(probs)))
rv_paddles_other.append(
Bernoulli(probs=paddle.to_tensor(probs_other))
)
values.append(value)
results = self.executor.run(
self.program,
feed={},
fetch_list=[
[
rv_paddles[i].mean,
rv_paddles[i].variance,
rv_paddles[i].log_prob(values[i]),
rv_paddles[i].prob(values[i]),
rv_paddles[i].cdf(values[i]),
rv_paddles[i].entropy(),
rv_paddles[i].kl_divergence(rv_paddles_other[i]),
kl_divergence(rv_paddles[i], rv_paddles_other[i]),
]
for i in range(self.params_len)
],
)
self.mean_paddle = []
self.variance_paddle = []
self.log_prob_paddle = []
self.prob_paddle = []
self.cdf_paddle = []
self.entropy_paddle = []
self.kl_paddle = []
self.kl_func_paddle = []
for i in range(self.params_len):
(
_mean,
_variance,
_log_prob,
_prob,
_cdf,
_entropy,
_kl,
_kl_func,
) = results[i * 8 : (i + 1) * 8]
self.mean_paddle.append(_mean)
self.variance_paddle.append(_variance)
self.log_prob_paddle.append(_log_prob)
self.prob_paddle.append(_prob)
self.cdf_paddle.append(_cdf)
self.entropy_paddle.append(_entropy)
self.kl_paddle.append(_kl)
self.kl_func_paddle.append(_kl_func)
def test_all(self):
for i in range(self.params_len):
self._test_mean(i)
self._test_variance(i)
self._test_log_prob(i)
self._test_prob(i)
self._test_cdf(i)
self._test_entropy(i)
self._test_kl_divergence(i)
def _test_mean(self, i):
np.testing.assert_allclose(
self.mean_np[i],
self.mean_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
def _test_variance(self, i):
np.testing.assert_allclose(
self.variance_np[i],
self.variance_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
def _test_log_prob(self, i):
np.testing.assert_allclose(
self.log_prob_np[i],
self.log_prob_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
def _test_prob(self, i):
np.testing.assert_allclose(
self.prob_np[i],
self.prob_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
def _test_cdf(self, i):
np.testing.assert_allclose(
self.cdf_np[i],
self.cdf_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
def _test_entropy(self, i):
np.testing.assert_allclose(
self.entropy_np[i],
self.entropy_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
def _test_kl_divergence(self, i):
np.testing.assert_allclose(
self.kl_np[i],
self.kl_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
np.testing.assert_allclose(
self.kl_np[i],
self.kl_func_paddle[i],
rtol=RTOL.get(default_dtype),
atol=ATOL.get(default_dtype),
)
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs', 'shape', 'temperature', 'expected_shape'),
[
# 1-D probs
(
'probs_03',
(0.3,),
[
100,
],
0.1,
[100, 1],
),
# N-D probs
(
'probs_0305',
(0.3, 0.5),
[
100,
],
0.1,
[100, 2],
),
],
)
class BernoulliTestSample(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.program):
self.init_numpy_data(self.probs, self.shape)
self.init_static_data(self.probs, self.shape, self.temperature)
def init_numpy_data(self, probs, shape):
self.rv_np = BernoulliNumpy(probs)
self.sample_np = self.rv_np.sample(shape)
def init_static_data(self, probs, shape, temperature):
with paddle.static.program_guard(self.program):
self.rv_paddle = Bernoulli(probs=paddle.to_tensor(probs))
[self.sample_paddle, self.rsample_paddle] = self.executor.run(
self.program,
feed={},
fetch_list=[
self.rv_paddle.sample(shape),
self.rv_paddle.rsample(shape, temperature),
],
)
def test_sample(self):
with paddle.static.program_guard(self.program):
self.assertEqual(
list(self.sample_paddle.shape), self.expected_shape
)
for i in range(len(self.probs)):
self.assertTrue(
_kstest(
self.sample_np[..., i].reshape(-1),
self.sample_paddle[..., i].reshape(-1),
)
)
def test_rsample(self):
"""Compare two samples from `rsample` method, one from scipy and another from paddle."""
with paddle.static.program_guard(self.program):
self.assertEqual(
list(self.rsample_paddle.shape), self.expected_shape
)
for i in range(len(self.probs)):
self.assertTrue(
_kstest(
self.sample_np[..., i].reshape(-1),
(_sigmoid(self.rsample_paddle[..., i]) > 0.5).reshape(
-1
),
self.temperature,
)
)
@place(DEVICES)
@parameterize_cls([TEST_CASE_NAME], ['BernoulliTestError'])
class BernoulliTestError(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
@parameterize_func(
[
(0,), # int
((0.3,),), # tuple
(
[
0.3,
],
), # list
(
np.array(
[
0.3,
]
),
), # ndarray
(-1j + 1,), # complex
('0',), # str
]
)
def test_bad_init_type(self, probs):
with paddle.static.program_guard(self.program):
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[Bernoulli(probs=probs)]
)
@parameterize_func(
[
(100,), # int
(100.0,), # float
]
)
def test_bad_sample_shape_type(self, shape):
with paddle.static.program_guard(self.program):
rv = Bernoulli(0.3)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.sample(shape)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.rsample(shape)]
)
@parameterize_func(
[
(1,), # int
]
)
def test_bad_rsample_temperature_type(self, temperature):
with paddle.static.program_guard(self.program):
rv = Bernoulli(0.3)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program,
feed={},
fetch_list=[rv.rsample([100], temperature)],
)
@parameterize_func(
[
(1,), # int
(1.0,), # float
([1.0],), # list
((1.0),), # tuple
(np.array(1.0),), # ndarray
]
)
def test_bad_value_type(self, value):
with paddle.static.program_guard(self.program):
rv = Bernoulli(0.3)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.log_prob(value)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.prob(value)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.cdf(value)]
)
@parameterize_func(
[
(np.array(1.0),), # ndarray or other distribution
]
)
def test_bad_kl_other_type(self, other):
with paddle.static.program_guard(self.program):
rv = Bernoulli(0.3)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.kl_divergence(other)]
)
@parameterize_func(
[
(paddle.to_tensor([0.1, 0.2, 0.3]),),
]
)
def test_bad_broadcast(self, value):
with paddle.static.program_guard(self.program):
rv = Bernoulli(paddle.to_tensor([0.3, 0.5]))
# `logits, value = paddle.broadcast_tensors([self.logits, value])`
# raise ValueError in dygraph, raise TypeError in static.
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.cdf(value)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.log_prob(value)]
)
with self.assertRaises(TypeError):
[_] = self.executor.run(
self.program, feed={}, fetch_list=[rv.prob(value)]
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册