diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 64d59b04864baa077e74806d6bf6a931442ee5ab..0d6aa561a52bf54c31935488278336fa5de3282e 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -26,11 +26,12 @@ from paddle.distribution.transform import * # noqa: F403 from paddle.distribution.transformed_distribution import \ TransformedDistribution from paddle.distribution.uniform import Uniform +from paddle.distribution.laplace import Laplace __all__ = [ # noqa 'Beta', 'Categorical', 'Dirichlet', 'Distribution', 'ExponentialFamily', 'Multinomial', 'Normal', 'Uniform', 'kl_divergence', 'register_kl', - 'Independent', 'TransformedDistribution' + 'Independent', 'TransformedDistribution', 'Laplace' ] __all__.extend(transform.__all__) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index ce3b828eaebadf41ce382e87336612aa6ddc1a78..6dae2f64fb733a44b3b7b6546c65fcf360f9f364 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -22,6 +22,7 @@ from paddle.distribution.distribution import Distribution from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.normal import Normal from paddle.distribution.uniform import Uniform +from paddle.distribution.laplace import Laplace from paddle.fluid.framework import _non_static_mode, in_dygraph_mode __all__ = ["register_kl", "kl_divergence"] @@ -168,6 +169,11 @@ def _kl_uniform_uniform(p, q): return p.kl_divergence(q) +@register_kl(Laplace, Laplace) +def _kl_laplace_laplace(p, q): + return p.kl_divergence(q) + + @register_kl(ExponentialFamily, ExponentialFamily) def _kl_expfamily_expfamily(p, q): """Compute kl-divergence using `Bregman divergences `_ diff --git a/python/paddle/distribution/laplace.py b/python/paddle/distribution/laplace.py new file mode 100644 index 0000000000000000000000000000000000000000..1796f50893e60831c5612231e357c171ea20e52a --- /dev/null +++ b/python/paddle/distribution/laplace.py @@ -0,0 +1,406 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers + +import numpy as np +import paddle +from paddle.distribution import distribution +from paddle.fluid import framework as framework + + +class Laplace(distribution.Distribution): + r""" + Creates a Laplace distribution parameterized by :attr:`loc` and :attr:`scale`. + + Mathematical details + + The probability density function (pdf) is + + .. math:: + pdf(x; \mu, \sigma) = \frac{1}{2 * \sigma} * e^{\frac{-|x - \mu|}{\sigma}} + + In the above equation: + + * :math:`loc = \mu`: is the location parameter. + * :math:`scale = \sigma`: is the scale parameter. + + Args: + loc (scalar|Tensor): The mean of the distribution. + scale (scalar|Tensor): The scale of the distribution. + + Examples: + .. code-block:: python + + import paddle + + m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) + m.sample() # Laplace distributed with loc=0, scale=1 + # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # [3.68546247]) + + """ + + def __init__(self, loc, scale): + if not isinstance(loc, (numbers.Real, framework.Variable)): + raise TypeError( + f"Expected type of loc is Real|Variable, but got {type(loc)}") + + if not isinstance(scale, (numbers.Real, framework.Variable)): + raise TypeError( + f"Expected type of scale is Real|Variable, but got {type(scale)}" + ) + + if isinstance(loc, numbers.Real): + loc = paddle.full(shape=(), fill_value=loc) + + if isinstance(scale, numbers.Real): + scale = paddle.full(shape=(), fill_value=scale) + + if (len(scale.shape) > 0 or len(loc.shape) > 0) and (loc.dtype + == scale.dtype): + self.loc, self.scale = paddle.broadcast_tensors([loc, scale]) + else: + self.loc, self.scale = loc, scale + + super(Laplace, self).__init__(self.loc.shape) + + @property + def mean(self): + """Mean of distribution. + + Returns: + Tensor: The mean value. + """ + return self.loc + + @property + def stddev(self): + r"""Standard deviation. + + The stddev is + + .. math:: + stddev = \sqrt{2} * \sigma + + In the above equation: + + * :math:`scale = \sigma`: is the scale parameter. + + Returns: + Tensor: The std value. + """ + return (2**0.5) * self.scale + + @property + def variance(self): + """Variance of distribution. + + The variance is + + .. math:: + variance = 2 * \sigma^2 + + In the above equation: + + * :math:`scale = \sigma`: is the scale parameter. + + Returns: + Tensor: The variance value. + """ + return self.stddev.pow(2) + + def _validate_value(self, value): + """Argument dimension check for distribution methods such as `log_prob`, + `cdf` and `icdf`. + + Args: + value (Tensor|Scalar): The input value, which can be a scalar or a tensor. + + Returns: + loc, scale, value: The broadcasted loc, scale and value, with the same dimension and data type. + """ + if isinstance(value, numbers.Real): + value = paddle.full(shape=(), fill_value=value) + if value.dtype != self.scale.dtype: + value = paddle.cast(value, self.scale.dtype) + if len(self.scale.shape) > 0 or len(self.loc.shape) > 0 or len( + value.shape) > 0: + loc, scale, value = paddle.broadcast_tensors( + [self.loc, self.scale, value]) + else: + loc, scale = self.loc, self.scale + + return loc, scale, value + + def log_prob(self, value): + r"""Log probability density/mass function. + + The log_prob is + + .. math:: + log\_prob(value) = \frac{-log(2 * \sigma) - |value - \mu|}{\sigma} + + In the above equation: + + * :math:`loc = \mu`: is the location parameter. + * :math:`scale = \sigma`: is the scale parameter. + + Args: + value (Tensor|Scalar): The input value, can be a scalar or a tensor. + + Returns: + Tensor: The log probability, whose data type is same with value. + + Examples: + .. code-block:: python + + import paddle + + m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) + value = paddle.to_tensor([0.1]) + m.log_prob(value) + # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # [-0.79314721]) + + """ + loc, scale, value = self._validate_value(value) + log_scale = -paddle.log(2 * scale) + + return (log_scale - paddle.abs(value - loc) / scale) + + def entropy(self): + r"""Entropy of Laplace distribution. + + The entropy is: + + .. math:: + entropy() = 1 + log(2 * \sigma) + + In the above equation: + + * :math:`scale = \sigma`: is the scale parameter. + + Returns: + The entropy of distribution. + + Examples: + .. code-block:: python + + import paddle + + m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) + m.entropy() + # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # [1.69314718]) + """ + return 1 + paddle.log(2 * self.scale) + + def cdf(self, value): + r"""Cumulative distribution function. + + The cdf is + + .. math:: + cdf(value) = 0.5 - 0.5 * sign(value - \mu) * e^\frac{-|(\mu - \sigma)|}{\sigma} + + In the above equation: + + * :math:`loc = \mu`: is the location parameter. + * :math:`scale = \sigma`: is the scale parameter. + + Args: + value (Tensor): The value to be evaluated. + + Returns: + Tensor: The cumulative probability of value. + + Examples: + .. code-block:: python + + import paddle + + m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) + value = paddle.to_tensor([0.1]) + m.cdf(value) + # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # [0.54758132]) + """ + loc, scale, value = self._validate_value(value) + iterm = (0.5 * (value - loc).sign() * + paddle.expm1(-(value - loc).abs() / scale)) + + return 0.5 - iterm + + def icdf(self, value): + r"""Inverse Cumulative distribution function. + + The icdf is + + .. math:: + cdf^{-1}(value)= \mu - \sigma * sign(value - 0.5) * ln(1 - 2 * |value-0.5|) + + In the above equation: + + * :math:`loc = \mu`: is the location parameter. + * :math:`scale = \sigma`: is the scale parameter. + + Args: + value (Tensor): The value to be evaluated. + + Returns: + Tensor: The cumulative probability of value. + + Examples: + .. code-block:: python + + import paddle + + m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) + value = paddle.to_tensor([0.1]) + m.icdf(value) + # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # [-1.60943794]) + """ + loc, scale, value = self._validate_value(value) + term = value - 0.5 + + return (loc - scale * (term).sign() * paddle.log1p(-2 * term.abs())) + + def sample(self, shape=()): + r"""Generate samples of the specified shape. + + Args: + shape(tuple[int]): The shape of generated samples. + + Returns: + Tensor: A sample tensor that fits the Laplace distribution. + + Examples: + .. code-block:: python + + import paddle + + m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) + m.sample() # Laplace distributed with loc=0, scale=1 + # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # [3.68546247]) + """ + if not isinstance(shape, tuple): + raise TypeError( + f'Expected shape should be tuple[int], but got {type(shape)}') + + with paddle.no_grad(): + return self.rsample(shape) + + def rsample(self, shape): + r"""Reparameterized sample. + + Args: + shape(tuple[int]): The shape of generated samples. + + Returns: + Tensor: A sample tensor that fits the Laplace distribution. + + Examples: + .. code-block:: python + + import paddle + + m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) + m.rsample((1,)) # Laplace distributed with loc=0, scale=1 + # Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=True, + # [[0.04337667]]) + """ + + eps = self._get_eps() + shape = self._extend_shape(shape) or (1, ) + uniform = paddle.uniform(shape=shape, + min=float(np.nextafter(-1, 1)) + eps / 2, + max=1. - eps / 2, + dtype=self.loc.dtype) + + if len(self.scale.shape) == 0 and len(self.loc.shape) == 0: + loc, scale, uniform = paddle.broadcast_tensors( + [self.loc, self.scale, uniform]) + else: + loc, scale = self.loc, self.scale + + return (loc - scale * uniform.sign() * paddle.log1p(-uniform.abs())) + + def _get_eps(self): + """ + Get the eps of certain data type. + + Note: + Since paddle.finfo is temporarily unavailable, we + use hard-coding style to get eps value. + + Returns: + Float: An eps value by different data types. + """ + eps = 1.19209e-07 + if (self.loc.dtype == paddle.float64 + or self.loc.dtype == paddle.complex128): + eps = 2.22045e-16 + + return eps + + def kl_divergence(self, other): + r"""Calculate the KL divergence KL(self || other) with two Laplace instances. + + The kl_divergence between two Laplace distribution is + + .. math:: + KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio}) + + .. math:: + ratio = \frac{\sigma_0}{\sigma_1} + + .. math:: + diff = \mu_1 - \mu_0 + + In the above equation: + + * :math:`loc = \mu`: is the location parameter of self. + * :math:`scale = \sigma`: is the scale parameter of self. + * :math:`loc = \mu_1`: is the location parameter of the reference Laplace distribution. + * :math:`scale = \sigma_1`: is the scale parameter of the reference Laplace distribution. + * :math:`ratio`: is the ratio between the two distribution. + * :math:`diff`: is the difference between the two distribution. + + Args: + other (Laplace): An instance of Laplace. + + Returns: + Tensor: The kl-divergence between two laplace distributions. + + Examples: + .. code-block:: python + + import paddle + + m1 = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) + m2 = paddle.distribution.Laplace(paddle.to_tensor([1.0]), paddle.to_tensor([0.5])) + m1.kl_divergence(m2) + # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # [1.04261160]) + """ + + var_ratio = other.scale / self.scale + t = paddle.abs(self.loc - other.loc) + term1 = ((self.scale * paddle.exp(-t / self.scale) + t) / other.scale) + term2 = paddle.log(var_ratio) + + return term1 + term2 - 1 diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_laplace.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_laplace.py new file mode 100644 index 0000000000000000000000000000000000000000..867fc1846a34d58ec3fd06d27d33364c447d228e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_laplace.py @@ -0,0 +1,210 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import scipy.stats + +import paddle +import config +import parameterize + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc', 'scale'), [ + ('one-dim', parameterize.xrand((2, )),\ + parameterize.xrand((2, ))), + ('multi-dim', parameterize.xrand((5, 5)),\ + parameterize.xrand((5, 5))), + ]) +class TestLaplace(unittest.TestCase): + + def setUp(self): + self._dist = paddle.distribution.Laplace(loc=paddle.to_tensor(self.loc), + scale=paddle.to_tensor(\ + self.scale)) + + def test_mean(self): + mean = self._dist.mean + self.assertEqual(mean.numpy().dtype, self.scale.dtype) + np.testing.assert_allclose(mean, + self._np_mean(), + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_variance(self): + var = self._dist.variance + self.assertEqual(var.numpy().dtype, self.scale.dtype) + np.testing.assert_allclose(var, + self._np_variance(), + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_stddev(self): + stddev = self._dist.stddev + self.assertEqual(stddev.numpy().dtype, self.scale.dtype) + np.testing.assert_allclose(stddev, + self._np_stddev(), + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_entropy(self): + entropy = self._dist.entropy() + self.assertEqual(entropy.numpy().dtype, self.scale.dtype) + + def test_sample(self): + + sample_shape = (50000, ) + samples = self._dist.sample(sample_shape) + sample_values = samples.numpy() + + self.assertEqual(samples.numpy().dtype, self.scale.dtype) + self.assertEqual(tuple(samples.shape), + tuple(self._dist._extend_shape(sample_shape))) + + self.assertEqual(samples.shape, list(sample_shape + self.loc.shape)) + self.assertEqual(sample_values.shape, sample_shape + self.loc.shape) + + np.testing.assert_allclose(sample_values.mean(axis=0), + scipy.stats.laplace.mean(self.loc, + scale=self.scale), + rtol=0.2, + atol=0.) + np.testing.assert_allclose(sample_values.var(axis=0), + scipy.stats.laplace.var(self.loc, + scale=self.scale), + rtol=0.1, + atol=0.) + + def _np_mean(self): + return self.loc + + def _np_stddev(self): + return (2**0.5) * self.scale + + def _np_variance(self): + stddev = (2**0.5) * self.scale + return np.power(stddev, 2) + + def _np_entropy(self): + return scipy.stats.laplace.entropy(loc=self.loc, scale=self.scale) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls((parameterize.TEST_CASE_NAME, 'loc', 'scale'), [ + ('float', 1., 2.), + ('int', 3, 4), +]) +class TestLaplaceKS(unittest.TestCase): + + def setUp(self): + self._dist = paddle.distribution.Laplace(loc=self.loc, scale=self.scale) + + def test_sample(self): + + sample_shape = (20000, ) + samples = self._dist.sample(sample_shape) + sample_values = samples.numpy() + self.assertTrue(self._kstest(self.loc, self.scale, sample_values)) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, p_value = scipy.stats.kstest( + samples, + scipy.stats.laplace(loc, scale=scale).cdf) + return ks < 0.02 + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc', 'scale', 'value'), + [ + ('value-float', np.array([0.2, 0.3]),\ + np.array([2., 3.]), np.array([2., 5.])), + ('value-int', np.array([0.2, 0.3]),\ + np.array([2., 3.]), np.array([2, 5])), + ('value-multi-dim', np.array([0.2, 0.3]), np.array([2., 3.]),\ + np.array([[4., 6], [8, 2]])), + ]) +class TestLaplacePDF(unittest.TestCase): + + def setUp(self): + self._dist = paddle.distribution.Laplace(loc=paddle.to_tensor(self.loc), + scale=paddle.to_tensor(\ + self.scale)) + + def test_prob(self): + np.testing.assert_allclose( + self._dist.prob(paddle.to_tensor(self.value)), + scipy.stats.laplace.pdf(self.value, self.loc, self.scale), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def test_log_prob(self): + np.testing.assert_allclose( + self._dist.log_prob(paddle.to_tensor(self.value)), + scipy.stats.laplace.logpdf(self.value, self.loc, self.scale), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def test_cdf(self): + np.testing.assert_allclose(self._dist.cdf(paddle.to_tensor(self.value)), + scipy.stats.laplace.cdf( + self.value, self.loc, self.scale), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def test_icdf(self): + np.testing.assert_allclose( + self._dist.icdf(paddle.to_tensor(self.value)), + scipy.stats.laplace.ppf(self.value, self.loc, self.scale), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc1', 'scale1',\ + 'loc2', 'scale2'), [ + ('kl', np.array([0.0]), np.array([1.0]), \ + np.array([1.0]), np.array([0.5])) + ]) +class TestLaplaceAndLaplaceKL(unittest.TestCase): + + def setUp(self): + self._dist_1 = paddle.distribution.Laplace(loc=paddle.to_tensor(self.loc1), + scale=paddle.to_tensor(\ + self.scale1)) + self._dist_2 = paddle.distribution.Laplace(loc=paddle.to_tensor(self.loc2), + scale=paddle.to_tensor(\ + self.scale2)) + + def test_kl_divergence(self): + np.testing.assert_allclose(paddle.distribution.kl_divergence( + self._dist_1, self._dist_2), + self._np_kl(), + atol=0, + rtol=0.50) + + def _np_kl(self): + x = np.linspace(scipy.stats.laplace.ppf(0.01),\ + scipy.stats.laplace.ppf(0.99), 1000) + d1 = scipy.stats.laplace.pdf(x, loc=0., scale=1.) + d2 = scipy.stats.laplace.pdf(x, loc=1., scale=0.5) + return scipy.stats.entropy(d1, d2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_laplace_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_laplace_static.py new file mode 100644 index 0000000000000000000000000000000000000000..62fe225849c43bdf6a1f529a0f2cc8739768a3d7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_laplace_static.py @@ -0,0 +1,273 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import scipy.stats + +import paddle +import config +import parameterize + +paddle.enable_static() + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc', 'scale'), [ + ('one-dim', parameterize.xrand((2, )),\ + parameterize.xrand((2, ))), + ('multi-dim', parameterize.xrand((5, 5)),\ + parameterize.xrand((5, 5))), + ]) +class TestLaplace(unittest.TestCase): + + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + self._dist = paddle.distribution.Laplace(loc=loc, scale=scale) + self.sample_shape = (50000, ) + mean = self._dist.mean + var = self._dist.variance + stddev = self._dist.stddev + entropy = self._dist.entropy() + samples = self._dist.sample(self.sample_shape) + fetch_list = [mean, var, stddev, entropy, samples] + self.feeds = {'loc': self.loc, 'scale': self.scale} + + executor.run(startup_program) + [self.mean, self.var, self.stddev, self.entropy, + self.samples] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_mean(self): + self.assertEqual(str(self.mean.dtype).split('.')[-1], self.scale.dtype) + np.testing.assert_allclose(self.mean, + self._np_mean(), + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_variance(self): + self.assertEqual(str(self.var.dtype).split('.')[-1], self.scale.dtype) + np.testing.assert_allclose(self.var, + self._np_variance(), + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_stddev(self): + self.assertEqual( + str(self.stddev.dtype).split('.')[-1], self.scale.dtype) + np.testing.assert_allclose(self.stddev, + self._np_stddev(), + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_entropy(self): + self.assertEqual( + str(self.entropy.dtype).split('.')[-1], self.scale.dtype) + + def test_sample(self): + + self.assertEqual(self.samples.dtype, self.scale.dtype) + self.assertEqual(tuple(self.samples.shape), + tuple(self._dist._extend_shape(self.sample_shape))) + + self.assertEqual(self.samples.shape, self.sample_shape + self.loc.shape) + self.assertEqual(self.samples.shape, self.sample_shape + self.loc.shape) + + np.testing.assert_allclose(self.samples.mean(axis=0), + scipy.stats.laplace.mean(self.loc, + scale=self.scale), + rtol=0.2, + atol=0.) + np.testing.assert_allclose(self.samples.var(axis=0), + scipy.stats.laplace.var(self.loc, + scale=self.scale), + rtol=0.1, + atol=0.) + + def _np_mean(self): + return self.loc + + def _np_stddev(self): + return (2**0.5) * self.scale + + def _np_variance(self): + stddev = (2**0.5) * self.scale + return np.power(stddev, 2) + + def _np_entropy(self): + return scipy.stats.laplace.entropy(loc=self.loc, scale=self.scale) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc', 'scale', 'value'), + [ + ('value-float', np.array([0.2, 0.3]),\ + np.array([2., 3.]), np.array([2., 5.])), + ('value-int', np.array([0.2, 0.3]),\ + np.array([2., 3.]), np.array([2, 5])), + ('value-multi-dim', np.array([0.2, 0.3]), np.array([2., 3.]),\ + np.array([[4., 6], [8, 2]])), + ]) +class TestLaplacePDF(unittest.TestCase): + + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + value = paddle.static.data('value', self.value.shape, + self.value.dtype) + self._dist = paddle.distribution.Laplace(loc=loc, scale=scale) + prob = self._dist.prob(value) + log_prob = self._dist.log_prob(value) + cdf = self._dist.cdf(value) + icdf = self._dist.icdf(value) + fetch_list = [prob, log_prob, cdf, icdf] + self.feeds = {'loc': self.loc, 'scale': self.scale, 'value': self.value} + + executor.run(startup_program) + [self.prob, self.log_prob, self.cdf, + self.icdf] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_prob(self): + np.testing.assert_allclose(self.prob, + scipy.stats.laplace.pdf( + self.value, self.loc, self.scale), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def test_log_prob(self): + np.testing.assert_allclose(self.log_prob, + scipy.stats.laplace.logpdf( + self.value, self.loc, self.scale), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def test_cdf(self): + np.testing.assert_allclose(self.cdf, + scipy.stats.laplace.cdf( + self.value, self.loc, self.scale), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def test_icdf(self): + np.testing.assert_allclose(self.icdf, + scipy.stats.laplace.ppf( + self.value, self.loc, self.scale), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc1', 'scale1',\ + 'loc2', 'scale2'), [ + ('kl', np.array([0.0]), np.array([1.0]), \ + np.array([1.0]), np.array([0.5])) + ]) +class TestLaplaceAndLaplaceKL(unittest.TestCase): + + def setUp(self): + self.mp = paddle.static.Program() + self.sp = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(self.mp, self.sp): + loc1 = paddle.static.data('loc1', self.loc1.shape, self.loc1.dtype) + scale1 = paddle.static.data('scale1', self.scale1.shape, + self.scale1.dtype) + loc2 = paddle.static.data('loc2', self.loc2.shape, self.loc2.dtype) + scale2 = paddle.static.data('scale2', self.scale2.shape, + self.scale2.dtype) + self._dist_1 = paddle.distribution.Laplace(loc=loc1, scale=scale1) + self._dist_2 = paddle.distribution.Laplace(loc=loc2, scale=scale2) + self.feeds = { + 'loc1': self.loc1, + 'scale1': self.scale1, + 'loc2': self.loc2, + 'scale2': self.scale2 + } + + def test_kl_divergence(self): + with paddle.static.program_guard(self.mp, self.sp): + out = paddle.distribution.kl_divergence(self._dist_1, self._dist_2) + self.executor.run(self.sp) + [out] = self.executor.run(self.mp, + feed=self.feeds, + fetch_list=[out]) + np.testing.assert_allclose(out, self._np_kl(), atol=0, rtol=0.50) + + def _np_kl(self): + x = np.linspace(scipy.stats.laplace.ppf(0.01),\ + scipy.stats.laplace.ppf(0.99), 1000) + d1 = scipy.stats.laplace.pdf(x, loc=0., scale=1.) + d2 = scipy.stats.laplace.pdf(x, loc=1., scale=0.5) + return scipy.stats.entropy(d1, d2) + + +""" +# Note: Zero dimension of a Tensor is not supported by static mode of paddle; +# therefore, ks test below cannot be conducted temporarily. + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'loc', 'scale', 'sample_shape'), [ + ('one-dim', np.array(4.0), np.array(3.0), np.array([3000]))]) +class TestLaplaceKS(unittest.TestCase): + def setUp(self): + self.program = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(self.program): + loc = paddle.static.data('loc', self.loc.shape, + self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + self.sample = paddle.static.data('sample_shape', self.sample_shape.shape, + self.sample_shape.dtype) + self._dist = paddle.distribution.Laplace(loc=loc, scale=scale) + self.feeds = {'loc': self.loc, 'scale': self.scale, 'sample_shape': self.sample_shape} + + def test_sample(self): + with paddle.static.program_guard(self.program): + [sample_values] = self.executor.run(self.program, + feed=self.feeds, + fetch_list=self._dist.sample((3000,))) + self.assertTrue(self._kstest(self.loc, self.scale, sample_values)) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, p_value = scipy.stats.kstest( + samples, + scipy.stats.laplace(loc, scale=scale).cdf) + return ks < 0.02 +""" + +if __name__ == '__main__': + unittest.main()