未验证 提交 af6d80fb 编写于 作者: M MayYouBeProsperous 提交者: GitHub

【Hackathon No.10】新增 LogNormal API (#46426)

* add LogNormal API

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* add comment

* fix bug

* fix docs

* fix bug

* fix bug

* fix bug

* add test

* add test

* change the args type of Normal sample

* fix bug

* fix bug

* fix bug

* fix bug

* add test

* add test

* format

* add comment

* add comment

* add comment

* add comment

* format code

* fix bug

* fix bug

* fix bug

* add comment

* remove name parameter for LogNormal

* organize imports
上级 c4bbe5d9
...@@ -20,6 +20,7 @@ from paddle.distribution.distribution import Distribution ...@@ -20,6 +20,7 @@ from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.independent import Independent from paddle.distribution.independent import Independent
from paddle.distribution.kl import kl_divergence, register_kl from paddle.distribution.kl import kl_divergence, register_kl
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.multinomial import Multinomial from paddle.distribution.multinomial import Multinomial
from paddle.distribution.normal import Normal from paddle.distribution.normal import Normal
from paddle.distribution.transform import * # noqa: F403 from paddle.distribution.transform import * # noqa: F403
...@@ -31,7 +32,7 @@ from paddle.distribution.laplace import Laplace ...@@ -31,7 +32,7 @@ from paddle.distribution.laplace import Laplace
__all__ = [ # noqa __all__ = [ # noqa
'Beta', 'Categorical', 'Dirichlet', 'Distribution', 'ExponentialFamily', 'Beta', 'Categorical', 'Dirichlet', 'Distribution', 'ExponentialFamily',
'Multinomial', 'Normal', 'Uniform', 'kl_divergence', 'register_kl', 'Multinomial', 'Normal', 'Uniform', 'kl_divergence', 'register_kl',
'Independent', 'TransformedDistribution', 'Laplace' 'Independent', 'TransformedDistribution', 'Laplace', 'LogNormal'
] ]
__all__.extend(transform.__all__) __all__.extend(transform.__all__)
...@@ -21,6 +21,7 @@ from paddle.distribution.dirichlet import Dirichlet ...@@ -21,6 +21,7 @@ from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.normal import Normal from paddle.distribution.normal import Normal
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.uniform import Uniform from paddle.distribution.uniform import Uniform
from paddle.distribution.laplace import Laplace from paddle.distribution.laplace import Laplace
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
...@@ -96,7 +97,7 @@ def register_kl(cls_p, cls_q): ...@@ -96,7 +97,7 @@ def register_kl(cls_p, cls_q):
def _dispatch(cls_p, cls_q): def _dispatch(cls_p, cls_q):
"""Multiple dispatch into concrete implement function""" """Multiple dispatch into concrete implement function."""
# find all matched super class pair of p and q # find all matched super class pair of p and q
matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE
...@@ -212,5 +213,10 @@ def _kl_expfamily_expfamily(p, q): ...@@ -212,5 +213,10 @@ def _kl_expfamily_expfamily(p, q):
return kl return kl
@register_kl(LogNormal, LogNormal)
def _kl_lognormal_lognormal(p, q):
return p._base.kl_divergence(q._base)
def _sum_rightmost(value, n): def _sum_rightmost(value, n):
return value.sum(list(range(-n, 0))) if n > 0 else value return value.sum(list(range(-n, 0))) if n > 0 else value
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.distribution.normal import Normal
from paddle.distribution.transform import ExpTransform
from paddle.distribution.transformed_distribution import \
TransformedDistribution
class LogNormal(TransformedDistribution):
r"""The LogNormal distribution with location `loc` and `scale` parameters.
.. math::
X \sim Normal(\mu, \sigma)
Y = exp(X) \sim LogNormal(\mu, \sigma)
Due to LogNormal distribution is based on the transformation of Normal distribution, we call that :math:`Normal(\mu, \sigma)` is the underlying distribution of :math:`LogNormal(\mu, \sigma)`
Mathematical details
The probability density function (pdf) is
.. math::
pdf(x; \mu, \sigma) = \frac{1}{\sigma x \sqrt{2\pi}}e^{(-\frac{(ln(x) - \mu)^2}{2\sigma^2})}
In the above equation:
* :math:`loc = \mu`: is the means of the underlying Normal distribution.
* :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution.
Args:
loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of the underlying Normal distribution.
scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of the underlying Normal distribution.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import LogNormal
# Define a single scalar LogNormal distribution.
dist = LogNormal(loc=0., scale=3.)
# Define a batch of two scalar valued LogNormals.
# The underlying Normal of first has mean 1 and standard deviation 11, the underlying Normal of second 2 and 22.
dist = LogNormal(loc=[1., 2.], scale=[11., 22.])
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample((3, ))
# Define a batch of two scalar valued LogNormals.
# Their underlying Normal have mean 1, but different standard deviations.
dist = LogNormal(loc=1., scale=[11., 22.])
# Complete example
value_tensor = paddle.to_tensor([0.8], dtype="float32")
lognormal_a = LogNormal([0.], [1.])
lognormal_b = LogNormal([0.5], [2.])
sample = lognormal_a.sample((2, ))
# a random tensor created by lognormal distribution with shape: [2, 1]
entropy = lognormal_a.entropy()
# [1.4189385] with shape: [1]
lp = lognormal_a.log_prob(value_tensor)
# [-0.72069150] with shape: [1]
p = lognormal_a.probs(value_tensor)
# [0.48641577] with shape: [1]
kl = lognormal_a.kl_divergence(lognormal_b)
# [0.34939718] with shape: [1]
"""
def __init__(self, loc, scale):
self._base = Normal(loc=loc, scale=scale)
self.loc = self._base.loc
self.scale = self._base.scale
super(LogNormal, self).__init__(self._base, [ExpTransform()])
@property
def mean(self):
"""Mean of lognormal distribuion.
Returns:
Tensor: mean value.
"""
return paddle.exp(self._base.mean + self._base.variance / 2)
@property
def variance(self):
"""Variance of lognormal distribution.
Returns:
Tensor: variance value.
"""
return (paddle.expm1(self._base.variance) *
paddle.exp(2 * self._base.mean + self._base.variance))
def entropy(self):
r"""Shannon entropy in nats.
The entropy is
.. math::
entropy(\sigma) = 0.5 \log (2 \pi e \sigma^2) + \mu
In the above equation:
* :math:`loc = \mu`: is the mean of the underlying Normal distribution.
* :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution.
Returns:
Tensor: Shannon entropy of lognormal distribution.
"""
return self._base.entropy() + self._base.mean
def probs(self, value):
"""Probability density/mass function.
Args:
value (Tensor): The input tensor.
Returns:
Tensor: probability.The data type is same with :attr:`value` .
"""
return paddle.exp(self.log_prob(value))
def kl_divergence(self, other):
r"""The KL-divergence between two lognormal distributions.
The probability density function (pdf) is
.. math::
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio})
.. math::
ratio = \frac{\sigma_0}{\sigma_1}
.. math::
diff = \mu_1 - \mu_0
In the above equation:
* :math:`loc = \mu_0`: is the means of current underlying Normal distribution.
* :math:`scale = \sigma_0`: is the stddevs of current underlying Normal distribution.
* :math:`loc = \mu_1`: is the means of other underlying Normal distribution.
* :math:`scale = \sigma_1`: is the stddevs of other underlying Normal distribution.
* :math:`ratio`: is the ratio of scales.
* :math:`diff`: is the difference between means.
Args:
other (LogNormal): instance of LogNormal.
Returns:
Tensor: kl-divergence between two lognormal distributions.
"""
return self._base.kl_divergence(other._base)
...@@ -16,6 +16,7 @@ import math ...@@ -16,6 +16,7 @@ import math
import warnings import warnings
import numpy as np import numpy as np
import paddle
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.distribution import distribution from paddle.distribution import distribution
from paddle.fluid import core from paddle.fluid import core
...@@ -25,6 +26,10 @@ from paddle.fluid.framework import _non_static_mode, in_dygraph_mode ...@@ -25,6 +26,10 @@ from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, elementwise_mul, elementwise_sub, nn, ops,
tensor) tensor)
try:
from collections.abc import Iterable
except:
from collections import Iterable
class Normal(distribution.Distribution): class Normal(distribution.Distribution):
...@@ -128,21 +133,42 @@ class Normal(distribution.Distribution): ...@@ -128,21 +133,42 @@ class Normal(distribution.Distribution):
self.scale = tensor.cast(self.scale, dtype=self.dtype) self.scale = tensor.cast(self.scale, dtype=self.dtype)
super(Normal, self).__init__(self.loc.shape) super(Normal, self).__init__(self.loc.shape)
def sample(self, shape, seed=0): @property
def mean(self):
"""Mean of multinomial distribuion.
Returns:
Tensor: mean value.
"""
return self.loc
@property
def variance(self):
"""Variance of lognormal distribution.
Returns:
Tensor: variance value.
"""
return self.scale.pow(2)
def sample(self, shape=(), seed=0):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
Args: Args:
shape (list): 1D `int32`. Shape of the generated samples. shape (Sequence[int], optional): Shape of the generated samples.
seed (int): Python integer number. seed (int): Python integer number.
Returns: Returns:
Tensor, A tensor with prepended dimensions shape.The data type is float32. Tensor, A tensor with prepended dimensions shape.The data type is float32.
""" """
if not isinstance(shape, Iterable):
raise TypeError('sample shape must be Iterable object.')
if not _non_static_mode(): if not _non_static_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample') check_type(seed, 'seed', (int), 'sample')
shape = list(shape)
batch_shape = list((self.loc + self.scale).shape) batch_shape = list((self.loc + self.scale).shape)
name = self.name + '_sample' name = self.name + '_sample'
...@@ -162,14 +188,32 @@ class Normal(distribution.Distribution): ...@@ -162,14 +188,32 @@ class Normal(distribution.Distribution):
return output return output
else: else:
output_shape = shape + batch_shape output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \ output = nn.gaussian_random(
(tensor.zeros(output_shape, dtype=self.dtype) + self.scale) output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * (
tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name) output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float: if self.all_arg_is_float:
return nn.reshape(output, shape, name=name) return nn.reshape(output, shape, name=name)
else: else:
return output return output
def rsample(self, shape=()):
"""Generate reparameterized samples of the specified shape.
Args:
shape (Sequence[int], optional): Shape of the generated samples.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
if not isinstance(shape, Iterable):
raise TypeError('sample shape must be Iterable object.')
shape = self._extend_shape(tuple(shape))
eps = paddle.normal(shape=shape)
return (self.loc + eps * self.scale)
def entropy(self): def entropy(self):
r"""Shannon entropy in nats. r"""Shannon entropy in nats.
...@@ -204,7 +248,7 @@ class Normal(distribution.Distribution): ...@@ -204,7 +248,7 @@ class Normal(distribution.Distribution):
value (Tensor): The input tensor. value (Tensor): The input tensor.
Returns: Returns:
Tensor: log probability.The data type is same with value. Tensor: log probability.The data type is same with :attr:`value` .
""" """
name = self.name + '_log_prob' name = self.name + '_log_prob'
...@@ -224,7 +268,7 @@ class Normal(distribution.Distribution): ...@@ -224,7 +268,7 @@ class Normal(distribution.Distribution):
value (Tensor): The input tensor. value (Tensor): The input tensor.
Returns: Returns:
Tensor, probability. The data type is same with value. Tensor, probability. The data type is same with :attr:`value` .
""" """
name = self.name + '_probs' name = self.name + '_probs'
......
...@@ -142,7 +142,7 @@ class Transform(object): ...@@ -142,7 +142,7 @@ class Transform(object):
input, [self]) input, [self])
if isinstance(input, Transform): if isinstance(input, Transform):
return ChainTransform([self, input]) return ChainTransform([self, input])
return self.forward(x) return self.forward(input)
def forward(self, x): def forward(self, x):
"""Forward transformation with mapping :math:`y = f(x)`. """Forward transformation with mapping :math:`y = f(x)`.
...@@ -285,7 +285,7 @@ class Transform(object): ...@@ -285,7 +285,7 @@ class Transform(object):
if hasattr(self, '_forward_log_det_jacobian'): if hasattr(self, '_forward_log_det_jacobian'):
return self._forward_log_det_jacobian(x) return self._forward_log_det_jacobian(x)
if hasattr(self, '_inverse_log_det_jacobian'): if hasattr(self, '_inverse_log_det_jacobian'):
return -self._inverse_log_det_jacobian(self.forward(y)) return -self._inverse_log_det_jacobian(self.forward(x))
raise NotImplementedError( raise NotImplementedError(
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian' 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian'
'is implemented. One of them is required.') 'is implemented. One of them is required.')
...@@ -1133,8 +1133,8 @@ class StickBreakingTransform(Transform): ...@@ -1133,8 +1133,8 @@ class StickBreakingTransform(Transform):
offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
z = F.sigmoid(x - offset.log()) z = F.sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1) z_cumprod = (1 - z).cumprod(-1)
return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \ return F.pad(z, [0] * 2 * (len(x.shape) - 1) + [0, 1], value=1) * \
F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1) F.pad(z_cumprod, [0] * 2 * (len(x.shape) - 1) + [1, 0], value=1)
def _inverse(self, y): def _inverse(self, y):
y_crop = y[..., :-1] y_crop = y[..., :-1]
......
...@@ -61,9 +61,10 @@ class TransformedDistribution(distribution.Distribution): ...@@ -61,9 +61,10 @@ class TransformedDistribution(distribution.Distribution):
raise TypeError("All element of transforms must be Transform type.") raise TypeError("All element of transforms must be Transform type.")
chain = transform.ChainTransform(transforms) chain = transform.ChainTransform(transforms)
if len(base.batch_shape + base.event_shape) < chain._domain.event_rank: base_shape = base.batch_shape + base.event_shape
if len(base_shape) < chain._domain.event_rank:
raise ValueError( raise ValueError(
f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}." f"'base' needs to have shape with size at least {chain._domain.event_rank}, but got {len(base_shape)}."
) )
if chain._domain.event_rank > len(base.event_shape): if chain._domain.event_rank > len(base.event_shape):
base = independent.Independent( base = independent.Independent(
...@@ -74,7 +75,7 @@ class TransformedDistribution(distribution.Distribution): ...@@ -74,7 +75,7 @@ class TransformedDistribution(distribution.Distribution):
transformed_shape = chain.forward_shape(base.batch_shape + transformed_shape = chain.forward_shape(base.batch_shape +
base.event_shape) base.event_shape)
transformed_event_rank = chain._codomain.event_rank + \ transformed_event_rank = chain._codomain.event_rank + \
max(len(base.event_shape)-chain._domain.event_rank, 0) max(len(base.event_shape) - chain._domain.event_rank, 0)
super(TransformedDistribution, self).__init__( super(TransformedDistribution, self).__init__(
transformed_shape[:len(transformed_shape) - transformed_event_rank], transformed_shape[:len(transformed_shape) - transformed_event_rank],
transformed_shape[len(transformed_shape) - transformed_event_rank:]) transformed_shape[len(transformed_shape) - transformed_event_rank:])
...@@ -83,7 +84,7 @@ class TransformedDistribution(distribution.Distribution): ...@@ -83,7 +84,7 @@ class TransformedDistribution(distribution.Distribution):
"""Sample from ``TransformedDistribution``. """Sample from ``TransformedDistribution``.
Args: Args:
shape (tuple, optional): The sample shape. Defaults to (). shape (Sequence[int], optional): The sample shape. Defaults to ().
Returns: Returns:
[Tensor]: The sample result. [Tensor]: The sample result.
...@@ -93,6 +94,20 @@ class TransformedDistribution(distribution.Distribution): ...@@ -93,6 +94,20 @@ class TransformedDistribution(distribution.Distribution):
x = t.forward(x) x = t.forward(x)
return x return x
def rsample(self, shape=()):
"""Reparameterized sample from ``TransformedDistribution``.
Args:
shape (Sequence[int], optional): The sample shape. Defaults to ().
Returns:
[Tensor]: The sample result.
"""
x = self._base.rsample(shape)
for t in self._transforms:
x = t.forward(x)
return x
def log_prob(self, value): def log_prob(self, value):
"""The log probability evaluated at value. """The log probability evaluated at value.
...@@ -110,7 +125,7 @@ class TransformedDistribution(distribution.Distribution): ...@@ -110,7 +125,7 @@ class TransformedDistribution(distribution.Distribution):
event_rank += t._domain.event_rank - t._codomain.event_rank event_rank += t._domain.event_rank - t._codomain.event_rank
log_prob = log_prob - \ log_prob = log_prob - \
_sum_rightmost(t.forward_log_det_jacobian( _sum_rightmost(t.forward_log_det_jacobian(
x), event_rank-t._domain.event_rank) x), event_rank - t._domain.event_rank)
y = x y = x
log_prob += _sum_rightmost(self._base.log_prob(y), log_prob += _sum_rightmost(self._base.log_prob(y),
event_rank - len(self._base.event_shape)) event_rank - len(self._base.event_shape))
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import config
import numpy as np
import paddle
from paddle.distribution.kl import kl_divergence
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.normal import Normal
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
import scipy.stats
from test_distribution import DistributionNumpy
class LogNormalNumpy(DistributionNumpy):
def __init__(self, loc, scale):
self.loc = np.array(loc)
self.scale = np.array(scale)
if str(self.loc.dtype) not in ['float32', 'float64']:
self.loc = self.loc.astype('float32')
self.scale = self.scale.astype('float32')
@property
def mean(self):
var = self.scale * self.scale
return np.exp(self.loc + var / 2)
@property
def variance(self):
var = self.scale * self.scale
return (np.exp(var) - 1) * np.exp(2 * self.loc + var)
def log_prob(self, value):
var = self.scale * self.scale
log_scale = np.log(self.scale)
return -(
(np.log(value) - self.loc) *
(np.log(value) - self.loc)) / (2. * var) - log_scale - math.log(
math.sqrt(2. * math.pi)) - np.log(value)
def probs(self, value):
var = self.scale * self.scale
return np.exp(
-1. * ((np.log(value) - self.loc) * (np.log(value) - self.loc)) /
(2. * var)) / (math.sqrt(2 * math.pi) * self.scale * value)
def entropy(self):
return 0.5 + self.loc + 0.5 * np.log(
np.array(2. * math.pi).astype(self.loc.dtype)) + np.log(self.scale)
def kl_divergence(self, other):
var_ratio = (self.scale / other.scale)
var_ratio = var_ratio * var_ratio
t1 = ((self.loc - other.loc) / other.scale)
t1 = (t1 * t1)
return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio))
@place(config.DEVICES)
@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale', 'value'),
[('one-dim', xrand((2, )), xrand((2, )), xrand((2, ))),
('multi-dim', xrand((3, 3)), xrand((3, 3)), xrand((3, 3)))])
class LogNormalTest(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.paddle_lognormal = LogNormal(loc=paddle.to_tensor(self.loc),
scale=paddle.to_tensor(self.scale))
self.np_lognormal = LogNormalNumpy(self.loc, self.scale)
def test_mean(self):
mean = self.paddle_lognormal.mean
np_mean = self.np_lognormal.mean
self.assertEqual(mean.numpy().dtype, np_mean.dtype)
np.testing.assert_allclose(mean,
np_mean,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_variance(self):
var = self.paddle_lognormal.variance
np_var = self.np_lognormal.variance
self.assertEqual(var.numpy().dtype, np_var.dtype)
np.testing.assert_allclose(var,
np_var,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_entropy(self):
entropy = self.paddle_lognormal.entropy()
np_entropy = self.np_lognormal.entropy()
self.assertEqual(entropy.numpy().dtype, np_entropy.dtype)
np.testing.assert_allclose(entropy,
np_entropy,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_probs(self):
with paddle.fluid.dygraph.guard(self.place):
probs = self.paddle_lognormal.probs(paddle.to_tensor(self.value))
np_probs = self.np_lognormal.probs(self.value)
np.testing.assert_allclose(
probs,
np_probs,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_log_prob(self):
with paddle.fluid.dygraph.guard(self.place):
log_prob = self.paddle_lognormal.log_prob(
paddle.to_tensor(self.value))
np_log_prob = self.np_lognormal.log_prob(self.value)
np.testing.assert_allclose(
log_prob,
np_log_prob,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
@place(config.DEVICES)
@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand(
(4, )), xrand((4, ), min=0, max=1))])
class TestLogNormalSample(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale)
n = 100000
self.sample_shape = (n, )
self.rsample_shape = (n, )
self.samples = self.paddle_lognormal.sample(self.sample_shape)
self.rsamples = self.paddle_lognormal.rsample(self.rsample_shape)
def test_sample(self):
samples_mean = self.samples.mean(axis=0)
samples_var = self.samples.var(axis=0)
np.testing.assert_allclose(samples_mean,
self.paddle_lognormal.mean,
rtol=0.1,
atol=0)
np.testing.assert_allclose(samples_var,
self.paddle_lognormal.variance,
rtol=0.1,
atol=0)
rsamples_mean = self.rsamples.mean(axis=0)
rsamples_var = self.rsamples.var(axis=0)
np.testing.assert_allclose(rsamples_mean,
self.paddle_lognormal.mean,
rtol=0.1,
atol=0)
np.testing.assert_allclose(rsamples_var,
self.paddle_lognormal.variance,
rtol=0.1,
atol=0)
batch_shape = (self.loc + self.scale).shape
self.assertEqual(self.samples.shape,
list(self.sample_shape + batch_shape))
self.assertEqual(self.rsamples.shape,
list(self.rsample_shape + batch_shape))
for i in range(len(self.scale)):
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.samples[:, i]))
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i]))
def _kstest(self, loc, scale, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
ks, _ = scipy.stats.kstest(
samples,
scipy.stats.lognorm(s=scale, scale=np.exp(loc)).cdf)
return ks < 0.02
@place(config.DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'loc1', 'scale1', 'loc2', 'scale2'),
[('one-dim', xrand((2, )), xrand((2, )), xrand((2, )), xrand((2, ))),
('multi-dim', xrand((2, 2)), xrand((2, 2)), xrand((2, 2)), xrand((2, 2)))])
class TestLogNormalKL(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.ln_a = LogNormal(loc=paddle.to_tensor(self.loc1),
scale=paddle.to_tensor(self.scale1))
self.ln_b = LogNormal(loc=paddle.to_tensor(self.loc2),
scale=paddle.to_tensor(self.scale2))
self.normal_a = Normal(loc=paddle.to_tensor(self.loc1),
scale=paddle.to_tensor(self.scale1))
self.normal_b = Normal(loc=paddle.to_tensor(self.loc2),
scale=paddle.to_tensor(self.scale2))
def test_kl_divergence(self):
kl0 = self.ln_a.kl_divergence(self.ln_b)
kl1 = kl_divergence(self.ln_a, self.ln_b)
kl_normal = kl_divergence(self.normal_a, self.normal_b)
kl_formula = self._kl(self.ln_a, self.ln_b)
self.assertEqual(tuple(kl0.shape), self.scale1.shape)
self.assertEqual(tuple(kl1.shape), self.scale1.shape)
np.testing.assert_allclose(kl0,
kl_formula,
rtol=config.RTOL.get(str(self.scale1.dtype)),
atol=config.ATOL.get(str(self.scale1.dtype)))
np.testing.assert_allclose(kl1,
kl_formula,
rtol=config.RTOL.get(str(self.scale1.dtype)),
atol=config.ATOL.get(str(self.scale1.dtype)))
np.testing.assert_allclose(kl_normal,
kl_formula,
rtol=config.RTOL.get(str(self.scale1.dtype)),
atol=config.ATOL.get(str(self.scale1.dtype)))
def _kl(self, dist1, dist2):
loc1 = np.array(dist1.loc)
loc2 = np.array(dist2.loc)
scale1 = np.array(dist1.scale)
scale2 = np.array(dist2.scale)
var_ratio = (scale1 / scale2)
var_ratio = var_ratio * var_ratio
t1 = ((loc1 - loc2) / scale2)
t1 = (t1 * t1)
return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import config
import numpy as np
import paddle
from paddle.distribution.kl import kl_divergence
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.normal import Normal
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
import scipy.stats
from test_distribution_lognormal import LogNormalNumpy
@place(config.DEVICES)
@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale', 'value'),
[('one-dim', xrand((2, )), xrand((2, )), xrand((2, ))),
('multi-dim', xrand((3, 3)), xrand((3, 3)), xrand((3, 3)))])
class TestLogNormal(unittest.TestCase):
def setUp(self):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
value = paddle.static.data('value', self.value.shape,
self.value.dtype)
self.paddle_lognormal = LogNormal(loc=loc, scale=scale)
self.np_lognormal = LogNormalNumpy(loc=self.loc, scale=self.scale)
mean = self.paddle_lognormal.mean
var = self.paddle_lognormal.variance
entropy = self.paddle_lognormal.entropy()
probs = self.paddle_lognormal.probs(value)
log_prob = self.paddle_lognormal.log_prob(value)
fetch_list = [mean, var, entropy, probs, log_prob]
self.feeds = {'loc': self.loc, 'scale': self.scale, 'value': self.value}
executor.run(startup_program)
[self.mean, self.var, self.entropy, self.probs,
self.log_prob] = executor.run(main_program,
feed=self.feeds,
fetch_list=fetch_list)
def test_mean(self):
np_mean = self.np_lognormal.mean
self.assertEqual(str(self.mean.dtype).split('.')[-1], self.scale.dtype)
np.testing.assert_allclose(self.mean,
np_mean,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_var(self):
np_var = self.np_lognormal.variance
self.assertEqual(str(self.var.dtype).split('.')[-1], self.scale.dtype)
np.testing.assert_allclose(self.var,
np_var,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_entropy(self):
np_entropy = self.np_lognormal.entropy()
self.assertEqual(
str(self.entropy.dtype).split('.')[-1], self.scale.dtype)
np.testing.assert_allclose(self.entropy,
np_entropy,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_probs(self):
np_probs = self.np_lognormal.probs(self.value)
np.testing.assert_allclose(self.probs,
np_probs,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
def test_log_prob(self):
np_log_prob = self.np_lognormal.log_prob(self.value)
np.testing.assert_allclose(self.log_prob,
np_log_prob,
rtol=config.RTOL.get(str(self.scale.dtype)),
atol=config.ATOL.get(str(self.scale.dtype)))
@place(config.DEVICES)
@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand(
(4, )), xrand((4, ), min=0, max=1))])
class TestLogNormalSample(unittest.TestCase):
def setUp(self):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
n = 100000
self.sample_shape = (n, )
self.rsample_shape = (n, )
self.paddle_lognormal = LogNormal(loc=loc, scale=scale)
mean = self.paddle_lognormal.mean
variance = self.paddle_lognormal.variance
samples = self.paddle_lognormal.sample(self.sample_shape)
rsamples = self.paddle_lognormal.rsample(self.rsample_shape)
fetch_list = [mean, variance, samples, rsamples]
self.feeds = {'loc': self.loc, 'scale': self.scale}
executor.run(startup_program)
[self.mean, self.variance, self.samples,
self.rsamples] = executor.run(main_program,
feed=self.feeds,
fetch_list=fetch_list)
def test_sample(self):
samples_mean = self.samples.mean(axis=0)
samples_var = self.samples.var(axis=0)
np.testing.assert_allclose(samples_mean, self.mean, rtol=0.1, atol=0)
np.testing.assert_allclose(samples_var, self.variance, rtol=0.1, atol=0)
rsamples_mean = self.rsamples.mean(axis=0)
rsamples_var = self.rsamples.var(axis=0)
np.testing.assert_allclose(rsamples_mean, self.mean, rtol=0.1, atol=0)
np.testing.assert_allclose(rsamples_var,
self.variance,
rtol=0.1,
atol=0)
batch_shape = (self.loc + self.scale).shape
self.assertEqual(self.samples.shape, self.sample_shape + batch_shape)
self.assertEqual(self.rsamples.shape, self.rsample_shape + batch_shape)
for i in range(len(self.scale)):
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.samples[:, i]))
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i]))
def _kstest(self, loc, scale, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
ks, _ = scipy.stats.kstest(
samples,
scipy.stats.lognorm(s=scale, scale=np.exp(loc)).cdf)
return ks < 0.02
@place(config.DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'loc1', 'scale1', 'loc2', 'scale2'),
[('one-dim', xrand((2, )), xrand((2, )), xrand((2, )), xrand((2, ))),
('multi-dim', xrand((2, 2)), xrand((2, 2)), xrand((2, 2)), xrand((2, 2)))])
class TestLogNormalKL(unittest.TestCase):
def setUp(self):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
loc1 = paddle.static.data('loc1', self.loc1.shape, self.loc1.dtype)
scale1 = paddle.static.data('scale1', self.scale1.shape,
self.scale1.dtype)
loc2 = paddle.static.data('loc2', self.loc2.shape, self.loc2.dtype)
scale2 = paddle.static.data('scale2', self.scale2.shape,
self.scale2.dtype)
self.ln_a = LogNormal(loc=loc1, scale=scale1)
self.ln_b = LogNormal(loc=loc2, scale=scale2)
self.normal_a = Normal(loc=loc1, scale=scale1)
self.normal_b = Normal(loc=loc2, scale=scale2)
kl0 = self.ln_a.kl_divergence(self.ln_b)
kl1 = kl_divergence(self.ln_a, self.ln_b)
kl_normal = kl_divergence(self.normal_a, self.normal_b)
kl_formula = self._kl(self.ln_a, self.ln_b)
fetch_list = [kl0, kl1, kl_normal, kl_formula]
self.feeds = {
'loc1': self.loc1,
'scale1': self.scale1,
'loc2': self.loc2,
'scale2': self.scale2
}
executor.run(startup_program)
[self.kl0, self.kl1, self.kl_normal,
self.kl_formula] = executor.run(main_program,
feed=self.feeds,
fetch_list=fetch_list)
def test_kl_divergence(self):
np.testing.assert_allclose(self.kl0,
self.kl_formula,
rtol=config.RTOL.get(str(self.scale1.dtype)),
atol=config.ATOL.get(str(self.scale1.dtype)))
np.testing.assert_allclose(self.kl1,
self.kl_formula,
rtol=config.RTOL.get(str(self.scale1.dtype)),
atol=config.ATOL.get(str(self.scale1.dtype)))
np.testing.assert_allclose(self.kl_normal,
self.kl_formula,
rtol=config.RTOL.get(str(self.scale1.dtype)),
atol=config.ATOL.get(str(self.scale1.dtype)))
def _kl(self, dist1, dist2):
loc1 = dist1.loc
loc2 = dist2.loc
scale1 = (dist1.scale)
scale2 = (dist2.scale)
var_ratio = (scale1 / scale2)
var_ratio = var_ratio * var_ratio
t1 = ((loc1 - loc2) / scale2)
t1 = (t1 * t1)
return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio))
if __name__ == '__main__':
unittest.main()
...@@ -154,13 +154,13 @@ class TestMultinomialException(unittest.TestCase): ...@@ -154,13 +154,13 @@ class TestMultinomialException(unittest.TestCase):
self.main_program = paddle.static.Program() self.main_program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place) self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(self.main_program, startup_program):
probs = paddle.static.data('probs', self.probs.shape, probs = paddle.static.data('probs', self.probs.shape,
self.probs.dtype) self.probs.dtype)
dist = paddle.distribution.Multinomial(self.total_count, probs) dist = paddle.distribution.Multinomial(self.total_count, probs)
self.feed = {'probs': self.probs} self.feed = {'probs': self.probs}
executor.run(startup_program) self.executor.run(startup_program)
def TestInit(self): def TestInit(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
import math import math
import unittest import unittest
import config
import numpy as np import numpy as np
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.distribution import * from paddle.distribution import *
from paddle.fluid import layers from paddle.fluid import layers
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
import scipy.stats
from test_distribution import DistributionNumpy from test_distribution import DistributionNumpy
np.random.seed(2022) np.random.seed(2022)
...@@ -131,7 +133,6 @@ class NormalTest(unittest.TestCase): ...@@ -131,7 +133,6 @@ class NormalTest(unittest.TestCase):
# There is a loss of accuracy in this conversion. # There is a loss of accuracy in this conversion.
# So set the tolerance from 1e-6 to 1e-4. # So set the tolerance from 1e-6 to 1e-4.
log_tolerance = 1e-4 log_tolerance = 1e-4
np.testing.assert_equal(sample.shape, np_sample.shape) np.testing.assert_equal(sample.shape, np_sample.shape)
np.testing.assert_allclose(entropy, np.testing.assert_allclose(entropy,
np_entropy, np_entropy,
...@@ -499,5 +500,123 @@ class NormalTest10(NormalTest): ...@@ -499,5 +500,123 @@ class NormalTest10(NormalTest):
dtype='float32') dtype='float32')
@place(config.DEVICES)
@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand(
(4, )), xrand((4, )))])
class TestNormalSampleDygraph(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.paddle_normal = Normal(loc=self.loc, scale=self.scale)
n = 100000
self.sample_shape = (n, )
self.rsample_shape = (n, )
self.samples = self.paddle_normal.sample(self.sample_shape)
self.rsamples = self.paddle_normal.rsample(self.rsample_shape)
def test_sample(self):
samples_mean = self.samples.mean(axis=0)
samples_var = self.samples.var(axis=0)
np.testing.assert_allclose(samples_mean,
self.paddle_normal.mean,
rtol=0.1,
atol=0)
np.testing.assert_allclose(samples_var,
self.paddle_normal.variance,
rtol=0.1,
atol=0)
rsamples_mean = self.rsamples.mean(axis=0)
rsamples_var = self.rsamples.var(axis=0)
np.testing.assert_allclose(rsamples_mean,
self.paddle_normal.mean,
rtol=0.1,
atol=0)
np.testing.assert_allclose(rsamples_var,
self.paddle_normal.variance,
rtol=0.1,
atol=0)
batch_shape = (self.loc + self.scale).shape
self.assertEqual(self.samples.shape,
list(self.sample_shape + batch_shape))
self.assertEqual(self.rsamples.shape,
list(self.rsample_shape + batch_shape))
for i in range(len(self.scale)):
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.samples[:, i]))
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i]))
def _kstest(self, loc, scale, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
ks, _ = scipy.stats.kstest(samples,
scipy.stats.norm(loc=loc, scale=scale).cdf)
return ks < 0.02
@place(config.DEVICES)
@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand(
(4, )), xrand((4, )))])
class TestNormalSampleStaic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
n = 100000
self.sample_shape = (n, )
self.rsample_shape = (n, )
self.paddle_normal = Normal(loc=loc, scale=scale)
mean = self.paddle_normal.mean
variance = self.paddle_normal.variance
samples = self.paddle_normal.sample(self.sample_shape)
rsamples = self.paddle_normal.rsample(self.rsample_shape)
fetch_list = [mean, variance, samples, rsamples]
self.feeds = {'loc': self.loc, 'scale': self.scale}
executor.run(startup_program)
[self.mean, self.variance, self.samples,
self.rsamples] = executor.run(main_program,
feed=self.feeds,
fetch_list=fetch_list)
def test_sample(self):
samples_mean = self.samples.mean(axis=0)
samples_var = self.samples.var(axis=0)
np.testing.assert_allclose(samples_mean, self.mean, rtol=0.1, atol=0)
np.testing.assert_allclose(samples_var, self.variance, rtol=0.1, atol=0)
rsamples_mean = self.rsamples.mean(axis=0)
rsamples_var = self.rsamples.var(axis=0)
np.testing.assert_allclose(rsamples_mean, self.mean, rtol=0.1, atol=0)
np.testing.assert_allclose(rsamples_var,
self.variance,
rtol=0.1,
atol=0)
batch_shape = (self.loc + self.scale).shape
self.assertEqual(self.samples.shape, self.sample_shape + batch_shape)
self.assertEqual(self.rsamples.shape, self.rsample_shape + batch_shape)
for i in range(len(self.scale)):
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.samples[:, i]))
self.assertTrue(
self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i]))
def _kstest(self, loc, scale, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
ks, _ = scipy.stats.kstest(samples,
scipy.stats.norm(loc=loc, scale=scale).cdf)
return ks < 0.02
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -61,6 +61,13 @@ class TestIndependent(unittest.TestCase): ...@@ -61,6 +61,13 @@ class TestIndependent(unittest.TestCase):
self.assertEqual(tuple(data.shape), expected_shape) self.assertEqual(tuple(data.shape), expected_shape)
self.assertEqual(data.dtype, self.base.loc.dtype) self.assertEqual(data.dtype, self.base.loc.dtype)
def test_rsample(self):
shape = [5, 10, 8]
expected_shape = (5, 10, 8, 1)
data = self._t.rsample(shape)
self.assertEqual(tuple(data.shape), expected_shape)
self.assertEqual(data.dtype, self.base.loc.dtype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册