未验证 提交 f0b26313 编写于 作者: P pangyoki 提交者: GitHub

fix _check_values_dtype_in_probs method in Distribution class (#27046)

上级 b150f2b3
...@@ -138,7 +138,7 @@ class Distribution(object): ...@@ -138,7 +138,7 @@ class Distribution(object):
convert value's dtype to be consistent with param's dtype. convert value's dtype to be consistent with param's dtype.
Args: Args:
param (int|float|list|numpy.ndarray|Tensor): low and high in Uniform class, loc and scale in Normal class. param (Tensor): low and high in Uniform class, loc and scale in Normal class.
value (Tensor): The input tensor. value (Tensor): The input tensor.
Returns: Returns:
...@@ -152,6 +152,7 @@ class Distribution(object): ...@@ -152,6 +152,7 @@ class Distribution(object):
) )
return core.ops.cast(value, 'in_dtype', value.dtype, return core.ops.cast(value, 'in_dtype', value.dtype,
'out_dtype', param.dtype) 'out_dtype', param.dtype)
return value
check_variable_and_dtype(value, 'value', ['float32', 'float64'], check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob') 'log_prob')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册