transformed_distribution.py 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

17
from paddle.distribution import distribution, independent, transform
18 19 20


class TransformedDistribution(distribution.Distribution):
21 22
    r"""
    Applies a sequence of Transforms to a base distribution.
23 24 25 26 27 28 29 30

    Args:
        base (Distribution): The base distribution.
        transforms (Sequence[Transform]): A sequence of ``Transform`` .

    Examples:

        .. code-block:: python
31 32

            import paddle
33 34 35
            from paddle.distribution import transformed_distribution

            d = transformed_distribution.TransformedDistribution(
36
                paddle.distribution.Normal(0., 1.),
37 38 39 40 41 42 43 44
                [paddle.distribution.AffineTransform(paddle.to_tensor(1.), paddle.to_tensor(2.))]
            )

            print(d.sample([10]))
            # Tensor(shape=[10], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [-0.10697651,  3.33609009, -0.86234951,  5.07457638,  0.75925219,
            #         -4.17087793,  2.22579336, -0.93845034,  0.66054249,  1.50957513])
            print(d.log_prob(paddle.to_tensor(0.5)))
45 46
            # Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        -1.64333570)
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    """

    def __init__(self, base, transforms):
        if not isinstance(base, distribution.Distribution):
            raise TypeError(
                f"Expected type of 'base' is Distribution, but got {type(base)}."
            )
        if not isinstance(transforms, typing.Sequence):
            raise TypeError(
                f"Expected type of 'transforms' is Sequence[Transform] or Chain, but got {type(transforms)}."
            )
        if not all(isinstance(t, transform.Transform) for t in transforms):
            raise TypeError("All element of transforms must be Transform type.")

        chain = transform.ChainTransform(transforms)
62
        base_shape = base.batch_shape + base.event_shape
63 64 65
        self._base = base
        self._transforms = transforms
        if not transforms:
66
            super().__init__(base.batch_shape, base.event_shape)
67 68
            return
        if len(base.batch_shape + base.event_shape) < chain._domain.event_rank:
69
            raise ValueError(
70
                f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}."
71 72 73
            )
        if chain._domain.event_rank > len(base.event_shape):
            base = independent.Independent(
74 75
                (base, chain._domain.event_rank - len(base.event_shape))
            )
76

77 78 79 80 81 82
        transformed_shape = chain.forward_shape(
            base.batch_shape + base.event_shape
        )
        transformed_event_rank = chain._codomain.event_rank + max(
            len(base.event_shape) - chain._domain.event_rank, 0
        )
83
        super().__init__(
84 85 86 87 88 89 90
            transformed_shape[
                : len(transformed_shape) - transformed_event_rank
            ],
            transformed_shape[
                len(transformed_shape) - transformed_event_rank :
            ],
        )
91 92 93 94 95

    def sample(self, shape=()):
        """Sample from ``TransformedDistribution``.

        Args:
96
            shape (Sequence[int], optional): The sample shape. Defaults to ().
97 98 99 100 101 102 103 104 105

        Returns:
            [Tensor]: The sample result.
        """
        x = self._base.sample(shape)
        for t in self._transforms:
            x = t.forward(x)
        return x

106 107 108 109 110 111 112 113 114 115 116 117 118 119
    def rsample(self, shape=()):
        """Reparameterized sample from ``TransformedDistribution``.

        Args:
            shape (Sequence[int], optional): The sample shape. Defaults to ().

        Returns:
            [Tensor]: The sample result.
        """
        x = self._base.rsample(shape)
        for t in self._transforms:
            x = t.forward(x)
        return x

120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    def log_prob(self, value):
        """The log probability evaluated at value.

        Args:
            value (Tensor): The value to be evaluated.

        Returns:
            Tensor: The log probability.
        """
        log_prob = 0.0
        y = value
        event_rank = len(self.event_shape)
        for t in reversed(self._transforms):
            x = t.inverse(y)
            event_rank += t._domain.event_rank - t._codomain.event_rank
135 136 137
            log_prob = log_prob - _sum_rightmost(
                t.forward_log_det_jacobian(x), event_rank - t._domain.event_rank
            )
138
            y = x
139 140 141
        log_prob += _sum_rightmost(
            self._base.log_prob(y), event_rank - len(self._base.event_shape)
        )
142 143 144 145 146
        return log_prob


def _sum_rightmost(value, n):
    return value.sum(list(range(-n, 0))) if n > 0 else value