提交 df611d8e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5298 Add dtype check in uniform and normal distribution

Merge pull request !5298 from XunDeng/pp_issue_branch
......@@ -368,3 +368,20 @@ class CheckTensor(PrimitiveWithInfer):
if isinstance(x, Tensor):
return x
raise TypeError(f"For {name}, input type should be a Tensor.")
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
"""
check if arg_a and arg_b have the same dtype.
"""
if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'):
if isinstance(arg_a, np.ndarray):
a_dtype = mstype.pytype_to_dtype(arg_a.dtype)
if isinstance(arg_a, np.ndarray):
b_dtype = mstype.pytype_to_dtype(arg_b.dtype)
if a_dtype != b_dtype:
raise TypeError(f"{name_a} and {name_b} should have the same dtype.")
int_type = mstype.int_type + mstype.uint_type
if a_dtype in int_type or a_dtype == mstype.float64:
return mstype.float32
return a_dtype
return hint_type
......@@ -32,7 +32,7 @@ class Bernoulli(Distribution):
name (str): name of the distribution. Default: Bernoulli.
Note:
probs should be proper probabilities (0 <= p <= 1).
probs should be proper probabilities (0 < p < 1).
Dist_spec_args is probs.
Examples:
......
......@@ -26,8 +26,9 @@ class Distribution(Cell):
Base class for all mathematical distributions.
Args:
seed (int): random seed used in sampling.
dtype (mindspore.dtype): type of the distribution.
name (str): name of the distribution.
name (str): Python str name prefixed to Ops created by this class. Default: subclass name.
param (dict): parameters used to initialize the distribution.
Note:
......
......@@ -35,7 +35,7 @@ class Geometric(Distribution):
name (str): name of the distribution. Default: Geometric.
Note:
probs should be proper probabilities (0 <= p <= 1).
probs should be proper probabilities (0 < p < 1).
Dist_spec_args is probs.
Examples:
......@@ -141,7 +141,7 @@ class Geometric(Distribution):
@property
def probs(self):
"""
Returns the probability for the outcome is 1.
Returns the probability of success of the Bernoulli trail.
"""
return self._probs
......
......@@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
raise_none_error
raise_none_error, common_dtype
from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic
class Normal(Distribution):
......@@ -104,7 +104,7 @@ class Normal(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype)
if mean is not None and sd is not None:
self._mean_value = cast_to_tensor(mean, self.parameter_type)
self._sd_value = cast_to_tensor(sd, self.parameter_type)
......@@ -126,6 +126,8 @@ class Normal(Distribution):
self.sq = P.Square()
self.sqrt = P.Sqrt()
self.zeroslike = P.ZerosLike()
self.dtypeop = P.DType()
self.sametypeshape = P.SameTypeShape()
def extend_repr(self):
if self.is_scalar_batch:
......@@ -143,7 +145,6 @@ class Normal(Distribution):
self.checktensor(mean, 'mean')
else:
mean = self.checktensor(mean, 'mean')
mean = self.cast(mean, self.parameter_type)
else:
mean = self._mean_value if self._mean_value is not None else raise_none_error('mean')
if sd is not None:
......@@ -151,12 +152,14 @@ class Normal(Distribution):
self.checktensor(sd, 'sd')
else:
sd = self.checktensor(sd, 'sd')
sd = self.cast(sd, self.parameter_type)
else:
sd = self._sd_value if self._sd_value is not None else raise_none_error('sd')
batch_shape = self.shape(mean + sd)
mean = mean * self.fill(self.dtype, batch_shape, 1.0)
sd = sd * self.fill(self.dtype, batch_shape, 1.0)
mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0)
sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0)
self.sametypeshape(mean, sd)
mean = self.cast(mean, self.parameter_type)
sd = self.cast(sd, self.parameter_type)
return mean, sd
def _mean(self, mean=None, sd=None):
......
......@@ -18,7 +18,7 @@ from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\
raise_none_error
raise_none_error, common_dtype
from ._utils.custom_ops import exp_generic, log_generic
class Uniform(Distribution):
......@@ -103,7 +103,7 @@ class Uniform(Distribution):
valid_dtype = mstype.float_type
check_type(dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param)
self.parameter_type = dtype
self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype)
if low is not None and high is not None:
self._low = cast_to_tensor(low, dtype)
self._high = cast_to_tensor(high, dtype)
......@@ -130,6 +130,8 @@ class Uniform(Distribution):
self.zeroslike = P.ZerosLike()
self.uniform = C.uniform
self.sametypeshape = P.SameTypeShape()
def extend_repr(self):
if self.is_scalar_batch:
str_info = f'low = {self.low}, high = {self.high}'
......@@ -146,7 +148,6 @@ class Uniform(Distribution):
self.checktensor(low, 'low')
else:
low = self.checktensor(low, 'low')
low = self.cast(low, self.parameter_type)
else:
low = self.low if self.low is not None else raise_none_error('low')
if high is not None:
......@@ -154,12 +155,14 @@ class Uniform(Distribution):
self.checktensor(high, 'high')
else:
high = self.checktensor(high, 'high')
high = self.cast(high, self.parameter_type)
else:
high = self.high if self.high is not None else raise_none_error('high')
batch_shape = self.shape(high - low)
high = high * self.fill(self.dtype, batch_shape, 1.0)
low = low * self.fill(self.dtype, batch_shape, 1.0)
high = high * self.fill(self.dtypeop(high), batch_shape, 1.0)
low = low * self.fill(self.dtypeop(low), batch_shape, 1.0)
self.sametypeshape(high, low)
low = self.cast(low, self.parameter_type)
high = self.cast(high, self.parameter_type)
return low, high
@property
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册