提交 80be7782 编写于 作者: P pangyoki

add Categorical class

上级 e9dd763c
......@@ -34,7 +34,7 @@ import warnings
from .fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
__all__ = ['Distribution', 'Uniform', 'Normal']
__all__ = ['Distribution', 'Uniform', 'Normal', 'Categorical']
class Distribution(object):
......@@ -640,3 +640,154 @@ class Normal(Distribution):
t1 = (t1 * t1)
return elementwise_add(
0.5 * var_ratio, 0.5 * (t1 - 1. - nn.log(var_ratio)), name=name)
class Categorical(Distribution):
"""
Categorical distribution is a discrete probability distribution that
describes the possible results of a random variable that can take on
one of K possible categories, with the probability of each category
separately specified.
The probability mass function (pmf) is:
.. math::
pmf(k; p_i) = \prod_{i=1}^{k} p_i^{[x=i]}
In the above equation:
* :math:`[x=i]` : it evaluates to 1 if :math:`x==i` , 0 otherwise.
Args:
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
from paddle.fluid.layers import Categorical
a_logits_npdata = np.array([-0.602,-0.602], dtype="float32")
a_logits_tensor = layers.create_tensor(dtype="float32")
layers.assign(a_logits_npdata, a_logits_tensor)
b_logits_npdata = np.array([-0.102,-0.112], dtype="float32")
b_logits_tensor = layers.create_tensor(dtype="float32")
layers.assign(b_logits_npdata, b_logits_tensor)
a = Categorical(a_logits_tensor)
b = Categorical(b_logits_tensor)
a.entropy()
# [0.6931472] with shape: [1]
b.entropy()
# [0.6931347] with shape: [1]
a.kl_divergence(b)
# [1.2516975e-05] with shape: [1]
"""
def __init__(self, logits, name=None):
"""
Args:
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32.
"""
if not in_dygraph_mode():
check_type(logits, 'logits', (np.ndarray, tensor.Variable, list),
'Categorical')
self.name = name if name is not None else 'Categorical'
self.dtype = 'float32'
if self._validate_args(logits):
self.logits = logits
self.dtype = convert_dtype(logits.dtype)
else:
if isinstance(logits, np.ndarray) and str(
logits.dtype) in ['float32', 'float64']:
self.dtype = logits.dtype
self.logits = self._to_tensor(logits)[0]
if self.dtype != convert_dtype(self.logits.dtype):
self.logits = tensor.cast(self.logits, dtype=self.dtype)
def sample(self, shape):
"""Generate samples of the specified shape.
Args:
shape (list): 1D `int32`. Shape of the generated samples.
Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.
"""
name = self.name + '_sample'
num_samples = np.prod(np.array(shape))
if in_dygraph_mode():
sample_index = core.ops.multinomial(
self.logits, 'num_samples', num_samples, 'replacement', True)
return nn.reshape(sample_index, shape + [-1])
check_type(shape, 'shape', (list), 'sample')
helper = LayerHelper("multinomial", **locals())
out = helper.create_variable_for_type_inference(
dtype=convert_np_dtype_to_dtype_('int64'))
helper.append_op(
type='multinomial',
inputs={"X": self.logits},
outputs={'Out': out},
attrs={'num_samples': num_samples,
'replacement': True})
return nn.reshape(out, shape + [-1], name=name)
def kl_divergence(self, other):
"""The KL-divergence between two Categorical distributions.
Args:
other (Categorical): instance of Categorical. The data type is float32.
Returns:
Variable: kl-divergence between two Categorical distributions.
"""
name = self.name + '_kl_divergence'
if not in_dygraph_mode():
check_type(other, 'other', Categorical, 'kl_divergence')
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
other_logits = other.logits - nn.reduce_max(
other.logits, dim=-1, keep_dim=True)
e_logits = ops.exp(logits)
other_e_logits = ops.exp(other_logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
other_z = nn.reduce_sum(other_e_logits, dim=-1, keep_dim=True)
prob = e_logits / z
kl = nn.reduce_sum(
prob * (logits - nn.log(z) - other_logits + nn.log(other_z)),
dim=-1,
keep_dim=True,
name=name)
return kl
def entropy(self):
"""Shannon entropy in nats.
Returns:
Variable: Shannon entropy of Categorical distribution. The data type is float32.
"""
name = self.name + '_entropy'
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
e_logits = ops.exp(logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
prob = e_logits / z
entropy = -1.0 * nn.reduce_sum(
prob * (logits - nn.log(z)), dim=-1, keep_dim=True, name=name)
return entropy
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册