From f0b26313d8c31e6758a1eff426ebd3a2a6d664f0 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Sun, 6 Sep 2020 22:41:14 -0500 Subject: [PATCH] fix _check_values_dtype_in_probs method in Distribution class (#27046) --- python/paddle/distribution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distribution.py b/python/paddle/distribution.py index 918ebce0782..35204affb3f 100644 --- a/python/paddle/distribution.py +++ b/python/paddle/distribution.py @@ -138,7 +138,7 @@ class Distribution(object): convert value's dtype to be consistent with param's dtype. Args: - param (int|float|list|numpy.ndarray|Tensor): low and high in Uniform class, loc and scale in Normal class. + param (Tensor): low and high in Uniform class, loc and scale in Normal class. value (Tensor): The input tensor. Returns: @@ -152,6 +152,7 @@ class Distribution(object): ) return core.ops.cast(value, 'in_dtype', value.dtype, 'out_dtype', param.dtype) + return value check_variable_and_dtype(value, 'value', ['float32', 'float64'], 'log_prob') -- GitLab