未验证 提交 72e0969b 编写于 作者: H hong 提交者: GitHub

fix uniform random (#21009)

* fix uniform random; test=develop

* add uniform random test; test=develop
上级 e5e699ec
...@@ -17854,7 +17854,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): ...@@ -17854,7 +17854,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
inputs = dict() inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max} attrs = {'seed': seed, 'min': min, 'max': max}
if in_dygraph_mode(): if in_dygraph_mode():
attrs = {'shape': shape} attrs['shape'] = shape
else: else:
if isinstance(shape, Variable): if isinstance(shape, Variable):
shape.stop_gradient = True shape.stop_gradient = True
......
...@@ -416,5 +416,15 @@ class TestUniformRandomOpSelectedRowsShapeTensorList(unittest.TestCase): ...@@ -416,5 +416,15 @@ class TestUniformRandomOpSelectedRowsShapeTensorList(unittest.TestCase):
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomDygraphMode(unittest.TestCase):
def test_check_output(self):
with fluid.dygraph.guard():
x = fluid.layers.uniform_random(
[10], dtype="float32", min=0.0, max=1.0)
x_np = x.numpy()
for i in range(10):
self.assertTrue((x_np[i] > 0 and x_np[i] < 1.0))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册