未验证 提交 a0c98e67 编写于 作者: P pangyoki 提交者: GitHub

fix dtype not matching bug in log_prob and probs method of Distribution class (#26767)

* fix _to_tensor method of Distribution class

* Add unittest

* let dtype be consistent with value in log_prob and probs

* fix format

* fix dtype problem and change unittest

* fix dtype of Numpy class in unittest

* add formula for entropy and kl

* change formula

* fix kl formula format

* fix kl formula format 2

* change gt to np in unittest

* optimize unittest format

* delete dumplicate

* delete dumplicate 2

* extract common function used to convert dtype value
上级 f95e8ffc
...@@ -102,21 +102,24 @@ class Distribution(object): ...@@ -102,21 +102,24 @@ class Distribution(object):
tmp = 0. tmp = 0.
for arg in args: for arg in args:
valid_arg = False
for cls in [float, list, np.ndarray, tensor.Variable]:
if isinstance(arg, cls):
valid_arg = True
break
assert valid_arg, "type of input args must be float, list, numpy.ndarray or Tensor."
if isinstance(arg, float): if isinstance(arg, float):
arg = np.zeros(1) + arg arg = [arg]
if not isinstance(arg, (list, np.ndarray, tensor.Variable)):
raise TypeError(
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".
format(type(arg)))
arg_np = np.array(arg) arg_np = np.array(arg)
arg_dtype = arg_np.dtype arg_dtype = arg_np.dtype
if str(arg_dtype) not in ['float32']: if str(arg_dtype) != 'float32':
if str(arg_dtype) != 'float64':
# "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
# and converted to float64 later using "cast".
warnings.warn( warnings.warn(
"data type of argument only support float32, your argument will be convert to float32." "data type of argument only support float32 and float64, your argument will be convert to float32."
) )
arg_np = arg_np.astype('float32') arg_np = arg_np.astype('float32')
# tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
tmp = tmp + arg_np tmp = tmp + arg_np
numpy_args.append(arg_np) numpy_args.append(arg_np)
...@@ -129,6 +132,36 @@ class Distribution(object): ...@@ -129,6 +132,36 @@ class Distribution(object):
return tuple(variable_args) return tuple(variable_args)
def _check_values_dtype_in_probs(self, param, value):
"""
Log_prob and probs methods have input ``value``, if value's dtype is different from param,
convert value's dtype to be consistent with param's dtype.
Args:
param (int|float|list|numpy.ndarray|Tensor): low and high in Uniform class, loc and scale in Normal class.
value (Tensor): The input tensor.
Returns:
value (Tensor): Change value's dtype if value's dtype is different from param.
"""
if in_dygraph_mode():
if value.dtype != param.dtype and convert_dtype(
value.dtype) in ['float32', 'float64']:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return core.ops.cast(value, 'in_dtype', value.dtype,
'out_dtype', param.dtype)
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
if value.dtype != param.dtype:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return tensor.cast(value, dtype=param.dtype)
return value
class Uniform(Distribution): class Uniform(Distribution):
"""Uniform distribution with `low` and `high` parameters. """Uniform distribution with `low` and `high` parameters.
...@@ -155,8 +188,8 @@ class Uniform(Distribution): ...@@ -155,8 +188,8 @@ class Uniform(Distribution):
[broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation). [broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation).
Args: Args:
low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples: Examples:
...@@ -206,6 +239,7 @@ class Uniform(Distribution): ...@@ -206,6 +239,7 @@ class Uniform(Distribution):
self.all_arg_is_float = False self.all_arg_is_float = False
self.batch_size_unknown = False self.batch_size_unknown = False
self.name = name if name is not None else 'Uniform' self.name = name if name is not None else 'Uniform'
self.dtype = 'float32'
if isinstance(low, int): if isinstance(low, int):
low = float(low) low = float(low)
...@@ -216,10 +250,22 @@ class Uniform(Distribution): ...@@ -216,10 +250,22 @@ class Uniform(Distribution):
self.batch_size_unknown = True self.batch_size_unknown = True
self.low = low self.low = low
self.high = high self.high = high
self.dtype = convert_dtype(low.dtype)
else: else:
if isinstance(low, float) and isinstance(high, float): if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True self.all_arg_is_float = True
if isinstance(
low,
np.ndarray) and str(low.dtype) in ['float32', 'float64']:
self.dtype = low.dtype
elif isinstance(
high,
np.ndarray) and str(high.dtype) in ['float32', 'float64']:
self.dtype = high.dtype
self.low, self.high = self._to_tensor(low, high) self.low, self.high = self._to_tensor(low, high)
if self.dtype != convert_dtype(self.low.dtype):
self.low = tensor.cast(self.low, dtype=self.dtype)
self.high = tensor.cast(self.high, dtype=self.dtype)
def sample(self, shape, seed=0): def sample(self, shape, seed=0):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
...@@ -241,11 +287,11 @@ class Uniform(Distribution): ...@@ -241,11 +287,11 @@ class Uniform(Distribution):
if self.batch_size_unknown: if self.batch_size_unknown:
output_shape = shape + batch_shape output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like( zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.low.dtype, 0.) self.low + self.high, batch_shape + shape, self.dtype, 0.)
uniform_random_tmp = nn.uniform_random_batch_size_like( uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp, zero_tmp,
zero_tmp.shape, zero_tmp.shape,
dtype=convert_dtype(zero_tmp.dtype), dtype=self.dtype,
min=0., min=0.,
max=1., max=1.,
seed=seed) seed=seed)
...@@ -259,9 +305,8 @@ class Uniform(Distribution): ...@@ -259,9 +305,8 @@ class Uniform(Distribution):
else: else:
output_shape = shape + batch_shape output_shape = shape + batch_shape
output = nn.uniform_random( output = nn.uniform_random(
output_shape, seed=seed) * (tensor.zeros( output_shape, seed=seed, dtype=self.dtype) * (tensor.zeros(
output_shape, dtype=self.low.dtype) + output_shape, dtype=self.dtype) + (self.high - self.low))
(self.high - self.low))
output = elementwise_add(output, self.low, name=name) output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float: if self.all_arg_is_float:
return nn.reshape(output, shape, name=name) return nn.reshape(output, shape, name=name)
...@@ -279,22 +324,20 @@ class Uniform(Distribution): ...@@ -279,22 +324,20 @@ class Uniform(Distribution):
""" """
name = self.name + '_log_prob' name = self.name + '_log_prob'
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode(): if in_dygraph_mode():
# ensure value in [low, high]
lb_bool = self.low < value lb_bool = self.low < value
ub_bool = value < self.high ub_bool = value < self.high
dtype = value.dtype
lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype', lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
dtype) value.dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype) value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low) return nn.log(lb * ub) - nn.log(self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'], lb_bool = self.low < value
'log_prob') ub_bool = value < self.high
lb_bool = control_flow.less_than(self.low, value)
ub_bool = control_flow.less_than(value, self.high)
lb = tensor.cast(lb_bool, dtype=value.dtype) lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype) ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub( return elementwise_sub(
...@@ -311,22 +354,19 @@ class Uniform(Distribution): ...@@ -311,22 +354,19 @@ class Uniform(Distribution):
""" """
name = self.name + '_probs' name = self.name + '_probs'
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode(): if in_dygraph_mode():
lb_bool = self.low < value lb_bool = self.low < value
ub_bool = value < self.high ub_bool = value < self.high
dtype = value.dtype
lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype', lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
dtype) value.dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype) value.dtype)
return (lb * ub) / (self.high - self.low) return (lb * ub) / (self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'], lb_bool = self.low < value
'log_prob') ub_bool = value < self.high
lb_bool = control_flow.less_than(self.low, value)
ub_bool = control_flow.less_than(value, self.high)
lb = tensor.cast(lb_bool, dtype=value.dtype) lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype) ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_div((lb * ub), (self.high - self.low), name=name) return elementwise_div((lb * ub), (self.high - self.low), name=name)
...@@ -334,6 +374,12 @@ class Uniform(Distribution): ...@@ -334,6 +374,12 @@ class Uniform(Distribution):
def entropy(self): def entropy(self):
"""Shannon entropy in nats. """Shannon entropy in nats.
The entropy is
.. math::
entropy(low, high) = \\log (high - low)
Returns: Returns:
Tensor: Shannon entropy of uniform distribution.The data type is float32. Tensor: Shannon entropy of uniform distribution.The data type is float32.
...@@ -364,8 +410,8 @@ class Normal(Distribution): ...@@ -364,8 +410,8 @@ class Normal(Distribution):
* :math:`Z`: is the normalization constant. * :math:`Z`: is the normalization constant.
Args: Args:
loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor. loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor. scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples: Examples:
...@@ -418,6 +464,7 @@ class Normal(Distribution): ...@@ -418,6 +464,7 @@ class Normal(Distribution):
self.batch_size_unknown = False self.batch_size_unknown = False
self.all_arg_is_float = False self.all_arg_is_float = False
self.name = name if name is not None else 'Normal' self.name = name if name is not None else 'Normal'
self.dtype = 'float32'
if isinstance(loc, int): if isinstance(loc, int):
loc = float(loc) loc = float(loc)
...@@ -428,10 +475,22 @@ class Normal(Distribution): ...@@ -428,10 +475,22 @@ class Normal(Distribution):
self.batch_size_unknown = True self.batch_size_unknown = True
self.loc = loc self.loc = loc
self.scale = scale self.scale = scale
self.dtype = convert_dtype(loc.dtype)
else: else:
if isinstance(loc, float) and isinstance(scale, float): if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True self.all_arg_is_float = True
if isinstance(
loc,
np.ndarray) and str(loc.dtype) in ['float32', 'float64']:
self.dtype = loc.dtype
elif isinstance(
scale,
np.ndarray) and str(scale.dtype) in ['float32', 'float64']:
self.dtype = scale.dtype
self.loc, self.scale = self._to_tensor(loc, scale) self.loc, self.scale = self._to_tensor(loc, scale)
if self.dtype != convert_dtype(self.loc.dtype):
self.loc = tensor.cast(self.loc, dtype=self.dtype)
self.scale = tensor.cast(self.scale, dtype=self.dtype)
def sample(self, shape, seed=0): def sample(self, shape, seed=0):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
...@@ -454,22 +513,18 @@ class Normal(Distribution): ...@@ -454,22 +513,18 @@ class Normal(Distribution):
if self.batch_size_unknown: if self.batch_size_unknown:
output_shape = shape + batch_shape output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like( zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.loc.dtype, 0.) self.loc + self.scale, batch_shape + shape, self.dtype, 0.)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape) zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
zero_tmp_shape = nn.shape(zero_tmp_reshape) zero_tmp_shape = nn.shape(zero_tmp_reshape)
normal_random_tmp = nn.gaussian_random( normal_random_tmp = nn.gaussian_random(
zero_tmp_shape, zero_tmp_shape, mean=0., std=1., seed=seed, dtype=self.dtype)
mean=0.,
std=1.,
seed=seed,
dtype=convert_dtype(self.loc.dtype))
output = normal_random_tmp * (zero_tmp_reshape + self.scale) output = normal_random_tmp * (zero_tmp_reshape + self.scale)
output = elementwise_add(output, self.loc, name=name) output = elementwise_add(output, self.loc, name=name)
return output return output
else: else:
output_shape = shape + batch_shape output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed) * \ output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \
(tensor.zeros(output_shape, dtype=self.loc.dtype) + self.scale) (tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name) output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float: if self.all_arg_is_float:
return nn.reshape(output, shape, name=name) return nn.reshape(output, shape, name=name)
...@@ -479,6 +534,16 @@ class Normal(Distribution): ...@@ -479,6 +534,16 @@ class Normal(Distribution):
def entropy(self): def entropy(self):
"""Shannon entropy in nats. """Shannon entropy in nats.
The entropy is
.. math::
entropy(\sigma) = 0.5 \\log (2 \pi e \sigma^2)
In the above equation:
* :math:`scale = \sigma`: is the std.
Returns: Returns:
Tensor: Shannon entropy of normal distribution.The data type is float32. Tensor: Shannon entropy of normal distribution.The data type is float32.
...@@ -486,7 +551,7 @@ class Normal(Distribution): ...@@ -486,7 +551,7 @@ class Normal(Distribution):
name = self.name + '_entropy' name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape) batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like( zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.loc.dtype, 0.) self.loc + self.scale, batch_shape, self.dtype, 0.)
return elementwise_add( return elementwise_add(
0.5 + zero_tmp, 0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)), 0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)),
...@@ -502,11 +567,9 @@ class Normal(Distribution): ...@@ -502,11 +567,9 @@ class Normal(Distribution):
Tensor: log probability.The data type is same with value. Tensor: log probability.The data type is same with value.
""" """
if not in_dygraph_mode():
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
name = self.name + '_log_prob' name = self.name + '_log_prob'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale var = self.scale * self.scale
log_scale = nn.log(self.scale) log_scale = nn.log(self.scale)
return elementwise_sub( return elementwise_sub(
...@@ -524,11 +587,9 @@ class Normal(Distribution): ...@@ -524,11 +587,9 @@ class Normal(Distribution):
Tensor: probability.The data type is same with value. Tensor: probability.The data type is same with value.
""" """
if not in_dygraph_mode():
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
name = self.name + '_probs' name = self.name + '_probs'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale var = self.scale * self.scale
return elementwise_div( return elementwise_div(
ops.exp(-1. * ((value - self.loc) * (value - self.loc)) / ops.exp(-1. * ((value - self.loc) * (value - self.loc)) /
...@@ -538,6 +599,29 @@ class Normal(Distribution): ...@@ -538,6 +599,29 @@ class Normal(Distribution):
def kl_divergence(self, other): def kl_divergence(self, other):
"""The KL-divergence between two normal distributions. """The KL-divergence between two normal distributions.
The probability density function (pdf) is
.. math::
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\\frac{diff}{\sigma_1})^2 - 1 - 2 \\ln {ratio})
.. math::
ratio = \\frac{\sigma_0}{\sigma_1}
.. math::
diff = \mu_1 - \mu_0
In the above equation:
* :math:`loc = \mu_0`: is the mean of current Normal distribution.
* :math:`scale = \sigma_0`: is the std of current Normal distribution.
* :math:`loc = \mu_1`: is the mean of other Normal distribution.
* :math:`scale = \sigma_1`: is the std of other Normal distribution.
* :math:`ratio`: is the ratio of scales.
* :math:`diff`: is the difference between means.
Args: Args:
other (Normal): instance of Normal. other (Normal): instance of Normal.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册