independent.py 3.4 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distribution import distribution


class Independent(distribution.Distribution):
    r"""
    Reinterprets some of the batch dimensions of a distribution as event dimensions.

    This is mainly useful for changing the shape of the result of
23
    :meth:`log_prob`.
24 25 26

    Args:
        base (Distribution): The base distribution.
27
        reinterpreted_batch_rank (int): The number of batch dimensions to
28 29 30 31 32
            reinterpret as event dimensions.

    Examples:

        .. code-block:: python
33

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
            import paddle
            from paddle.distribution import independent

            beta = paddle.distribution.Beta(paddle.to_tensor([0.5, 0.5]), paddle.to_tensor([0.5, 0.5]))
            print(beta.batch_shape, beta.event_shape)
            # (2,) ()
            print(beta.log_prob(paddle.to_tensor(0.2)))
            # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [-0.22843921, -0.22843921])
            reinterpreted_beta = independent.Independent(beta, 1)
            print(reinterpreted_beta.batch_shape, reinterpreted_beta.event_shape)
            # () (2,)
            print(reinterpreted_beta.log_prob(paddle.to_tensor([0.2,  0.2])))
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [-0.45687842])
    """

    def __init__(self, base, reinterpreted_batch_rank):
        if not isinstance(base, distribution.Distribution):
            raise TypeError(
54 55
                f"Expected type of 'base' is Distribution, but got {type(base)}"
            )
56 57 58 59 60 61 62 63
        if not (0 < reinterpreted_batch_rank <= len(base.batch_shape)):
            raise ValueError(
                f"Expected 0 < reinterpreted_batch_rank <= {len(base.batch_shape)}, but got {reinterpreted_batch_rank}"
            )
        self._base = base
        self._reinterpreted_batch_rank = reinterpreted_batch_rank

        shape = base.batch_shape + base.event_shape
64 65 66 67 68 69 70 71
        super(Independent, self).__init__(
            batch_shape=shape[
                : len(base.batch_shape) - reinterpreted_batch_rank
            ],
            event_shape=shape[
                len(base.batch_shape) - reinterpreted_batch_rank :
            ],
        )
72 73 74 75 76 77 78 79 80 81 82 83 84

    @property
    def mean(self):
        return self._base.mean

    @property
    def variance(self):
        return self._base.variance

    def sample(self, shape=()):
        return self._base.sample(shape)

    def log_prob(self, value):
85 86 87
        return self._sum_rightmost(
            self._base.log_prob(value), self._reinterpreted_batch_rank
        )
88 89 90 91 92

    def prob(self, value):
        return self.log_prob(value).exp()

    def entropy(self):
93 94 95
        return self._sum_rightmost(
            self._base.entropy(), self._reinterpreted_batch_rank
        )
96 97 98

    def _sum_rightmost(self, value, n):
        return value.sum(list(range(-n, 0))) if n > 0 else value