未验证 提交 5f275aad 编写于 作者: P pangyoki 提交者: GitHub

fix sample method of Uniform and Normal class (#26713)

* fix sample shape error

* Add unittest

* change assert_allclose to assert_equal

* Add unittest doc

* fix encoding problem
上级 52a6ca0c
......@@ -243,10 +243,19 @@ class Uniform(Distribution):
zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.low.dtype, 0.)
uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp, zero_tmp.shape, min=0., max=1., seed=seed)
output = uniform_random_tmp * (zero_tmp + self.high - self.low
) + self.low
return nn.reshape(output, output_shape, name=name)
zero_tmp,
zero_tmp.shape,
dtype=convert_dtype(zero_tmp.dtype),
min=0.,
max=1.,
seed=seed)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
uniform_random_tmp_reshape = nn.reshape(uniform_random_tmp,
output_shape)
output = uniform_random_tmp_reshape * (
zero_tmp_reshape + self.high - self.low)
output = elementwise_add(output, self.low, name=name)
return output
else:
output_shape = shape + batch_shape
output = nn.uniform_random(
......@@ -446,11 +455,17 @@ class Normal(Distribution):
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.loc.dtype, 0.)
zero_tmp_shape = nn.shape(zero_tmp)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
zero_tmp_shape = nn.shape(zero_tmp_reshape)
normal_random_tmp = nn.gaussian_random(
zero_tmp_shape, mean=0., std=1., seed=seed)
output = normal_random_tmp * (zero_tmp + self.scale) + self.loc
return nn.reshape(output, output_shape, name=name)
zero_tmp_shape,
mean=0.,
std=1.,
seed=seed,
dtype=convert_dtype(self.loc.dtype))
output = normal_random_tmp * (zero_tmp_reshape + self.scale)
output = elementwise_add(output, self.loc, name=name)
return output
else:
output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed) * \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册