From 491d4f01b0e33853606aa218d8fc0c6949f6398a Mon Sep 17 00:00:00 2001 From: pangyoki Date: Thu, 9 Dec 2021 14:30:26 +0800 Subject: [PATCH] fix Uniform sample method (#37823) --- python/paddle/distribution.py | 3 ++- .../tests/unittests/test_distribution.py | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/paddle/distribution.py b/python/paddle/distribution.py index e30d3e4c20a..cf198eab1e8 100644 --- a/python/paddle/distribution.py +++ b/python/paddle/distribution.py @@ -305,7 +305,8 @@ class Uniform(Distribution): else: output_shape = shape + batch_shape output = nn.uniform_random( - output_shape, seed=seed, dtype=self.dtype) * (tensor.zeros( + output_shape, dtype=self.dtype, min=0., max=1., + seed=seed) * (tensor.zeros( output_shape, dtype=self.dtype) + (self.high - self.low)) output = elementwise_add(output, self.low, name=name) if self.all_arg_is_float: diff --git a/python/paddle/fluid/tests/unittests/test_distribution.py b/python/paddle/fluid/tests/unittests/test_distribution.py index f1c12c90490..6cf2c5f6e2c 100644 --- a/python/paddle/fluid/tests/unittests/test_distribution.py +++ b/python/paddle/fluid/tests/unittests/test_distribution.py @@ -336,6 +336,29 @@ class UniformTest11(UniformTest): name='values', shape=[dims], dtype='float32') +class UniformTestSample(unittest.TestCase): + def setUp(self): + self.init_param() + + def init_param(self): + self.low = 3.0 + self.high = 4.0 + + def test_uniform_sample(self): + paddle.disable_static() + uniform = Uniform(low=self.low, high=self.high) + s = uniform.sample([100]) + self.assertTrue((s >= self.low).all()) + self.assertTrue((s < self.high).all()) + paddle.enable_static() + + +class UniformTestSample2(UniformTestSample): + def init_param(self): + self.low = -5.0 + self.high = 2.0 + + class NormalNumpy(DistributionNumpy): def __init__(self, loc, scale): self.loc = np.array(loc) -- GitLab