未验证 提交 1e53088a 编写于 作者: P pangyoki 提交者: GitHub

fix Distribution class problem (#26535)

* fix problems commented by Lutao

* rename _to_variable to _to_tensor

* fix unittest coverage

* little problem
上级 67d03bed
......@@ -25,6 +25,7 @@ from .fluid.layers import control_flow
from .fluid.layers import tensor
from .fluid.layers import ops
from .fluid.layers import nn
from .fluid import core
from .fluid.framework import in_dygraph_mode
from .tensor.math import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub
import math
......@@ -87,7 +88,7 @@ class Distribution(object):
return is_variable
def _to_variable(self, *args):
def _to_tensor(self, *args):
"""
Argument convert args to Tensor
......@@ -134,7 +135,7 @@ class Uniform(Distribution):
Mathematical Details
The probability density function (pdf) is,
The probability density function (pdf) is
.. math::
......@@ -154,8 +155,8 @@ class Uniform(Distribution):
[broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation).
Args:
low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is float32 or int
high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is float32 or int
low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor
high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
......@@ -218,7 +219,7 @@ class Uniform(Distribution):
else:
if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True
self.low, self.high = self._to_variable(low, high)
self.low, self.high = self._to_tensor(low, high)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
......@@ -272,10 +273,13 @@ class Uniform(Distribution):
if in_dygraph_mode():
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub(
nn.log(lb * ub), nn.log(self.high - self.low), name=name)
dtype = value.dtype
lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
......@@ -301,9 +305,13 @@ class Uniform(Distribution):
if in_dygraph_mode():
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_div((lb * ub), (self.high - self.low), name=name)
dtype = value.dtype
lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype)
return (lb * ub) / (self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob')
......@@ -330,7 +338,7 @@ class Normal(Distribution):
Mathematical details
The probability density function (pdf) is,
The probability density function (pdf) is
.. math::
......@@ -347,8 +355,8 @@ class Normal(Distribution):
* :math:`Z`: is the normalization constant.
Args:
loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is float32 or int.
scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is float32 or int.
loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor.
scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples:
......@@ -414,7 +422,7 @@ class Normal(Distribution):
else:
if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True
self.loc, self.scale = self._to_variable(loc, scale)
self.loc, self.scale = self._to_tensor(loc, scale)
def sample(self, shape, seed=0):
"""Generate samples of the specified shape.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册