未验证 提交 00cddf07 编写于 作者: X Xiaoxu Chen 提交者: GitHub

add ExponentialFamily and Dirichlet probability distribution (#38445)

* extend Distribution baseclass for supporting multivariant distribution and prob method

* add ExponentialFamily base class and entropy using Bregman divergence

* add dirichlet probability distribution
上级 c5bf09bb
...@@ -13,8 +13,16 @@ ...@@ -13,8 +13,16 @@
# limitations under the License. # limitations under the License.
from .categorical import Categorical from .categorical import Categorical
from .dirichlet import Dirichlet
from .distribution import Distribution from .distribution import Distribution
from .exponential_family import ExponentialFamily
from .normal import Normal from .normal import Normal
from .uniform import Uniform from .uniform import Uniform
__all__ = ['Categorical', 'Distribution', 'Normal', 'Uniform'] __all__ = [ #noqa
'Categorical',
'Distribution',
'Normal', 'Uniform',
'ExponentialFamily',
'Dirichlet'
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.framework import in_dygraph_mode
from ..fluid.layer_helper import LayerHelper
from .exponential_family import ExponentialFamily
class Dirichlet(ExponentialFamily):
r"""
Dirichlet distribution with parameter concentration
The Dirichlet distribution is defined over the `(k-1)-simplex` using a
positive, lenght-k vector concentration(`k > 1`).
The Dirichlet is identically the Beta distribution when `k = 2`.
The probability density function (pdf) is
.. math::
f(x_1,...,x_k; \alpha_1,...,\alpha_k) = \frac{1}{B(\alpha)} \prod_{i=1}^{k}x_i^{\alpha_i-1}
The normalizing constant is the multivariate beta function.
Args:
concentration (Tensor): concentration parameter of dirichlet
distribution
Examples:
.. code-block:: python
import paddle
dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.]))
print(dirichlet.entropy())
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [-1.24434423])
print(dirichlet.prob(paddle.to_tensor([.3, .5, .6])))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [10.80000114])
"""
def __init__(self, concentration):
if concentration.dim() < 1:
raise ValueError(
"`concentration` parameter must be at least one dimensional")
self.concentration = concentration
super(Dirichlet, self).__init__(concentration.shape[:-1],
concentration.shape[-1:])
@property
def mean(self):
"""mean of Dirichelt distribution.
Returns:
mean value of distribution.
"""
return self.concentration / self.concentration.sum(-1, keepdim=True)
@property
def variance(self):
"""variance of Dirichlet distribution.
Returns:
variance value of distribution.
"""
concentration0 = self.concentration.sum(-1, keepdim=True)
return (self.concentration * (concentration0 - self.concentration)) / (
concentration0.pow(2) * (concentration0 + 1))
def sample(self, shape=()):
"""sample from dirichlet distribution.
Args:
shape (Tensor, optional): sample shape. Defaults to empty tuple.
"""
shape = shape if isinstance(shape, tuple) else tuple(shape)
return _dirichlet(self.concentration.expand(self._extend_shape(shape)))
def prob(self, value):
"""Probability density function(pdf) evaluated at value.
Args:
value (Tensor): value to be evaluated.
Returns:
pdf evaluated at value.
"""
return paddle.exp(self.log_prob(value))
def log_prob(self, value):
"""log of probability densitiy function.
Args:
value (Tensor): value to be evaluated.
"""
return ((paddle.log(value) * (self.concentration - 1.0)
).sum(-1) + paddle.lgamma(self.concentration.sum(-1)) -
paddle.lgamma(self.concentration).sum(-1))
def entropy(self):
"""entropy of Dirichlet distribution.
Returns:
entropy of distribution.
"""
concentration0 = self.concentration.sum(-1)
k = self.concentration.shape[-1]
return (paddle.lgamma(self.concentration).sum(-1) -
paddle.lgamma(concentration0) -
(k - concentration0) * paddle.digamma(concentration0) - (
(self.concentration - 1.0
) * paddle.digamma(self.concentration)).sum(-1))
@property
def _natural_parameters(self):
return (self.concentration, )
def _log_normalizer(self, x):
return x.lgamma().sum(-1) - paddle.lgamma(x.sum(-1))
def _dirichlet(concentration, name=None):
raise NotImplementedError
# op_type = 'dirichlet'
# check_variable_and_dtype(concentration, 'concentration',
# ['float32', 'float64'], op_type)
# if in_dygraph_mode():
# return paddle._C_ops.dirichlet(concentration)
# else:
# helper = LayerHelper(op_type, **locals())
# out = helper.create_variable_for_type_inference(
# dtype=concentration.dtype)
# helper.append_op(
# type=op_type,
# inputs={"Alpha": concentration},
# outputs={'Out': out},
# attrs={})
# return out
...@@ -42,10 +42,34 @@ class Distribution(object): ...@@ -42,10 +42,34 @@ class Distribution(object):
implemented in specific distributions. implemented in specific distributions.
""" """
def __init__(self): def __init__(self, batch_shape=(), event_shape=()):
self._batch_shape = batch_shape if isinstance(
batch_shape, tuple) else tuple(batch_shape)
self._event_shape = event_shape if isinstance(
event_shape, tuple) else tuple(event_shape)
super(Distribution, self).__init__() super(Distribution, self).__init__()
def sample(self): @property
def batch_shape(self):
"""Returns batch shape of distribution
Returns:
Tensor: batch shape
"""
return self._batch_shape
@property
def event_shape(self):
"""Returns event shape of distribution
Returns:
Tensor: event shape
"""
return self._event_shape
def sample(self, shape=()):
"""Sampling from the distribution.""" """Sampling from the distribution."""
raise NotImplementedError raise NotImplementedError
...@@ -57,6 +81,14 @@ class Distribution(object): ...@@ -57,6 +81,14 @@ class Distribution(object):
"""The KL-divergence between self distributions and other.""" """The KL-divergence between self distributions and other."""
raise NotImplementedError raise NotImplementedError
def prob(self, value):
"""Probability density/mass function evaluated at value.
Args:
value (Tensor): value which will be evaluated
"""
raise NotImplementedError
def log_prob(self, value): def log_prob(self, value):
"""Log probability density/mass function.""" """Log probability density/mass function."""
raise NotImplementedError raise NotImplementedError
...@@ -65,6 +97,17 @@ class Distribution(object): ...@@ -65,6 +97,17 @@ class Distribution(object):
"""Probability density/mass function.""" """Probability density/mass function."""
raise NotImplementedError raise NotImplementedError
def _extend_shape(self, sample_shape):
"""compute shape of the sample
Args:
sample_shape (Tensor): sample shape
Returns:
Tensor: generated sample data shape
"""
return sample_shape + self._batch_shape + self._event_shape
def _validate_args(self, *args): def _validate_args(self, *args):
""" """
Argument validation for distribution args Argument validation for distribution args
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ..fluid.framework import in_dygraph_mode
from .distribution import Distribution
class ExponentialFamily(Distribution):
""" Base class for exponential family distribution.
"""
@property
def _natural_parameters(self):
raise NotImplementedError
def _log_normalizer(self):
raise NotImplementedError
@property
def _mean_carrier_measure(self):
raise NotImplementedError
def entropy(self):
"""caculate entropy use `bregman divergence`
https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf
"""
entropy_value = -self._mean_carrier_measure
natural_parameters = []
for parameter in self._natural_parameters:
parameter = parameter.detach()
parameter.stop_gradient = False
natural_parameters.append(parameter)
log_norm = self._log_normalizer(*natural_parameters)
if in_dygraph_mode():
grads = paddle.grad(
log_norm.sum(), natural_parameters, create_graph=True)
else:
grads = paddle.static.gradients(log_norm.sum(), natural_parameters)
entropy_value += log_norm
for p, g in zip(natural_parameters, grads):
entropy_value -= p * g
return entropy_value
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import sys
import numpy as np
import paddle
DEVICES = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
DEVICES.append(paddle.CUDAPlace(0))
DEFAULT_DTYPE = 'float64'
TEST_CASE_NAME = 'suffix'
# All test case will use float64 for compare percision, refs:
# https://github.com/PaddlePaddle/Paddle/wiki/Upgrade-OP-Precision-to-Float64
RTOL = {
'float32': 1e-03,
'complex64': 1e-3,
'float64': 1e-5,
'complex128': 1e-5
}
ATOL = {'float32': 0.0, 'complex64': 0, 'float64': 0.0, 'complex128': 0}
def xrand(shape=(10, 10, 10), dtype=DEFAULT_DTYPE, min=1.0, max=10.0):
return ((np.random.rand(*shape).astype(dtype)) * (max - min) + min)
def place(devices, key='place'):
def decorate(cls):
module = sys.modules[cls.__module__].__dict__
raw_classes = {
k: v
for k, v in module.items() if k.startswith(cls.__name__)
}
for raw_name, raw_cls in raw_classes.items():
for d in devices:
test_cls = dict(raw_cls.__dict__)
test_cls.update({key: d})
new_name = raw_name + '.' + d.__class__.__name__
module[new_name] = type(new_name, (raw_cls, ), test_cls)
del module[raw_name]
return cls
return decorate
def parameterize(fields, values=None):
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]
def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for k, v in enumerate(params):
test_cls = dict(cls.__dict__)
test_cls.update(v)
name = cls.__name__ + str(k)
name = name + '.' + v.get('suffix') if v.get('suffix') else name
test_cls_module[name] = type(name, (cls, ), test_cls)
for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls
return decorate
@contextlib.contextmanager
def stgraph(func, *args):
"""static graph exec context"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = func(input, n, axes, norm)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class Exponential(paddle.distribution.ExponentialFamily):
"""mock exponential distribution, which support computing entropy and
kl use bregman divergence
"""
_mean_carrier_measure = 0
def __init__(self, rate):
self._rate = rate
super(Exponential, self).__init__(batch_shape=rate.shape)
@property
def rate(self):
return self._rate
def entropy(self):
return 1.0 - paddle.log(self._rate)
@property
def _natural_parameters(self):
return (-self._rate, )
def _log_normalizer(self, x):
return -paddle.log(-x)
# @paddle.distribution.register_kl(Exponential, Exponential)
# def _kl_exponential_exponential(p, q):
# rate_ratio = q.rate / p.rate
# t1 = -rate_ratio.log()
# return t1 + rate_ratio - 1
class DummyExpFamily(paddle.distribution.ExponentialFamily):
"""dummy class extend from exponential family
"""
def __init__(self, *args):
pass
def entropy(self):
return 1.0
@property
def _natural_parameters(self):
return (1.0, )
def _log_normalizer(self, x):
return -paddle.log(-x)
...@@ -21,6 +21,8 @@ from paddle import fluid ...@@ -21,6 +21,8 @@ from paddle import fluid
from paddle.distribution import * from paddle.distribution import *
from paddle.fluid import layers from paddle.fluid import layers
import config
paddle.enable_static() paddle.enable_static()
...@@ -128,3 +130,41 @@ class DistributionTestName(unittest.TestCase): ...@@ -128,3 +130,41 @@ class DistributionTestName(unittest.TestCase):
lp = categorical1.log_prob(value_tensor) lp = categorical1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob') self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'batch_shape', 'event_shape'),
[('test-tuple', (10, 20),
(10, 20)), ('test-list', [100, 100], [100, 200, 300]),
('test-null-eventshape', (100, 100), ())])
class TestDistributionShape(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.dist = paddle.distribution.Distribution(
batch_shape=self.batch_shape, event_shape=self.event_shape)
def tearDown(self):
paddle.enable_static()
def test_batch_shape(self):
self.assertTrue(isinstance(self.dist.batch_shape, tuple))
self.assertTrue(self.dist.batch_shape == tuple(self.batch_shape))
def test_event_shape(self):
self.assertTrue(isinstance(self.dist.event_shape, tuple))
self.assertTrue(self.dist.event_shape == tuple(self.event_shape))
def test_prob(self):
with self.assertRaises(NotImplementedError):
self.dist.prob(paddle.to_tensor(config.xrand()))
def test_extend_shape(self):
shapes = [(34, 20), (56, ), ()]
for shape in shapes:
self.assertTrue(
self.dist._extend_shape(shape),
shape + self.dist.batch_shape + self.dist.event_shape)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import scipy.stats
import config
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place,
xrand)
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'concentration'),
[
('test-one-dim', config.xrand((89, ))),
# ('test-multi-dim', config.xrand((10, 20, 30)))
])
class TestDirichlet(unittest.TestCase):
def setUp(self):
self._paddle_diric = paddle.distribution.Dirichlet(
paddle.to_tensor(self.concentration))
def test_mean(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_diric.mean,
scipy.stats.dirichlet.mean(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
def test_variance(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_diric.variance,
scipy.stats.dirichlet.var(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
def test_prob(self):
value = [np.random.rand(*self.concentration.shape)]
value = [v / v.sum() for v in value]
for v in value:
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_diric.prob(paddle.to_tensor(v)),
scipy.stats.dirichlet.pdf(v, self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
def test_log_prob(self):
value = [np.random.rand(*self.concentration.shape)]
value = [v / v.sum() for v in value]
for v in value:
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_diric.log_prob(paddle.to_tensor(v)),
scipy.stats.dirichlet.logpdf(v, self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
def test_entropy(self):
with paddle.fluid.dygraph.guard(self.place):
np.testing.assert_allclose(
self._paddle_diric.entropy(),
scipy.stats.dirichlet.entropy(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
def test_natural_parameters(self):
self.assertTrue(
isinstance(self._paddle_diric._natural_parameters, tuple))
def test_log_normalizer(self):
self.assertTrue(
np.all(
self._paddle_diric._log_normalizer(
paddle.to_tensor(config.xrand((100, 100, 100)))).numpy() <
0.0))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'concentration'),
[('test-zero-dim', np.array(1.0))])
class TestDirichletException(unittest.TestCase):
def TestInit(self):
with self.assertRaises(ValueError):
paddle.distribution.Dirichlet(
paddle.squeeze(self.concentration))
def TestSample(self):
with self.assertRaises(NotImplementedError):
paddle.distribution.Dirichlet(
paddle.to_tensor(self.concentration)).sample()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import scipy.stats
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place,
xrand)
paddle.enable_static()
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'concentration'),
[('test-one-dim', np.random.rand(89) + 5.0)])
class TestDirichlet(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor()
with paddle.static.program_guard(self.program):
conc = paddle.static.data('conc', self.concentration.shape,
self.concentration.dtype)
self._paddle_diric = paddle.distribution.Dirichlet(conc)
self.feeds = {'conc': self.concentration}
def test_mean(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.mean])
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.mean(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
def test_variance(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.variance])
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.var(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
def test_prob(self):
with paddle.static.program_guard(self.program):
random_number = np.random.rand(*self.concentration.shape)
random_number = random_number / random_number.sum()
feeds = dict(self.feeds, value=random_number)
value = paddle.static.data('value', random_number.shape,
random_number.dtype)
out = self._paddle_diric.prob(value)
[out] = self.executor.run(self.program,
feed=feeds,
fetch_list=[out])
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.pdf(random_number, self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
def test_log_prob(self):
with paddle.static.program_guard(self.program):
random_number = np.random.rand(*self.concentration.shape)
random_number = random_number / random_number.sum()
feeds = dict(self.feeds, value=random_number)
value = paddle.static.data('value', random_number.shape,
random_number.dtype)
out = self._paddle_diric.log_prob(value)
[out] = self.executor.run(self.program,
feed=feeds,
fetch_list=[out])
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.logpdf(random_number, self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
def test_entropy(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.entropy()])
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.entropy(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import mock_data as mock
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'dist'), [('test-mock-exp',
mock.Exponential(rate=paddle.rand(
[100, 200, 99],
dtype=config.DEFAULT_DTYPE)))])
class TestExponentialFamily(unittest.TestCase):
def test_entropy(self):
np.testing.assert_allclose(
self.dist.entropy(),
paddle.distribution.ExponentialFamily.entropy(self.dist),
rtol=config.RTOL.get(config.DEFAULT_DTYPE),
atol=config.ATOL.get(config.DEFAULT_DTYPE))
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'dist'),
[('test-dummy-dist', mock.DummyExpFamily(0.5, 0.5)),
('test-dirichlet-dist',
paddle.distribution.Dirichlet(paddle.to_tensor(config.xrand())))])
class TestExponentialFamilyException(unittest.TestCase):
def test_entropy_expection(self):
with self.assertRaises(NotImplementedError):
paddle.distribution.ExponentialFamily.entropy(self.dist)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import mock_data as mock
paddle.enable_static()
@config.place(config.DEVICES)
class TestExponentialFamily(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor()
with paddle.static.program_guard(self.program):
rate_np = config.xrand((100, 200, 99))
rate = paddle.static.data('rate', rate_np.shape, rate_np.dtype)
self.mock_dist = mock.Exponential(rate)
self.feeds = {'rate': rate_np}
def test_entropy(self):
with paddle.static.program_guard(self.program):
[out1, out2] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[
self.mock_dist.entropy(),
paddle.distribution.ExponentialFamily.entropy(
self.mock_dist)
])
np.testing.assert_allclose(
out1,
out2,
rtol=config.RTOL.get(config.DEFAULT_DTYPE),
atol=config.ATOL.get(config.DEFAULT_DTYPE))
def test_entropy_expection(self):
with paddle.static.program_guard(self.program):
with self.assertRaises(NotImplementedError):
paddle.distribution.ExponentialFamily.entropy(
mock.DummyExpFamily(0.5, 0.5))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册