未验证 提交 491d4f01 编写于 作者: P pangyoki 提交者: GitHub

fix Uniform sample method (#37823)

上级 34a06cf5
...@@ -305,7 +305,8 @@ class Uniform(Distribution): ...@@ -305,7 +305,8 @@ class Uniform(Distribution):
else: else:
output_shape = shape + batch_shape output_shape = shape + batch_shape
output = nn.uniform_random( output = nn.uniform_random(
output_shape, seed=seed, dtype=self.dtype) * (tensor.zeros( output_shape, dtype=self.dtype, min=0., max=1.,
seed=seed) * (tensor.zeros(
output_shape, dtype=self.dtype) + (self.high - self.low)) output_shape, dtype=self.dtype) + (self.high - self.low))
output = elementwise_add(output, self.low, name=name) output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float: if self.all_arg_is_float:
......
...@@ -336,6 +336,29 @@ class UniformTest11(UniformTest): ...@@ -336,6 +336,29 @@ class UniformTest11(UniformTest):
name='values', shape=[dims], dtype='float32') name='values', shape=[dims], dtype='float32')
class UniformTestSample(unittest.TestCase):
def setUp(self):
self.init_param()
def init_param(self):
self.low = 3.0
self.high = 4.0
def test_uniform_sample(self):
paddle.disable_static()
uniform = Uniform(low=self.low, high=self.high)
s = uniform.sample([100])
self.assertTrue((s >= self.low).all())
self.assertTrue((s < self.high).all())
paddle.enable_static()
class UniformTestSample2(UniformTestSample):
def init_param(self):
self.low = -5.0
self.high = 2.0
class NormalNumpy(DistributionNumpy): class NormalNumpy(DistributionNumpy):
def __init__(self, loc, scale): def __init__(self, loc, scale):
self.loc = np.array(loc) self.loc = np.array(loc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册