提交 86616ac5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4773 Fix empty shape issue in distribution sample functions

Merge pull request !4773 from peixu_ren/custom_bijector
......@@ -45,10 +45,6 @@ def cast_to_tensor(t, hint_type=mstype.float32):
return t
t_type = hint_type
if isinstance(t, Tensor):
#check if the Tensor in shape of Tensor(4)
if t.dim() == 0:
value = t.asnumpy()
return Tensor([value], dtype=t_type)
#convert the type of tensor to dtype
return Tensor(t.asnumpy(), dtype=t_type)
if isinstance(t, (list, np.ndarray)):
......@@ -56,7 +52,7 @@ def cast_to_tensor(t, hint_type=mstype.float32):
if isinstance(t, bool):
raise TypeError(f'Input cannot be Type Bool')
if isinstance(t, (int, float)):
return Tensor([t], dtype=t_type)
return Tensor(t, dtype=t_type)
raise TypeError("Input type is not supported.")
def convert_to_batch(t, batch_shape, required_type):
......
......@@ -107,6 +107,7 @@ class Bernoulli(Distribution):
self._probs = probs
# ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
......@@ -284,8 +285,16 @@ class Bernoulli(Distribution):
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
origin_shape = shape + self.shape(probs1)
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), l_zero, h_one, self.seed)
sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed)
sample = self.less(sample_uniform, probs1)
return self.cast(sample, self.dtype)
value = self.cast(sample, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
......@@ -111,6 +111,7 @@ class Exponential(Distribution):
self.minval = np.finfo(np.float).tiny
# ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
......@@ -276,8 +277,16 @@ class Exponential(Distribution):
rate = self.cast(rate, self.parameter_type) if rate is not None else self.rate
if rate is None:
raise_none_error("rate")
origin_shape = shape + self.shape(rate)
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(rate), minval, maxval, self.seed)
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
sample = -self.log(sample_uniform) / rate
return self.cast(sample, self.dtype)
value = self.cast(sample, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
......@@ -112,6 +112,7 @@ class Geometric(Distribution):
self.minval = np.finfo(np.float).tiny
# ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
......@@ -283,8 +284,16 @@ class Geometric(Distribution):
probs1 = self.cast(probs, self.parameter_type) if probs is not None else self.probs
if probs1 is None:
raise_none_error("probs")
origin_shape = shape + self.shape(probs1)
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
minval = self.const(self.minval)
maxval = self.const(1.0)
sample_uniform = self.uniform(shape + self.shape(probs1), minval, maxval, self.seed)
sample_uniform = self.uniform(sample_shape, minval, maxval, self.seed)
sample = self.floor(self.log(sample_uniform) / self.log(1.0 - probs1))
return self.cast(sample, self.dtype)
value = self.cast(sample, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
......@@ -114,6 +114,7 @@ class Normal(Distribution):
#ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.erf = P.Erf()
......@@ -305,7 +306,14 @@ class Normal(Distribution):
sd = self.cast(sd, self.parameter_type) if sd is not None else self._sd_value
if sd is None:
raise_none_error("sd")
batch_shape = self.shape(self.zeroslike(mean) + self.zeroslike(sd))
sample_shape = shape + batch_shape
batch_shape = self.shape(mean + sd)
origin_shape = shape + batch_shape
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
sample_norm = C.normal(sample_shape, mean, sd, self.seed)
return sample_norm
value = self.cast(sample_norm, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
......@@ -112,6 +112,7 @@ class Uniform(Distribution):
self._high = high
# ops needed for the class
self.squeeze = P.Squeeze(0)
self.cast = P.Cast()
self.const = P.ScalarToArray()
self.dtypeop = P.DType()
......@@ -327,8 +328,16 @@ class Uniform(Distribution):
if high is None:
raise_none_error("high")
broadcast_shape = self.shape(low + high)
origin_shape = shape + broadcast_shape
if origin_shape == ():
sample_shape = (1,)
else:
sample_shape = origin_shape
l_zero = self.const(0.0)
h_one = self.const(1.0)
sample_uniform = self.uniform(shape + broadcast_shape, l_zero, h_one, self.seed)
sample_uniform = self.uniform(sample_shape, l_zero, h_one, self.seed)
sample = (high - low) * sample_uniform + low
return self.cast(sample, self.dtype)
value = self.cast(sample, self.dtype)
if origin_shape == ():
value = self.squeeze(value)
return value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册