未验证 提交 6735a37a 编写于 作者: X Xiaoxu Chen 提交者: GitHub

Add probability distribution transformation APIs (#40536)

* add random varaiable transformations API for paddle's distribution package

* add TransformedDistribution API for paddle's probability distribution package

* add random variable transformation unitests for static graph

* replace math.prod which not support python3.7 with functools.reduce

* add Independent and TransformedDistribution distribution

* add unittests for constraint

* fix typo and AffineTransform sample code error

* add mean,variance,rsample abstract method for Distribution
上级 7dfd3846
......@@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .beta import Beta
from .categorical import Categorical
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential_family import ExponentialFamily
from .kl import kl_divergence, register_kl
from .multinomial import Multinomial
from .normal import Normal
from .uniform import Uniform
from paddle.distribution import transform
from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.independent import Independent
from paddle.distribution.kl import kl_divergence, register_kl
from paddle.distribution.multinomial import Multinomial
from paddle.distribution.normal import Normal
from paddle.distribution.transform import * # noqa: F403
from paddle.distribution.transformed_distribution import \
TransformedDistribution
from paddle.distribution.uniform import Uniform
__all__ = [ # noqa
'Beta',
......@@ -33,4 +38,8 @@ __all__ = [ # noqa
'Uniform',
'kl_divergence',
'register_kl',
'Independent',
'TransformedDistribution'
]
__all__.extend(transform.__all__)
......@@ -14,12 +14,10 @@
import numbers
import paddle
from paddle.distribution import dirichlet, exponential_family
from .dirichlet import Dirichlet
from .exponential_family import ExponentialFamily
class Beta(ExponentialFamily):
class Beta(exponential_family.ExponentialFamily):
r"""
Beta distribution parameterized by alpha and beta.
......@@ -93,7 +91,8 @@ class Beta(ExponentialFamily):
self.alpha, self.beta = paddle.broadcast_tensors([alpha, beta])
self._dirichlet = Dirichlet(paddle.stack([self.alpha, self.beta], -1))
self._dirichlet = dirichlet.Dirichlet(
paddle.stack([self.alpha, self.beta], -1))
super(Beta, self).__init__(self._dirichlet._batch_shape)
......
......@@ -18,18 +18,18 @@ import warnings
import numpy as np
import paddle
from paddle import _C_ops
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import _non_static_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from .distribution import Distribution
class Categorical(Distribution):
from paddle.distribution import distribution
from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops,
tensor)
from paddle.tensor import arange, concat, gather_nd, multinomial
class Categorical(distribution.Distribution):
r"""
Categorical distribution is a discrete probability distribution that
describes the possible results of a random variable that can take on
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class Constraint(object):
"""Constraint condition for random variable.
"""
def __call__(self, value):
raise NotImplementedError
class Real(Constraint):
def __call__(self, value):
return value == value
class Range(Constraint):
def __init__(self, lower, upper):
self._lower = lower
self._upper = upper
super(Range, self).__init__()
def __call__(self, value):
return self._lower <= value <= self._upper
class Positive(Constraint):
def __call__(self, value):
return value >= 0.
class Simplex(Constraint):
def __call__(self, value):
return paddle.all(value >= 0, axis=-1) and (
(value.sum(-1) - 1).abs() < 1e-6)
real = Real()
positive = Positive()
simplex = Simplex()
......@@ -13,14 +13,13 @@
# limitations under the License.
import paddle
from paddle.distribution import exponential_family
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.framework import _non_static_mode
from ..fluid.layer_helper import LayerHelper
from .exponential_family import ExponentialFamily
class Dirichlet(ExponentialFamily):
class Dirichlet(exponential_family.ExponentialFamily):
r"""
Dirichlet distribution with parameter "concentration".
......
......@@ -27,14 +27,14 @@ import warnings
import numpy as np
import paddle
from paddle import _C_ops
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import _non_static_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops,
tensor)
from paddle.tensor import arange, concat, gather_nd, multinomial
class Distribution(object):
......@@ -78,10 +78,24 @@ class Distribution(object):
"""
return self._event_shape
@property
def mean(self):
"""Mean of distribution"""
raise NotImplementedError
@property
def variance(self):
"""Variance of distribution"""
raise NotImplementedError
def sample(self, shape=()):
"""Sampling from the distribution."""
raise NotImplementedError
def rsample(self, shape=()):
"""reparameterized sample"""
raise NotImplementedError
def entropy(self):
"""The entropy of the distribution."""
raise NotImplementedError
......@@ -96,7 +110,7 @@ class Distribution(object):
Args:
value (Tensor): value which will be evaluated
"""
raise NotImplementedError
return self.log_prob(value).exp()
def log_prob(self, value):
"""Log probability density/mass function."""
......
......@@ -13,12 +13,11 @@
# limitations under the License.
import paddle
from paddle.distribution import distribution
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from ..fluid.framework import _non_static_mode
from .distribution import Distribution
class ExponentialFamily(Distribution):
class ExponentialFamily(distribution.Distribution):
r"""
ExponentialFamily is the base class for probability distributions belonging
to exponential family, whose probability mass/density function has the
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distribution import distribution
class Independent(distribution.Distribution):
r"""
Reinterprets some of the batch dimensions of a distribution as event dimensions.
This is mainly useful for changing the shape of the result of
:meth:`log_prob`.
Args:
base (Distribution): The base distribution.
reinterpreted_batch_rank (int): The number of batch dimensions to
reinterpret as event dimensions.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import independent
beta = paddle.distribution.Beta(paddle.to_tensor([0.5, 0.5]), paddle.to_tensor([0.5, 0.5]))
print(beta.batch_shape, beta.event_shape)
# (2,) ()
print(beta.log_prob(paddle.to_tensor(0.2)))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-0.22843921, -0.22843921])
reinterpreted_beta = independent.Independent(beta, 1)
print(reinterpreted_beta.batch_shape, reinterpreted_beta.event_shape)
# () (2,)
print(reinterpreted_beta.log_prob(paddle.to_tensor([0.2, 0.2])))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-0.45687842])
"""
def __init__(self, base, reinterpreted_batch_rank):
if not isinstance(base, distribution.Distribution):
raise TypeError(
f"Expected type of 'base' is Distribution, but got {type(base)}")
if not (0 < reinterpreted_batch_rank <= len(base.batch_shape)):
raise ValueError(
f"Expected 0 < reinterpreted_batch_rank <= {len(base.batch_shape)}, but got {reinterpreted_batch_rank}"
)
self._base = base
self._reinterpreted_batch_rank = reinterpreted_batch_rank
shape = base.batch_shape + base.event_shape
super(Independent, self).__init__(
batch_shape=shape[:len(base.batch_shape) -
reinterpreted_batch_rank],
event_shape=shape[len(base.batch_shape) -
reinterpreted_batch_rank:])
@property
def mean(self):
return self._base.mean
@property
def variance(self):
return self._base.variance
def sample(self, shape=()):
return self._base.sample(shape)
def log_prob(self, value):
return self._sum_rightmost(
self._base.log_prob(value), self._reinterpreted_batch_rank)
def prob(self, value):
return self.log_prob(value).exp()
def entropy(self):
return self._sum_rightmost(self._base.entropy(),
self._reinterpreted_batch_rank)
def _sum_rightmost(self, value, n):
return value.sum(list(range(-n, 0))) if n > 0 else value
......@@ -15,15 +15,14 @@ import functools
import warnings
import paddle
from ..fluid.framework import _non_static_mode
from .beta import Beta
from .categorical import Categorical
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential_family import ExponentialFamily
from .normal import Normal
from .uniform import Uniform
from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.normal import Normal
from paddle.distribution.uniform import Uniform
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
__all__ = ["register_kl", "kl_divergence"]
......@@ -207,5 +206,4 @@ def _kl_expfamily_expfamily(p, q):
def _sum_rightmost(value, n):
"""Sum elements along rightmost n dim"""
return value.sum(list(range(-n, 0))) if n > 0 else value
......@@ -17,18 +17,17 @@ import warnings
import numpy as np
from paddle import _C_ops
from paddle.distribution import distribution
from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops,
tensor)
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import _non_static_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from .distribution import Distribution
class Normal(Distribution):
class Normal(distribution.Distribution):
r"""The Normal distribution with location `loc` and `scale` parameters.
Mathematical details
......@@ -129,6 +128,7 @@ class Normal(Distribution):
if self.dtype != convert_dtype(self.loc.dtype):
self.loc = tensor.cast(self.loc, dtype=self.dtype)
self.scale = tensor.cast(self.scale, dtype=self.dtype)
super(Normal, self).__init__(self.loc.shape)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import functools
import math
import numbers
import operator
import typing
import paddle
import paddle.nn.functional as F
from paddle.distribution import (constraint, distribution,
transformed_distribution, variable)
__all__ = [ # noqa
'Transform',
'AbsTransform',
'AffineTransform',
'ChainTransform',
'ExpTransform',
'IndependentTransform',
'PowerTransform',
'ReshapeTransform',
'SigmoidTransform',
'SoftmaxTransform',
'StackTransform',
'StickBreakingTransform',
'TanhTransform'
]
class Type(enum.Enum):
"""Mapping type of a transformation.
"""
BIJECTION = 'bijection' # bijective(injective and surjective)
INJECTION = 'injection' # injective-only
SURJECTION = 'surjection' # surjective-only
OTHER = 'other' # general, neither injective nor surjective
@classmethod
def is_injective(cls, _type):
"""Both bijection and injection are injective mapping.
"""
return _type in (cls.BIJECTION, cls.INJECTION)
class Transform(object):
r"""Base class for the transformations of random variables.
``Transform`` can be used to represent any differentiable and injective
function from the subset of :math:`R^n` to subset of :math:`R^m`, generally
used for transforming a random sample generated by ``Distribution``
instance.
Suppose :math:`X` is a K-dimensional random variable with probability
density function :math:`p_X(x)`. A new random variable :math:`Y = f(X)` may
be defined by transforming :math:`X` with a suitably well-behaved funciton
:math:`f`. It suffices for what follows to note that if f is one-to-one and
its inverse :math:`f^{-1}` have a well-defined Jacobian, then the density of
:math:`Y` is
.. math::
p_Y(y) = p_X(f^{-1}(y)) |det J_{f^{-1}}(y)|
where det is the matrix determinant operation and :math:`J_{f^{-1}}(y)` is
the Jacobian matrix of :math:`f^{-1}` evaluated at :math:`y`.
Taking :math:`x = f^{-1}(y)`, the Jacobian matrix is defined by
.. math::
J(y) = \begin{bmatrix}
{\frac{\partial x_1}{\partial y_1}} &{\frac{\partial x_1}{\partial y_2}}
&{\cdots} &{\frac{\partial x_1}{\partial y_K}} \\
{\frac{\partial x_2}{\partial y_1}} &{\frac{\partial x_2}
{\partial y_2}}&{\cdots} &{\frac{\partial x_2}{\partial y_K}} \\
{\vdots} &{\vdots} &{\ddots} &{\vdots}\\
{\frac{\partial x_K}{\partial y_1}} &{\frac{\partial x_K}{\partial y_2}}
&{\cdots} &{\frac{\partial x_K}{\partial y_K}}
\end{bmatrix}
A ``Transform`` can be characterized by three operations:
#. forward
Forward implements :math:`x \rightarrow f(x)`, and is used to convert
one random outcome into another.
#. inverse
Undoes the transformation :math:`y \rightarrow f^{-1}(y)`.
#. log_det_jacobian
The log of the absolute value of the determinant of the matrix of all
first-order partial derivatives of the inverse function.
Subclass typically implement follow methods:
* _forward
* _inverse
* _forward_log_det_jacobian
* _inverse_log_det_jacobian (optional)
If the transform changes the shape of the input, you must also implemented:
* _forward_shape
* _inverse_shape
"""
_type = Type.INJECTION
def __init__(self):
super(Transform, self).__init__()
@classmethod
def _is_injective(cls):
"""Is the transformation type one-to-one or not.
Returns:
bool: ``True`` denotes injective. ``False`` denotes non-injective.
"""
return Type.is_injective(cls._type)
def __call__(self, input):
"""Make this instance as a callable object. The return value is
depening on the input type.
* If the input is a ``Tensor`` instance, return
``self.forward(input)`` .
* If the input is a ``Distribution`` instance, return
``TransformedDistribution(base=input, transforms=[self])`` .
* If the input is a ``Transform`` instance, return
``ChainTransform([self, input])`` .
Args:
input (Tensor|Distribution|Transform): The input value.
Returns:
[Tensor|TransformedDistribution|ChainTransform]: The return value.
"""
if isinstance(input, distribution.Distribution):
return transformed_distribution.TransformedDistribution(input,
[self])
if isinstance(input, Transform):
return ChainTransform([self, input])
return self.forward(x)
def forward(self, x):
"""Forward transformation with mapping :math:`y = f(x)`.
Useful for turning one random outcome into another.
Args:
x (Tensos): Input parameter, generally is a sample generated
from ``Distribution``.
Returns:
Tensor: Outcome of forward transformation.
"""
if not isinstance(x, paddle.fluid.framework.Variable):
raise TypeError(
f"Expected 'x' is a Tensor or Real, but got {type(x)}.")
if x.dim() < self._domain.event_rank:
raise ValueError(
f'The dimensions of x({x.dim()}) should be '
f'grater than or equal to {self._domain.event_rank}')
return self._forward(x)
def inverse(self, y):
"""Inverse transformation :math:`x = f^{-1}(y)`. It's useful for "reversing"
a transformation to compute one probability in terms of another.
Args:
y (Tensor): Input parameter for inverse transformation.
Returns:
Tensor: Outcome of inverse transform.
"""
if not isinstance(y, paddle.fluid.framework.Variable):
raise TypeError(
f"Expected 'y' is a Tensor or Real, but got {type(y)}.")
if y.dim() < self._codomain.event_rank:
raise ValueError(
f'The dimensions of y({y.dim()}) should be '
f'grater than or equal to {self._codomain.event_rank}')
return self._inverse(y)
def forward_log_det_jacobian(self, x):
"""The log of the absolute value of the determinant of the matrix of all
first-order partial derivatives of the inverse function.
Args:
x (Tensor): Input tensor, generally is a sample generated from
``Distribution``
Returns:
Tensor: The log of the absolute value of Jacobian determinant.
"""
if not isinstance(x, paddle.fluid.framework.Variable):
raise TypeError(
f"Expected 'y' is a Tensor or Real, but got {type(x)}.")
if isinstance(x, paddle.fluid.framework.Variable) and x.dim(
) < self._domain.event_rank:
raise ValueError(
f'The dimensions of x({x.dim()}) should be '
f'grater than or equal to {self._domain.event_rank}')
if not self._is_injective():
raise NotImplementedError(
"forward_log_det_jacobian can't be implemented for non-injective"
"transforms.")
return self._call_forward_log_det_jacobian(x)
def inverse_log_det_jacobian(self, y):
"""Compute :math:`log|det J_{f^{-1}}(y)|`.
Note that ``forward_log_det_jacobian`` is the negative of this function,
evaluated at :math:`f^{-1}(y)`.
Args:
y (Tensor): The input to the ``inverse`` Jacobian determinant
evaluation.
Returns:
Tensor: The value of :math:`log|det J_{f^{-1}}(y)|`.
"""
if not isinstance(y, paddle.fluid.framework.Variable):
raise TypeError(f"Expected 'y' is a Tensor, but got {type(y)}.")
if y.dim() < self._codomain.event_rank:
raise ValueError(
f'The dimensions of y({y.dim()}) should be '
f'grater than or equal to {self._codomain.event_rank}')
return self._call_inverse_log_det_jacobian(y)
def forward_shape(self, shape):
"""Infer the shape of forward transformation.
Args:
shape (Sequence[int]): The input shape.
Returns:
Sequence[int]: The output shape.
"""
if not isinstance(shape, typing.Sequence):
raise TypeError(
f"Expected shape is Sequence[int] type, but got {type(shape)}.")
return self._forward_shape(shape)
def inverse_shape(self, shape):
"""Infer the shape of inverse transformation.
Args:
shape (Sequence[int]): The input shape of inverse transformation.
Returns:
Sequence[int]: The output shape of inverse transformation.
"""
if not isinstance(shape, typing.Sequence):
raise TypeError(
f"Expected shape is Sequence[int] type, but got {type(shape)}.")
return self._inverse_shape(shape)
@property
def _domain(self):
"""The domain of this transformation"""
return variable.real
@property
def _codomain(self):
"""The codomain of this transformation"""
return variable.real
def _forward(self, x):
"""Inner method for publid API ``forward``, subclass should
overwrite this method for supporting forward transformation.
"""
raise NotImplementedError('Forward not implemented')
def _inverse(self, y):
"""Inner method of public API ``inverse``, subclass should
overwrite this method for supporting inverse transformation.
"""
raise NotImplementedError('Inverse not implemented')
def _call_forward_log_det_jacobian(self, x):
"""Inner method called by ``forward_log_det_jacobian``."""
if hasattr(self, '_forward_log_det_jacobian'):
return self._forward_log_det_jacobian(x)
if hasattr(self, '_inverse_log_det_jacobian'):
return -self._inverse_log_det_jacobian(self.forward(y))
raise NotImplementedError(
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian'
'is implemented. One of them is required.')
def _call_inverse_log_det_jacobian(self, y):
"""Inner method called by ``inverse_log_det_jacobian``"""
if hasattr(self, '_inverse_log_det_jacobian'):
return self._inverse_log_det_jacobian(y)
if hasattr(self, '_forward_log_det_jacobian'):
return -self._forward_log_det_jacobian(self._inverse(y))
raise NotImplementedError(
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian '
'is implemented. One of them is required')
def _forward_shape(self, shape):
"""Inner method called by ``forward_shape``, which is used to infer the
forward shape. Subclass should overwrite this method for supporting
``forward_shape``.
"""
return shape
def _inverse_shape(self, shape):
"""Inner method called by ``inverse_shape``, whic is used to infer the
invese shape. Subclass should overwrite this method for supporting
``inverse_shape``.
"""
return shape
class AbsTransform(Transform):
r"""Absolute transformation with formula :math:`y = f(x) = abs(x)`,
element-wise.
This non-injective transformation allows for transformations of scalar
distributions with the absolute value function, which maps ``(-inf, inf)``
to ``[0, inf)`` .
* For ``y`` in ``(0, inf)`` , ``AbsTransform.inverse(y)`` returns the set invese
``{x in (-inf, inf) : |x| = y}`` as a tuple, ``-y, y`` .
* For ``y`` equal ``0`` , ``AbsTransform.inverse(0)`` returns ``0, 0``, which is not
the set inverse (the set inverse is the singleton {0}), but "works" in
conjunction with ``TransformedDistribution`` to produce a left
semi-continuous pdf.
* For ``y`` in ``(-inf, 0)`` , ``AbsTransform.inverse(y)`` returns the
wrong thing ``-y, y``. This is done for efficiency.
Examples:
.. code-block:: python
import paddle
abs = paddle.distribution.AbsTransform()
print(abs.forward(paddle.to_tensor([-1., 0., 1.])))
# Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1., 0., 1.])
print(abs.inverse(paddle.to_tensor(1.)))
# (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-1.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1.]))
# The |dX/dY| is constant 1. So Log|dX/dY| == 0
print(abs.inverse_log_det_jacobian(paddle.to_tensor(1.)))
# (Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# 0.), Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# 0.))
#Special case handling of 0.
print(abs.inverse(paddle.to_tensor(0.)))
# (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.]))
print(abs.inverse_log_det_jacobian(paddle.to_tensor(0.)))
# (Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# 0.), Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# 0.))
"""
_type = Type.SURJECTION
def _forward(self, x):
return x.abs()
def _inverse(self, y):
return -y, y
def _inverse_log_det_jacobian(self, y):
zero = paddle.zeros([1], dtype=y.dtype)
return zero, zero
@property
def _domain(self):
return variable.real
@property
def _codomain(self):
return variable.positive
class AffineTransform(Transform):
r"""Affine transformation with mapping
:math:`y = \text{loc} + \text{scale} \times x`.
Args:
loc (Tensor): The location parameter.
scale (Tensor): The scale parameter.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1., 2.])
affine = paddle.distribution.AffineTransform(paddle.to_tensor(0.), paddle.to_tensor(1.))
print(affine.forward(x))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1., 2.])
print(affine.inverse(affine.forward(x)))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1., 2.])
print(affine.forward_log_det_jacobian(x))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.])
"""
_type = Type.BIJECTION
def __init__(self, loc, scale):
if not isinstance(loc, paddle.fluid.framework.Variable):
raise TypeError(f"Expected 'loc' is a Tensor, but got {type(loc)}")
if not isinstance(scale, paddle.fluid.framework.Variable):
raise TypeError(
f"Expected scale is a Tensor, but got {type(scale)}")
self._loc = loc
self._scale = scale
super(AffineTransform, self).__init__()
@property
def loc(self):
return self._loc
@property
def scale(self):
return self._scale
def _forward(self, x):
return self._loc + self._scale * x
def _inverse(self, y):
return (y - self._loc) / self._scale
def _forward_log_det_jacobian(self, x):
return paddle.abs(self._scale).log()
def _forward_shape(self, shape):
return tuple(
paddle.broadcast_shape(
paddle.broadcast_shape(shape, self._loc.shape),
self._scale.shape))
def _inverse_shape(self, shape):
return tuple(
paddle.broadcast_shape(
paddle.broadcast_shape(shape, self._loc.shape),
self._scale.shape))
@property
def _domain(self):
return variable.real
@property
def _codomain(self):
return variable.real
class ChainTransform(Transform):
r"""Composes multiple transforms in a chain.
Args:
transforms (Sequence[Transform]): A sequence of transformations.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([0., 1., 2., 3.])
chain = paddle.distribution.ChainTransform((
paddle.distribution.AffineTransform(
paddle.to_tensor(0.), paddle.to_tensor(1.)),
paddle.distribution.ExpTransform()
))
print(chain.forward(x))
# Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1. , 2.71828175 , 7.38905621 , 20.08553696])
print(chain.inverse(chain.forward(x)))
# Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0., 1., 2., 3.])
print(chain.forward_log_det_jacobian(x))
# Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0., 1., 2., 3.])
print(chain.inverse_log_det_jacobian(chain.forward(x)))
# Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [ 0., -1., -2., -3.])
"""
def __init__(self, transforms):
if not isinstance(transforms, typing.Sequence):
raise TypeError(
f"Expected type of 'transforms' is Sequence, but got {type(transforms)}"
)
if not all(isinstance(t, Transform) for t in transforms):
raise TypeError(
"All elements of transforms should be Transform type.")
self.transforms = transforms
super(ChainTransform, self).__init__()
def _is_injective(self):
return all(t._is_injective() for t in self.transforms)
def _forward(self, x):
for transform in self.transforms:
x = transform.forward(x)
return x
def _inverse(self, y):
for transform in reversed(self.transforms):
y = transform.inverse(y)
return y
def _forward_log_det_jacobian(self, x):
value = 0.
event_rank = self._domain.event_rank
for t in self.transforms:
value += self._sum_rightmost(
t.forward_log_det_jacobian(x),
event_rank - t._domain.event_rank)
x = t.forward(x)
event_rank += t._codomain.event_rank - t._domain.event_rank
return value
def _forward_shape(self, shape):
for transform in self.transforms:
shape = transform.forward_shape(shape)
return shape
def _inverse_shape(self, shape):
for transform in self.transforms:
shape = transform.inverse_shape(shape)
return shape
def _sum_rightmost(self, value, n):
"""sum value along rightmost n dim"""
return value.sum(list(range(-n, 0))) if n > 0 else value
@property
def _domain(self):
domain = self.transforms[0]._domain
# Compute the lower bound of input dimensions for chain transform.
#
# Suppose the dimensions of input tensor is N, and chain [t0,...ti,...tm],
# ti(in) denotes ti.domain.event_rank, ti(out) denotes ti.codomain.event_rank,
# delta(ti) denotes (ti(out) - ti(in)).
# For transform ti, N shoud satisfy the constraint:
# N + delta(t0) + delta(t1)...delta(t(i-1)) >= ti(in)
# So, for all transform in chain, N shoud satisfy follow constraints:
# t0: N >= t0(in)
# t1: N >= t1(in) - delta(t0)
# ...
# tm: N >= tm(in) - ... - delta(ti) - ... - delta(t0)
#
# Above problem can be solved more effectively use dynamic programming.
# Let N(i) denotes lower bound of transform ti, than the state
# transition equation is:
# N(i) = max{N(i+1)-delta(ti), ti(in)}
event_rank = self.transforms[-1]._codomain.event_rank
for t in reversed(self.transforms):
event_rank -= t._codomain.event_rank - t._domain.event_rank
event_rank = max(event_rank, t._domain.event_rank)
return variable.Independent(domain, event_rank - domain.event_rank)
@property
def _codomain(self):
codomain = self.transforms[-1]._codomain
event_rank = self.transforms[0]._domain.event_rank
for t in self.transforms:
event_rank += t._codomain.event_rank - t._domain.event_rank
event_rank = max(event_rank, t._codomain.event_rank)
return variable.Independent(codomain, event_rank - codomain.event_rank)
class ExpTransform(Transform):
r"""Exponent transformation with mapping :math:`y = \exp(x)`.
Exapmles:
.. code-block:: python
import paddle
exp = paddle.distribution.ExpTransform()
print(exp.forward(paddle.to_tensor([1., 2., 3.])))
# Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [2.71828175 , 7.38905621 , 20.08553696])
print(exp.inverse(paddle.to_tensor([1., 2., 3.])))
# Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0. , 0.69314718, 1.09861231])
print(exp.forward_log_det_jacobian(paddle.to_tensor([1., 2., 3.])))
# Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1., 2., 3.])
print(exp.inverse_log_det_jacobian(paddle.to_tensor([1., 2., 3.])))
# Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [ 0. , -0.69314718, -1.09861231])
"""
_type = Type.BIJECTION
def __init__(self):
super(ExpTransform, self).__init__()
@property
def _domain(self):
return variable.real
@property
def _codomain(self):
return variable.positive
def _forward(self, x):
return x.exp()
def _inverse(self, y):
return y.log()
def _forward_log_det_jacobian(self, x):
return x
class IndependentTransform(Transform):
r"""
``IndependentTransform`` wraps a base transformation, reinterprets
some of the rightmost batch axes as event axes.
Generally, it is used to expand the event axes. This has no effect on the
forward or inverse transformaion, but does sum out the
``reinterpretd_bach_rank`` rightmost dimensions in computing the determinant
of Jacobian matrix.
To see this, consider the ``ExpTransform`` applied to a Tensor which has
sample, batch, and event ``(S,B,E)`` shape semantics. Suppose the Tensor's
paritioned-shape is ``(S=[4], B=[2, 2], E=[3])`` , reinterpreted_batch_rank
is 1. Then the reinterpreted Tensor's shape is ``(S=[4], B=[2], E=[2, 3])`` .
The shape returned by ``forward`` and ``inverse`` is unchanged, ie,
``[4,2,2,3]`` . However the shape returned by ``inverse_log_det_jacobian``
is ``[4,2]``, because the Jacobian determinant is a reduction over the
event dimensions.
Args:
base (Transform): The base transformation.
reinterpreted_batch_rank (int): The num of rightmost batch rank that
will be reinterpreted as event rank.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[1., 2., 3.], [4., 5., 6.]])
# Exponential transform with event_rank = 1
multi_exp = paddle.distribution.IndependentTransform(
paddle.distribution.ExpTransform(), 1)
print(multi_exp.forward(x))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[2.71828175 , 7.38905621 , 20.08553696 ],
# [54.59814835 , 148.41316223, 403.42880249]])
print(multi_exp.forward_log_det_jacobian(x))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [6. , 15.])
"""
def __init__(self, base, reinterpreted_batch_rank):
if not isinstance(base, Transform):
raise TypeError(
f"Expected 'base' is Transform type, but get {type(base)}")
if reinterpreted_batch_rank <= 0:
raise ValueError(
f"Expected 'reinterpreted_batch_rank' is grater than zero, but got {reinterpreted_batch_rank}"
)
self._base = base
self._reinterpreted_batch_rank = reinterpreted_batch_rank
super(IndependentTransform, self).__init__()
def _is_injective(self):
return self._base._is_injective()
def _forward(self, x):
if x.dim() < self._domain.event_rank:
raise ValueError("Input dimensions is less than event dimensions.")
return self._base.forward(x)
def _inverse(self, y):
if y.dim() < self._codomain.event_rank:
raise ValueError("Input dimensions is less than event dimensions.")
return self._base.inverse(y)
def _forward_log_det_jacobian(self, x):
return self._base.forward_log_det_jacobian(x).sum(
list(range(-self._reinterpreted_batch_rank, 0)))
def _forward_shape(self, shape):
return self._base.forward_shape(shape)
def _inverse_shape(self, shape):
return self._base.inverse_shape(shape)
@property
def _domain(self):
return variable.Independent(self._base._domain,
self._reinterpreted_batch_rank)
@property
def _codomain(self):
return variable.Independent(self._base._codomain,
self._reinterpreted_batch_rank)
class PowerTransform(Transform):
r"""
Power transformation with mapping :math:`y = x^{\text{power}}`.
Args:
power (Tensor): The power parameter.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1., 2.])
power = paddle.distribution.PowerTransform(paddle.to_tensor(2.))
print(power.forward(x))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1., 4.])
print(power.inverse(power.forward(x)))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1., 2.])
print(power.forward_log_det_jacobian(x))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.69314718, 1.38629436])
"""
_type = Type.BIJECTION
def __init__(self, power):
if not isinstance(power, paddle.fluid.framework.Variable):
raise TypeError(
f"Expected 'power' is a tensor, but got {type(power)}")
self._power = power
super(PowerTransform, self).__init__()
@property
def power(self):
return self._power
@property
def _domain(self):
return variable.real
@property
def _codomain(self):
return variable.positive
def _forward(self, x):
return x.pow(self._power)
def _inverse(self, y):
return y.pow(1 / self._power)
def _forward_log_det_jacobian(self, x):
return (self._power * x.pow(self._power - 1)).abs().log()
def _forward_shape(self, shape):
return tuple(paddle.broadcast_shape(shape, self._power.shape))
def _inverse_shape(self, shape):
return tuple(paddle.broadcast_shape(shape, self._power.shape))
class ReshapeTransform(Transform):
r"""Reshape the event shape of a tensor.
Note that ``in_event_shape`` and ``out_event_shape`` must have the same
number of elements.
Args:
in_event_shape(Sequence[int]): The input event shape.
out_event_shape(Sequence[int]): The output event shape.
Examples:
.. code-block:: python
import paddle
x = paddle.ones((1,2,3))
reshape_transform = paddle.distribution.ReshapeTransform((2, 3), (3, 2))
print(reshape_transform.forward_shape((1,2,3)))
# (5, 2, 6)
print(reshape_transform.forward(x))
# Tensor(shape=[1, 3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[[1., 1.],
# [1., 1.],
# [1., 1.]]])
print(reshape_transform.inverse(reshape_transform.forward(x)))
# Tensor(shape=[1, 2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[[1., 1., 1.],
# [1., 1., 1.]]])
print(reshape_transform.forward_log_det_jacobian(x))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.])
"""
_type = Type.BIJECTION
def __init__(self, in_event_shape, out_event_shape):
if not isinstance(in_event_shape, typing.Sequence) or not isinstance(
out_event_shape, typing.Sequence):
raise TypeError(
f"Expected type of 'in_event_shape' and 'out_event_shape' is "
f"Squence[int], but got 'in_event_shape': {in_event_shape}, "
f"'out_event_shape': {out_event_shape}")
if functools.reduce(operator.mul, in_event_shape) != functools.reduce(
operator.mul, out_event_shape):
raise ValueError(
f"The numel of 'in_event_shape' should be 'out_event_shape', "
f"but got {functools.reduce(operator.mul, in_event_shape)}!={functools.reduce(operator.mul, out_event_shape)}"
)
self._in_event_shape = tuple(in_event_shape)
self._out_event_shape = tuple(out_event_shape)
super(ReshapeTransform, self).__init__()
@property
def in_event_shape(self):
return self._in_event_shape
@property
def out_event_shape(self):
return self._out_event_shape
@property
def _domain(self):
return variable.Independent(variable.real, len(self._in_event_shape))
@property
def _codomain(self):
return variable.Independent(variable.real, len(self._out_event_shape))
def _forward(self, x):
return x.reshape(
tuple(x.shape)[:x.dim() - len(self._in_event_shape)] +
self._out_event_shape)
def _inverse(self, y):
return y.reshape(
tuple(y.shape)[:y.dim() - len(self._out_event_shape)] +
self._in_event_shape)
def _forward_shape(self, shape):
if len(shape) < len(self._in_event_shape):
raise ValueError(
f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}"
)
if shape[-len(self._in_event_shape):] != self._in_event_shape:
raise ValueError(
f"Event shape mismatch, expected: {self._in_event_shape}, but got {shape[-len(self._in_event_shape):]}"
)
return tuple(shape[:-len(self._in_event_shape)]) + self._out_event_shape
def _inverse_shape(self, shape):
if len(shape) < len(self._out_event_shape):
raise ValueError(
f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}"
)
if shape[-len(self._out_event_shape):] != self._out_event_shape:
raise ValueError(
f"Event shape mismatch, expected: {self._out_event_shape}, but got {shape[-len(self._out_event_shape):]}"
)
return tuple(shape[:-len(self._out_event_shape)]) + self._in_event_shape
def _forward_log_det_jacobian(self, x):
# paddle.zeros not support zero dimension Tensor.
shape = x.shape[:x.dim() - len(self._in_event_shape)] or [1]
return paddle.zeros(shape, dtype=x.dtype)
class SigmoidTransform(Transform):
r"""Sigmoid transformation with mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
Examples:
.. code-block:: python
import paddle
x = paddle.ones((2,3))
t = paddle.distribution.SigmoidTransform()
print(t.forward(x))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0.73105860, 0.73105860, 0.73105860],
# [0.73105860, 0.73105860, 0.73105860]])
print(t.inverse(t.forward(x)))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[1.00000012, 1.00000012, 1.00000012],
# [1.00000012, 1.00000012, 1.00000012]])
print(t.forward_log_det_jacobian(x))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[-1.62652326, -1.62652326, -1.62652326],
# [-1.62652326, -1.62652326, -1.62652326]])
"""
@property
def _domain(self):
return variable.real
@property
def _codomain(self):
return variable.Variable(False, 0, constraint.Range(0., 1.))
def _forward(self, x):
return F.sigmoid(x)
def _inverse(self, y):
return y.log() - (-y).log1p()
def _forward_log_det_jacobian(self, x):
return -F.softplus(-x) - F.softplus(x)
class SoftmaxTransform(Transform):
r"""Softmax transformation with mapping :math:`y=\exp(x)` then normalizing.
It's generally used to convert unconstrained space to simplex. This mapping
is not injective, so ``forward_log_det_jacobian`` and
``inverse_log_det_jacobian`` are not implemented.
Examples:
.. code-block:: python
import paddle
x = paddle.ones((2,3))
t = paddle.distribution.SoftmaxTransform()
print(t.forward(x))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0.33333334, 0.33333334, 0.33333334],
# [0.33333334, 0.33333334, 0.33333334]])
print(t.inverse(t.forward(x)))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[-1.09861231, -1.09861231, -1.09861231],
# [-1.09861231, -1.09861231, -1.09861231]])
"""
_type = Type.OTHER
@property
def _domain(self):
return variable.Independent(variable.real, 1)
@property
def _codomain(self):
return variable.Variable(False, 1, constraint.simplex)
def _forward(self, x):
x = (x - x.max(-1, keepdim=True)[0]).exp()
return x / x.sum(-1, keepdim=True)
def _inverse(self, y):
return y.log()
def _forward_shape(self, shape):
if len(shape) < 1:
raise ValueError(
f"Expected length of shape is grater than 1, but got {len(shape)}"
)
return shape
def _inverse_shape(self, shape):
if len(shape) < 1:
raise ValueError(
f"Expected length of shape is grater than 1, but got {len(shape)}"
)
return shape
class StackTransform(Transform):
r""" ``StackTransform`` applies a sequence of transformations along the
specific axis.
Args:
transforms(Sequence[Transform]): The sequence of transformations.
axis(int): The axis along which will be transformed.
Examples:
.. code-block:: python
import paddle
x = paddle.stack(
(paddle.to_tensor([1., 2., 3.]), paddle.to_tensor([1, 2., 3.])), 1)
t = paddle.distribution.StackTransform(
(paddle.distribution.ExpTransform(),
paddle.distribution.PowerTransform(paddle.to_tensor(2.))),
1
)
print(t.forward(x))
# Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[2.71828175 , 1. ],
# [7.38905621 , 4. ],
# [20.08553696, 9. ]])
print(t.inverse(t.forward(x)))
# Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[1., 1.],
# [2., 2.],
# [3., 3.]])
print(t.forward_log_det_jacobian(x))
# Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[1. , 0.69314718],
# [2. , 1.38629436],
# [3. , 1.79175949]])
"""
def __init__(self, transforms, axis=0):
if not transforms or not isinstance(transforms, typing.Sequence):
raise TypeError(
f"Expected 'transforms' is Sequence[Transform], but got {type(transforms)}."
)
if not all(isinstance(t, Transform) for t in transforms):
raise TypeError(
'Expected all element in transforms is Transform Type.')
if not isinstance(axis, int):
raise TypeError(f"Expected 'axis' is int, but got{type(axis)}.")
self._transforms = transforms
self._axis = axis
def _is_injective(self):
return all(t._is_injective() for t in self._transforms)
@property
def transforms(self):
return self._transforms
@property
def axis(self):
return self._axis
def _forward(self, x):
self._check_size(x)
return paddle.stack([
t.forward(v)
for v, t in zip(paddle.unstack(x, self._axis), self._transforms)
], self._axis)
def _inverse(self, y):
self._check_size(y)
return paddle.stack([
t.inverse(v)
for v, t in zip(paddle.unstack(y, self._axis), self._transforms)
], self._axis)
def _forward_log_det_jacobian(self, x):
self._check_size(x)
return paddle.stack([
t.forward_log_det_jacobian(v)
for v, t in zip(paddle.unstack(x, self._axis), self._transforms)
], self._axis)
def _check_size(self, v):
if not (-v.dim() <= self._axis < v.dim()):
raise ValueError(
f'Input dimensions {v.dim()} should be grater than stack '
f'transform axis {self._axis}.')
if v.shape[self._axis] != len(self._transforms):
raise ValueError(
f'Input size along {self._axis} should be equal to the '
f'length of transforms.')
@property
def _domain(self):
return variable.Stack([t._domain for t in self._transforms], self._axis)
@property
def _codomain(self):
return variable.Stack([t._codomain for t in self._transforms],
self._axis)
class StickBreakingTransform(Transform):
r"""Convert an unconstrained vector to the simplex with one additional
dimension by the stick-breaking construction.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1.,2.,3.])
t = paddle.distribution.StickBreakingTransform()
print(t.forward(x))
# Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.47536686, 0.41287899, 0.10645414, 0.00530004])
print(t.inverse(t.forward(x)))
# Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.99999988, 2. , 2.99999881])
print(t.forward_log_det_jacobian(x))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-9.10835075])
"""
_type = Type.BIJECTION
def _forward(self, x):
offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
z = F.sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1)
return F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1)
def _inverse(self, y):
y_crop = y[..., :-1]
offset = y.shape[-1] - paddle.ones([y_crop.shape[-1]]).cumsum(-1)
sf = 1 - y_crop.cumsum(-1)
x = y_crop.log() - sf.log() + offset.log()
return x
def _forward_log_det_jacobian(self, x):
y = self.forward(x)
offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
x = x - offset.log()
return (-x + F.log_sigmoid(x) + y[..., :-1].log()).sum(-1)
def _forward_shape(self, shape):
if not shape:
raise ValueError(f"Expected 'shape' is not empty, but got {shape}")
return shape[:-1] + (shape[-1] + 1, )
def _inverse_shape(self, shape):
if not shape:
raise ValueError(f"Expected 'shape' is not empty, but got {shape}")
return shape[:-1] + (shape[-1] - 1, )
@property
def _domain(self):
return variable.Independent(variable.real, 1)
@property
def _codomain(self):
return variable.Variable(False, 1, constraint.simplex)
class TanhTransform(Transform):
r"""Tanh transformation with mapping :math:`y = \tanh(x)`.
Examples
.. code-block:: python
import paddle
tanh = paddle.distribution.TanhTransform()
x = paddle.to_tensor([[1., 2., 3.], [4., 5., 6.]])
print(tanh.forward(x))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0.76159418, 0.96402758, 0.99505478],
# [0.99932933, 0.99990922, 0.99998772]])
print(tanh.inverse(tanh.forward(x)))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[1.00000012, 2. , 3.00000286],
# [4.00002146, 5.00009823, 6.00039864]])
print(tanh.forward_log_det_jacobian(x))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[-0.86756170 , -2.65000558 , -4.61865711 ],
# [-6.61437654 , -8.61379623 , -10.61371803]])
print(tanh.inverse_log_det_jacobian(tanh.forward(x)))
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0.86756176 , 2.65000558 , 4.61866283 ],
# [6.61441946 , 8.61399269 , 10.61451530]])
"""
_type = Type.BIJECTION
@property
def _domain(self):
return variable.real
@property
def _codomain(self):
return variable.Variable(False, 0, constraint.Range(-1.0, 1.0))
def _forward(self, x):
return x.tanh()
def _inverse(self, y):
return y.atanh()
def _forward_log_det_jacobian(self, x):
"""We implicitly rely on _forward_log_det_jacobian rather than
explicitly implement ``_inverse_log_det_jacobian`` since directly using
``-tf.math.log1p(-tf.square(y))`` has lower numerical precision.
See details: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
"""
return 2. * (math.log(2.) - x - F.softplus(-2. * x))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from paddle.distribution import distribution
from paddle.distribution import transform
from paddle.distribution import independent
class TransformedDistribution(distribution.Distribution):
r"""
Applies a sequence of Transforms to a base distribution.
Args:
base (Distribution): The base distribution.
transforms (Sequence[Transform]): A sequence of ``Transform`` .
Examples:
.. code-block:: python
import paddle
from paddle.distribution import transformed_distribution
d = transformed_distribution.TransformedDistribution(
paddle.distribution.Normal(0., 1.),
[paddle.distribution.AffineTransform(paddle.to_tensor(1.), paddle.to_tensor(2.))]
)
print(d.sample([10]))
# Tensor(shape=[10], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-0.10697651, 3.33609009, -0.86234951, 5.07457638, 0.75925219,
# -4.17087793, 2.22579336, -0.93845034, 0.66054249, 1.50957513])
print(d.log_prob(paddle.to_tensor(0.5)))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-1.64333570])
"""
def __init__(self, base, transforms):
if not isinstance(base, distribution.Distribution):
raise TypeError(
f"Expected type of 'base' is Distribution, but got {type(base)}."
)
if not isinstance(transforms, typing.Sequence):
raise TypeError(
f"Expected type of 'transforms' is Sequence[Transform] or Chain, but got {type(transforms)}."
)
if not all(isinstance(t, transform.Transform) for t in transforms):
raise TypeError("All element of transforms must be Transform type.")
chain = transform.ChainTransform(transforms)
if len(base.batch_shape + base.event_shape) < chain._domain.event_rank:
raise ValueError(
f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}."
)
if chain._domain.event_rank > len(base.event_shape):
base = independent.Independent(
(base, chain._domain.event_rank - len(base.event_shape)))
self._base = base
self._transforms = transforms
transformed_shape = chain.forward_shape(base.batch_shape +
base.event_shape)
transformed_event_rank = chain._codomain.event_rank + \
max(len(base.event_shape)-chain._domain.event_rank, 0)
super(TransformedDistribution, self).__init__(
transformed_shape[:len(transformed_shape) - transformed_event_rank],
transformed_shape[:len(transformed_shape) - transformed_event_rank])
def sample(self, shape=()):
"""Sample from ``TransformedDistribution``.
Args:
shape (tuple, optional): The sample shape. Defaults to ().
Returns:
[Tensor]: The sample result.
"""
x = self._base.sample(shape)
for t in self._transforms:
x = t.forward(x)
return x
def log_prob(self, value):
"""The log probability evaluated at value.
Args:
value (Tensor): The value to be evaluated.
Returns:
Tensor: The log probability.
"""
log_prob = 0.0
y = value
event_rank = len(self.event_shape)
for t in reversed(self._transforms):
x = t.inverse(y)
event_rank += t._domain.event_rank - t._codomain.event_rank
log_prob = log_prob - \
_sum_rightmost(t.forward_log_det_jacobian(
x), event_rank-t._domain.event_rank)
y = x
log_prob += _sum_rightmost(
self._base.log_prob(y), event_rank - len(self._base.event_shape))
return log_prob
def _sum_rightmost(value, n):
return value.sum(list(range(-n, 0))) if n > 0 else value
......@@ -17,18 +17,18 @@ import warnings
import numpy as np
from paddle import _C_ops
from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import _non_static_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from .distribution import Distribution
class Uniform(Distribution):
from paddle.distribution import distribution
from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops,
tensor)
from paddle.tensor import arange, concat, gather_nd, multinomial
class Uniform(distribution.Distribution):
r"""Uniform distribution with `low` and `high` parameters.
Mathematical Details
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distribution import constraint
class Variable(object):
"""Random variable of probability distribution.
Args:
is_discrete (bool): Is the variable discrete or continuous.
event_rank (int): The rank of event dimensions.
"""
def __init__(self, is_discrete=False, event_rank=0, constraint=None):
self._is_discrete = is_discrete
self._event_rank = event_rank
self._constraint = constraint
@property
def is_discrete(self):
return self._is_discrete
@property
def event_rank(self):
return self._event_rank
def constraint(self, value):
"""Check whether the 'value' meet the constraint conditions of this
random variable."""
return self._constraint(value)
class Real(Variable):
def __init__(self, event_rank=0):
super(Real, self).__init__(False, event_rank, constraint.real)
class Positive(Variable):
def __init__(self, event_rank=0):
super(Positive, self).__init__(False, event_rank, constraint.positive)
class Independent(Variable):
"""Reinterprets some of the batch axes of variable as event axes.
Args:
base (Variable): Base variable.
reinterpreted_batch_rank (int): The rightmost batch rank to be
reinterpreted.
"""
def __init__(self, base, reinterpreted_batch_rank):
self._base = base
self._reinterpreted_batch_rank = reinterpreted_batch_rank
super(Independent, self).__init__(
base.is_discrete, base.event_rank + reinterpreted_batch_rank)
def constraint(self, value):
ret = self._base.constraint(value)
if ret.dim() < self._reinterpreted_batch_rank:
raise ValueError(
"Input dimensions must be equal or grater than {}".format(
self._reinterpreted_batch_rank))
return ret.reshape(ret.shape[:ret.dim() - self.reinterpreted_batch_rank]
+ (-1, )).all(-1)
class Stack(Variable):
def __init__(self, vars, axis=0):
self._vars = vars
self._axis = axis
@property
def is_discrete(self):
return any(var.is_discrete for var in self._vars)
@property
def event_rank(self):
rank = max(var.event_rank for var in self._vars)
if self._axis + rank < 0:
rank += 1
return rank
def constraint(self, value):
if not (-value.dim() <= self._axis < value.dim()):
raise ValueError(
f'Input dimensions {value.dim()} should be grater than stack '
f'constraint axis {self._axis}.')
return paddle.stack([
var.check(value)
for var, value in zip(self._vars, paddle.unstack(value, self._axis))
], self._axis)
real = Real()
positive = Positive()
......@@ -11,11 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import sys
import numpy as np
import paddle
DEVICES = [paddle.CPUPlace()]
......@@ -34,66 +29,3 @@ RTOL = {
'complex128': 1e-5
}
ATOL = {'float32': 0.0, 'complex64': 0, 'float64': 0.0, 'complex128': 0}
def xrand(shape=(10, 10, 10), dtype=DEFAULT_DTYPE, min=1.0, max=10.0):
return ((np.random.rand(*shape).astype(dtype)) * (max - min) + min)
def place(devices, key='place'):
def decorate(cls):
module = sys.modules[cls.__module__].__dict__
raw_classes = {
k: v
for k, v in module.items() if k.startswith(cls.__name__)
}
for raw_name, raw_cls in raw_classes.items():
for d in devices:
test_cls = dict(raw_cls.__dict__)
test_cls.update({key: d})
new_name = raw_name + '.' + d.__class__.__name__
module[new_name] = type(new_name, (raw_cls, ), test_cls)
del module[raw_name]
return cls
return decorate
def parameterize(fields, values=None):
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]
def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for k, v in enumerate(params):
test_cls = dict(cls.__dict__)
test_cls.update(v)
name = cls.__name__ + str(k)
name = name + '.' + v.get('suffix') if v.get('suffix') else name
test_cls_module[name] = type(name, (cls, ), test_cls)
for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls
return decorate
@contextlib.contextmanager
def stgraph(func, *args):
"""static graph exec context"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = func(input, n, axes, norm)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import functools
import inspect
import re
import sys
import numpy as np
import config
TEST_CASE_NAME = 'suffix'
def xrand(shape=(10, 10, 10), dtype=config.DEFAULT_DTYPE, min=1.0, max=10.0):
return ((np.random.rand(*shape).astype(dtype)) * (max - min) + min)
def place(devices, key='place'):
def decorate(cls):
module = sys.modules[cls.__module__].__dict__
raw_classes = {
k: v
for k, v in module.items() if k.startswith(cls.__name__)
}
for raw_name, raw_cls in raw_classes.items():
for d in devices:
test_cls = dict(raw_cls.__dict__)
test_cls.update({key: d})
new_name = raw_name + '.' + d.__class__.__name__
module[new_name] = type(new_name, (raw_cls, ), test_cls)
del module[raw_name]
return cls
return decorate
def parameterize_cls(fields, values=None):
fields = [fields] if isinstance(fields, str) else fields
params = [dict(zip(fields, vals)) for vals in values]
def decorate(cls):
test_cls_module = sys.modules[cls.__module__].__dict__
for k, v in enumerate(params):
test_cls = dict(cls.__dict__)
test_cls.update(v)
name = cls.__name__ + str(k)
name = name + '.' + v.get('suffix') if v.get('suffix') else name
test_cls_module[name] = type(name, (cls, ), test_cls)
for m in list(cls.__dict__):
if m.startswith("test"):
delattr(cls, m)
return cls
return decorate
def parameterize_func(input, name_func=None, doc_func=None,
skip_on_empty=False):
doc_func = doc_func or default_doc_func
name_func = name_func or default_name_func
def wrapper(f, instance=None):
frame_locals = inspect.currentframe().f_back.f_locals
parameters = input_as_callable(input)()
if not parameters:
if not skip_on_empty:
raise ValueError(
"Parameters iterable is empty (hint: use "
"`parameterized.expand([], skip_on_empty=True)` to skip "
"this test when the input is empty)")
return wraps(f)(skip_on_empty_helper)
digits = len(str(len(parameters) - 1))
for num, p in enumerate(parameters):
name = name_func(
f, "{num:0>{digits}}".format(
digits=digits, num=num), p)
# If the original function has patches applied by 'mock.patch',
# re-construct all patches on the just former decoration layer
# of param_as_standalone_func so as not to share
# patch objects between new functions
nf = reapply_patches_if_need(f)
frame_locals[name] = param_as_standalone_func(p, nf, name)
frame_locals[name].__doc__ = doc_func(f, num, p)
# Delete original patches to prevent new function from evaluating
# original patching object as well as re-constructed patches.
delete_patches_if_need(f)
f.__test__ = False
return wrapper
def reapply_patches_if_need(func):
def dummy_wrapper(orgfunc):
@wraps(orgfunc)
def dummy_func(*args, **kwargs):
return orgfunc(*args, **kwargs)
return dummy_func
if hasattr(func, 'patchings'):
func = dummy_wrapper(func)
tmp_patchings = func.patchings
delattr(func, 'patchings')
for patch_obj in tmp_patchings:
func = patch_obj.decorate_callable(func)
return func
def delete_patches_if_need(func):
if hasattr(func, 'patchings'):
func.patchings[:] = []
def default_name_func(func, num, p):
base_name = func.__name__
name_suffix = "_%s" % (num, )
if len(p.args) > 0 and isinstance(p.args[0], str):
name_suffix += "_" + to_safe_name(p.args[0])
return base_name + name_suffix
def default_doc_func(func, num, p):
if func.__doc__ is None:
return None
all_args_with_values = parameterized_argument_value_pairs(func, p)
# Assumes that the function passed is a bound method.
descs = ["%s=%s" % (n, short_repr(v)) for n, v in all_args_with_values]
# The documentation might be a multiline string, so split it
# and just work with the first string, ignoring the period
# at the end if there is one.
first, nl, rest = func.__doc__.lstrip().partition("\n")
suffix = ""
if first.endswith("."):
suffix = "."
first = first[:-1]
args = "%s[with %s]" % (len(first) and " " or "", ", ".join(descs))
return "".join(to_text(x) for x in [first.rstrip(), args, suffix, nl, rest])
def param_as_standalone_func(p, func, name):
@functools.wraps(func)
def standalone_func(*a):
return func(*(a + p.args), **p.kwargs)
standalone_func.__name__ = name
# place_as is used by py.test to determine what source file should be
# used for this test.
standalone_func.place_as = func
# Remove __wrapped__ because py.test will try to look at __wrapped__
# to determine which parameters should be used with this test case,
# and obviously we don't need it to do any parameterization.
try:
del standalone_func.__wrapped__
except AttributeError:
pass
return standalone_func
def input_as_callable(input):
if callable(input):
return lambda: check_input_values(input())
input_values = check_input_values(input)
return lambda: input_values
def check_input_values(input_values):
if not isinstance(input_values, list):
input_values = list(input_values)
return [param.from_decorator(p) for p in input_values]
def skip_on_empty_helper(*a, **kw):
raise SkipTest("parameterized input is empty")
_param = collections.namedtuple("param", "args kwargs")
class param(_param):
def __new__(cls, *args, **kwargs):
return _param.__new__(cls, args, kwargs)
@classmethod
def explicit(cls, args=None, kwargs=None):
""" Creates a ``param`` by explicitly specifying ``args`` and
``kwargs``::
>>> param.explicit([1,2,3])
param(*(1, 2, 3))
>>> param.explicit(kwargs={"foo": 42})
param(*(), **{"foo": "42"})
"""
args = args or ()
kwargs = kwargs or {}
return cls(*args, **kwargs)
@classmethod
def from_decorator(cls, args):
""" Returns an instance of ``param()`` for ``@parameterized`` argument
``args``::
>>> param.from_decorator((42, ))
param(args=(42, ), kwargs={})
>>> param.from_decorator("foo")
param(args=("foo", ), kwargs={})
"""
if isinstance(args, param):
return args
elif isinstance(args, str):
args = (args, )
try:
return cls(*args)
except TypeError as e:
if "after * must be" not in str(e):
raise
raise TypeError(
"Parameters must be tuples, but %r is not (hint: use '(%r, )')"
% (args, args), )
def __repr__(self):
return "param(*%r, **%r)" % self
def to_safe_name(s):
return str(re.sub("[^a-zA-Z0-9_]+", "_", s))
@contextlib.contextmanager
def stgraph(func, *args):
"""static graph exec context"""
paddle.enable_static()
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
input = paddle.static.data('input', x.shape, dtype=x.dtype)
output = func(input, n, axes, norm)
exe = paddle.static.Executor(place)
exe.run(sp)
[output] = exe.run(mp, feed={'input': x}, fetch_list=[output])
yield output
paddle.disable_static()
# alias
parameterize = parameterize_func
param_cls = parameterize_cls
param_func = parameterize_func
......@@ -22,6 +22,7 @@ from paddle.distribution import *
from paddle.fluid import layers
import config
import parameterize
paddle.enable_static()
......@@ -132,11 +133,12 @@ class DistributionTestName(unittest.TestCase):
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'batch_shape', 'event_shape'),
[('test-tuple', (10, 20),
(10, 20)), ('test-list', [100, 100], [100, 200, 300]),
('test-null-eventshape', (100, 100), ())])
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'batch_shape', 'event_shape'),
[('test-tuple', (10, 20), (10, 20)),
('test-list', [100, 100], [100, 200, 300]), ('test-null-eventshape',
(100, 100), ())])
class TestDistributionShape(unittest.TestCase):
def setUp(self):
paddle.disable_static()
......@@ -156,7 +158,7 @@ class TestDistributionShape(unittest.TestCase):
def test_prob(self):
with self.assertRaises(NotImplementedError):
self.dist.prob(paddle.to_tensor(config.xrand()))
self.dist.prob(paddle.to_tensor(parameterize.xrand()))
def test_extend_shape(self):
shapes = [(34, 20), (56, ), ()]
......@@ -164,3 +166,24 @@ class TestDistributionShape(unittest.TestCase):
self.assertTrue(
self.dist._extend_shape(shape),
shape + self.dist.batch_shape + self.dist.event_shape)
class TestDistributionException(unittest.TestCase):
def setUp(self):
self._d = paddle.distribution.Distribution()
def test_mean(self):
with self.assertRaises(NotImplementedError):
self._d.mean
def test_variance(self):
with self.assertRaises(NotImplementedError):
self._d.variance
def test_rsample(self):
with self.assertRaises(NotImplementedError):
self._d.rsample(())
if __name__ == '__main__':
unittest.main()
......@@ -18,14 +18,15 @@ import numpy as np
import paddle
import scipy.stats
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place,
xrand)
import config
from config import ATOL, DEVICES, RTOL
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'alpha', 'beta'),
[('test-scale', 1.0, 2.0), ('test-tensor', xrand(), xrand()),
('test-broadcast', xrand((2, 1)), xrand((2, 5)))])
@parameterize_cls((TEST_CASE_NAME, 'alpha', 'beta'),
[('test-scale', 1.0, 2.0), ('test-tensor', xrand(), xrand()),
('test-broadcast', xrand((2, 1)), xrand((2, 5)))])
class TestBeta(unittest.TestCase):
def setUp(self):
# scale no need convert to tensor for scale input unittest
......@@ -98,3 +99,7 @@ class TestBeta(unittest.TestCase):
self.assertTrue(
self._paddle_beta.sample(case.get('input')).shape ==
case.get('expect'))
if __name__ == '__main__':
unittest.main()
......@@ -18,16 +18,19 @@ import numpy as np
import paddle
import scipy.stats
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place,
xrand)
import config
import parameterize as param
from config import ATOL, RTOL
from parameterize import xrand
paddle.enable_static()
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'alpha', 'beta'), [('test-tensor', xrand(
(10, 10)), xrand((10, 10))), ('test-broadcast', xrand((2, 1)), xrand(
(2, 5))), ('test-larger-data', xrand((10, 20)), xrand((10, 20)))])
@param.place(config.DEVICES)
@param.parameterize_cls(
(param.TEST_CASE_NAME, 'alpha', 'beta'), [('test-tensor', xrand(
(10, 10)), xrand((10, 10))), ('test-broadcast', xrand((2, 1)), xrand(
(2, 5))), ('test-larger-data', xrand((10, 20)), xrand((10, 20)))])
class TestBeta(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
......
......@@ -439,3 +439,7 @@ class DistributionTestError(unittest.TestCase):
cat.log_prob(value)
self.assertRaises(ValueError, test_shape_not_match_error)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.distribution import constraint
import config
import parameterize as param
@param.param_cls((param.TEST_CASE_NAME, 'value'),
[('NotImplement', np.random.rand(2, 3))])
class TestConstraint(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Constraint()
def test_costraint(self):
with self.assertRaises(NotImplementedError):
self._constraint(self.value)
@param.param_cls((param.TEST_CASE_NAME, 'value', 'expect'),
[('real', 1., True)])
class TestReal(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Real()
def test_costraint(self):
self.assertEqual(self._constraint(self.value), self.expect)
@param.param_cls((param.TEST_CASE_NAME, 'lower', 'upper', 'value', 'expect'),
[('in_range', 0, 1, 0.5, True), ('out_range', 0, 1, 2, False)])
class TestRange(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Range(self.lower, self.upper)
def test_costraint(self):
self.assertEqual(self._constraint(self.value), self.expect)
@param.param_cls((param.TEST_CASE_NAME, 'value', 'expect'),
[('positive', 1, True), ('negative', -1, False)])
class TestPositive(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Positive()
def test_costraint(self):
self.assertEqual(self._constraint(self.value), self.expect)
@param.param_cls((param.TEST_CASE_NAME, 'value', 'expect'),
[('simplex', paddle.to_tensor([0.5, 0.5]), True),
('non_simplex', paddle.to_tensor([-0.5, 0.5]), False)])
class TestSimplex(unittest.TestCase):
def setUp(self):
self._constraint = constraint.Simplex()
def test_costraint(self):
self.assertEqual(self._constraint(self.value), self.expect)
if __name__ == '__main__':
unittest.main()
......@@ -19,15 +19,15 @@ import paddle
import scipy.stats
import config
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place,
xrand)
from config import ATOL, DEVICES, RTOL
import parameterize as param
@place(DEVICES)
@parameterize(
(TEST_CASE_NAME, 'concentration'),
@param.place(DEVICES)
@param.param_cls(
(param.TEST_CASE_NAME, 'concentration'),
[
('test-one-dim', config.xrand((89, ))),
('test-one-dim', param.xrand((89, ))),
# ('test-multi-dim', config.xrand((10, 20, 30)))
])
class TestDirichlet(unittest.TestCase):
......@@ -91,14 +91,18 @@ class TestDirichlet(unittest.TestCase):
self.assertTrue(
np.all(
self._paddle_diric._log_normalizer(
paddle.to_tensor(config.xrand((100, 100, 100)))).numpy() <
paddle.to_tensor(param.xrand((100, 100, 100)))).numpy() <
0.0))
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'concentration'),
[('test-zero-dim', np.array(1.0))])
@param.place(DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'concentration'),
[('test-zero-dim', np.array(1.0))])
class TestDirichletException(unittest.TestCase):
def TestInit(self):
with self.assertRaises(ValueError):
paddle.distribution.Dirichlet(
paddle.squeeze(self.concentration))
if __name__ == '__main__':
unittest.main()
......@@ -18,15 +18,15 @@ import numpy as np
import paddle
import scipy.stats
from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place,
xrand)
from config import ATOL, DEVICES, RTOL
from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand
paddle.enable_static()
@place(DEVICES)
@parameterize((TEST_CASE_NAME, 'concentration'),
[('test-one-dim', np.random.rand(89) + 5.0)])
@parameterize_cls((TEST_CASE_NAME, 'concentration'),
[('test-one-dim', np.random.rand(89) + 5.0)])
class TestDirichlet(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
......
......@@ -20,14 +20,15 @@ import scipy.stats
import config
import mock_data as mock
import parameterize
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'dist'), [('test-mock-exp',
mock.Exponential(rate=paddle.rand(
[100, 200, 99],
dtype=config.DEFAULT_DTYPE)))])
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'dist'), [('test-mock-exp',
mock.Exponential(rate=paddle.rand(
[100, 200, 99],
dtype=config.DEFAULT_DTYPE)))])
class TestExponentialFamily(unittest.TestCase):
def test_entropy(self):
np.testing.assert_allclose(
......@@ -37,15 +38,15 @@ class TestExponentialFamily(unittest.TestCase):
atol=config.ATOL.get(config.DEFAULT_DTYPE))
@config.place(config.DEVICES)
@config.parameterize(
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(config.TEST_CASE_NAME, 'dist'),
[('test-dummy', mock.DummyExpFamily(0.5, 0.5)),
('test-dirichlet',
paddle.distribution.Dirichlet(paddle.to_tensor(config.xrand()))), (
paddle.distribution.Dirichlet(paddle.to_tensor(parameterize.xrand()))), (
'test-beta', paddle.distribution.Beta(
paddle.to_tensor(config.xrand()),
paddle.to_tensor(config.xrand())))])
paddle.to_tensor(parameterize.xrand()),
paddle.to_tensor(parameterize.xrand())))])
class TestExponentialFamilyException(unittest.TestCase):
def test_entropy_exception(self):
with self.assertRaises(NotImplementedError):
......
......@@ -20,17 +20,18 @@ import scipy.stats
import config
import mock_data as mock
import parameterize
paddle.enable_static()
@config.place(config.DEVICES)
@parameterize.place(config.DEVICES)
class TestExponentialFamily(unittest.TestCase):
def setUp(self):
self.program = paddle.static.Program()
self.executor = paddle.static.Executor()
with paddle.static.program_guard(self.program):
rate_np = config.xrand((100, 200, 99))
rate_np = parameterize.xrand((100, 200, 99))
rate = paddle.static.data('rate', rate_np.shape, rate_np.dtype)
self.mock_dist = mock.Exponential(rate)
self.feeds = {'rate': rate_np}
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import parameterize as param
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank'),
[('base_beta', paddle.distribution.Beta(
paddle.rand([1, 2]), paddle.rand([1, 2])), 1)])
class TestIndependent(unittest.TestCase):
def setUp(self):
self._t = paddle.distribution.Independent(self.base,
self.reinterpreted_batch_rank)
def test_mean(self):
np.testing.assert_allclose(
self.base.mean,
self._t.mean,
rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))
def test_variance(self):
np.testing.assert_allclose(
self.base.variance,
self._t.variance,
rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))
def test_entropy(self):
np.testing.assert_allclose(
self._np_sum_rightmost(self.base.entropy().numpy(),
self.reinterpreted_batch_rank),
self._t.entropy(),
rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))
def _np_sum_rightmost(self, value, n):
return np.sum(value, tuple(range(-n, 0))) if n > 0 else value
def test_log_prob(self):
value = np.random.rand(1)
np.testing.assert_allclose(
self._np_sum_rightmost(
self.base.log_prob(paddle.to_tensor(value)).numpy(),
self.reinterpreted_batch_rank),
self._t.log_prob(paddle.to_tensor(value)).numpy(),
rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)),
atol=config.ATOL.get(str(self.base.alpha.numpy().dtype)))
# TODO(cxxly): Add Kolmogorov-Smirnov test for sample result.
def test_sample(self):
shape = (5, 10, 8)
expected_shape = (5, 10, 8, 1, 2)
data = self._t.sample(shape)
self.assertEqual(tuple(data.shape), expected_shape)
self.assertEqual(data.dtype, self.base.alpha.dtype)
@param.place(config.DEVICES)
@param.param_cls(
(param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank',
'expected_exception'),
[('base_not_transform', '', 1, TypeError),
('rank_less_than_zero', paddle.distribution.Transform(), -1, ValueError)])
class TestIndependentException(unittest.TestCase):
def test_init(self):
with self.assertRaises(self.expected_exception):
paddle.distribution.IndependentTransform(
self.base, self.reinterpreted_batch_rank)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import parameterize as param
paddle.enable_static()
@param.place(config.DEVICES)
@param.param_cls(
(param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank', 'alpha', 'beta'),
[('base_beta', paddle.distribution.Beta, 1, np.random.rand(1, 2),
np.random.rand(1, 2))])
class TestIndependent(unittest.TestCase):
def setUp(self):
value = np.random.rand(1)
self.dtype = value.dtype
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
alpha = paddle.static.data('alpha', self.alpha.shape,
self.alpha.dtype)
beta = paddle.static.data('beta', self.beta.shape, self.beta.dtype)
self.base = self.base(alpha, beta)
t = paddle.distribution.Independent(self.base,
self.reinterpreted_batch_rank)
mean = t.mean
variance = t.variance
entropy = t.entropy()
static_value = paddle.static.data('value', value.shape, value.dtype)
log_prob = t.log_prob(static_value)
base_mean = self.base.mean
base_variance = self.base.variance
base_entropy = self.base.entropy()
base_log_prob = self.base.log_prob(static_value)
fetch_list = [
mean, variance, entropy, log_prob, base_mean, base_variance,
base_entropy, base_log_prob
]
exe.run(sp)
[
self.mean, self.variance, self.entropy, self.log_prob,
self.base_mean, self.base_variance, self.base_entropy,
self.base_log_prob
] = exe.run(
mp,
feed={'value': value,
'alpha': self.alpha,
'beta': self.beta},
fetch_list=fetch_list)
def test_mean(self):
np.testing.assert_allclose(
self.mean,
self.base_mean,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
def test_variance(self):
np.testing.assert_allclose(
self.variance,
self.base_variance,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
def test_entropy(self):
np.testing.assert_allclose(
self._np_sum_rightmost(self.base_entropy,
self.reinterpreted_batch_rank),
self.entropy,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
def _np_sum_rightmost(self, value, n):
return np.sum(value, tuple(range(-n, 0))) if n > 0 else value
def test_log_prob(self):
np.testing.assert_allclose(
self._np_sum_rightmost(self.base_log_prob,
self.reinterpreted_batch_rank),
self.log_prob,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
if __name__ == '__main__':
unittest.main()
......@@ -19,15 +19,17 @@ import paddle
import scipy.stats
import config
import parameterize
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [
('one-dim', 10, config.xrand((3, ))),
('multi-dim', 9, config.xrand((10, 20))),
('prob-sum-one', 10, np.array([0.5, 0.2, 0.3])),
('prob-sum-non-one', 10, np.array([2., 3., 5.])),
])
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'total_count', 'probs'), [
('one-dim', 10, parameterize.xrand((3, ))),
('multi-dim', 9, parameterize.xrand((10, 20))),
('prob-sum-one', 10, np.array([0.5, 0.2, 0.3])),
('prob-sum-non-one', 10, np.array([2., 3., 5.])),
])
class TestMultinomial(unittest.TestCase):
def setUp(self):
self._dist = paddle.distribution.Multinomial(
......@@ -98,9 +100,9 @@ class TestMultinomial(unittest.TestCase):
return scipy.stats.multinomial.entropy(self.total_count, probs)
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'total_count', 'probs', 'value'),
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'total_count', 'probs', 'value'),
[
('value-float', 10, np.array([0.2, 0.3, 0.5]), np.array([2., 3., 5.])),
('value-int', 10, np.array([0.2, 0.3, 0.5]), np.array([2, 3, 5])),
......@@ -122,12 +124,13 @@ class TestMultinomialPmf(unittest.TestCase):
atol=config.ATOL.get(str(self.probs.dtype)))
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [
('total_count_le_one', 0, np.array([0.3, 0.7])),
('total_count_float', np.array([0.3, 0.7])),
('probs_zero_dim', np.array(0)),
])
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(config.TEST_CASE_NAME, 'total_count', 'probs'), [
('total_count_le_one', 0, np.array([0.3, 0.7])),
('total_count_float', np.array([0.3, 0.7])),
('probs_zero_dim', np.array(0)),
])
class TestMultinomialException(unittest.TestCase):
def TestInit(self):
with self.assertRaises(ValueError):
......
......@@ -19,17 +19,19 @@ import paddle
import scipy.stats
import config
import parameterize
paddle.enable_static()
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [
('one-dim', 5, config.xrand((3, ))),
('multi-dim', 9, config.xrand((2, 3))),
('prob-sum-one', 5, np.array([0.5, 0.2, 0.3])),
('prob-sum-non-one', 5, np.array([2., 3., 5.])),
])
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'total_count', 'probs'), [
('one-dim', 5, parameterize.xrand((3, ))),
('multi-dim', 9, parameterize.xrand((2, 3))),
('prob-sum-one', 5, np.array([0.5, 0.2, 0.3])),
('prob-sum-non-one', 5, np.array([2., 3., 5.])),
])
class TestMultinomial(unittest.TestCase):
def setUp(self):
startup_program = paddle.static.Program()
......@@ -99,9 +101,9 @@ class TestMultinomial(unittest.TestCase):
return scipy.stats.multinomial.entropy(self.total_count, probs)
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'total_count', 'probs', 'value'),
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'total_count', 'probs', 'value'),
[
('value-float', 5, np.array([0.2, 0.3, 0.5]), np.array([1., 1., 3.])),
('value-int', 5, np.array([0.2, 0.3, 0.5]), np.array([2, 2, 1])),
......@@ -139,12 +141,13 @@ class TestMultinomialPmf(unittest.TestCase):
atol=config.ATOL.get(str(self.probs.dtype)))
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [
('total_count_le_one', 0, np.array([0.3, 0.7])),
('total_count_float', np.array([0.3, 0.7])),
('probs_zero_dim', np.array(0)),
])
@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'total_count', 'probs'), [
('total_count_le_one', 0, np.array([0.3, 0.7])),
('total_count_float', np.array([0.3, 0.7])),
('probs_zero_dim', np.array(0)),
])
class TestMultinomialException(unittest.TestCase):
def setUp(self):
startup_program = paddle.static.Program()
......
......@@ -454,3 +454,7 @@ class NormalTest10(NormalTest):
with fluid.program_guard(self.test_program):
self.static_values = layers.data(
name='values', shape=[dims], dtype='float32')
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import unittest
import numpy as np
import paddle
from paddle.distribution import constraint, transform, variable
import config
import parameterize as param
@param.place(config.DEVICES)
class TestTransform(unittest.TestCase):
def setUp(self):
self._t = transform.Transform()
@param.param_func([
(paddle.distribution.Distribution(),
paddle.distribution.TransformedDistribution),
(paddle.distribution.ExpTransform(), paddle.distribution.ChainTransform)
])
def test_call(self, input, expected_type):
t = transform.Transform()
self.assertIsInstance(t(input), expected_type)
@param.param_func(
[(transform.Type.BIJECTION, True), (transform.Type.INJECTION, True),
(transform.Type.SURJECTION, False), (transform.Type.OTHER, False)])
def test_is_injective(self, type, expected):
transform.Transform._type = type
self.assertEqual(self._t._is_injective(), expected)
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Real))
@param.param_func([(0, TypeError), (paddle.rand((2, 3)),
NotImplementedError)])
def test_forward(self, input, expected):
with self.assertRaises(expected):
self._t.forward(input)
@param.param_func([(0, TypeError), (paddle.rand((2, 3)),
NotImplementedError)])
def test_inverse(self, input, expected):
with self.assertRaises(expected):
self._t.inverse(input)
@param.param_func([(0, TypeError), (paddle.rand((2, 3)),
NotImplementedError)])
def test_forward_log_det_jacobian(self, input, expected):
with self.assertRaises(expected):
self._t.forward_log_det_jacobian(input)
@param.param_func([(0, TypeError), (paddle.rand((2, 3)),
NotImplementedError)])
def test_inverse_log_det_jacobian(self, input, expected):
with self.assertRaises(expected):
self._t.inverse_log_det_jacobian(input)
@param.param_func([(0, TypeError)])
def test_forward_shape(self, shape, expected):
with self.assertRaises(expected):
self._t.forward_shape(shape)
@param.param_func([(0, TypeError)])
def test_inverse_shape(self, shape, expected):
with self.assertRaises(expected):
self._t.inverse_shape(shape)
@param.place(config.DEVICES)
class TestAbsTransform(unittest.TestCase):
def setUp(self):
self._t = transform.AbsTransform()
def test_is_injective(self):
self.assertFalse(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Positive))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
@param.param_func([(np.array([-1., 1., 0.]), np.array([1., 1., 0.])),
(np.array([[1., -1., -0.1], [-3., -0.1, 0]]),
np.array([[1., 1., 0.1], [3., 0.1, 0]]))])
def test_forward(self, input, expected):
np.testing.assert_allclose(
self._t.forward(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(np.array(1.), (-np.array(1.), np.array(1.)))])
def test_inverse(self, input, expected):
actual0, actual1 = self._t.inverse(paddle.to_tensor(input))
expected0, expected1 = expected
np.testing.assert_allclose(
actual0.numpy(),
expected0,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
np.testing.assert_allclose(
actual1.numpy(),
expected1,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def test_forward_log_det_jacobian(self):
with self.assertRaises(NotImplementedError):
self._t.forward_log_det_jacobian(paddle.rand((10, )))
@param.param_func([(np.array(1.), (np.array(0.), np.array(0.))), ])
def test_inverse_log_det_jacobian(self, input, expected):
actual0, actual1 = self._t.inverse_log_det_jacobian(
paddle.to_tensor(input))
expected0, expected1 = expected
np.testing.assert_allclose(
actual0.numpy(),
expected0,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
np.testing.assert_allclose(
actual1.numpy(),
expected1,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'loc', 'scale'), [
('normal', np.random.rand(8, 10), np.random.rand(8, 10)),
('broadcast', np.random.rand(2, 10), np.random.rand(10)),
])
class TestAffineTransform(unittest.TestCase):
def setUp(self):
self._t = transform.AffineTransform(
paddle.to_tensor(self.loc), paddle.to_tensor(self.scale))
@param.param_func([
(paddle.rand([1]), 0, TypeError),
(0, paddle.rand([1]), TypeError),
])
def test_init_exception(self, loc, scale, exc):
with self.assertRaises(exc):
paddle.distribution.AffineTransform(loc, scale)
def test_scale(self):
np.testing.assert_allclose(self._t.scale, self.scale)
def test_loc(self):
np.testing.assert_allclose(self._t.loc, self.loc)
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Real))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
def test_forward(self):
x = np.random.random(self.loc.shape)
np.testing.assert_allclose(
self._t.forward(paddle.to_tensor(x)).numpy(),
self._np_forward(x),
rtol=config.RTOL.get(str(self._t.loc.numpy().dtype)),
atol=config.ATOL.get(str(self._t.loc.numpy().dtype)))
def test_inverse(self):
y = np.random.random(self.loc.shape)
np.testing.assert_allclose(
self._t.inverse(paddle.to_tensor(y)).numpy(),
self._np_inverse(y),
rtol=config.RTOL.get(str(self._t.loc.numpy().dtype)),
atol=config.ATOL.get(str(self._t.loc.numpy().dtype)))
def _np_forward(self, x):
return self.loc + self.scale * x
def _np_inverse(self, y):
return (y - self.loc) / self.scale
def _np_forward_jacobian(self, x):
return np.log(np.abs(self.scale))
def _np_inverse_jacobian(self, y):
return -self._np_forward_jacobian(self._np_inverse(y))
def test_inverse_log_det_jacobian(self):
y = np.random.random(self.scale.shape)
np.testing.assert_allclose(
self._t.inverse_log_det_jacobian(paddle.to_tensor(y)).numpy(),
self._np_inverse_jacobian(y),
rtol=config.RTOL.get(str(self._t.loc.numpy().dtype)),
atol=config.ATOL.get(str(self._t.loc.numpy().dtype)))
def test_forward_log_det_jacobian(self):
x = np.random.random(self.scale.shape)
np.testing.assert_allclose(
self._t.forward_log_det_jacobian(paddle.to_tensor(x)).numpy(),
self._np_forward_jacobian(x),
rtol=config.RTOL.get(str(self._t.loc.numpy().dtype)),
atol=config.ATOL.get(str(self._t.loc.numpy().dtype)))
def test_forward_shape(self):
shape = self.loc.shape
self.assertEqual(
tuple(self._t.forward_shape(shape)),
np.broadcast(np.random.random(shape), self.loc, self.scale).shape)
def test_inverse_shape(self):
shape = self.scale.shape
self.assertEqual(
tuple(self._t.forward_shape(shape)),
np.broadcast(np.random.random(shape), self.loc, self.scale).shape)
@param.place(config.DEVICES)
class TestExpTransform(unittest.TestCase):
def setUp(self):
self._t = transform.ExpTransform()
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Positive))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
@param.param_func(
[(np.array([0., 1., 2., 3.]), np.exp(np.array([0., 1., 2., 3.]))),
(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]]),
np.exp(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]])))])
def test_forward(self, input, expected):
np.testing.assert_allclose(
self._t.forward(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(np.array([1., 2., 3.]), np.log(np.array([1., 2., 3.]))),
(np.array([[1., 2., 3.], [6., 7., 8.]]),
np.log(np.array([[1., 2., 3.], [6., 7., 8.]])))])
def test_inverse(self, input, expected):
np.testing.assert_allclose(
self._t.inverse(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(np.array([1., 2., 3.]), ),
(np.array([[1., 2., 3.], [6., 7., 8.]]), )])
def test_forward_log_det_jacobian(self, input):
np.testing.assert_allclose(
self._t.forward_log_det_jacobian(paddle.to_tensor(input)).numpy(),
self._np_forward_jacobian(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def _np_forward_jacobian(self, x):
return x
@param.param_func([(np.array([1., 2., 3.]), ),
(np.array([[1., 2., 3.], [6., 7., 8.]]), )])
def test_inverse_log_det_jacobian(self, input):
np.testing.assert_allclose(
self._t.inverse_log_det_jacobian(paddle.to_tensor(input)).numpy(),
self._np_inverse_jacobian(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def _np_inverse_jacobian(self, y):
return -self._np_forward_jacobian(np.log(y))
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
class TestChainTransform(unittest.TestCase):
@param.param_func([(paddle.distribution.Transform, TypeError),
([0], TypeError)])
def test_init_exception(self, transforms, exception):
with self.assertRaises(exception):
paddle.distribution.ChainTransform(transforms)
@param.param_func((
(transform.ChainTransform(
(transform.AbsTransform(),
transform.AffineTransform(paddle.rand([1]), paddle.rand([1])))),
False), (transform.ChainTransform((
transform.AffineTransform(paddle.rand([1]), paddle.rand([1])),
transform.ExpTransform(), )), True)))
def test_is_injective(self, chain, expected):
self.assertEqual(chain._is_injective(), expected)
@param.param_func(((transform.ChainTransform(
(transform.IndependentTransform(transform.ExpTransform(), 1),
transform.IndependentTransform(transform.ExpTransform(), 10),
transform.IndependentTransform(transform.ExpTransform(), 8))),
variable.Independent(variable.real, 10)), ))
def test_domain(self, input, expected):
self.assertIsInstance(input._domain, type(expected))
self.assertEqual(input._domain.event_rank, expected.event_rank)
self.assertEqual(input._domain.is_discrete, expected.is_discrete)
@param.param_func(((transform.ChainTransform(
(transform.IndependentTransform(transform.ExpTransform(), 9),
transform.IndependentTransform(transform.ExpTransform(), 4),
transform.IndependentTransform(transform.ExpTransform(), 5))),
variable.Independent(variable.real, 9)), ))
def test_codomain(self, input, expected):
self.assertIsInstance(input._codomain, variable.Independent)
self.assertEqual(input._codomain.event_rank, expected.event_rank)
self.assertEqual(input._codomain.is_discrete, expected.is_discrete)
@param.param_func(
[(transform.ChainTransform((transform.AffineTransform(
paddle.to_tensor(0.0), paddle.to_tensor(1.0)),
transform.ExpTransform())),
np.array([0., 1., 2., 3.]), np.exp(np.array([0., 1., 2., 3.]) * 1.0)),
(transform.ChainTransform((transform.ExpTransform(),
transform.TanhTransform())),
np.array([[0., -1., 2., -3.], [-5., 6., 7., -8.]]),
np.tanh(np.exp(np.array([[0., -1., 2., -3.], [-5., 6., 7., -8.]]))))])
def test_forward(self, chain, input, expected):
np.testing.assert_allclose(
chain.forward(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func(
[(transform.ChainTransform(
(transform.AffineTransform(
paddle.to_tensor(0.0), paddle.to_tensor(-1.0)),
transform.ExpTransform())), np.array([0., 1., 2., 3.]),
np.log(np.array([0., 1., 2., 3.])) / (-1.0)),
(transform.ChainTransform((transform.ExpTransform(),
transform.TanhTransform())),
np.array([[0., 1., 2., 3.], [5., 6., 7., 8.]]),
np.log(np.arctanh(np.array([[0., 1., 2., 3.], [5., 6., 7., 8.]]))))])
def test_inverse(self, chain, input, expected):
np.testing.assert_allclose(
chain.inverse(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([
(transform.ChainTransform(
(transform.AffineTransform(
paddle.to_tensor(0.0), paddle.to_tensor(-1.0)),
transform.PowerTransform(paddle.to_tensor(2.0)))),
np.array([1., 2., 3.]), np.log(2. * np.array([1., 2., 3.]))),
])
def test_forward_log_det_jacobian(self, chain, input, expected):
np.testing.assert_allclose(
chain.forward_log_det_jacobian(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(transform.ChainTransform((transform.AffineTransform(
paddle.to_tensor(0.0),
paddle.to_tensor(-1.0)), transform.ExpTransform())), (2, 3, 5),
(2, 3, 5)), ])
def test_forward_shape(self, chain, shape, expected_shape):
self.assertEqual(chain.forward_shape(shape), expected_shape)
@param.param_func([(transform.ChainTransform((transform.AffineTransform(
paddle.to_tensor(0.0),
paddle.to_tensor(-1.0)), transform.ExpTransform())), (2, 3, 5),
(2, 3, 5)), ])
def test_inverse_shape(self, chain, shape, expected_shape):
self.assertEqual(chain.inverse_shape(shape), expected_shape)
@param.place(config.DEVICES)
@param.param_cls(
(param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank', 'x'),
[('rank-over-zero', transform.ExpTransform(), 2, np.random.rand(2, 3, 3)),
])
class TestIndependentTransform(unittest.TestCase):
def setUp(self):
self._t = transform.IndependentTransform(self.base,
self.reinterpreted_batch_rank)
@param.param_func([(0, 0, TypeError),
(paddle.distribution.Transform(), -1, ValueError)])
def test_init_exception(self, base, rank, exc):
with self.assertRaises(exc):
paddle.distribution.IndependentTransform(base, rank)
def test_is_injective(self):
self.assertEqual(self._t._is_injective(), self.base._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Independent))
self.assertEqual(
self._t._domain.event_rank,
self.base._domain.event_rank + self.reinterpreted_batch_rank)
self.assertEqual(self._t._domain.is_discrete,
self.base._domain.is_discrete)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Independent))
self.assertEqual(
self._t._codomain.event_rank,
self.base._codomain.event_rank + self.reinterpreted_batch_rank)
self.assertEqual(self._t._codomain.is_discrete,
self.base._codomain.is_discrete)
def test_forward(self):
np.testing.assert_allclose(
self._t.forward(paddle.to_tensor(self.x)).numpy(),
self.base.forward(paddle.to_tensor(self.x)).numpy(),
rtol=config.RTOL.get(str(self.x.dtype)),
atol=config.ATOL.get(str(self.x.dtype)))
def test_inverse(self):
np.testing.assert_allclose(
self._t.inverse(paddle.to_tensor(self.x)).numpy(),
self.base.inverse(paddle.to_tensor(self.x)).numpy(),
rtol=config.RTOL.get(str(self.x.dtype)),
atol=config.ATOL.get(str(self.x.dtype)))
def test_forward_log_det_jacobian(self):
actual = self._t.forward_log_det_jacobian(paddle.to_tensor(self.x))
self.assertEqual(
tuple(actual.shape), self.x.shape[:-self.reinterpreted_batch_rank])
expected = self.base.forward_log_det_jacobian(
paddle.to_tensor(self.x)).sum(
list(range(-self.reinterpreted_batch_rank, 0)))
np.testing.assert_allclose(
actual.numpy(),
expected.numpy(),
rtol=config.RTOL.get(str(self.x.dtype)),
atol=config.ATOL.get(str(self.x.dtype)))
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
class TestPowerTransform(unittest.TestCase):
def setUp(self):
self._t = transform.PowerTransform(paddle.to_tensor(2.))
def test_init(self):
with self.assertRaises(TypeError):
transform.PowerTransform(1.)
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Positive))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
@param.param_func([(np.array([2.]), np.array([0., -1., 2.]), np.power(
np.array([0., -1., 2.]),
2.)), (np.array([[0.], [3.]]), np.array([[1., 0.], [5., 6.]]), np.power(
np.array([[1., 0.], [5., 6.]]), np.array([[0.], [3.]])))])
def test_forward(self, power, x, y):
t = transform.PowerTransform(paddle.to_tensor(power))
np.testing.assert_allclose(
t.forward(paddle.to_tensor(x)).numpy(),
y,
rtol=config.RTOL.get(str(x.dtype)),
atol=config.ATOL.get(str(x.dtype)))
@param.param_func([(np.array([2.]), np.array([4.]), np.array([2.]))])
def test_inverse(self, power, y, x):
t = transform.PowerTransform(paddle.to_tensor(power))
np.testing.assert_allclose(
t.inverse(paddle.to_tensor(y)).numpy(),
x,
rtol=config.RTOL.get(str(x.dtype)),
atol=config.ATOL.get(str(x.dtype)))
@param.param_func(((np.array([2.]), np.array([3., 1.4, 0.8])), ))
def test_forward_log_det_jacobian(self, power, x):
t = transform.PowerTransform(paddle.to_tensor(power))
np.testing.assert_allclose(
t.forward_log_det_jacobian(paddle.to_tensor(x)).numpy(),
self._np_forward_jacobian(power, x),
rtol=config.RTOL.get(str(x.dtype)),
atol=config.ATOL.get(str(x.dtype)))
def _np_forward_jacobian(self, alpha, x):
return np.abs(np.log(alpha * np.power(x, alpha - 1)))
@param.param_func([((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
class TestTanhTransform(unittest.TestCase):
def setUp(self):
self._t = transform.TanhTransform()
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Variable))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
self.assertEqual(self._t._codomain._constraint._lower, -1)
self.assertEqual(self._t._codomain._constraint._upper, 1)
@param.param_func(
[(np.array([0., 1., 2., 3.]), np.tanh(np.array([0., 1., 2., 3.]))),
(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]]),
np.tanh(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]])))])
def test_forward(self, input, expected):
np.testing.assert_allclose(
self._t.forward(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func(
[(np.array([1., 2., 3.]), np.arctanh(np.array([1., 2., 3.]))),
(np.array([[1., 2., 3.], [6., 7., 8.]]),
np.arctanh(np.array([[1., 2., 3.], [6., 7., 8.]])))])
def test_inverse(self, input, expected):
np.testing.assert_allclose(
self._t.inverse(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(np.array([1., 2., 3.]), ),
(np.array([[1., 2., 3.], [6., 7., 8.]]), )])
def test_forward_log_det_jacobian(self, input):
np.testing.assert_allclose(
self._t.forward_log_det_jacobian(paddle.to_tensor(input)).numpy(),
self._np_forward_jacobian(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def _np_forward_jacobian(self, x):
return 2. * (np.log(2.) - x - self._np_softplus(-2. * x))
def _np_softplus(self, x, beta=1., threshold=20.):
if np.any(beta * x > threshold):
return x
return 1. / beta * np.log1p(np.exp(beta * x))
def _np_inverse_jacobian(self, y):
return -self._np_forward_jacobian(np.arctanh(y))
@param.param_func([(np.array([1., 2., 3.]), ),
(np.array([[1., 2., 3.], [6., 7., 8.]]), )])
def test_inverse_log_det_jacobian(self, input):
np.testing.assert_allclose(
self._t.inverse_log_det_jacobian(paddle.to_tensor(input)).numpy(),
self._np_inverse_jacobian(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'in_event_shape', 'out_event_shape'), [
('regular_shape', (2, 3), (3, 2)),
])
class TestReshapeTransform(unittest.TestCase):
def setUp(self):
self._t = transform.ReshapeTransform(self.in_event_shape,
self.out_event_shape)
@param.param_func([(0, 0, TypeError), ((1, 2), (1, 3), ValueError)])
def test_init_exception(self, in_event_shape, out_event_shape, exc):
with self.assertRaises(exc):
paddle.distribution.ReshapeTransform(in_event_shape,
out_event_shape)
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Independent))
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Independent))
def test_forward(self):
x = paddle.ones(self.in_event_shape)
np.testing.assert_allclose(
self._t.forward(x),
paddle.ones(self.out_event_shape),
rtol=config.RTOL.get(str(x.numpy().dtype)),
atol=config.ATOL.get(str(x.numpy().dtype)))
def test_inverse(self):
x = paddle.ones(self.out_event_shape)
np.testing.assert_allclose(
self._t.inverse(x).numpy(),
paddle.ones(self.in_event_shape).numpy(),
rtol=config.RTOL.get(str(x.numpy().dtype)),
atol=config.ATOL.get(str(x.numpy().dtype)))
def test_forward_log_det_jacobian(self):
x = paddle.ones(self.in_event_shape)
np.testing.assert_allclose(
self._t.forward_log_det_jacobian(x).numpy(),
paddle.zeros([1]).numpy(),
rtol=config.RTOL.get(str(x.numpy().dtype)),
atol=config.ATOL.get(str(x.numpy().dtype)))
def test_in_event_shape(self):
self.assertEqual(self._t.in_event_shape, self.in_event_shape)
def test_out_event_shape(self):
self.assertEqual(self._t.out_event_shape, self.out_event_shape)
@param.param_func([((), ValueError), ((1, 2), ValueError)])
def test_forward_shape_exception(self, shape, exc):
with self.assertRaises(exc):
self._t.forward_shape(shape)
@param.param_func([((), ValueError), ((1, 2), ValueError)])
def test_inverse_shape_exception(self, shape, exc):
with self.assertRaises(exc):
self._t.inverse_shape(shape)
def _np_softplus(x, beta=1., threshold=20.):
if np.any(beta * x > threshold):
return x
return 1. / beta * np.log1p(np.exp(beta * x))
class TestSigmoidTransform(unittest.TestCase):
def setUp(self):
self._t = transform.SigmoidTransform()
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Variable))
@param.param_func(((np.ones((5, 10)),
1 / (1 + np.exp(-np.ones((5, 10))))), ))
def test_forward(self, input, expected):
np.testing.assert_allclose(
self._t.forward(paddle.to_tensor(input)),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func((
(np.ones(10), np.log(np.ones(10)) - np.log1p(-np.ones(10))), ))
def test_inverse(self, input, expected):
np.testing.assert_allclose(
self._t.inverse(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func((
(np.ones(10),
-_np_softplus(-np.ones(10)) - _np_softplus(np.ones(10))), ))
def test_forward_log_det_jacobian(self, input, expected):
np.testing.assert_allclose(
self._t.forward_log_det_jacobian(paddle.to_tensor(input)).numpy(),
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
class TestSoftmaxTransform(unittest.TestCase):
def setUp(self):
self._t = transform.SoftmaxTransform()
def test_is_injective(self):
self.assertFalse(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Independent))
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Variable))
@param.param_func(((np.random.random((5, 10)), ), ))
def test_forward(self, input):
np.testing.assert_allclose(
self._t.forward(paddle.to_tensor(input)),
self._np_forward(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func(((np.random.random(10), ), ))
def test_inverse(self, input):
np.testing.assert_allclose(
self._t.inverse(paddle.to_tensor(input)),
self._np_inverse(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def _np_forward(self, x):
x = np.exp(x - np.max(x, -1, keepdims=True)[0])
return x / np.sum(x, -1, keepdims=True)
def _np_inverse(self, y):
return np.log(y)
def test_forward_log_det_jacobian(self):
with self.assertRaises(NotImplementedError):
self._t.forward_log_det_jacobian(paddle.rand((2, 3)))
@param.param_func([((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ValueError)])
def test_forward_shape_exception(self, shape, exc):
with self.assertRaises(exc):
self._t.forward_shape(shape)
@param.param_func([((), ValueError)])
def test_inverse_shape_exception(self, shape, exc):
with self.assertRaises(exc):
self._t.inverse_shape(shape)
@param.param_func([((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.inverse_shape(shape), expected_shape)
class TestStickBreakingTransform(unittest.TestCase):
def setUp(self):
self._t = transform.StickBreakingTransform()
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Independent))
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Variable))
@param.param_func(((np.random.random((10)), ), ))
def test_forward(self, input):
np.testing.assert_allclose(
self._t.inverse(self._t.forward(paddle.to_tensor(input))),
input,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([((2, 3, 5), (2, 3, 6))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((2, 3, 5), (2, 3, 4))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.inverse_shape(shape), expected_shape)
@param.param_func(((np.random.random((10)), ), ))
def test_forward_log_det_jacobian(self, x):
self.assertEqual(
self._t.forward_log_det_jacobian(paddle.to_tensor(x)).shape, [1])
# Todo
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'transforms', 'axis'), [
('simple_one_transform', [transform.ExpTransform()], 0),
])
class TestStackTransform(unittest.TestCase):
def setUp(self):
self._t = transform.StackTransform(self.transforms, self.axis)
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Stack))
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Stack))
@param.param_func([(np.array([[0., 1., 2., 3.]]), ),
(np.array([[-5., 6., 7., 8.]]), )])
def test_forward(self, input):
self.assertEqual(
tuple(self._t.forward(paddle.to_tensor(input)).shape), input.shape)
@param.param_func([(np.array([[1., 2., 3.]]), ),
(np.array([[6., 7., 8.]], ), )])
def test_inverse(self, input):
self.assertEqual(
tuple(self._t.inverse(paddle.to_tensor(input)).shape), input.shape)
@param.param_func([(np.array([[1., 2., 3.]]), ),
(np.array([[6., 7., 8.]]), )])
def test_forward_log_det_jacobian(self, input):
self.assertEqual(
tuple(
self._t.forward_log_det_jacobian(paddle.to_tensor(input))
.shape), input.shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
def test_axis(self):
self.assertEqual(self._t.axis, self.axis)
@param.param_func(
[(0, 0, TypeError), ([0], 0, TypeError),
([paddle.distribution.ExpTransform()], 'axis', TypeError)])
def test_init_exception(self, transforms, axis, exc):
with self.assertRaises(exc):
paddle.distribution.StackTransform(transforms, axis)
def test_transforms(self):
self.assertIsInstance((self._t.transforms), typing.Sequence)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.distribution import transform, variable, constraint
import config
import parameterize as param
paddle.enable_static()
@param.place(config.DEVICES)
class TestTransform(unittest.TestCase):
def setUp(self):
self._t = transform.Transform()
@param.param_func(
[(transform.Type.BIJECTION, True), (transform.Type.INJECTION, True),
(transform.Type.SURJECTION, False), (transform.Type.OTHER, False)])
def test_is_injective(self, type, expected):
transform.Transform._type = type
self.assertEqual(self._t._is_injective(), expected)
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Real))
@param.param_func([(np.array(0), NotImplementedError), (np.random.random(
(2, 3)), NotImplementedError)])
def test_forward(self, input, expected):
with self.assertRaises(expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.Transform()
static_input = paddle.static.data('input', input.shape,
input.dtype)
output = t.forward(static_input)
exe.run(sp)
exe.run(mp, feed={'input': input}, fetch_list=[output])
@param.param_func([(np.array(0), NotImplementedError), (np.random.random(
(2, 3)), NotImplementedError)])
def test_inverse(self, input, expected):
with self.assertRaises(expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.Transform()
static_input = paddle.static.data('input', input.shape,
input.dtype)
output = t.inverse(static_input)
exe.run(sp)
exe.run(mp, feed={'input': input}, fetch_list=[output])
@param.param_func([(np.array(0), NotImplementedError), (paddle.rand(
(2, 3)), NotImplementedError)])
def test_forward_log_det_jacobian(self, input, expected):
with self.assertRaises(expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.Transform()
static_input = paddle.static.data('input', input.shape,
input.dtype)
output = t.forward_log_det_jacobian(static_input)
exe.run(sp)
exe.run(mp, feed={'input': input}, fetch_list=[output])
@param.param_func([(np.array(0), NotImplementedError), (paddle.rand(
(2, 3)), NotImplementedError)])
def test_inverse_log_det_jacobian(self, input, expected):
with self.assertRaises(expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.Transform()
static_input = paddle.static.data('input', input.shape,
input.dtype)
output = t.inverse_log_det_jacobian(static_input)
exe.run(sp)
exe.run(mp, feed={'input': input}, fetch_list=[output])
@param.param_func([(0, TypeError)])
def test_forward_shape(self, shape, expected):
with self.assertRaises(expected):
self._t.forward_shape(shape)
@param.param_func([(0, TypeError)])
def test_inverse_shape(self, shape, expected):
with self.assertRaises(expected):
self._t.forward_shape(shape)
@param.place(config.DEVICES)
class TestAbsTransform(unittest.TestCase):
def setUp(self):
self._t = transform.AbsTransform()
def test_is_injective(self):
self.assertFalse(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Positive))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
@param.param_func([(np.array([-1., 1., 0.]), np.array([1., 1., 0.])),
(np.array([[1., -1., -0.1], [-3., -0.1, 0]]),
np.array([[1., 1., 0.1], [3., 0.1, 0]]))])
def test_forward(self, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.AbsTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.forward(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(np.array([1.]), (-np.array([1.]), np.array([1.])))])
def test_inverse(self, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.AbsTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
actual0, actual1 = t.inverse(static_input)
exe.run(sp)
[actual0, actual1] = exe.run(mp,
feed={'input': input},
fetch_list=[actual0, actual1])
expected0, expected1 = expected
np.testing.assert_allclose(
actual0,
expected0,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
np.testing.assert_allclose(
actual1,
expected1,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def test_forward_log_det_jacobian(self):
input = np.random.random((10, ))
with self.assertRaises(NotImplementedError):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.AbsTransform()
static_input = paddle.static.data('input', input.shape,
input.dtype)
output = t.forward_log_det_jacobian(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
@param.param_func([(np.array([1.]), (np.array([0.]), np.array([0.]))), ])
def test_inverse_log_det_jacobian(self, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.AbsTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
actual0, actual1 = t.inverse_log_det_jacobian(static_input)
exe.run(sp)
[actual0, actual1] = exe.run(mp,
feed={'input': input},
fetch_list=[actual0, actual1])
expected0, expected1 = expected
np.testing.assert_allclose(
actual0,
expected0,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
np.testing.assert_allclose(
actual1,
expected1,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'loc', 'scale'), [
('normal', np.random.rand(8, 10), np.random.rand(8, 10)),
('broadcast', np.random.rand(2, 10), np.random.rand(10)),
])
class TestAffineTransform(unittest.TestCase):
def setUp(self):
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
self._t = transform.AffineTransform(loc, scale)
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Real))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
def test_forward(self):
input = np.random.random(self.loc.shape)
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
t = transform.AffineTransform(loc, scale)
static_input = paddle.static.data('input', self.loc.shape,
self.loc.dtype)
output = t.forward(static_input)
exe.run(sp)
[output] = exe.run(
mp,
feed={'input': input,
'loc': self.loc,
'scale': self.scale},
fetch_list=[output])
np.testing.assert_allclose(
output,
self._np_forward(input),
rtol=config.RTOL.get(str(self.loc.dtype)),
atol=config.ATOL.get(str(self.loc.dtype)))
def test_inverse(self):
input = np.random.random(self.loc.shape)
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
t = transform.AffineTransform(loc, scale)
static_input = paddle.static.data('input', self.loc.shape,
self.loc.dtype)
output = t.inverse(static_input)
exe.run(sp)
[output] = exe.run(
mp,
feed={'input': input,
'loc': self.loc,
'scale': self.scale},
fetch_list=[output])
np.testing.assert_allclose(
output,
self._np_inverse(input),
rtol=config.RTOL.get(str(self.loc.dtype)),
atol=config.ATOL.get(str(self.loc.dtype)))
def _np_forward(self, x):
return self.loc + self.scale * x
def _np_inverse(self, y):
return (y - self.loc) / self.scale
def _np_forward_jacobian(self, x):
return np.log(np.abs(self.scale))
def _np_inverse_jacobian(self, y):
return -self._np_forward_jacobian(self._np_inverse(y))
def test_inverse_log_det_jacobian(self):
input = np.random.random(self.scale.shape)
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
t = transform.AffineTransform(loc, scale)
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.inverse_log_det_jacobian(static_input)
exe.run(sp)
[output] = exe.run(
mp,
feed={'input': input,
'loc': self.loc,
'scale': self.scale},
fetch_list=[output])
np.testing.assert_allclose(
output,
self._np_inverse_jacobian(input),
rtol=config.RTOL.get(str(self.loc.dtype)),
atol=config.ATOL.get(str(self.loc.dtype)))
def test_forward_log_det_jacobian(self):
input = np.random.random(self.scale.shape)
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype)
scale = paddle.static.data('scale', self.scale.shape,
self.scale.dtype)
t = transform.AffineTransform(loc, scale)
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.forward_log_det_jacobian(static_input)
exe.run(sp)
[output] = exe.run(
mp,
feed={'input': input,
'loc': self.loc,
'scale': self.scale},
fetch_list=[output])
np.testing.assert_allclose(
output,
self._np_forward_jacobian(input),
rtol=config.RTOL.get(str(self.loc.dtype)),
atol=config.ATOL.get(str(self.loc.dtype)))
def test_forward_shape(self):
shape = self.loc.shape
self.assertEqual(
tuple(self._t.forward_shape(shape)),
np.broadcast(np.random.random(shape), self.loc, self.scale).shape)
def test_inverse_shape(self):
shape = self.scale.shape
self.assertEqual(
tuple(self._t.forward_shape(shape)),
np.broadcast(np.random.random(shape), self.loc, self.scale).shape)
@param.place(config.DEVICES)
class TestExpTransform(unittest.TestCase):
def setUp(self):
self._t = transform.ExpTransform()
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Positive))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
@param.param_func(
[(np.array([0., 1., 2., 3.]), np.exp(np.array([0., 1., 2., 3.]))),
(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]]),
np.exp(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]])))])
def test_forward(self, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.ExpTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.forward(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(np.array([1., 2., 3.]), np.log(np.array([1., 2., 3.]))),
(np.array([[1., 2., 3.], [6., 7., 8.]]),
np.log(np.array([[1., 2., 3.], [6., 7., 8.]])))])
def test_inverse(self, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.ExpTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.inverse(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(np.array([1., 2., 3.]), ),
(np.array([[1., 2., 3.], [6., 7., 8.]]), )])
def test_forward_log_det_jacobian(self, input):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.ExpTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.forward_log_det_jacobian(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
self._np_forward_jacobian(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def _np_forward_jacobian(self, x):
return x
@param.param_func([(np.array([1., 2., 3.]), ),
(np.array([[1., 2., 3.], [6., 7., 8.]]), )])
def test_inverse_log_det_jacobian(self, input):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.ExpTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.inverse_log_det_jacobian(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
self._np_inverse_jacobian(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def _np_inverse_jacobian(self, y):
return -self._np_forward_jacobian(np.log(y))
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
class TestChainTransform(unittest.TestCase):
@param.param_func((
(transform.ChainTransform(
(transform.AbsTransform(),
transform.AffineTransform(paddle.rand([1]), paddle.rand([1])))),
False), (transform.ChainTransform((
transform.AffineTransform(paddle.rand([1]), paddle.rand([1])),
transform.ExpTransform(), )), True)))
def test_is_injective(self, chain, expected):
self.assertEqual(chain._is_injective(), expected)
@param.param_func(((transform.ChainTransform(
(transform.IndependentTransform(transform.ExpTransform(), 1),
transform.IndependentTransform(transform.ExpTransform(), 10),
transform.IndependentTransform(transform.ExpTransform(), 8))),
variable.Independent(variable.real, 10)), ))
def test_domain(self, input, expected):
self.assertIsInstance(input._domain, type(expected))
self.assertEqual(input._domain.event_rank, expected.event_rank)
self.assertEqual(input._domain.is_discrete, expected.is_discrete)
@param.param_func(((transform.ChainTransform(
(transform.IndependentTransform(transform.ExpTransform(), 9),
transform.IndependentTransform(transform.ExpTransform(), 4),
transform.IndependentTransform(transform.ExpTransform(), 5))),
variable.Independent(variable.real, 9)), ))
def test_codomain(self, input, expected):
self.assertIsInstance(input._codomain, variable.Independent)
self.assertEqual(input._codomain.event_rank, expected.event_rank)
self.assertEqual(input._codomain.is_discrete, expected.is_discrete)
@param.param_func(
[(transform.ChainTransform((transform.ExpTransform(),
transform.TanhTransform())),
np.array([[0., -1., 2., -3.], [-5., 6., 7., -8.]]),
np.tanh(np.exp(np.array([[0., -1., 2., -3.], [-5., 6., 7., -8.]]))))])
def test_forward(self, chain, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = chain
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.forward(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func(
[(transform.ChainTransform((transform.ExpTransform(),
transform.TanhTransform())),
np.array([[0., 1., 2., 3.], [5., 6., 7., 8.]]),
np.log(np.arctanh(np.array([[0., 1., 2., 3.], [5., 6., 7., 8.]]))))])
def test_inverse(self, chain, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = chain
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.inverse(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(transform.ChainTransform((transform.AffineTransform(
paddle.full([1], 0.0),
paddle.full([1], -1.0)), transform.ExpTransform())), (2, 3, 5),
(2, 3, 5)), ])
def test_forward_shape(self, chain, shape, expected_shape):
self.assertEqual(chain.forward_shape(shape), expected_shape)
@param.param_func([(transform.ChainTransform((transform.AffineTransform(
paddle.full([1], 0.0),
paddle.full([1], -1.0)), transform.ExpTransform())), (2, 3, 5),
(2, 3, 5)), ])
def test_inverse_shape(self, chain, shape, expected_shape):
self.assertEqual(chain.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
@param.param_cls(
(param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank', 'x'),
[('rank-over-zero', transform.ExpTransform(), 2, np.random.rand(2, 3, 3)),
])
class TestIndependentTransform(unittest.TestCase):
def setUp(self):
self._t = transform.IndependentTransform(self.base,
self.reinterpreted_batch_rank)
def test_is_injective(self):
self.assertEqual(self._t._is_injective(), self.base._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Independent))
self.assertEqual(
self._t._domain.event_rank,
self.base._domain.event_rank + self.reinterpreted_batch_rank)
self.assertEqual(self._t._domain.is_discrete,
self.base._domain.is_discrete)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Independent))
self.assertEqual(
self._t._codomain.event_rank,
self.base._codomain.event_rank + self.reinterpreted_batch_rank)
self.assertEqual(self._t._codomain.is_discrete,
self.base._codomain.is_discrete)
def test_forward(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.IndependentTransform(self.base,
self.reinterpreted_batch_rank)
static_input = paddle.static.data('input', self.x.shape,
self.x.dtype)
output = t.forward(static_input)
expected = self.base.forward(static_input)
exe.run(sp)
[output, expected] = exe.run(mp,
feed={'input': self.x},
fetch_list=[output, expected])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(self.x.dtype)),
atol=config.ATOL.get(str(self.x.dtype)))
def test_inverse(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.IndependentTransform(self.base,
self.reinterpreted_batch_rank)
static_input = paddle.static.data('input', self.x.shape,
self.x.dtype)
output = t.inverse(static_input)
expected = self.base.inverse(static_input)
exe.run(sp)
[output, expected] = exe.run(mp,
feed={'input': self.x},
fetch_list=[output, expected])
np.testing.assert_allclose(
expected,
output,
rtol=config.RTOL.get(str(self.x.dtype)),
atol=config.ATOL.get(str(self.x.dtype)))
def test_forward_log_det_jacobian(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.IndependentTransform(self.base,
self.reinterpreted_batch_rank)
static_input = paddle.static.data('input', self.x.shape,
self.x.dtype)
output = t.forward_log_det_jacobian(static_input)
expected = self.base.forward_log_det_jacobian(
static_input.sum(
list(range(-self.reinterpreted_batch_rank, 0))))
exe.run(sp)
[actual, expected] = exe.run(mp,
feed={'input': self.x},
fetch_list=[output, expected])
self.assertEqual(
tuple(actual.shape), self.x.shape[:-self.reinterpreted_batch_rank])
np.testing.assert_allclose(
actual,
expected,
rtol=config.RTOL.get(str(self.x.dtype)),
atol=config.ATOL.get(str(self.x.dtype)))
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
class TestPowerTransform(unittest.TestCase):
def setUp(self):
self._t = transform.PowerTransform(paddle.full([1], 2.))
def test_init(self):
with self.assertRaises(TypeError):
transform.PowerTransform(1.)
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Positive))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
@param.param_func([(np.array([2.]), np.array([0., -1., 2.]), np.power(
np.array([0., -1., 2.]),
2.)), (np.array([[0.], [3.]]), np.array([[1., 0.], [5., 6.]]), np.power(
np.array([[1., 0.], [5., 6.]]), np.array([[0.], [3.]])))])
def test_forward(self, power, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
static_power = paddle.static.data('power', power.shape, power.dtype)
t = transform.PowerTransform(static_power)
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.forward(static_input)
exe.run(sp)
[output] = exe.run(mp,
feed={'input': input,
'power': power},
fetch_list=[output])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(np.array([2.]), np.array([4.]), np.array([2.]))])
def test_inverse(self, power, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
static_power = paddle.static.data('power', power.shape, power.dtype)
t = transform.PowerTransform(static_power)
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.inverse(static_input)
exe.run(sp)
[output] = exe.run(mp,
feed={'input': input,
'power': power},
fetch_list=[output])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func(((np.array([2.]), np.array([3., 1.4, 0.8])), ))
def test_forward_log_det_jacobian(self, power, input):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
static_power = paddle.static.data('power', power.shape, power.dtype)
t = transform.PowerTransform(static_power)
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.forward_log_det_jacobian(static_input)
exe.run(sp)
[output] = exe.run(mp,
feed={'input': input,
'power': power},
fetch_list=[output])
np.testing.assert_allclose(
output,
self._np_forward_jacobian(power, input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def _np_forward_jacobian(self, alpha, x):
return np.abs(np.log(alpha * np.power(x, alpha - 1)))
@param.param_func([((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
class TestTanhTransform(unittest.TestCase):
def setUp(self):
self._t = transform.TanhTransform()
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Real))
self.assertEqual(self._t._domain.event_rank, 0)
self.assertEqual(self._t._domain.is_discrete, False)
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Variable))
self.assertEqual(self._t._codomain.event_rank, 0)
self.assertEqual(self._t._codomain.is_discrete, False)
self.assertEqual(self._t._codomain._constraint._lower, -1)
self.assertEqual(self._t._codomain._constraint._upper, 1)
@param.param_func(
[(np.array([0., 1., 2., 3.]), np.tanh(np.array([0., 1., 2., 3.]))),
(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]]),
np.tanh(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]])))])
def test_forward(self, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.TanhTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.forward(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func(
[(np.array([1., 2., 3.]), np.arctanh(np.array([1., 2., 3.]))),
(np.array([[1., 2., 3.], [6., 7., 8.]]),
np.arctanh(np.array([[1., 2., 3.], [6., 7., 8.]])))])
def test_inverse(self, input, expected):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.TanhTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.inverse(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([(np.array([1., 2., 3.]), ),
(np.array([[1., 2., 3.], [6., 7., 8.]]), )])
def test_forward_log_det_jacobian(self, input):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.TanhTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.forward_log_det_jacobian(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
self._np_forward_jacobian(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
def _np_forward_jacobian(self, x):
return 2. * (np.log(2.) - x - self._np_softplus(-2. * x))
def _np_softplus(self, x, beta=1., threshold=20.):
if np.any(beta * x > threshold):
return x
return 1. / beta * np.log1p(np.exp(beta * x))
def _np_inverse_jacobian(self, y):
return -self._np_forward_jacobian(np.arctanh(y))
@param.param_func([(np.array([1., 2., 3.]), ),
(np.array([[1., 2., 3.], [6., 7., 8.]]), )])
def test_inverse_log_det_jacobian(self, input):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
t = transform.TanhTransform()
static_input = paddle.static.data('input', input.shape, input.dtype)
output = t.inverse_log_det_jacobian(static_input)
exe.run(sp)
[output] = exe.run(mp, feed={'input': input}, fetch_list=[output])
np.testing.assert_allclose(
output,
self._np_inverse_jacobian(input),
rtol=config.RTOL.get(str(input.dtype)),
atol=config.ATOL.get(str(input.dtype)))
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_forward_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))])
def test_inverse_shape(self, shape, expected_shape):
self.assertEqual(self._t.forward_shape(shape), expected_shape)
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'in_event_shape', 'out_event_shape'), [
('regular_shape', (2, 3), (3, 2)),
])
class TestReshapeTransform(unittest.TestCase):
def setUp(self):
self._t = transform.ReshapeTransform(self.in_event_shape,
self.out_event_shape)
def test_is_injective(self):
self.assertTrue(self._t._is_injective())
def test_domain(self):
self.assertTrue(isinstance(self._t._domain, variable.Independent))
def test_codomain(self):
self.assertTrue(isinstance(self._t._codomain, variable.Independent))
def test_forward(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones(self.in_event_shape)
t = transform.ReshapeTransform(self.in_event_shape,
self.out_event_shape)
output = self._t.forward(x)
exe.run(sp)
[output] = exe.run(mp, feed={}, fetch_list=[output])
expected = np.ones(self.out_event_shape)
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(expected.dtype)),
atol=config.ATOL.get(str(expected.dtype)))
def test_inverse(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones(self.out_event_shape)
t = transform.ReshapeTransform(self.in_event_shape,
self.out_event_shape)
output = self._t.inverse(x)
exe.run(sp)
[output] = exe.run(mp, feed={}, fetch_list=[output])
expected = np.ones(self.in_event_shape)
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(expected.dtype)),
atol=config.ATOL.get(str(expected.dtype)))
def test_forward_log_det_jacobian(self):
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones(self.in_event_shape)
t = transform.ReshapeTransform(self.in_event_shape,
self.out_event_shape)
output = self._t.forward_log_det_jacobian(x)
exe.run(sp)
[output] = exe.run(mp, feed={}, fetch_list=[output])
expected = np.zeros([1])
np.testing.assert_allclose(
output,
expected,
rtol=config.RTOL.get(str(expected.dtype)),
atol=config.ATOL.get(str(expected.dtype)))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import parameterize as param
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'base', 'transforms'),
[('base_normal', paddle.distribution.Normal(0., 1.),
[paddle.distribution.ExpTransform()])])
class TestIndependent(unittest.TestCase):
def setUp(self):
self._t = paddle.distribution.TransformedDistribution(self.base,
self.transforms)
def _np_sum_rightmost(self, value, n):
return np.sum(value, tuple(range(-n, 0))) if n > 0 else value
def test_log_prob(self):
value = paddle.to_tensor(0.5)
np.testing.assert_allclose(
self.simple_log_prob(value, self.base, self.transforms),
self._t.log_prob(value),
rtol=config.RTOL.get(str(value.numpy().dtype)),
atol=config.ATOL.get(str(value.numpy().dtype)))
def simple_log_prob(self, value, base, transforms):
log_prob = 0.0
y = value
for t in reversed(transforms):
x = t.inverse(y)
log_prob = log_prob - t.forward_log_det_jacobian(x)
y = x
log_prob += base.log_prob(y)
return log_prob
# TODO(cxxly): Add Kolmogorov-Smirnov test for sample result.
def test_sample(self):
shape = [5, 10, 8]
expected_shape = (5, 10, 8)
data = self._t.sample(shape)
self.assertEqual(tuple(data.shape), expected_shape)
self.assertEqual(data.dtype, self.base.loc.dtype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import unittest
import numpy as np
import paddle
import scipy.stats
import config
import parameterize as param
paddle.enable_static()
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'base', 'transforms'),
[('base_normal', paddle.distribution.Normal,
[paddle.distribution.ExpTransform()])])
class TestIndependent(unittest.TestCase):
def setUp(self):
value = np.array([0.5])
loc = np.array([0.])
scale = np.array([1.])
shape = [5, 10, 8]
self.dtype = value.dtype
exe = paddle.static.Executor()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
static_value = paddle.static.data('value', value.shape, value.dtype)
static_loc = paddle.static.data('loc', loc.shape, loc.dtype)
static_scale = paddle.static.data('scale', scale.shape, scale.dtype)
self.base = self.base(static_loc, static_scale)
self._t = paddle.distribution.TransformedDistribution(
self.base, self.transforms)
actual_log_prob = self._t.log_prob(static_value)
expected_log_prob = self.transformed_log_prob(
static_value, self.base, self.transforms)
sample_data = self._t.sample(shape)
exe.run(sp)
[self.actual_log_prob, self.expected_log_prob,
self.sample_data] = exe.run(
mp,
feed={'value': value,
'loc': loc,
'scale': scale},
fetch_list=[actual_log_prob, expected_log_prob, sample_data])
def test_log_prob(self):
np.testing.assert_allclose(
self.actual_log_prob,
self.expected_log_prob,
rtol=config.RTOL.get(str(self.dtype)),
atol=config.ATOL.get(str(self.dtype)))
def transformed_log_prob(self, value, base, transforms):
log_prob = 0.0
y = value
for t in reversed(transforms):
x = t.inverse(y)
log_prob = log_prob - t.forward_log_det_jacobian(x)
y = x
log_prob += base.log_prob(y)
return log_prob
# TODO(cxxly): Add Kolmogorov-Smirnov test for sample result.
def test_sample(self):
expected_shape = (5, 10, 8, 1)
self.assertEqual(tuple(self.sample_data.shape), expected_shape)
self.assertEqual(self.sample_data.dtype, self.dtype)
if __name__ == '__main__':
unittest.main()
......@@ -343,3 +343,7 @@ class UniformTestSample2(UniformTestSample):
def init_param(self):
self.low = -5.0
self.high = 2.0
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.distribution import variable
from paddle.distribution import constraint
import config
import parameterize as param
@param.param_cls(
(param.TEST_CASE_NAME, 'is_discrete', 'event_rank', 'constraint'),
[('NotImplement', False, 0, constraint.Constraint())])
class TestVariable(unittest.TestCase):
def setUp(self):
self._var = variable.Variable(self.is_discrete, self.event_rank,
self.constraint)
@param.param_func([(1, )])
def test_costraint(self, value):
with self.assertRaises(NotImplementedError):
self._var.constraint(value)
@param.param_cls((param.TEST_CASE_NAME, 'base', 'rank'),
[('real_base', variable.real, 10)])
class TestIndependent(unittest.TestCase):
def setUp(self):
self._var = variable.Independent(self.base, self.rank)
@param.param_func([(paddle.rand([2, 3, 4]), ValueError), ])
def test_costraint(self, value, expect):
with self.assertRaises(expect):
self._var.constraint(value)
@param.param_cls((param.TEST_CASE_NAME, 'vars', 'axis'),
[('real_base', [variable.real], 10)])
class TestStack(unittest.TestCase):
def setUp(self):
self._var = variable.Stack(self.vars, self.axis)
def test_is_discrete(self):
self.assertEqual(self._var.is_discrete, False)
@param.param_func([(paddle.rand([2, 3, 4]), ValueError), ])
def test_costraint(self, value, expect):
with self.assertRaises(expect):
self._var.constraint(value)
if __name__ == '__main__':
unittest.main()
......@@ -23,12 +23,13 @@ from paddle.distribution import kl
import config
import mock_data as mock
import parameterize as param
paddle.set_default_dtype('float64')
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [
@param.place(config.DEVICES)
@param.parameterize_cls((param.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [
('test_regular_input', 6.0 * np.random.random((4, 5)) + 1e-4,
6.0 * np.random.random((4, 5)) + 1e-4, 6.0 * np.random.random(
(4, 5)) + 1e-4, 6.0 * np.random.random((4, 5)) + 1e-4),
......@@ -55,8 +56,8 @@ class TestKLBetaBeta(unittest.TestCase):
(a2 - a1 + b2 - b1) * scipy.special.digamma(a1 + b1))
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'conc1', 'conc2'), [
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'conc1', 'conc2'), [
('test-regular-input', np.random.random((5, 7, 8, 10)), np.random.random(
(5, 7, 8, 10))),
])
......@@ -88,23 +89,22 @@ class DummyDistribution(paddle.distribution.Distribution):
pass
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'p', 'q'),
[('test-unregister', DummyDistribution(), DummyDistribution)])
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'p', 'q'),
[('test-unregister', DummyDistribution(), DummyDistribution)])
class TestDispatch(unittest.TestCase):
def test_dispatch_with_unregister(self):
with self.assertRaises(NotImplementedError):
paddle.distribution.kl_divergence(self.p, self.q)
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'p', 'q'),
[('test-diff-dist', mock.Exponential(paddle.rand((100, 200, 100)) + 1.0),
mock.Exponential(paddle.rand((100, 200, 100)) + 2.0)),
('test-same-dist', mock.Exponential(paddle.to_tensor(1.0)),
mock.Exponential(paddle.to_tensor(1.0)))])
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'p', 'q'),
[('test-diff-dist',
mock.Exponential(paddle.rand((100, 200, 100)) + 1.0),
mock.Exponential(paddle.rand((100, 200, 100)) + 2.0)),
('test-same-dist', mock.Exponential(paddle.to_tensor(1.0)),
mock.Exponential(paddle.to_tensor(1.0)))])
class TestKLExpfamilyExpFamily(unittest.TestCase):
def test_kl_expfamily_expfamily(self):
np.testing.assert_allclose(
......
......@@ -22,13 +22,14 @@ import scipy.stats
from paddle.distribution import kl
import config
import parameterize as param
import mock_data as mock
paddle.enable_static()
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [
('test_regular_input', 6.0 * np.random.random((4, 5)) + 1e-4,
6.0 * np.random.random((4, 5)) + 1e-4, 6.0 * np.random.random(
(4, 5)) + 1e-4, 6.0 * np.random.random((4, 5)) + 1e-4),
......@@ -75,8 +76,8 @@ class TestKLBetaBeta(unittest.TestCase):
(a2 - a1 + b2 - b1) * scipy.special.digamma(a1 + b1))
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'conc1', 'conc2'), [
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'conc1', 'conc2'), [
('test-regular-input', np.random.random((5, 7, 8, 10)), np.random.random(
(5, 7, 8, 10))),
])
......@@ -123,9 +124,9 @@ class DummyDistribution(paddle.distribution.Distribution):
pass
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'p', 'q'),
[('test-dispatch-exception')])
@param.place(config.DEVICES)
@param.param_cls((param.TEST_CASE_NAME, 'p', 'q'),
[('test-dispatch-exception')])
class TestDispatch(unittest.TestCase):
def setUp(self):
self.mp = paddle.static.Program()
......@@ -143,11 +144,11 @@ class TestDispatch(unittest.TestCase):
self.executor.run(self.mp, feed={}, fetch_list=[out])
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'rate1', 'rate2'),
[('test-diff-dist', np.random.rand(100, 200, 100) + 1.0,
np.random.rand(100, 200, 100) + 2.0),
('test-same-dist', np.array([1.0]), np.array([1.0]))])
@param.place(config.DEVICES)
@param.param_cls((config.TEST_CASE_NAME, 'rate1', 'rate2'),
[('test-diff-dist', np.random.rand(100, 200, 100) + 1.0,
np.random.rand(100, 200, 100) + 2.0),
('test-same-dist', np.array([1.0]), np.array([1.0]))])
class TestKLExpfamilyExpFamily(unittest.TestCase):
def setUp(self):
self.mp = paddle.static.Program()
......@@ -176,3 +177,7 @@ class TestKLExpfamilyExpFamily(unittest.TestCase):
out2,
rtol=config.RTOL.get(config.DEFAULT_DTYPE),
atol=config.ATOL.get(config.DEFAULT_DTYPE))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册