未验证 提交 9a0bfece 编写于 作者: F Feiyu Chan 提交者: GitHub

remove redundant computation in Categorical.probs (#42114)

上级 6553a9d7
...@@ -115,6 +115,8 @@ class Categorical(distribution.Distribution): ...@@ -115,6 +115,8 @@ class Categorical(distribution.Distribution):
self.logits = self._to_tensor(logits)[0] self.logits = self._to_tensor(logits)[0]
if self.dtype != convert_dtype(self.logits.dtype): if self.dtype != convert_dtype(self.logits.dtype):
self.logits = tensor.cast(self.logits, dtype=self.dtype) self.logits = tensor.cast(self.logits, dtype=self.dtype)
dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True)
self._prob = self.logits / dist_sum
def sample(self, shape): def sample(self, shape):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
...@@ -297,42 +299,21 @@ class Categorical(distribution.Distribution): ...@@ -297,42 +299,21 @@ class Categorical(distribution.Distribution):
""" """
name = self.name + '_probs' name = self.name + '_probs'
if len(self._prob.shape) == 1: # batch_shape is empty
dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True) return paddle.gather(
prob = self.logits / dist_sum self._prob, value.reshape(
[-1], name=name), name=name).reshape(
shape = list(prob.shape) value.shape, name=name)
value_shape = list(value.shape)
if len(shape) == 1:
num_value_in_one_dist = np.prod(value_shape)
index_value = paddle.reshape(value, [num_value_in_one_dist, 1])
index = index_value
else: else:
num_dist = np.prod(shape[:-1]) if len(value.shape) == 1:
num_value_in_one_dist = value_shape[-1] return paddle.take_along_axis(
prob = paddle.reshape(prob, [num_dist, shape[-1]]) self._prob,
if len(value_shape) == 1: paddle.reshape(
value = nn.expand(value, [num_dist]) value, (len(self._prob.shape) - 1) * [1] + [-1],
value_shape = shape[:-1] + value_shape name=name),
index_value = paddle.reshape(value, [num_dist, -1, 1]) axis=-1)
if shape[:-1] != value_shape[:-1]: else:
raise ValueError( return paddle.take_along_axis(self._prob, value, axis=-1)
"shape of value {} must match shape of logits {}".format(
str(value_shape[:-1]), str(shape[:-1])))
index_prefix = paddle.unsqueeze(
arange(
num_dist, dtype=index_value.dtype), axis=-1)
index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist])
index_prefix = paddle.unsqueeze(index_prefix, axis=-1)
if index_value.dtype != index_prefix.dtype:
tensor.cast(index_prefix, dtype=index_value.dtype)
index = concat([index_prefix, index_value], axis=-1)
# value is the category index to search for the corresponding probability.
select_prob = gather_nd(prob, index)
return paddle.reshape(select_prob, value_shape, name=name)
def log_prob(self, value): def log_prob(self, value):
"""Log probabilities of the given category. Refer to ``probs`` method. """Log probabilities of the given category. Refer to ``probs`` method.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册