transformed_distribution.py 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

from paddle.distribution import distribution
from paddle.distribution import transform
from paddle.distribution import independent


class TransformedDistribution(distribution.Distribution):
23 24
    r"""
    Applies a sequence of Transforms to a base distribution.
25 26 27 28 29 30 31 32

    Args:
        base (Distribution): The base distribution.
        transforms (Sequence[Transform]): A sequence of ``Transform`` .

    Examples:

        .. code-block:: python
33 34

            import paddle
35 36 37
            from paddle.distribution import transformed_distribution

            d = transformed_distribution.TransformedDistribution(
38
                paddle.distribution.Normal(0., 1.),
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
                [paddle.distribution.AffineTransform(paddle.to_tensor(1.), paddle.to_tensor(2.))]
            )

            print(d.sample([10]))
            # Tensor(shape=[10], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [-0.10697651,  3.33609009, -0.86234951,  5.07457638,  0.75925219,
            #         -4.17087793,  2.22579336, -0.93845034,  0.66054249,  1.50957513])
            print(d.log_prob(paddle.to_tensor(0.5)))
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [-1.64333570])
    """

    def __init__(self, base, transforms):
        if not isinstance(base, distribution.Distribution):
            raise TypeError(
                f"Expected type of 'base' is Distribution, but got {type(base)}."
            )
        if not isinstance(transforms, typing.Sequence):
            raise TypeError(
                f"Expected type of 'transforms' is Sequence[Transform] or Chain, but got {type(transforms)}."
            )
        if not all(isinstance(t, transform.Transform) for t in transforms):
            raise TypeError("All element of transforms must be Transform type.")

        chain = transform.ChainTransform(transforms)
64
        base_shape = base.batch_shape + base.event_shape
65 66 67
        self._base = base
        self._transforms = transforms
        if not transforms:
68 69 70
            super(TransformedDistribution, self).__init__(
                base.batch_shape, base.event_shape
            )
71 72
            return
        if len(base.batch_shape + base.event_shape) < chain._domain.event_rank:
73
            raise ValueError(
74
                f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}."
75 76 77
            )
        if chain._domain.event_rank > len(base.event_shape):
            base = independent.Independent(
78 79
                (base, chain._domain.event_rank - len(base.event_shape))
            )
80

81 82 83 84 85 86
        transformed_shape = chain.forward_shape(
            base.batch_shape + base.event_shape
        )
        transformed_event_rank = chain._codomain.event_rank + max(
            len(base.event_shape) - chain._domain.event_rank, 0
        )
87
        super(TransformedDistribution, self).__init__(
88 89 90 91 92 93 94
            transformed_shape[
                : len(transformed_shape) - transformed_event_rank
            ],
            transformed_shape[
                len(transformed_shape) - transformed_event_rank :
            ],
        )
95 96 97 98 99

    def sample(self, shape=()):
        """Sample from ``TransformedDistribution``.

        Args:
100
            shape (Sequence[int], optional): The sample shape. Defaults to ().
101 102 103 104 105 106 107 108 109

        Returns:
            [Tensor]: The sample result.
        """
        x = self._base.sample(shape)
        for t in self._transforms:
            x = t.forward(x)
        return x

110 111 112 113 114 115 116 117 118 119 120 121 122 123
    def rsample(self, shape=()):
        """Reparameterized sample from ``TransformedDistribution``.

        Args:
            shape (Sequence[int], optional): The sample shape. Defaults to ().

        Returns:
            [Tensor]: The sample result.
        """
        x = self._base.rsample(shape)
        for t in self._transforms:
            x = t.forward(x)
        return x

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    def log_prob(self, value):
        """The log probability evaluated at value.

        Args:
            value (Tensor): The value to be evaluated.

        Returns:
            Tensor: The log probability.
        """
        log_prob = 0.0
        y = value
        event_rank = len(self.event_shape)
        for t in reversed(self._transforms):
            x = t.inverse(y)
            event_rank += t._domain.event_rank - t._codomain.event_rank
139 140 141
            log_prob = log_prob - _sum_rightmost(
                t.forward_log_det_jacobian(x), event_rank - t._domain.event_rank
            )
142
            y = x
143 144 145
        log_prob += _sum_rightmost(
            self._base.log_prob(y), event_rank - len(self._base.event_shape)
        )
146 147 148 149 150
        return log_prob


def _sum_rightmost(value, n):
    return value.sum(list(range(-n, 0))) if n > 0 else value