未验证 提交 5e0227dc 编写于 作者: G guofei 提交者: GitHub

[cherry-pick 1.8] Fix the unittests to surpport python3.8 (#27451)

test=release/1.8
上级 2747adff
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
"""Defination of Role Makers.""" """Defination of Role Makers."""
from __future__ import print_function from __future__ import print_function
from multiprocessing import Process, Manager import multiprocessing
import paddle.fluid as fluid import paddle.fluid as fluid
import os import os
import sys
import time import time
__all__ = [ __all__ = [
...@@ -602,7 +603,7 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -602,7 +603,7 @@ class GeneralRoleMaker(RoleMakerBase):
if ip_port != "": if ip_port != "":
self._http_ip_port = ip_port.split(":") self._http_ip_port = ip_port.split(":")
# it's for communication between processes # it's for communication between processes
self._manager = Manager() self._manager = multiprocessing.Manager()
# global dict to store status # global dict to store status
self._http_server_d = self._manager.dict() self._http_server_d = self._manager.dict()
# set running status of http server # set running status of http server
...@@ -636,9 +637,15 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -636,9 +637,15 @@ class GeneralRoleMaker(RoleMakerBase):
"all": len(worker_endpoints) + len(eplist) "all": len(worker_endpoints) + len(eplist)
} }
# child process for http server # child process for http server
self._http_server = Process( if sys.version_info >= (3, 8) and sys.platform == 'darwin':
target=self.__start_kv_server, self._http_server = multiprocessing.get_context(
args=(self._http_server_d, size_d)) 'fork').Process(
target=self.__start_kv_server,
args=(self._http_server_d, size_d))
else:
self._http_server = multiprocessing.Process(
target=self.__start_kv_server,
args=(self._http_server_d, size_d))
self._http_server.daemon = True self._http_server.daemon = True
# set running status to True # set running status to True
self._http_server_d["running"] = True self._http_server_d["running"] = True
......
...@@ -33,6 +33,14 @@ def execute(main_program, startup_program): ...@@ -33,6 +33,14 @@ def execute(main_program, startup_program):
exe.run(main_program) exe.run(main_program)
def get_vaild_warning_num(warning, w):
num = 0
for i in range(len(w)):
if warning in str(w[i].message):
num += 1
return num
class TestDeviceGuard(unittest.TestCase): class TestDeviceGuard(unittest.TestCase):
def test_device_guard(self): def test_device_guard(self):
main_program = fluid.Program() main_program = fluid.Program()
...@@ -108,7 +116,10 @@ class TestDeviceGuard(unittest.TestCase): ...@@ -108,7 +116,10 @@ class TestDeviceGuard(unittest.TestCase):
i = fluid.layers.increment(x=i, value=1, in_place=True) i = fluid.layers.increment(x=i, value=1, in_place=True)
fluid.layers.less_than(x=i, y=loop_len, cond=cond) fluid.layers.less_than(x=i, y=loop_len, cond=cond)
assert len(w) == 1 warning = "The Op(while) is not support to set device."
warning_num = get_vaild_warning_num(warning, w)
assert warning_num == 1
all_ops = main_program.global_block().ops all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops: for op in all_ops:
...@@ -138,7 +149,10 @@ class TestDeviceGuard(unittest.TestCase): ...@@ -138,7 +149,10 @@ class TestDeviceGuard(unittest.TestCase):
shape=[1], value=4.0, dtype='float32') shape=[1], value=4.0, dtype='float32')
result = fluid.layers.less_than(x=x, y=y, force_cpu=False) result = fluid.layers.less_than(x=x, y=y, force_cpu=False)
assert len(w) == 2 warning = "\'device_guard\' has higher priority when they are used at the same time."
warning_num = get_vaild_warning_num(warning, w)
assert warning_num == 2
all_ops = main_program.global_block().ops all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops: for op in all_ops:
......
...@@ -22,7 +22,7 @@ class TestDebugStringFramework(unittest.TestCase): ...@@ -22,7 +22,7 @@ class TestDebugStringFramework(unittest.TestCase):
def test_debug_str(self): def test_debug_str(self):
p = Program() p = Program()
p.current_block().create_var(name='t', shape=[0, 1]) p.current_block().create_var(name='t', shape=[0, 1])
self.assertRaises(ValueError, callableObj=p.__str__) self.assertRaises(ValueError, p.to_string, True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -50,7 +50,7 @@ class TestSaveModelWithoutVar(unittest.TestCase): ...@@ -50,7 +50,7 @@ class TestSaveModelWithoutVar(unittest.TestCase):
params_filename='params') params_filename='params')
expected_warn = "no variable in your model, please ensure there are any variables in your model to save" expected_warn = "no variable in your model, please ensure there are any variables in your model to save"
self.assertTrue(len(w) > 0) self.assertTrue(len(w) > 0)
self.assertTrue(expected_warn == str(w[0].message)) self.assertTrue(expected_warn == str(w[-1].message))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册