diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 334b0d4a3a5d5fbd0fae18a09c62678ed6326f9b..3a9af812add6e27ccf69ce5d0f93614c9d16e833 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .beta import Beta -from .categorical import Categorical -from .dirichlet import Dirichlet -from .distribution import Distribution -from .exponential_family import ExponentialFamily -from .kl import kl_divergence, register_kl -from .multinomial import Multinomial -from .normal import Normal -from .uniform import Uniform +from paddle.distribution import transform +from paddle.distribution.beta import Beta +from paddle.distribution.categorical import Categorical +from paddle.distribution.dirichlet import Dirichlet +from paddle.distribution.distribution import Distribution +from paddle.distribution.exponential_family import ExponentialFamily +from paddle.distribution.independent import Independent +from paddle.distribution.kl import kl_divergence, register_kl +from paddle.distribution.multinomial import Multinomial +from paddle.distribution.normal import Normal +from paddle.distribution.transform import * # noqa: F403 +from paddle.distribution.transformed_distribution import \ + TransformedDistribution +from paddle.distribution.uniform import Uniform __all__ = [ # noqa 'Beta', @@ -33,4 +38,8 @@ __all__ = [ # noqa 'Uniform', 'kl_divergence', 'register_kl', + 'Independent', + 'TransformedDistribution' ] + +__all__.extend(transform.__all__) diff --git a/python/paddle/distribution/beta.py b/python/paddle/distribution/beta.py index 82f3dced1072885d9bd780b897d97d2c660d8662..e371b56eb66b891a2f866d3238f68bb800c385be 100644 --- a/python/paddle/distribution/beta.py +++ b/python/paddle/distribution/beta.py @@ -14,12 +14,10 @@ import numbers import paddle +from paddle.distribution import dirichlet, exponential_family -from .dirichlet import Dirichlet -from .exponential_family import ExponentialFamily - -class Beta(ExponentialFamily): +class Beta(exponential_family.ExponentialFamily): r""" Beta distribution parameterized by alpha and beta. @@ -93,7 +91,8 @@ class Beta(ExponentialFamily): self.alpha, self.beta = paddle.broadcast_tensors([alpha, beta]) - self._dirichlet = Dirichlet(paddle.stack([self.alpha, self.beta], -1)) + self._dirichlet = dirichlet.Dirichlet( + paddle.stack([self.alpha, self.beta], -1)) super(Beta, self).__init__(self._dirichlet._batch_shape) diff --git a/python/paddle/distribution/categorical.py b/python/paddle/distribution/categorical.py index 7b9aa694e80935ca88e090803ef1ebee497d38ef..b181a25fbcee1ecebb7241bd991fc78e152cbea3 100644 --- a/python/paddle/distribution/categorical.py +++ b/python/paddle/distribution/categorical.py @@ -18,18 +18,18 @@ import warnings import numpy as np import paddle from paddle import _C_ops - -from ..fluid import core -from ..fluid.data_feeder import (check_dtype, check_type, - check_variable_and_dtype, convert_dtype) -from ..fluid.framework import _non_static_mode -from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, - elementwise_mul, elementwise_sub, nn, ops, tensor) -from ..tensor import arange, concat, gather_nd, multinomial -from .distribution import Distribution - - -class Categorical(Distribution): +from paddle.distribution import distribution +from paddle.fluid import core +from paddle.fluid.data_feeder import (check_dtype, check_type, + check_variable_and_dtype, convert_dtype) +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode +from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, + elementwise_mul, elementwise_sub, nn, ops, + tensor) +from paddle.tensor import arange, concat, gather_nd, multinomial + + +class Categorical(distribution.Distribution): r""" Categorical distribution is a discrete probability distribution that describes the possible results of a random variable that can take on diff --git a/python/paddle/distribution/constraint.py b/python/paddle/distribution/constraint.py new file mode 100644 index 0000000000000000000000000000000000000000..d094a7607da9612c40bfe2070a100eb41f85dd4c --- /dev/null +++ b/python/paddle/distribution/constraint.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle + + +class Constraint(object): + """Constraint condition for random variable. + """ + + def __call__(self, value): + raise NotImplementedError + + +class Real(Constraint): + def __call__(self, value): + return value == value + + +class Range(Constraint): + def __init__(self, lower, upper): + self._lower = lower + self._upper = upper + super(Range, self).__init__() + + def __call__(self, value): + return self._lower <= value <= self._upper + + +class Positive(Constraint): + def __call__(self, value): + return value >= 0. + + +class Simplex(Constraint): + def __call__(self, value): + return paddle.all(value >= 0, axis=-1) and ( + (value.sum(-1) - 1).abs() < 1e-6) + + +real = Real() +positive = Positive() +simplex = Simplex() diff --git a/python/paddle/distribution/dirichlet.py b/python/paddle/distribution/dirichlet.py index ab7379a590e08dc5c86f4d569014bfa4518fa774..740f850b7c1da2bcd73ca7df24ba5e921286b210 100644 --- a/python/paddle/distribution/dirichlet.py +++ b/python/paddle/distribution/dirichlet.py @@ -13,14 +13,13 @@ # limitations under the License. import paddle +from paddle.distribution import exponential_family +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode +from paddle.fluid.layer_helper import LayerHelper -from ..fluid.data_feeder import check_variable_and_dtype -from ..fluid.framework import _non_static_mode -from ..fluid.layer_helper import LayerHelper -from .exponential_family import ExponentialFamily - -class Dirichlet(ExponentialFamily): +class Dirichlet(exponential_family.ExponentialFamily): r""" Dirichlet distribution with parameter "concentration". diff --git a/python/paddle/distribution/distribution.py b/python/paddle/distribution/distribution.py index 404ce933c402ce8fefe6033b597c06cd88cb593d..1c8edfa138d2e46bcc5056c58313fbfa5b146eb4 100644 --- a/python/paddle/distribution/distribution.py +++ b/python/paddle/distribution/distribution.py @@ -27,14 +27,14 @@ import warnings import numpy as np import paddle from paddle import _C_ops - -from ..fluid import core -from ..fluid.data_feeder import (check_dtype, check_type, - check_variable_and_dtype, convert_dtype) -from ..fluid.framework import _non_static_mode -from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, - elementwise_mul, elementwise_sub, nn, ops, tensor) -from ..tensor import arange, concat, gather_nd, multinomial +from paddle.fluid import core +from paddle.fluid.data_feeder import (check_dtype, check_type, + check_variable_and_dtype, convert_dtype) +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode +from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, + elementwise_mul, elementwise_sub, nn, ops, + tensor) +from paddle.tensor import arange, concat, gather_nd, multinomial class Distribution(object): @@ -78,10 +78,24 @@ class Distribution(object): """ return self._event_shape + @property + def mean(self): + """Mean of distribution""" + raise NotImplementedError + + @property + def variance(self): + """Variance of distribution""" + raise NotImplementedError + def sample(self, shape=()): """Sampling from the distribution.""" raise NotImplementedError + def rsample(self, shape=()): + """reparameterized sample""" + raise NotImplementedError + def entropy(self): """The entropy of the distribution.""" raise NotImplementedError @@ -96,7 +110,7 @@ class Distribution(object): Args: value (Tensor): value which will be evaluated """ - raise NotImplementedError + return self.log_prob(value).exp() def log_prob(self, value): """Log probability density/mass function.""" diff --git a/python/paddle/distribution/exponential_family.py b/python/paddle/distribution/exponential_family.py index 72a71edfba5b305ef0b292f24a6bf99f256841b2..e0236f9e6e2be77a436c4f697f49d21b13b15b94 100644 --- a/python/paddle/distribution/exponential_family.py +++ b/python/paddle/distribution/exponential_family.py @@ -13,12 +13,11 @@ # limitations under the License. import paddle +from paddle.distribution import distribution +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode -from ..fluid.framework import _non_static_mode -from .distribution import Distribution - -class ExponentialFamily(Distribution): +class ExponentialFamily(distribution.Distribution): r""" ExponentialFamily is the base class for probability distributions belonging to exponential family, whose probability mass/density function has the diff --git a/python/paddle/distribution/independent.py b/python/paddle/distribution/independent.py new file mode 100644 index 0000000000000000000000000000000000000000..3534a31591b275fc856d9f25c564fd8c464beaba --- /dev/null +++ b/python/paddle/distribution/independent.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distribution import distribution + + +class Independent(distribution.Distribution): + r""" + Reinterprets some of the batch dimensions of a distribution as event dimensions. + + This is mainly useful for changing the shape of the result of + :meth:`log_prob`. + + Args: + base (Distribution): The base distribution. + reinterpreted_batch_rank (int): The number of batch dimensions to + reinterpret as event dimensions. + + Examples: + + .. code-block:: python + + import paddle + from paddle.distribution import independent + + beta = paddle.distribution.Beta(paddle.to_tensor([0.5, 0.5]), paddle.to_tensor([0.5, 0.5])) + print(beta.batch_shape, beta.event_shape) + # (2,) () + print(beta.log_prob(paddle.to_tensor(0.2))) + # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [-0.22843921, -0.22843921]) + reinterpreted_beta = independent.Independent(beta, 1) + print(reinterpreted_beta.batch_shape, reinterpreted_beta.event_shape) + # () (2,) + print(reinterpreted_beta.log_prob(paddle.to_tensor([0.2, 0.2]))) + # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [-0.45687842]) + """ + + def __init__(self, base, reinterpreted_batch_rank): + if not isinstance(base, distribution.Distribution): + raise TypeError( + f"Expected type of 'base' is Distribution, but got {type(base)}") + if not (0 < reinterpreted_batch_rank <= len(base.batch_shape)): + raise ValueError( + f"Expected 0 < reinterpreted_batch_rank <= {len(base.batch_shape)}, but got {reinterpreted_batch_rank}" + ) + self._base = base + self._reinterpreted_batch_rank = reinterpreted_batch_rank + + shape = base.batch_shape + base.event_shape + super(Independent, self).__init__( + batch_shape=shape[:len(base.batch_shape) - + reinterpreted_batch_rank], + event_shape=shape[len(base.batch_shape) - + reinterpreted_batch_rank:]) + + @property + def mean(self): + return self._base.mean + + @property + def variance(self): + return self._base.variance + + def sample(self, shape=()): + return self._base.sample(shape) + + def log_prob(self, value): + return self._sum_rightmost( + self._base.log_prob(value), self._reinterpreted_batch_rank) + + def prob(self, value): + return self.log_prob(value).exp() + + def entropy(self): + return self._sum_rightmost(self._base.entropy(), + self._reinterpreted_batch_rank) + + def _sum_rightmost(self, value, n): + return value.sum(list(range(-n, 0))) if n > 0 else value diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 166244c51163f15d4864659734ec2867264c32c3..6310214117e9df2a11283caaaba9f7876d2ca0ac 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -15,15 +15,14 @@ import functools import warnings import paddle - -from ..fluid.framework import _non_static_mode -from .beta import Beta -from .categorical import Categorical -from .dirichlet import Dirichlet -from .distribution import Distribution -from .exponential_family import ExponentialFamily -from .normal import Normal -from .uniform import Uniform +from paddle.distribution.beta import Beta +from paddle.distribution.categorical import Categorical +from paddle.distribution.dirichlet import Dirichlet +from paddle.distribution.distribution import Distribution +from paddle.distribution.exponential_family import ExponentialFamily +from paddle.distribution.normal import Normal +from paddle.distribution.uniform import Uniform +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode __all__ = ["register_kl", "kl_divergence"] @@ -207,5 +206,4 @@ def _kl_expfamily_expfamily(p, q): def _sum_rightmost(value, n): - """Sum elements along rightmost n dim""" return value.sum(list(range(-n, 0))) if n > 0 else value diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 76d40131b9ea493d9de706a0ce566311e41e01d0..51a180271c63b917324ac8d1796d8c0274f82e41 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -17,18 +17,17 @@ import warnings import numpy as np from paddle import _C_ops +from paddle.distribution import distribution +from paddle.fluid import core +from paddle.fluid.data_feeder import (check_dtype, check_type, + check_variable_and_dtype, convert_dtype) +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode +from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, + elementwise_mul, elementwise_sub, nn, ops, + tensor) -from ..fluid import core -from ..fluid.data_feeder import (check_dtype, check_type, - check_variable_and_dtype, convert_dtype) -from ..fluid.framework import _non_static_mode -from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, - elementwise_mul, elementwise_sub, nn, ops, tensor) -from ..tensor import arange, concat, gather_nd, multinomial -from .distribution import Distribution - -class Normal(Distribution): +class Normal(distribution.Distribution): r"""The Normal distribution with location `loc` and `scale` parameters. Mathematical details @@ -129,6 +128,7 @@ class Normal(Distribution): if self.dtype != convert_dtype(self.loc.dtype): self.loc = tensor.cast(self.loc, dtype=self.dtype) self.scale = tensor.cast(self.scale, dtype=self.dtype) + super(Normal, self).__init__(self.loc.shape) def sample(self, shape, seed=0): """Generate samples of the specified shape. diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0e63f048ea5d942b97060b3ec24f6daac68362 --- /dev/null +++ b/python/paddle/distribution/transform.py @@ -0,0 +1,1231 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import functools +import math +import numbers +import operator +import typing + +import paddle +import paddle.nn.functional as F +from paddle.distribution import (constraint, distribution, + transformed_distribution, variable) + +__all__ = [ # noqa + 'Transform', + 'AbsTransform', + 'AffineTransform', + 'ChainTransform', + 'ExpTransform', + 'IndependentTransform', + 'PowerTransform', + 'ReshapeTransform', + 'SigmoidTransform', + 'SoftmaxTransform', + 'StackTransform', + 'StickBreakingTransform', + 'TanhTransform' +] + + +class Type(enum.Enum): + """Mapping type of a transformation. + """ + BIJECTION = 'bijection' # bijective(injective and surjective) + INJECTION = 'injection' # injective-only + SURJECTION = 'surjection' # surjective-only + OTHER = 'other' # general, neither injective nor surjective + + @classmethod + def is_injective(cls, _type): + """Both bijection and injection are injective mapping. + """ + return _type in (cls.BIJECTION, cls.INJECTION) + + +class Transform(object): + r"""Base class for the transformations of random variables. + + ``Transform`` can be used to represent any differentiable and injective + function from the subset of :math:`R^n` to subset of :math:`R^m`, generally + used for transforming a random sample generated by ``Distribution`` + instance. + + Suppose :math:`X` is a K-dimensional random variable with probability + density function :math:`p_X(x)`. A new random variable :math:`Y = f(X)` may + be defined by transforming :math:`X` with a suitably well-behaved funciton + :math:`f`. It suffices for what follows to note that if f is one-to-one and + its inverse :math:`f^{-1}` have a well-defined Jacobian, then the density of + :math:`Y` is + + .. math:: + + p_Y(y) = p_X(f^{-1}(y)) |det J_{f^{-1}}(y)| + + where det is the matrix determinant operation and :math:`J_{f^{-1}}(y)` is + the Jacobian matrix of :math:`f^{-1}` evaluated at :math:`y`. + Taking :math:`x = f^{-1}(y)`, the Jacobian matrix is defined by + + .. math:: + + J(y) = \begin{bmatrix} + {\frac{\partial x_1}{\partial y_1}} &{\frac{\partial x_1}{\partial y_2}} + &{\cdots} &{\frac{\partial x_1}{\partial y_K}} \\ + {\frac{\partial x_2}{\partial y_1}} &{\frac{\partial x_2} + {\partial y_2}}&{\cdots} &{\frac{\partial x_2}{\partial y_K}} \\ + {\vdots} &{\vdots} &{\ddots} &{\vdots}\\ + {\frac{\partial x_K}{\partial y_1}} &{\frac{\partial x_K}{\partial y_2}} + &{\cdots} &{\frac{\partial x_K}{\partial y_K}} + \end{bmatrix} + + A ``Transform`` can be characterized by three operations: + + #. forward + Forward implements :math:`x \rightarrow f(x)`, and is used to convert + one random outcome into another. + #. inverse + Undoes the transformation :math:`y \rightarrow f^{-1}(y)`. + #. log_det_jacobian + The log of the absolute value of the determinant of the matrix of all + first-order partial derivatives of the inverse function. + + Subclass typically implement follow methods: + + * _forward + * _inverse + * _forward_log_det_jacobian + * _inverse_log_det_jacobian (optional) + + If the transform changes the shape of the input, you must also implemented: + + * _forward_shape + * _inverse_shape + + """ + _type = Type.INJECTION + + def __init__(self): + super(Transform, self).__init__() + + @classmethod + def _is_injective(cls): + """Is the transformation type one-to-one or not. + + Returns: + bool: ``True`` denotes injective. ``False`` denotes non-injective. + """ + return Type.is_injective(cls._type) + + def __call__(self, input): + """Make this instance as a callable object. The return value is + depening on the input type. + + * If the input is a ``Tensor`` instance, return + ``self.forward(input)`` . + * If the input is a ``Distribution`` instance, return + ``TransformedDistribution(base=input, transforms=[self])`` . + * If the input is a ``Transform`` instance, return + ``ChainTransform([self, input])`` . + + Args: + input (Tensor|Distribution|Transform): The input value. + + Returns: + [Tensor|TransformedDistribution|ChainTransform]: The return value. + """ + if isinstance(input, distribution.Distribution): + return transformed_distribution.TransformedDistribution(input, + [self]) + if isinstance(input, Transform): + return ChainTransform([self, input]) + return self.forward(x) + + def forward(self, x): + """Forward transformation with mapping :math:`y = f(x)`. + + Useful for turning one random outcome into another. + + Args: + x (Tensos): Input parameter, generally is a sample generated + from ``Distribution``. + + Returns: + Tensor: Outcome of forward transformation. + """ + if not isinstance(x, paddle.fluid.framework.Variable): + raise TypeError( + f"Expected 'x' is a Tensor or Real, but got {type(x)}.") + if x.dim() < self._domain.event_rank: + raise ValueError( + f'The dimensions of x({x.dim()}) should be ' + f'grater than or equal to {self._domain.event_rank}') + return self._forward(x) + + def inverse(self, y): + """Inverse transformation :math:`x = f^{-1}(y)`. It's useful for "reversing" + a transformation to compute one probability in terms of another. + + Args: + y (Tensor): Input parameter for inverse transformation. + + Returns: + Tensor: Outcome of inverse transform. + """ + if not isinstance(y, paddle.fluid.framework.Variable): + raise TypeError( + f"Expected 'y' is a Tensor or Real, but got {type(y)}.") + if y.dim() < self._codomain.event_rank: + raise ValueError( + f'The dimensions of y({y.dim()}) should be ' + f'grater than or equal to {self._codomain.event_rank}') + return self._inverse(y) + + def forward_log_det_jacobian(self, x): + """The log of the absolute value of the determinant of the matrix of all + first-order partial derivatives of the inverse function. + + Args: + x (Tensor): Input tensor, generally is a sample generated from + ``Distribution`` + + Returns: + Tensor: The log of the absolute value of Jacobian determinant. + """ + if not isinstance(x, paddle.fluid.framework.Variable): + raise TypeError( + f"Expected 'y' is a Tensor or Real, but got {type(x)}.") + if isinstance(x, paddle.fluid.framework.Variable) and x.dim( + ) < self._domain.event_rank: + raise ValueError( + f'The dimensions of x({x.dim()}) should be ' + f'grater than or equal to {self._domain.event_rank}') + if not self._is_injective(): + raise NotImplementedError( + "forward_log_det_jacobian can't be implemented for non-injective" + "transforms.") + + return self._call_forward_log_det_jacobian(x) + + def inverse_log_det_jacobian(self, y): + """Compute :math:`log|det J_{f^{-1}}(y)|`. + Note that ``forward_log_det_jacobian`` is the negative of this function, + evaluated at :math:`f^{-1}(y)`. + + Args: + y (Tensor): The input to the ``inverse`` Jacobian determinant + evaluation. + + Returns: + Tensor: The value of :math:`log|det J_{f^{-1}}(y)|`. + """ + if not isinstance(y, paddle.fluid.framework.Variable): + raise TypeError(f"Expected 'y' is a Tensor, but got {type(y)}.") + if y.dim() < self._codomain.event_rank: + raise ValueError( + f'The dimensions of y({y.dim()}) should be ' + f'grater than or equal to {self._codomain.event_rank}') + return self._call_inverse_log_det_jacobian(y) + + def forward_shape(self, shape): + """Infer the shape of forward transformation. + + Args: + shape (Sequence[int]): The input shape. + + Returns: + Sequence[int]: The output shape. + """ + if not isinstance(shape, typing.Sequence): + raise TypeError( + f"Expected shape is Sequence[int] type, but got {type(shape)}.") + return self._forward_shape(shape) + + def inverse_shape(self, shape): + """Infer the shape of inverse transformation. + + Args: + shape (Sequence[int]): The input shape of inverse transformation. + + Returns: + Sequence[int]: The output shape of inverse transformation. + """ + if not isinstance(shape, typing.Sequence): + raise TypeError( + f"Expected shape is Sequence[int] type, but got {type(shape)}.") + return self._inverse_shape(shape) + + @property + def _domain(self): + """The domain of this transformation""" + return variable.real + + @property + def _codomain(self): + """The codomain of this transformation""" + return variable.real + + def _forward(self, x): + """Inner method for publid API ``forward``, subclass should + overwrite this method for supporting forward transformation. + """ + raise NotImplementedError('Forward not implemented') + + def _inverse(self, y): + """Inner method of public API ``inverse``, subclass should + overwrite this method for supporting inverse transformation. + """ + raise NotImplementedError('Inverse not implemented') + + def _call_forward_log_det_jacobian(self, x): + """Inner method called by ``forward_log_det_jacobian``.""" + if hasattr(self, '_forward_log_det_jacobian'): + return self._forward_log_det_jacobian(x) + if hasattr(self, '_inverse_log_det_jacobian'): + return -self._inverse_log_det_jacobian(self.forward(y)) + raise NotImplementedError( + 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian' + 'is implemented. One of them is required.') + + def _call_inverse_log_det_jacobian(self, y): + """Inner method called by ``inverse_log_det_jacobian``""" + if hasattr(self, '_inverse_log_det_jacobian'): + return self._inverse_log_det_jacobian(y) + if hasattr(self, '_forward_log_det_jacobian'): + return -self._forward_log_det_jacobian(self._inverse(y)) + raise NotImplementedError( + 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian ' + 'is implemented. One of them is required') + + def _forward_shape(self, shape): + """Inner method called by ``forward_shape``, which is used to infer the + forward shape. Subclass should overwrite this method for supporting + ``forward_shape``. + """ + return shape + + def _inverse_shape(self, shape): + """Inner method called by ``inverse_shape``, whic is used to infer the + invese shape. Subclass should overwrite this method for supporting + ``inverse_shape``. + """ + return shape + + +class AbsTransform(Transform): + r"""Absolute transformation with formula :math:`y = f(x) = abs(x)`, + element-wise. + + This non-injective transformation allows for transformations of scalar + distributions with the absolute value function, which maps ``(-inf, inf)`` + to ``[0, inf)`` . + + * For ``y`` in ``(0, inf)`` , ``AbsTransform.inverse(y)`` returns the set invese + ``{x in (-inf, inf) : |x| = y}`` as a tuple, ``-y, y`` . + * For ``y`` equal ``0`` , ``AbsTransform.inverse(0)`` returns ``0, 0``, which is not + the set inverse (the set inverse is the singleton {0}), but "works" in + conjunction with ``TransformedDistribution`` to produce a left + semi-continuous pdf. + * For ``y`` in ``(-inf, 0)`` , ``AbsTransform.inverse(y)`` returns the + wrong thing ``-y, y``. This is done for efficiency. + + Examples: + + .. code-block:: python + + import paddle + + abs = paddle.distribution.AbsTransform() + + print(abs.forward(paddle.to_tensor([-1., 0., 1.]))) + # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [1., 0., 1.]) + + print(abs.inverse(paddle.to_tensor(1.))) + # (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [-1.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [1.])) + + # The |dX/dY| is constant 1. So Log|dX/dY| == 0 + print(abs.inverse_log_det_jacobian(paddle.to_tensor(1.))) + # (Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # 0.), Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # 0.)) + + #Special case handling of 0. + print(abs.inverse(paddle.to_tensor(0.))) + # (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0.])) + print(abs.inverse_log_det_jacobian(paddle.to_tensor(0.))) + # (Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # 0.), Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # 0.)) + + """ + _type = Type.SURJECTION + + def _forward(self, x): + return x.abs() + + def _inverse(self, y): + return -y, y + + def _inverse_log_det_jacobian(self, y): + zero = paddle.zeros([1], dtype=y.dtype) + return zero, zero + + @property + def _domain(self): + return variable.real + + @property + def _codomain(self): + return variable.positive + + +class AffineTransform(Transform): + r"""Affine transformation with mapping + :math:`y = \text{loc} + \text{scale} \times x`. + + Args: + loc (Tensor): The location parameter. + scale (Tensor): The scale parameter. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.to_tensor([1., 2.]) + affine = paddle.distribution.AffineTransform(paddle.to_tensor(0.), paddle.to_tensor(1.)) + + print(affine.forward(x)) + # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [1., 2.]) + print(affine.inverse(affine.forward(x))) + # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [1., 2.]) + print(affine.forward_log_det_jacobian(x)) + # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0.]) + """ + _type = Type.BIJECTION + + def __init__(self, loc, scale): + if not isinstance(loc, paddle.fluid.framework.Variable): + raise TypeError(f"Expected 'loc' is a Tensor, but got {type(loc)}") + if not isinstance(scale, paddle.fluid.framework.Variable): + raise TypeError( + f"Expected scale is a Tensor, but got {type(scale)}") + self._loc = loc + self._scale = scale + super(AffineTransform, self).__init__() + + @property + def loc(self): + return self._loc + + @property + def scale(self): + return self._scale + + def _forward(self, x): + return self._loc + self._scale * x + + def _inverse(self, y): + return (y - self._loc) / self._scale + + def _forward_log_det_jacobian(self, x): + return paddle.abs(self._scale).log() + + def _forward_shape(self, shape): + return tuple( + paddle.broadcast_shape( + paddle.broadcast_shape(shape, self._loc.shape), + self._scale.shape)) + + def _inverse_shape(self, shape): + return tuple( + paddle.broadcast_shape( + paddle.broadcast_shape(shape, self._loc.shape), + self._scale.shape)) + + @property + def _domain(self): + return variable.real + + @property + def _codomain(self): + return variable.real + + +class ChainTransform(Transform): + r"""Composes multiple transforms in a chain. + + Args: + transforms (Sequence[Transform]): A sequence of transformations. + + Examples: + + .. code-block:: python + + import paddle + + + x = paddle.to_tensor([0., 1., 2., 3.]) + + chain = paddle.distribution.ChainTransform(( + paddle.distribution.AffineTransform( + paddle.to_tensor(0.), paddle.to_tensor(1.)), + paddle.distribution.ExpTransform() + )) + print(chain.forward(x)) + # Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [1. , 2.71828175 , 7.38905621 , 20.08553696]) + print(chain.inverse(chain.forward(x))) + # Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0., 1., 2., 3.]) + print(chain.forward_log_det_jacobian(x)) + # Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0., 1., 2., 3.]) + print(chain.inverse_log_det_jacobian(chain.forward(x))) + # Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [ 0., -1., -2., -3.]) + """ + + def __init__(self, transforms): + if not isinstance(transforms, typing.Sequence): + raise TypeError( + f"Expected type of 'transforms' is Sequence, but got {type(transforms)}" + ) + if not all(isinstance(t, Transform) for t in transforms): + raise TypeError( + "All elements of transforms should be Transform type.") + + self.transforms = transforms + super(ChainTransform, self).__init__() + + def _is_injective(self): + return all(t._is_injective() for t in self.transforms) + + def _forward(self, x): + for transform in self.transforms: + x = transform.forward(x) + return x + + def _inverse(self, y): + for transform in reversed(self.transforms): + y = transform.inverse(y) + return y + + def _forward_log_det_jacobian(self, x): + value = 0. + event_rank = self._domain.event_rank + for t in self.transforms: + value += self._sum_rightmost( + t.forward_log_det_jacobian(x), + event_rank - t._domain.event_rank) + x = t.forward(x) + event_rank += t._codomain.event_rank - t._domain.event_rank + return value + + def _forward_shape(self, shape): + for transform in self.transforms: + shape = transform.forward_shape(shape) + return shape + + def _inverse_shape(self, shape): + for transform in self.transforms: + shape = transform.inverse_shape(shape) + return shape + + def _sum_rightmost(self, value, n): + """sum value along rightmost n dim""" + return value.sum(list(range(-n, 0))) if n > 0 else value + + @property + def _domain(self): + domain = self.transforms[0]._domain + + # Compute the lower bound of input dimensions for chain transform. + # + # Suppose the dimensions of input tensor is N, and chain [t0,...ti,...tm], + # ti(in) denotes ti.domain.event_rank, ti(out) denotes ti.codomain.event_rank, + # delta(ti) denotes (ti(out) - ti(in)). + # For transform ti, N shoud satisfy the constraint: + # N + delta(t0) + delta(t1)...delta(t(i-1)) >= ti(in) + # So, for all transform in chain, N shoud satisfy follow constraints: + # t0: N >= t0(in) + # t1: N >= t1(in) - delta(t0) + # ... + # tm: N >= tm(in) - ... - delta(ti) - ... - delta(t0) + # + # Above problem can be solved more effectively use dynamic programming. + # Let N(i) denotes lower bound of transform ti, than the state + # transition equation is: + # N(i) = max{N(i+1)-delta(ti), ti(in)} + event_rank = self.transforms[-1]._codomain.event_rank + for t in reversed(self.transforms): + event_rank -= t._codomain.event_rank - t._domain.event_rank + event_rank = max(event_rank, t._domain.event_rank) + + return variable.Independent(domain, event_rank - domain.event_rank) + + @property + def _codomain(self): + codomain = self.transforms[-1]._codomain + + event_rank = self.transforms[0]._domain.event_rank + for t in self.transforms: + event_rank += t._codomain.event_rank - t._domain.event_rank + event_rank = max(event_rank, t._codomain.event_rank) + + return variable.Independent(codomain, event_rank - codomain.event_rank) + + +class ExpTransform(Transform): + r"""Exponent transformation with mapping :math:`y = \exp(x)`. + + Exapmles: + + .. code-block:: python + + import paddle + + exp = paddle.distribution.ExpTransform() + print(exp.forward(paddle.to_tensor([1., 2., 3.]))) + # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [2.71828175 , 7.38905621 , 20.08553696]) + + print(exp.inverse(paddle.to_tensor([1., 2., 3.]))) + # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0. , 0.69314718, 1.09861231]) + + print(exp.forward_log_det_jacobian(paddle.to_tensor([1., 2., 3.]))) + # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [1., 2., 3.]) + + print(exp.inverse_log_det_jacobian(paddle.to_tensor([1., 2., 3.]))) + # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [ 0. , -0.69314718, -1.09861231]) + """ + _type = Type.BIJECTION + + def __init__(self): + super(ExpTransform, self).__init__() + + @property + def _domain(self): + return variable.real + + @property + def _codomain(self): + return variable.positive + + def _forward(self, x): + return x.exp() + + def _inverse(self, y): + return y.log() + + def _forward_log_det_jacobian(self, x): + return x + + +class IndependentTransform(Transform): + r""" + ``IndependentTransform`` wraps a base transformation, reinterprets + some of the rightmost batch axes as event axes. + + Generally, it is used to expand the event axes. This has no effect on the + forward or inverse transformaion, but does sum out the + ``reinterpretd_bach_rank`` rightmost dimensions in computing the determinant + of Jacobian matrix. + + To see this, consider the ``ExpTransform`` applied to a Tensor which has + sample, batch, and event ``(S,B,E)`` shape semantics. Suppose the Tensor's + paritioned-shape is ``(S=[4], B=[2, 2], E=[3])`` , reinterpreted_batch_rank + is 1. Then the reinterpreted Tensor's shape is ``(S=[4], B=[2], E=[2, 3])`` . + The shape returned by ``forward`` and ``inverse`` is unchanged, ie, + ``[4,2,2,3]`` . However the shape returned by ``inverse_log_det_jacobian`` + is ``[4,2]``, because the Jacobian determinant is a reduction over the + event dimensions. + + Args: + base (Transform): The base transformation. + reinterpreted_batch_rank (int): The num of rightmost batch rank that + will be reinterpreted as event rank. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[1., 2., 3.], [4., 5., 6.]]) + + # Exponential transform with event_rank = 1 + multi_exp = paddle.distribution.IndependentTransform( + paddle.distribution.ExpTransform(), 1) + print(multi_exp.forward(x)) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[2.71828175 , 7.38905621 , 20.08553696 ], + # [54.59814835 , 148.41316223, 403.42880249]]) + print(multi_exp.forward_log_det_jacobian(x)) + # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [6. , 15.]) + """ + + def __init__(self, base, reinterpreted_batch_rank): + if not isinstance(base, Transform): + raise TypeError( + f"Expected 'base' is Transform type, but get {type(base)}") + if reinterpreted_batch_rank <= 0: + raise ValueError( + f"Expected 'reinterpreted_batch_rank' is grater than zero, but got {reinterpreted_batch_rank}" + ) + + self._base = base + self._reinterpreted_batch_rank = reinterpreted_batch_rank + super(IndependentTransform, self).__init__() + + def _is_injective(self): + return self._base._is_injective() + + def _forward(self, x): + if x.dim() < self._domain.event_rank: + raise ValueError("Input dimensions is less than event dimensions.") + return self._base.forward(x) + + def _inverse(self, y): + if y.dim() < self._codomain.event_rank: + raise ValueError("Input dimensions is less than event dimensions.") + return self._base.inverse(y) + + def _forward_log_det_jacobian(self, x): + return self._base.forward_log_det_jacobian(x).sum( + list(range(-self._reinterpreted_batch_rank, 0))) + + def _forward_shape(self, shape): + return self._base.forward_shape(shape) + + def _inverse_shape(self, shape): + return self._base.inverse_shape(shape) + + @property + def _domain(self): + return variable.Independent(self._base._domain, + self._reinterpreted_batch_rank) + + @property + def _codomain(self): + return variable.Independent(self._base._codomain, + self._reinterpreted_batch_rank) + + +class PowerTransform(Transform): + r""" + Power transformation with mapping :math:`y = x^{\text{power}}`. + + Args: + power (Tensor): The power parameter. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.to_tensor([1., 2.]) + power = paddle.distribution.PowerTransform(paddle.to_tensor(2.)) + + print(power.forward(x)) + # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [1., 4.]) + print(power.inverse(power.forward(x))) + # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [1., 2.]) + print(power.forward_log_det_jacobian(x)) + # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0.69314718, 1.38629436]) + """ + _type = Type.BIJECTION + + def __init__(self, power): + if not isinstance(power, paddle.fluid.framework.Variable): + raise TypeError( + f"Expected 'power' is a tensor, but got {type(power)}") + self._power = power + super(PowerTransform, self).__init__() + + @property + def power(self): + return self._power + + @property + def _domain(self): + return variable.real + + @property + def _codomain(self): + return variable.positive + + def _forward(self, x): + return x.pow(self._power) + + def _inverse(self, y): + return y.pow(1 / self._power) + + def _forward_log_det_jacobian(self, x): + return (self._power * x.pow(self._power - 1)).abs().log() + + def _forward_shape(self, shape): + return tuple(paddle.broadcast_shape(shape, self._power.shape)) + + def _inverse_shape(self, shape): + return tuple(paddle.broadcast_shape(shape, self._power.shape)) + + +class ReshapeTransform(Transform): + r"""Reshape the event shape of a tensor. + + Note that ``in_event_shape`` and ``out_event_shape`` must have the same + number of elements. + + Args: + in_event_shape(Sequence[int]): The input event shape. + out_event_shape(Sequence[int]): The output event shape. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.ones((1,2,3)) + reshape_transform = paddle.distribution.ReshapeTransform((2, 3), (3, 2)) + print(reshape_transform.forward_shape((1,2,3))) + # (5, 2, 6) + print(reshape_transform.forward(x)) + # Tensor(shape=[1, 3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[[1., 1.], + # [1., 1.], + # [1., 1.]]]) + print(reshape_transform.inverse(reshape_transform.forward(x))) + # Tensor(shape=[1, 2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[[1., 1., 1.], + # [1., 1., 1.]]]) + print(reshape_transform.forward_log_det_jacobian(x)) + # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0.]) + """ + _type = Type.BIJECTION + + def __init__(self, in_event_shape, out_event_shape): + if not isinstance(in_event_shape, typing.Sequence) or not isinstance( + out_event_shape, typing.Sequence): + raise TypeError( + f"Expected type of 'in_event_shape' and 'out_event_shape' is " + f"Squence[int], but got 'in_event_shape': {in_event_shape}, " + f"'out_event_shape': {out_event_shape}") + if functools.reduce(operator.mul, in_event_shape) != functools.reduce( + operator.mul, out_event_shape): + raise ValueError( + f"The numel of 'in_event_shape' should be 'out_event_shape', " + f"but got {functools.reduce(operator.mul, in_event_shape)}!={functools.reduce(operator.mul, out_event_shape)}" + ) + + self._in_event_shape = tuple(in_event_shape) + self._out_event_shape = tuple(out_event_shape) + super(ReshapeTransform, self).__init__() + + @property + def in_event_shape(self): + return self._in_event_shape + + @property + def out_event_shape(self): + return self._out_event_shape + + @property + def _domain(self): + return variable.Independent(variable.real, len(self._in_event_shape)) + + @property + def _codomain(self): + return variable.Independent(variable.real, len(self._out_event_shape)) + + def _forward(self, x): + return x.reshape( + tuple(x.shape)[:x.dim() - len(self._in_event_shape)] + + self._out_event_shape) + + def _inverse(self, y): + return y.reshape( + tuple(y.shape)[:y.dim() - len(self._out_event_shape)] + + self._in_event_shape) + + def _forward_shape(self, shape): + if len(shape) < len(self._in_event_shape): + raise ValueError( + f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}" + ) + if shape[-len(self._in_event_shape):] != self._in_event_shape: + raise ValueError( + f"Event shape mismatch, expected: {self._in_event_shape}, but got {shape[-len(self._in_event_shape):]}" + ) + return tuple(shape[:-len(self._in_event_shape)]) + self._out_event_shape + + def _inverse_shape(self, shape): + if len(shape) < len(self._out_event_shape): + raise ValueError( + f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}" + ) + if shape[-len(self._out_event_shape):] != self._out_event_shape: + raise ValueError( + f"Event shape mismatch, expected: {self._out_event_shape}, but got {shape[-len(self._out_event_shape):]}" + ) + return tuple(shape[:-len(self._out_event_shape)]) + self._in_event_shape + + def _forward_log_det_jacobian(self, x): + # paddle.zeros not support zero dimension Tensor. + shape = x.shape[:x.dim() - len(self._in_event_shape)] or [1] + return paddle.zeros(shape, dtype=x.dtype) + + +class SigmoidTransform(Transform): + r"""Sigmoid transformation with mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.ones((2,3)) + t = paddle.distribution.SigmoidTransform() + print(t.forward(x)) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0.73105860, 0.73105860, 0.73105860], + # [0.73105860, 0.73105860, 0.73105860]]) + print(t.inverse(t.forward(x))) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[1.00000012, 1.00000012, 1.00000012], + # [1.00000012, 1.00000012, 1.00000012]]) + print(t.forward_log_det_jacobian(x)) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[-1.62652326, -1.62652326, -1.62652326], + # [-1.62652326, -1.62652326, -1.62652326]]) + """ + + @property + def _domain(self): + return variable.real + + @property + def _codomain(self): + return variable.Variable(False, 0, constraint.Range(0., 1.)) + + def _forward(self, x): + return F.sigmoid(x) + + def _inverse(self, y): + return y.log() - (-y).log1p() + + def _forward_log_det_jacobian(self, x): + return -F.softplus(-x) - F.softplus(x) + + +class SoftmaxTransform(Transform): + r"""Softmax transformation with mapping :math:`y=\exp(x)` then normalizing. + + It's generally used to convert unconstrained space to simplex. This mapping + is not injective, so ``forward_log_det_jacobian`` and + ``inverse_log_det_jacobian`` are not implemented. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.ones((2,3)) + t = paddle.distribution.SoftmaxTransform() + print(t.forward(x)) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0.33333334, 0.33333334, 0.33333334], + # [0.33333334, 0.33333334, 0.33333334]]) + print(t.inverse(t.forward(x))) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[-1.09861231, -1.09861231, -1.09861231], + # [-1.09861231, -1.09861231, -1.09861231]]) + """ + _type = Type.OTHER + + @property + def _domain(self): + return variable.Independent(variable.real, 1) + + @property + def _codomain(self): + return variable.Variable(False, 1, constraint.simplex) + + def _forward(self, x): + x = (x - x.max(-1, keepdim=True)[0]).exp() + return x / x.sum(-1, keepdim=True) + + def _inverse(self, y): + return y.log() + + def _forward_shape(self, shape): + if len(shape) < 1: + raise ValueError( + f"Expected length of shape is grater than 1, but got {len(shape)}" + ) + return shape + + def _inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError( + f"Expected length of shape is grater than 1, but got {len(shape)}" + ) + return shape + + +class StackTransform(Transform): + r""" ``StackTransform`` applies a sequence of transformations along the + specific axis. + + Args: + transforms(Sequence[Transform]): The sequence of transformations. + axis(int): The axis along which will be transformed. + + Examples: + + .. code-block:: python + + import paddle + + + x = paddle.stack( + (paddle.to_tensor([1., 2., 3.]), paddle.to_tensor([1, 2., 3.])), 1) + t = paddle.distribution.StackTransform( + (paddle.distribution.ExpTransform(), + paddle.distribution.PowerTransform(paddle.to_tensor(2.))), + 1 + ) + print(t.forward(x)) + # Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[2.71828175 , 1. ], + # [7.38905621 , 4. ], + # [20.08553696, 9. ]]) + print(t.inverse(t.forward(x))) + # Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[1., 1.], + # [2., 2.], + # [3., 3.]]) + print(t.forward_log_det_jacobian(x)) + # Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[1. , 0.69314718], + # [2. , 1.38629436], + # [3. , 1.79175949]]) + """ + + def __init__(self, transforms, axis=0): + if not transforms or not isinstance(transforms, typing.Sequence): + raise TypeError( + f"Expected 'transforms' is Sequence[Transform], but got {type(transforms)}." + ) + if not all(isinstance(t, Transform) for t in transforms): + raise TypeError( + 'Expected all element in transforms is Transform Type.') + if not isinstance(axis, int): + raise TypeError(f"Expected 'axis' is int, but got{type(axis)}.") + + self._transforms = transforms + self._axis = axis + + def _is_injective(self): + return all(t._is_injective() for t in self._transforms) + + @property + def transforms(self): + return self._transforms + + @property + def axis(self): + return self._axis + + def _forward(self, x): + self._check_size(x) + return paddle.stack([ + t.forward(v) + for v, t in zip(paddle.unstack(x, self._axis), self._transforms) + ], self._axis) + + def _inverse(self, y): + self._check_size(y) + return paddle.stack([ + t.inverse(v) + for v, t in zip(paddle.unstack(y, self._axis), self._transforms) + ], self._axis) + + def _forward_log_det_jacobian(self, x): + self._check_size(x) + return paddle.stack([ + t.forward_log_det_jacobian(v) + for v, t in zip(paddle.unstack(x, self._axis), self._transforms) + ], self._axis) + + def _check_size(self, v): + if not (-v.dim() <= self._axis < v.dim()): + raise ValueError( + f'Input dimensions {v.dim()} should be grater than stack ' + f'transform axis {self._axis}.') + if v.shape[self._axis] != len(self._transforms): + raise ValueError( + f'Input size along {self._axis} should be equal to the ' + f'length of transforms.') + + @property + def _domain(self): + return variable.Stack([t._domain for t in self._transforms], self._axis) + + @property + def _codomain(self): + return variable.Stack([t._codomain for t in self._transforms], + self._axis) + + +class StickBreakingTransform(Transform): + r"""Convert an unconstrained vector to the simplex with one additional + dimension by the stick-breaking construction. + + Examples: + + .. code-block:: python + + import paddle + + + x = paddle.to_tensor([1.,2.,3.]) + t = paddle.distribution.StickBreakingTransform() + print(t.forward(x)) + # Tensor(shape=[4], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0.47536686, 0.41287899, 0.10645414, 0.00530004]) + print(t.inverse(t.forward(x))) + # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [0.99999988, 2. , 2.99999881]) + print(t.forward_log_det_jacobian(x)) + # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [-9.10835075]) + """ + + _type = Type.BIJECTION + + def _forward(self, x): + offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) + z = F.sigmoid(x - offset.log()) + z_cumprod = (1 - z).cumprod(-1) + return F.pad(z, [0, 1], value=1) * F.pad(z_cumprod, [1, 0], value=1) + + def _inverse(self, y): + y_crop = y[..., :-1] + offset = y.shape[-1] - paddle.ones([y_crop.shape[-1]]).cumsum(-1) + sf = 1 - y_crop.cumsum(-1) + x = y_crop.log() - sf.log() + offset.log() + return x + + def _forward_log_det_jacobian(self, x): + y = self.forward(x) + offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) + x = x - offset.log() + return (-x + F.log_sigmoid(x) + y[..., :-1].log()).sum(-1) + + def _forward_shape(self, shape): + if not shape: + raise ValueError(f"Expected 'shape' is not empty, but got {shape}") + return shape[:-1] + (shape[-1] + 1, ) + + def _inverse_shape(self, shape): + if not shape: + raise ValueError(f"Expected 'shape' is not empty, but got {shape}") + return shape[:-1] + (shape[-1] - 1, ) + + @property + def _domain(self): + return variable.Independent(variable.real, 1) + + @property + def _codomain(self): + return variable.Variable(False, 1, constraint.simplex) + + +class TanhTransform(Transform): + r"""Tanh transformation with mapping :math:`y = \tanh(x)`. + + Examples + + .. code-block:: python + + import paddle + + tanh = paddle.distribution.TanhTransform() + + x = paddle.to_tensor([[1., 2., 3.], [4., 5., 6.]]) + + print(tanh.forward(x)) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0.76159418, 0.96402758, 0.99505478], + # [0.99932933, 0.99990922, 0.99998772]]) + print(tanh.inverse(tanh.forward(x))) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[1.00000012, 2. , 3.00000286], + # [4.00002146, 5.00009823, 6.00039864]]) + print(tanh.forward_log_det_jacobian(x)) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[-0.86756170 , -2.65000558 , -4.61865711 ], + # [-6.61437654 , -8.61379623 , -10.61371803]]) + print(tanh.inverse_log_det_jacobian(tanh.forward(x))) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0.86756176 , 2.65000558 , 4.61866283 ], + # [6.61441946 , 8.61399269 , 10.61451530]]) + """ + _type = Type.BIJECTION + + @property + def _domain(self): + return variable.real + + @property + def _codomain(self): + return variable.Variable(False, 0, constraint.Range(-1.0, 1.0)) + + def _forward(self, x): + return x.tanh() + + def _inverse(self, y): + return y.atanh() + + def _forward_log_det_jacobian(self, x): + """We implicitly rely on _forward_log_det_jacobian rather than + explicitly implement ``_inverse_log_det_jacobian`` since directly using + ``-tf.math.log1p(-tf.square(y))`` has lower numerical precision. + + See details: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80 + """ + return 2. * (math.log(2.) - x - F.softplus(-2. * x)) diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7aa5886ae245731cd0aa9a9c802c4f14d21b9a --- /dev/null +++ b/python/paddle/distribution/transformed_distribution.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +from paddle.distribution import distribution +from paddle.distribution import transform +from paddle.distribution import independent + + +class TransformedDistribution(distribution.Distribution): + r""" + Applies a sequence of Transforms to a base distribution. + + Args: + base (Distribution): The base distribution. + transforms (Sequence[Transform]): A sequence of ``Transform`` . + + Examples: + + .. code-block:: python + + import paddle + from paddle.distribution import transformed_distribution + + d = transformed_distribution.TransformedDistribution( + paddle.distribution.Normal(0., 1.), + [paddle.distribution.AffineTransform(paddle.to_tensor(1.), paddle.to_tensor(2.))] + ) + + print(d.sample([10])) + # Tensor(shape=[10], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [-0.10697651, 3.33609009, -0.86234951, 5.07457638, 0.75925219, + # -4.17087793, 2.22579336, -0.93845034, 0.66054249, 1.50957513]) + print(d.log_prob(paddle.to_tensor(0.5))) + # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [-1.64333570]) + """ + + def __init__(self, base, transforms): + if not isinstance(base, distribution.Distribution): + raise TypeError( + f"Expected type of 'base' is Distribution, but got {type(base)}." + ) + if not isinstance(transforms, typing.Sequence): + raise TypeError( + f"Expected type of 'transforms' is Sequence[Transform] or Chain, but got {type(transforms)}." + ) + if not all(isinstance(t, transform.Transform) for t in transforms): + raise TypeError("All element of transforms must be Transform type.") + + chain = transform.ChainTransform(transforms) + if len(base.batch_shape + base.event_shape) < chain._domain.event_rank: + raise ValueError( + f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}." + ) + if chain._domain.event_rank > len(base.event_shape): + base = independent.Independent( + (base, chain._domain.event_rank - len(base.event_shape))) + self._base = base + self._transforms = transforms + + transformed_shape = chain.forward_shape(base.batch_shape + + base.event_shape) + transformed_event_rank = chain._codomain.event_rank + \ + max(len(base.event_shape)-chain._domain.event_rank, 0) + super(TransformedDistribution, self).__init__( + transformed_shape[:len(transformed_shape) - transformed_event_rank], + transformed_shape[:len(transformed_shape) - transformed_event_rank]) + + def sample(self, shape=()): + """Sample from ``TransformedDistribution``. + + Args: + shape (tuple, optional): The sample shape. Defaults to (). + + Returns: + [Tensor]: The sample result. + """ + x = self._base.sample(shape) + for t in self._transforms: + x = t.forward(x) + return x + + def log_prob(self, value): + """The log probability evaluated at value. + + Args: + value (Tensor): The value to be evaluated. + + Returns: + Tensor: The log probability. + """ + log_prob = 0.0 + y = value + event_rank = len(self.event_shape) + for t in reversed(self._transforms): + x = t.inverse(y) + event_rank += t._domain.event_rank - t._codomain.event_rank + log_prob = log_prob - \ + _sum_rightmost(t.forward_log_det_jacobian( + x), event_rank-t._domain.event_rank) + y = x + log_prob += _sum_rightmost( + self._base.log_prob(y), event_rank - len(self._base.event_shape)) + return log_prob + + +def _sum_rightmost(value, n): + return value.sum(list(range(-n, 0))) if n > 0 else value diff --git a/python/paddle/distribution/uniform.py b/python/paddle/distribution/uniform.py index 61f07bf9f70f7824e835e161bf583e55d1328934..5957dab14ef38182e3c4a00c36a78b3dfbd9e03c 100644 --- a/python/paddle/distribution/uniform.py +++ b/python/paddle/distribution/uniform.py @@ -17,18 +17,18 @@ import warnings import numpy as np from paddle import _C_ops - -from ..fluid import core -from ..fluid.data_feeder import (check_dtype, check_type, - check_variable_and_dtype, convert_dtype) -from ..fluid.framework import _non_static_mode -from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, - elementwise_mul, elementwise_sub, nn, ops, tensor) -from ..tensor import arange, concat, gather_nd, multinomial -from .distribution import Distribution - - -class Uniform(Distribution): +from paddle.distribution import distribution +from paddle.fluid import core +from paddle.fluid.data_feeder import (check_dtype, check_type, + check_variable_and_dtype, convert_dtype) +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode +from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, + elementwise_mul, elementwise_sub, nn, ops, + tensor) +from paddle.tensor import arange, concat, gather_nd, multinomial + + +class Uniform(distribution.Distribution): r"""Uniform distribution with `low` and `high` parameters. Mathematical Details diff --git a/python/paddle/distribution/variable.py b/python/paddle/distribution/variable.py new file mode 100644 index 0000000000000000000000000000000000000000..6ece1c3a1d83e9a97156cf7a22a862ef597c0f60 --- /dev/null +++ b/python/paddle/distribution/variable.py @@ -0,0 +1,109 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distribution import constraint + + +class Variable(object): + """Random variable of probability distribution. + + Args: + is_discrete (bool): Is the variable discrete or continuous. + event_rank (int): The rank of event dimensions. + """ + + def __init__(self, is_discrete=False, event_rank=0, constraint=None): + self._is_discrete = is_discrete + self._event_rank = event_rank + self._constraint = constraint + + @property + def is_discrete(self): + return self._is_discrete + + @property + def event_rank(self): + return self._event_rank + + def constraint(self, value): + """Check whether the 'value' meet the constraint conditions of this + random variable.""" + return self._constraint(value) + + +class Real(Variable): + def __init__(self, event_rank=0): + super(Real, self).__init__(False, event_rank, constraint.real) + + +class Positive(Variable): + def __init__(self, event_rank=0): + super(Positive, self).__init__(False, event_rank, constraint.positive) + + +class Independent(Variable): + """Reinterprets some of the batch axes of variable as event axes. + + Args: + base (Variable): Base variable. + reinterpreted_batch_rank (int): The rightmost batch rank to be + reinterpreted. + """ + + def __init__(self, base, reinterpreted_batch_rank): + self._base = base + self._reinterpreted_batch_rank = reinterpreted_batch_rank + super(Independent, self).__init__( + base.is_discrete, base.event_rank + reinterpreted_batch_rank) + + def constraint(self, value): + ret = self._base.constraint(value) + if ret.dim() < self._reinterpreted_batch_rank: + raise ValueError( + "Input dimensions must be equal or grater than {}".format( + self._reinterpreted_batch_rank)) + return ret.reshape(ret.shape[:ret.dim() - self.reinterpreted_batch_rank] + + (-1, )).all(-1) + + +class Stack(Variable): + def __init__(self, vars, axis=0): + self._vars = vars + self._axis = axis + + @property + def is_discrete(self): + return any(var.is_discrete for var in self._vars) + + @property + def event_rank(self): + rank = max(var.event_rank for var in self._vars) + if self._axis + rank < 0: + rank += 1 + return rank + + def constraint(self, value): + if not (-value.dim() <= self._axis < value.dim()): + raise ValueError( + f'Input dimensions {value.dim()} should be grater than stack ' + f'constraint axis {self._axis}.') + + return paddle.stack([ + var.check(value) + for var, value in zip(self._vars, paddle.unstack(value, self._axis)) + ], self._axis) + + +real = Real() +positive = Positive() diff --git a/python/paddle/fluid/tests/unittests/distribution/config.py b/python/paddle/fluid/tests/unittests/distribution/config.py index 809dfb2b56d6636c1cd602a46ffd76aec6ae978b..aee76250e5d142c02a08cd494b65435c06be9747 100644 --- a/python/paddle/fluid/tests/unittests/distribution/config.py +++ b/python/paddle/fluid/tests/unittests/distribution/config.py @@ -11,11 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import contextlib -import sys - -import numpy as np import paddle DEVICES = [paddle.CPUPlace()] @@ -34,66 +29,3 @@ RTOL = { 'complex128': 1e-5 } ATOL = {'float32': 0.0, 'complex64': 0, 'float64': 0.0, 'complex128': 0} - - -def xrand(shape=(10, 10, 10), dtype=DEFAULT_DTYPE, min=1.0, max=10.0): - return ((np.random.rand(*shape).astype(dtype)) * (max - min) + min) - - -def place(devices, key='place'): - def decorate(cls): - module = sys.modules[cls.__module__].__dict__ - raw_classes = { - k: v - for k, v in module.items() if k.startswith(cls.__name__) - } - - for raw_name, raw_cls in raw_classes.items(): - for d in devices: - test_cls = dict(raw_cls.__dict__) - test_cls.update({key: d}) - new_name = raw_name + '.' + d.__class__.__name__ - module[new_name] = type(new_name, (raw_cls, ), test_cls) - del module[raw_name] - return cls - - return decorate - - -def parameterize(fields, values=None): - - fields = [fields] if isinstance(fields, str) else fields - params = [dict(zip(fields, vals)) for vals in values] - - def decorate(cls): - test_cls_module = sys.modules[cls.__module__].__dict__ - for k, v in enumerate(params): - test_cls = dict(cls.__dict__) - test_cls.update(v) - name = cls.__name__ + str(k) - name = name + '.' + v.get('suffix') if v.get('suffix') else name - - test_cls_module[name] = type(name, (cls, ), test_cls) - - for m in list(cls.__dict__): - if m.startswith("test"): - delattr(cls, m) - return cls - - return decorate - - -@contextlib.contextmanager -def stgraph(func, *args): - """static graph exec context""" - paddle.enable_static() - mp, sp = paddle.static.Program(), paddle.static.Program() - with paddle.static.program_guard(mp, sp): - input = paddle.static.data('input', x.shape, dtype=x.dtype) - output = func(input, n, axes, norm) - - exe = paddle.static.Executor(place) - exe.run(sp) - [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) - yield output - paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/distribution/parameterize.py b/python/paddle/fluid/tests/unittests/distribution/parameterize.py new file mode 100644 index 0000000000000000000000000000000000000000..09aa241b15dfe729994b7494952983c497bac997 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/parameterize.py @@ -0,0 +1,272 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import contextlib +import functools +import inspect +import re +import sys + +import numpy as np +import config + +TEST_CASE_NAME = 'suffix' + + +def xrand(shape=(10, 10, 10), dtype=config.DEFAULT_DTYPE, min=1.0, max=10.0): + return ((np.random.rand(*shape).astype(dtype)) * (max - min) + min) + + +def place(devices, key='place'): + def decorate(cls): + module = sys.modules[cls.__module__].__dict__ + raw_classes = { + k: v + for k, v in module.items() if k.startswith(cls.__name__) + } + + for raw_name, raw_cls in raw_classes.items(): + for d in devices: + test_cls = dict(raw_cls.__dict__) + test_cls.update({key: d}) + new_name = raw_name + '.' + d.__class__.__name__ + module[new_name] = type(new_name, (raw_cls, ), test_cls) + del module[raw_name] + return cls + + return decorate + + +def parameterize_cls(fields, values=None): + fields = [fields] if isinstance(fields, str) else fields + params = [dict(zip(fields, vals)) for vals in values] + + def decorate(cls): + test_cls_module = sys.modules[cls.__module__].__dict__ + for k, v in enumerate(params): + test_cls = dict(cls.__dict__) + test_cls.update(v) + name = cls.__name__ + str(k) + name = name + '.' + v.get('suffix') if v.get('suffix') else name + + test_cls_module[name] = type(name, (cls, ), test_cls) + + for m in list(cls.__dict__): + if m.startswith("test"): + delattr(cls, m) + return cls + + return decorate + + +def parameterize_func(input, name_func=None, doc_func=None, + skip_on_empty=False): + doc_func = doc_func or default_doc_func + name_func = name_func or default_name_func + + def wrapper(f, instance=None): + frame_locals = inspect.currentframe().f_back.f_locals + + parameters = input_as_callable(input)() + + if not parameters: + if not skip_on_empty: + raise ValueError( + "Parameters iterable is empty (hint: use " + "`parameterized.expand([], skip_on_empty=True)` to skip " + "this test when the input is empty)") + return wraps(f)(skip_on_empty_helper) + + digits = len(str(len(parameters) - 1)) + for num, p in enumerate(parameters): + name = name_func( + f, "{num:0>{digits}}".format( + digits=digits, num=num), p) + # If the original function has patches applied by 'mock.patch', + # re-construct all patches on the just former decoration layer + # of param_as_standalone_func so as not to share + # patch objects between new functions + nf = reapply_patches_if_need(f) + frame_locals[name] = param_as_standalone_func(p, nf, name) + frame_locals[name].__doc__ = doc_func(f, num, p) + + # Delete original patches to prevent new function from evaluating + # original patching object as well as re-constructed patches. + delete_patches_if_need(f) + + f.__test__ = False + + return wrapper + + +def reapply_patches_if_need(func): + def dummy_wrapper(orgfunc): + @wraps(orgfunc) + def dummy_func(*args, **kwargs): + return orgfunc(*args, **kwargs) + + return dummy_func + + if hasattr(func, 'patchings'): + func = dummy_wrapper(func) + tmp_patchings = func.patchings + delattr(func, 'patchings') + for patch_obj in tmp_patchings: + func = patch_obj.decorate_callable(func) + return func + + +def delete_patches_if_need(func): + if hasattr(func, 'patchings'): + func.patchings[:] = [] + + +def default_name_func(func, num, p): + base_name = func.__name__ + name_suffix = "_%s" % (num, ) + + if len(p.args) > 0 and isinstance(p.args[0], str): + name_suffix += "_" + to_safe_name(p.args[0]) + return base_name + name_suffix + + +def default_doc_func(func, num, p): + if func.__doc__ is None: + return None + + all_args_with_values = parameterized_argument_value_pairs(func, p) + + # Assumes that the function passed is a bound method. + descs = ["%s=%s" % (n, short_repr(v)) for n, v in all_args_with_values] + + # The documentation might be a multiline string, so split it + # and just work with the first string, ignoring the period + # at the end if there is one. + first, nl, rest = func.__doc__.lstrip().partition("\n") + suffix = "" + if first.endswith("."): + suffix = "." + first = first[:-1] + args = "%s[with %s]" % (len(first) and " " or "", ", ".join(descs)) + return "".join(to_text(x) for x in [first.rstrip(), args, suffix, nl, rest]) + + +def param_as_standalone_func(p, func, name): + @functools.wraps(func) + def standalone_func(*a): + return func(*(a + p.args), **p.kwargs) + + standalone_func.__name__ = name + + # place_as is used by py.test to determine what source file should be + # used for this test. + standalone_func.place_as = func + + # Remove __wrapped__ because py.test will try to look at __wrapped__ + # to determine which parameters should be used with this test case, + # and obviously we don't need it to do any parameterization. + try: + del standalone_func.__wrapped__ + except AttributeError: + pass + return standalone_func + + +def input_as_callable(input): + if callable(input): + return lambda: check_input_values(input()) + input_values = check_input_values(input) + return lambda: input_values + + +def check_input_values(input_values): + if not isinstance(input_values, list): + input_values = list(input_values) + return [param.from_decorator(p) for p in input_values] + + +def skip_on_empty_helper(*a, **kw): + raise SkipTest("parameterized input is empty") + + +_param = collections.namedtuple("param", "args kwargs") + + +class param(_param): + def __new__(cls, *args, **kwargs): + return _param.__new__(cls, args, kwargs) + + @classmethod + def explicit(cls, args=None, kwargs=None): + """ Creates a ``param`` by explicitly specifying ``args`` and + ``kwargs``:: + >>> param.explicit([1,2,3]) + param(*(1, 2, 3)) + >>> param.explicit(kwargs={"foo": 42}) + param(*(), **{"foo": "42"}) + """ + args = args or () + kwargs = kwargs or {} + return cls(*args, **kwargs) + + @classmethod + def from_decorator(cls, args): + """ Returns an instance of ``param()`` for ``@parameterized`` argument + ``args``:: + >>> param.from_decorator((42, )) + param(args=(42, ), kwargs={}) + >>> param.from_decorator("foo") + param(args=("foo", ), kwargs={}) + """ + if isinstance(args, param): + return args + elif isinstance(args, str): + args = (args, ) + try: + return cls(*args) + except TypeError as e: + if "after * must be" not in str(e): + raise + raise TypeError( + "Parameters must be tuples, but %r is not (hint: use '(%r, )')" + % (args, args), ) + + def __repr__(self): + return "param(*%r, **%r)" % self + + +def to_safe_name(s): + return str(re.sub("[^a-zA-Z0-9_]+", "_", s)) + + +@contextlib.contextmanager +def stgraph(func, *args): + """static graph exec context""" + paddle.enable_static() + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + input = paddle.static.data('input', x.shape, dtype=x.dtype) + output = func(input, n, axes, norm) + + exe = paddle.static.Executor(place) + exe.run(sp) + [output] = exe.run(mp, feed={'input': x}, fetch_list=[output]) + yield output + paddle.disable_static() + + +# alias +parameterize = parameterize_func +param_cls = parameterize_cls +param_func = parameterize_func diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution.py index 42b658cdce4d2e0d448f34ba04116a4575b3cce5..7a1cb25b96f469f2c7b3f4dca1d6e96c1d0974b4 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution.py @@ -22,6 +22,7 @@ from paddle.distribution import * from paddle.fluid import layers import config +import parameterize paddle.enable_static() @@ -132,11 +133,12 @@ class DistributionTestName(unittest.TestCase): self.assertEqual(self.get_prefix(lp.name), name + '_log_prob') -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'batch_shape', 'event_shape'), - [('test-tuple', (10, 20), - (10, 20)), ('test-list', [100, 100], [100, 200, 300]), - ('test-null-eventshape', (100, 100), ())]) +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'batch_shape', 'event_shape'), + [('test-tuple', (10, 20), (10, 20)), + ('test-list', [100, 100], [100, 200, 300]), ('test-null-eventshape', + (100, 100), ())]) class TestDistributionShape(unittest.TestCase): def setUp(self): paddle.disable_static() @@ -156,7 +158,7 @@ class TestDistributionShape(unittest.TestCase): def test_prob(self): with self.assertRaises(NotImplementedError): - self.dist.prob(paddle.to_tensor(config.xrand())) + self.dist.prob(paddle.to_tensor(parameterize.xrand())) def test_extend_shape(self): shapes = [(34, 20), (56, ), ()] @@ -164,3 +166,24 @@ class TestDistributionShape(unittest.TestCase): self.assertTrue( self.dist._extend_shape(shape), shape + self.dist.batch_shape + self.dist.event_shape) + + +class TestDistributionException(unittest.TestCase): + def setUp(self): + self._d = paddle.distribution.Distribution() + + def test_mean(self): + with self.assertRaises(NotImplementedError): + self._d.mean + + def test_variance(self): + with self.assertRaises(NotImplementedError): + self._d.variance + + def test_rsample(self): + with self.assertRaises(NotImplementedError): + self._d.rsample(()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py index 1e5267405e8b84869fe618f04178252a2f1b27bc..fb0c37e3d659d82e9186cd4438b614d066e726a3 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py @@ -18,14 +18,15 @@ import numpy as np import paddle import scipy.stats -from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place, - xrand) +import config +from config import ATOL, DEVICES, RTOL +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand @place(DEVICES) -@parameterize((TEST_CASE_NAME, 'alpha', 'beta'), - [('test-scale', 1.0, 2.0), ('test-tensor', xrand(), xrand()), - ('test-broadcast', xrand((2, 1)), xrand((2, 5)))]) +@parameterize_cls((TEST_CASE_NAME, 'alpha', 'beta'), + [('test-scale', 1.0, 2.0), ('test-tensor', xrand(), xrand()), + ('test-broadcast', xrand((2, 1)), xrand((2, 5)))]) class TestBeta(unittest.TestCase): def setUp(self): # scale no need convert to tensor for scale input unittest @@ -98,3 +99,7 @@ class TestBeta(unittest.TestCase): self.assertTrue( self._paddle_beta.sample(case.get('input')).shape == case.get('expect')) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta_static.py index b8d72336807a4003a3d72b6df4d59fa34497eb47..e8fe0f17600c44a92aca47e78df5ef71ba2a111b 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_beta_static.py @@ -18,16 +18,19 @@ import numpy as np import paddle import scipy.stats -from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place, - xrand) +import config +import parameterize as param +from config import ATOL, RTOL +from parameterize import xrand paddle.enable_static() -@place(DEVICES) -@parameterize((TEST_CASE_NAME, 'alpha', 'beta'), [('test-tensor', xrand( - (10, 10)), xrand((10, 10))), ('test-broadcast', xrand((2, 1)), xrand( - (2, 5))), ('test-larger-data', xrand((10, 20)), xrand((10, 20)))]) +@param.place(config.DEVICES) +@param.parameterize_cls( + (param.TEST_CASE_NAME, 'alpha', 'beta'), [('test-tensor', xrand( + (10, 10)), xrand((10, 10))), ('test-broadcast', xrand((2, 1)), xrand( + (2, 5))), ('test-larger-data', xrand((10, 20)), xrand((10, 20)))]) class TestBeta(unittest.TestCase): def setUp(self): self.program = paddle.static.Program() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_categorical.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_categorical.py index d92ec52edaeb8f1e909986ceee48d863e09eb279..f43ac7bea763f95d21ec495ed421f42f0737832f 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_categorical.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_categorical.py @@ -439,3 +439,7 @@ class DistributionTestError(unittest.TestCase): cat.log_prob(value) self.assertRaises(ValueError, test_shape_not_match_error) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_constraint.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_constraint.py new file mode 100644 index 0000000000000000000000000000000000000000..c31d2124193eec7e232d0d7b9d3926db7e8cd921 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_constraint.py @@ -0,0 +1,78 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +from paddle.distribution import constraint + +import config +import parameterize as param + + +@param.param_cls((param.TEST_CASE_NAME, 'value'), + [('NotImplement', np.random.rand(2, 3))]) +class TestConstraint(unittest.TestCase): + def setUp(self): + self._constraint = constraint.Constraint() + + def test_costraint(self): + with self.assertRaises(NotImplementedError): + self._constraint(self.value) + + +@param.param_cls((param.TEST_CASE_NAME, 'value', 'expect'), + [('real', 1., True)]) +class TestReal(unittest.TestCase): + def setUp(self): + self._constraint = constraint.Real() + + def test_costraint(self): + self.assertEqual(self._constraint(self.value), self.expect) + + +@param.param_cls((param.TEST_CASE_NAME, 'lower', 'upper', 'value', 'expect'), + [('in_range', 0, 1, 0.5, True), ('out_range', 0, 1, 2, False)]) +class TestRange(unittest.TestCase): + def setUp(self): + self._constraint = constraint.Range(self.lower, self.upper) + + def test_costraint(self): + self.assertEqual(self._constraint(self.value), self.expect) + + +@param.param_cls((param.TEST_CASE_NAME, 'value', 'expect'), + [('positive', 1, True), ('negative', -1, False)]) +class TestPositive(unittest.TestCase): + def setUp(self): + self._constraint = constraint.Positive() + + def test_costraint(self): + self.assertEqual(self._constraint(self.value), self.expect) + + +@param.param_cls((param.TEST_CASE_NAME, 'value', 'expect'), + [('simplex', paddle.to_tensor([0.5, 0.5]), True), + ('non_simplex', paddle.to_tensor([-0.5, 0.5]), False)]) +class TestSimplex(unittest.TestCase): + def setUp(self): + self._constraint = constraint.Simplex() + + def test_costraint(self): + self.assertEqual(self._constraint(self.value), self.expect) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_dirichlet.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_dirichlet.py index 8baddfa2e9be1ebec383e4e44382871cba94cc02..9caec312b33821eddd9955010ef5d2ae382072ef 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_dirichlet.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_dirichlet.py @@ -19,15 +19,15 @@ import paddle import scipy.stats import config -from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place, - xrand) +from config import ATOL, DEVICES, RTOL +import parameterize as param -@place(DEVICES) -@parameterize( - (TEST_CASE_NAME, 'concentration'), +@param.place(DEVICES) +@param.param_cls( + (param.TEST_CASE_NAME, 'concentration'), [ - ('test-one-dim', config.xrand((89, ))), + ('test-one-dim', param.xrand((89, ))), # ('test-multi-dim', config.xrand((10, 20, 30))) ]) class TestDirichlet(unittest.TestCase): @@ -91,14 +91,18 @@ class TestDirichlet(unittest.TestCase): self.assertTrue( np.all( self._paddle_diric._log_normalizer( - paddle.to_tensor(config.xrand((100, 100, 100)))).numpy() < + paddle.to_tensor(param.xrand((100, 100, 100)))).numpy() < 0.0)) - @place(DEVICES) - @parameterize((TEST_CASE_NAME, 'concentration'), - [('test-zero-dim', np.array(1.0))]) + @param.place(DEVICES) + @param.param_cls((param.TEST_CASE_NAME, 'concentration'), + [('test-zero-dim', np.array(1.0))]) class TestDirichletException(unittest.TestCase): def TestInit(self): with self.assertRaises(ValueError): paddle.distribution.Dirichlet( paddle.squeeze(self.concentration)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_dirichlet_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_dirichlet_static.py index c84da943cf6cd64b3a663ffe3b2d68ab63895d27..f7096d295eeb57b3737ef0aea08fc63bc3d5f3d0 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_dirichlet_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_dirichlet_static.py @@ -18,15 +18,15 @@ import numpy as np import paddle import scipy.stats -from config import (ATOL, DEVICES, RTOL, TEST_CASE_NAME, parameterize, place, - xrand) +from config import ATOL, DEVICES, RTOL +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand paddle.enable_static() @place(DEVICES) -@parameterize((TEST_CASE_NAME, 'concentration'), - [('test-one-dim', np.random.rand(89) + 5.0)]) +@parameterize_cls((TEST_CASE_NAME, 'concentration'), + [('test-one-dim', np.random.rand(89) + 5.0)]) class TestDirichlet(unittest.TestCase): def setUp(self): self.program = paddle.static.Program() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_expfamily.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_expfamily.py index 341ec852c52197f689870f0a6c45141ebe318301..b601ac285840a00993288feb8b5b737715fb67e5 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_expfamily.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_expfamily.py @@ -20,14 +20,15 @@ import scipy.stats import config import mock_data as mock +import parameterize -@config.place(config.DEVICES) -@config.parameterize( - (config.TEST_CASE_NAME, 'dist'), [('test-mock-exp', - mock.Exponential(rate=paddle.rand( - [100, 200, 99], - dtype=config.DEFAULT_DTYPE)))]) +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'dist'), [('test-mock-exp', + mock.Exponential(rate=paddle.rand( + [100, 200, 99], + dtype=config.DEFAULT_DTYPE)))]) class TestExponentialFamily(unittest.TestCase): def test_entropy(self): np.testing.assert_allclose( @@ -37,15 +38,15 @@ class TestExponentialFamily(unittest.TestCase): atol=config.ATOL.get(config.DEFAULT_DTYPE)) -@config.place(config.DEVICES) -@config.parameterize( +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( (config.TEST_CASE_NAME, 'dist'), [('test-dummy', mock.DummyExpFamily(0.5, 0.5)), ('test-dirichlet', - paddle.distribution.Dirichlet(paddle.to_tensor(config.xrand()))), ( + paddle.distribution.Dirichlet(paddle.to_tensor(parameterize.xrand()))), ( 'test-beta', paddle.distribution.Beta( - paddle.to_tensor(config.xrand()), - paddle.to_tensor(config.xrand())))]) + paddle.to_tensor(parameterize.xrand()), + paddle.to_tensor(parameterize.xrand())))]) class TestExponentialFamilyException(unittest.TestCase): def test_entropy_exception(self): with self.assertRaises(NotImplementedError): diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_expfamily_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_expfamily_static.py index bb6317f1d56fd9f3733e938c77298d57e3a7ce75..28c337b617b2ee0d730f506f1841db90b0beb158 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_expfamily_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_expfamily_static.py @@ -20,17 +20,18 @@ import scipy.stats import config import mock_data as mock +import parameterize paddle.enable_static() -@config.place(config.DEVICES) +@parameterize.place(config.DEVICES) class TestExponentialFamily(unittest.TestCase): def setUp(self): self.program = paddle.static.Program() self.executor = paddle.static.Executor() with paddle.static.program_guard(self.program): - rate_np = config.xrand((100, 200, 99)) + rate_np = parameterize.xrand((100, 200, 99)) rate = paddle.static.data('rate', rate_np.shape, rate_np.dtype) self.mock_dist = mock.Exponential(rate) self.feeds = {'rate': rate_np} diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_independent.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_independent.py new file mode 100644 index 0000000000000000000000000000000000000000..f67c260cbcc318584ab4dbf3824253837ef629e0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_independent.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numbers +import unittest + +import numpy as np +import paddle +import scipy.stats + +import config +import parameterize as param + + +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank'), + [('base_beta', paddle.distribution.Beta( + paddle.rand([1, 2]), paddle.rand([1, 2])), 1)]) +class TestIndependent(unittest.TestCase): + def setUp(self): + self._t = paddle.distribution.Independent(self.base, + self.reinterpreted_batch_rank) + + def test_mean(self): + np.testing.assert_allclose( + self.base.mean, + self._t.mean, + rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)), + atol=config.ATOL.get(str(self.base.alpha.numpy().dtype))) + + def test_variance(self): + np.testing.assert_allclose( + self.base.variance, + self._t.variance, + rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)), + atol=config.ATOL.get(str(self.base.alpha.numpy().dtype))) + + def test_entropy(self): + np.testing.assert_allclose( + self._np_sum_rightmost(self.base.entropy().numpy(), + self.reinterpreted_batch_rank), + self._t.entropy(), + rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)), + atol=config.ATOL.get(str(self.base.alpha.numpy().dtype))) + + def _np_sum_rightmost(self, value, n): + return np.sum(value, tuple(range(-n, 0))) if n > 0 else value + + def test_log_prob(self): + value = np.random.rand(1) + np.testing.assert_allclose( + self._np_sum_rightmost( + self.base.log_prob(paddle.to_tensor(value)).numpy(), + self.reinterpreted_batch_rank), + self._t.log_prob(paddle.to_tensor(value)).numpy(), + rtol=config.RTOL.get(str(self.base.alpha.numpy().dtype)), + atol=config.ATOL.get(str(self.base.alpha.numpy().dtype))) + + # TODO(cxxly): Add Kolmogorov-Smirnov test for sample result. + def test_sample(self): + shape = (5, 10, 8) + expected_shape = (5, 10, 8, 1, 2) + data = self._t.sample(shape) + self.assertEqual(tuple(data.shape), expected_shape) + self.assertEqual(data.dtype, self.base.alpha.dtype) + + +@param.place(config.DEVICES) +@param.param_cls( + (param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank', + 'expected_exception'), + [('base_not_transform', '', 1, TypeError), + ('rank_less_than_zero', paddle.distribution.Transform(), -1, ValueError)]) +class TestIndependentException(unittest.TestCase): + def test_init(self): + with self.assertRaises(self.expected_exception): + paddle.distribution.IndependentTransform( + self.base, self.reinterpreted_batch_rank) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_independent_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_independent_static.py new file mode 100644 index 0000000000000000000000000000000000000000..eb078160a03e062b9faf3c6e6d51fca58293e6fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_independent_static.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numbers +import unittest + +import numpy as np +import paddle +import scipy.stats + +import config +import parameterize as param + +paddle.enable_static() + + +@param.place(config.DEVICES) +@param.param_cls( + (param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank', 'alpha', 'beta'), + [('base_beta', paddle.distribution.Beta, 1, np.random.rand(1, 2), + np.random.rand(1, 2))]) +class TestIndependent(unittest.TestCase): + def setUp(self): + value = np.random.rand(1) + self.dtype = value.dtype + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + alpha = paddle.static.data('alpha', self.alpha.shape, + self.alpha.dtype) + beta = paddle.static.data('beta', self.beta.shape, self.beta.dtype) + self.base = self.base(alpha, beta) + t = paddle.distribution.Independent(self.base, + self.reinterpreted_batch_rank) + mean = t.mean + variance = t.variance + entropy = t.entropy() + static_value = paddle.static.data('value', value.shape, value.dtype) + log_prob = t.log_prob(static_value) + + base_mean = self.base.mean + base_variance = self.base.variance + base_entropy = self.base.entropy() + base_log_prob = self.base.log_prob(static_value) + + fetch_list = [ + mean, variance, entropy, log_prob, base_mean, base_variance, + base_entropy, base_log_prob + ] + exe.run(sp) + [ + self.mean, self.variance, self.entropy, self.log_prob, + self.base_mean, self.base_variance, self.base_entropy, + self.base_log_prob + ] = exe.run( + mp, + feed={'value': value, + 'alpha': self.alpha, + 'beta': self.beta}, + fetch_list=fetch_list) + + def test_mean(self): + np.testing.assert_allclose( + self.mean, + self.base_mean, + rtol=config.RTOL.get(str(self.dtype)), + atol=config.ATOL.get(str(self.dtype))) + + def test_variance(self): + np.testing.assert_allclose( + self.variance, + self.base_variance, + rtol=config.RTOL.get(str(self.dtype)), + atol=config.ATOL.get(str(self.dtype))) + + def test_entropy(self): + np.testing.assert_allclose( + self._np_sum_rightmost(self.base_entropy, + self.reinterpreted_batch_rank), + self.entropy, + rtol=config.RTOL.get(str(self.dtype)), + atol=config.ATOL.get(str(self.dtype))) + + def _np_sum_rightmost(self, value, n): + return np.sum(value, tuple(range(-n, 0))) if n > 0 else value + + def test_log_prob(self): + np.testing.assert_allclose( + self._np_sum_rightmost(self.base_log_prob, + self.reinterpreted_batch_rank), + self.log_prob, + rtol=config.RTOL.get(str(self.dtype)), + atol=config.ATOL.get(str(self.dtype))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial.py index bff723dfa29fc80fdececa34a1fbef592481cfb2..851645a96d4059e3c2ad8913725f74a5339e4ad5 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial.py @@ -19,15 +19,17 @@ import paddle import scipy.stats import config +import parameterize -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [ - ('one-dim', 10, config.xrand((3, ))), - ('multi-dim', 9, config.xrand((10, 20))), - ('prob-sum-one', 10, np.array([0.5, 0.2, 0.3])), - ('prob-sum-non-one', 10, np.array([2., 3., 5.])), -]) +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'total_count', 'probs'), [ + ('one-dim', 10, parameterize.xrand((3, ))), + ('multi-dim', 9, parameterize.xrand((10, 20))), + ('prob-sum-one', 10, np.array([0.5, 0.2, 0.3])), + ('prob-sum-non-one', 10, np.array([2., 3., 5.])), + ]) class TestMultinomial(unittest.TestCase): def setUp(self): self._dist = paddle.distribution.Multinomial( @@ -98,9 +100,9 @@ class TestMultinomial(unittest.TestCase): return scipy.stats.multinomial.entropy(self.total_count, probs) -@config.place(config.DEVICES) -@config.parameterize( - (config.TEST_CASE_NAME, 'total_count', 'probs', 'value'), +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'total_count', 'probs', 'value'), [ ('value-float', 10, np.array([0.2, 0.3, 0.5]), np.array([2., 3., 5.])), ('value-int', 10, np.array([0.2, 0.3, 0.5]), np.array([2, 3, 5])), @@ -122,12 +124,13 @@ class TestMultinomialPmf(unittest.TestCase): atol=config.ATOL.get(str(self.probs.dtype))) -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [ - ('total_count_le_one', 0, np.array([0.3, 0.7])), - ('total_count_float', np.array([0.3, 0.7])), - ('probs_zero_dim', np.array(0)), -]) +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (config.TEST_CASE_NAME, 'total_count', 'probs'), [ + ('total_count_le_one', 0, np.array([0.3, 0.7])), + ('total_count_float', np.array([0.3, 0.7])), + ('probs_zero_dim', np.array(0)), + ]) class TestMultinomialException(unittest.TestCase): def TestInit(self): with self.assertRaises(ValueError): diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py index 2eb5b9769dfe9e96d8a4b2c31e4aea9dfed105de..ac86ad8d3e1854fbe460f88c78a6ce9aaa713055 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py @@ -19,17 +19,19 @@ import paddle import scipy.stats import config +import parameterize paddle.enable_static() -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [ - ('one-dim', 5, config.xrand((3, ))), - ('multi-dim', 9, config.xrand((2, 3))), - ('prob-sum-one', 5, np.array([0.5, 0.2, 0.3])), - ('prob-sum-non-one', 5, np.array([2., 3., 5.])), -]) +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'total_count', 'probs'), [ + ('one-dim', 5, parameterize.xrand((3, ))), + ('multi-dim', 9, parameterize.xrand((2, 3))), + ('prob-sum-one', 5, np.array([0.5, 0.2, 0.3])), + ('prob-sum-non-one', 5, np.array([2., 3., 5.])), + ]) class TestMultinomial(unittest.TestCase): def setUp(self): startup_program = paddle.static.Program() @@ -99,9 +101,9 @@ class TestMultinomial(unittest.TestCase): return scipy.stats.multinomial.entropy(self.total_count, probs) -@config.place(config.DEVICES) -@config.parameterize( - (config.TEST_CASE_NAME, 'total_count', 'probs', 'value'), +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'total_count', 'probs', 'value'), [ ('value-float', 5, np.array([0.2, 0.3, 0.5]), np.array([1., 1., 3.])), ('value-int', 5, np.array([0.2, 0.3, 0.5]), np.array([2, 2, 1])), @@ -139,12 +141,13 @@ class TestMultinomialPmf(unittest.TestCase): atol=config.ATOL.get(str(self.probs.dtype))) -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [ - ('total_count_le_one', 0, np.array([0.3, 0.7])), - ('total_count_float', np.array([0.3, 0.7])), - ('probs_zero_dim', np.array(0)), -]) +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'total_count', 'probs'), [ + ('total_count_le_one', 0, np.array([0.3, 0.7])), + ('total_count_float', np.array([0.3, 0.7])), + ('probs_zero_dim', np.array(0)), + ]) class TestMultinomialException(unittest.TestCase): def setUp(self): startup_program = paddle.static.Program() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py index d1ded2256e250faaad45de88ddefcac2f6ad523e..0c23e367f98f7e85354fc2b0c224533658de01cb 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py @@ -454,3 +454,7 @@ class NormalTest10(NormalTest): with fluid.program_guard(self.test_program): self.static_values = layers.data( name='values', shape=[dims], dtype='float32') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b1304a52ef354924c419f629a47d5920036f26ef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform.py @@ -0,0 +1,917 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +import unittest + +import numpy as np +import paddle +from paddle.distribution import constraint, transform, variable + +import config +import parameterize as param + + +@param.place(config.DEVICES) +class TestTransform(unittest.TestCase): + def setUp(self): + self._t = transform.Transform() + + @param.param_func([ + (paddle.distribution.Distribution(), + paddle.distribution.TransformedDistribution), + (paddle.distribution.ExpTransform(), paddle.distribution.ChainTransform) + ]) + def test_call(self, input, expected_type): + t = transform.Transform() + self.assertIsInstance(t(input), expected_type) + + @param.param_func( + [(transform.Type.BIJECTION, True), (transform.Type.INJECTION, True), + (transform.Type.SURJECTION, False), (transform.Type.OTHER, False)]) + def test_is_injective(self, type, expected): + transform.Transform._type = type + self.assertEqual(self._t._is_injective(), expected) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Real)) + + @param.param_func([(0, TypeError), (paddle.rand((2, 3)), + NotImplementedError)]) + def test_forward(self, input, expected): + with self.assertRaises(expected): + self._t.forward(input) + + @param.param_func([(0, TypeError), (paddle.rand((2, 3)), + NotImplementedError)]) + def test_inverse(self, input, expected): + with self.assertRaises(expected): + self._t.inverse(input) + + @param.param_func([(0, TypeError), (paddle.rand((2, 3)), + NotImplementedError)]) + def test_forward_log_det_jacobian(self, input, expected): + with self.assertRaises(expected): + self._t.forward_log_det_jacobian(input) + + @param.param_func([(0, TypeError), (paddle.rand((2, 3)), + NotImplementedError)]) + def test_inverse_log_det_jacobian(self, input, expected): + with self.assertRaises(expected): + self._t.inverse_log_det_jacobian(input) + + @param.param_func([(0, TypeError)]) + def test_forward_shape(self, shape, expected): + with self.assertRaises(expected): + self._t.forward_shape(shape) + + @param.param_func([(0, TypeError)]) + def test_inverse_shape(self, shape, expected): + with self.assertRaises(expected): + self._t.inverse_shape(shape) + + +@param.place(config.DEVICES) +class TestAbsTransform(unittest.TestCase): + def setUp(self): + self._t = transform.AbsTransform() + + def test_is_injective(self): + self.assertFalse(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Positive)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + + @param.param_func([(np.array([-1., 1., 0.]), np.array([1., 1., 0.])), + (np.array([[1., -1., -0.1], [-3., -0.1, 0]]), + np.array([[1., 1., 0.1], [3., 0.1, 0]]))]) + def test_forward(self, input, expected): + np.testing.assert_allclose( + self._t.forward(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(np.array(1.), (-np.array(1.), np.array(1.)))]) + def test_inverse(self, input, expected): + actual0, actual1 = self._t.inverse(paddle.to_tensor(input)) + expected0, expected1 = expected + np.testing.assert_allclose( + actual0.numpy(), + expected0, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + np.testing.assert_allclose( + actual1.numpy(), + expected1, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def test_forward_log_det_jacobian(self): + with self.assertRaises(NotImplementedError): + self._t.forward_log_det_jacobian(paddle.rand((10, ))) + + @param.param_func([(np.array(1.), (np.array(0.), np.array(0.))), ]) + def test_inverse_log_det_jacobian(self, input, expected): + actual0, actual1 = self._t.inverse_log_det_jacobian( + paddle.to_tensor(input)) + expected0, expected1 = expected + np.testing.assert_allclose( + actual0.numpy(), + expected0, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + np.testing.assert_allclose( + actual1.numpy(), + expected1, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'loc', 'scale'), [ + ('normal', np.random.rand(8, 10), np.random.rand(8, 10)), + ('broadcast', np.random.rand(2, 10), np.random.rand(10)), +]) +class TestAffineTransform(unittest.TestCase): + def setUp(self): + self._t = transform.AffineTransform( + paddle.to_tensor(self.loc), paddle.to_tensor(self.scale)) + + @param.param_func([ + (paddle.rand([1]), 0, TypeError), + (0, paddle.rand([1]), TypeError), + ]) + def test_init_exception(self, loc, scale, exc): + with self.assertRaises(exc): + paddle.distribution.AffineTransform(loc, scale) + + def test_scale(self): + np.testing.assert_allclose(self._t.scale, self.scale) + + def test_loc(self): + np.testing.assert_allclose(self._t.loc, self.loc) + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Real)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + + def test_forward(self): + x = np.random.random(self.loc.shape) + np.testing.assert_allclose( + self._t.forward(paddle.to_tensor(x)).numpy(), + self._np_forward(x), + rtol=config.RTOL.get(str(self._t.loc.numpy().dtype)), + atol=config.ATOL.get(str(self._t.loc.numpy().dtype))) + + def test_inverse(self): + y = np.random.random(self.loc.shape) + np.testing.assert_allclose( + self._t.inverse(paddle.to_tensor(y)).numpy(), + self._np_inverse(y), + rtol=config.RTOL.get(str(self._t.loc.numpy().dtype)), + atol=config.ATOL.get(str(self._t.loc.numpy().dtype))) + + def _np_forward(self, x): + return self.loc + self.scale * x + + def _np_inverse(self, y): + return (y - self.loc) / self.scale + + def _np_forward_jacobian(self, x): + return np.log(np.abs(self.scale)) + + def _np_inverse_jacobian(self, y): + return -self._np_forward_jacobian(self._np_inverse(y)) + + def test_inverse_log_det_jacobian(self): + y = np.random.random(self.scale.shape) + np.testing.assert_allclose( + self._t.inverse_log_det_jacobian(paddle.to_tensor(y)).numpy(), + self._np_inverse_jacobian(y), + rtol=config.RTOL.get(str(self._t.loc.numpy().dtype)), + atol=config.ATOL.get(str(self._t.loc.numpy().dtype))) + + def test_forward_log_det_jacobian(self): + x = np.random.random(self.scale.shape) + np.testing.assert_allclose( + self._t.forward_log_det_jacobian(paddle.to_tensor(x)).numpy(), + self._np_forward_jacobian(x), + rtol=config.RTOL.get(str(self._t.loc.numpy().dtype)), + atol=config.ATOL.get(str(self._t.loc.numpy().dtype))) + + def test_forward_shape(self): + shape = self.loc.shape + self.assertEqual( + tuple(self._t.forward_shape(shape)), + np.broadcast(np.random.random(shape), self.loc, self.scale).shape) + + def test_inverse_shape(self): + shape = self.scale.shape + self.assertEqual( + tuple(self._t.forward_shape(shape)), + np.broadcast(np.random.random(shape), self.loc, self.scale).shape) + + +@param.place(config.DEVICES) +class TestExpTransform(unittest.TestCase): + def setUp(self): + self._t = transform.ExpTransform() + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Positive)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + + @param.param_func( + [(np.array([0., 1., 2., 3.]), np.exp(np.array([0., 1., 2., 3.]))), + (np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]]), + np.exp(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]])))]) + def test_forward(self, input, expected): + np.testing.assert_allclose( + self._t.forward(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(np.array([1., 2., 3.]), np.log(np.array([1., 2., 3.]))), + (np.array([[1., 2., 3.], [6., 7., 8.]]), + np.log(np.array([[1., 2., 3.], [6., 7., 8.]])))]) + def test_inverse(self, input, expected): + np.testing.assert_allclose( + self._t.inverse(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(np.array([1., 2., 3.]), ), + (np.array([[1., 2., 3.], [6., 7., 8.]]), )]) + def test_forward_log_det_jacobian(self, input): + np.testing.assert_allclose( + self._t.forward_log_det_jacobian(paddle.to_tensor(input)).numpy(), + self._np_forward_jacobian(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def _np_forward_jacobian(self, x): + return x + + @param.param_func([(np.array([1., 2., 3.]), ), + (np.array([[1., 2., 3.], [6., 7., 8.]]), )]) + def test_inverse_log_det_jacobian(self, input): + np.testing.assert_allclose( + self._t.inverse_log_det_jacobian(paddle.to_tensor(input)).numpy(), + self._np_inverse_jacobian(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def _np_inverse_jacobian(self, y): + return -self._np_forward_jacobian(np.log(y)) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +class TestChainTransform(unittest.TestCase): + @param.param_func([(paddle.distribution.Transform, TypeError), + ([0], TypeError)]) + def test_init_exception(self, transforms, exception): + with self.assertRaises(exception): + paddle.distribution.ChainTransform(transforms) + + @param.param_func(( + (transform.ChainTransform( + (transform.AbsTransform(), + transform.AffineTransform(paddle.rand([1]), paddle.rand([1])))), + False), (transform.ChainTransform(( + transform.AffineTransform(paddle.rand([1]), paddle.rand([1])), + transform.ExpTransform(), )), True))) + def test_is_injective(self, chain, expected): + self.assertEqual(chain._is_injective(), expected) + + @param.param_func(((transform.ChainTransform( + (transform.IndependentTransform(transform.ExpTransform(), 1), + transform.IndependentTransform(transform.ExpTransform(), 10), + transform.IndependentTransform(transform.ExpTransform(), 8))), + variable.Independent(variable.real, 10)), )) + def test_domain(self, input, expected): + self.assertIsInstance(input._domain, type(expected)) + self.assertEqual(input._domain.event_rank, expected.event_rank) + self.assertEqual(input._domain.is_discrete, expected.is_discrete) + + @param.param_func(((transform.ChainTransform( + (transform.IndependentTransform(transform.ExpTransform(), 9), + transform.IndependentTransform(transform.ExpTransform(), 4), + transform.IndependentTransform(transform.ExpTransform(), 5))), + variable.Independent(variable.real, 9)), )) + def test_codomain(self, input, expected): + self.assertIsInstance(input._codomain, variable.Independent) + self.assertEqual(input._codomain.event_rank, expected.event_rank) + self.assertEqual(input._codomain.is_discrete, expected.is_discrete) + + @param.param_func( + [(transform.ChainTransform((transform.AffineTransform( + paddle.to_tensor(0.0), paddle.to_tensor(1.0)), + transform.ExpTransform())), + np.array([0., 1., 2., 3.]), np.exp(np.array([0., 1., 2., 3.]) * 1.0)), + (transform.ChainTransform((transform.ExpTransform(), + transform.TanhTransform())), + np.array([[0., -1., 2., -3.], [-5., 6., 7., -8.]]), + np.tanh(np.exp(np.array([[0., -1., 2., -3.], [-5., 6., 7., -8.]]))))]) + def test_forward(self, chain, input, expected): + np.testing.assert_allclose( + chain.forward(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func( + [(transform.ChainTransform( + (transform.AffineTransform( + paddle.to_tensor(0.0), paddle.to_tensor(-1.0)), + transform.ExpTransform())), np.array([0., 1., 2., 3.]), + np.log(np.array([0., 1., 2., 3.])) / (-1.0)), + (transform.ChainTransform((transform.ExpTransform(), + transform.TanhTransform())), + np.array([[0., 1., 2., 3.], [5., 6., 7., 8.]]), + np.log(np.arctanh(np.array([[0., 1., 2., 3.], [5., 6., 7., 8.]]))))]) + def test_inverse(self, chain, input, expected): + np.testing.assert_allclose( + chain.inverse(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([ + (transform.ChainTransform( + (transform.AffineTransform( + paddle.to_tensor(0.0), paddle.to_tensor(-1.0)), + transform.PowerTransform(paddle.to_tensor(2.0)))), + np.array([1., 2., 3.]), np.log(2. * np.array([1., 2., 3.]))), + ]) + def test_forward_log_det_jacobian(self, chain, input, expected): + np.testing.assert_allclose( + chain.forward_log_det_jacobian(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(transform.ChainTransform((transform.AffineTransform( + paddle.to_tensor(0.0), + paddle.to_tensor(-1.0)), transform.ExpTransform())), (2, 3, 5), + (2, 3, 5)), ]) + def test_forward_shape(self, chain, shape, expected_shape): + self.assertEqual(chain.forward_shape(shape), expected_shape) + + @param.param_func([(transform.ChainTransform((transform.AffineTransform( + paddle.to_tensor(0.0), + paddle.to_tensor(-1.0)), transform.ExpTransform())), (2, 3, 5), + (2, 3, 5)), ]) + def test_inverse_shape(self, chain, shape, expected_shape): + self.assertEqual(chain.inverse_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +@param.param_cls( + (param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank', 'x'), + [('rank-over-zero', transform.ExpTransform(), 2, np.random.rand(2, 3, 3)), + ]) +class TestIndependentTransform(unittest.TestCase): + def setUp(self): + self._t = transform.IndependentTransform(self.base, + self.reinterpreted_batch_rank) + + @param.param_func([(0, 0, TypeError), + (paddle.distribution.Transform(), -1, ValueError)]) + def test_init_exception(self, base, rank, exc): + with self.assertRaises(exc): + paddle.distribution.IndependentTransform(base, rank) + + def test_is_injective(self): + self.assertEqual(self._t._is_injective(), self.base._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Independent)) + self.assertEqual( + self._t._domain.event_rank, + self.base._domain.event_rank + self.reinterpreted_batch_rank) + self.assertEqual(self._t._domain.is_discrete, + self.base._domain.is_discrete) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Independent)) + self.assertEqual( + self._t._codomain.event_rank, + self.base._codomain.event_rank + self.reinterpreted_batch_rank) + self.assertEqual(self._t._codomain.is_discrete, + self.base._codomain.is_discrete) + + def test_forward(self): + np.testing.assert_allclose( + self._t.forward(paddle.to_tensor(self.x)).numpy(), + self.base.forward(paddle.to_tensor(self.x)).numpy(), + rtol=config.RTOL.get(str(self.x.dtype)), + atol=config.ATOL.get(str(self.x.dtype))) + + def test_inverse(self): + np.testing.assert_allclose( + self._t.inverse(paddle.to_tensor(self.x)).numpy(), + self.base.inverse(paddle.to_tensor(self.x)).numpy(), + rtol=config.RTOL.get(str(self.x.dtype)), + atol=config.ATOL.get(str(self.x.dtype))) + + def test_forward_log_det_jacobian(self): + actual = self._t.forward_log_det_jacobian(paddle.to_tensor(self.x)) + self.assertEqual( + tuple(actual.shape), self.x.shape[:-self.reinterpreted_batch_rank]) + expected = self.base.forward_log_det_jacobian( + paddle.to_tensor(self.x)).sum( + list(range(-self.reinterpreted_batch_rank, 0))) + np.testing.assert_allclose( + actual.numpy(), + expected.numpy(), + rtol=config.RTOL.get(str(self.x.dtype)), + atol=config.ATOL.get(str(self.x.dtype))) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +class TestPowerTransform(unittest.TestCase): + def setUp(self): + self._t = transform.PowerTransform(paddle.to_tensor(2.)) + + def test_init(self): + with self.assertRaises(TypeError): + transform.PowerTransform(1.) + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Positive)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + + @param.param_func([(np.array([2.]), np.array([0., -1., 2.]), np.power( + np.array([0., -1., 2.]), + 2.)), (np.array([[0.], [3.]]), np.array([[1., 0.], [5., 6.]]), np.power( + np.array([[1., 0.], [5., 6.]]), np.array([[0.], [3.]])))]) + def test_forward(self, power, x, y): + t = transform.PowerTransform(paddle.to_tensor(power)) + np.testing.assert_allclose( + t.forward(paddle.to_tensor(x)).numpy(), + y, + rtol=config.RTOL.get(str(x.dtype)), + atol=config.ATOL.get(str(x.dtype))) + + @param.param_func([(np.array([2.]), np.array([4.]), np.array([2.]))]) + def test_inverse(self, power, y, x): + t = transform.PowerTransform(paddle.to_tensor(power)) + np.testing.assert_allclose( + t.inverse(paddle.to_tensor(y)).numpy(), + x, + rtol=config.RTOL.get(str(x.dtype)), + atol=config.ATOL.get(str(x.dtype))) + + @param.param_func(((np.array([2.]), np.array([3., 1.4, 0.8])), )) + def test_forward_log_det_jacobian(self, power, x): + t = transform.PowerTransform(paddle.to_tensor(power)) + np.testing.assert_allclose( + t.forward_log_det_jacobian(paddle.to_tensor(x)).numpy(), + self._np_forward_jacobian(power, x), + rtol=config.RTOL.get(str(x.dtype)), + atol=config.ATOL.get(str(x.dtype))) + + def _np_forward_jacobian(self, alpha, x): + return np.abs(np.log(alpha * np.power(x, alpha - 1))) + + @param.param_func([((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +class TestTanhTransform(unittest.TestCase): + def setUp(self): + self._t = transform.TanhTransform() + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Variable)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + self.assertEqual(self._t._codomain._constraint._lower, -1) + self.assertEqual(self._t._codomain._constraint._upper, 1) + + @param.param_func( + [(np.array([0., 1., 2., 3.]), np.tanh(np.array([0., 1., 2., 3.]))), + (np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]]), + np.tanh(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]])))]) + def test_forward(self, input, expected): + np.testing.assert_allclose( + self._t.forward(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func( + [(np.array([1., 2., 3.]), np.arctanh(np.array([1., 2., 3.]))), + (np.array([[1., 2., 3.], [6., 7., 8.]]), + np.arctanh(np.array([[1., 2., 3.], [6., 7., 8.]])))]) + def test_inverse(self, input, expected): + np.testing.assert_allclose( + self._t.inverse(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(np.array([1., 2., 3.]), ), + (np.array([[1., 2., 3.], [6., 7., 8.]]), )]) + def test_forward_log_det_jacobian(self, input): + np.testing.assert_allclose( + self._t.forward_log_det_jacobian(paddle.to_tensor(input)).numpy(), + self._np_forward_jacobian(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def _np_forward_jacobian(self, x): + return 2. * (np.log(2.) - x - self._np_softplus(-2. * x)) + + def _np_softplus(self, x, beta=1., threshold=20.): + if np.any(beta * x > threshold): + return x + return 1. / beta * np.log1p(np.exp(beta * x)) + + def _np_inverse_jacobian(self, y): + return -self._np_forward_jacobian(np.arctanh(y)) + + @param.param_func([(np.array([1., 2., 3.]), ), + (np.array([[1., 2., 3.], [6., 7., 8.]]), )]) + def test_inverse_log_det_jacobian(self, input): + np.testing.assert_allclose( + self._t.inverse_log_det_jacobian(paddle.to_tensor(input)).numpy(), + self._np_inverse_jacobian(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'in_event_shape', 'out_event_shape'), [ + ('regular_shape', (2, 3), (3, 2)), +]) +class TestReshapeTransform(unittest.TestCase): + def setUp(self): + self._t = transform.ReshapeTransform(self.in_event_shape, + self.out_event_shape) + + @param.param_func([(0, 0, TypeError), ((1, 2), (1, 3), ValueError)]) + def test_init_exception(self, in_event_shape, out_event_shape, exc): + with self.assertRaises(exc): + paddle.distribution.ReshapeTransform(in_event_shape, + out_event_shape) + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Independent)) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Independent)) + + def test_forward(self): + x = paddle.ones(self.in_event_shape) + np.testing.assert_allclose( + self._t.forward(x), + paddle.ones(self.out_event_shape), + rtol=config.RTOL.get(str(x.numpy().dtype)), + atol=config.ATOL.get(str(x.numpy().dtype))) + + def test_inverse(self): + x = paddle.ones(self.out_event_shape) + np.testing.assert_allclose( + self._t.inverse(x).numpy(), + paddle.ones(self.in_event_shape).numpy(), + rtol=config.RTOL.get(str(x.numpy().dtype)), + atol=config.ATOL.get(str(x.numpy().dtype))) + + def test_forward_log_det_jacobian(self): + x = paddle.ones(self.in_event_shape) + np.testing.assert_allclose( + self._t.forward_log_det_jacobian(x).numpy(), + paddle.zeros([1]).numpy(), + rtol=config.RTOL.get(str(x.numpy().dtype)), + atol=config.ATOL.get(str(x.numpy().dtype))) + + def test_in_event_shape(self): + self.assertEqual(self._t.in_event_shape, self.in_event_shape) + + def test_out_event_shape(self): + self.assertEqual(self._t.out_event_shape, self.out_event_shape) + + @param.param_func([((), ValueError), ((1, 2), ValueError)]) + def test_forward_shape_exception(self, shape, exc): + with self.assertRaises(exc): + self._t.forward_shape(shape) + + @param.param_func([((), ValueError), ((1, 2), ValueError)]) + def test_inverse_shape_exception(self, shape, exc): + with self.assertRaises(exc): + self._t.inverse_shape(shape) + + +def _np_softplus(x, beta=1., threshold=20.): + if np.any(beta * x > threshold): + return x + return 1. / beta * np.log1p(np.exp(beta * x)) + + +class TestSigmoidTransform(unittest.TestCase): + def setUp(self): + self._t = transform.SigmoidTransform() + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Variable)) + + @param.param_func(((np.ones((5, 10)), + 1 / (1 + np.exp(-np.ones((5, 10))))), )) + def test_forward(self, input, expected): + np.testing.assert_allclose( + self._t.forward(paddle.to_tensor(input)), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func(( + (np.ones(10), np.log(np.ones(10)) - np.log1p(-np.ones(10))), )) + def test_inverse(self, input, expected): + np.testing.assert_allclose( + self._t.inverse(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func(( + (np.ones(10), + -_np_softplus(-np.ones(10)) - _np_softplus(np.ones(10))), )) + def test_forward_log_det_jacobian(self, input, expected): + np.testing.assert_allclose( + self._t.forward_log_det_jacobian(paddle.to_tensor(input)).numpy(), + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +class TestSoftmaxTransform(unittest.TestCase): + def setUp(self): + self._t = transform.SoftmaxTransform() + + def test_is_injective(self): + self.assertFalse(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Independent)) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Variable)) + + @param.param_func(((np.random.random((5, 10)), ), )) + def test_forward(self, input): + np.testing.assert_allclose( + self._t.forward(paddle.to_tensor(input)), + self._np_forward(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func(((np.random.random(10), ), )) + def test_inverse(self, input): + np.testing.assert_allclose( + self._t.inverse(paddle.to_tensor(input)), + self._np_inverse(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def _np_forward(self, x): + x = np.exp(x - np.max(x, -1, keepdims=True)[0]) + return x / np.sum(x, -1, keepdims=True) + + def _np_inverse(self, y): + return np.log(y) + + def test_forward_log_det_jacobian(self): + with self.assertRaises(NotImplementedError): + self._t.forward_log_det_jacobian(paddle.rand((2, 3))) + + @param.param_func([((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ValueError)]) + def test_forward_shape_exception(self, shape, exc): + with self.assertRaises(exc): + self._t.forward_shape(shape) + + @param.param_func([((), ValueError)]) + def test_inverse_shape_exception(self, shape, exc): + with self.assertRaises(exc): + self._t.inverse_shape(shape) + + @param.param_func([((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.inverse_shape(shape), expected_shape) + + +class TestStickBreakingTransform(unittest.TestCase): + def setUp(self): + self._t = transform.StickBreakingTransform() + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Independent)) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Variable)) + + @param.param_func(((np.random.random((10)), ), )) + def test_forward(self, input): + np.testing.assert_allclose( + self._t.inverse(self._t.forward(paddle.to_tensor(input))), + input, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([((2, 3, 5), (2, 3, 6))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((2, 3, 5), (2, 3, 4))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.inverse_shape(shape), expected_shape) + + @param.param_func(((np.random.random((10)), ), )) + def test_forward_log_det_jacobian(self, x): + self.assertEqual( + self._t.forward_log_det_jacobian(paddle.to_tensor(x)).shape, [1]) + + +# Todo +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'transforms', 'axis'), [ + ('simple_one_transform', [transform.ExpTransform()], 0), +]) +class TestStackTransform(unittest.TestCase): + def setUp(self): + self._t = transform.StackTransform(self.transforms, self.axis) + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Stack)) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Stack)) + + @param.param_func([(np.array([[0., 1., 2., 3.]]), ), + (np.array([[-5., 6., 7., 8.]]), )]) + def test_forward(self, input): + self.assertEqual( + tuple(self._t.forward(paddle.to_tensor(input)).shape), input.shape) + + @param.param_func([(np.array([[1., 2., 3.]]), ), + (np.array([[6., 7., 8.]], ), )]) + def test_inverse(self, input): + self.assertEqual( + tuple(self._t.inverse(paddle.to_tensor(input)).shape), input.shape) + + @param.param_func([(np.array([[1., 2., 3.]]), ), + (np.array([[6., 7., 8.]]), )]) + def test_forward_log_det_jacobian(self, input): + self.assertEqual( + tuple( + self._t.forward_log_det_jacobian(paddle.to_tensor(input)) + .shape), input.shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + def test_axis(self): + self.assertEqual(self._t.axis, self.axis) + + @param.param_func( + [(0, 0, TypeError), ([0], 0, TypeError), + ([paddle.distribution.ExpTransform()], 'axis', TypeError)]) + def test_init_exception(self, transforms, axis, exc): + with self.assertRaises(exc): + paddle.distribution.StackTransform(transforms, axis) + + def test_transforms(self): + self.assertIsInstance((self._t.transforms), typing.Sequence) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform_static.py new file mode 100644 index 0000000000000000000000000000000000000000..fa5742fb261034b4e926435e6b0726b809a0dac1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transform_static.py @@ -0,0 +1,974 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +from paddle.distribution import transform, variable, constraint + +import config +import parameterize as param + +paddle.enable_static() + + +@param.place(config.DEVICES) +class TestTransform(unittest.TestCase): + def setUp(self): + self._t = transform.Transform() + + @param.param_func( + [(transform.Type.BIJECTION, True), (transform.Type.INJECTION, True), + (transform.Type.SURJECTION, False), (transform.Type.OTHER, False)]) + def test_is_injective(self, type, expected): + transform.Transform._type = type + self.assertEqual(self._t._is_injective(), expected) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Real)) + + @param.param_func([(np.array(0), NotImplementedError), (np.random.random( + (2, 3)), NotImplementedError)]) + def test_forward(self, input, expected): + with self.assertRaises(expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.Transform() + static_input = paddle.static.data('input', input.shape, + input.dtype) + output = t.forward(static_input) + exe.run(sp) + exe.run(mp, feed={'input': input}, fetch_list=[output]) + + @param.param_func([(np.array(0), NotImplementedError), (np.random.random( + (2, 3)), NotImplementedError)]) + def test_inverse(self, input, expected): + with self.assertRaises(expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.Transform() + static_input = paddle.static.data('input', input.shape, + input.dtype) + output = t.inverse(static_input) + exe.run(sp) + exe.run(mp, feed={'input': input}, fetch_list=[output]) + + @param.param_func([(np.array(0), NotImplementedError), (paddle.rand( + (2, 3)), NotImplementedError)]) + def test_forward_log_det_jacobian(self, input, expected): + with self.assertRaises(expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.Transform() + static_input = paddle.static.data('input', input.shape, + input.dtype) + output = t.forward_log_det_jacobian(static_input) + exe.run(sp) + exe.run(mp, feed={'input': input}, fetch_list=[output]) + + @param.param_func([(np.array(0), NotImplementedError), (paddle.rand( + (2, 3)), NotImplementedError)]) + def test_inverse_log_det_jacobian(self, input, expected): + with self.assertRaises(expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.Transform() + static_input = paddle.static.data('input', input.shape, + input.dtype) + output = t.inverse_log_det_jacobian(static_input) + exe.run(sp) + exe.run(mp, feed={'input': input}, fetch_list=[output]) + + @param.param_func([(0, TypeError)]) + def test_forward_shape(self, shape, expected): + with self.assertRaises(expected): + self._t.forward_shape(shape) + + @param.param_func([(0, TypeError)]) + def test_inverse_shape(self, shape, expected): + with self.assertRaises(expected): + self._t.forward_shape(shape) + + +@param.place(config.DEVICES) +class TestAbsTransform(unittest.TestCase): + def setUp(self): + self._t = transform.AbsTransform() + + def test_is_injective(self): + self.assertFalse(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Positive)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + + @param.param_func([(np.array([-1., 1., 0.]), np.array([1., 1., 0.])), + (np.array([[1., -1., -0.1], [-3., -0.1, 0]]), + np.array([[1., 1., 0.1], [3., 0.1, 0]]))]) + def test_forward(self, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.AbsTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.forward(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(np.array([1.]), (-np.array([1.]), np.array([1.])))]) + def test_inverse(self, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.AbsTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + actual0, actual1 = t.inverse(static_input) + exe.run(sp) + [actual0, actual1] = exe.run(mp, + feed={'input': input}, + fetch_list=[actual0, actual1]) + expected0, expected1 = expected + np.testing.assert_allclose( + actual0, + expected0, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + np.testing.assert_allclose( + actual1, + expected1, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def test_forward_log_det_jacobian(self): + input = np.random.random((10, )) + with self.assertRaises(NotImplementedError): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.AbsTransform() + static_input = paddle.static.data('input', input.shape, + input.dtype) + output = t.forward_log_det_jacobian(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + + @param.param_func([(np.array([1.]), (np.array([0.]), np.array([0.]))), ]) + def test_inverse_log_det_jacobian(self, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.AbsTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + actual0, actual1 = t.inverse_log_det_jacobian(static_input) + exe.run(sp) + [actual0, actual1] = exe.run(mp, + feed={'input': input}, + fetch_list=[actual0, actual1]) + expected0, expected1 = expected + np.testing.assert_allclose( + actual0, + expected0, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + np.testing.assert_allclose( + actual1, + expected1, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'loc', 'scale'), [ + ('normal', np.random.rand(8, 10), np.random.rand(8, 10)), + ('broadcast', np.random.rand(2, 10), np.random.rand(10)), +]) +class TestAffineTransform(unittest.TestCase): + def setUp(self): + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + self._t = transform.AffineTransform(loc, scale) + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Real)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + + def test_forward(self): + input = np.random.random(self.loc.shape) + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + t = transform.AffineTransform(loc, scale) + static_input = paddle.static.data('input', self.loc.shape, + self.loc.dtype) + output = t.forward(static_input) + exe.run(sp) + [output] = exe.run( + mp, + feed={'input': input, + 'loc': self.loc, + 'scale': self.scale}, + fetch_list=[output]) + np.testing.assert_allclose( + output, + self._np_forward(input), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def test_inverse(self): + input = np.random.random(self.loc.shape) + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + t = transform.AffineTransform(loc, scale) + static_input = paddle.static.data('input', self.loc.shape, + self.loc.dtype) + output = t.inverse(static_input) + exe.run(sp) + [output] = exe.run( + mp, + feed={'input': input, + 'loc': self.loc, + 'scale': self.scale}, + fetch_list=[output]) + np.testing.assert_allclose( + output, + self._np_inverse(input), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def _np_forward(self, x): + return self.loc + self.scale * x + + def _np_inverse(self, y): + return (y - self.loc) / self.scale + + def _np_forward_jacobian(self, x): + return np.log(np.abs(self.scale)) + + def _np_inverse_jacobian(self, y): + return -self._np_forward_jacobian(self._np_inverse(y)) + + def test_inverse_log_det_jacobian(self): + input = np.random.random(self.scale.shape) + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + t = transform.AffineTransform(loc, scale) + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.inverse_log_det_jacobian(static_input) + exe.run(sp) + [output] = exe.run( + mp, + feed={'input': input, + 'loc': self.loc, + 'scale': self.scale}, + fetch_list=[output]) + np.testing.assert_allclose( + output, + self._np_inverse_jacobian(input), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def test_forward_log_det_jacobian(self): + input = np.random.random(self.scale.shape) + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + t = transform.AffineTransform(loc, scale) + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.forward_log_det_jacobian(static_input) + exe.run(sp) + [output] = exe.run( + mp, + feed={'input': input, + 'loc': self.loc, + 'scale': self.scale}, + fetch_list=[output]) + np.testing.assert_allclose( + output, + self._np_forward_jacobian(input), + rtol=config.RTOL.get(str(self.loc.dtype)), + atol=config.ATOL.get(str(self.loc.dtype))) + + def test_forward_shape(self): + shape = self.loc.shape + self.assertEqual( + tuple(self._t.forward_shape(shape)), + np.broadcast(np.random.random(shape), self.loc, self.scale).shape) + + def test_inverse_shape(self): + shape = self.scale.shape + self.assertEqual( + tuple(self._t.forward_shape(shape)), + np.broadcast(np.random.random(shape), self.loc, self.scale).shape) + + +@param.place(config.DEVICES) +class TestExpTransform(unittest.TestCase): + def setUp(self): + self._t = transform.ExpTransform() + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Positive)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + + @param.param_func( + [(np.array([0., 1., 2., 3.]), np.exp(np.array([0., 1., 2., 3.]))), + (np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]]), + np.exp(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]])))]) + def test_forward(self, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.ExpTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.forward(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(np.array([1., 2., 3.]), np.log(np.array([1., 2., 3.]))), + (np.array([[1., 2., 3.], [6., 7., 8.]]), + np.log(np.array([[1., 2., 3.], [6., 7., 8.]])))]) + def test_inverse(self, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.ExpTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.inverse(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(np.array([1., 2., 3.]), ), + (np.array([[1., 2., 3.], [6., 7., 8.]]), )]) + def test_forward_log_det_jacobian(self, input): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.ExpTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.forward_log_det_jacobian(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + self._np_forward_jacobian(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def _np_forward_jacobian(self, x): + return x + + @param.param_func([(np.array([1., 2., 3.]), ), + (np.array([[1., 2., 3.], [6., 7., 8.]]), )]) + def test_inverse_log_det_jacobian(self, input): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.ExpTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.inverse_log_det_jacobian(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + self._np_inverse_jacobian(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def _np_inverse_jacobian(self, y): + return -self._np_forward_jacobian(np.log(y)) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +class TestChainTransform(unittest.TestCase): + @param.param_func(( + (transform.ChainTransform( + (transform.AbsTransform(), + transform.AffineTransform(paddle.rand([1]), paddle.rand([1])))), + False), (transform.ChainTransform(( + transform.AffineTransform(paddle.rand([1]), paddle.rand([1])), + transform.ExpTransform(), )), True))) + def test_is_injective(self, chain, expected): + self.assertEqual(chain._is_injective(), expected) + + @param.param_func(((transform.ChainTransform( + (transform.IndependentTransform(transform.ExpTransform(), 1), + transform.IndependentTransform(transform.ExpTransform(), 10), + transform.IndependentTransform(transform.ExpTransform(), 8))), + variable.Independent(variable.real, 10)), )) + def test_domain(self, input, expected): + self.assertIsInstance(input._domain, type(expected)) + self.assertEqual(input._domain.event_rank, expected.event_rank) + self.assertEqual(input._domain.is_discrete, expected.is_discrete) + + @param.param_func(((transform.ChainTransform( + (transform.IndependentTransform(transform.ExpTransform(), 9), + transform.IndependentTransform(transform.ExpTransform(), 4), + transform.IndependentTransform(transform.ExpTransform(), 5))), + variable.Independent(variable.real, 9)), )) + def test_codomain(self, input, expected): + self.assertIsInstance(input._codomain, variable.Independent) + self.assertEqual(input._codomain.event_rank, expected.event_rank) + self.assertEqual(input._codomain.is_discrete, expected.is_discrete) + + @param.param_func( + [(transform.ChainTransform((transform.ExpTransform(), + transform.TanhTransform())), + np.array([[0., -1., 2., -3.], [-5., 6., 7., -8.]]), + np.tanh(np.exp(np.array([[0., -1., 2., -3.], [-5., 6., 7., -8.]]))))]) + def test_forward(self, chain, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = chain + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.forward(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func( + [(transform.ChainTransform((transform.ExpTransform(), + transform.TanhTransform())), + np.array([[0., 1., 2., 3.], [5., 6., 7., 8.]]), + np.log(np.arctanh(np.array([[0., 1., 2., 3.], [5., 6., 7., 8.]]))))]) + def test_inverse(self, chain, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = chain + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.inverse(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(transform.ChainTransform((transform.AffineTransform( + paddle.full([1], 0.0), + paddle.full([1], -1.0)), transform.ExpTransform())), (2, 3, 5), + (2, 3, 5)), ]) + def test_forward_shape(self, chain, shape, expected_shape): + self.assertEqual(chain.forward_shape(shape), expected_shape) + + @param.param_func([(transform.ChainTransform((transform.AffineTransform( + paddle.full([1], 0.0), + paddle.full([1], -1.0)), transform.ExpTransform())), (2, 3, 5), + (2, 3, 5)), ]) + def test_inverse_shape(self, chain, shape, expected_shape): + self.assertEqual(chain.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +@param.param_cls( + (param.TEST_CASE_NAME, 'base', 'reinterpreted_batch_rank', 'x'), + [('rank-over-zero', transform.ExpTransform(), 2, np.random.rand(2, 3, 3)), + ]) +class TestIndependentTransform(unittest.TestCase): + def setUp(self): + self._t = transform.IndependentTransform(self.base, + self.reinterpreted_batch_rank) + + def test_is_injective(self): + self.assertEqual(self._t._is_injective(), self.base._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Independent)) + self.assertEqual( + self._t._domain.event_rank, + self.base._domain.event_rank + self.reinterpreted_batch_rank) + self.assertEqual(self._t._domain.is_discrete, + self.base._domain.is_discrete) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Independent)) + self.assertEqual( + self._t._codomain.event_rank, + self.base._codomain.event_rank + self.reinterpreted_batch_rank) + self.assertEqual(self._t._codomain.is_discrete, + self.base._codomain.is_discrete) + + def test_forward(self): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.IndependentTransform(self.base, + self.reinterpreted_batch_rank) + static_input = paddle.static.data('input', self.x.shape, + self.x.dtype) + output = t.forward(static_input) + expected = self.base.forward(static_input) + exe.run(sp) + [output, expected] = exe.run(mp, + feed={'input': self.x}, + fetch_list=[output, expected]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(self.x.dtype)), + atol=config.ATOL.get(str(self.x.dtype))) + + def test_inverse(self): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.IndependentTransform(self.base, + self.reinterpreted_batch_rank) + static_input = paddle.static.data('input', self.x.shape, + self.x.dtype) + output = t.inverse(static_input) + expected = self.base.inverse(static_input) + exe.run(sp) + [output, expected] = exe.run(mp, + feed={'input': self.x}, + fetch_list=[output, expected]) + np.testing.assert_allclose( + expected, + output, + rtol=config.RTOL.get(str(self.x.dtype)), + atol=config.ATOL.get(str(self.x.dtype))) + + def test_forward_log_det_jacobian(self): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.IndependentTransform(self.base, + self.reinterpreted_batch_rank) + static_input = paddle.static.data('input', self.x.shape, + self.x.dtype) + output = t.forward_log_det_jacobian(static_input) + expected = self.base.forward_log_det_jacobian( + static_input.sum( + list(range(-self.reinterpreted_batch_rank, 0)))) + exe.run(sp) + [actual, expected] = exe.run(mp, + feed={'input': self.x}, + fetch_list=[output, expected]) + self.assertEqual( + tuple(actual.shape), self.x.shape[:-self.reinterpreted_batch_rank]) + np.testing.assert_allclose( + actual, + expected, + rtol=config.RTOL.get(str(self.x.dtype)), + atol=config.ATOL.get(str(self.x.dtype))) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +class TestPowerTransform(unittest.TestCase): + def setUp(self): + self._t = transform.PowerTransform(paddle.full([1], 2.)) + + def test_init(self): + with self.assertRaises(TypeError): + transform.PowerTransform(1.) + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Positive)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + + @param.param_func([(np.array([2.]), np.array([0., -1., 2.]), np.power( + np.array([0., -1., 2.]), + 2.)), (np.array([[0.], [3.]]), np.array([[1., 0.], [5., 6.]]), np.power( + np.array([[1., 0.], [5., 6.]]), np.array([[0.], [3.]])))]) + def test_forward(self, power, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + static_power = paddle.static.data('power', power.shape, power.dtype) + t = transform.PowerTransform(static_power) + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.forward(static_input) + exe.run(sp) + [output] = exe.run(mp, + feed={'input': input, + 'power': power}, + fetch_list=[output]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(np.array([2.]), np.array([4.]), np.array([2.]))]) + def test_inverse(self, power, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + static_power = paddle.static.data('power', power.shape, power.dtype) + t = transform.PowerTransform(static_power) + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.inverse(static_input) + exe.run(sp) + [output] = exe.run(mp, + feed={'input': input, + 'power': power}, + fetch_list=[output]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func(((np.array([2.]), np.array([3., 1.4, 0.8])), )) + def test_forward_log_det_jacobian(self, power, input): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + static_power = paddle.static.data('power', power.shape, power.dtype) + t = transform.PowerTransform(static_power) + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.forward_log_det_jacobian(static_input) + exe.run(sp) + [output] = exe.run(mp, + feed={'input': input, + 'power': power}, + fetch_list=[output]) + np.testing.assert_allclose( + output, + self._np_forward_jacobian(power, input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def _np_forward_jacobian(self, alpha, x): + return np.abs(np.log(alpha * np.power(x, alpha - 1))) + + @param.param_func([((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +class TestTanhTransform(unittest.TestCase): + def setUp(self): + self._t = transform.TanhTransform() + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Real)) + self.assertEqual(self._t._domain.event_rank, 0) + self.assertEqual(self._t._domain.is_discrete, False) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Variable)) + self.assertEqual(self._t._codomain.event_rank, 0) + self.assertEqual(self._t._codomain.is_discrete, False) + self.assertEqual(self._t._codomain._constraint._lower, -1) + self.assertEqual(self._t._codomain._constraint._upper, 1) + + @param.param_func( + [(np.array([0., 1., 2., 3.]), np.tanh(np.array([0., 1., 2., 3.]))), + (np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]]), + np.tanh(np.array([[0., 1., 2., 3.], [-5., 6., 7., 8.]])))]) + def test_forward(self, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.TanhTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.forward(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func( + [(np.array([1., 2., 3.]), np.arctanh(np.array([1., 2., 3.]))), + (np.array([[1., 2., 3.], [6., 7., 8.]]), + np.arctanh(np.array([[1., 2., 3.], [6., 7., 8.]])))]) + def test_inverse(self, input, expected): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.TanhTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.inverse(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([(np.array([1., 2., 3.]), ), + (np.array([[1., 2., 3.], [6., 7., 8.]]), )]) + def test_forward_log_det_jacobian(self, input): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.TanhTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.forward_log_det_jacobian(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + self._np_forward_jacobian(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + def _np_forward_jacobian(self, x): + return 2. * (np.log(2.) - x - self._np_softplus(-2. * x)) + + def _np_softplus(self, x, beta=1., threshold=20.): + if np.any(beta * x > threshold): + return x + return 1. / beta * np.log1p(np.exp(beta * x)) + + def _np_inverse_jacobian(self, y): + return -self._np_forward_jacobian(np.arctanh(y)) + + @param.param_func([(np.array([1., 2., 3.]), ), + (np.array([[1., 2., 3.], [6., 7., 8.]]), )]) + def test_inverse_log_det_jacobian(self, input): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + t = transform.TanhTransform() + static_input = paddle.static.data('input', input.shape, input.dtype) + output = t.inverse_log_det_jacobian(static_input) + exe.run(sp) + [output] = exe.run(mp, feed={'input': input}, fetch_list=[output]) + np.testing.assert_allclose( + output, + self._np_inverse_jacobian(input), + rtol=config.RTOL.get(str(input.dtype)), + atol=config.ATOL.get(str(input.dtype))) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_forward_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + @param.param_func([((), ()), ((2, 3, 5), (2, 3, 5))]) + def test_inverse_shape(self, shape, expected_shape): + self.assertEqual(self._t.forward_shape(shape), expected_shape) + + +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'in_event_shape', 'out_event_shape'), [ + ('regular_shape', (2, 3), (3, 2)), +]) +class TestReshapeTransform(unittest.TestCase): + def setUp(self): + self._t = transform.ReshapeTransform(self.in_event_shape, + self.out_event_shape) + + def test_is_injective(self): + self.assertTrue(self._t._is_injective()) + + def test_domain(self): + self.assertTrue(isinstance(self._t._domain, variable.Independent)) + + def test_codomain(self): + self.assertTrue(isinstance(self._t._codomain, variable.Independent)) + + def test_forward(self): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.ones(self.in_event_shape) + t = transform.ReshapeTransform(self.in_event_shape, + self.out_event_shape) + output = self._t.forward(x) + exe.run(sp) + [output] = exe.run(mp, feed={}, fetch_list=[output]) + expected = np.ones(self.out_event_shape) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(expected.dtype)), + atol=config.ATOL.get(str(expected.dtype))) + + def test_inverse(self): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.ones(self.out_event_shape) + t = transform.ReshapeTransform(self.in_event_shape, + self.out_event_shape) + output = self._t.inverse(x) + exe.run(sp) + [output] = exe.run(mp, feed={}, fetch_list=[output]) + expected = np.ones(self.in_event_shape) + + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(expected.dtype)), + atol=config.ATOL.get(str(expected.dtype))) + + def test_forward_log_det_jacobian(self): + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.ones(self.in_event_shape) + t = transform.ReshapeTransform(self.in_event_shape, + self.out_event_shape) + output = self._t.forward_log_det_jacobian(x) + exe.run(sp) + [output] = exe.run(mp, feed={}, fetch_list=[output]) + expected = np.zeros([1]) + np.testing.assert_allclose( + output, + expected, + rtol=config.RTOL.get(str(expected.dtype)), + atol=config.ATOL.get(str(expected.dtype))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..2f7bb61e38d1389b5cb6d0c5ae6615b15bb15ec3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numbers +import unittest + +import numpy as np +import paddle +import scipy.stats + +import config +import parameterize as param + + +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'base', 'transforms'), + [('base_normal', paddle.distribution.Normal(0., 1.), + [paddle.distribution.ExpTransform()])]) +class TestIndependent(unittest.TestCase): + def setUp(self): + self._t = paddle.distribution.TransformedDistribution(self.base, + self.transforms) + + def _np_sum_rightmost(self, value, n): + return np.sum(value, tuple(range(-n, 0))) if n > 0 else value + + def test_log_prob(self): + value = paddle.to_tensor(0.5) + np.testing.assert_allclose( + self.simple_log_prob(value, self.base, self.transforms), + self._t.log_prob(value), + rtol=config.RTOL.get(str(value.numpy().dtype)), + atol=config.ATOL.get(str(value.numpy().dtype))) + + def simple_log_prob(self, value, base, transforms): + log_prob = 0.0 + y = value + for t in reversed(transforms): + x = t.inverse(y) + log_prob = log_prob - t.forward_log_det_jacobian(x) + y = x + log_prob += base.log_prob(y) + return log_prob + + # TODO(cxxly): Add Kolmogorov-Smirnov test for sample result. + def test_sample(self): + shape = [5, 10, 8] + expected_shape = (5, 10, 8) + data = self._t.sample(shape) + self.assertEqual(tuple(data.shape), expected_shape) + self.assertEqual(data.dtype, self.base.loc.dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution_static.py new file mode 100644 index 0000000000000000000000000000000000000000..f07205a62680a755dcc886f53308166f3286870b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution_static.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numbers +import unittest + +import numpy as np +import paddle +import scipy.stats + +import config +import parameterize as param + +paddle.enable_static() + + +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'base', 'transforms'), + [('base_normal', paddle.distribution.Normal, + [paddle.distribution.ExpTransform()])]) +class TestIndependent(unittest.TestCase): + def setUp(self): + value = np.array([0.5]) + loc = np.array([0.]) + scale = np.array([1.]) + shape = [5, 10, 8] + self.dtype = value.dtype + exe = paddle.static.Executor() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + static_value = paddle.static.data('value', value.shape, value.dtype) + static_loc = paddle.static.data('loc', loc.shape, loc.dtype) + static_scale = paddle.static.data('scale', scale.shape, scale.dtype) + self.base = self.base(static_loc, static_scale) + self._t = paddle.distribution.TransformedDistribution( + self.base, self.transforms) + actual_log_prob = self._t.log_prob(static_value) + expected_log_prob = self.transformed_log_prob( + static_value, self.base, self.transforms) + sample_data = self._t.sample(shape) + + exe.run(sp) + [self.actual_log_prob, self.expected_log_prob, + self.sample_data] = exe.run( + mp, + feed={'value': value, + 'loc': loc, + 'scale': scale}, + fetch_list=[actual_log_prob, expected_log_prob, sample_data]) + + def test_log_prob(self): + np.testing.assert_allclose( + self.actual_log_prob, + self.expected_log_prob, + rtol=config.RTOL.get(str(self.dtype)), + atol=config.ATOL.get(str(self.dtype))) + + def transformed_log_prob(self, value, base, transforms): + log_prob = 0.0 + y = value + for t in reversed(transforms): + x = t.inverse(y) + log_prob = log_prob - t.forward_log_det_jacobian(x) + y = x + log_prob += base.log_prob(y) + return log_prob + + # TODO(cxxly): Add Kolmogorov-Smirnov test for sample result. + def test_sample(self): + expected_shape = (5, 10, 8, 1) + self.assertEqual(tuple(self.sample_data.shape), expected_shape) + self.assertEqual(self.sample_data.dtype, self.dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py index e6076764b04fe3c0a752fe38ddd980336a3ba62b..d8fe23b9c1bdac37acb041cefbb1d580956e3e58 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py @@ -343,3 +343,7 @@ class UniformTestSample2(UniformTestSample): def init_param(self): self.low = -5.0 self.high = 2.0 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_variable.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd50157207fd160a1085beebcba0cf669534af4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_variable.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +from paddle.distribution import variable +from paddle.distribution import constraint + +import config +import parameterize as param + + +@param.param_cls( + (param.TEST_CASE_NAME, 'is_discrete', 'event_rank', 'constraint'), + [('NotImplement', False, 0, constraint.Constraint())]) +class TestVariable(unittest.TestCase): + def setUp(self): + self._var = variable.Variable(self.is_discrete, self.event_rank, + self.constraint) + + @param.param_func([(1, )]) + def test_costraint(self, value): + with self.assertRaises(NotImplementedError): + self._var.constraint(value) + + +@param.param_cls((param.TEST_CASE_NAME, 'base', 'rank'), + [('real_base', variable.real, 10)]) +class TestIndependent(unittest.TestCase): + def setUp(self): + self._var = variable.Independent(self.base, self.rank) + + @param.param_func([(paddle.rand([2, 3, 4]), ValueError), ]) + def test_costraint(self, value, expect): + with self.assertRaises(expect): + self._var.constraint(value) + + +@param.param_cls((param.TEST_CASE_NAME, 'vars', 'axis'), + [('real_base', [variable.real], 10)]) +class TestStack(unittest.TestCase): + def setUp(self): + self._var = variable.Stack(self.vars, self.axis) + + def test_is_discrete(self): + self.assertEqual(self._var.is_discrete, False) + + @param.param_func([(paddle.rand([2, 3, 4]), ValueError), ]) + def test_costraint(self, value, expect): + with self.assertRaises(expect): + self._var.constraint(value) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_kl.py b/python/paddle/fluid/tests/unittests/distribution/test_kl.py index 55358380c8b23fdfd512b259aca06901d5623e38..635f5446c8ef25f19a4ffe3df328b65d84cb45b7 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_kl.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_kl.py @@ -23,12 +23,13 @@ from paddle.distribution import kl import config import mock_data as mock +import parameterize as param paddle.set_default_dtype('float64') -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [ +@param.place(config.DEVICES) +@param.parameterize_cls((param.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [ ('test_regular_input', 6.0 * np.random.random((4, 5)) + 1e-4, 6.0 * np.random.random((4, 5)) + 1e-4, 6.0 * np.random.random( (4, 5)) + 1e-4, 6.0 * np.random.random((4, 5)) + 1e-4), @@ -55,8 +56,8 @@ class TestKLBetaBeta(unittest.TestCase): (a2 - a1 + b2 - b1) * scipy.special.digamma(a1 + b1)) -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'conc1', 'conc2'), [ +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'conc1', 'conc2'), [ ('test-regular-input', np.random.random((5, 7, 8, 10)), np.random.random( (5, 7, 8, 10))), ]) @@ -88,23 +89,22 @@ class DummyDistribution(paddle.distribution.Distribution): pass -@config.place(config.DEVICES) -@config.parameterize( - (config.TEST_CASE_NAME, 'p', 'q'), - [('test-unregister', DummyDistribution(), DummyDistribution)]) +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'p', 'q'), + [('test-unregister', DummyDistribution(), DummyDistribution)]) class TestDispatch(unittest.TestCase): def test_dispatch_with_unregister(self): with self.assertRaises(NotImplementedError): paddle.distribution.kl_divergence(self.p, self.q) -@config.place(config.DEVICES) -@config.parameterize( - (config.TEST_CASE_NAME, 'p', 'q'), - [('test-diff-dist', mock.Exponential(paddle.rand((100, 200, 100)) + 1.0), - mock.Exponential(paddle.rand((100, 200, 100)) + 2.0)), - ('test-same-dist', mock.Exponential(paddle.to_tensor(1.0)), - mock.Exponential(paddle.to_tensor(1.0)))]) +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'p', 'q'), + [('test-diff-dist', + mock.Exponential(paddle.rand((100, 200, 100)) + 1.0), + mock.Exponential(paddle.rand((100, 200, 100)) + 2.0)), + ('test-same-dist', mock.Exponential(paddle.to_tensor(1.0)), + mock.Exponential(paddle.to_tensor(1.0)))]) class TestKLExpfamilyExpFamily(unittest.TestCase): def test_kl_expfamily_expfamily(self): np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/distribution/test_kl_static.py b/python/paddle/fluid/tests/unittests/distribution/test_kl_static.py index 828a7320d474fcb6be3a4003d5110c980f245404..b061650a53b9e8114f49a548a30260a70ffbd219 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_kl_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_kl_static.py @@ -22,13 +22,14 @@ import scipy.stats from paddle.distribution import kl import config +import parameterize as param import mock_data as mock paddle.enable_static() -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [ +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'a1', 'b1', 'a2', 'b2'), [ ('test_regular_input', 6.0 * np.random.random((4, 5)) + 1e-4, 6.0 * np.random.random((4, 5)) + 1e-4, 6.0 * np.random.random( (4, 5)) + 1e-4, 6.0 * np.random.random((4, 5)) + 1e-4), @@ -75,8 +76,8 @@ class TestKLBetaBeta(unittest.TestCase): (a2 - a1 + b2 - b1) * scipy.special.digamma(a1 + b1)) -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'conc1', 'conc2'), [ +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'conc1', 'conc2'), [ ('test-regular-input', np.random.random((5, 7, 8, 10)), np.random.random( (5, 7, 8, 10))), ]) @@ -123,9 +124,9 @@ class DummyDistribution(paddle.distribution.Distribution): pass -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'p', 'q'), - [('test-dispatch-exception')]) +@param.place(config.DEVICES) +@param.param_cls((param.TEST_CASE_NAME, 'p', 'q'), + [('test-dispatch-exception')]) class TestDispatch(unittest.TestCase): def setUp(self): self.mp = paddle.static.Program() @@ -143,11 +144,11 @@ class TestDispatch(unittest.TestCase): self.executor.run(self.mp, feed={}, fetch_list=[out]) -@config.place(config.DEVICES) -@config.parameterize((config.TEST_CASE_NAME, 'rate1', 'rate2'), - [('test-diff-dist', np.random.rand(100, 200, 100) + 1.0, - np.random.rand(100, 200, 100) + 2.0), - ('test-same-dist', np.array([1.0]), np.array([1.0]))]) +@param.place(config.DEVICES) +@param.param_cls((config.TEST_CASE_NAME, 'rate1', 'rate2'), + [('test-diff-dist', np.random.rand(100, 200, 100) + 1.0, + np.random.rand(100, 200, 100) + 2.0), + ('test-same-dist', np.array([1.0]), np.array([1.0]))]) class TestKLExpfamilyExpFamily(unittest.TestCase): def setUp(self): self.mp = paddle.static.Program() @@ -176,3 +177,7 @@ class TestKLExpfamilyExpFamily(unittest.TestCase): out2, rtol=config.RTOL.get(config.DEFAULT_DTYPE), atol=config.ATOL.get(config.DEFAULT_DTYPE)) + + +if __name__ == '__main__': + unittest.main()