未验证 提交 9bea834e 编写于 作者: G guofei 提交者: GitHub

Refine the unittest to support py38 (#27208)

* Refine the unittest to support py38

    test=develop
上级 a7fadce8
...@@ -33,6 +33,14 @@ def execute(main_program, startup_program): ...@@ -33,6 +33,14 @@ def execute(main_program, startup_program):
exe.run(main_program) exe.run(main_program)
def get_vaild_warning_num(warning, w):
num = 0
for i in range(len(w)):
if warning in str(w[i].message):
num += 1
return num
class TestDeviceGuard(unittest.TestCase): class TestDeviceGuard(unittest.TestCase):
def test_device_guard(self): def test_device_guard(self):
main_program = fluid.Program() main_program = fluid.Program()
...@@ -133,7 +141,10 @@ class TestDeviceGuard(unittest.TestCase): ...@@ -133,7 +141,10 @@ class TestDeviceGuard(unittest.TestCase):
i = fluid.layers.increment(x=i, value=1, in_place=True) i = fluid.layers.increment(x=i, value=1, in_place=True)
fluid.layers.less_than(x=i, y=loop_len, cond=cond) fluid.layers.less_than(x=i, y=loop_len, cond=cond)
assert len(w) == 1 warning = "The Op(while) is not support to set device."
warning_num = get_vaild_warning_num(warning, w)
assert warning_num == 1
all_ops = main_program.global_block().ops all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops: for op in all_ops:
...@@ -169,7 +180,10 @@ class TestDeviceGuard(unittest.TestCase): ...@@ -169,7 +180,10 @@ class TestDeviceGuard(unittest.TestCase):
shape=[1], value=4.0, dtype='float32') shape=[1], value=4.0, dtype='float32')
result = fluid.layers.less_than(x=x, y=y, force_cpu=False) result = fluid.layers.less_than(x=x, y=y, force_cpu=False)
assert len(w) == 2 warning = "\'device_guard\' has higher priority when they are used at the same time."
warning_num = get_vaild_warning_num(warning, w)
assert warning_num == 2
all_ops = main_program.global_block().ops all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops: for op in all_ops:
......
...@@ -216,7 +216,7 @@ class API_TestGather(unittest.TestCase): ...@@ -216,7 +216,7 @@ class API_TestGather(unittest.TestCase):
"index": index_np, "index": index_np,
'axis': axis_np}, 'axis': axis_np},
fetch_list=[out]) fetch_list=[out])
expected_output = gather_numpy(x_np, index_np, axis_np) expected_output = gather_numpy(x_np, index_np, axis_np[0])
self.assertTrue(np.allclose(result, expected_output)) self.assertTrue(np.allclose(result, expected_output))
......
...@@ -50,7 +50,7 @@ class TestSaveModelWithoutVar(unittest.TestCase): ...@@ -50,7 +50,7 @@ class TestSaveModelWithoutVar(unittest.TestCase):
params_filename='params') params_filename='params')
expected_warn = "no variable in your model, please ensure there are any variables in your model to save" expected_warn = "no variable in your model, please ensure there are any variables in your model to save"
self.assertTrue(len(w) > 0) self.assertTrue(len(w) > 0)
self.assertTrue(expected_warn == str(w[0].message)) self.assertTrue(expected_warn == str(w[-1].message))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册