variable.py 3.5 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distribution import constraint


class Variable(object):
    """Random variable of probability distribution.

    Args:
        is_discrete (bool): Is the variable discrete or continuous.
        event_rank (int): The rank of event dimensions.
    """

    def __init__(self, is_discrete=False, event_rank=0, constraint=None):
        self._is_discrete = is_discrete
        self._event_rank = event_rank
        self._constraint = constraint

    @property
    def is_discrete(self):
        return self._is_discrete

    @property
    def event_rank(self):
        return self._event_rank

    def constraint(self, value):
40
        """Check whether the 'value' meet the constraint conditions of this
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
        random variable."""
        return self._constraint(value)


class Real(Variable):
    def __init__(self, event_rank=0):
        super(Real, self).__init__(False, event_rank, constraint.real)


class Positive(Variable):
    def __init__(self, event_rank=0):
        super(Positive, self).__init__(False, event_rank, constraint.positive)


class Independent(Variable):
    """Reinterprets some of the batch axes of variable as event axes.

    Args:
        base (Variable): Base variable.
60 61
        reinterpreted_batch_rank (int): The rightmost batch rank to be
            reinterpreted.
62 63 64 65 66
    """

    def __init__(self, base, reinterpreted_batch_rank):
        self._base = base
        self._reinterpreted_batch_rank = reinterpreted_batch_rank
67 68 69
        super(Independent, self).__init__(
            base.is_discrete, base.event_rank + reinterpreted_batch_rank
        )
70 71 72 73 74 75

    def constraint(self, value):
        ret = self._base.constraint(value)
        if ret.dim() < self._reinterpreted_batch_rank:
            raise ValueError(
                "Input dimensions must be equal or grater than  {}".format(
76 77 78 79 80 81
                    self._reinterpreted_batch_rank
                )
            )
        return ret.reshape(
            ret.shape[: ret.dim() - self.reinterpreted_batch_rank] + (-1,)
        ).all(-1)
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103


class Stack(Variable):
    def __init__(self, vars, axis=0):
        self._vars = vars
        self._axis = axis

    @property
    def is_discrete(self):
        return any(var.is_discrete for var in self._vars)

    @property
    def event_rank(self):
        rank = max(var.event_rank for var in self._vars)
        if self._axis + rank < 0:
            rank += 1
        return rank

    def constraint(self, value):
        if not (-value.dim() <= self._axis < value.dim()):
            raise ValueError(
                f'Input dimensions {value.dim()} should be grater than stack '
104 105 106 107 108 109 110 111 112 113 114 115
                f'constraint axis {self._axis}.'
            )

        return paddle.stack(
            [
                var.check(value)
                for var, value in zip(
                    self._vars, paddle.unstack(value, self._axis)
                )
            ],
            self._axis,
        )
116 117 118 119


real = Real()
positive = Positive()