From 9ec9fc0f3630f3632e332e33858f251991d6d29e Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Sat, 25 Apr 2020 12:08:11 +0800 Subject: [PATCH] =?UTF-8?q?fix=20the=20set=20dtype=20bug=20of=20uniform=5F?= =?UTF-8?q?random=20op=EF=BC=8Csupport=20set=20the=20dtype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix the bug in inferVartype in the uniform_random op, add the support the set of dtype --- python/paddle/fluid/layers/nn.py | 2 +- .../paddle/fluid/tests/unittests/test_uniform_random_op.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 852345c997..c4cfe141c9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14322,7 +14322,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): helper = LayerHelper("uniform_random", **locals()) inputs = dict() - attrs = {'seed': seed, 'min': min, 'max': max} + attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype} if in_dygraph_mode(): attrs['shape'] = shape else: diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index bc939a5ac7..9aca04cabd 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -190,6 +190,12 @@ class TestUniformRandomOpError(unittest.TestCase): self.assertRaises(TypeError, test_dtype) + def test_out_dtype(): + out = fluid.layers.uniform_random(shape=[3, 4], dtype='float64') + self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64) + + test_out_dtype() + class TestUniformRandomOpWithDiagInit(TestUniformRandomOp): def init_attrs(self): -- GitLab