未验证 提交 4794a44f 编写于 作者: X Xiaoxu Chen 提交者: GitHub

Probability distribution API of Beta and KL-Divergence (#38558)

* add beta distribution
* add kl_divergence and register_kl api
上级 761055f0
...@@ -12,17 +12,23 @@ ...@@ -12,17 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .beta import Beta
from .categorical import Categorical from .categorical import Categorical
from .dirichlet import Dirichlet from .dirichlet import Dirichlet
from .distribution import Distribution from .distribution import Distribution
from .exponential_family import ExponentialFamily from .exponential_family import ExponentialFamily
from .kl import kl_divergence, register_kl
from .normal import Normal from .normal import Normal
from .uniform import Uniform from .uniform import Uniform
__all__ = [ #noqa __all__ = [ # noqa
'Beta',
'Categorical', 'Categorical',
'Dirichlet',
'Distribution', 'Distribution',
'Normal', 'Uniform',
'ExponentialFamily', 'ExponentialFamily',
'Dirichlet' 'Normal',
'Uniform',
'kl_divergence',
'register_kl'
] ]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import paddle
from .dirichlet import Dirichlet
from .exponential_family import ExponentialFamily
class Beta(ExponentialFamily):
r"""
Beta distribution parameterized by alpha and beta
The probability density function (pdf) is
.. math::
f(x; \alpha, \beta) = \frac{1}{B(\alpha, \beta)}x^{\alpha-1}(1-x)^{\beta-1}
where the normalization, B, is the beta function,
.. math::
B(\alpha, \beta) = \int_{0}^{1} t^{\alpha - 1} (1-t)^{\beta - 1}\mathrm{d}t
Args:
alpha (float|Tensor): alpha parameter of beta distribution, positive(>0).
beta (float|Tensor): beta parameter of beta distribution, positive(>0).
Examples:
.. code-block:: python
import paddle
# scale input
beta = paddle.distribution.Beta(alpha=0.5, beta=0.5)
print(beta.mean)
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.50000000])
print(beta.variance)
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.12500000])
print(beta.entropy())
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.12500000])
# tensor input with broadcast
beta = paddle.distribution.Beta(alpha=paddle.to_tensor([0.2, 0.4]), beta=0.6)
print(beta.mean)
# Tensor(shape=[2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.25000000, 0.40000001])
print(beta.variance)
# Tensor(shape=[2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.10416666, 0.12000000])
print(beta.entropy())
# Tensor(shape=[2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [-1.91923141, -0.38095069])
"""
def __init__(self, alpha, beta):
if isinstance(alpha, numbers.Real):
alpha = paddle.full(shape=[1], fill_value=alpha)
if isinstance(beta, numbers.Real):
beta = paddle.full(shape=[1], fill_value=beta)
self.alpha, self.beta = paddle.broadcast_tensors([alpha, beta])
self._dirichlet = Dirichlet(paddle.stack([self.alpha, self.beta], -1))
super(Beta, self).__init__(self._dirichlet._batch_shape)
@property
def mean(self):
"""mean of beta distribution.
"""
return self.alpha / (self.alpha + self.beta)
@property
def variance(self):
"""variance of beat distribution
"""
sum = self.alpha + self.beta
return self.alpha * self.beta / (sum.pow(2) * (sum + 1))
def prob(self, value):
"""probability density funciotn evaluated at value
Args:
value (Tensor): value to be evaluated.
Returns:
Tensor: probability.
"""
return paddle.exp(self.log_prob(value))
def log_prob(self, value):
"""log probability density funciton evaluated at value
Args:
value (Tensor): value to be evaluated
Returns:
Tensor: log probability.
"""
return self._dirichlet.log_prob(paddle.stack([value, 1.0 - value], -1))
def sample(self, shape=()):
"""sample from beta distribution with sample shape.
Args:
shape (Sequence[int], optional): sample shape.
Returns:
sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
"""
shape = shape if isinstance(shape, tuple) else tuple(shape)
return paddle.squeeze(self._dirichlet.sample(shape)[..., 0])
def entropy(self):
"""entropy of dirichlet distribution
Returns:
Tensor: entropy.
"""
return self._dirichlet.entropy()
@property
def _natural_parameters(self):
return (self.alpha, self.beta)
def _log_normalizer(self, x, y):
return paddle.lgamma(x) + paddle.lgamma(y) - paddle.lgamma(x + y)
...@@ -90,7 +90,7 @@ class Dirichlet(ExponentialFamily): ...@@ -90,7 +90,7 @@ class Dirichlet(ExponentialFamily):
"""sample from dirichlet distribution. """sample from dirichlet distribution.
Args: Args:
shape (Tensor, optional): sample shape. Defaults to empty tuple. shape (Sequence[int], optional): sample shape. Defaults to empty tuple.
""" """
shape = shape if isinstance(shape, tuple) else tuple(shape) shape = shape if isinstance(shape, tuple) else tuple(shape)
return _dirichlet(self.concentration.expand(self._extend_shape(shape))) return _dirichlet(self.concentration.expand(self._extend_shape(shape)))
...@@ -139,24 +139,21 @@ class Dirichlet(ExponentialFamily): ...@@ -139,24 +139,21 @@ class Dirichlet(ExponentialFamily):
def _dirichlet(concentration, name=None): def _dirichlet(concentration, name=None):
raise NotImplementedError op_type = 'dirichlet'
check_variable_and_dtype(concentration, 'concentration',
# op_type = 'dirichlet' ['float32', 'float64'], op_type)
# check_variable_and_dtype(concentration, 'concentration', if in_dygraph_mode():
# ['float32', 'float64'], op_type) return paddle._C_ops.dirichlet(concentration)
# if in_dygraph_mode(): else:
# return paddle._C_ops.dirichlet(concentration) helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(
# else: dtype=concentration.dtype)
# helper = LayerHelper(op_type, **locals()) helper.append_op(
# out = helper.create_variable_for_type_inference( type=op_type,
# dtype=concentration.dtype) inputs={"Alpha": concentration},
# helper.append_op( outputs={'Out': out},
# type=op_type, attrs={})
# inputs={"Alpha": concentration}, return out
# outputs={'Out': out},
# attrs={})
# return out
...@@ -40,6 +40,14 @@ class Distribution(object): ...@@ -40,6 +40,14 @@ class Distribution(object):
""" """
The abstract base class for probability distributions. Functions are The abstract base class for probability distributions. Functions are
implemented in specific distributions. implemented in specific distributions.
Args:
batch_shape(Sequence[int], optional): independent, not identically
distributed draws, aka a "collection" or "bunch" of distributions.
event_shape(Sequence[int], optional): the shape of a single
draw from the distribution; it may be dependent across dimensions.
For scalar distributions, the event shape is []. For n-dimension
multivariate distribution, the event shape is [n].
""" """
def __init__(self, batch_shape=(), event_shape=()): def __init__(self, batch_shape=(), event_shape=()):
...@@ -56,7 +64,7 @@ class Distribution(object): ...@@ -56,7 +64,7 @@ class Distribution(object):
"""Returns batch shape of distribution """Returns batch shape of distribution
Returns: Returns:
Tensor: batch shape Sequence[int]: batch shape
""" """
return self._batch_shape return self._batch_shape
...@@ -65,7 +73,7 @@ class Distribution(object): ...@@ -65,7 +73,7 @@ class Distribution(object):
"""Returns event shape of distribution """Returns event shape of distribution
Returns: Returns:
Tensor: event shape Sequence[int]: event shape
""" """
return self._event_shape return self._event_shape
......
...@@ -19,7 +19,20 @@ from .distribution import Distribution ...@@ -19,7 +19,20 @@ from .distribution import Distribution
class ExponentialFamily(Distribution): class ExponentialFamily(Distribution):
""" Base class for exponential family distribution. r"""
ExponentialFamily is the base class for probability distributions belonging
to exponential family, whose probability mass/density function has the
form is defined below
ExponentialFamily is derived from `paddle.distribution.Distribution`.
.. math::
f_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes
the sufficient statistic, :math:`F(\theta)` is the log normalizer function
for a given family and :math:`k(x)` is the carrier measure.
""" """
@property @property
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings
import paddle
from ..fluid.framework import in_dygraph_mode
from .beta import Beta
from .categorical import Categorical
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential_family import ExponentialFamily
from .normal import Normal
from .uniform import Uniform
__all__ = ["register_kl", "kl_divergence"]
_REGISTER_TABLE = {}
def kl_divergence(p, q):
r"""
Kullback-Leibler divergence between distribution p and q.
.. math::
KL(p||q) = \int p(x)log\frac{p(x)}{q(x)} \mathrm{d}x
Args:
p (Distribution): ``Distribution`` object.
q (Distribution): ``Distribution`` object.
Returns:
Tensor: batchwise KL-divergence between distribution p and q.
Raises:
NotImplementedError: can't find register function for KL(p||Q).
Examples:
.. code-block:: python
import paddle
p = paddle.distribution.Beta(alpha=0.5, beta=0.5)
q = paddle.distribution.Beta(alpha=0.3, beta=0.7)
print(paddle.distribution.kl_divergence(p, q))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.21193528])
"""
return _dispatch(type(p), type(q))(p, q)
def register_kl(cls_p, cls_q):
"""Decorator for register a KL divergence implemention function.
Args:
cls_p(Distribution): subclass derived from ``Distribution``.
cls_q(Distribution): subclass derived from ``Distribution``.
Examples:
.. code-block:: python
import paddle
@paddle.distribution.register_kl(paddle.distribution.Beta, paddle.distribution.Beta)
def kl_beta_beta():
pass # insert implementation here
"""
if (not issubclass(cls_p, Distribution) or
not issubclass(cls_q, Distribution)):
raise TypeError('cls_p and cls_q must be subclass of Distribution')
def decorator(f):
_REGISTER_TABLE[cls_p, cls_q] = f
return f
return decorator
def _dispatch(cls_p, cls_q):
"""multiple dispatch into concrete implement function"""
# find all matched super class pair of p and q
matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE
if issubclass(cls_p, super_p) and issubclass(cls_q, super_q)]
if not matchs:
raise NotImplementedError
left_p, left_q = min(_Compare(*m) for m in matchs).classes
right_p, right_q = min(_Compare(*reversed(m)) for m in matchs).classes
if _REGISTER_TABLE[left_p, left_q] is not _REGISTER_TABLE[right_p, right_q]:
warnings.warn(
'Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'.
format(cls_p.__name__, cls_q.__name__, left_p.__name__,
right_q.__name__), RuntimeWarning)
return _REGISTER_TABLE[left_p, left_q]
@functools.total_ordering
class _Compare(object):
def __init__(self, *classes):
self.classes = classes
def __eq__(self, other):
return self.classes == other.classes
def __le__(self, other):
for cls_x, cls_y in zip(self.classes, other.classes):
if not issubclass(cls_x, cls_y):
return False
if cls_x is not cls_y:
break
return True
@register_kl(Beta, Beta)
def _kl_beta_beta(p, q):
return ((q.alpha.lgamma() + q.beta.lgamma() + (p.alpha + p.beta).lgamma()) -
(p.alpha.lgamma() + p.beta.lgamma() + (q.alpha + q.beta).lgamma()) +
((p.alpha - q.alpha) * p.alpha.digamma()) + (
(p.beta - q.beta) * p.beta.digamma()) + (
((q.alpha + q.beta) -
(p.alpha + p.beta)) * (p.alpha + p.beta).digamma()))
@register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(p, q):
return (
(p.concentration.sum(-1).lgamma() - q.concentration.sum(-1).lgamma()) -
((p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)) + (
((p.concentration - q.concentration) *
(p.concentration.digamma() -
p.concentration.sum(-1).digamma().unsqueeze(-1))).sum(-1)))
@register_kl(Categorical, Categorical)
def _kl_categorical_categorical(p, q):
return p.kl_divergence(q)
@register_kl(Normal, Normal)
def _kl_normal_normal(p, q):
return p.kl_divergence(q)
@register_kl(Uniform, Uniform)
def _kl_uniform_uniform(p, q):
return p.kl_divergence(q)
@register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q):
"""compute kl-divergence using `Bregman divergences`
https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf
"""
if not type(p) == type(q):
raise NotImplementedError
p_natural_params = []
for param in p._natural_parameters:
param = param.detach()
param.stop_gradient = False
p_natural_params.append(param)
q_natural_params = q._natural_parameters
p_log_norm = p._log_normalizer(*p_natural_params)
try:
if in_dygraph_mode():
p_grads = paddle.grad(
p_log_norm, p_natural_params, create_graph=True)
else:
p_grads = paddle.static.gradients(p_log_norm, p_natural_params)
except RuntimeError as e:
raise TypeError(
"Cann't compute kl_divergence({cls_p}, {cls_q}) use bregman divergence. Please register_kl({cls_p}, {cls_q}).".
format(
cls_p=type(p).__name__, cls_q=type(q).__name__)) from e
kl = q._log_normalizer(*q_natural_params) - p_log_norm
for p_param, q_param, p_grad in zip(p_natural_params, q_natural_params,
p_grads):
term = (q_param - p_param) * p_grad
kl -= _sum_rightmost(term, len(q.event_shape))
return kl
def _sum_rightmost(value, n):
"""sum value along rightmost n dim"""
return value.sum(list(range(-n, 0))) if n > 0 else value
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -40,13 +40,6 @@ class Exponential(paddle.distribution.ExponentialFamily): ...@@ -40,13 +40,6 @@ class Exponential(paddle.distribution.ExponentialFamily):
return -paddle.log(-x) return -paddle.log(-x)
# @paddle.distribution.register_kl(Exponential, Exponential)
# def _kl_exponential_exponential(p, q):
# rate_ratio = q.rate / p.rate
# t1 = -rate_ratio.log()
# return t1 + rate_ratio - 1
class DummyExpFamily(paddle.distribution.ExponentialFamily): class DummyExpFamily(paddle.distribution.ExponentialFamily):
"""dummy class extend from exponential family """dummy class extend from exponential family
""" """
...@@ -63,3 +56,10 @@ class DummyExpFamily(paddle.distribution.ExponentialFamily): ...@@ -63,3 +56,10 @@ class DummyExpFamily(paddle.distribution.ExponentialFamily):
def _log_normalizer(self, x): def _log_normalizer(self, x):
return -paddle.log(-x) return -paddle.log(-x)
@paddle.distribution.register_kl(Exponential, Exponential)
def _kl_exponential_exponential(p, q):
rate_ratio = q.rate / p.rate
t1 = -rate_ratio.log()
return t1 + rate_ratio - 1
...@@ -164,7 +164,3 @@ class TestDistributionShape(unittest.TestCase): ...@@ -164,7 +164,3 @@ class TestDistributionShape(unittest.TestCase):
self.assertTrue( self.assertTrue(
self.dist._extend_shape(shape), self.dist._extend_shape(shape),
shape + self.dist.batch_shape + self.dist.event_shape) shape + self.dist.batch_shape + self.dist.event_shape)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place,
xrand)
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'alpha', 'beta'),
[('test-scale', 1.0, 2.0), ('test-tensor', xrand(), xrand()),
('test-broadcast', xrand((2, 1)), xrand((2, 5)))])
class TestBeta(unittest.TestCase):
def setUp(self):
# scale no need convert to tensor for scale input unittest
alpha, beta = self.alpha, self.beta
if not isinstance(self.alpha, numbers.Real):
alpha = paddle.to_tensor(self.alpha)
if not isinstance(self.beta, numbers.Real):
beta = paddle.to_tensor(self.beta)
self._paddle_beta = paddle.distribution.Beta(alpha, beta)
def test_mean(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_beta.mean,
scipy.stats.beta.mean(self.alpha, self.beta),
rtol=RTOL.get(str(self._paddle_beta.alpha.numpy().dtype)),
atol=ATOL.get(str(self._paddle_beta.alpha.numpy().dtype)))
def test_variance(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_beta.variance,
scipy.stats.beta.var(self.alpha, self.beta),
rtol=RTOL.get(str(self._paddle_beta.alpha.numpy().dtype)),
atol=ATOL.get(str(self._paddle_beta.alpha.numpy().dtype)))
def test_prob(self):
value = [np.random.rand(*self._paddle_beta.alpha.shape)]
for v in value:
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_beta.prob(paddle.to_tensor(v)),
scipy.stats.beta.pdf(v, self.alpha, self.beta),
rtol=RTOL.get(str(self._paddle_beta.alpha.numpy().dtype)),
atol=ATOL.get(str(self._paddle_beta.alpha.numpy().dtype)))
def test_log_prob(self):
value = [np.random.rand(*self._paddle_beta.alpha.shape)]
for v in value:
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_beta.log_prob(paddle.to_tensor(v)),
scipy.stats.beta.logpdf(v, self.alpha, self.beta),
rtol=RTOL.get(str(self._paddle_beta.alpha.numpy().dtype)),
atol=ATOL.get(str(self._paddle_beta.alpha.numpy().dtype)))
def test_entropy(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_beta.entropy(),
scipy.stats.beta.entropy(self.alpha, self.beta),
rtol=RTOL.get(str(self._paddle_beta.alpha.numpy().dtype)),
atol=ATOL.get(str(self._paddle_beta.alpha.numpy().dtype)))
def test_sample_shape(self):
cases = [
{
'input': [],
'expect': [] + paddle.squeeze(self._paddle_beta.alpha).shape
},
{
'input': [2, 3],
'expect': [2, 3] + paddle.squeeze(self._paddle_beta.alpha).shape
},
]
for case in cases:
self.assertTrue(
self._paddle_beta.sample(case.get('input')).shape ==
case.get('expect'))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place,
xrand)
paddle.enable_static()
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'alpha', 'beta'), [('test-tensor', xrand(
(10, 10)), xrand((10, 10))), ('test-broadcast', xrand((2, 1)), xrand(
(2, 5))), ('test-larger-data', xrand((10, 20)), xrand((10, 20)))])
class TestBeta(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.program):
# scale no need convert to tensor for scale input unittest
alpha = paddle.static.data('alpha', self.alpha.shape,
self.alpha.dtype)
beta = paddle.static.data('beta', self.beta.shape, self.beta.dtype)
self._paddle_beta = paddle.distribution.Beta(alpha, beta)
self.feeds = {'alpha': self.alpha, 'beta': self.beta}
def test_mean(self):
with paddle.static.program_guard(self.program):
[mean] = self.executor.run(self.program,
feed=self.feeds,
fetch_list=[self._paddle_beta.mean])
np.testing.assert_allclose(
mean,
scipy.stats.beta.mean(self.alpha, self.beta),
rtol=RTOL.get(str(self.alpha.dtype)),
atol=ATOL.get(str(self.alpha.dtype)))
def test_variance(self):
with paddle.static.program_guard(self.program):
[variance] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_beta.variance])
np.testing.assert_allclose(
variance,
scipy.stats.beta.var(self.alpha, self.beta),
rtol=RTOL.get(str(self.alpha.dtype)),
atol=ATOL.get(str(self.alpha.dtype)))
def test_prob(self):
with paddle.static.program_guard(self.program):
value = paddle.static.data('value', self._paddle_beta.alpha.shape,
self._paddle_beta.alpha.dtype)
prob = self._paddle_beta.prob(value)
random_number = np.random.rand(*self._paddle_beta.alpha.shape)
feeds = dict(self.feeds, value=random_number)
[prob] = self.executor.run(self.program,
feed=feeds,
fetch_list=[prob])
np.testing.assert_allclose(
prob,
scipy.stats.beta.pdf(random_number, self.alpha, self.beta),
rtol=RTOL.get(str(self.alpha.dtype)),
atol=ATOL.get(str(self.alpha.dtype)))
def test_log_prob(self):
with paddle.static.program_guard(self.program):
value = paddle.static.data('value', self._paddle_beta.alpha.shape,
self._paddle_beta.alpha.dtype)
prob = self._paddle_beta.log_prob(value)
random_number = np.random.rand(*self._paddle_beta.alpha.shape)
feeds = dict(self.feeds, value=random_number)
[prob] = self.executor.run(self.program,
feed=feeds,
fetch_list=[prob])
np.testing.assert_allclose(
prob,
scipy.stats.beta.logpdf(random_number, self.alpha, self.beta),
rtol=RTOL.get(str(self.alpha.dtype)),
atol=ATOL.get(str(self.alpha.dtype)))
def test_entropy(self):
with paddle.static.program_guard(self.program):
[entropy] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_beta.entropy()])
np.testing.assert_allclose(
entropy,
scipy.stats.beta.entropy(self.alpha, self.beta),
rtol=RTOL.get(str(self.alpha.dtype)),
atol=ATOL.get(str(self.alpha.dtype)))
def test_sample(self):
with paddle.static.program_guard(self.program):
[data] = self.executor.run(self.program,
feed=self.feeds,
fetch_list=self._paddle_beta.sample())
self.assertTrue(data.shape,
np.broadcast_arrays(self.alpha, self.beta)[0].shape)
...@@ -102,12 +102,3 @@ class TestDirichlet(unittest.TestCase): ...@@ -102,12 +102,3 @@ class TestDirichlet(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
paddle.distribution.Dirichlet( paddle.distribution.Dirichlet(
paddle.squeeze(self.concentration)) paddle.squeeze(self.concentration))
def TestSample(self):
with self.assertRaises(NotImplementedError):
paddle.distribution.Dirichlet(
paddle.to_tensor(self.concentration)).sample()
if __name__ == '__main__':
unittest.main()
...@@ -104,7 +104,3 @@ class TestDirichlet(unittest.TestCase): ...@@ -104,7 +104,3 @@ class TestDirichlet(unittest.TestCase):
scipy.stats.dirichlet.entropy(self.concentration), scipy.stats.dirichlet.entropy(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)), rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype))) atol=ATOL.get(str(self.concentration.dtype)))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -40,14 +40,13 @@ class TestExponentialFamily(unittest.TestCase): ...@@ -40,14 +40,13 @@ class TestExponentialFamily(unittest.TestCase):
@config.place(config.DEVICES) @config.place(config.DEVICES)
@config.parameterize( @config.parameterize(
(config.TEST_CASE_NAME, 'dist'), (config.TEST_CASE_NAME, 'dist'),
[('test-dummy-dist', mock.DummyExpFamily(0.5, 0.5)), [('test-dummy', mock.DummyExpFamily(0.5, 0.5)),
('test-dirichlet-dist', ('test-dirichlet',
paddle.distribution.Dirichlet(paddle.to_tensor(config.xrand())))]) paddle.distribution.Dirichlet(paddle.to_tensor(config.xrand()))), (
'test-beta', paddle.distribution.Beta(
paddle.to_tensor(config.xrand()),
paddle.to_tensor(config.xrand())))])
class TestExponentialFamilyException(unittest.TestCase): class TestExponentialFamilyException(unittest.TestCase):
def test_entropy_expection(self): def test_entropy_exception(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
paddle.distribution.ExponentialFamily.entropy(self.dist) paddle.distribution.ExponentialFamily.entropy(self.dist)
if __name__ == '__main__':
unittest.main()
...@@ -52,12 +52,8 @@ class TestExponentialFamily(unittest.TestCase): ...@@ -52,12 +52,8 @@ class TestExponentialFamily(unittest.TestCase):
rtol=config.RTOL.get(config.DEFAULT_DTYPE), rtol=config.RTOL.get(config.DEFAULT_DTYPE),
atol=config.ATOL.get(config.DEFAULT_DTYPE)) atol=config.ATOL.get(config.DEFAULT_DTYPE))
def test_entropy_expection(self): def test_entropy_exception(self):
with paddle.static.program_guard(self.program): with paddle.static.program_guard(self.program):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
paddle.distribution.ExponentialFamily.entropy( paddle.distribution.ExponentialFamily.entropy(
mock.DummyExpFamily(0.5, 0.5)) mock.DummyExpFamily(0.5, 0.5))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.special
import scipy.stats
from paddle.distribution import kl
import config
import mock_data as mock
paddle.set_default_dtype('float64')
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [
('test_regular_input', 6.0 * np.random.random((4, 5)) + 1e-4,
6.0 * np.random.random((4, 5)) + 1e-4, 6.0 * np.random.random(
(4, 5)) + 1e-4, 6.0 * np.random.random((4, 5)) + 1e-4),
])
class TestKLBetaBeta(unittest.TestCase):
def setUp(self):
self.p = paddle.distribution.Beta(
paddle.to_tensor(self.a1), paddle.to_tensor(self.b1))
self.q = paddle.distribution.Beta(
paddle.to_tensor(self.a2), paddle.to_tensor(self.b2))
def test_kl_divergence(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
paddle.distribution.kl_divergence(self.p, self.q),
self.scipy_kl_beta_beta(self.a1, self.b1, self.a2, self.b2),
rtol=config.RTOL.get(str(self.a1.dtype)),
atol=config.ATOL.get(str(self.a1.dtype)))
def scipy_kl_beta_beta(self, a1, b1, a2, b2):
return (scipy.special.betaln(a2, b2) - scipy.special.betaln(a1, b1) +
(a1 - a2) * scipy.special.digamma(a1) +
(b1 - b2) * scipy.special.digamma(b1) +
(a2 - a1 + b2 - b1) * scipy.special.digamma(a1 + b1))
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'conc1', 'conc2'), [
('test-regular-input', np.random.random((5, 7, 8, 10)), np.random.random(
(5, 7, 8, 10))),
])
class TestKLDirichletDirichlet(unittest.TestCase):
def setUp(self):
self.p = paddle.distribution.Dirichlet(paddle.to_tensor(self.conc1))
self.q = paddle.distribution.Dirichlet(paddle.to_tensor(self.conc2))
def test_kl_divergence(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
paddle.distribution.kl_divergence(self.p, self.q),
self.scipy_kl_diric_diric(self.conc1, self.conc2),
rtol=config.RTOL.get(str(self.conc1.dtype)),
atol=config.ATOL.get(str(self.conc1.dtype)))
def scipy_kl_diric_diric(self, conc1, conc2):
return (
scipy.special.gammaln(np.sum(conc1, -1)) -
scipy.special.gammaln(np.sum(conc2, -1)) - np.sum(
scipy.special.gammaln(conc1) - scipy.special.gammaln(conc2), -1)
+ np.sum((conc1 - conc2) *
(scipy.special.digamma(conc1) -
scipy.special.digamma(np.sum(conc1, -1, keepdims=True))),
-1))
class DummyDistribution(paddle.distribution.Distribution):
pass
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'p', 'q'),
[('test-unregister', DummyDistribution(), DummyDistribution)])
class TestDispatch(unittest.TestCase):
def test_dispatch_with_unregister(self):
with self.assertRaises(NotImplementedError):
paddle.distribution.kl_divergence(self.p, self.q)
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'p', 'q'),
[('test-diff-dist', mock.Exponential(paddle.rand((100, 200, 100)) + 1.0),
mock.Exponential(paddle.rand((100, 200, 100)) + 2.0)),
('test-same-dist', mock.Exponential(paddle.to_tensor(1.0)),
mock.Exponential(paddle.to_tensor(1.0)))])
class TestKLExpfamilyExpFamily(unittest.TestCase):
def test_kl_expfamily_expfamily(self):
np.testing.assert_allclose(
paddle.distribution.kl_divergence(self.p, self.q),
kl._kl_expfamily_expfamily(self.p, self.q),
rtol=config.RTOL.get(config.DEFAULT_DTYPE),
atol=config.ATOL.get(config.DEFAULT_DTYPE))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.special
import scipy.stats
from paddle.distribution import kl
import config
import mock_data as mock
paddle.enable_static()
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [
('test_regular_input', 6.0 * np.random.random((4, 5)) + 1e-4,
6.0 * np.random.random((4, 5)) + 1e-4, 6.0 * np.random.random(
(4, 5)) + 1e-4, 6.0 * np.random.random((4, 5)) + 1e-4),
])
class TestKLBetaBeta(unittest.TestCase):
def setUp(self):
self.mp = paddle.static.Program()
self.sp = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.mp, self.sp):
a1 = paddle.static.data('a1', self.a1.shape, dtype=self.a1.dtype)
b1 = paddle.static.data('b1', self.b1.shape, dtype=self.b1.dtype)
a2 = paddle.static.data('a2', self.a2.shape, dtype=self.a2.dtype)
b2 = paddle.static.data('b2', self.b2.shape, dtype=self.b2.dtype)
self.p = paddle.distribution.Beta(a1, b1)
self.q = paddle.distribution.Beta(a2, b2)
self.feeds = {
'a1': self.a1,
'b1': self.b1,
'a2': self.a2,
'b2': self.b2
}
def test_kl_divergence(self):
with paddle.static.program_guard(self.mp, self.sp):
out = paddle.distribution.kl_divergence(self.p, self.q)
self.executor.run(self.sp)
[out] = self.executor.run(self.mp,
feed=self.feeds,
fetch_list=[out])
np.testing.assert_allclose(
out,
self.scipy_kl_beta_beta(self.a1, self.b1, self.a2, self.b2),
rtol=config.RTOL.get(str(self.a1.dtype)),
atol=config.ATOL.get(str(self.a1.dtype)))
def scipy_kl_beta_beta(self, a1, b1, a2, b2):
return (scipy.special.betaln(a2, b2) - scipy.special.betaln(a1, b1) +
(a1 - a2) * scipy.special.digamma(a1) +
(b1 - b2) * scipy.special.digamma(b1) +
(a2 - a1 + b2 - b1) * scipy.special.digamma(a1 + b1))
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'conc1', 'conc2'), [
('test-regular-input', np.random.random((5, 7, 8, 10)), np.random.random(
(5, 7, 8, 10))),
])
class TestKLDirichletDirichlet(unittest.TestCase):
def setUp(self):
self.mp = paddle.static.Program()
self.sp = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.mp, self.sp):
conc1 = paddle.static.data('conc1', self.conc1.shape,
self.conc1.dtype)
conc2 = paddle.static.data('conc2', self.conc2.shape,
self.conc2.dtype)
self.p = paddle.distribution.Dirichlet(conc1)
self.q = paddle.distribution.Dirichlet(conc2)
self.feeds = {'conc1': self.conc1, 'conc2': self.conc2}
def test_kl_divergence(self):
with paddle.static.program_guard(self.mp, self.sp):
out = paddle.distribution.kl_divergence(self.p, self.q)
self.executor.run(self.sp)
[out] = self.executor.run(self.mp,
feed=self.feeds,
fetch_list=[out])
np.testing.assert_allclose(
out,
self.scipy_kl_diric_diric(self.conc1, self.conc2),
rtol=config.RTOL.get(str(self.conc1.dtype)),
atol=config.ATOL.get(str(self.conc1.dtype)))
def scipy_kl_diric_diric(self, conc1, conc2):
return (
scipy.special.gammaln(np.sum(conc1, -1)) -
scipy.special.gammaln(np.sum(conc2, -1)) - np.sum(
scipy.special.gammaln(conc1) - scipy.special.gammaln(conc2), -1)
+ np.sum((conc1 - conc2) *
(scipy.special.digamma(conc1) -
scipy.special.digamma(np.sum(conc1, -1, keepdims=True))),
-1))
class DummyDistribution(paddle.distribution.Distribution):
pass
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'p', 'q'),
[('test-dispatch-exception')])
class TestDispatch(unittest.TestCase):
def setUp(self):
self.mp = paddle.static.Program()
self.sp = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.mp, self.sp):
self.p = DummyDistribution()
self.q = DummyDistribution()
def test_dispatch_with_unregister(self):
with self.assertRaises(NotImplementedError):
with paddle.static.program_guard(self.mp, self.sp):
out = paddle.distribution.kl_divergence(self.p, self.q)
self.executor.run(self.sp)
self.executor.run(self.mp, feed={}, fetch_list=[out])
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'rate1', 'rate2'),
[('test-diff-dist', np.random.rand(100, 200, 100) + 1.0,
np.random.rand(100, 200, 100) + 2.0),
('test-same-dist', np.array([1.0]), np.array([1.0]))])
class TestKLExpfamilyExpFamily(unittest.TestCase):
def setUp(self):
self.mp = paddle.static.Program()
self.sp = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(self.mp, self.sp):
rate1 = paddle.static.data(
'rate1', shape=self.rate1.shape, dtype=self.rate1.dtype)
rate2 = paddle.static.data(
'rate2', shape=self.rate2.shape, dtype=self.rate2.dtype)
self.p = mock.Exponential(rate1)
self.q = mock.Exponential(rate2)
self.feeds = {'rate1': self.rate1, 'rate2': self.rate2}
def test_kl_expfamily_expfamily(self):
with paddle.static.program_guard(self.mp, self.sp):
out1 = paddle.distribution.kl_divergence(self.p, self.q)
out2 = kl._kl_expfamily_expfamily(self.p, self.q)
self.executor.run(self.sp)
[out1, out2] = self.executor.run(self.mp,
feed=self.feeds,
fetch_list=[out1, out2])
np.testing.assert_allclose(
out1,
out2,
rtol=config.RTOL.get(config.DEFAULT_DTYPE),
atol=config.ATOL.get(config.DEFAULT_DTYPE))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册