未验证 提交 f1a9f877 编写于 作者: Y YuRonan 提交者: GitHub

【Hackathon No.8】 add gumbel distribution api (#46255)

* init gumbel api

* commit: update test file

* fix:bug

* update Gumbel API

* upgrade distribution/gumbel.py

* add tests/test_distribution_gumbel.py

* fix:code style

* fix:code style

* fix:code style

* fix:code style

* fix bug

* fix:code style

* fix:code style

* fix:rollback uniform

* fix:delete invalid code

* fix:bug and add static test

* fix:code style

* fix:code style

* fix:delete init transforms

* fix:bug

* fix:bug

* fix:code style

* fix:code style

* fix:add transforms

* fix:code style

* fix:code style

* fix:bug

* fix:bug

* fix:code style

* fix:code style

* fix:bug

* fix:code style

* fix:code style

* fix:bug for gumbel.py / add:judge transforms'len for transformed_distribution.py

* update gumbel.py

* fix:bug for test_distribution_gumbel.py

* fix:bug for test_distribution_gumbel_static.py

* fix:code style

* fix:code style

* fix:coverage

* fix:bug

* fix:bug

* fix:code style

* fix:bug

* delete:no use package for gumbel.py

* add:coverage transforms'len judge for test_distribution_gumbel.py

* fix:code style for test_distribution_gumbel.py

* fix:coverage

* fix:code style

* fix:code style

* fix:code style

* fix:code style

* fix:code style

* fix:en doc

* fix:param

* fix:copyright

* fixSample; test=document_fix
Co-authored-by: Ndasen <sen15530876201@163.com>
上级 abb38136
...@@ -17,6 +17,7 @@ from paddle.distribution.beta import Beta ...@@ -17,6 +17,7 @@ from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution from paddle.distribution.distribution import Distribution
from paddle.distribution.gumbel import Gumbel
from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.independent import Independent from paddle.distribution.independent import Independent
from paddle.distribution.kl import kl_divergence, register_kl from paddle.distribution.kl import kl_divergence, register_kl
...@@ -32,7 +33,7 @@ from paddle.distribution.laplace import Laplace ...@@ -32,7 +33,7 @@ from paddle.distribution.laplace import Laplace
__all__ = [ # noqa __all__ = [ # noqa
'Beta', 'Categorical', 'Dirichlet', 'Distribution', 'ExponentialFamily', 'Beta', 'Categorical', 'Dirichlet', 'Distribution', 'ExponentialFamily',
'Multinomial', 'Normal', 'Uniform', 'kl_divergence', 'register_kl', 'Multinomial', 'Normal', 'Uniform', 'kl_divergence', 'register_kl',
'Independent', 'TransformedDistribution', 'Laplace', 'LogNormal' 'Independent', 'TransformedDistribution', 'Laplace', 'LogNormal', 'Gumbel'
] ]
__all__.extend(transform.__all__) __all__.extend(transform.__all__)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numbers
import math
import numpy as np
from paddle.distribution.transformed_distribution import TransformedDistribution
from paddle.fluid import framework as framework
class Gumbel(TransformedDistribution):
r"""The Gumbel distribution with location `loc` and `scale` parameters.
Mathematical details
The probability density function (pdf) is
.. math::
pdf(x; mu, sigma) = exp(-(x - mu) / sigma - exp(-(x - mu) / sigma)) / sigma
In the above equation:
* :math:`loc = \mu`: is the mean.
* :math:`scale = \sigma`: is the std.
Args:
loc(int|float|tensor): The mean of gumbel distribution.The data type is int, float, tensor.
scale(int|float|tensor): The std of gumbel distribution.The data type is int, float, tensor.
Examples:
.. code-block:: python
import paddle
from paddle.distribution.gumbel import Gumbel
# Gumbel distributed with loc=0, scale=1
dist = Gumbel(paddle.full([1], 0.0), paddle.full([1], 1.0))
dist.sample([2])
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[-0.27544352], [-0.64499271]])
value = paddle.full([1], 0.5)
dist.prob(value)
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.33070430])
dist.log_prob(value)
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [-1.10653067])
dist.cdf(value)
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.54523915])
dist.entropy()
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1.57721567])
dist.rsample([2])
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[0.80463481], [0.91893655]])
"""
def __init__(self, loc, scale):
if not isinstance(loc, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of loc is Real|Variable, but got {type(loc)}")
if not isinstance(scale, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of scale is Real|Variable, but got {type(scale)}"
)
if isinstance(loc, numbers.Real):
loc = paddle.full(shape=(), fill_value=loc)
if isinstance(scale, numbers.Real):
scale = paddle.full(shape=(), fill_value=scale)
if loc.shape != scale.shape:
self.loc, self.scale = paddle.broadcast_tensors([loc, scale])
else:
self.loc, self.scale = loc, scale
finfo = np.finfo(dtype='float32')
self.base_dist = paddle.distribution.Uniform(
paddle.full_like(self.loc, float(finfo.tiny)),
paddle.full_like(self.loc, float(1 - finfo.eps)))
self.transforms = ()
super(Gumbel, self).__init__(self.base_dist, self.transforms)
@property
def mean(self):
"""Mean of distribution
The mean is
.. math::
mean = \mu + \sigma * γ
In the above equation:
* :math:`loc = \mu`: is the location parameter.
* :math:`scale = \sigma`: is the scale parameter.
* :math:`γ`: is the euler's constant.
Returns:
Tensor: mean value.
"""
return self.loc + self.scale * np.euler_gamma
@property
def variance(self):
"""Variance of distribution.
The variance is
.. math::
variance = \sigma^2 * \pi^2 / 6
In the above equation:
* :math:`scale = \sigma`: is the scale parameter.
Returns:
Tensor: The variance value.
"""
temp = paddle.full(shape=self.loc.shape,
fill_value=math.pi * math.pi,
dtype=self.scale.dtype)
return paddle.pow(self.scale, 2) * temp / 6
@property
def stddev(self):
"""Standard deviation of distribution
The standard deviation is
.. math::
stddev = \sqrt{\sigma^2 * \pi^2 / 6}
In the above equation:
* :math:`scale = \sigma`: is the scale parameter.
Returns:
Tensor: std value
"""
return paddle.sqrt(self.variance)
def prob(self, value):
"""Probability density/mass function
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with value.
"""
y = (self.loc - value) / self.scale
return paddle.exp(y - paddle.exp(y)) / self.scale
def log_prob(self, value):
"""Log probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: log probability.The data type is same with value.
"""
return paddle.log(self.prob(value))
def cdf(self, value):
"""Cumulative distribution function.
Args:
value (Tensor): value to be evaluated.
Returns:
Tensor: cumulative probability of value.
"""
return paddle.exp(-paddle.exp(-(value - self.loc) / self.scale))
def entropy(self):
"""Entropy of Gumbel distribution.
Returns:
Entropy of distribution.
"""
return paddle.log(self.scale) + 1 + np.euler_gamma
def sample(self, shape):
"""Sample from ``Gumbel``.
Args:
shape (Sequence[int], optional): The sample shape. Defaults to ().
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
with paddle.no_grad():
return self.rsample(shape)
def rsample(self, shape):
"""reparameterized sample
Args:
shape (Sequence[int]): 1D `int32`. Shape of the generated samples.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
exp_trans = paddle.distribution.ExpTransform()
affine_trans_1 = paddle.distribution.AffineTransform(
paddle.full(shape=self.scale.shape,
fill_value=0,
dtype=self.loc.dtype), -paddle.ones_like(self.scale))
affine_trans_2 = paddle.distribution.AffineTransform(
self.loc, -self.scale)
return affine_trans_2.forward(
exp_trans.inverse(
affine_trans_1.forward(
exp_trans.inverse(self._base.sample(shape)))))
...@@ -62,15 +62,19 @@ class TransformedDistribution(distribution.Distribution): ...@@ -62,15 +62,19 @@ class TransformedDistribution(distribution.Distribution):
chain = transform.ChainTransform(transforms) chain = transform.ChainTransform(transforms)
base_shape = base.batch_shape + base.event_shape base_shape = base.batch_shape + base.event_shape
if len(base_shape) < chain._domain.event_rank: self._base = base
self._transforms = transforms
if not transforms:
super(TransformedDistribution,
self).__init__(base.batch_shape, base.event_shape)
return
if len(base.batch_shape + base.event_shape) < chain._domain.event_rank:
raise ValueError( raise ValueError(
f"'base' needs to have shape with size at least {chain._domain.event_rank}, but got {len(base_shape)}." f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}."
) )
if chain._domain.event_rank > len(base.event_shape): if chain._domain.event_rank > len(base.event_shape):
base = independent.Independent( base = independent.Independent(
(base, chain._domain.event_rank - len(base.event_shape))) (base, chain._domain.event_rank - len(base.event_shape)))
self._base = base
self._transforms = transforms
transformed_shape = chain.forward_shape(base.batch_shape + transformed_shape = chain.forward_shape(base.batch_shape +
base.event_shape) base.event_shape)
......
...@@ -123,6 +123,8 @@ class Uniform(distribution.Distribution): ...@@ -123,6 +123,8 @@ class Uniform(distribution.Distribution):
self.low = tensor.cast(self.low, dtype=self.dtype) self.low = tensor.cast(self.low, dtype=self.dtype)
self.high = tensor.cast(self.high, dtype=self.dtype) self.high = tensor.cast(self.high, dtype=self.dtype)
super(Uniform, self).__init__(self.low.shape)
def sample(self, shape, seed=0): def sample(self, shape, seed=0):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import parameterize
from paddle.distribution.gumbel import Gumbel
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls((parameterize.TEST_CASE_NAME, 'loc', 'scale'), [
('one-dim', parameterize.xrand((4, )), parameterize.xrand((4, ))),
('multi-dim', parameterize.xrand((5, 3)), parameterize.xrand((5, 3))),
])
class TestGumbel(unittest.TestCase):
def setUp(self):
self._dist = Gumbel(loc=paddle.to_tensor(self.loc),
scale=paddle.to_tensor(self.scale))
def test_mean(self):
mean = self._dist.mean
self.assertEqual(mean.numpy().dtype, self._np_mean().dtype)
np.testing.assert_allclose(mean,
self._np_mean(),
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_variance(self):
var = self._dist.variance
self.assertEqual(var.numpy().dtype, self._np_variance().dtype)
np.testing.assert_allclose(var,
self._np_variance(),
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_stddev(self):
stddev = self._dist.stddev
self.assertEqual(stddev.numpy().dtype, self._np_stddev().dtype)
np.testing.assert_allclose(stddev,
self._np_stddev(),
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_entropy(self):
entropy = self._dist.entropy()
self.assertEqual(entropy.numpy().dtype, self._np_entropy().dtype)
np.testing.assert_allclose(entropy,
self._np_entropy(),
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_sample(self):
sample_shape = [10000]
samples = self._dist.sample(sample_shape)
sample_values = samples.numpy()
self.assertEqual(sample_values.dtype, self.scale.dtype)
np.testing.assert_allclose(sample_values.mean(axis=0),
scipy.stats.gumbel_r.mean(self.loc,
scale=self.scale),
rtol=0.1,
atol=config.ATOL.get(str(self.loc.dtype)))
np.testing.assert_allclose(sample_values.var(axis=0),
scipy.stats.gumbel_r.var(self.loc,
scale=self.scale),
rtol=0.1,
atol=config.ATOL.get(str(self.loc.dtype)))
def test_rsample(self):
sample_shape = [10000]
samples = self._dist.rsample(sample_shape)
sample_values = samples.numpy()
self.assertEqual(sample_values.dtype, self.scale.dtype)
np.testing.assert_allclose(sample_values.mean(axis=0),
scipy.stats.gumbel_r.mean(self.loc,
scale=self.scale),
rtol=0.1,
atol=config.ATOL.get(str(self.loc.dtype)))
np.testing.assert_allclose(sample_values.var(axis=0),
scipy.stats.gumbel_r.var(self.loc,
scale=self.scale),
rtol=0.1,
atol=config.ATOL.get(str(self.loc.dtype)))
def _np_mean(self):
return self.loc + self.scale * np.euler_gamma
def _np_stddev(self):
return np.sqrt(self._np_variance())
def _np_variance(self):
return np.divide(
np.multiply(np.power(self.scale, 2), np.power(np.pi, 2)), 6)
def _np_entropy(self):
return np.log(self.scale) + 1 + np.euler_gamma
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'loc', 'scale', 'value'), [
('value-float', np.array([0.1, 0.4]), np.array([1., 4.
]), np.array([3., 7.])),
('value-int', np.array([0.1, 0.4]), np.array([1, 4]), np.array([3, 7])),
('value-multi-dim', np.array([0.1, 0.4]), np.array(
[1, 4]), np.array([[5., 4], [6, 2]])),
])
class TestGumbelPDF(unittest.TestCase):
def setUp(self):
self._dist = Gumbel(loc=paddle.to_tensor(self.loc),
scale=paddle.to_tensor(self.scale))
def test_prob(self):
np.testing.assert_allclose(
self._dist.prob(paddle.to_tensor(self.value)),
scipy.stats.gumbel_r.pdf(self.value, self.loc, self.scale),
rtol=config.RTOL.get(str(self.loc.dtype)),
atol=config.ATOL.get(str(self.loc.dtype)))
def test_log_prob(self):
np.testing.assert_allclose(
self._dist.log_prob(paddle.to_tensor(self.value)),
scipy.stats.gumbel_r.logpdf(self.value, self.loc, self.scale),
rtol=config.RTOL.get(str(self.loc.dtype)),
atol=config.ATOL.get(str(self.loc.dtype)))
def test_cdf(self):
np.testing.assert_allclose(self._dist.cdf(paddle.to_tensor(self.value)),
scipy.stats.gumbel_r.cdf(
self.value, self.loc, self.scale),
rtol=0.02,
atol=config.ATOL.get(str(self.loc.dtype)))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import scipy.stats
import paddle
import config
import parameterize
from paddle.distribution.gumbel import Gumbel
paddle.enable_static()
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls((parameterize.TEST_CASE_NAME, 'loc', 'scale'), [
('one-dim', parameterize.xrand((4, )), parameterize.xrand((4, ))),
('multi-dim', parameterize.xrand((5, 3)), parameterize.xrand((5, 3))),
])
class TestGumbel(unittest.TestCase):
def setUp(self):
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
self._dist = Gumbel(loc=loc, scale=scale)
self.sample_shape = [50000]
mean = self._dist.mean
var = self._dist.variance
stddev = self._dist.stddev
entropy = self._dist.entropy()
samples = self._dist.sample(self.sample_shape)
fetch_list = [mean, var, stddev, entropy, samples]
self.feeds = {'loc': self.loc, 'scale': self.scale}
executor.run(startup_program)
[self.mean, self.var, self.stddev, self.entropy,
self.samples] = executor.run(main_program,
feed=self.feeds,
fetch_list=fetch_list)
def test_mean(self):
self.assertEqual(str(self.mean.dtype).split('.')[-1], self.scale.dtype)
np.testing.assert_allclose(self.mean,
self._np_mean(),
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_variance(self):
self.assertEqual(str(self.var.dtype).split('.')[-1], self.scale.dtype)
np.testing.assert_allclose(self.var,
self._np_variance(),
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_stddev(self):
self.assertEqual(
str(self.stddev.dtype).split('.')[-1], self.scale.dtype)
np.testing.assert_allclose(self.stddev,
self._np_stddev(),
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_entropy(self):
self.assertEqual(
str(self.entropy.dtype).split('.')[-1], self.scale.dtype)
def test_sample(self):
self.assertEqual(self.samples.dtype, self.scale.dtype)
np.testing.assert_allclose(self.samples.mean(axis=0),
scipy.stats.gumbel_r.mean(self.loc,
scale=self.scale),
rtol=0.1,
atol=config.ATOL.get(str(self.scale.dtype)))
np.testing.assert_allclose(self.samples.var(axis=0),
scipy.stats.gumbel_r.var(self.loc,
scale=self.scale),
rtol=0.1,
atol=config.ATOL.get(str(self.scale.dtype)))
def _np_mean(self):
return self.loc + self.scale * np.euler_gamma
def _np_stddev(self):
return np.sqrt(self._np_variance())
def _np_variance(self):
return np.divide(
np.multiply(np.power(self.scale, 2), np.power(np.pi, 2)), 6)
def _np_entropy(self):
return np.log(self.scale) + 1 + np.euler_gamma
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'loc', 'scale', 'value'), [
('value-float', np.array([0.1, 0.4]), np.array([1., 4.
]), np.array([3., 7.])),
('value-int', np.array([0.1, 0.4]), np.array([1, 4]), np.array([3, 7])),
('value-multi-dim', np.array([0.1, 0.4]), np.array(
[1, 4]), np.array([[5., 4], [6, 2]])),
])
class TestGumbelPDF(unittest.TestCase):
def setUp(self):
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
value = paddle.static.data('value', self.value.shape,
self.value.dtype)
self._dist = Gumbel(loc=loc, scale=scale)
prob = self._dist.prob(value)
log_prob = self._dist.log_prob(value)
cdf = self._dist.cdf(value)
fetch_list = [prob, log_prob, cdf]
self.feeds = {'loc': self.loc, 'scale': self.scale, 'value': self.value}
executor.run(startup_program)
[self.prob, self.log_prob,
self.cdf] = executor.run(main_program,
feed=self.feeds,
fetch_list=fetch_list)
def test_prob(self):
np.testing.assert_allclose(self.prob,
scipy.stats.gumbel_r.pdf(
self.value, self.loc, self.scale),
rtol=config.RTOL.get(str(self.loc.dtype)),
atol=config.ATOL.get(str(self.loc.dtype)))
def test_log_prob(self):
np.testing.assert_allclose(self.log_prob,
scipy.stats.gumbel_r.logpdf(
self.value, self.loc, self.scale),
rtol=config.RTOL.get(str(self.loc.dtype)),
atol=config.ATOL.get(str(self.loc.dtype)))
def test_cdf(self):
np.testing.assert_allclose(self.cdf,
scipy.stats.gumbel_r.cdf(
self.value, self.loc, self.scale),
rtol=0.3,
atol=config.ATOL.get(str(self.loc.dtype)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册