diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index f1297545c27a1b7a43cd503fbe95a28485784643..58d1c7cd014b1be3c337c67c44268668df57f4d2 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -45,10 +45,6 @@ def cast_to_tensor(t, hint_type=mstype.float32): return t t_type = hint_type if isinstance(t, Tensor): - #check if the Tensor in shape of Tensor(4) - if t.dim() == 0: - value = t.asnumpy() - return Tensor([value], dtype=t_type) #convert the type of tensor to dtype return Tensor(t.asnumpy(), dtype=t_type) if isinstance(t, (list, np.ndarray)): @@ -56,7 +52,7 @@ def cast_to_tensor(t, hint_type=mstype.float32): if isinstance(t, bool): raise TypeError(f'Input cannot be Type Bool') if isinstance(t, (int, float)): - return Tensor([t], dtype=t_type) + return Tensor(t, dtype=t_type) raise TypeError("Input type is not supported.") def convert_to_batch(t, batch_shape, required_type): diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index c59a8f369130dede9b67c90d5cfcb6d8e9ba1e4b..509c6fe8e7f32a0514b5b21be6b0263c3a680c50 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -107,6 +107,7 @@ class Bernoulli(Distribution): self._probs = probs # ops needed for the class + self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() @@ -284,8 +285,16 @@ class Bernoulli(Distribution): probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs if probs1 is None: raise_none_error("probs") + origin_shape = shape + self.shape(probs1) + if origin_shape == (): + sample_shape = (1,) + else: + sample_shape = origin_shape l_zero = self.const(0.0) h_one = self.const(1.0) - sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed) + sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed) sample = self.less(sample_uniform, probs1) - return self.cast(sample, self.dtype) + value = self.cast(sample, self.dtype) + if origin_shape == (): + value = self.squeeze(value) + return value diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index b933e54ee5db763fd5706c128beebbddcf3391f3..b89b8af62765b705061ee34eda4758041914d5a3 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -111,6 +111,7 @@ class Exponential(Distribution): self.minval = np.finfo(np.float).tiny # ops needed for the class + self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() @@ -276,8 +277,16 @@ class Exponential(Distribution): rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate if rate is None: raise_none_error("rate") + origin_shape = shape + self.shape(rate) + if origin_shape == (): + sample_shape = (1,) + else: + sample_shape = origin_shape minval = self.const(self.minval) maxval = self.const(1.0) - sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed) + sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed) sample = -self.log(sample_uniform) / rate - return self.cast(sample, self.dtype) + value = self.cast(sample, self.dtype) + if origin_shape == (): + value = self.squeeze(value) + return value diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 9c658707e949a3d2600de2f67b49107d83d6ef91..45acecfe869bd97d6b8722f23cb49d46c20e0ffa 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -112,6 +112,7 @@ class Geometric(Distribution): self.minval = np.finfo(np.float).tiny # ops needed for the class + self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() @@ -283,8 +284,16 @@ class Geometric(Distribution): probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs if probs1 is None: raise_none_error("probs") + origin_shape = shape + self.shape(probs1) + if origin_shape == (): + sample_shape = (1,) + else: + sample_shape = origin_shape minval = self.const(self.minval) maxval = self.const(1.0) - sample_uniform = self.uniform(shape + self.shape(probs1), minval, maxval, self.seed) + sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed) sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs1)) - return self.cast(sample, self.dtype) + value = self.cast(sample, self.dtype) + if origin_shape == (): + value = self.squeeze(value) + return value diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index 7d53273c60d052258f69db5ecf43b198e4031ef6..86c867696ff511f67afce5dd738ac12862461e68 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -114,6 +114,7 @@ class Normal(Distribution): #ops needed for the class + self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.erf = P.Erf() @@ -305,7 +306,14 @@ class Normal(Distribution): sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value if sd is None: raise_none_error("sd") - batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) - sample_shape = shape + batch_shape + batch_shape = self.shape(mean + sd) + origin_shape = shape + batch_shape + if origin_shape == (): + sample_shape = (1,) + else: + sample_shape = origin_shape sample_norm = C.normal(sample_shape, mean, sd, self.seed) - return sample_norm + value = self.cast(sample_norm, self.dtype) + if origin_shape == (): + value = self.squeeze(value) + return value diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 3a0e36e761ede30d019ce93b8f6d826e68a255c3..2d1324804fbcf940f5a7fa9242e1a29c6978cb62 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -112,6 +112,7 @@ class Uniform(Distribution): self._high = high # ops needed for the class + self.squeeze = P.Squeeze(0) self.cast = P.Cast() self.const = P.ScalarToArray() self.dtypeop = P.DType() @@ -327,8 +328,16 @@ class Uniform(Distribution): if high is None: raise_none_error("high") broadcast_shape = self.shape(low + high) + origin_shape = shape + broadcast_shape + if origin_shape == (): + sample_shape = (1,) + else: + sample_shape = origin_shape l_zero = self.const(0.0) h_one = self.const(1.0) - sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed) + sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed) sample = (high - low) * sample_uniform + low - return self.cast(sample, self.dtype) + value = self.cast(sample, self.dtype) + if origin_shape == (): + value = self.squeeze(value) + return value