未验证 提交 9a0bfece 编写于 作者: F Feiyu Chan 提交者: GitHub

remove redundant computation in Categorical.probs (#42114)

上级 6553a9d7
......@@ -115,6 +115,8 @@ class Categorical(distribution.Distribution):
self.logits = self._to_tensor(logits)[0]
if self.dtype != convert_dtype(self.logits.dtype):
self.logits = tensor.cast(self.logits, dtype=self.dtype)
dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True)
self._prob = self.logits / dist_sum
def sample(self, shape):
"""Generate samples of the specified shape.
......@@ -297,42 +299,21 @@ class Categorical(distribution.Distribution):
"""
name = self.name + '_probs'
dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True)
prob = self.logits / dist_sum
shape = list(prob.shape)
value_shape = list(value.shape)
if len(shape) == 1:
num_value_in_one_dist = np.prod(value_shape)
index_value = paddle.reshape(value, [num_value_in_one_dist, 1])
index = index_value
if len(self._prob.shape) == 1: # batch_shape is empty
return paddle.gather(
self._prob, value.reshape(
[-1], name=name), name=name).reshape(
value.shape, name=name)
else:
if len(value.shape) == 1:
return paddle.take_along_axis(
self._prob,
paddle.reshape(
value, (len(self._prob.shape) - 1) * [1] + [-1],
name=name),
axis=-1)
else:
num_dist = np.prod(shape[:-1])
num_value_in_one_dist = value_shape[-1]
prob = paddle.reshape(prob, [num_dist, shape[-1]])
if len(value_shape) == 1:
value = nn.expand(value, [num_dist])
value_shape = shape[:-1] + value_shape
index_value = paddle.reshape(value, [num_dist, -1, 1])
if shape[:-1] != value_shape[:-1]:
raise ValueError(
"shape of value {} must match shape of logits {}".format(
str(value_shape[:-1]), str(shape[:-1])))
index_prefix = paddle.unsqueeze(
arange(
num_dist, dtype=index_value.dtype), axis=-1)
index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist])
index_prefix = paddle.unsqueeze(index_prefix, axis=-1)
if index_value.dtype != index_prefix.dtype:
tensor.cast(index_prefix, dtype=index_value.dtype)
index = concat([index_prefix, index_value], axis=-1)
# value is the category index to search for the corresponding probability.
select_prob = gather_nd(prob, index)
return paddle.reshape(select_prob, value_shape, name=name)
return paddle.take_along_axis(self._prob, value, axis=-1)
def log_prob(self, value):
"""Log probabilities of the given category. Refer to ``probs`` method.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册