未验证 提交 9b3ef597 编写于 作者: P pangyoki 提交者: GitHub

add categorical class (#27695)

* add multinomial cpu kernel

* fix C++ notype error

* fix windows ci array len error

* let array len be const

* change array to vector

* add cuda kernrl with num_distribution is 1, and not support replacement=False

* add multinomial python api

* support num_distribution different multinomial distributions

* add categorical class

* fix test_distribution enable_static error

* add unittest for different setting of Categorical

* optimize format

* little change

* little change

* add raise error if shape not match, optimize format

* fix windows CI dtype error in concat

* little changes

* little changes2

* change values type to int64

* change values type to int64

* change values type to int64
上级 4d3eefbb
......@@ -28,13 +28,14 @@ from .fluid.layers import nn
from .fluid import core
from .fluid.framework import in_dygraph_mode
from .tensor.math import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub
from .tensor import arange, gather_nd, concat, multinomial
import math
import numpy as np
import warnings
from .fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
__all__ = ['Distribution', 'Uniform', 'Normal']
__all__ = ['Distribution', 'Uniform', 'Normal', 'Categorical']
class Distribution(object):
......@@ -640,3 +641,318 @@ class Normal(Distribution):
t1 = (t1 * t1)
return elementwise_add(
0.5 * var_ratio, 0.5 * (t1 - 1. - nn.log(var_ratio)), name=name)
class Categorical(Distribution):
"""
Categorical distribution is a discrete probability distribution that
describes the possible results of a random variable that can take on
one of K possible categories, with the probability of each category
separately specified.
The probability mass function (pmf) is:
.. math::
pmf(k; p_i) = \prod_{i=1}^{k} p_i^{[x=i]}
In the above equation:
* :math:`[x=i]` : it evaluates to 1 if :math:`x==i` , 0 otherwise.
Args:
logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]
cat = Categorical(x)
cat2 = Categorical(y)
cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]
cat.entropy()
# [1.71887]
cat.kl_divergence(cat2)
# [0.0278455]
value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]
cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]
"""
def __init__(self, logits, name=None):
"""
Args:
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32 or float64.
"""
if not in_dygraph_mode():
check_type(logits, 'logits', (np.ndarray, tensor.Variable, list),
'Categorical')
self.name = name if name is not None else 'Categorical'
self.dtype = 'float32'
if self._validate_args(logits):
self.logits = logits
self.dtype = convert_dtype(logits.dtype)
else:
if isinstance(logits, np.ndarray) and str(
logits.dtype) in ['float32', 'float64']:
self.dtype = logits.dtype
self.logits = self._to_tensor(logits)[0]
if self.dtype != convert_dtype(self.logits.dtype):
self.logits = tensor.cast(self.logits, dtype=self.dtype)
def sample(self, shape):
"""Generate samples of the specified shape.
Args:
shape (list): Shape of the generated samples.
Returns:
Tensor: A tensor with prepended dimensions shape.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
cat = Categorical(x)
cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]
"""
name = self.name + '_sample'
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
num_samples = np.prod(np.array(shape))
logits_shape = list(self.logits.shape)
if len(logits_shape) > 1:
sample_shape = shape + logits_shape[:-1]
logits = nn.reshape(self.logits,
[np.prod(logits_shape[:-1]), logits_shape[-1]])
else:
sample_shape = shape
logits = self.logits
sample_index = multinomial(logits, num_samples, True)
return nn.reshape(sample_index, sample_shape, name=name)
def kl_divergence(self, other):
"""The KL-divergence between two Categorical distributions.
Args:
other (Categorical): instance of Categorical. The data type is float32.
Returns:
Variable: kl-divergence between two Categorical distributions.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]
cat = Categorical(x)
cat2 = Categorical(y)
cat.kl_divergence(cat2)
# [0.0278455]
"""
name = self.name + '_kl_divergence'
if not in_dygraph_mode():
check_type(other, 'other', Categorical, 'kl_divergence')
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
other_logits = other.logits - nn.reduce_max(
other.logits, dim=-1, keep_dim=True)
e_logits = ops.exp(logits)
other_e_logits = ops.exp(other_logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
other_z = nn.reduce_sum(other_e_logits, dim=-1, keep_dim=True)
prob = e_logits / z
kl = nn.reduce_sum(
prob * (logits - nn.log(z) - other_logits + nn.log(other_z)),
dim=-1,
keep_dim=True,
name=name)
return kl
def entropy(self):
"""Shannon entropy in nats.
Returns:
Variable: Shannon entropy of Categorical distribution. The data type is float32.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
cat = Categorical(x)
cat.entropy()
# [1.71887]
"""
name = self.name + '_entropy'
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
e_logits = ops.exp(logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
prob = e_logits / z
neg_entropy = nn.reduce_sum(
prob * (logits - nn.log(z)), dim=-1, keep_dim=True)
entropy = nn.scale(neg_entropy, scale=-1.0, name=name)
return entropy
def probs(self, value):
"""Probabilities of the given category (``value``).
If ``logits`` is 2-D or higher dimension, the last dimension will be regarded as
category, and the others represents the different distributions.
At the same time, if ``vlaue`` is 1-D Tensor, ``value`` will be broadcast to the
same number of distributions as ``logits``.
If ``value`` is not 1-D Tensor, ``value`` should have the same number distributions
with ``logits. That is, ``value[:-1] = logits[:-1]``.
Args:
value (Tensor): The input tensor represents the selected category index.
Returns:
Tensor: probability according to the category index.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
cat = Categorical(x)
value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]
"""
name = self.name + '_probs'
dist_sum = nn.reduce_sum(self.logits, dim=-1, keep_dim=True)
prob = self.logits / dist_sum
shape = list(prob.shape)
value_shape = list(value.shape)
if len(shape) == 1:
num_value_in_one_dist = np.prod(value_shape)
index_value = nn.reshape(value, [num_value_in_one_dist, 1])
index = index_value
else:
num_dist = np.prod(shape[:-1])
num_value_in_one_dist = value_shape[-1]
prob = nn.reshape(prob, [num_dist, shape[-1]])
if len(value_shape) == 1:
value = nn.expand(value, [num_dist])
value_shape = shape[:-1] + value_shape
index_value = nn.reshape(value, [num_dist, -1, 1])
if shape[:-1] != value_shape[:-1]:
raise ValueError(
"shape of value {} must match shape of logits {}".format(
str(value_shape[:-1]), str(shape[:-1])))
index_prefix = nn.unsqueeze(
arange(
num_dist, dtype=index_value.dtype), axes=-1)
index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist])
index_prefix = nn.unsqueeze(index_prefix, axes=-1)
if index_value.dtype != index_prefix.dtype:
tensor.cast(index_prefix, dtype=index_value.dtype)
index = concat([index_prefix, index_value], axis=-1)
# value is the category index to search for the corresponding probability.
select_prob = gather_nd(prob, index)
return nn.reshape(select_prob, value_shape, name=name)
def log_prob(self, value):
"""Log probabilities of the given category. Refer to ``probs`` method.
Args:
value (Tensor): The input tensor represents the selected category index.
Returns:
Tensor: Log probability.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
cat = Categorical(x)
value = paddle.to_tensor([2,1,3])
cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]
"""
name = self.name + '_log_prob'
return nn.log(self.probs(value), name=name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册