未验证 提交 01f606b4 编写于 作者: X Xiaoxu Chen 提交者: GitHub

add multinomial probability distribution (#38820)

* add multinomial probability distribution
* fix categorical sample bug when logits less than zero
* fix categorical sample can't pass hypothesis test and entropy shape error bug
上级 9b3b53ba
...@@ -18,6 +18,7 @@ from .dirichlet import Dirichlet ...@@ -18,6 +18,7 @@ from .dirichlet import Dirichlet
from .distribution import Distribution from .distribution import Distribution
from .exponential_family import ExponentialFamily from .exponential_family import ExponentialFamily
from .kl import kl_divergence, register_kl from .kl import kl_divergence, register_kl
from .multinomial import Multinomial
from .normal import Normal from .normal import Normal
from .uniform import Uniform from .uniform import Uniform
...@@ -27,8 +28,9 @@ __all__ = [ # noqa ...@@ -27,8 +28,9 @@ __all__ = [ # noqa
'Dirichlet', 'Dirichlet',
'Distribution', 'Distribution',
'ExponentialFamily', 'ExponentialFamily',
'Multinomial',
'Normal', 'Normal',
'Uniform', 'Uniform',
'kl_divergence', 'kl_divergence',
'register_kl' 'register_kl',
] ]
...@@ -21,7 +21,14 @@ from .exponential_family import ExponentialFamily ...@@ -21,7 +21,14 @@ from .exponential_family import ExponentialFamily
class Beta(ExponentialFamily): class Beta(ExponentialFamily):
r""" r"""
Beta distribution parameterized by alpha and beta Beta distribution parameterized by alpha and beta.
In probability theory and statistics, the beta distribution is a family of
continuous probability distributions defined on the interval [0, 1]
parameterized by two positive shape parameters, denoted by alpha and beta,
that appear as exponents of the random variable and control the shape of
the distribution. The generalization to multiple variables is called a
Dirichlet distribution.
The probability density function (pdf) is The probability density function (pdf) is
...@@ -37,8 +44,14 @@ class Beta(ExponentialFamily): ...@@ -37,8 +44,14 @@ class Beta(ExponentialFamily):
Args: Args:
alpha (float|Tensor): alpha parameter of beta distribution, positive(>0). alpha (float|Tensor): Alpha parameter. It supports broadcast semantics.
beta (float|Tensor): beta parameter of beta distribution, positive(>0). The value of alpha must be positive. When the parameter is a tensor,
it represents multiple independent distribution with
a batch_shape(refer to ``Distribution`` ).
beta (float|Tensor): Beta parameter. It supports broadcast semantics.
The value of beta must be positive(>0). When the parameter is tensor,
it represent multiple independent distribution with
a batch_shape(refer to ``Distribution`` ).
Examples: Examples:
...@@ -86,56 +99,56 @@ class Beta(ExponentialFamily): ...@@ -86,56 +99,56 @@ class Beta(ExponentialFamily):
@property @property
def mean(self): def mean(self):
"""mean of beta distribution. """Mean of beta distribution.
""" """
return self.alpha / (self.alpha + self.beta) return self.alpha / (self.alpha + self.beta)
@property @property
def variance(self): def variance(self):
"""variance of beat distribution """Variance of beat distribution
""" """
sum = self.alpha + self.beta sum = self.alpha + self.beta
return self.alpha * self.beta / (sum.pow(2) * (sum + 1)) return self.alpha * self.beta / (sum.pow(2) * (sum + 1))
def prob(self, value): def prob(self, value):
"""probability density funciotn evaluated at value """Probability density funciotn evaluated at value
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): Value to be evaluated.
Returns: Returns:
Tensor: probability. Tensor: Probability.
""" """
return paddle.exp(self.log_prob(value)) return paddle.exp(self.log_prob(value))
def log_prob(self, value): def log_prob(self, value):
"""log probability density funciton evaluated at value """Log probability density funciton evaluated at value
Args: Args:
value (Tensor): value to be evaluated value (Tensor): Value to be evaluated
Returns: Returns:
Tensor: log probability. Tensor: Log probability.
""" """
return self._dirichlet.log_prob(paddle.stack([value, 1.0 - value], -1)) return self._dirichlet.log_prob(paddle.stack([value, 1.0 - value], -1))
def sample(self, shape=()): def sample(self, shape=()):
"""sample from beta distribution with sample shape. """Sample from beta distribution with sample shape.
Args: Args:
shape (Sequence[int], optional): sample shape. shape (Sequence[int], optional): Sample shape.
Returns: Returns:
sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
""" """
shape = shape if isinstance(shape, tuple) else tuple(shape) shape = shape if isinstance(shape, tuple) else tuple(shape)
return paddle.squeeze(self._dirichlet.sample(shape)[..., 0]) return paddle.squeeze(self._dirichlet.sample(shape)[..., 0], axis=-1)
def entropy(self): def entropy(self):
"""entropy of dirichlet distribution """Entropy of dirichlet distribution
Returns: Returns:
Tensor: entropy. Tensor: Entropy.
""" """
return self._dirichlet.entropy() return self._dirichlet.entropy()
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -16,6 +16,7 @@ import math ...@@ -16,6 +16,7 @@ import math
import warnings import warnings
import numpy as np import numpy as np
import paddle
from paddle import _C_ops from paddle import _C_ops
from ..fluid import core from ..fluid import core
...@@ -123,7 +124,7 @@ class Categorical(Distribution): ...@@ -123,7 +124,7 @@ class Categorical(Distribution):
Returns: Returns:
Tensor: A tensor with prepended dimensions shape. Tensor: A tensor with prepended dimensions shape.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -153,14 +154,22 @@ class Categorical(Distribution): ...@@ -153,14 +154,22 @@ class Categorical(Distribution):
logits_shape = list(self.logits.shape) logits_shape = list(self.logits.shape)
if len(logits_shape) > 1: if len(logits_shape) > 1:
sample_shape = shape + logits_shape[:-1] sample_shape = shape + logits_shape[:-1]
logits = nn.reshape(self.logits, logits = paddle.reshape(
[np.prod(logits_shape[:-1]), logits_shape[-1]]) self.logits, [np.prod(logits_shape[:-1]), logits_shape[-1]])
else: else:
sample_shape = shape sample_shape = shape
logits = self.logits logits = self.logits
sample_index = multinomial(logits, num_samples, True) sample_index = multinomial(
return nn.reshape(sample_index, sample_shape, name=name) self._logits_to_probs(logits), num_samples, True)
# multinomial sample shape is (logits.shape[:-1], num_samples), need to
# tanspose to (num_samples, logits.shape[:-1])
permute = list(range(sample_index.dim()))
permute.insert(0, permute.pop(-1))
sample_index = sample_index.transpose(permute)
return paddle.reshape(sample_index, sample_shape, name=name)
def kl_divergence(self, other): def kl_divergence(self, other):
"""The KL-divergence between two Categorical distributions. """The KL-divergence between two Categorical distributions.
...@@ -170,7 +179,7 @@ class Categorical(Distribution): ...@@ -170,7 +179,7 @@ class Categorical(Distribution):
Returns: Returns:
Tensor: kl-divergence between two Categorical distributions. Tensor: kl-divergence between two Categorical distributions.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -200,19 +209,20 @@ class Categorical(Distribution): ...@@ -200,19 +209,20 @@ class Categorical(Distribution):
if not in_dygraph_mode(): if not in_dygraph_mode():
check_type(other, 'other', Categorical, 'kl_divergence') check_type(other, 'other', Categorical, 'kl_divergence')
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) logits = self.logits - \
other_logits = other.logits - nn.reduce_max( paddle.max(self.logits, axis=-1, keepdim=True)
other.logits, dim=-1, keep_dim=True) other_logits = other.logits - paddle.max(
other.logits, axis=-1, keepdim=True)
e_logits = ops.exp(logits) e_logits = ops.exp(logits)
other_e_logits = ops.exp(other_logits) other_e_logits = ops.exp(other_logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True) z = paddle.sum(e_logits, axis=-1, keepdim=True)
other_z = nn.reduce_sum(other_e_logits, dim=-1, keep_dim=True) other_z = paddle.sum(other_e_logits, axis=-1, keepdim=True)
prob = e_logits / z prob = e_logits / z
kl = nn.reduce_sum( kl = paddle.sum(prob * (
prob * (logits - nn.log(z) - other_logits + nn.log(other_z)), logits - paddle.log(z) - other_logits + paddle.log(other_z)),
dim=-1, axis=-1,
keep_dim=True, keepdim=True,
name=name) name=name)
return kl return kl
...@@ -221,7 +231,7 @@ class Categorical(Distribution): ...@@ -221,7 +231,7 @@ class Categorical(Distribution):
Returns: Returns:
Tensor: Shannon entropy of Categorical distribution. The data type is float32. Tensor: Shannon entropy of Categorical distribution. The data type is float32.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -241,14 +251,14 @@ class Categorical(Distribution): ...@@ -241,14 +251,14 @@ class Categorical(Distribution):
""" """
name = self.name + '_entropy' name = self.name + '_entropy'
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) logits = self.logits - \
paddle.max(self.logits, axis=-1, keepdim=True)
e_logits = ops.exp(logits) e_logits = ops.exp(logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True) z = paddle.sum(e_logits, axis=-1, keepdim=True)
prob = e_logits / z prob = e_logits / z
neg_entropy = nn.reduce_sum( neg_entropy = paddle.sum(prob * (logits - paddle.log(z)), axis=-1)
prob * (logits - nn.log(z)), dim=-1, keep_dim=True) entropy = paddle.scale(neg_entropy, scale=-1.0, name=name)
entropy = nn.scale(neg_entropy, scale=-1.0, name=name)
return entropy return entropy
def probs(self, value): def probs(self, value):
...@@ -266,7 +276,7 @@ class Categorical(Distribution): ...@@ -266,7 +276,7 @@ class Categorical(Distribution):
Returns: Returns:
Tensor: probability according to the category index. Tensor: probability according to the category index.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -288,33 +298,33 @@ class Categorical(Distribution): ...@@ -288,33 +298,33 @@ class Categorical(Distribution):
""" """
name = self.name + '_probs' name = self.name + '_probs'
dist_sum = nn.reduce_sum(self.logits, dim=-1, keep_dim=True) dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True)
prob = self.logits / dist_sum prob = self.logits / dist_sum
shape = list(prob.shape) shape = list(prob.shape)
value_shape = list(value.shape) value_shape = list(value.shape)
if len(shape) == 1: if len(shape) == 1:
num_value_in_one_dist = np.prod(value_shape) num_value_in_one_dist = np.prod(value_shape)
index_value = nn.reshape(value, [num_value_in_one_dist, 1]) index_value = paddle.reshape(value, [num_value_in_one_dist, 1])
index = index_value index = index_value
else: else:
num_dist = np.prod(shape[:-1]) num_dist = np.prod(shape[:-1])
num_value_in_one_dist = value_shape[-1] num_value_in_one_dist = value_shape[-1]
prob = nn.reshape(prob, [num_dist, shape[-1]]) prob = paddle.reshape(prob, [num_dist, shape[-1]])
if len(value_shape) == 1: if len(value_shape) == 1:
value = nn.expand(value, [num_dist]) value = nn.expand(value, [num_dist])
value_shape = shape[:-1] + value_shape value_shape = shape[:-1] + value_shape
index_value = nn.reshape(value, [num_dist, -1, 1]) index_value = paddle.reshape(value, [num_dist, -1, 1])
if shape[:-1] != value_shape[:-1]: if shape[:-1] != value_shape[:-1]:
raise ValueError( raise ValueError(
"shape of value {} must match shape of logits {}".format( "shape of value {} must match shape of logits {}".format(
str(value_shape[:-1]), str(shape[:-1]))) str(value_shape[:-1]), str(shape[:-1])))
index_prefix = nn.unsqueeze( index_prefix = paddle.unsqueeze(
arange( arange(
num_dist, dtype=index_value.dtype), axes=-1) num_dist, dtype=index_value.dtype), axis=-1)
index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist]) index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist])
index_prefix = nn.unsqueeze(index_prefix, axes=-1) index_prefix = paddle.unsqueeze(index_prefix, axis=-1)
if index_value.dtype != index_prefix.dtype: if index_value.dtype != index_prefix.dtype:
tensor.cast(index_prefix, dtype=index_value.dtype) tensor.cast(index_prefix, dtype=index_value.dtype)
...@@ -322,7 +332,7 @@ class Categorical(Distribution): ...@@ -322,7 +332,7 @@ class Categorical(Distribution):
# value is the category index to search for the corresponding probability. # value is the category index to search for the corresponding probability.
select_prob = gather_nd(prob, index) select_prob = gather_nd(prob, index)
return nn.reshape(select_prob, value_shape, name=name) return paddle.reshape(select_prob, value_shape, name=name)
def log_prob(self, value): def log_prob(self, value):
"""Log probabilities of the given category. Refer to ``probs`` method. """Log probabilities of the given category. Refer to ``probs`` method.
...@@ -332,7 +342,7 @@ class Categorical(Distribution): ...@@ -332,7 +342,7 @@ class Categorical(Distribution):
Returns: Returns:
Tensor: Log probability. Tensor: Log probability.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -354,4 +364,4 @@ class Categorical(Distribution): ...@@ -354,4 +364,4 @@ class Categorical(Distribution):
""" """
name = self.name + '_log_prob' name = self.name + '_log_prob'
return nn.log(self.probs(value), name=name) return paddle.log(self.probs(value), name=name)
...@@ -22,23 +22,37 @@ from .exponential_family import ExponentialFamily ...@@ -22,23 +22,37 @@ from .exponential_family import ExponentialFamily
class Dirichlet(ExponentialFamily): class Dirichlet(ExponentialFamily):
r""" r"""
Dirichlet distribution with parameter concentration Dirichlet distribution with parameter "concentration".
The Dirichlet distribution is defined over the `(k-1)-simplex` using a The Dirichlet distribution is defined over the `(k-1)-simplex` using a
positive, lenght-k vector concentration(`k > 1`). positive, lenght-k vector concentration(`k > 1`).
The Dirichlet is identically the Beta distribution when `k = 2`. The Dirichlet is identically the Beta distribution when `k = 2`.
For independent and identically distributed continuous random variable
:math:`\boldsymbol X \in R_k` , and support
:math:`\boldsymbol X \in (0,1), ||\boldsymbol X|| = 1` ,
The probability density function (pdf) is The probability density function (pdf) is
.. math:: .. math::
f(\boldsymbol X; \boldsymbol \alpha) = \frac{1}{B(\boldsymbol \alpha)} \prod_{i=1}^{k}x_i^{\alpha_i-1}
f(x_1,...,x_k; \alpha_1,...,\alpha_k) = \frac{1}{B(\alpha)} \prod_{i=1}^{k}x_i^{\alpha_i-1} where :math:`\boldsymbol \alpha = {\alpha_1,...,\alpha_k}, k \ge 2` is
parameter, the normalizing constant is the multivariate beta function.
The normalizing constant is the multivariate beta function. .. math::
B(\boldsymbol \alpha) = \frac{\prod_{i=1}^{k} \Gamma(\alpha_i)}{\Gamma(\alpha_0)}
:math:`\alpha_0=\sum_{i=1}^{k} \alpha_i` is the sum of parameters,
:math:`\Gamma(\alpha)` is gamma function.
Args: Args:
concentration (Tensor): concentration parameter of dirichlet concentration (Tensor): "Concentration" parameter of dirichlet
distribution distribution, also called :math:`\alpha`. When it's over one
dimension, the last axis denotes the parameter of distribution,
``event_shape=concentration.shape[-1:]`` , axes other than last are
condsider batch dimensions with ``batch_shape=concentration.shape[:-1]`` .
Examples: Examples:
...@@ -68,59 +82,59 @@ class Dirichlet(ExponentialFamily): ...@@ -68,59 +82,59 @@ class Dirichlet(ExponentialFamily):
@property @property
def mean(self): def mean(self):
"""mean of Dirichelt distribution. """Mean of Dirichelt distribution.
Returns: Returns:
mean value of distribution. Mean value of distribution.
""" """
return self.concentration / self.concentration.sum(-1, keepdim=True) return self.concentration / self.concentration.sum(-1, keepdim=True)
@property @property
def variance(self): def variance(self):
"""variance of Dirichlet distribution. """Variance of Dirichlet distribution.
Returns: Returns:
variance value of distribution. Variance value of distribution.
""" """
concentration0 = self.concentration.sum(-1, keepdim=True) concentration0 = self.concentration.sum(-1, keepdim=True)
return (self.concentration * (concentration0 - self.concentration)) / ( return (self.concentration * (concentration0 - self.concentration)) / (
concentration0.pow(2) * (concentration0 + 1)) concentration0.pow(2) * (concentration0 + 1))
def sample(self, shape=()): def sample(self, shape=()):
"""sample from dirichlet distribution. """Sample from dirichlet distribution.
Args: Args:
shape (Sequence[int], optional): sample shape. Defaults to empty tuple. shape (Sequence[int], optional): Sample shape. Defaults to empty tuple.
""" """
shape = shape if isinstance(shape, tuple) else tuple(shape) shape = shape if isinstance(shape, tuple) else tuple(shape)
return _dirichlet(self.concentration.expand(self._extend_shape(shape))) return _dirichlet(self.concentration.expand(self._extend_shape(shape)))
def prob(self, value): def prob(self, value):
"""Probability density function(pdf) evaluated at value. """Probability density function(PDF) evaluated at value.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): Value to be evaluated.
Returns: Returns:
pdf evaluated at value. PDF evaluated at value.
""" """
return paddle.exp(self.log_prob(value)) return paddle.exp(self.log_prob(value))
def log_prob(self, value): def log_prob(self, value):
"""log of probability densitiy function. """Log of probability densitiy function.
Args: Args:
value (Tensor): value to be evaluated. value (Tensor): Value to be evaluated.
""" """
return ((paddle.log(value) * (self.concentration - 1.0) return ((paddle.log(value) * (self.concentration - 1.0)
).sum(-1) + paddle.lgamma(self.concentration.sum(-1)) - ).sum(-1) + paddle.lgamma(self.concentration.sum(-1)) -
paddle.lgamma(self.concentration).sum(-1)) paddle.lgamma(self.concentration).sum(-1))
def entropy(self): def entropy(self):
"""entropy of Dirichlet distribution. """Entropy of Dirichlet distribution.
Returns: Returns:
entropy of distribution. Entropy of distribution.
""" """
concentration0 = self.concentration.sum(-1) concentration0 = self.concentration.sum(-1)
k = self.concentration.shape[-1] k = self.concentration.shape[-1]
......
...@@ -25,6 +25,7 @@ import math ...@@ -25,6 +25,7 @@ import math
import warnings import warnings
import numpy as np import numpy as np
import paddle
from paddle import _C_ops from paddle import _C_ops
from ..fluid import core from ..fluid import core
...@@ -102,7 +103,13 @@ class Distribution(object): ...@@ -102,7 +103,13 @@ class Distribution(object):
raise NotImplementedError raise NotImplementedError
def probs(self, value): def probs(self, value):
"""Probability density/mass function.""" """Probability density/mass function.
.. note::
This method will be deprecated in the future, please use `prob`
instead.
"""
raise NotImplementedError raise NotImplementedError
def _extend_shape(self, sample_shape): def _extend_shape(self, sample_shape):
...@@ -212,3 +219,22 @@ class Distribution(object): ...@@ -212,3 +219,22 @@ class Distribution(object):
) )
return tensor.cast(value, dtype=param.dtype) return tensor.cast(value, dtype=param.dtype)
return value return value
def _probs_to_logits(self, probs, is_binary=False):
r"""
Converts probabilities into logits. For the binary, probs denotes the
probability of occurrence of the event indexed by `1`. For the
multi-dimensional, values of last axis denote the probabilities of
occurrence of each of the events.
"""
return (paddle.log(probs) - paddle.log1p(-probs)) \
if is_binary else paddle.log(probs)
def _logits_to_probs(self, logits, is_binary=False):
r"""
Converts logits into probabilities. For the binary, each value denotes
log odds, whereas for the multi-dimensional case, the values along the
last dimension denote the log probabilities of the events.
"""
return paddle.nn.functional.sigmoid(logits) \
if is_binary else paddle.nn.functional.softmax(logits, axis=-1)
...@@ -33,6 +33,8 @@ class ExponentialFamily(Distribution): ...@@ -33,6 +33,8 @@ class ExponentialFamily(Distribution):
where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes
the sufficient statistic, :math:`F(\theta)` is the log normalizer function the sufficient statistic, :math:`F(\theta)` is the log normalizer function
for a given family and :math:`k(x)` is the carrier measure. for a given family and :math:`k(x)` is the carrier measure.
Distribution belongs to exponential family referring to https://en.wikipedia.org/wiki/Exponential_family
""" """
@property @property
......
...@@ -43,10 +43,7 @@ def kl_divergence(p, q): ...@@ -43,10 +43,7 @@ def kl_divergence(p, q):
q (Distribution): ``Distribution`` object. q (Distribution): ``Distribution`` object.
Returns: Returns:
Tensor: batchwise KL-divergence between distribution p and q. Tensor: Batchwise KL-divergence between distribution p and q.
Raises:
NotImplementedError: can't find register function for KL(p||Q).
Examples: Examples:
...@@ -68,9 +65,15 @@ def kl_divergence(p, q): ...@@ -68,9 +65,15 @@ def kl_divergence(p, q):
def register_kl(cls_p, cls_q): def register_kl(cls_p, cls_q):
"""Decorator for register a KL divergence implemention function. """Decorator for register a KL divergence implemention function.
The ``kl_divergence(p, q)`` function will search concrete implemention
functions registered by ``register_kl``, according to multi-dispatch pattern.
If an implemention function is found, it will return the result, otherwise,
it will raise ``NotImplementError`` exception. Users can register
implemention funciton by the decorator.
Args: Args:
cls_p(Distribution): subclass derived from ``Distribution``. cls_p(Distribution): Subclass derived from ``Distribution``.
cls_q(Distribution): subclass derived from ``Distribution``. cls_q(Distribution): Subclass derived from ``Distribution``.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -93,7 +96,7 @@ def register_kl(cls_p, cls_q): ...@@ -93,7 +96,7 @@ def register_kl(cls_p, cls_q):
def _dispatch(cls_p, cls_q): def _dispatch(cls_p, cls_q):
"""multiple dispatch into concrete implement function""" """Multiple dispatch into concrete implement function"""
# find all matched super class pair of p and q # find all matched super class pair of p and q
matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE
...@@ -167,8 +170,7 @@ def _kl_uniform_uniform(p, q): ...@@ -167,8 +170,7 @@ def _kl_uniform_uniform(p, q):
@register_kl(ExponentialFamily, ExponentialFamily) @register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q): def _kl_expfamily_expfamily(p, q):
"""compute kl-divergence using `Bregman divergences` """Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_
https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf
""" """
if not type(p) == type(q): if not type(p) == type(q):
raise NotImplementedError raise NotImplementedError
...@@ -205,5 +207,5 @@ def _kl_expfamily_expfamily(p, q): ...@@ -205,5 +207,5 @@ def _kl_expfamily_expfamily(p, q):
def _sum_rightmost(value, n): def _sum_rightmost(value, n):
"""sum value along rightmost n dim""" """Sum elements along rightmost n dim"""
return value.sum(list(range(-n, 0))) if n > 0 else value return value.sum(list(range(-n, 0))) if n > 0 else value
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import paddle
from paddle.distribution import categorical, distribution
class Multinomial(distribution.Distribution):
r"""
Multinomial distribution parameterized by :attr:`total_count` and
:attr:`probs`.
In probability theory, the multinomial distribution is a generalization of
the binomial distribution, it models the probability of counts for each side
of a k-sided die rolled n times. When k is 2 and n is 1, the multinomial is
the bernoulli distribution, when k is 2 and n is grater than 1, it is the
binomial distribution, when k is grater than 2 and n is 1, it is the
categorical distribution.
The probability mass function (PMF) for multinomial is
.. math::
f(x_1, ..., x_k; n, p_1,...,p_k) = \frac{n!}{x_1!...x_k!}p_1^{x_1}...p_k^{x_k}
where, :math:`n` is number of trials, k is the number of categories,
:math:`p_i` denote probability of a trial falling into each category,
:math:`{\textstyle \sum_{i=1}^{k}p_i=1}, p_i \ge 0`, and :math:`x_i` denote
count of each category.
Args:
total_count (int): Number of trials.
probs (Tensor): Probability of a trial falling into each category. Last
axis of probs indexes over categories, other axes index over batches.
Probs value should between [0, 1], and sum to 1 along last axis. If
the value over 1, it will be normalized to sum to 1 along the last
axis.
Examples:
.. code-block:: python
import paddle
multinomial = paddle.distribution.Multinomial(10, paddle.to_tensor([0.2, 0.3, 0.5]))
print(multinomial.sample((2, 3)))
# Tensor(shape=[2, 3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[[1., 4., 5.],
# [0., 2., 8.],
# [2., 4., 4.]],
# [[1., 6., 3.],
# [3., 3., 4.],
# [3., 4., 3.]]])
"""
def __init__(self, total_count, probs):
if not isinstance(total_count, int) or total_count < 1:
raise ValueError(
'input parameter total_count must be int type and grater than zero.'
)
if probs.dim() < 1:
raise ValueError(
'probs parameter shoule not be none and over one dimension')
self.probs = probs / probs.sum(-1, keepdim=True)
self.total_count = total_count
self._categorical = categorical.Categorical(
logits=self._probs_to_logits(probs))
super(Multinomial, self).__init__(probs.shape[:-1], probs.shape[-1:])
@property
def mean(self):
"""mean of multinomial distribuion.
Returns:
Tensor: mean value.
"""
return self.probs * self.total_count
@property
def variance(self):
"""variance of multinomial distribution.
Returns:
Tensor: variance value.
"""
return self.total_count * self.probs * (1 - self.probs)
def prob(self, value):
"""probability mass function evaluated at value.
Args:
value (Tensor): value to be evaluated.
Returns:
Tensor: probability of value.
"""
return paddle.exp(self.log_prob(value))
def log_prob(self, value):
"""probability mass function evaluated at value
Args:
value (Tensor): value to be evaluated.
Returns:
Tensor: probability of value.
"""
if paddle.is_integer(value):
value = paddle.cast(value, self.probs.dtype)
logits, value = paddle.broadcast_tensors(
[paddle.log(self.probs), value])
logits[(value == 0) & (paddle.isinf(logits))] = 0
return (paddle.lgamma(value.sum(-1) + 1) -
paddle.lgamma(value + 1).sum(-1) + (value * logits).sum(-1))
def sample(self, shape=()):
"""draw sample data from multinomial distribution
Args:
sample_shape (tuple, optional): [description]. Defaults to ().
"""
if not isinstance(shape, collections.Iterable):
raise TypeError('sample shape must be Iterable object.')
samples = self._categorical.sample([self.total_count, ] + list(shape))
return paddle.nn.functional.one_hot(
samples, self.probs.shape[-1]).cast(self.probs.dtype).sum(0)
def entropy(self):
"""entropy of multinomial distribution
Returns:
Tensor: entropy value
"""
n = paddle.full(
shape=[1], fill_value=self.total_count, dtype=self.probs.dtype)
support = paddle.arange(
self.total_count + 1, dtype=self.probs.dtype).reshape((-1, ) + (
1, ) * len(self.probs.shape))[1:]
binomial_pmf = paddle.exp(self._binomial_logpmf(n, support))
return ((n * self._categorical.entropy() - paddle.lgamma(n + 1)) + (
(binomial_pmf * paddle.lgamma(support + 1)).sum([0, -1])))
def _binomial_logpmf(self, count, value):
logits = self._probs_to_logits(self.probs, is_binary=True)
factor_n = paddle.lgamma(count + 1)
factor_k = paddle.lgamma(value + 1)
factor_nmk = paddle.lgamma(count - value + 1)
norm = (count * _clip_by_zero(logits) + count *
paddle.log1p(paddle.exp(-paddle.abs(logits))) - factor_n)
return value * logits - factor_k - factor_nmk - norm
def _binomial_support(count, dtype):
return paddle.arange(count + 1, dtype=dtype)
def _clip_by_zero(x):
# like clip(x, min=0) but grad at 0 is 0.5
return (x.clip(min=0) + x - x.clip(max=0)) / 2
...@@ -33,7 +33,7 @@ class CategoricalNumpy(DistributionNumpy): ...@@ -33,7 +33,7 @@ class CategoricalNumpy(DistributionNumpy):
e_logits = np.exp(logits) e_logits = np.exp(logits)
z = np.sum(e_logits, axis=-1, keepdims=True) z = np.sum(e_logits, axis=-1, keepdims=True)
prob = e_logits / z prob = e_logits / z
return -1. * np.sum(prob * (logits - np.log(z)), axis=-1, keepdims=True) return -1. * np.sum(prob * (logits - np.log(z)), axis=-1)
def kl_divergence(self, other): def kl_divergence(self, other):
logits = self.logits - np.max(self.logits, axis=-1, keepdims=True) logits = self.logits - np.max(self.logits, axis=-1, keepdims=True)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import scipy.stats
import config
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [
('one-dim', 10, config.xrand((3, ))),
('multi-dim', 9, config.xrand((10, 20))),
('prob-sum-one', 10, np.array([0.5, 0.2, 0.3])),
('prob-sum-non-one', 10, np.array([2., 3., 5.])),
])
class TestMultinomial(unittest.TestCase):
def setUp(self):
self._dist = paddle.distribution.Multinomial(
total_count=self.total_count, probs=paddle.to_tensor(self.probs))
def test_mean(self):
mean = self._dist.mean
self.assertEqual(mean.numpy().dtype, self.probs.dtype)
np.testing.assert_allclose(
mean,
self._np_mean(),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)))
def test_variance(self):
var = self._dist.variance
self.assertEqual(var.numpy().dtype, self.probs.dtype)
np.testing.assert_allclose(
var,
self._np_variance(),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)))
def test_entropy(self):
entropy = self._dist.entropy()
self.assertEqual(entropy.numpy().dtype, self.probs.dtype)
np.testing.assert_allclose(
entropy,
self._np_entropy(),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)))
def test_sample(self):
sample_shape = ()
samples = self._dist.sample(sample_shape)
self.assertEqual(samples.numpy().dtype, self.probs.dtype)
self.assertEqual(
tuple(samples.shape),
sample_shape + self._dist.batch_shape + self._dist.event_shape)
sample_shape = (6, )
samples = self._dist.sample(sample_shape)
self.assertEqual(samples.numpy().dtype, self.probs.dtype)
self.assertEqual(
tuple(samples.shape),
sample_shape + self._dist.batch_shape + self._dist.event_shape)
self.assertTrue(
np.all(samples.sum(-1).numpy() == self._dist.total_count))
sample_shape = (5000, )
samples = self._dist.sample(sample_shape)
sample_mean = samples.mean(axis=0)
# Tolerance value 0.2 is empirical value which is consistent with
# TensorFlow
np.testing.assert_allclose(
sample_mean, self._dist.mean, atol=0, rtol=0.20)
def _np_variance(self):
probs = self.probs / self.probs.sum(-1, keepdims=True)
return self.total_count * probs * (1 - probs)
def _np_mean(self):
probs = self.probs / self.probs.sum(-1, keepdims=True)
return self.total_count * probs
def _np_entropy(self):
probs = self.probs / self.probs.sum(-1, keepdims=True)
return scipy.stats.multinomial.entropy(self.total_count, probs)
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'total_count', 'probs', 'value'),
[
('value-float', 10, np.array([0.2, 0.3, 0.5]), np.array([2., 3., 5.])),
('value-int', 10, np.array([0.2, 0.3, 0.5]), np.array([2, 3, 5])),
('value-multi-dim', 10, np.array([[0.3, 0.7], [0.5, 0.5]]),
np.array([[4., 6], [8, 2]])),
# ('value-sum-non-n', 10, np.array([0.5, 0.2, 0.3]), np.array([4,5,2])),
])
class TestMultinomialPmf(unittest.TestCase):
def setUp(self):
self._dist = paddle.distribution.Multinomial(
total_count=self.total_count, probs=paddle.to_tensor(self.probs))
def test_prob(self):
np.testing.assert_allclose(
self._dist.prob(paddle.to_tensor(self.value)),
scipy.stats.multinomial.pmf(self.value, self.total_count,
self.probs),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)))
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [
('total_count_le_one', 0, np.array([0.3, 0.7])),
('total_count_float', np.array([0.3, 0.7])),
('probs_zero_dim', np.array(0)),
])
class TestMultinomialException(unittest.TestCase):
def TestInit(self):
with self.assertRaises(ValueError):
paddle.distribution.Multinomial(self.total_count,
paddle.to_tensor(self.probs))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import scipy.stats
import config
paddle.enable_static()
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [
('one-dim', 5, config.xrand((3, ))),
('multi-dim', 9, config.xrand((2, 3))),
('prob-sum-one', 5, np.array([0.5, 0.2, 0.3])),
('prob-sum-non-one', 5, np.array([2., 3., 5.])),
])
class TestMultinomial(unittest.TestCase):
def setUp(self):
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
probs = paddle.static.data('probs', self.probs.shape,
self.probs.dtype)
dist = paddle.distribution.Multinomial(self.total_count, probs)
mean = dist.mean
var = dist.variance
entropy = dist.entropy()
mini_samples = dist.sample(shape=(6, ))
large_samples = dist.sample(shape=(5000, ))
fetch_list = [mean, var, entropy, mini_samples, large_samples]
feed = {'probs': self.probs}
executor.run(startup_program)
[
self.mean, self.var, self.entropy, self.mini_samples,
self.large_samples
] = executor.run(main_program, feed=feed, fetch_list=fetch_list)
def test_mean(self):
self.assertEqual(str(self.mean.dtype).split('.')[-1], self.probs.dtype)
np.testing.assert_allclose(
self.mean,
self._np_mean(),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)))
def test_variance(self):
self.assertEqual(str(self.var.dtype).split('.')[-1], self.probs.dtype)
np.testing.assert_allclose(
self.var,
self._np_variance(),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)))
def test_entropy(self):
self.assertEqual(
str(self.entropy.dtype).split('.')[-1], self.probs.dtype)
np.testing.assert_allclose(
self.entropy,
self._np_entropy(),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)))
def test_sample(self):
self.assertEqual(
str(self.mini_samples.dtype).split('.')[-1], self.probs.dtype)
self.assertTrue(np.all(self.mini_samples.sum(-1) == self.total_count))
sample_mean = self.large_samples.mean(axis=0)
np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.20)
def _np_variance(self):
probs = self.probs / self.probs.sum(-1, keepdims=True)
return self.total_count * probs * (1 - probs)
def _np_mean(self):
probs = self.probs / self.probs.sum(-1, keepdims=True)
return self.total_count * probs
def _np_entropy(self):
probs = self.probs / self.probs.sum(-1, keepdims=True)
return scipy.stats.multinomial.entropy(self.total_count, probs)
@config.place(config.DEVICES)
@config.parameterize(
(config.TEST_CASE_NAME, 'total_count', 'probs', 'value'),
[
('value-float', 5, np.array([0.2, 0.3, 0.5]), np.array([1., 1., 3.])),
('value-int', 5, np.array([0.2, 0.3, 0.5]), np.array([2, 2, 1])),
('value-multi-dim', 5, np.array([[0.3, 0.7], [0.5, 0.5]]),
np.array([[1., 4.], [2., 3.]])),
# ('value-sum-non-n', 10, np.array([0.5, 0.2, 0.3]), np.array([4,5,2])),
])
class TestMultinomialPmf(unittest.TestCase):
def setUp(self):
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
probs = paddle.static.data('probs', self.probs.shape,
self.probs.dtype)
value = paddle.static.data('value', self.value.shape,
self.value.dtype)
dist = paddle.distribution.Multinomial(self.total_count, probs)
pmf = dist.prob(value)
feed = {'probs': self.probs, 'value': self.value}
fetch_list = [pmf]
executor.run(startup_program)
[self.pmf] = executor.run(main_program,
feed=feed,
fetch_list=fetch_list)
def test_prob(self):
np.testing.assert_allclose(
self.pmf,
scipy.stats.multinomial.pmf(self.value, self.total_count,
self.probs),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)))
@config.place(config.DEVICES)
@config.parameterize((config.TEST_CASE_NAME, 'total_count', 'probs'), [
('total_count_le_one', 0, np.array([0.3, 0.7])),
('total_count_float', np.array([0.3, 0.7])),
('probs_zero_dim', np.array(0)),
])
class TestMultinomialException(unittest.TestCase):
def setUp(self):
startup_program = paddle.static.Program()
self.main_program = paddle.static.Program()
self.executor = paddle.static.Executor(self.place)
with paddle.static.program_guard(main_program, startup_program):
probs = paddle.static.data('probs', self.probs.shape,
self.probs.dtype)
dist = paddle.distribution.Multinomial(self.total_count, probs)
self.feed = {'probs': self.probs}
executor.run(startup_program)
def TestInit(self):
with self.assertRaises(ValueError):
self.executor.run(self.main_program, feed=self.feed, fetch=[])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册