diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 0d6aa561a52bf54c31935488278336fa5de3282e..0e77febe5519100b3a4a8e1185adc10d77f17127 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -20,6 +20,7 @@ from paddle.distribution.distribution import Distribution from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.independent import Independent from paddle.distribution.kl import kl_divergence, register_kl +from paddle.distribution.lognormal import LogNormal from paddle.distribution.multinomial import Multinomial from paddle.distribution.normal import Normal from paddle.distribution.transform import * # noqa: F403 @@ -31,7 +32,7 @@ from paddle.distribution.laplace import Laplace __all__ = [ # noqa 'Beta', 'Categorical', 'Dirichlet', 'Distribution', 'ExponentialFamily', 'Multinomial', 'Normal', 'Uniform', 'kl_divergence', 'register_kl', - 'Independent', 'TransformedDistribution', 'Laplace' + 'Independent', 'TransformedDistribution', 'Laplace', 'LogNormal' ] __all__.extend(transform.__all__) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 6dae2f64fb733a44b3b7b6546c65fcf360f9f364..80a093ad8b4915395d21f6929c6fb3dca6cc052f 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -21,6 +21,7 @@ from paddle.distribution.dirichlet import Dirichlet from paddle.distribution.distribution import Distribution from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.normal import Normal +from paddle.distribution.lognormal import LogNormal from paddle.distribution.uniform import Uniform from paddle.distribution.laplace import Laplace from paddle.fluid.framework import _non_static_mode, in_dygraph_mode @@ -96,7 +97,7 @@ def register_kl(cls_p, cls_q): def _dispatch(cls_p, cls_q): - """Multiple dispatch into concrete implement function""" + """Multiple dispatch into concrete implement function.""" # find all matched super class pair of p and q matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE @@ -212,5 +213,10 @@ def _kl_expfamily_expfamily(p, q): return kl +@register_kl(LogNormal, LogNormal) +def _kl_lognormal_lognormal(p, q): + return p._base.kl_divergence(q._base) + + def _sum_rightmost(value, n): return value.sum(list(range(-n, 0))) if n > 0 else value diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py new file mode 100644 index 0000000000000000000000000000000000000000..b171e1ecbc61ebb6e4e491064c3a20a5bf1dae4a --- /dev/null +++ b/python/paddle/distribution/lognormal.py @@ -0,0 +1,175 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.distribution.normal import Normal +from paddle.distribution.transform import ExpTransform +from paddle.distribution.transformed_distribution import \ + TransformedDistribution + + +class LogNormal(TransformedDistribution): + r"""The LogNormal distribution with location `loc` and `scale` parameters. + + .. math:: + + X \sim Normal(\mu, \sigma) + + Y = exp(X) \sim LogNormal(\mu, \sigma) + + + Due to LogNormal distribution is based on the transformation of Normal distribution, we call that :math:`Normal(\mu, \sigma)` is the underlying distribution of :math:`LogNormal(\mu, \sigma)` + + Mathematical details + + The probability density function (pdf) is + + .. math:: + pdf(x; \mu, \sigma) = \frac{1}{\sigma x \sqrt{2\pi}}e^{(-\frac{(ln(x) - \mu)^2}{2\sigma^2})} + + In the above equation: + + * :math:`loc = \mu`: is the means of the underlying Normal distribution. + * :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution. + + Args: + loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of the underlying Normal distribution. + scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of the underlying Normal distribution. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import LogNormal + + # Define a single scalar LogNormal distribution. + dist = LogNormal(loc=0., scale=3.) + # Define a batch of two scalar valued LogNormals. + # The underlying Normal of first has mean 1 and standard deviation 11, the underlying Normal of second 2 and 22. + dist = LogNormal(loc=[1., 2.], scale=[11., 22.]) + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample((3, )) + + # Define a batch of two scalar valued LogNormals. + # Their underlying Normal have mean 1, but different standard deviations. + dist = LogNormal(loc=1., scale=[11., 22.]) + + # Complete example + value_tensor = paddle.to_tensor([0.8], dtype="float32") + + lognormal_a = LogNormal([0.], [1.]) + lognormal_b = LogNormal([0.5], [2.]) + sample = lognormal_a.sample((2, )) + # a random tensor created by lognormal distribution with shape: [2, 1] + entropy = lognormal_a.entropy() + # [1.4189385] with shape: [1] + lp = lognormal_a.log_prob(value_tensor) + # [-0.72069150] with shape: [1] + p = lognormal_a.probs(value_tensor) + # [0.48641577] with shape: [1] + kl = lognormal_a.kl_divergence(lognormal_b) + # [0.34939718] with shape: [1] + """ + + def __init__(self, loc, scale): + self._base = Normal(loc=loc, scale=scale) + self.loc = self._base.loc + self.scale = self._base.scale + super(LogNormal, self).__init__(self._base, [ExpTransform()]) + + @property + def mean(self): + """Mean of lognormal distribuion. + + Returns: + Tensor: mean value. + """ + return paddle.exp(self._base.mean + self._base.variance / 2) + + @property + def variance(self): + """Variance of lognormal distribution. + + Returns: + Tensor: variance value. + """ + return (paddle.expm1(self._base.variance) * + paddle.exp(2 * self._base.mean + self._base.variance)) + + def entropy(self): + r"""Shannon entropy in nats. + + The entropy is + + .. math:: + + entropy(\sigma) = 0.5 \log (2 \pi e \sigma^2) + \mu + + In the above equation: + + * :math:`loc = \mu`: is the mean of the underlying Normal distribution. + * :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution. + + Returns: + Tensor: Shannon entropy of lognormal distribution. + + """ + return self._base.entropy() + self._base.mean + + def probs(self, value): + """Probability density/mass function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: probability.The data type is same with :attr:`value` . + + """ + return paddle.exp(self.log_prob(value)) + + def kl_divergence(self, other): + r"""The KL-divergence between two lognormal distributions. + + The probability density function (pdf) is + + .. math:: + + KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio}) + + .. math:: + + ratio = \frac{\sigma_0}{\sigma_1} + + .. math:: + + diff = \mu_1 - \mu_0 + + In the above equation: + + * :math:`loc = \mu_0`: is the means of current underlying Normal distribution. + * :math:`scale = \sigma_0`: is the stddevs of current underlying Normal distribution. + * :math:`loc = \mu_1`: is the means of other underlying Normal distribution. + * :math:`scale = \sigma_1`: is the stddevs of other underlying Normal distribution. + * :math:`ratio`: is the ratio of scales. + * :math:`diff`: is the difference between means. + + Args: + other (LogNormal): instance of LogNormal. + + Returns: + Tensor: kl-divergence between two lognormal distributions. + + """ + return self._base.kl_divergence(other._base) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index c9235dc94066514cd08d3a2f83c6f3c57baa9c10..33e36fbe72dac69bf76ded1ca35417e66db6d6c7 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -16,6 +16,7 @@ import math import warnings import numpy as np +import paddle from paddle import _C_ops, _legacy_C_ops from paddle.distribution import distribution from paddle.fluid import core @@ -25,6 +26,10 @@ from paddle.fluid.framework import _non_static_mode, in_dygraph_mode from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, elementwise_mul, elementwise_sub, nn, ops, tensor) +try: + from collections.abc import Iterable +except: + from collections import Iterable class Normal(distribution.Distribution): @@ -128,21 +133,42 @@ class Normal(distribution.Distribution): self.scale = tensor.cast(self.scale, dtype=self.dtype) super(Normal, self).__init__(self.loc.shape) - def sample(self, shape, seed=0): + @property + def mean(self): + """Mean of multinomial distribuion. + + Returns: + Tensor: mean value. + """ + return self.loc + + @property + def variance(self): + """Variance of lognormal distribution. + + Returns: + Tensor: variance value. + """ + return self.scale.pow(2) + + def sample(self, shape=(), seed=0): """Generate samples of the specified shape. Args: - shape (list): 1D `int32`. Shape of the generated samples. + shape (Sequence[int], optional): Shape of the generated samples. seed (int): Python integer number. Returns: Tensor, A tensor with prepended dimensions shape.The data type is float32. """ + if not isinstance(shape, Iterable): + raise TypeError('sample shape must be Iterable object.') + if not _non_static_mode(): - check_type(shape, 'shape', (list), 'sample') check_type(seed, 'seed', (int), 'sample') + shape = list(shape) batch_shape = list((self.loc + self.scale).shape) name = self.name + '_sample' @@ -162,14 +188,32 @@ class Normal(distribution.Distribution): return output else: output_shape = shape + batch_shape - output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \ - (tensor.zeros(output_shape, dtype=self.dtype) + self.scale) + output = nn.gaussian_random( + output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * ( + tensor.zeros(output_shape, dtype=self.dtype) + self.scale) output = elementwise_add(output, self.loc, name=name) if self.all_arg_is_float: return nn.reshape(output, shape, name=name) else: return output + def rsample(self, shape=()): + """Generate reparameterized samples of the specified shape. + + Args: + shape (Sequence[int], optional): Shape of the generated samples. + + Returns: + Tensor: A tensor with prepended dimensions shape.The data type is float32. + + """ + if not isinstance(shape, Iterable): + raise TypeError('sample shape must be Iterable object.') + + shape = self._extend_shape(tuple(shape)) + eps = paddle.normal(shape=shape) + return (self.loc + eps * self.scale) + def entropy(self): r"""Shannon entropy in nats. @@ -204,7 +248,7 @@ class Normal(distribution.Distribution): value (Tensor): The input tensor. Returns: - Tensor: log probability.The data type is same with value. + Tensor: log probability.The data type is same with :attr:`value` . """ name = self.name + '_log_prob' @@ -224,7 +268,7 @@ class Normal(distribution.Distribution): value (Tensor): The input tensor. Returns: - Tensor, probability. The data type is same with value. + Tensor, probability. The data type is same with :attr:`value` . """ name = self.name + '_probs' diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index 890b7c737aa716c896ee9e1715642cf10a77f372..ff2e13f94acf97387deed2dd70b5c5c48318518a 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -142,7 +142,7 @@ class Transform(object): input, [self]) if isinstance(input, Transform): return ChainTransform([self, input]) - return self.forward(x) + return self.forward(input) def forward(self, x): """Forward transformation with mapping :math:`y = f(x)`. @@ -285,7 +285,7 @@ class Transform(object): if hasattr(self, '_forward_log_det_jacobian'): return self._forward_log_det_jacobian(x) if hasattr(self, '_inverse_log_det_jacobian'): - return -self._inverse_log_det_jacobian(self.forward(y)) + return -self._inverse_log_det_jacobian(self.forward(x)) raise NotImplementedError( 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian' 'is implemented. One of them is required.') @@ -1133,8 +1133,8 @@ class StickBreakingTransform(Transform): offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) z = F.sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) - return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \ - F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1) + return F.pad(z, [0] * 2 * (len(x.shape) - 1) + [0, 1], value=1) * \ + F.pad(z_cumprod, [0] * 2 * (len(x.shape) - 1) + [1, 0], value=1) def _inverse(self, y): y_crop = y[..., :-1] diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index 880bab7d6e3c861404917cb6657610d76b678a03..da0e5908f0ce1bab37a8ab651de6095dee8d5924 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -61,9 +61,10 @@ class TransformedDistribution(distribution.Distribution): raise TypeError("All element of transforms must be Transform type.") chain = transform.ChainTransform(transforms) - if len(base.batch_shape + base.event_shape) < chain._domain.event_rank: + base_shape = base.batch_shape + base.event_shape + if len(base_shape) < chain._domain.event_rank: raise ValueError( - f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}." + f"'base' needs to have shape with size at least {chain._domain.event_rank}, but got {len(base_shape)}." ) if chain._domain.event_rank > len(base.event_shape): base = independent.Independent( @@ -74,7 +75,7 @@ class TransformedDistribution(distribution.Distribution): transformed_shape = chain.forward_shape(base.batch_shape + base.event_shape) transformed_event_rank = chain._codomain.event_rank + \ - max(len(base.event_shape)-chain._domain.event_rank, 0) + max(len(base.event_shape) - chain._domain.event_rank, 0) super(TransformedDistribution, self).__init__( transformed_shape[:len(transformed_shape) - transformed_event_rank], transformed_shape[len(transformed_shape) - transformed_event_rank:]) @@ -83,7 +84,7 @@ class TransformedDistribution(distribution.Distribution): """Sample from ``TransformedDistribution``. Args: - shape (tuple, optional): The sample shape. Defaults to (). + shape (Sequence[int], optional): The sample shape. Defaults to (). Returns: [Tensor]: The sample result. @@ -93,6 +94,20 @@ class TransformedDistribution(distribution.Distribution): x = t.forward(x) return x + def rsample(self, shape=()): + """Reparameterized sample from ``TransformedDistribution``. + + Args: + shape (Sequence[int], optional): The sample shape. Defaults to (). + + Returns: + [Tensor]: The sample result. + """ + x = self._base.rsample(shape) + for t in self._transforms: + x = t.forward(x) + return x + def log_prob(self, value): """The log probability evaluated at value. @@ -110,7 +125,7 @@ class TransformedDistribution(distribution.Distribution): event_rank += t._domain.event_rank - t._codomain.event_rank log_prob = log_prob - \ _sum_rightmost(t.forward_log_det_jacobian( - x), event_rank-t._domain.event_rank) + x), event_rank - t._domain.event_rank) y = x log_prob += _sum_rightmost(self._base.log_prob(y), event_rank - len(self._base.event_shape)) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c97047505c333c96775259240f0548cdb452ff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -0,0 +1,245 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import unittest + +import config +import numpy as np +import paddle +from paddle.distribution.kl import kl_divergence +from paddle.distribution.lognormal import LogNormal +from paddle.distribution.normal import Normal +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +import scipy.stats + +from test_distribution import DistributionNumpy + + +class LogNormalNumpy(DistributionNumpy): + + def __init__(self, loc, scale): + self.loc = np.array(loc) + self.scale = np.array(scale) + if str(self.loc.dtype) not in ['float32', 'float64']: + self.loc = self.loc.astype('float32') + self.scale = self.scale.astype('float32') + + @property + def mean(self): + var = self.scale * self.scale + return np.exp(self.loc + var / 2) + + @property + def variance(self): + var = self.scale * self.scale + return (np.exp(var) - 1) * np.exp(2 * self.loc + var) + + def log_prob(self, value): + var = self.scale * self.scale + log_scale = np.log(self.scale) + return -( + (np.log(value) - self.loc) * + (np.log(value) - self.loc)) / (2. * var) - log_scale - math.log( + math.sqrt(2. * math.pi)) - np.log(value) + + def probs(self, value): + var = self.scale * self.scale + return np.exp( + -1. * ((np.log(value) - self.loc) * (np.log(value) - self.loc)) / + (2. * var)) / (math.sqrt(2 * math.pi) * self.scale * value) + + def entropy(self): + return 0.5 + self.loc + 0.5 * np.log( + np.array(2. * math.pi).astype(self.loc.dtype)) + np.log(self.scale) + + def kl_divergence(self, other): + var_ratio = (self.scale / other.scale) + var_ratio = var_ratio * var_ratio + t1 = ((self.loc - other.loc) / other.scale) + t1 = (t1 * t1) + return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) + + +@place(config.DEVICES) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale', 'value'), + [('one-dim', xrand((2, )), xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((3, 3)), xrand((3, 3)), xrand((3, 3)))]) +class LogNormalTest(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + self.paddle_lognormal = LogNormal(loc=paddle.to_tensor(self.loc), + scale=paddle.to_tensor(self.scale)) + self.np_lognormal = LogNormalNumpy(self.loc, self.scale) + + def test_mean(self): + mean = self.paddle_lognormal.mean + np_mean = self.np_lognormal.mean + self.assertEqual(mean.numpy().dtype, np_mean.dtype) + np.testing.assert_allclose(mean, + np_mean, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_variance(self): + var = self.paddle_lognormal.variance + np_var = self.np_lognormal.variance + self.assertEqual(var.numpy().dtype, np_var.dtype) + np.testing.assert_allclose(var, + np_var, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_entropy(self): + entropy = self.paddle_lognormal.entropy() + np_entropy = self.np_lognormal.entropy() + self.assertEqual(entropy.numpy().dtype, np_entropy.dtype) + np.testing.assert_allclose(entropy, + np_entropy, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_probs(self): + with paddle.fluid.dygraph.guard(self.place): + probs = self.paddle_lognormal.probs(paddle.to_tensor(self.value)) + np_probs = self.np_lognormal.probs(self.value) + np.testing.assert_allclose( + probs, + np_probs, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_log_prob(self): + with paddle.fluid.dygraph.guard(self.place): + log_prob = self.paddle_lognormal.log_prob( + paddle.to_tensor(self.value)) + np_log_prob = self.np_lognormal.log_prob(self.value) + np.testing.assert_allclose( + log_prob, + np_log_prob, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + +@place(config.DEVICES) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (4, )), xrand((4, ), min=0, max=1))]) +class TestLogNormalSample(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + self.paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale) + n = 100000 + self.sample_shape = (n, ) + self.rsample_shape = (n, ) + self.samples = self.paddle_lognormal.sample(self.sample_shape) + self.rsamples = self.paddle_lognormal.rsample(self.rsample_shape) + + def test_sample(self): + samples_mean = self.samples.mean(axis=0) + samples_var = self.samples.var(axis=0) + np.testing.assert_allclose(samples_mean, + self.paddle_lognormal.mean, + rtol=0.1, + atol=0) + np.testing.assert_allclose(samples_var, + self.paddle_lognormal.variance, + rtol=0.1, + atol=0) + + rsamples_mean = self.rsamples.mean(axis=0) + rsamples_var = self.rsamples.var(axis=0) + np.testing.assert_allclose(rsamples_mean, + self.paddle_lognormal.mean, + rtol=0.1, + atol=0) + np.testing.assert_allclose(rsamples_var, + self.paddle_lognormal.variance, + rtol=0.1, + atol=0) + + batch_shape = (self.loc + self.scale).shape + self.assertEqual(self.samples.shape, + list(self.sample_shape + batch_shape)) + self.assertEqual(self.rsamples.shape, + list(self.rsample_shape + batch_shape)) + + for i in range(len(self.scale)): + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i])) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest( + samples, + scipy.stats.lognorm(s=scale, scale=np.exp(loc)).cdf) + return ks < 0.02 + + +@place(config.DEVICES) +@parameterize_cls( + (TEST_CASE_NAME, 'loc1', 'scale1', 'loc2', 'scale2'), + [('one-dim', xrand((2, )), xrand((2, )), xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((2, 2)), xrand((2, 2)), xrand((2, 2)), xrand((2, 2)))]) +class TestLogNormalKL(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + self.ln_a = LogNormal(loc=paddle.to_tensor(self.loc1), + scale=paddle.to_tensor(self.scale1)) + self.ln_b = LogNormal(loc=paddle.to_tensor(self.loc2), + scale=paddle.to_tensor(self.scale2)) + self.normal_a = Normal(loc=paddle.to_tensor(self.loc1), + scale=paddle.to_tensor(self.scale1)) + self.normal_b = Normal(loc=paddle.to_tensor(self.loc2), + scale=paddle.to_tensor(self.scale2)) + + def test_kl_divergence(self): + kl0 = self.ln_a.kl_divergence(self.ln_b) + kl1 = kl_divergence(self.ln_a, self.ln_b) + kl_normal = kl_divergence(self.normal_a, self.normal_b) + kl_formula = self._kl(self.ln_a, self.ln_b) + + self.assertEqual(tuple(kl0.shape), self.scale1.shape) + self.assertEqual(tuple(kl1.shape), self.scale1.shape) + np.testing.assert_allclose(kl0, + kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + np.testing.assert_allclose(kl1, + kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + np.testing.assert_allclose(kl_normal, + kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + + def _kl(self, dist1, dist2): + loc1 = np.array(dist1.loc) + loc2 = np.array(dist2.loc) + scale1 = np.array(dist1.scale) + scale2 = np.array(dist2.scale) + var_ratio = (scale1 / scale2) + var_ratio = var_ratio * var_ratio + t1 = ((loc1 - loc2) / scale2) + t1 = (t1 * t1) + return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py new file mode 100644 index 0000000000000000000000000000000000000000..75a9e497f34b7f263510aab39dc1d490c105a4b8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -0,0 +1,238 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import config +import numpy as np +import paddle +from paddle.distribution.kl import kl_divergence +from paddle.distribution.lognormal import LogNormal +from paddle.distribution.normal import Normal +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +import scipy.stats + +from test_distribution_lognormal import LogNormalNumpy + + +@place(config.DEVICES) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale', 'value'), + [('one-dim', xrand((2, )), xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((3, 3)), xrand((3, 3)), xrand((3, 3)))]) +class TestLogNormal(unittest.TestCase): + + def setUp(self): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + value = paddle.static.data('value', self.value.shape, + self.value.dtype) + self.paddle_lognormal = LogNormal(loc=loc, scale=scale) + self.np_lognormal = LogNormalNumpy(loc=self.loc, scale=self.scale) + mean = self.paddle_lognormal.mean + var = self.paddle_lognormal.variance + entropy = self.paddle_lognormal.entropy() + probs = self.paddle_lognormal.probs(value) + log_prob = self.paddle_lognormal.log_prob(value) + fetch_list = [mean, var, entropy, probs, log_prob] + self.feeds = {'loc': self.loc, 'scale': self.scale, 'value': self.value} + + executor.run(startup_program) + [self.mean, self.var, self.entropy, self.probs, + self.log_prob] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_mean(self): + np_mean = self.np_lognormal.mean + self.assertEqual(str(self.mean.dtype).split('.')[-1], self.scale.dtype) + np.testing.assert_allclose(self.mean, + np_mean, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_var(self): + np_var = self.np_lognormal.variance + self.assertEqual(str(self.var.dtype).split('.')[-1], self.scale.dtype) + np.testing.assert_allclose(self.var, + np_var, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_entropy(self): + np_entropy = self.np_lognormal.entropy() + self.assertEqual( + str(self.entropy.dtype).split('.')[-1], self.scale.dtype) + np.testing.assert_allclose(self.entropy, + np_entropy, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_probs(self): + np_probs = self.np_lognormal.probs(self.value) + np.testing.assert_allclose(self.probs, + np_probs, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_log_prob(self): + np_log_prob = self.np_lognormal.log_prob(self.value) + np.testing.assert_allclose(self.log_prob, + np_log_prob, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + +@place(config.DEVICES) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (4, )), xrand((4, ), min=0, max=1))]) +class TestLogNormalSample(unittest.TestCase): + + def setUp(self): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + n = 100000 + self.sample_shape = (n, ) + self.rsample_shape = (n, ) + self.paddle_lognormal = LogNormal(loc=loc, scale=scale) + mean = self.paddle_lognormal.mean + variance = self.paddle_lognormal.variance + samples = self.paddle_lognormal.sample(self.sample_shape) + rsamples = self.paddle_lognormal.rsample(self.rsample_shape) + fetch_list = [mean, variance, samples, rsamples] + self.feeds = {'loc': self.loc, 'scale': self.scale} + + executor.run(startup_program) + [self.mean, self.variance, self.samples, + self.rsamples] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_sample(self): + samples_mean = self.samples.mean(axis=0) + samples_var = self.samples.var(axis=0) + np.testing.assert_allclose(samples_mean, self.mean, rtol=0.1, atol=0) + np.testing.assert_allclose(samples_var, self.variance, rtol=0.1, atol=0) + + rsamples_mean = self.rsamples.mean(axis=0) + rsamples_var = self.rsamples.var(axis=0) + np.testing.assert_allclose(rsamples_mean, self.mean, rtol=0.1, atol=0) + np.testing.assert_allclose(rsamples_var, + self.variance, + rtol=0.1, + atol=0) + + batch_shape = (self.loc + self.scale).shape + self.assertEqual(self.samples.shape, self.sample_shape + batch_shape) + self.assertEqual(self.rsamples.shape, self.rsample_shape + batch_shape) + + for i in range(len(self.scale)): + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i])) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest( + samples, + scipy.stats.lognorm(s=scale, scale=np.exp(loc)).cdf) + return ks < 0.02 + + +@place(config.DEVICES) +@parameterize_cls( + (TEST_CASE_NAME, 'loc1', 'scale1', 'loc2', 'scale2'), + [('one-dim', xrand((2, )), xrand((2, )), xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((2, 2)), xrand((2, 2)), xrand((2, 2)), xrand((2, 2)))]) +class TestLogNormalKL(unittest.TestCase): + + def setUp(self): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc1 = paddle.static.data('loc1', self.loc1.shape, self.loc1.dtype) + scale1 = paddle.static.data('scale1', self.scale1.shape, + self.scale1.dtype) + loc2 = paddle.static.data('loc2', self.loc2.shape, self.loc2.dtype) + scale2 = paddle.static.data('scale2', self.scale2.shape, + self.scale2.dtype) + + self.ln_a = LogNormal(loc=loc1, scale=scale1) + self.ln_b = LogNormal(loc=loc2, scale=scale2) + self.normal_a = Normal(loc=loc1, scale=scale1) + self.normal_b = Normal(loc=loc2, scale=scale2) + + kl0 = self.ln_a.kl_divergence(self.ln_b) + kl1 = kl_divergence(self.ln_a, self.ln_b) + kl_normal = kl_divergence(self.normal_a, self.normal_b) + kl_formula = self._kl(self.ln_a, self.ln_b) + + fetch_list = [kl0, kl1, kl_normal, kl_formula] + self.feeds = { + 'loc1': self.loc1, + 'scale1': self.scale1, + 'loc2': self.loc2, + 'scale2': self.scale2 + } + + executor.run(startup_program) + [self.kl0, self.kl1, self.kl_normal, + self.kl_formula] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_kl_divergence(self): + np.testing.assert_allclose(self.kl0, + self.kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + + np.testing.assert_allclose(self.kl1, + self.kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + + np.testing.assert_allclose(self.kl_normal, + self.kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + + def _kl(self, dist1, dist2): + loc1 = dist1.loc + loc2 = dist2.loc + scale1 = (dist1.scale) + scale2 = (dist2.scale) + var_ratio = (scale1 / scale2) + var_ratio = var_ratio * var_ratio + t1 = ((loc1 - loc2) / scale2) + t1 = (t1 * t1) + return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py index f9beb6b7702f8edfc573e00f55af0213b49b30ed..56341d7fc0ef814e486c7532680c99a2c69c35e8 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py @@ -154,13 +154,13 @@ class TestMultinomialException(unittest.TestCase): self.main_program = paddle.static.Program() self.executor = paddle.static.Executor(self.place) - with paddle.static.program_guard(main_program, startup_program): + with paddle.static.program_guard(self.main_program, startup_program): probs = paddle.static.data('probs', self.probs.shape, self.probs.dtype) dist = paddle.distribution.Multinomial(self.total_count, probs) self.feed = {'probs': self.probs} - executor.run(startup_program) + self.executor.run(startup_program) def TestInit(self): with self.assertRaises(ValueError): diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py index 5023905caa744f78d82c102cc262365ba590f860..1873ac7efa6b56aaeb49b3ff5afb2bf3d120a54a 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py @@ -15,12 +15,14 @@ import math import unittest +import config import numpy as np import paddle from paddle import fluid from paddle.distribution import * from paddle.fluid import layers - +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +import scipy.stats from test_distribution import DistributionNumpy np.random.seed(2022) @@ -131,7 +133,6 @@ class NormalTest(unittest.TestCase): # There is a loss of accuracy in this conversion. # So set the tolerance from 1e-6 to 1e-4. log_tolerance = 1e-4 - np.testing.assert_equal(sample.shape, np_sample.shape) np.testing.assert_allclose(entropy, np_entropy, @@ -499,5 +500,123 @@ class NormalTest10(NormalTest): dtype='float32') +@place(config.DEVICES) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (4, )), xrand((4, )))]) +class TestNormalSampleDygraph(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + self.paddle_normal = Normal(loc=self.loc, scale=self.scale) + n = 100000 + self.sample_shape = (n, ) + self.rsample_shape = (n, ) + self.samples = self.paddle_normal.sample(self.sample_shape) + self.rsamples = self.paddle_normal.rsample(self.rsample_shape) + + def test_sample(self): + samples_mean = self.samples.mean(axis=0) + samples_var = self.samples.var(axis=0) + np.testing.assert_allclose(samples_mean, + self.paddle_normal.mean, + rtol=0.1, + atol=0) + np.testing.assert_allclose(samples_var, + self.paddle_normal.variance, + rtol=0.1, + atol=0) + + rsamples_mean = self.rsamples.mean(axis=0) + rsamples_var = self.rsamples.var(axis=0) + np.testing.assert_allclose(rsamples_mean, + self.paddle_normal.mean, + rtol=0.1, + atol=0) + np.testing.assert_allclose(rsamples_var, + self.paddle_normal.variance, + rtol=0.1, + atol=0) + + batch_shape = (self.loc + self.scale).shape + self.assertEqual(self.samples.shape, + list(self.sample_shape + batch_shape)) + self.assertEqual(self.rsamples.shape, + list(self.rsample_shape + batch_shape)) + + for i in range(len(self.scale)): + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i])) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest(samples, + scipy.stats.norm(loc=loc, scale=scale).cdf) + return ks < 0.02 + + +@place(config.DEVICES) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (4, )), xrand((4, )))]) +class TestNormalSampleStaic(unittest.TestCase): + + def setUp(self): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + n = 100000 + self.sample_shape = (n, ) + self.rsample_shape = (n, ) + self.paddle_normal = Normal(loc=loc, scale=scale) + mean = self.paddle_normal.mean + variance = self.paddle_normal.variance + samples = self.paddle_normal.sample(self.sample_shape) + rsamples = self.paddle_normal.rsample(self.rsample_shape) + fetch_list = [mean, variance, samples, rsamples] + self.feeds = {'loc': self.loc, 'scale': self.scale} + + executor.run(startup_program) + [self.mean, self.variance, self.samples, + self.rsamples] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_sample(self): + samples_mean = self.samples.mean(axis=0) + samples_var = self.samples.var(axis=0) + np.testing.assert_allclose(samples_mean, self.mean, rtol=0.1, atol=0) + np.testing.assert_allclose(samples_var, self.variance, rtol=0.1, atol=0) + + rsamples_mean = self.rsamples.mean(axis=0) + rsamples_var = self.rsamples.var(axis=0) + np.testing.assert_allclose(rsamples_mean, self.mean, rtol=0.1, atol=0) + np.testing.assert_allclose(rsamples_var, + self.variance, + rtol=0.1, + atol=0) + + batch_shape = (self.loc + self.scale).shape + self.assertEqual(self.samples.shape, self.sample_shape + batch_shape) + self.assertEqual(self.rsamples.shape, self.rsample_shape + batch_shape) + + for i in range(len(self.scale)): + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i])) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest(samples, + scipy.stats.norm(loc=loc, scale=scale).cdf) + return ks < 0.02 + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py index c47250195daab1fd8f25701a32b9639ac303a177..15fd94117f0080ab10f6ff13ff62654fc3bc4003 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py @@ -61,6 +61,13 @@ class TestIndependent(unittest.TestCase): self.assertEqual(tuple(data.shape), expected_shape) self.assertEqual(data.dtype, self.base.loc.dtype) + def test_rsample(self): + shape = [5, 10, 8] + expected_shape = (5, 10, 8, 1) + data = self._t.rsample(shape) + self.assertEqual(tuple(data.shape), expected_shape) + self.assertEqual(data.dtype, self.base.loc.dtype) + if __name__ == '__main__': unittest.main()