diff --git a/python/paddle/distribution.py b/python/paddle/distribution.py index 35204affb3fd168b8bd137d78c3413a08885e2bb..ff3e882229ae8cbc809804e14d2af06455ffcbcf 100644 --- a/python/paddle/distribution.py +++ b/python/paddle/distribution.py @@ -28,13 +28,14 @@ from .fluid.layers import nn from .fluid import core from .fluid.framework import in_dygraph_mode from .tensor.math import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub +from .tensor import arange, gather_nd, concat, multinomial import math import numpy as np import warnings from .fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype -__all__ = ['Distribution', 'Uniform', 'Normal'] +__all__ = ['Distribution', 'Uniform', 'Normal', 'Categorical'] class Distribution(object): @@ -640,3 +641,318 @@ class Normal(Distribution): t1 = (t1 * t1) return elementwise_add( 0.5 * var_ratio, 0.5 * (t1 - 1. - nn.log(var_ratio)), name=name) + + +class Categorical(Distribution): + """ + Categorical distribution is a discrete probability distribution that + describes the possible results of a random variable that can take on + one of K possible categories, with the probability of each category + separately specified. + + The probability mass function (pmf) is: + + .. math:: + + pmf(k; p_i) = \prod_{i=1}^{k} p_i^{[x=i]} + + In the above equation: + + * :math:`[x=i]` : it evaluates to 1 if :math:`x==i` , 0 otherwise. + + Args: + logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + x = paddle.rand([6]) + print(x.numpy()) + # [0.32564053, 0.99334985, 0.99034804, + # 0.09053693, 0.30820143, 0.19095989] + y = paddle.rand([6]) + print(y.numpy()) + # [0.6365463 , 0.7278677 , 0.90260243, + # 0.5226815 , 0.35837543, 0.13981032] + + cat = Categorical(x) + cat2 = Categorical(y) + + cat.sample([2,3]) + # [[5, 1, 1], + # [0, 1, 2]] + + cat.entropy() + # [1.71887] + + cat.kl_divergence(cat2) + # [0.0278455] + + value = paddle.to_tensor([2,1,3]) + cat.probs(value) + # [0.341613 0.342648 0.03123] + + cat.log_prob(value) + # [-1.07408 -1.07105 -3.46638] + + """ + + def __init__(self, logits, name=None): + """ + Args: + logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32 or float64. + """ + if not in_dygraph_mode(): + check_type(logits, 'logits', (np.ndarray, tensor.Variable, list), + 'Categorical') + + self.name = name if name is not None else 'Categorical' + self.dtype = 'float32' + + if self._validate_args(logits): + self.logits = logits + self.dtype = convert_dtype(logits.dtype) + else: + if isinstance(logits, np.ndarray) and str( + logits.dtype) in ['float32', 'float64']: + self.dtype = logits.dtype + self.logits = self._to_tensor(logits)[0] + if self.dtype != convert_dtype(self.logits.dtype): + self.logits = tensor.cast(self.logits, dtype=self.dtype) + + def sample(self, shape): + """Generate samples of the specified shape. + + Args: + shape (list): Shape of the generated samples. + + Returns: + Tensor: A tensor with prepended dimensions shape. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + x = paddle.rand([6]) + print(x.numpy()) + # [0.32564053, 0.99334985, 0.99034804, + # 0.09053693, 0.30820143, 0.19095989] + + cat = Categorical(x) + + cat.sample([2,3]) + # [[5, 1, 1], + # [0, 1, 2]] + + """ + name = self.name + '_sample' + if not in_dygraph_mode(): + check_type(shape, 'shape', (list), 'sample') + + num_samples = np.prod(np.array(shape)) + + logits_shape = list(self.logits.shape) + if len(logits_shape) > 1: + sample_shape = shape + logits_shape[:-1] + logits = nn.reshape(self.logits, + [np.prod(logits_shape[:-1]), logits_shape[-1]]) + else: + sample_shape = shape + logits = self.logits + + sample_index = multinomial(logits, num_samples, True) + return nn.reshape(sample_index, sample_shape, name=name) + + def kl_divergence(self, other): + """The KL-divergence between two Categorical distributions. + + Args: + other (Categorical): instance of Categorical. The data type is float32. + + Returns: + Variable: kl-divergence between two Categorical distributions. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + x = paddle.rand([6]) + print(x.numpy()) + # [0.32564053, 0.99334985, 0.99034804, + # 0.09053693, 0.30820143, 0.19095989] + y = paddle.rand([6]) + print(y.numpy()) + # [0.6365463 , 0.7278677 , 0.90260243, + # 0.5226815 , 0.35837543, 0.13981032] + + cat = Categorical(x) + cat2 = Categorical(y) + + cat.kl_divergence(cat2) + # [0.0278455] + + """ + name = self.name + '_kl_divergence' + if not in_dygraph_mode(): + check_type(other, 'other', Categorical, 'kl_divergence') + + logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) + other_logits = other.logits - nn.reduce_max( + other.logits, dim=-1, keep_dim=True) + e_logits = ops.exp(logits) + other_e_logits = ops.exp(other_logits) + z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True) + other_z = nn.reduce_sum(other_e_logits, dim=-1, keep_dim=True) + prob = e_logits / z + kl = nn.reduce_sum( + prob * (logits - nn.log(z) - other_logits + nn.log(other_z)), + dim=-1, + keep_dim=True, + name=name) + + return kl + + def entropy(self): + """Shannon entropy in nats. + + Returns: + Variable: Shannon entropy of Categorical distribution. The data type is float32. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + x = paddle.rand([6]) + print(x.numpy()) + # [0.32564053, 0.99334985, 0.99034804, + # 0.09053693, 0.30820143, 0.19095989] + + cat = Categorical(x) + + cat.entropy() + # [1.71887] + + """ + name = self.name + '_entropy' + logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) + e_logits = ops.exp(logits) + z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True) + prob = e_logits / z + + neg_entropy = nn.reduce_sum( + prob * (logits - nn.log(z)), dim=-1, keep_dim=True) + entropy = nn.scale(neg_entropy, scale=-1.0, name=name) + return entropy + + def probs(self, value): + """Probabilities of the given category (``value``). + + If ``logits`` is 2-D or higher dimension, the last dimension will be regarded as + category, and the others represents the different distributions. + At the same time, if ``vlaue`` is 1-D Tensor, ``value`` will be broadcast to the + same number of distributions as ``logits``. + If ``value`` is not 1-D Tensor, ``value`` should have the same number distributions + with ``logits. That is, ``value[:-1] = logits[:-1]``. + + Args: + value (Tensor): The input tensor represents the selected category index. + + Returns: + Tensor: probability according to the category index. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + x = paddle.rand([6]) + print(x.numpy()) + # [0.32564053, 0.99334985, 0.99034804, + # 0.09053693, 0.30820143, 0.19095989] + + cat = Categorical(x) + + value = paddle.to_tensor([2,1,3]) + cat.probs(value) + # [0.341613 0.342648 0.03123] + + """ + name = self.name + '_probs' + + dist_sum = nn.reduce_sum(self.logits, dim=-1, keep_dim=True) + prob = self.logits / dist_sum + + shape = list(prob.shape) + value_shape = list(value.shape) + if len(shape) == 1: + num_value_in_one_dist = np.prod(value_shape) + index_value = nn.reshape(value, [num_value_in_one_dist, 1]) + index = index_value + else: + num_dist = np.prod(shape[:-1]) + num_value_in_one_dist = value_shape[-1] + prob = nn.reshape(prob, [num_dist, shape[-1]]) + if len(value_shape) == 1: + value = nn.expand(value, [num_dist]) + value_shape = shape[:-1] + value_shape + index_value = nn.reshape(value, [num_dist, -1, 1]) + if shape[:-1] != value_shape[:-1]: + raise ValueError( + "shape of value {} must match shape of logits {}".format( + str(value_shape[:-1]), str(shape[:-1]))) + + index_prefix = nn.unsqueeze( + arange( + num_dist, dtype=index_value.dtype), axes=-1) + index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist]) + index_prefix = nn.unsqueeze(index_prefix, axes=-1) + + if index_value.dtype != index_prefix.dtype: + tensor.cast(index_prefix, dtype=index_value.dtype) + index = concat([index_prefix, index_value], axis=-1) + + # value is the category index to search for the corresponding probability. + select_prob = gather_nd(prob, index) + return nn.reshape(select_prob, value_shape, name=name) + + def log_prob(self, value): + """Log probabilities of the given category. Refer to ``probs`` method. + + Args: + value (Tensor): The input tensor represents the selected category index. + + Returns: + Tensor: Log probability. + + Examples: + .. code-block:: python + + import paddle + from paddle.distribution import Categorical + + x = paddle.rand([6]) + print(x.numpy()) + # [0.32564053, 0.99334985, 0.99034804, + # 0.09053693, 0.30820143, 0.19095989] + + cat = Categorical(x) + + value = paddle.to_tensor([2,1,3]) + + cat.log_prob(value) + # [-1.07408 -1.07105 -3.46638] + + """ + name = self.name + '_log_prob' + + return nn.log(self.probs(value), name=name) diff --git a/python/paddle/fluid/tests/unittests/test_distribution.py b/python/paddle/fluid/tests/unittests/test_distribution.py index 40611fed65260765f8b634448413fd22f245541c..d5790811df94f3938faeeb6efa1cb51090366787 100644 --- a/python/paddle/fluid/tests/unittests/test_distribution.py +++ b/python/paddle/fluid/tests/unittests/test_distribution.py @@ -65,41 +65,6 @@ class UniformNumpy(DistributionNumpy): return np.log(self.high - self.low) -class NormalNumpy(DistributionNumpy): - def __init__(self, loc, scale): - self.loc = np.array(loc) - self.scale = np.array(scale) - if str(self.loc.dtype) not in ['float32', 'float64']: - self.loc = self.loc.astype('float32') - self.scale = self.scale.astype('float32') - - def sample(self, shape): - shape = tuple(shape) + (self.loc + self.scale).shape - return self.loc + (np.random.randn(*shape) * self.scale) - - def log_prob(self, value): - var = self.scale * self.scale - log_scale = np.log(self.scale) - return -((value - self.loc) * (value - self.loc)) / ( - 2. * var) - log_scale - math.log(math.sqrt(2. * math.pi)) - - def probs(self, value): - var = self.scale * self.scale - return np.exp(-1. * ((value - self.loc) * (value - self.loc)) / - (2. * var)) / (math.sqrt(2 * math.pi) * self.scale) - - def entropy(self): - return 0.5 + 0.5 * np.log( - np.array(2. * math.pi).astype(self.loc.dtype)) + np.log(self.scale) - - def kl_divergence(self, other): - var_ratio = (self.scale / other.scale) - var_ratio = var_ratio * var_ratio - t1 = ((self.loc - other.loc) / other.scale) - t1 = (t1 * t1) - return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) - - class UniformTest(unittest.TestCase): def setUp(self, use_gpu=False, batch_size=5, dims=6): self.use_gpu = use_gpu @@ -336,6 +301,41 @@ class UniformTest9(UniformTest): name='values', shape=[dims], dtype='float32') +class NormalNumpy(DistributionNumpy): + def __init__(self, loc, scale): + self.loc = np.array(loc) + self.scale = np.array(scale) + if str(self.loc.dtype) not in ['float32', 'float64']: + self.loc = self.loc.astype('float32') + self.scale = self.scale.astype('float32') + + def sample(self, shape): + shape = tuple(shape) + (self.loc + self.scale).shape + return self.loc + (np.random.randn(*shape) * self.scale) + + def log_prob(self, value): + var = self.scale * self.scale + log_scale = np.log(self.scale) + return -((value - self.loc) * (value - self.loc)) / ( + 2. * var) - log_scale - math.log(math.sqrt(2. * math.pi)) + + def probs(self, value): + var = self.scale * self.scale + return np.exp(-1. * ((value - self.loc) * (value - self.loc)) / + (2. * var)) / (math.sqrt(2 * math.pi) * self.scale) + + def entropy(self): + return 0.5 + 0.5 * np.log( + np.array(2. * math.pi).astype(self.loc.dtype)) + np.log(self.scale) + + def kl_divergence(self, other): + var_ratio = (self.scale / other.scale) + var_ratio = var_ratio * var_ratio + t1 = ((self.loc - other.loc) / other.scale) + t1 = (t1 * t1) + return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) + + class NormalTest(unittest.TestCase): def setUp(self, use_gpu=False, batch_size=2, dims=3): self.use_gpu = use_gpu @@ -559,26 +559,6 @@ class NormalTest5(NormalTest): class NormalTest6(NormalTest): - def init_data(self, batch_size=2, dims=3): - # loc and scale are Tensor with dtype 'VarType.FP32'. - self.loc_np = np.random.randn(batch_size, dims).astype('float32') - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - while not np.all(self.scale_np > 0): - self.scale_np = np.random.randn(batch_size, dims).astype('float32') - self.values_np = np.random.randn(batch_size, dims).astype('float32') - self.loc = paddle.to_tensor(self.loc_np) - self.scale = paddle.to_tensor(self.scale_np) - self.values = paddle.to_tensor(self.values_np) - # used to construct another Normal object to calculate kl_divergence - self.other_loc_np = np.random.randn(batch_size, dims).astype('float32') - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - while not np.all(self.scale_np > 0): - self.other_scale_np = np.random.randn(batch_size, - dims).astype('float32') - self.other_loc = paddle.to_tensor(self.other_loc_np) - self.other_scale = paddle.to_tensor(self.other_scale_np) - def init_numpy_data(self, batch_size, dims): # loc and scale are Tensor with dtype 'VarType.FP32'. self.loc_np = np.random.randn(batch_size, dims).astype('float32') @@ -693,6 +673,294 @@ class NormalTest8(NormalTest): name='other_scale', shape=[dims], dtype='float64') +class CategoricalNumpy(DistributionNumpy): + def __init__(self, logits): + self.logits = np.array(logits).astype('float32') + + def entropy(self): + logits = self.logits - np.max(self.logits, axis=-1, keepdims=True) + e_logits = np.exp(logits) + z = np.sum(e_logits, axis=-1, keepdims=True) + prob = e_logits / z + return -1. * np.sum(prob * (logits - np.log(z)), axis=-1, keepdims=True) + + def kl_divergence(self, other): + logits = self.logits - np.max(self.logits, axis=-1, keepdims=True) + other_logits = other.logits - np.max( + other.logits, axis=-1, keepdims=True) + e_logits = np.exp(logits) + other_e_logits = np.exp(other_logits) + z = np.sum(e_logits, axis=-1, keepdims=True) + other_z = np.sum(other_e_logits, axis=-1, keepdims=True) + prob = e_logits / z + return np.sum(prob * (logits - np.log(z) - other_logits \ + + np.log(other_z)), axis=-1, keepdims=True) + + +class CategoricalTest(unittest.TestCase): + def setUp(self, use_gpu=False, batch_size=3, dims=5): + self.use_gpu = use_gpu + if not use_gpu: + self.place = fluid.CPUPlace() + self.gpu_id = -1 + else: + self.place = fluid.CUDAPlace(0) + self.gpu_id = 0 + + self.batch_size = batch_size + self.dims = dims + self.init_numpy_data(batch_size, dims) + + paddle.disable_static(self.place) + self.init_dynamic_data(batch_size, dims) + + paddle.enable_static() + self.test_program = fluid.Program() + self.executor = fluid.Executor(self.place) + self.init_static_data(batch_size, dims) + + def init_numpy_data(self, batch_size, dims): + # input logtis is 2-D Tensor + # value used in probs and log_prob method is 1-D Tensor + self.logits_np = np.random.rand(batch_size, dims).astype('float32') + self.other_logits_np = np.random.rand(batch_size, + dims).astype('float32') + self.value_np = np.array([2, 1, 3]).astype('int64') + + self.logits_shape = [batch_size, dims] + # dist_shape = logits_shape[:-1], it represents the number of + # different distributions. + self.dist_shape = [batch_size] + # sample shape represents the number of samples + self.sample_shape = [2, 4] + # value used in probs and log_prob method + # If value is 1-D and logits is 2-D or higher dimension, value will be + # broadcasted to have the same number of distributions with logits. + # If value is 2-D or higher dimentsion, it should have the same number + # of distributions with logtis. ``value[:-1] = logits[:-1] + self.value_shape = [3] + + def init_dynamic_data(self, batch_size, dims): + self.logits = paddle.to_tensor(self.logits_np) + self.other_logits = paddle.to_tensor(self.other_logits_np) + self.value = paddle.to_tensor(self.value_np) + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.logits_static = fluid.data( + name='logits', shape=self.logits_shape, dtype='float32') + self.other_logits_static = fluid.data( + name='other_logits', shape=self.logits_shape, dtype='float32') + self.value_static = fluid.data( + name='value', shape=self.value_shape, dtype='int64') + + def get_numpy_selected_probs(self, probability): + np_probs = np.zeros(self.dist_shape + self.value_shape) + for i in range(self.batch_size): + for j in range(3): + np_probs[i][j] = probability[i][self.value_np[j]] + return np_probs + + def compare_with_numpy(self, fetch_list, tolerance=1e-6): + sample, entropy, kl, probs, log_prob = fetch_list + log_tolerance = 1e-4 + + np.testing.assert_equal(sample.shape, + self.sample_shape + self.dist_shape) + + np_categorical = CategoricalNumpy(self.logits_np) + np_other_categorical = CategoricalNumpy(self.other_logits_np) + np_entropy = np_categorical.entropy() + np_kl = np_categorical.kl_divergence(np_other_categorical) + + np.testing.assert_allclose( + entropy, np_entropy, rtol=log_tolerance, atol=log_tolerance) + np.testing.assert_allclose( + kl, np_kl, rtol=log_tolerance, atol=log_tolerance) + + sum_dist = np.sum(self.logits_np, axis=-1, keepdims=True) + probability = self.logits_np / sum_dist + np_probs = self.get_numpy_selected_probs(probability) + np_log_prob = np.log(np_probs) + + np.testing.assert_allclose( + probs, np_probs, rtol=tolerance, atol=tolerance) + np.testing.assert_allclose( + log_prob, np_log_prob, rtol=tolerance, atol=tolerance) + + def test_categorical_distribution_dygraph(self, tolerance=1e-6): + paddle.disable_static(self.place) + categorical = Categorical(self.logits) + other_categorical = Categorical(self.other_logits) + + sample = categorical.sample(self.sample_shape).numpy() + entropy = categorical.entropy().numpy() + kl = categorical.kl_divergence(other_categorical).numpy() + probs = categorical.probs(self.value).numpy() + log_prob = categorical.log_prob(self.value).numpy() + + fetch_list = [sample, entropy, kl, probs, log_prob] + self.compare_with_numpy(fetch_list) + + def test_categorical_distribution_static(self, tolerance=1e-6): + paddle.enable_static() + with fluid.program_guard(self.test_program): + categorical = Categorical(self.logits_static) + other_categorical = Categorical(self.other_logits_static) + + sample = categorical.sample(self.sample_shape) + entropy = categorical.entropy() + kl = categorical.kl_divergence(other_categorical) + probs = categorical.probs(self.value_static) + log_prob = categorical.log_prob(self.value_static) + + fetch_list = [sample, entropy, kl, probs, log_prob] + + feed_vars = { + 'logits': self.logits_np, + 'other_logits': self.other_logits_np, + 'value': self.value_np + } + + self.executor.run(fluid.default_startup_program()) + fetch_list = self.executor.run(program=self.test_program, + feed=feed_vars, + fetch_list=fetch_list) + + self.compare_with_numpy(fetch_list) + + +class CategoricalTest2(CategoricalTest): + def init_numpy_data(self, batch_size, dims): + # input logtis is 2-D Tensor with dtype Float64 + # value used in probs and log_prob method is 1-D Tensor + self.logits_np = np.random.rand(batch_size, dims).astype('float64') + self.other_logits_np = np.random.rand(batch_size, + dims).astype('float64') + self.value_np = np.array([2, 1, 3]).astype('int64') + + self.logits_shape = [batch_size, dims] + self.dist_shape = [batch_size] + self.sample_shape = [2, 4] + self.value_shape = [3] + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.logits_static = fluid.data( + name='logits', shape=self.logits_shape, dtype='float64') + self.other_logits_static = fluid.data( + name='other_logits', shape=self.logits_shape, dtype='float64') + self.value_static = fluid.data( + name='value', shape=self.value_shape, dtype='int64') + + +class CategoricalTest3(CategoricalTest): + def init_dynamic_data(self, batch_size, dims): + # input logtis is 2-D numpy.ndarray with dtype Float32 + # value used in probs and log_prob method is 1-D Tensor + self.logits = self.logits_np + self.other_logits = self.other_logits_np + self.value = paddle.to_tensor(self.value_np) + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.logits_static = self.logits_np + self.other_logits_static = self.other_logits_np + self.value_static = fluid.data( + name='value', shape=self.value_shape, dtype='int64') + + +class CategoricalTest4(CategoricalTest): + def init_numpy_data(self, batch_size, dims): + # input logtis is 2-D numpy.ndarray with dtype Float64 + # value used in probs and log_prob method is 1-D Tensor + self.logits_np = np.random.rand(batch_size, dims).astype('float64') + self.other_logits_np = np.random.rand(batch_size, + dims).astype('float64') + self.value_np = np.array([2, 1, 3]).astype('int64') + + self.logits_shape = [batch_size, dims] + self.dist_shape = [batch_size] + self.sample_shape = [2, 4] + self.value_shape = [3] + + def init_dynamic_data(self, batch_size, dims): + self.logits = self.logits_np + self.other_logits = self.other_logits_np + self.value = paddle.to_tensor(self.value_np) + + def init_static_data(self, batch_size, dims): + with fluid.program_guard(self.test_program): + self.logits_static = self.logits_np + self.other_logits_static = self.other_logits_np + self.value_static = fluid.data( + name='value', shape=self.value_shape, dtype='int64') + + +# test shape of logits and value used in probs and log_prob method +class CategoricalTest5(CategoricalTest): + def init_numpy_data(self, batch_size, dims): + # input logtis is 1-D Tensor + # value used in probs and log_prob method is 1-D Tensor + self.logits_np = np.random.rand(dims).astype('float32') + self.other_logits_np = np.random.rand(dims).astype('float32') + self.value_np = np.array([2, 1, 3]).astype('int64') + + self.logits_shape = [dims] + self.dist_shape = [] + self.sample_shape = [2, 4] + self.value_shape = [3] + + def get_numpy_selected_probs(self, probability): + np_probs = np.zeros(self.value_shape) + for i in range(3): + np_probs[i] = probability[self.value_np[i]] + return np_probs + + +class CategoricalTest6(CategoricalTest): + def init_numpy_data(self, batch_size, dims): + # input logtis is 2-D Tensor + # value used in probs and log_prob method has the same number of batches with input + self.logits_np = np.random.rand(3, 5).astype('float32') + self.other_logits_np = np.random.rand(3, 5).astype('float32') + self.value_np = np.array([[2, 1], [0, 3], [2, 3]]).astype('int64') + + self.logits_shape = [3, 5] + self.dist_shape = [3] + self.sample_shape = [2, 4] + self.value_shape = [3, 2] + + def get_numpy_selected_probs(self, probability): + np_probs = np.zeros(self.value_shape) + for i in range(3): + for j in range(2): + np_probs[i][j] = probability[i][self.value_np[i][j]] + return np_probs + + +class CategoricalTest7(CategoricalTest): + def init_numpy_data(self, batch_size, dims): + # input logtis is 3-D Tensor + # value used in probs and log_prob method has the same number of distribuions with input + self.logits_np = np.random.rand(3, 2, 5).astype('float32') + self.other_logits_np = np.random.rand(3, 2, 5).astype('float32') + self.value_np = np.array([2, 1, 3]).astype('int64') + + self.logits_shape = [3, 2, 5] + self.dist_shape = [3, 2] + self.sample_shape = [2, 4] + self.value_shape = [3] + + def get_numpy_selected_probs(self, probability): + np_probs = np.zeros(self.dist_shape + self.value_shape) + for i in range(3): + for j in range(2): + for k in range(3): + np_probs[i][j][k] = probability[i][j][self.value_np[k]] + return np_probs + + class DistributionTestError(unittest.TestCase): def test_distribution_error(self): distribution = Distribution() @@ -711,6 +979,7 @@ class DistributionTestError(unittest.TestCase): self.assertRaises(NotImplementedError, distribution.probs, value_tensor) def test_normal_error(self): + paddle.enable_static() normal = Normal(0.0, 1.0) value = [1.0, 2.0] @@ -734,6 +1003,7 @@ class DistributionTestError(unittest.TestCase): self.assertRaises(TypeError, normal.kl_divergence, normal_other) def test_uniform_error(self): + paddle.enable_static() uniform = Uniform(0.0, 1.0) value = [1.0, 2.0] @@ -752,6 +1022,39 @@ class DistributionTestError(unittest.TestCase): # type of seed must be int self.assertRaises(TypeError, uniform.sample, [2, 3], seed) + def test_categorical_error(self): + paddle.enable_static() + + categorical = Categorical([0.4, 0.6]) + + value = [1, 0] + # type of value must be variable + self.assertRaises(AttributeError, categorical.log_prob, value) + + value = [1, 0] + # type of value must be variable + self.assertRaises(AttributeError, categorical.probs, value) + + shape = 1.0 + # type of shape must be list + self.assertRaises(TypeError, categorical.sample, shape) + + categorical_other = Uniform(1.0, 2.0) + # type of other must be an instance of Categorical + self.assertRaises(TypeError, categorical.kl_divergence, + categorical_other) + + def test_shape_not_match_error(): + # shape of value must match shape of logits + # value_shape[:-1] == logits_shape[:-1] + paddle.disable_static() + logits = paddle.rand([3, 5]) + cat = Categorical(logits) + value = paddle.to_tensor([[2, 1, 3], [3, 2, 1]], dtype='int64') + cat.log_prob(value) + + self.assertRaises(ValueError, test_shape_not_match_error) + class DistributionTestName(unittest.TestCase): def get_prefix(self, string): @@ -812,6 +1115,35 @@ class DistributionTestName(unittest.TestCase): p = uniform1.probs(value_tensor) self.assertEqual(self.get_prefix(p.name), name + '_probs') + def test_categorical_name(self): + name = 'test_categorical' + categorical1 = Categorical([0.4, 0.6], name=name) + self.assertEqual(categorical1.name, name) + + categorical2 = Categorical([0.5, 0.5]) + self.assertEqual(categorical2.name, 'Categorical') + + paddle.enable_static() + + sample = categorical1.sample([2]) + self.assertEqual(self.get_prefix(sample.name), name + '_sample') + + entropy = categorical1.entropy() + self.assertEqual(self.get_prefix(entropy.name), name + '_entropy') + + kl = categorical1.kl_divergence(categorical2) + self.assertEqual(self.get_prefix(kl.name), name + '_kl_divergence') + + value_npdata = np.array([0], dtype="int64") + value_tensor = layers.create_tensor(dtype="int64") + layers.assign(value_npdata, value_tensor) + + p = categorical1.probs(value_tensor) + self.assertEqual(self.get_prefix(p.name), name + '_probs') + + lp = categorical1.log_prob(value_tensor) + self.assertEqual(self.get_prefix(lp.name), name + '_log_prob') + if __name__ == '__main__': unittest.main()