未验证 提交 1e53088a 编写于 作者: P pangyoki 提交者: GitHub

fix Distribution class problem (#26535)

* fix problems commented by Lutao

* rename _to_variable to _to_tensor

* fix unittest coverage

* little problem
上级 67d03bed
...@@ -25,6 +25,7 @@ from .fluid.layers import control_flow ...@@ -25,6 +25,7 @@ from .fluid.layers import control_flow
from .fluid.layers import tensor from .fluid.layers import tensor
from .fluid.layers import ops from .fluid.layers import ops
from .fluid.layers import nn from .fluid.layers import nn
from .fluid import core
from .fluid.framework import in_dygraph_mode from .fluid.framework import in_dygraph_mode
from .tensor.math import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub from .tensor.math import elementwise_mul, elementwise_div, elementwise_add, elementwise_sub
import math import math
...@@ -87,7 +88,7 @@ class Distribution(object): ...@@ -87,7 +88,7 @@ class Distribution(object):
return is_variable return is_variable
def _to_variable(self, *args): def _to_tensor(self, *args):
""" """
Argument convert args to Tensor Argument convert args to Tensor
...@@ -134,7 +135,7 @@ class Uniform(Distribution): ...@@ -134,7 +135,7 @@ class Uniform(Distribution):
Mathematical Details Mathematical Details
The probability density function (pdf) is, The probability density function (pdf) is
.. math:: .. math::
...@@ -154,8 +155,8 @@ class Uniform(Distribution): ...@@ -154,8 +155,8 @@ class Uniform(Distribution):
[broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation). [broadcasting](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/beginners_guide/basic_concept/broadcasting_en.html) (e.g., `high - low` is a valid operation).
Args: Args:
low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is float32 or int low(int|float|list|numpy.ndarray|Tensor): The lower boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor
high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is float32 or int high(int|float|list|numpy.ndarray|Tensor): The higher boundary of uniform distribution.The data type is int, float32, list, numpy.ndarray or Tensor
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples: Examples:
...@@ -218,7 +219,7 @@ class Uniform(Distribution): ...@@ -218,7 +219,7 @@ class Uniform(Distribution):
else: else:
if isinstance(low, float) and isinstance(high, float): if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True self.all_arg_is_float = True
self.low, self.high = self._to_variable(low, high) self.low, self.high = self._to_tensor(low, high)
def sample(self, shape, seed=0): def sample(self, shape, seed=0):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
...@@ -272,10 +273,13 @@ class Uniform(Distribution): ...@@ -272,10 +273,13 @@ class Uniform(Distribution):
if in_dygraph_mode(): if in_dygraph_mode():
lb_bool = self.low < value lb_bool = self.low < value
ub_bool = value < self.high ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype) dtype = value.dtype
return elementwise_sub( lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
nn.log(lb * ub), nn.log(self.high - self.low), name=name) dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'], check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob') 'log_prob')
...@@ -301,9 +305,13 @@ class Uniform(Distribution): ...@@ -301,9 +305,13 @@ class Uniform(Distribution):
if in_dygraph_mode(): if in_dygraph_mode():
lb_bool = self.low < value lb_bool = self.low < value
ub_bool = value < self.high ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype) dtype = value.dtype
return elementwise_div((lb * ub), (self.high - self.low), name=name) lb = core.ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype',
dtype)
ub = core.ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype',
dtype)
return (lb * ub) / (self.high - self.low)
check_variable_and_dtype(value, 'value', ['float32', 'float64'], check_variable_and_dtype(value, 'value', ['float32', 'float64'],
'log_prob') 'log_prob')
...@@ -330,7 +338,7 @@ class Normal(Distribution): ...@@ -330,7 +338,7 @@ class Normal(Distribution):
Mathematical details Mathematical details
The probability density function (pdf) is, The probability density function (pdf) is
.. math:: .. math::
...@@ -347,8 +355,8 @@ class Normal(Distribution): ...@@ -347,8 +355,8 @@ class Normal(Distribution):
* :math:`Z`: is the normalization constant. * :math:`Z`: is the normalization constant.
Args: Args:
loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is float32 or int. loc(int|float|list|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor.
scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is float32 or int. scale(int|float|list|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float32, list, numpy.ndarray or Tensor.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Examples: Examples:
...@@ -414,7 +422,7 @@ class Normal(Distribution): ...@@ -414,7 +422,7 @@ class Normal(Distribution):
else: else:
if isinstance(loc, float) and isinstance(scale, float): if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True self.all_arg_is_float = True
self.loc, self.scale = self._to_variable(loc, scale) self.loc, self.scale = self._to_tensor(loc, scale)
def sample(self, shape, seed=0): def sample(self, shape, seed=0):
"""Generate samples of the specified shape. """Generate samples of the specified shape.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册