未验证 提交 bca1b0c6 编写于 作者: D dasen 提交者: GitHub

【Hackathon 4th No.11】 为 paddle 添加 Geometric Distribution API (#51224)

上级 8f5eae47
...@@ -29,6 +29,7 @@ from paddle.distribution.transform import * # noqa: F403 ...@@ -29,6 +29,7 @@ from paddle.distribution.transform import * # noqa: F403
from paddle.distribution.transformed_distribution import TransformedDistribution from paddle.distribution.transformed_distribution import TransformedDistribution
from paddle.distribution.uniform import Uniform from paddle.distribution.uniform import Uniform
from paddle.distribution.laplace import Laplace from paddle.distribution.laplace import Laplace
from paddle.distribution.geometric import Geometric
__all__ = [ # noqa __all__ = [ # noqa
'Bernoulli', 'Bernoulli',
...@@ -47,6 +48,7 @@ __all__ = [ # noqa ...@@ -47,6 +48,7 @@ __all__ = [ # noqa
'Laplace', 'Laplace',
'LogNormal', 'LogNormal',
'Gumbel', 'Gumbel',
'Geometric',
] ]
__all__.extend(transform.__all__) __all__.extend(transform.__all__)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import numpy as np
import paddle
from paddle.distribution import distribution, uniform
from paddle.fluid import framework
class Geometric(distribution.Distribution):
r"""
Geometric distribution parameterized by probs.
In probability theory and statistics, the geometric distribution is one of
discrete probability distributions, parameterized by one positive shape parameter, denoted by probs.
In n Bernoulli trials, it takes k trials to get the probability of success for the first time.
In detail, it is: the probability that the first k-1 times failed and the kth time succeeded.
The geometric distribution is a special case of the Pascal distribution when r=1.
The probability mass function (pmf) is
.. math::
Pr(Y=k)=(1-p)^kp
where k is number of trials performed and p is probability of success for each trial and k=0,1,2,3,4..., p belong to (0,1].
Args:
probs (Real|Tensor): Probability parameter.
The value of probs must be positive. When the parameter is a tensor, probs is probability of success for each trial.
Returns:
Geometric distribution for instantiation of probs.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.mean
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [2.])
geom.variance
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [2.])
geom.stddev
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.41421354])
"""
def __init__(self, probs):
if isinstance(probs, (numbers.Real, paddle.Tensor, framework.Variable)):
if isinstance(probs, numbers.Real):
probs = paddle.full(
shape=(), fill_value=probs, dtype=paddle.float32
)
all_ones = paddle.full(
shape=probs.shape, fill_value=1, dtype=probs.dtype
)
all_zeros = paddle.full(
shape=probs.shape, fill_value=0, dtype=probs.dtype
)
all_false = paddle.full(
shape=probs.shape, fill_value=False, dtype=bool
)
lessthen_0 = probs <= all_zeros
morethen_1 = probs > all_ones
else:
raise TypeError(
f"Expected type of probs is Number.Real|Tensor|framework.Variable, but got {type(probs)}"
)
if paddle.equal_all(lessthen_0, all_false) and paddle.equal_all(
morethen_1, all_false
):
batch_shape = tuple(probs.shape)
else:
raise ValueError(
"Expected parameter probs of distribution Geometric to satisfy the"
"constraint Interval(lower_bound=0.0, upper_bound=1.0)"
)
self.probs = probs
super().__init__(batch_shape)
@property
def mean(self):
"""Mean of geometric distribution."""
return 1.0 / self.probs
@property
def variance(self):
"""Variance of geometric distribution."""
return paddle.to_tensor(
(1.0 / self.probs - 1.0) / self.probs,
dtype=self.probs.dtype,
)
@property
def stddev(self):
"""Standard deviation of Geometric distribution."""
return paddle.sqrt(self.variance)
def pmf(self, k):
r"""Probability mass funciotn evaluated at k.
.. math::
P(X=k) = (1-p)^{k-1} p, \quad k=1,2,3,\ldots
Args:
k (int): Value to be evaluated.
Returns:
Tensor: Probability.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.pmf(2)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.25000000])
"""
if isinstance(k, (numbers.Integral, framework.Variable)):
return paddle.pow((1.0 - self.probs), k - 1.0) * self.probs
else:
raise TypeError(
f"Expected type of k is number.Real|framework.Variable, but got {type(k)}"
)
def log_pmf(self, k):
r"""Log probability mass function evaluated at k.
.. math::
\log P(X = k) = \log(1-p)^k p
Args:
k (int): Value to be evaluated.
Returns:
Tensor: Log probability.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.log_pmf(2)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [-1.38629436])
"""
if isinstance(k, (numbers.Integral, framework.Variable)):
return paddle.log(self.pmf(k))
else:
raise TypeError(
f"Expected type of k is number.Real|framework.Variable, but got {type(k)}"
)
def sample(self, shape=()):
"""Sample from Geometric distribution with sample shape.
Args:
shape (tuple(int)): Sample shape.
Returns:
Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.sample((2,2))
# Tensor(shape=[2, 2, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[[4.28128004],
# [0.53546447]],
# [[0.88012987],
# [0.54070371]]])
"""
with paddle.no_grad():
return self.rsample(shape)
def rsample(self, shape=()):
"""Generate samples of the specified shape.
Args:
shape(tuple(int)): The shape of generated samples.
Returns:
Tensor: A sample tensor that fits the Geometric distribution.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.rsample((2,2))
# Tensor(shape=[2, 2, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[[2.90974379],
# [1.28049409]],
# [[4.60141420],
# [2.98836184]]])
"""
shape = distribution.Distribution._extend_shape(
self, sample_shape=shape
)
tiny = np.finfo(dtype='float32').tiny
sample_uniform = uniform.Uniform(low=float(tiny), high=float(1))
new_t = sample_uniform.sample(list(shape))
return paddle.log(new_t) / paddle.log1p(-(self.probs))
def entropy(self):
r"""Entropy of dirichlet distribution.
.. math::
H(X) = -\left[\frac{1}{p} \log p + \frac{1-p}{p^2} \log (1-p) \right]
Returns:
Tensor: Entropy.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.entropy()
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.38629436])
"""
x = (1.0 - self.probs) * paddle.log(1.0 - self.probs)
y = self.probs * paddle.log(self.probs)
return -(x + y) / self.probs
def cdf(self, k):
r"""Cdf of geometric distribution.
.. math::
F(X \leq k) = 1 - (1-p)^k, \quad k=0,1,2,\ldots
Args:
k: The number of trials performed.
Returns:
Tensor: Entropy.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.cdf(4)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.93750000])
"""
if isinstance(k, (numbers.Integral, framework.Variable)):
return 1.0 - paddle.pow((1.0 - self.probs), k)
else:
raise TypeError(
f"Expected type of k is number.Real|framework.Variable, but got {type(k)}"
)
def kl_divergence(self, other):
r"""Calculate the KL divergence KL(self || other) with two Geometric instances.
.. math::
KL(P \| Q) = \frac{p}{q} \log \frac{p}{q} + \log (1-p) - \log (1-q)
Args:
other (Geometric): An instance of Geometric.
Returns:
Tensor: The kl-divergence between two geometric distributions.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom_p = Geometric(0.5)
geom_q = Geometric(0.1)
geom_p.kl_divergence(geom_q)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.51082563])
"""
if isinstance(other, Geometric):
p, q = self.probs, other.probs
return p * paddle.log(p / q) + (1.0 - p) * paddle.log(
(1.0 - p) / (1.0 - q)
)
else:
raise TypeError(
f"Exected type of other is geometric.Geometric, but got {type(other)}"
)
...@@ -21,6 +21,7 @@ from paddle.distribution.categorical import Categorical ...@@ -21,6 +21,7 @@ from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.geometric import Geometric
from paddle.distribution.laplace import Laplace from paddle.distribution.laplace import Laplace
from paddle.distribution.lognormal import LogNormal from paddle.distribution.lognormal import LogNormal
from paddle.distribution.normal import Normal from paddle.distribution.normal import Normal
...@@ -200,6 +201,11 @@ def _kl_laplace_laplace(p, q): ...@@ -200,6 +201,11 @@ def _kl_laplace_laplace(p, q):
return p.kl_divergence(q) return p.kl_divergence(q)
@register_kl(Geometric, Geometric)
def _kl_geometric_geometric(p, q):
return p.kl_divergence(q)
@register_kl(ExponentialFamily, ExponentialFamily) @register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q): def _kl_expfamily_expfamily(p, q):
"""Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_""" """Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_"""
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import scipy.stats
from config import ATOL, DEVICES, RTOL
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
import paddle
from paddle.distribution import geometric, kl
from paddle.nn.functional import log_softmax
np.random.seed(2023)
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
(
'multi-dim',
xrand(
(2, 3),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
],
)
class TestGeometric(unittest.TestCase):
def setUp(self):
probs = self.probs
if not isinstance(self.probs, numbers.Real):
probs = paddle.to_tensor(self.probs, dtype=paddle.float32)
self._paddle_geom = geometric.Geometric(probs)
def test_mean(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.mean,
scipy.stats.geom.mean(self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_variance(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.variance,
scipy.stats.geom.var(self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_stddev(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.stddev,
scipy.stats.geom.std(self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_entropy(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.entropy(),
scipy.stats.geom.entropy(self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_init_prob_value_error(self):
with self.assertRaises(ValueError):
paddle.distribution.geometric.Geometric(2)
def test_init_prob_type_error(self):
with self.assertRaises(TypeError):
paddle.distribution.geometric.Geometric([2])
def test_sample_shape(self):
cases = [
{
'input': (),
'expect': ()
+ tuple(paddle.squeeze(self._paddle_geom.probs).shape),
},
{
'input': (4, 2),
'expect': (4, 2)
+ tuple(paddle.squeeze(self._paddle_geom.probs).shape),
},
]
for case in cases:
self.assertTrue(
tuple(self._paddle_geom.sample(case.get('input')).shape)
== case.get('expect')
)
def test_sample(self):
sample_shape = (80000,)
samples = self._paddle_geom.sample(sample_shape)
sample_values = samples.numpy()
self.assertEqual(sample_values.dtype, self.probs.dtype)
np.testing.assert_allclose(
sample_values.mean(axis=0),
scipy.stats.geom.mean(self.probs),
rtol=0.7,
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
np.testing.assert_allclose(
sample_values.var(axis=0),
scipy.stats.geom.var(self.probs),
rtol=0.7,
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_rsample_shape(self):
cases = [
{
'input': (),
'expect': ()
+ tuple(paddle.squeeze(self._paddle_geom.probs).shape),
},
{
'input': (2, 5),
'expect': (2, 5)
+ tuple(paddle.squeeze(self._paddle_geom.probs).shape),
},
]
for case in cases:
self.assertTrue(
tuple(self._paddle_geom.rsample(case.get('input')).shape)
== case.get('expect')
)
def test_rsample(self):
sample_shape = (100000,)
samples = self._paddle_geom.rsample(sample_shape)
sample_values = samples.numpy()
self.assertEqual(sample_values.dtype, self.probs.dtype)
np.testing.assert_allclose(
sample_values.mean(axis=0),
scipy.stats.geom.mean(self.probs),
rtol=0.7,
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
np.testing.assert_allclose(
sample_values.var(axis=0),
scipy.stats.geom.var(self.probs),
rtol=0.7,
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_back_rsample(self):
sample_shape = (100000,)
with paddle.fluid.dygraph.guard(self.place):
self._paddle_geom.probs.stop_gradient = False
rs_value = self._paddle_geom.rsample(sample_shape)
softmax_rs = log_softmax(rs_value)
grads = paddle.grad([softmax_rs], [self._paddle_geom.probs])
self.assertEqual(len(grads), 1)
self.assertEqual(grads[0].dtype, self._paddle_geom.probs.dtype)
self.assertEqual(grads[0].shape, self._paddle_geom.probs.shape)
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs', 'value'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
(
'mult-dim',
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
(
'mult-dim',
xrand(
(2, 2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
],
)
class TestGeometricPMF(unittest.TestCase):
def setUp(self):
self._paddle_geom = geometric.Geometric(
probs=paddle.to_tensor(self.probs)
)
def test_pmf(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.pmf(self.value),
scipy.stats.geom.pmf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_log_pmf(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.log_pmf(self.value),
scipy.stats.geom.logpmf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_cdf(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.cdf(self.value),
scipy.stats.geom.cdf(self.value, self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_pmf_error(self):
self.assertRaises(TypeError, self._paddle_geom.pmf, [1, 2])
def test_log_pmf_error(self):
self.assertRaises(TypeError, self._paddle_geom.log_pmf, [1, 2])
def test_cdf_error(self):
self.assertRaises(TypeError, self._paddle_geom.cdf, [1, 2])
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs1', 'probs2'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
(
'multi-dim',
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
],
)
class TestGeometricKL(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self._geometric1 = geometric.Geometric(
probs=paddle.to_tensor(self.probs1)
)
self._geometric2 = geometric.Geometric(
probs=paddle.to_tensor(self.probs2)
)
def test_kl_divergence(self):
np.testing.assert_allclose(
kl.kl_divergence(self._geometric1, self._geometric2),
self._kl(),
rtol=RTOL.get(str(self._geometric1.probs.numpy().dtype)),
atol=ATOL.get(str(self._geometric1.probs.numpy().dtype)),
)
def test_kl1_error(self):
self.assertRaises(
TypeError,
self._geometric1.kl_divergence,
paddle.distribution.beta.Beta,
)
def test_kl2_error(self):
self.assertRaises(
TypeError,
self._geometric2.kl_divergence,
paddle.distribution.beta.Beta,
)
def _kl(self):
return self.probs1 * np.log(self.probs1 / self.probs2) + (
1.0 - self.probs1
) * np.log((1.0 - self.probs1) / (1.0 - self.probs2))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import scipy.stats
from config import ATOL, DEVICES, RTOL
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
import paddle
from paddle.distribution import geometric
np.random.seed(2023)
paddle.enable_static()
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
(
'multi-dim',
xrand(
(2, 3),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
],
)
class TestGeometric(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.program):
# scale no need convert to tensor for scale input unittest
probs = paddle.static.data(
'probs', self.probs.shape, self.probs.dtype
)
self._paddle_geometric = geometric.Geometric(probs)
self.feeds = {'probs': self.probs}
def test_mean(self):
with paddle.static.program_guard(self.program):
[mean] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.mean],
)
np.testing.assert_allclose(
mean,
scipy.stats.geom.mean(self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_variance(self):
with paddle.static.program_guard(self.program):
[variance] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.variance],
)
np.testing.assert_allclose(
variance,
scipy.stats.geom.var(self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_stddev(self):
with paddle.static.program_guard(self.program):
[stddev] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.stddev],
)
np.testing.assert_allclose(
stddev,
scipy.stats.geom.std(self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_sample(self):
with paddle.static.program_guard(self.program):
[data] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=self._paddle_geometric.sample(),
)
self.assertTrue(
data.shape, np.broadcast_arrays(self.probs)[0].shape
)
def test_rsample(self):
with paddle.static.program_guard(self.program):
[data] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=self._paddle_geometric.rsample(),
)
self.assertTrue(
data.shape, np.broadcast_arrays(self.probs)[0].shape
)
def test_entropy(self):
with paddle.static.program_guard(self.program):
[entropy] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.entropy()],
)
np.testing.assert_allclose(
entropy,
scipy.stats.geom.entropy(self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_init_prob_type_error(self):
with self.assertRaises(TypeError):
paddle.distribution.geometric.Geometric([0.5])
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs', 'value'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
(
'mult-dim',
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
(
'mult-dim',
xrand(
(2, 2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
],
)
class TestGeometricPMF(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.program):
probs = paddle.static.data(
'probs', self.probs.shape, self.probs.dtype
)
self._paddle_geometric = geometric.Geometric(probs)
self.feeds = {'probs': self.probs, 'value': self.value}
def test_pmf(self):
with paddle.static.program_guard(self.program):
[pmf] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.pmf(self.value)],
)
np.testing.assert_allclose(
pmf,
scipy.stats.geom.pmf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_log_pmf(self):
with paddle.static.program_guard(self.program):
[log_pmf] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.log_pmf(self.value)],
)
np.testing.assert_allclose(
log_pmf,
scipy.stats.geom.logpmf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_cdf(self):
with paddle.static.program_guard(self.program):
[cdf] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.cdf(self.value)],
)
np.testing.assert_allclose(
cdf,
scipy.stats.geom.cdf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_pmf_error(self):
self.assertRaises(TypeError, self._paddle_geometric.pmf, [1, 2])
def test_log_pmf_error(self):
self.assertRaises(TypeError, self._paddle_geometric.log_pmf, [1, 2])
def test_cdf_error(self):
self.assertRaises(TypeError, self._paddle_geometric.cdf, [1, 2])
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs1', 'probs2'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
(
'multi-dim',
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
],
)
class TestGeometricKL(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.program_p = paddle.static.Program()
self.program_q = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.program_p, self.program_q):
probs_p = paddle.static.data(
'probs1', self.probs1.shape, self.probs1.dtype
)
probs_q = paddle.static.data(
'probs2', self.probs2.shape, self.probs2.dtype
)
self._paddle_geomP = geometric.Geometric(probs_p)
self._paddle_geomQ = geometric.Geometric(probs_q)
self.feeds = {
'probs1': self.probs1,
'probs2': self.probs2,
}
def test_kl_divergence(self):
with paddle.static.program_guard(self.program_p, self.program_q):
self.executor.run(self.program_q)
[kl_diver] = self.executor.run(
self.program_p,
feed=self.feeds,
fetch_list=[
self._paddle_geomP.kl_divergence(self._paddle_geomQ)
],
)
np.testing.assert_allclose(
kl_diver,
self._kl(),
rtol=RTOL.get(str(self.probs1.dtype)),
atol=ATOL.get(str(self.probs1.dtype)),
)
def test_kl1_error(self):
self.assertRaises(
TypeError,
self._paddle_geomP.kl_divergence,
paddle.distribution.beta.Beta,
)
def test_kl2_error(self):
self.assertRaises(
TypeError,
self._paddle_geomQ.kl_divergence,
paddle.distribution.beta.Beta,
)
def _kl(self):
return self.probs1 * np.log(self.probs1 / self.probs2) + (
1.0 - self.probs1
) * np.log((1.0 - self.probs1) / (1.0 - self.probs2))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册