未验证 提交 a0c98e67 编写于 作者: P pangyoki 提交者: GitHub

fix dtype not matching bug in log_prob and probs method of Distribution class (#26767)

* fix _to_tensor method of Distribution class

* Add unittest

* let dtype be consistent with value in log_prob and probs

* fix format

* fix dtype problem and change unittest

* fix dtype of Numpy class in unittest

* add formula for entropy and kl

* change formula

* fix kl formula format

* fix kl formula format 2

* change gt to np in unittest

* optimize unittest format

* delete dumplicate

* delete dumplicate 2

* extract common function used to convert dtype value
上级 f95e8ffc
......@@ -102,21 +102,24 @@ class Distribution(object):
tmp = 0.
for arg in args:
valid_arg = False
for cls in [float, list, np.ndarray, tensor.Variable]:
if isinstance(arg, cls):
valid_arg = True
break
assert valid_arg, "type of input args must be float, list, numpy.ndarray or Tensor."
if isinstance(arg, float):
arg = np.zeros(1) + arg
arg = [arg]
if not isinstance(arg, (list, np.ndarray, tensor.Variable)):
raise TypeError(
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".
format(type(arg)))
arg_np = np.array(arg)
arg_dtype = arg_np.dtype
if str(arg_dtype) not in ['float32']:
warnings.warn(
"data type of argument only support float32, your argument will be convert to float32."
)
if str(arg_dtype) != 'float32':
if str(arg_dtype) != 'float64':
# "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
# and converted to float64 later using "cast".
warnings.warn(
"data type of argument only support float32 and float64, your argument will be convert to float32."
)
arg_np = arg_np.astype('float32')
# tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
tmp = tmp + arg_np
numpy_args.append(arg_np)
......@@ -129,6 +132,36 @@ class Distribution(object):
return tuple(variable_args)
def _check_values_dtype_in_probs(self, param, value):
"""
Log_prob and probs methods have input ``value``, if value's dtype is different from param,
convert value's dtype to be consistent with param's dtype.
Args:
param (int|float|list|numpy.ndarray|Tensor): low and high in Uniform class, loc and scale in Normal class.
value (Tensor): The input tensor.
Returns:
value (Tensor): Change value's dtype if value's dtype is different from param.
"""
if in_dygraph_mode():
if value.dtype != param.dtype and convert_dtype(
value.dtype) in ['float32', 'float64']:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return core.ops.cast(value, 'in_dtype', value.dtype,
'out_dtype', param.dtype)
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
if value.dtype != param.dtype:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
return tensor.cast(value, dtype=param.dtype)
return value
class Uniform(Distribution):
"""Uniform distribution with `low` and `high` parameters.
......@@ -155,8 +188,8 @@ class Uniform(Distribution):
[broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation).
Args:
low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor
high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor
low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float, list, numpy.ndarray or Tensor
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
......@@ -206,6 +239,7 @@ class Uniform(Distribution):
self.all_arg_is_float = False
self.batch_size_unknown = False
self.name = name if name is not None else 'Uniform'
self.dtype = 'float32'
if isinstance(low, int):
low = float(low)
......@@ -216,10 +250,22 @@ class Uniform(Distribution):
self.batch_size_unknown = True
self.low = low
self.high = high
self.dtype = convert_dtype(low.dtype)
else:
if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True
if isinstance(
low,
np.ndarray) and str(low.dtype) in ['float32', 'float64']:
self.dtype = low.dtype
elif isinstance(
high,
np.ndarray) and str(high.dtype) in ['float32', 'float64']:
self.dtype = high.dtype
self.low, self.high = self._to_tensor(low, high)
if self.dtype != convert_dtype(self.low.dtype):
self.low = tensor.cast(self.low, dtype=self.dtype)
self.high = tensor.cast(self.high, dtype=self.dtype)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
......@@ -241,11 +287,11 @@ class Uniform(Distribution):
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.low.dtype, 0.)
self.low + self.high, batch_shape + shape, self.dtype, 0.)
uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp,
zero_tmp.shape,
dtype=convert_dtype(zero_tmp.dtype),
dtype=self.dtype,
min=0.,
max=1.,
seed=seed)
......@@ -259,9 +305,8 @@ class Uniform(Distribution):
else:
output_shape = shape + batch_shape
output = nn.uniform_random(
output_shape, seed=seed) * (tensor.zeros(
output_shape, dtype=self.low.dtype) +
(self.high - self.low))
output_shape, seed=seed, dtype=self.dtype) * (tensor.zeros(
output_shape, dtype=self.dtype) + (self.high - self.low))
output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
......@@ -279,22 +324,20 @@ class Uniform(Distribution):
"""
name = self.name + '_log_prob'
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode():
# ensure value in [low, high]
lb_bool = self.low < value
ub_bool = value < self.high
dtype = value.dtype
lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
dtype)
value.dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype)
value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
lb_bool = control_flow.less_than(self.low, value)
ub_bool = control_flow.less_than(value, self.high)
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub(
......@@ -311,22 +354,19 @@ class Uniform(Distribution):
"""
name = self.name + '_probs'
value = self._check_values_dtype_in_probs(self.low, value)
if in_dygraph_mode():
lb_bool = self.low < value
ub_bool = value < self.high
dtype = value.dtype
lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
dtype)
value.dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype)
value.dtype)
return (lb * ub) / (self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
lb_bool = control_flow.less_than(self.low, value)
ub_bool = control_flow.less_than(value, self.high)
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_div((lb * ub), (self.high - self.low), name=name)
......@@ -334,6 +374,12 @@ class Uniform(Distribution):
def entropy(self):
"""Shannon entropy in nats.
The entropy is
.. math::
entropy(low, high) = \\log (high - low)
Returns:
Tensor: Shannon entropy of uniform distribution.The data type is float32.
......@@ -364,8 +410,8 @@ class Normal(Distribution):
* :math:`Z`: is the normalization constant.
Args:
loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor.
scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor.
loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
......@@ -418,6 +464,7 @@ class Normal(Distribution):
self.batch_size_unknown = False
self.all_arg_is_float = False
self.name = name if name is not None else 'Normal'
self.dtype = 'float32'
if isinstance(loc, int):
loc = float(loc)
......@@ -428,10 +475,22 @@ class Normal(Distribution):
self.batch_size_unknown = True
self.loc = loc
self.scale = scale
self.dtype = convert_dtype(loc.dtype)
else:
if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True
if isinstance(
loc,
np.ndarray) and str(loc.dtype) in ['float32', 'float64']:
self.dtype = loc.dtype
elif isinstance(
scale,
np.ndarray) and str(scale.dtype) in ['float32', 'float64']:
self.dtype = scale.dtype
self.loc, self.scale = self._to_tensor(loc, scale)
if self.dtype != convert_dtype(self.loc.dtype):
self.loc = tensor.cast(self.loc, dtype=self.dtype)
self.scale = tensor.cast(self.scale, dtype=self.dtype)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
......@@ -454,22 +513,18 @@ class Normal(Distribution):
if self.batch_size_unknown:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.loc.dtype, 0.)
self.loc + self.scale, batch_shape + shape, self.dtype, 0.)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
zero_tmp_shape = nn.shape(zero_tmp_reshape)
normal_random_tmp = nn.gaussian_random(
zero_tmp_shape,
mean=0.,
std=1.,
seed=seed,
dtype=convert_dtype(self.loc.dtype))
zero_tmp_shape, mean=0., std=1., seed=seed, dtype=self.dtype)
output = normal_random_tmp * (zero_tmp_reshape + self.scale)
output = elementwise_add(output, self.loc, name=name)
return output
else:
output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed) * \
(tensor.zeros(output_shape, dtype=self.loc.dtype) + self.scale)
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \
(tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
......@@ -479,6 +534,16 @@ class Normal(Distribution):
def entropy(self):
"""Shannon entropy in nats.
The entropy is
.. math::
entropy(\sigma) = 0.5 \\log (2 \pi e \sigma^2)
In the above equation:
* :math:`scale = \sigma`: is the std.
Returns:
Tensor: Shannon entropy of normal distribution.The data type is float32.
......@@ -486,7 +551,7 @@ class Normal(Distribution):
name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.loc.dtype, 0.)
self.loc + self.scale, batch_shape, self.dtype, 0.)
return elementwise_add(
0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)),
......@@ -502,11 +567,9 @@ class Normal(Distribution):
Tensor: log probability.The data type is same with value.
"""
if not in_dygraph_mode():
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
name = self.name + '_log_prob'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale
log_scale = nn.log(self.scale)
return elementwise_sub(
......@@ -524,11 +587,9 @@ class Normal(Distribution):
Tensor: probability.The data type is same with value.
"""
if not in_dygraph_mode():
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
name = self.name + '_probs'
value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale
return elementwise_div(
ops.exp(-1. * ((value - self.loc) * (value - self.loc)) /
......@@ -538,6 +599,29 @@ class Normal(Distribution):
def kl_divergence(self, other):
"""The KL-divergence between two normal distributions.
The probability density function (pdf) is
.. math::
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\\frac{diff}{\sigma_1})^2 - 1 - 2 \\ln {ratio})
.. math::
ratio = \\frac{\sigma_0}{\sigma_1}
.. math::
diff = \mu_1 - \mu_0
In the above equation:
* :math:`loc = \mu_0`: is the mean of current Normal distribution.
* :math:`scale = \sigma_0`: is the std of current Normal distribution.
* :math:`loc = \mu_1`: is the mean of other Normal distribution.
* :math:`scale = \sigma_1`: is the std of other Normal distribution.
* :math:`ratio`: is the ratio of scales.
* :math:`diff`: is the difference between means.
Args:
other (Normal): instance of Normal.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册