diff --git a/python/paddle/distribution/categorical.py b/python/paddle/distribution/categorical.py index b181a25fbcee1ecebb7241bd991fc78e152cbea3..97a3df490b1d05070d6224866de0265b7f3c43df 100644 --- a/python/paddle/distribution/categorical.py +++ b/python/paddle/distribution/categorical.py @@ -115,6 +115,8 @@ class Categorical(distribution.Distribution): self.logits = self._to_tensor(logits)[0] if self.dtype != convert_dtype(self.logits.dtype): self.logits = tensor.cast(self.logits, dtype=self.dtype) + dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True) + self._prob = self.logits / dist_sum def sample(self, shape): """Generate samples of the specified shape. @@ -297,42 +299,21 @@ class Categorical(distribution.Distribution): """ name = self.name + '_probs' - - dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True) - prob = self.logits / dist_sum - - shape = list(prob.shape) - value_shape = list(value.shape) - if len(shape) == 1: - num_value_in_one_dist = np.prod(value_shape) - index_value = paddle.reshape(value, [num_value_in_one_dist, 1]) - index = index_value + if len(self._prob.shape) == 1: # batch_shape is empty + return paddle.gather( + self._prob, value.reshape( + [-1], name=name), name=name).reshape( + value.shape, name=name) else: - num_dist = np.prod(shape[:-1]) - num_value_in_one_dist = value_shape[-1] - prob = paddle.reshape(prob, [num_dist, shape[-1]]) - if len(value_shape) == 1: - value = nn.expand(value, [num_dist]) - value_shape = shape[:-1] + value_shape - index_value = paddle.reshape(value, [num_dist, -1, 1]) - if shape[:-1] != value_shape[:-1]: - raise ValueError( - "shape of value {} must match shape of logits {}".format( - str(value_shape[:-1]), str(shape[:-1]))) - - index_prefix = paddle.unsqueeze( - arange( - num_dist, dtype=index_value.dtype), axis=-1) - index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist]) - index_prefix = paddle.unsqueeze(index_prefix, axis=-1) - - if index_value.dtype != index_prefix.dtype: - tensor.cast(index_prefix, dtype=index_value.dtype) - index = concat([index_prefix, index_value], axis=-1) - - # value is the category index to search for the corresponding probability. - select_prob = gather_nd(prob, index) - return paddle.reshape(select_prob, value_shape, name=name) + if len(value.shape) == 1: + return paddle.take_along_axis( + self._prob, + paddle.reshape( + value, (len(self._prob.shape) - 1) * [1] + [-1], + name=name), + axis=-1) + else: + return paddle.take_along_axis(self._prob, value, axis=-1) def log_prob(self, value): """Log probabilities of the given category. Refer to ``probs`` method.