未验证 提交 bca1b0c6 编写于 作者: D dasen 提交者: GitHub

【Hackathon 4th No.11】 为 paddle 添加 Geometric Distribution API (#51224)

上级 8f5eae47
......@@ -29,6 +29,7 @@ from paddle.distribution.transform import * # noqa: F403
from paddle.distribution.transformed_distribution import TransformedDistribution
from paddle.distribution.uniform import Uniform
from paddle.distribution.laplace import Laplace
from paddle.distribution.geometric import Geometric
__all__ = [ # noqa
'Bernoulli',
......@@ -47,6 +48,7 @@ __all__ = [ # noqa
'Laplace',
'LogNormal',
'Gumbel',
'Geometric',
]
__all__.extend(transform.__all__)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import numpy as np
import paddle
from paddle.distribution import distribution, uniform
from paddle.fluid import framework
class Geometric(distribution.Distribution):
r"""
Geometric distribution parameterized by probs.
In probability theory and statistics, the geometric distribution is one of
discrete probability distributions, parameterized by one positive shape parameter, denoted by probs.
In n Bernoulli trials, it takes k trials to get the probability of success for the first time.
In detail, it is: the probability that the first k-1 times failed and the kth time succeeded.
The geometric distribution is a special case of the Pascal distribution when r=1.
The probability mass function (pmf) is
.. math::
Pr(Y=k)=(1-p)^kp
where k is number of trials performed and p is probability of success for each trial and k=0,1,2,3,4..., p belong to (0,1].
Args:
probs (Real|Tensor): Probability parameter.
The value of probs must be positive. When the parameter is a tensor, probs is probability of success for each trial.
Returns:
Geometric distribution for instantiation of probs.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.mean
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [2.])
geom.variance
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [2.])
geom.stddev
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.41421354])
"""
def __init__(self, probs):
if isinstance(probs, (numbers.Real, paddle.Tensor, framework.Variable)):
if isinstance(probs, numbers.Real):
probs = paddle.full(
shape=(), fill_value=probs, dtype=paddle.float32
)
all_ones = paddle.full(
shape=probs.shape, fill_value=1, dtype=probs.dtype
)
all_zeros = paddle.full(
shape=probs.shape, fill_value=0, dtype=probs.dtype
)
all_false = paddle.full(
shape=probs.shape, fill_value=False, dtype=bool
)
lessthen_0 = probs <= all_zeros
morethen_1 = probs > all_ones
else:
raise TypeError(
f"Expected type of probs is Number.Real|Tensor|framework.Variable, but got {type(probs)}"
)
if paddle.equal_all(lessthen_0, all_false) and paddle.equal_all(
morethen_1, all_false
):
batch_shape = tuple(probs.shape)
else:
raise ValueError(
"Expected parameter probs of distribution Geometric to satisfy the"
"constraint Interval(lower_bound=0.0, upper_bound=1.0)"
)
self.probs = probs
super().__init__(batch_shape)
@property
def mean(self):
"""Mean of geometric distribution."""
return 1.0 / self.probs
@property
def variance(self):
"""Variance of geometric distribution."""
return paddle.to_tensor(
(1.0 / self.probs - 1.0) / self.probs,
dtype=self.probs.dtype,
)
@property
def stddev(self):
"""Standard deviation of Geometric distribution."""
return paddle.sqrt(self.variance)
def pmf(self, k):
r"""Probability mass funciotn evaluated at k.
.. math::
P(X=k) = (1-p)^{k-1} p, \quad k=1,2,3,\ldots
Args:
k (int): Value to be evaluated.
Returns:
Tensor: Probability.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.pmf(2)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.25000000])
"""
if isinstance(k, (numbers.Integral, framework.Variable)):
return paddle.pow((1.0 - self.probs), k - 1.0) * self.probs
else:
raise TypeError(
f"Expected type of k is number.Real|framework.Variable, but got {type(k)}"
)
def log_pmf(self, k):
r"""Log probability mass function evaluated at k.
.. math::
\log P(X = k) = \log(1-p)^k p
Args:
k (int): Value to be evaluated.
Returns:
Tensor: Log probability.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.log_pmf(2)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [-1.38629436])
"""
if isinstance(k, (numbers.Integral, framework.Variable)):
return paddle.log(self.pmf(k))
else:
raise TypeError(
f"Expected type of k is number.Real|framework.Variable, but got {type(k)}"
)
def sample(self, shape=()):
"""Sample from Geometric distribution with sample shape.
Args:
shape (tuple(int)): Sample shape.
Returns:
Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.sample((2,2))
# Tensor(shape=[2, 2, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[[4.28128004],
# [0.53546447]],
# [[0.88012987],
# [0.54070371]]])
"""
with paddle.no_grad():
return self.rsample(shape)
def rsample(self, shape=()):
"""Generate samples of the specified shape.
Args:
shape(tuple(int)): The shape of generated samples.
Returns:
Tensor: A sample tensor that fits the Geometric distribution.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.rsample((2,2))
# Tensor(shape=[2, 2, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[[2.90974379],
# [1.28049409]],
# [[4.60141420],
# [2.98836184]]])
"""
shape = distribution.Distribution._extend_shape(
self, sample_shape=shape
)
tiny = np.finfo(dtype='float32').tiny
sample_uniform = uniform.Uniform(low=float(tiny), high=float(1))
new_t = sample_uniform.sample(list(shape))
return paddle.log(new_t) / paddle.log1p(-(self.probs))
def entropy(self):
r"""Entropy of dirichlet distribution.
.. math::
H(X) = -\left[\frac{1}{p} \log p + \frac{1-p}{p^2} \log (1-p) \right]
Returns:
Tensor: Entropy.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.entropy()
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.38629436])
"""
x = (1.0 - self.probs) * paddle.log(1.0 - self.probs)
y = self.probs * paddle.log(self.probs)
return -(x + y) / self.probs
def cdf(self, k):
r"""Cdf of geometric distribution.
.. math::
F(X \leq k) = 1 - (1-p)^k, \quad k=0,1,2,\ldots
Args:
k: The number of trials performed.
Returns:
Tensor: Entropy.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom = Geometric(0.5)
geom.cdf(4)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.93750000])
"""
if isinstance(k, (numbers.Integral, framework.Variable)):
return 1.0 - paddle.pow((1.0 - self.probs), k)
else:
raise TypeError(
f"Expected type of k is number.Real|framework.Variable, but got {type(k)}"
)
def kl_divergence(self, other):
r"""Calculate the KL divergence KL(self || other) with two Geometric instances.
.. math::
KL(P \| Q) = \frac{p}{q} \log \frac{p}{q} + \log (1-p) - \log (1-q)
Args:
other (Geometric): An instance of Geometric.
Returns:
Tensor: The kl-divergence between two geometric distributions.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Geometric
geom_p = Geometric(0.5)
geom_q = Geometric(0.1)
geom_p.kl_divergence(geom_q)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [0.51082563])
"""
if isinstance(other, Geometric):
p, q = self.probs, other.probs
return p * paddle.log(p / q) + (1.0 - p) * paddle.log(
(1.0 - p) / (1.0 - q)
)
else:
raise TypeError(
f"Exected type of other is geometric.Geometric, but got {type(other)}"
)
......@@ -21,6 +21,7 @@ from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.geometric import Geometric
from paddle.distribution.laplace import Laplace
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.normal import Normal
......@@ -200,6 +201,11 @@ def _kl_laplace_laplace(p, q):
return p.kl_divergence(q)
@register_kl(Geometric, Geometric)
def _kl_geometric_geometric(p, q):
return p.kl_divergence(q)
@register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q):
"""Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_"""
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import scipy.stats
from config import ATOL, DEVICES, RTOL
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
import paddle
from paddle.distribution import geometric, kl
from paddle.nn.functional import log_softmax
np.random.seed(2023)
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
(
'multi-dim',
xrand(
(2, 3),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
],
)
class TestGeometric(unittest.TestCase):
def setUp(self):
probs = self.probs
if not isinstance(self.probs, numbers.Real):
probs = paddle.to_tensor(self.probs, dtype=paddle.float32)
self._paddle_geom = geometric.Geometric(probs)
def test_mean(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.mean,
scipy.stats.geom.mean(self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_variance(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.variance,
scipy.stats.geom.var(self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_stddev(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.stddev,
scipy.stats.geom.std(self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_entropy(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.entropy(),
scipy.stats.geom.entropy(self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_init_prob_value_error(self):
with self.assertRaises(ValueError):
paddle.distribution.geometric.Geometric(2)
def test_init_prob_type_error(self):
with self.assertRaises(TypeError):
paddle.distribution.geometric.Geometric([2])
def test_sample_shape(self):
cases = [
{
'input': (),
'expect': ()
+ tuple(paddle.squeeze(self._paddle_geom.probs).shape),
},
{
'input': (4, 2),
'expect': (4, 2)
+ tuple(paddle.squeeze(self._paddle_geom.probs).shape),
},
]
for case in cases:
self.assertTrue(
tuple(self._paddle_geom.sample(case.get('input')).shape)
== case.get('expect')
)
def test_sample(self):
sample_shape = (80000,)
samples = self._paddle_geom.sample(sample_shape)
sample_values = samples.numpy()
self.assertEqual(sample_values.dtype, self.probs.dtype)
np.testing.assert_allclose(
sample_values.mean(axis=0),
scipy.stats.geom.mean(self.probs),
rtol=0.7,
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
np.testing.assert_allclose(
sample_values.var(axis=0),
scipy.stats.geom.var(self.probs),
rtol=0.7,
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_rsample_shape(self):
cases = [
{
'input': (),
'expect': ()
+ tuple(paddle.squeeze(self._paddle_geom.probs).shape),
},
{
'input': (2, 5),
'expect': (2, 5)
+ tuple(paddle.squeeze(self._paddle_geom.probs).shape),
},
]
for case in cases:
self.assertTrue(
tuple(self._paddle_geom.rsample(case.get('input')).shape)
== case.get('expect')
)
def test_rsample(self):
sample_shape = (100000,)
samples = self._paddle_geom.rsample(sample_shape)
sample_values = samples.numpy()
self.assertEqual(sample_values.dtype, self.probs.dtype)
np.testing.assert_allclose(
sample_values.mean(axis=0),
scipy.stats.geom.mean(self.probs),
rtol=0.7,
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
np.testing.assert_allclose(
sample_values.var(axis=0),
scipy.stats.geom.var(self.probs),
rtol=0.7,
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_back_rsample(self):
sample_shape = (100000,)
with paddle.fluid.dygraph.guard(self.place):
self._paddle_geom.probs.stop_gradient = False
rs_value = self._paddle_geom.rsample(sample_shape)
softmax_rs = log_softmax(rs_value)
grads = paddle.grad([softmax_rs], [self._paddle_geom.probs])
self.assertEqual(len(grads), 1)
self.assertEqual(grads[0].dtype, self._paddle_geom.probs.dtype)
self.assertEqual(grads[0].shape, self._paddle_geom.probs.shape)
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs', 'value'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
(
'mult-dim',
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
(
'mult-dim',
xrand(
(2, 2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
],
)
class TestGeometricPMF(unittest.TestCase):
def setUp(self):
self._paddle_geom = geometric.Geometric(
probs=paddle.to_tensor(self.probs)
)
def test_pmf(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.pmf(self.value),
scipy.stats.geom.pmf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_log_pmf(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.log_pmf(self.value),
scipy.stats.geom.logpmf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_cdf(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_geom.cdf(self.value),
scipy.stats.geom.cdf(self.value, self.probs),
rtol=RTOL.get(str(self._paddle_geom.probs.numpy().dtype)),
atol=ATOL.get(str(self._paddle_geom.probs.numpy().dtype)),
)
def test_pmf_error(self):
self.assertRaises(TypeError, self._paddle_geom.pmf, [1, 2])
def test_log_pmf_error(self):
self.assertRaises(TypeError, self._paddle_geom.log_pmf, [1, 2])
def test_cdf_error(self):
self.assertRaises(TypeError, self._paddle_geom.cdf, [1, 2])
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs1', 'probs2'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
(
'multi-dim',
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
],
)
class TestGeometricKL(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self._geometric1 = geometric.Geometric(
probs=paddle.to_tensor(self.probs1)
)
self._geometric2 = geometric.Geometric(
probs=paddle.to_tensor(self.probs2)
)
def test_kl_divergence(self):
np.testing.assert_allclose(
kl.kl_divergence(self._geometric1, self._geometric2),
self._kl(),
rtol=RTOL.get(str(self._geometric1.probs.numpy().dtype)),
atol=ATOL.get(str(self._geometric1.probs.numpy().dtype)),
)
def test_kl1_error(self):
self.assertRaises(
TypeError,
self._geometric1.kl_divergence,
paddle.distribution.beta.Beta,
)
def test_kl2_error(self):
self.assertRaises(
TypeError,
self._geometric2.kl_divergence,
paddle.distribution.beta.Beta,
)
def _kl(self):
return self.probs1 * np.log(self.probs1 / self.probs2) + (
1.0 - self.probs1
) * np.log((1.0 - self.probs1) / (1.0 - self.probs2))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import scipy.stats
from config import ATOL, DEVICES, RTOL
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
import paddle
from paddle.distribution import geometric
np.random.seed(2023)
paddle.enable_static()
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
(
'multi-dim',
xrand(
(2, 3),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
],
)
class TestGeometric(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.program):
# scale no need convert to tensor for scale input unittest
probs = paddle.static.data(
'probs', self.probs.shape, self.probs.dtype
)
self._paddle_geometric = geometric.Geometric(probs)
self.feeds = {'probs': self.probs}
def test_mean(self):
with paddle.static.program_guard(self.program):
[mean] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.mean],
)
np.testing.assert_allclose(
mean,
scipy.stats.geom.mean(self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_variance(self):
with paddle.static.program_guard(self.program):
[variance] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.variance],
)
np.testing.assert_allclose(
variance,
scipy.stats.geom.var(self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_stddev(self):
with paddle.static.program_guard(self.program):
[stddev] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.stddev],
)
np.testing.assert_allclose(
stddev,
scipy.stats.geom.std(self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_sample(self):
with paddle.static.program_guard(self.program):
[data] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=self._paddle_geometric.sample(),
)
self.assertTrue(
data.shape, np.broadcast_arrays(self.probs)[0].shape
)
def test_rsample(self):
with paddle.static.program_guard(self.program):
[data] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=self._paddle_geometric.rsample(),
)
self.assertTrue(
data.shape, np.broadcast_arrays(self.probs)[0].shape
)
def test_entropy(self):
with paddle.static.program_guard(self.program):
[entropy] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.entropy()],
)
np.testing.assert_allclose(
entropy,
scipy.stats.geom.entropy(self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_init_prob_type_error(self):
with self.assertRaises(TypeError):
paddle.distribution.geometric.Geometric([0.5])
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs', 'value'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
(
'mult-dim',
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
(
'mult-dim',
xrand(
(2, 2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
5,
),
],
)
class TestGeometricPMF(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.program):
probs = paddle.static.data(
'probs', self.probs.shape, self.probs.dtype
)
self._paddle_geometric = geometric.Geometric(probs)
self.feeds = {'probs': self.probs, 'value': self.value}
def test_pmf(self):
with paddle.static.program_guard(self.program):
[pmf] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.pmf(self.value)],
)
np.testing.assert_allclose(
pmf,
scipy.stats.geom.pmf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_log_pmf(self):
with paddle.static.program_guard(self.program):
[log_pmf] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.log_pmf(self.value)],
)
np.testing.assert_allclose(
log_pmf,
scipy.stats.geom.logpmf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_cdf(self):
with paddle.static.program_guard(self.program):
[cdf] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_geometric.cdf(self.value)],
)
np.testing.assert_allclose(
cdf,
scipy.stats.geom.cdf(self.value, self.probs),
rtol=RTOL.get(str(self.probs.dtype)),
atol=ATOL.get(str(self.probs.dtype)),
)
def test_pmf_error(self):
self.assertRaises(TypeError, self._paddle_geometric.pmf, [1, 2])
def test_log_pmf_error(self):
self.assertRaises(TypeError, self._paddle_geometric.log_pmf, [1, 2])
def test_cdf_error(self):
self.assertRaises(TypeError, self._paddle_geometric.cdf, [1, 2])
@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'probs1', 'probs2'),
[
(
'one-dim',
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
xrand(
(2,),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
(
'multi-dim',
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
xrand(
(2, 2),
dtype='float32',
min=np.finfo(dtype='float32').tiny,
max=1.0,
),
),
],
)
class TestGeometricKL(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.program_p = paddle.static.Program()
self.program_q = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.program_p, self.program_q):
probs_p = paddle.static.data(
'probs1', self.probs1.shape, self.probs1.dtype
)
probs_q = paddle.static.data(
'probs2', self.probs2.shape, self.probs2.dtype
)
self._paddle_geomP = geometric.Geometric(probs_p)
self._paddle_geomQ = geometric.Geometric(probs_q)
self.feeds = {
'probs1': self.probs1,
'probs2': self.probs2,
}
def test_kl_divergence(self):
with paddle.static.program_guard(self.program_p, self.program_q):
self.executor.run(self.program_q)
[kl_diver] = self.executor.run(
self.program_p,
feed=self.feeds,
fetch_list=[
self._paddle_geomP.kl_divergence(self._paddle_geomQ)
],
)
np.testing.assert_allclose(
kl_diver,
self._kl(),
rtol=RTOL.get(str(self.probs1.dtype)),
atol=ATOL.get(str(self.probs1.dtype)),
)
def test_kl1_error(self):
self.assertRaises(
TypeError,
self._paddle_geomP.kl_divergence,
paddle.distribution.beta.Beta,
)
def test_kl2_error(self):
self.assertRaises(
TypeError,
self._paddle_geomQ.kl_divergence,
paddle.distribution.beta.Beta,
)
def _kl(self):
return self.probs1 * np.log(self.probs1 / self.probs2) + (
1.0 - self.probs1
) * np.log((1.0 - self.probs1) / (1.0 - self.probs2))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册