提交 86616ac5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4773 Fix empty shape issue in distribution sample functions

Merge pull request !4773 from peixu_ren/custom_bijector
...@@ -45,10 +45,6 @@ def cast_to_tensor(t, hint_type=mstype.float32): ...@@ -45,10 +45,6 @@ def cast_to_tensor(t, hint_type=mstype.float32):
return t return t
t_type = hint_type t_type = hint_type
if isinstance(t, Tensor): if isinstance(t, Tensor):
#check if the Tensor in shape of Tensor(4)
if t.dim() == 0:
value = t.asnumpy()
return Tensor([value], dtype=t_type)
#convert the type of tensor to dtype #convert the type of tensor to dtype
return Tensor(t.asnumpy(), dtype=t_type) return Tensor(t.asnumpy(), dtype=t_type)
if isinstance(t, (list, np.ndarray)): if isinstance(t, (list, np.ndarray)):
...@@ -56,7 +52,7 @@ def cast_to_tensor(t, hint_type=mstype.float32): ...@@ -56,7 +52,7 @@ def cast_to_tensor(t, hint_type=mstype.float32):
if isinstance(t, bool): if isinstance(t, bool):
raise TypeError(f'Input cannot be Type Bool') raise TypeError(f'Input cannot be Type Bool')
if isinstance(t, (int, float)): if isinstance(t, (int, float)):
return Tensor([t], dtype=t_type) return Tensor(t, dtype=t_type)
raise TypeError("Input type is not supported.") raise TypeError("Input type is not supported.")
def convert_to_batch(t, batch_shape, required_type): def convert_to_batch(t, batch_shape, required_type):
......
...@@ -107,6 +107,7 @@ class Bernoulli(Distribution): ...@@ -107,6 +107,7 @@ class Bernoulli(Distribution):
self._probs = probs self._probs = probs
# ops needed for the class # ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
...@@ -284,8 +285,16 @@ class Bernoulli(Distribution): ...@@ -284,8 +285,16 @@ class Bernoulli(Distribution):
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None: if probs1 is None:
raise_none_error("probs") raise_none_error("probs")
origin_shape = shape + self.shape(probs1)
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
l_zero = self.const(0.0) l_zero = self.const(0.0)
h_one = self.const(1.0) h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed) sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed)
sample = self.less(sample_uniform, probs1) sample = self.less(sample_uniform, probs1)
return self.cast(sample, self.dtype) value = self.cast(sample, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
...@@ -111,6 +111,7 @@ class Exponential(Distribution): ...@@ -111,6 +111,7 @@ class Exponential(Distribution):
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
...@@ -276,8 +277,16 @@ class Exponential(Distribution): ...@@ -276,8 +277,16 @@ class Exponential(Distribution):
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None: if rate is None:
raise_none_error("rate") raise_none_error("rate")
origin_shape = shape + self.shape(rate)
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
minval = self.const(self.minval) minval = self.const(self.minval)
maxval = self.const(1.0) maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed) sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
sample = -self.log(sample_uniform) / rate sample = -self.log(sample_uniform) / rate
return self.cast(sample, self.dtype) value = self.cast(sample, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
...@@ -112,6 +112,7 @@ class Geometric(Distribution): ...@@ -112,6 +112,7 @@ class Geometric(Distribution):
self.minval = np.finfo(np.float).tiny self.minval = np.finfo(np.float).tiny
# ops needed for the class # ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
...@@ -283,8 +284,16 @@ class Geometric(Distribution): ...@@ -283,8 +284,16 @@ class Geometric(Distribution):
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None: if probs1 is None:
raise_none_error("probs") raise_none_error("probs")
origin_shape = shape + self.shape(probs1)
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
minval = self.const(self.minval) minval = self.const(self.minval)
maxval = self.const(1.0) maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), minval, maxval, self.seed) sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs1)) sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs1))
return self.cast(sample, self.dtype) value = self.cast(sample, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
...@@ -114,6 +114,7 @@ class Normal(Distribution): ...@@ -114,6 +114,7 @@ class Normal(Distribution):
#ops needed for the class #ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.erf = P.Erf() self.erf = P.Erf()
...@@ -305,7 +306,14 @@ class Normal(Distribution): ...@@ -305,7 +306,14 @@ class Normal(Distribution):
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None: if sd is None:
raise_none_error("sd") raise_none_error("sd")
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd)) batch_shape = self.shape(mean + sd)
sample_shape = shape + batch_shape origin_shape = shape + batch_shape
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
sample_norm = C.normal(sample_shape, mean, sd, self.seed) sample_norm = C.normal(sample_shape, mean, sd, self.seed)
return sample_norm value = self.cast(sample_norm, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
...@@ -112,6 +112,7 @@ class Uniform(Distribution): ...@@ -112,6 +112,7 @@ class Uniform(Distribution):
self._high = high self._high = high
# ops needed for the class # ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast() self.cast = P.Cast()
self.const = P.ScalarToArray() self.const = P.ScalarToArray()
self.dtypeop = P.DType() self.dtypeop = P.DType()
...@@ -327,8 +328,16 @@ class Uniform(Distribution): ...@@ -327,8 +328,16 @@ class Uniform(Distribution):
if high is None: if high is None:
raise_none_error("high") raise_none_error("high")
broadcast_shape = self.shape(low + high) broadcast_shape = self.shape(low + high)
origin_shape = shape + broadcast_shape
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
l_zero = self.const(0.0) l_zero = self.const(0.0)
h_one = self.const(1.0) h_one = self.const(1.0)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed) sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed)
sample = (high - low) * sample_uniform + low sample = (high - low) * sample_uniform + low
return self.cast(sample, self.dtype) value = self.cast(sample, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册