diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 90f09fe41cd0554318a1eee6b087748e39dc2f73..7ee0a2dcd3888c39d6a9e2b9fb7c00ae699dbbfd 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -368,3 +368,20 @@ class CheckTensor(PrimitiveWithInfer): if isinstance(x, Tensor): return x raise TypeError(f"For {name}, input type should be a Tensor.") + +def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): + """ + check if arg_a and arg_b have the same dtype. + """ + if hasattr(arg_a, 'dtype') and hasattr(arg_b, 'dtype'): + if isinstance(arg_a, np.ndarray): + a_dtype = mstype.pytype_to_dtype(arg_a.dtype) + if isinstance(arg_a, np.ndarray): + b_dtype = mstype.pytype_to_dtype(arg_b.dtype) + if a_dtype != b_dtype: + raise TypeError(f"{name_a} and {name_b} should have the same dtype.") + int_type = mstype.int_type + mstype.uint_type + if a_dtype in int_type or a_dtype == mstype.float64: + return mstype.float32 + return a_dtype + return hint_type diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index cd03256d4d1a652540b08a0207b36866b6bb43b9..0dcbc59689b1550d86397f502c6840f5812bcb5a 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -32,7 +32,7 @@ class Bernoulli(Distribution): name (str): name of the distribution. Default: Bernoulli. Note: - probs should be proper probabilities (0 <= p <= 1). + probs should be proper probabilities (0 < p < 1). Dist_spec_args is probs. Examples: diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index 9514198dfd43d39d1eaa1f972b51a3b8582187a0..a4d6faace1a626453732d8a5cc6ecebcb860f6ef 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -26,8 +26,9 @@ class Distribution(Cell): Base class for all mathematical distributions. Args: + seed (int): random seed used in sampling. dtype (mindspore.dtype): type of the distribution. - name (str): name of the distribution. + name (str): Python str name prefixed to Ops created by this class. Default: subclass name. param (dict): parameters used to initialize the distribution. Note: diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 230f7a9174b34696fdc277fdacb8711e4276f61a..f3d3eb015f4b79f7e7b18e6a4801fab126b10b84 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -35,7 +35,7 @@ class Geometric(Distribution): name (str): name of the distribution. Default: Geometric. Note: - probs should be proper probabilities (0 <= p <= 1). + probs should be proper probabilities (0 < p < 1). Dist_spec_args is probs. Examples: @@ -141,7 +141,7 @@ class Geometric(Distribution): @property def probs(self): """ - Returns the probability for the outcome is 1. + Returns the probability of success of the Bernoulli trail. """ return self._probs diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index 1ad5b4dff725774dd3eea5a96e9a954ee608ced2..014537560af534c260bb7b356705b723f279d822 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -19,7 +19,7 @@ from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ - raise_none_error + raise_none_error, common_dtype from ._utils.custom_ops import exp_generic, expm1_generic, log_generic, erf_generic class Normal(Distribution): @@ -104,7 +104,7 @@ class Normal(Distribution): valid_dtype = mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Normal, self).__init__(seed, dtype, name, param) - self.parameter_type = dtype + self.parameter_type = common_dtype(mean, 'mean', sd, 'sd', self.dtype) if mean is not None and sd is not None: self._mean_value = cast_to_tensor(mean, self.parameter_type) self._sd_value = cast_to_tensor(sd, self.parameter_type) @@ -126,6 +126,8 @@ class Normal(Distribution): self.sq = P.Square() self.sqrt = P.Sqrt() self.zeroslike = P.ZerosLike() + self.dtypeop = P.DType() + self.sametypeshape = P.SameTypeShape() def extend_repr(self): if self.is_scalar_batch: @@ -143,7 +145,6 @@ class Normal(Distribution): self.checktensor(mean, 'mean') else: mean = self.checktensor(mean, 'mean') - mean = self.cast(mean, self.parameter_type) else: mean = self._mean_value if self._mean_value is not None else raise_none_error('mean') if sd is not None: @@ -151,12 +152,14 @@ class Normal(Distribution): self.checktensor(sd, 'sd') else: sd = self.checktensor(sd, 'sd') - sd = self.cast(sd, self.parameter_type) else: sd = self._sd_value if self._sd_value is not None else raise_none_error('sd') batch_shape = self.shape(mean + sd) - mean = mean * self.fill(self.dtype, batch_shape, 1.0) - sd = sd * self.fill(self.dtype, batch_shape, 1.0) + mean = mean * self.fill(self.dtypeop(mean), batch_shape, 1.0) + sd = sd * self.fill(self.dtypeop(sd), batch_shape, 1.0) + self.sametypeshape(mean, sd) + mean = self.cast(mean, self.parameter_type) + sd = self.cast(sd, self.parameter_type) return mean, sd def _mean(self, mean=None, sd=None): diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index d2882d70feb9b14927ae1da23c85f248c80467ee..160dc3ad2dd38e942459f26bac32c916b7b07344 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -18,7 +18,7 @@ from mindspore.ops import composite as C from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_greater, check_type, check_distribution_name,\ - raise_none_error + raise_none_error, common_dtype from ._utils.custom_ops import exp_generic, log_generic class Uniform(Distribution): @@ -103,7 +103,7 @@ class Uniform(Distribution): valid_dtype = mstype.float_type check_type(dtype, valid_dtype, type(self).__name__) super(Uniform, self).__init__(seed, dtype, name, param) - self.parameter_type = dtype + self.parameter_type = common_dtype(low, 'low', high, 'high', self.dtype) if low is not None and high is not None: self._low = cast_to_tensor(low, dtype) self._high = cast_to_tensor(high, dtype) @@ -130,6 +130,8 @@ class Uniform(Distribution): self.zeroslike = P.ZerosLike() self.uniform = C.uniform + self.sametypeshape = P.SameTypeShape() + def extend_repr(self): if self.is_scalar_batch: str_info = f'low = {self.low}, high = {self.high}' @@ -146,7 +148,6 @@ class Uniform(Distribution): self.checktensor(low, 'low') else: low = self.checktensor(low, 'low') - low = self.cast(low, self.parameter_type) else: low = self.low if self.low is not None else raise_none_error('low') if high is not None: @@ -154,12 +155,14 @@ class Uniform(Distribution): self.checktensor(high, 'high') else: high = self.checktensor(high, 'high') - high = self.cast(high, self.parameter_type) else: high = self.high if self.high is not None else raise_none_error('high') batch_shape = self.shape(high - low) - high = high * self.fill(self.dtype, batch_shape, 1.0) - low = low * self.fill(self.dtype, batch_shape, 1.0) + high = high * self.fill(self.dtypeop(high), batch_shape, 1.0) + low = low * self.fill(self.dtypeop(low), batch_shape, 1.0) + self.sametypeshape(high, low) + low = self.cast(low, self.parameter_type) + high = self.cast(high, self.parameter_type) return low, high @property