未验证 提交 9ec9fc0f 编写于 作者: Z Zhong Hui 提交者: GitHub

fix the set dtype bug of uniform_random op,support set the dtype

fix the bug in inferVartype in the uniform_random op, add the support the set of dtype 
上级 6bf26ef1
...@@ -14322,7 +14322,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): ...@@ -14322,7 +14322,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
helper = LayerHelper("uniform_random", **locals()) helper = LayerHelper("uniform_random", **locals())
inputs = dict() inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max} attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
if in_dygraph_mode(): if in_dygraph_mode():
attrs['shape'] = shape attrs['shape'] = shape
else: else:
......
...@@ -190,6 +190,12 @@ class TestUniformRandomOpError(unittest.TestCase): ...@@ -190,6 +190,12 @@ class TestUniformRandomOpError(unittest.TestCase):
self.assertRaises(TypeError, test_dtype) self.assertRaises(TypeError, test_dtype)
def test_out_dtype():
out = fluid.layers.uniform_random(shape=[3, 4], dtype='float64')
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64)
test_out_dtype()
class TestUniformRandomOpWithDiagInit(TestUniformRandomOp): class TestUniformRandomOpWithDiagInit(TestUniformRandomOp):
def init_attrs(self): def init_attrs(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册