未验证 提交 45aefbc7 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #10592 from luotao1/fix_network_with_dtype

fix unittest-error: test_network_with_dtype
...@@ -24,7 +24,7 @@ BATCH_SIZE = 20 ...@@ -24,7 +24,7 @@ BATCH_SIZE = 20
class TestNetWithDtype(unittest.TestCase): class TestNetWithDtype(unittest.TestCase):
def setUp(self): def set_network(self):
self.dtype = "float64" self.dtype = "float64"
self.init_dtype() self.init_dtype()
self.x = fluid.layers.data(name='x', shape=[13], dtype=self.dtype) self.x = fluid.layers.data(name='x', shape=[13], dtype=self.dtype)
...@@ -55,12 +55,14 @@ class TestNetWithDtype(unittest.TestCase): ...@@ -55,12 +55,14 @@ class TestNetWithDtype(unittest.TestCase):
pass pass
def test_cpu(self): def test_cpu(self):
self.set_network()
place = fluid.CPUPlace() place = fluid.CPUPlace()
self.run_net_on_place(place) self.run_net_on_place(place)
def test_gpu(self): def test_gpu(self):
if not core.is_compiled_with_cuda(): if not core.is_compiled_with_cuda():
return return
self.set_network()
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
self.run_net_on_place(place) self.run_net_on_place(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册