Created by: pangyoki
PR types
Bug fixes
PR changes
APIs
Describe
In _to_tensor
method of Distribution class (refer to PR #26355 and PR #26535). Even if we want to support both float32
and float64
dtype in Distribution classes, when parameters (low
and high
in Uniform
, loc
and scale
in Normal
) are numpy.ndarray
and dtypes are float64
, we can only set dtype to be float32
using assign
op to get the correspoding variable. Becase assign
op doesn't support float64
when input is numpy.ndarray
.
Thus, when parameters are `numpy.ndarray` with `float64` dtype,
we will transfer them into `VarType.FP32` variable.
In log_prob
and probs
methods in Distribution class, the input value
of these methods is a tensor. In users' view, it's reasonable that the dtype of value
and parameters are same.
When dtype of parameters are `float64` in `numpy.ndarray` and dtype of `value` is `VarType.FP64`,
it will cause error because the dtype of parameters change to `VarType.FP32`.
The following is an example code:
import numpy as np
import paddle
from paddle.distribution import Normal
paddle.disable_static()
value_np = np.array([0.8, 0.3], dtype='float64')
value_tensor = paddle.to_tensor(value_np) # 'float64' Tensor
loc_np = np.array([1, 2]).astype('float64') # will be converted to 'float32' Tensor automatically
scale_np = np.array([11, 22]).astype('float64') # will be converted to 'float32' Tensor automatically
normal = Normal(loc_np, scale_np)
lp = normal.log_prob(value_tensor) # error !!!
We are going to let assign
op support float64
, but it will lose precision because Attr
don't support float64
in framework.proto (refer to https://github.com/PaddlePaddle/Paddle/pull/26797). That is, assign op
can only support float32
.
-
Thus, in this PR, we use
cast
operation to convert dtype afterassign
op if dtype isfloat64
. If users define aUniform
distribution whoselow
andhigh
arefloat64
numpy.ndarray, we firstly useassign
op to getfloat32
variable. Then usecast
to getfloat64
variable. -
What's more,
probs
andlog_prob
methods have a variable input namedvalues
. If dtype ofvalues
is different withlow
inUniform
orloc
inNormal
, it will cause error. To solve this dtype conflict, wecast
dtype ofvalues
to be the same as that oflow
orloc
. (in_check_values_dtype_in_probs
function) -
In Doc discribtion, we add formula for
entropy
andkl-divergence
methods. Formula forlog_prob
andprobs
have been given in doc of class, that is, thepdf
(probability density function) of the distribution. -
By the way, we rewrite unittest to make it more readable.