提交 df611d8e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5298 Add dtype check in uniform and normal distribution

Merge pull request !5298 from XunDeng/pp_issue_branch
...@@ -368,3 +368,20 @@ class CheckTensor(PrimitiveWithInfer): ...@@ -368,3 +368,20 @@ class CheckTensor(PrimitiveWithInfer):
if isinstance(x, Tensor): if isinstance(x, Tensor):
return x return x
raise TypeError(f"For {name}, input type should be a Tensor.") raise TypeError(f"For {name}, input type should be a Tensor.")
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
"""
check if arg_a and arg_b have the same dtype.
"""
if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'):
if isinstance(arg_a, np.ndarray):
a_dtype = mstype.pytype_to_dtype(arg_a.dtype)
if isinstance(arg_a, np.ndarray):
b_dtype = mstype.pytype_to_dtype(arg_b.dtype)
if a_dtype != b_dtype:
raise TypeError(f"{name_a} and {name_b} should have the same dtype.")
int_type = mstype.int_type + mstype.uint_type
if a_dtype in int_type or a_dtype == mstype.float64:
return mstype.float32
return a_dtype
return hint_type
...@@ -32,7 +32,7 @@ class Bernoulli(Distribution): ...@@ -32,7 +32,7 @@ class Bernoulli(Distribution):
name (str): name of the distribution. Default: Bernoulli. name (str): name of the distribution. Default: Bernoulli.
Note: Note:
probs should be proper probabilities (0 <= p <= 1). probs should be proper probabilities (0 < p < 1).
Dist_spec_args is probs. Dist_spec_args is probs.
Examples: Examples:
......
...@@ -26,8 +26,9 @@ class Distribution(Cell): ...@@ -26,8 +26,9 @@ class Distribution(Cell):
Base class for all mathematical distributions. Base class for all mathematical distributions.
Args: Args:
seed (int): random seed used in sampling.
dtype (mindspore.dtype): type of the distribution. dtype (mindspore.dtype): type of the distribution.
name (str): name of the distribution. name (str): Python str name prefixed to Ops created by this class. Default: subclass name.
param (dict): parameters used to initialize the distribution. param (dict): parameters used to initialize the distribution.
Note: Note:
......
...@@ -35,7 +35,7 @@ class Geometric(Distribution): ...@@ -35,7 +35,7 @@ class Geometric(Distribution):
name (str): name of the distribution. Default: Geometric. name (str): name of the distribution. Default: Geometric.
Note: Note:
probs should be proper probabilities (0 <= p <= 1). probs should be proper probabilities (0 < p < 1).
Dist_spec_args is probs. Dist_spec_args is probs.
Examples: Examples:
...@@ -141,7 +141,7 @@ class Geometric(Distribution): ...@@ -141,7 +141,7 @@ class Geometric(Distribution):
@property @property
def probs(self): def probs(self):
""" """
Returns the probability for the outcome is 1. Returns the probability of success of the Bernoulli trail.
""" """
return self._probs return self._probs
......
...@@ -19,7 +19,7 @@ from mindspore.ops import composite as C ...@@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error raise_none_error, common_dtype
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
class Normal(Distribution): class Normal(Distribution):
...@@ -104,7 +104,7 @@ class Normal(Distribution): ...@@ -104,7 +104,7 @@ class Normal(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param) super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype)
if mean is not None and sd is not None: if mean is not None and sd is not None:
self._mean_value = cast_to_tensor(mean, self.parameter_type) self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = cast_to_tensor(sd, self.parameter_type) self._sd_value = cast_to_tensor(sd, self.parameter_type)
...@@ -126,6 +126,8 @@ class Normal(Distribution): ...@@ -126,6 +126,8 @@ class Normal(Distribution):
self.sq = P.Square() self.sq = P.Square()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.zeroslike = P.ZerosLike() self.zeroslike = P.ZerosLike()
self.dtypeop = P.DType()
self.sametypeshape = P.SameTypeShape()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
...@@ -143,7 +145,6 @@ class Normal(Distribution): ...@@ -143,7 +145,6 @@ class Normal(Distribution):
self.checktensor(mean, 'mean') self.checktensor(mean, 'mean')
else: else:
mean = self.checktensor(mean, 'mean') mean = self.checktensor(mean, 'mean')
mean = self.cast(mean, self.parameter_type)
else: else:
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
if sd is not None: if sd is not None:
...@@ -151,12 +152,14 @@ class Normal(Distribution): ...@@ -151,12 +152,14 @@ class Normal(Distribution):
self.checktensor(sd, 'sd') self.checktensor(sd, 'sd')
else: else:
sd = self.checktensor(sd, 'sd') sd = self.checktensor(sd, 'sd')
sd = self.cast(sd, self.parameter_type)
else: else:
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
batch_shape = self.shape(mean + sd) batch_shape = self.shape(mean + sd)
mean = mean * self.fill(self.dtype, batch_shape, 1.0) mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0)
sd = sd * self.fill(self.dtype, batch_shape, 1.0) sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0)
self.sametypeshape(mean, sd)
mean = self.cast(mean, self.parameter_type)
sd = self.cast(sd, self.parameter_type)
return mean, sd return mean, sd
def _mean(self, mean=None, sd=None): def _mean(self, mean=None, sd=None):
......
...@@ -18,7 +18,7 @@ from mindspore.ops import composite as C ...@@ -18,7 +18,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
raise_none_error raise_none_error, common_dtype
from ._utils.custom_ops import exp_generic, log_generic from ._utils.custom_ops import exp_generic, log_generic
class Uniform(Distribution): class Uniform(Distribution):
...@@ -103,7 +103,7 @@ class Uniform(Distribution): ...@@ -103,7 +103,7 @@ class Uniform(Distribution):
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__) check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype)
if low is not None and high is not None: if low is not None and high is not None:
self._low = cast_to_tensor(low, dtype) self._low = cast_to_tensor(low, dtype)
self._high = cast_to_tensor(high, dtype) self._high = cast_to_tensor(high, dtype)
...@@ -130,6 +130,8 @@ class Uniform(Distribution): ...@@ -130,6 +130,8 @@ class Uniform(Distribution):
self.zeroslike = P.ZerosLike() self.zeroslike = P.ZerosLike()
self.uniform = C.uniform self.uniform = C.uniform
self.sametypeshape = P.SameTypeShape()
def extend_repr(self): def extend_repr(self):
if self.is_scalar_batch: if self.is_scalar_batch:
str_info = f'low = {self.low}, high = {self.high}' str_info = f'low = {self.low}, high = {self.high}'
...@@ -146,7 +148,6 @@ class Uniform(Distribution): ...@@ -146,7 +148,6 @@ class Uniform(Distribution):
self.checktensor(low, 'low') self.checktensor(low, 'low')
else: else:
low = self.checktensor(low, 'low') low = self.checktensor(low, 'low')
low = self.cast(low, self.parameter_type)
else: else:
low = self.low if self.low is not None else raise_none_error('low') low = self.low if self.low is not None else raise_none_error('low')
if high is not None: if high is not None:
...@@ -154,12 +155,14 @@ class Uniform(Distribution): ...@@ -154,12 +155,14 @@ class Uniform(Distribution):
self.checktensor(high, 'high') self.checktensor(high, 'high')
else: else:
high = self.checktensor(high, 'high') high = self.checktensor(high, 'high')
high = self.cast(high, self.parameter_type)
else: else:
high = self.high if self.high is not None else raise_none_error('high') high = self.high if self.high is not None else raise_none_error('high')
batch_shape = self.shape(high - low) batch_shape = self.shape(high - low)
high = high * self.fill(self.dtype, batch_shape, 1.0) high = high * self.fill(self.dtypeop(high), batch_shape, 1.0)
low = low * self.fill(self.dtype, batch_shape, 1.0) low = low * self.fill(self.dtypeop(low), batch_shape, 1.0)
self.sametypeshape(high, low)
low = self.cast(low, self.parameter_type)
high = self.cast(high, self.parameter_type)
return low, high return low, high
@property @property
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册