未验证 提交 9b3ef597 编写于 作者: P pangyoki 提交者: GitHub

add categorical class (#27695)

* add multinomial cpu kernel

* fix C++ notype error

* fix windows ci array len error

* let array len be const

* change array to vector

* add cuda kernrl with num_distribution is 1, and not support replacement=False

* add multinomial python api

* support num_distribution different multinomial distributions

* add categorical class

* fix test_distribution enable_static error

* add unittest for different setting of Categorical

* optimize format

* little change

* little change

* add raise error if shape not match, optimize format

* fix windows CI dtype error in concat

* little changes

* little changes2

* change values type to int64

* change values type to int64

* change values type to int64
上级 4d3eefbb
......@@ -28,13 +28,14 @@ from .fluid.layers import nn
from .fluid import core
from .fluid.framework import in_dygraph_mode
from .tensor.math import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub
from .tensor import arange, gather_nd, concat, multinomial
import math
import numpy as np
import warnings
from .fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
__all__ = ['Distribution', 'Uniform', 'Normal']
__all__ = ['Distribution', 'Uniform', 'Normal', 'Categorical']
class Distribution(object):
......@@ -640,3 +641,318 @@ class Normal(Distribution):
t1 = (t1 * t1)
return elementwise_add(
0.5 * var_ratio, 0.5 * (t1 - 1. - nn.log(var_ratio)), name=name)
class Categorical(Distribution):
"""
Categorical distribution is a discrete probability distribution that
describes the possible results of a random variable that can take on
one of K possible categories, with the probability of each category
separately specified.
The probability mass function (pmf) is:
.. math::
pmf(k; p_i) = \prod_{i=1}^{k} p_i^{[x=i]}
In the above equation:
* :math:`[x=i]` : it evaluates to 1 if :math:`x==i` , 0 otherwise.
Args:
logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]
cat = Categorical(x)
cat2 = Categorical(y)
cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]
cat.entropy()
# [1.71887]
cat.kl_divergence(cat2)
# [0.0278455]
value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]
cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]
"""
def __init__(self, logits, name=None):
"""
Args:
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32 or float64.
"""
if not in_dygraph_mode():
check_type(logits, 'logits', (np.ndarray, tensor.Variable, list),
'Categorical')
self.name = name if name is not None else 'Categorical'
self.dtype = 'float32'
if self._validate_args(logits):
self.logits = logits
self.dtype = convert_dtype(logits.dtype)
else:
if isinstance(logits, np.ndarray) and str(
logits.dtype) in ['float32', 'float64']:
self.dtype = logits.dtype
self.logits = self._to_tensor(logits)[0]
if self.dtype != convert_dtype(self.logits.dtype):
self.logits = tensor.cast(self.logits, dtype=self.dtype)
def sample(self, shape):
"""Generate samples of the specified shape.
Args:
shape (list): Shape of the generated samples.
Returns:
Tensor: A tensor with prepended dimensions shape.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
cat = Categorical(x)
cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]
"""
name = self.name + '_sample'
if not in_dygraph_mode():
check_type(shape, 'shape', (list), 'sample')
num_samples = np.prod(np.array(shape))
logits_shape = list(self.logits.shape)
if len(logits_shape) > 1:
sample_shape = shape + logits_shape[:-1]
logits = nn.reshape(self.logits,
[np.prod(logits_shape[:-1]), logits_shape[-1]])
else:
sample_shape = shape
logits = self.logits
sample_index = multinomial(logits, num_samples, True)
return nn.reshape(sample_index, sample_shape, name=name)
def kl_divergence(self, other):
"""The KL-divergence between two Categorical distributions.
Args:
other (Categorical): instance of Categorical. The data type is float32.
Returns:
Variable: kl-divergence between two Categorical distributions.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]
cat = Categorical(x)
cat2 = Categorical(y)
cat.kl_divergence(cat2)
# [0.0278455]
"""
name = self.name + '_kl_divergence'
if not in_dygraph_mode():
check_type(other, 'other', Categorical, 'kl_divergence')
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
other_logits = other.logits - nn.reduce_max(
other.logits, dim=-1, keep_dim=True)
e_logits = ops.exp(logits)
other_e_logits = ops.exp(other_logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
other_z = nn.reduce_sum(other_e_logits, dim=-1, keep_dim=True)
prob = e_logits / z
kl = nn.reduce_sum(
prob * (logits - nn.log(z) - other_logits + nn.log(other_z)),
dim=-1,
keep_dim=True,
name=name)
return kl
def entropy(self):
"""Shannon entropy in nats.
Returns:
Variable: Shannon entropy of Categorical distribution. The data type is float32.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
cat = Categorical(x)
cat.entropy()
# [1.71887]
"""
name = self.name + '_entropy'
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
e_logits = ops.exp(logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
prob = e_logits / z
neg_entropy = nn.reduce_sum(
prob * (logits - nn.log(z)), dim=-1, keep_dim=True)
entropy = nn.scale(neg_entropy, scale=-1.0, name=name)
return entropy
def probs(self, value):
"""Probabilities of the given category (``value``).
If ``logits`` is 2-D or higher dimension, the last dimension will be regarded as
category, and the others represents the different distributions.
At the same time, if ``vlaue`` is 1-D Tensor, ``value`` will be broadcast to the
same number of distributions as ``logits``.
If ``value`` is not 1-D Tensor, ``value`` should have the same number distributions
with ``logits. That is, ``value[:-1] = logits[:-1]``.
Args:
value (Tensor): The input tensor represents the selected category index.
Returns:
Tensor: probability according to the category index.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
cat = Categorical(x)
value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]
"""
name = self.name + '_probs'
dist_sum = nn.reduce_sum(self.logits, dim=-1, keep_dim=True)
prob = self.logits / dist_sum
shape = list(prob.shape)
value_shape = list(value.shape)
if len(shape) == 1:
num_value_in_one_dist = np.prod(value_shape)
index_value = nn.reshape(value, [num_value_in_one_dist, 1])
index = index_value
else:
num_dist = np.prod(shape[:-1])
num_value_in_one_dist = value_shape[-1]
prob = nn.reshape(prob, [num_dist, shape[-1]])
if len(value_shape) == 1:
value = nn.expand(value, [num_dist])
value_shape = shape[:-1] + value_shape
index_value = nn.reshape(value, [num_dist, -1, 1])
if shape[:-1] != value_shape[:-1]:
raise ValueError(
"shape of value {} must match shape of logits {}".format(
str(value_shape[:-1]), str(shape[:-1])))
index_prefix = nn.unsqueeze(
arange(
num_dist, dtype=index_value.dtype), axes=-1)
index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist])
index_prefix = nn.unsqueeze(index_prefix, axes=-1)
if index_value.dtype != index_prefix.dtype:
tensor.cast(index_prefix, dtype=index_value.dtype)
index = concat([index_prefix, index_value], axis=-1)
# value is the category index to search for the corresponding probability.
select_prob = gather_nd(prob, index)
return nn.reshape(select_prob, value_shape, name=name)
def log_prob(self, value):
"""Log probabilities of the given category. Refer to ``probs`` method.
Args:
value (Tensor): The input tensor represents the selected category index.
Returns:
Tensor: Log probability.
Examples:
.. code-block:: python
import paddle
from paddle.distribution import Categorical
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
cat = Categorical(x)
value = paddle.to_tensor([2,1,3])
cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]
"""
name = self.name + '_log_prob'
return nn.log(self.probs(value), name=name)
......@@ -65,41 +65,6 @@ class UniformNumpy(DistributionNumpy):
return np.log(self.high - self.low)
class NormalNumpy(DistributionNumpy):
def __init__(self, loc, scale):
self.loc = np.array(loc)
self.scale = np.array(scale)
if str(self.loc.dtype) not in ['float32', 'float64']:
self.loc = self.loc.astype('float32')
self.scale = self.scale.astype('float32')
def sample(self, shape):
shape = tuple(shape) + (self.loc + self.scale).shape
return self.loc + (np.random.randn(*shape) * self.scale)
def log_prob(self, value):
var = self.scale * self.scale
log_scale = np.log(self.scale)
return -((value - self.loc) * (value - self.loc)) / (
2. * var) - log_scale - math.log(math.sqrt(2. * math.pi))
def probs(self, value):
var = self.scale * self.scale
return np.exp(-1. * ((value - self.loc) * (value - self.loc)) /
(2. * var)) / (math.sqrt(2 * math.pi) * self.scale)
def entropy(self):
return 0.5 + 0.5 * np.log(
np.array(2. * math.pi).astype(self.loc.dtype)) + np.log(self.scale)
def kl_divergence(self, other):
var_ratio = (self.scale / other.scale)
var_ratio = var_ratio * var_ratio
t1 = ((self.loc - other.loc) / other.scale)
t1 = (t1 * t1)
return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio))
class UniformTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=5, dims=6):
self.use_gpu = use_gpu
......@@ -336,6 +301,41 @@ class UniformTest9(UniformTest):
name='values', shape=[dims], dtype='float32')
class NormalNumpy(DistributionNumpy):
def __init__(self, loc, scale):
self.loc = np.array(loc)
self.scale = np.array(scale)
if str(self.loc.dtype) not in ['float32', 'float64']:
self.loc = self.loc.astype('float32')
self.scale = self.scale.astype('float32')
def sample(self, shape):
shape = tuple(shape) + (self.loc + self.scale).shape
return self.loc + (np.random.randn(*shape) * self.scale)
def log_prob(self, value):
var = self.scale * self.scale
log_scale = np.log(self.scale)
return -((value - self.loc) * (value - self.loc)) / (
2. * var) - log_scale - math.log(math.sqrt(2. * math.pi))
def probs(self, value):
var = self.scale * self.scale
return np.exp(-1. * ((value - self.loc) * (value - self.loc)) /
(2. * var)) / (math.sqrt(2 * math.pi) * self.scale)
def entropy(self):
return 0.5 + 0.5 * np.log(
np.array(2. * math.pi).astype(self.loc.dtype)) + np.log(self.scale)
def kl_divergence(self, other):
var_ratio = (self.scale / other.scale)
var_ratio = var_ratio * var_ratio
t1 = ((self.loc - other.loc) / other.scale)
t1 = (t1 * t1)
return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio))
class NormalTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=2, dims=3):
self.use_gpu = use_gpu
......@@ -559,26 +559,6 @@ class NormalTest5(NormalTest):
class NormalTest6(NormalTest):
def init_data(self, batch_size=2, dims=3):
# loc and scale are Tensor with dtype 'VarType.FP32'.
self.loc_np = np.random.randn(batch_size, dims).astype('float32')
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
while not np.all(self.scale_np > 0):
self.scale_np = np.random.randn(batch_size, dims).astype('float32')
self.values_np = np.random.randn(batch_size, dims).astype('float32')
self.loc = paddle.to_tensor(self.loc_np)
self.scale = paddle.to_tensor(self.scale_np)
self.values = paddle.to_tensor(self.values_np)
# used to construct another Normal object to calculate kl_divergence
self.other_loc_np = np.random.randn(batch_size, dims).astype('float32')
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
while not np.all(self.scale_np > 0):
self.other_scale_np = np.random.randn(batch_size,
dims).astype('float32')
self.other_loc = paddle.to_tensor(self.other_loc_np)
self.other_scale = paddle.to_tensor(self.other_scale_np)
def init_numpy_data(self, batch_size, dims):
# loc and scale are Tensor with dtype 'VarType.FP32'.
self.loc_np = np.random.randn(batch_size, dims).astype('float32')
......@@ -693,6 +673,294 @@ class NormalTest8(NormalTest):
name='other_scale', shape=[dims], dtype='float64')
class CategoricalNumpy(DistributionNumpy):
def __init__(self, logits):
self.logits = np.array(logits).astype('float32')
def entropy(self):
logits = self.logits - np.max(self.logits, axis=-1, keepdims=True)
e_logits = np.exp(logits)
z = np.sum(e_logits, axis=-1, keepdims=True)
prob = e_logits / z
return -1. * np.sum(prob * (logits - np.log(z)), axis=-1, keepdims=True)
def kl_divergence(self, other):
logits = self.logits - np.max(self.logits, axis=-1, keepdims=True)
other_logits = other.logits - np.max(
other.logits, axis=-1, keepdims=True)
e_logits = np.exp(logits)
other_e_logits = np.exp(other_logits)
z = np.sum(e_logits, axis=-1, keepdims=True)
other_z = np.sum(other_e_logits, axis=-1, keepdims=True)
prob = e_logits / z
return np.sum(prob * (logits - np.log(z) - other_logits \
+ np.log(other_z)), axis=-1, keepdims=True)
class CategoricalTest(unittest.TestCase):
def setUp(self, use_gpu=False, batch_size=3, dims=5):
self.use_gpu = use_gpu
if not use_gpu:
self.place = fluid.CPUPlace()
self.gpu_id = -1
else:
self.place = fluid.CUDAPlace(0)
self.gpu_id = 0
self.batch_size = batch_size
self.dims = dims
self.init_numpy_data(batch_size, dims)
paddle.disable_static(self.place)
self.init_dynamic_data(batch_size, dims)
paddle.enable_static()
self.test_program = fluid.Program()
self.executor = fluid.Executor(self.place)
self.init_static_data(batch_size, dims)
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float32')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
# dist_shape = logits_shape[:-1], it represents the number of
# different distributions.
self.dist_shape = [batch_size]
# sample shape represents the number of samples
self.sample_shape = [2, 4]
# value used in probs and log_prob method
# If value is 1-D and logits is 2-D or higher dimension, value will be
# broadcasted to have the same number of distributions with logits.
# If value is 2-D or higher dimentsion, it should have the same number
# of distributions with logtis. ``value[:-1] = logits[:-1]
self.value_shape = [3]
def init_dynamic_data(self, batch_size, dims):
self.logits = paddle.to_tensor(self.logits_np)
self.other_logits = paddle.to_tensor(self.other_logits_np)
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = fluid.data(
name='logits', shape=self.logits_shape, dtype='float32')
self.other_logits_static = fluid.data(
name='other_logits', shape=self.logits_shape, dtype='float32')
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.dist_shape + self.value_shape)
for i in range(self.batch_size):
for j in range(3):
np_probs[i][j] = probability[i][self.value_np[j]]
return np_probs
def compare_with_numpy(self, fetch_list, tolerance=1e-6):
sample, entropy, kl, probs, log_prob = fetch_list
log_tolerance = 1e-4
np.testing.assert_equal(sample.shape,
self.sample_shape + self.dist_shape)
np_categorical = CategoricalNumpy(self.logits_np)
np_other_categorical = CategoricalNumpy(self.other_logits_np)
np_entropy = np_categorical.entropy()
np_kl = np_categorical.kl_divergence(np_other_categorical)
np.testing.assert_allclose(
entropy, np_entropy, rtol=log_tolerance, atol=log_tolerance)
np.testing.assert_allclose(
kl, np_kl, rtol=log_tolerance, atol=log_tolerance)
sum_dist = np.sum(self.logits_np, axis=-1, keepdims=True)
probability = self.logits_np / sum_dist
np_probs = self.get_numpy_selected_probs(probability)
np_log_prob = np.log(np_probs)
np.testing.assert_allclose(
probs, np_probs, rtol=tolerance, atol=tolerance)
np.testing.assert_allclose(
log_prob, np_log_prob, rtol=tolerance, atol=tolerance)
def test_categorical_distribution_dygraph(self, tolerance=1e-6):
paddle.disable_static(self.place)
categorical = Categorical(self.logits)
other_categorical = Categorical(self.other_logits)
sample = categorical.sample(self.sample_shape).numpy()
entropy = categorical.entropy().numpy()
kl = categorical.kl_divergence(other_categorical).numpy()
probs = categorical.probs(self.value).numpy()
log_prob = categorical.log_prob(self.value).numpy()
fetch_list = [sample, entropy, kl, probs, log_prob]
self.compare_with_numpy(fetch_list)
def test_categorical_distribution_static(self, tolerance=1e-6):
paddle.enable_static()
with fluid.program_guard(self.test_program):
categorical = Categorical(self.logits_static)
other_categorical = Categorical(self.other_logits_static)
sample = categorical.sample(self.sample_shape)
entropy = categorical.entropy()
kl = categorical.kl_divergence(other_categorical)
probs = categorical.probs(self.value_static)
log_prob = categorical.log_prob(self.value_static)
fetch_list = [sample, entropy, kl, probs, log_prob]
feed_vars = {
'logits': self.logits_np,
'other_logits': self.other_logits_np,
'value': self.value_np
}
self.executor.run(fluid.default_startup_program())
fetch_list = self.executor.run(program=self.test_program,
feed=feed_vars,
fetch_list=fetch_list)
self.compare_with_numpy(fetch_list)
class CategoricalTest2(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor with dtype Float64
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float64')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float64')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
self.dist_shape = [batch_size]
self.sample_shape = [2, 4]
self.value_shape = [3]
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = fluid.data(
name='logits', shape=self.logits_shape, dtype='float64')
self.other_logits_static = fluid.data(
name='other_logits', shape=self.logits_shape, dtype='float64')
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest3(CategoricalTest):
def init_dynamic_data(self, batch_size, dims):
# input logtis is 2-D numpy.ndarray with dtype Float32
# value used in probs and log_prob method is 1-D Tensor
self.logits = self.logits_np
self.other_logits = self.other_logits_np
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np
self.other_logits_static = self.other_logits_np
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
class CategoricalTest4(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D numpy.ndarray with dtype Float64
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(batch_size, dims).astype('float64')
self.other_logits_np = np.random.rand(batch_size,
dims).astype('float64')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [batch_size, dims]
self.dist_shape = [batch_size]
self.sample_shape = [2, 4]
self.value_shape = [3]
def init_dynamic_data(self, batch_size, dims):
self.logits = self.logits_np
self.other_logits = self.other_logits_np
self.value = paddle.to_tensor(self.value_np)
def init_static_data(self, batch_size, dims):
with fluid.program_guard(self.test_program):
self.logits_static = self.logits_np
self.other_logits_static = self.other_logits_np
self.value_static = fluid.data(
name='value', shape=self.value_shape, dtype='int64')
# test shape of logits and value used in probs and log_prob method
class CategoricalTest5(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 1-D Tensor
# value used in probs and log_prob method is 1-D Tensor
self.logits_np = np.random.rand(dims).astype('float32')
self.other_logits_np = np.random.rand(dims).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [dims]
self.dist_shape = []
self.sample_shape = [2, 4]
self.value_shape = [3]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.value_shape)
for i in range(3):
np_probs[i] = probability[self.value_np[i]]
return np_probs
class CategoricalTest6(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 2-D Tensor
# value used in probs and log_prob method has the same number of batches with input
self.logits_np = np.random.rand(3, 5).astype('float32')
self.other_logits_np = np.random.rand(3, 5).astype('float32')
self.value_np = np.array([[2, 1], [0, 3], [2, 3]]).astype('int64')
self.logits_shape = [3, 5]
self.dist_shape = [3]
self.sample_shape = [2, 4]
self.value_shape = [3, 2]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.value_shape)
for i in range(3):
for j in range(2):
np_probs[i][j] = probability[i][self.value_np[i][j]]
return np_probs
class CategoricalTest7(CategoricalTest):
def init_numpy_data(self, batch_size, dims):
# input logtis is 3-D Tensor
# value used in probs and log_prob method has the same number of distribuions with input
self.logits_np = np.random.rand(3, 2, 5).astype('float32')
self.other_logits_np = np.random.rand(3, 2, 5).astype('float32')
self.value_np = np.array([2, 1, 3]).astype('int64')
self.logits_shape = [3, 2, 5]
self.dist_shape = [3, 2]
self.sample_shape = [2, 4]
self.value_shape = [3]
def get_numpy_selected_probs(self, probability):
np_probs = np.zeros(self.dist_shape + self.value_shape)
for i in range(3):
for j in range(2):
for k in range(3):
np_probs[i][j][k] = probability[i][j][self.value_np[k]]
return np_probs
class DistributionTestError(unittest.TestCase):
def test_distribution_error(self):
distribution = Distribution()
......@@ -711,6 +979,7 @@ class DistributionTestError(unittest.TestCase):
self.assertRaises(NotImplementedError, distribution.probs, value_tensor)
def test_normal_error(self):
paddle.enable_static()
normal = Normal(0.0, 1.0)
value = [1.0, 2.0]
......@@ -734,6 +1003,7 @@ class DistributionTestError(unittest.TestCase):
self.assertRaises(TypeError, normal.kl_divergence, normal_other)
def test_uniform_error(self):
paddle.enable_static()
uniform = Uniform(0.0, 1.0)
value = [1.0, 2.0]
......@@ -752,6 +1022,39 @@ class DistributionTestError(unittest.TestCase):
# type of seed must be int
self.assertRaises(TypeError, uniform.sample, [2, 3], seed)
def test_categorical_error(self):
paddle.enable_static()
categorical = Categorical([0.4, 0.6])
value = [1, 0]
# type of value must be variable
self.assertRaises(AttributeError, categorical.log_prob, value)
value = [1, 0]
# type of value must be variable
self.assertRaises(AttributeError, categorical.probs, value)
shape = 1.0
# type of shape must be list
self.assertRaises(TypeError, categorical.sample, shape)
categorical_other = Uniform(1.0, 2.0)
# type of other must be an instance of Categorical
self.assertRaises(TypeError, categorical.kl_divergence,
categorical_other)
def test_shape_not_match_error():
# shape of value must match shape of logits
# value_shape[:-1] == logits_shape[:-1]
paddle.disable_static()
logits = paddle.rand([3, 5])
cat = Categorical(logits)
value = paddle.to_tensor([[2, 1, 3], [3, 2, 1]], dtype='int64')
cat.log_prob(value)
self.assertRaises(ValueError, test_shape_not_match_error)
class DistributionTestName(unittest.TestCase):
def get_prefix(self, string):
......@@ -812,6 +1115,35 @@ class DistributionTestName(unittest.TestCase):
p = uniform1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
def test_categorical_name(self):
name = 'test_categorical'
categorical1 = Categorical([0.4, 0.6], name=name)
self.assertEqual(categorical1.name, name)
categorical2 = Categorical([0.5, 0.5])
self.assertEqual(categorical2.name, 'Categorical')
paddle.enable_static()
sample = categorical1.sample([2])
self.assertEqual(self.get_prefix(sample.name), name + '_sample')
entropy = categorical1.entropy()
self.assertEqual(self.get_prefix(entropy.name), name + '_entropy')
kl = categorical1.kl_divergence(categorical2)
self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence')
value_npdata = np.array([0], dtype="int64")
value_tensor = layers.create_tensor(dtype="int64")
layers.assign(value_npdata, value_tensor)
p = categorical1.probs(value_tensor)
self.assertEqual(self.get_prefix(p.name), name + '_probs')
lp = categorical1.log_prob(value_tensor)
self.assertEqual(self.get_prefix(lp.name), name + '_log_prob')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册