未验证 提交 04496d89 编写于 作者: Z zmxdream 提交者: GitHub

[PSCore]Fix test fleet base 2 (#38588)

上级 15cbf81b
...@@ -627,6 +627,8 @@ class Fleet(object): ...@@ -627,6 +627,8 @@ class Fleet(object):
""" """
self._runtime_handle._init_server(*args, **kwargs) self._runtime_handle._init_server(*args, **kwargs)
@is_non_distributed_check
@inited_runtime_handler
def load_model(self, path, mode): def load_model(self, path, mode):
""" """
load fleet model from path load fleet model from path
...@@ -699,6 +701,8 @@ class Fleet(object): ...@@ -699,6 +701,8 @@ class Fleet(object):
""" """
self._runtime_handle._stop_worker() self._runtime_handle._stop_worker()
@is_non_distributed_check
@inited_runtime_handler
def save(self, dirname, feed=[], fetch=[], **configs): def save(self, dirname, feed=[], fetch=[], **configs):
inference = True inference = True
...@@ -742,6 +746,8 @@ class Fleet(object): ...@@ -742,6 +746,8 @@ class Fleet(object):
self._runtime_handle._save_persistables( self._runtime_handle._save_persistables(
executor, dirname, main_program=None, mode=increment_mode) executor, dirname, main_program=None, mode=increment_mode)
@is_non_distributed_check
@inited_runtime_handler
def save_inference_model(self, def save_inference_model(self,
executor, executor,
dirname, dirname,
...@@ -777,6 +783,8 @@ class Fleet(object): ...@@ -777,6 +783,8 @@ class Fleet(object):
executor, dirname, feeded_var_names, target_vars, main_program, executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment, mode) export_for_deployment, mode)
@is_non_distributed_check
@inited_runtime_handler
def save_persistables(self, executor, dirname, main_program=None, mode=0): def save_persistables(self, executor, dirname, main_program=None, mode=0):
""" """
......
...@@ -62,10 +62,6 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): ...@@ -62,10 +62,6 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
self.assertEqual(sends, 0) self.assertEqual(sends, 0)
self.assertEqual(sgds, 0) self.assertEqual(sgds, 0)
fleet.init_worker()
time.sleep(8)
fleet.stop_worker()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -24,9 +24,9 @@ class TestFleetBase(unittest.TestCase): ...@@ -24,9 +24,9 @@ class TestFleetBase(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ["POD_IP"] = "127.0.0.1" os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36000" os.environ["PADDLE_PORT"] = "36000"
os.environ["PADDLE_TRAINERS_NUM"] = "2" os.environ["PADDLE_TRAINERS_NUM"] = "1"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ #os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001" # "127.0.0.1:36001,127.0.0.2:36001"
def test_ps_minimize(self): def test_ps_minimize(self):
import paddle import paddle
...@@ -78,45 +78,6 @@ class TestFleetBase(unittest.TestCase): ...@@ -78,45 +78,6 @@ class TestFleetBase(unittest.TestCase):
fleet.load_model(path="/tmp", mode=0) fleet.load_model(path="/tmp", mode=0)
fleet.load_model(path="/tmp", mode=1) fleet.load_model(path="/tmp", mode=1)
self.assertRaises(
Exception,
fleet.save_inference_model,
dirname='/tmp/',
feeded_var_names=['x', 'y'],
target_vars=[avg_cost],
executor="exe")
self.assertRaises(
Exception,
fleet.save_inference_model,
dirname='/tmp/',
feeded_var_names=['x', 'y'],
target_vars=[avg_cost],
executor=exe,
main_program=compiled_prog)
self.assertRaises(
Exception,
fleet.save_inference_model,
dirname='afs:/tmp/',
feeded_var_names=['x', 'y'],
target_vars=[avg_cost],
executor=exe,
main_program=compiled_prog)
self.assertRaises(
Exception, fleet.save_persistables, executor=pe, dirname='/tmp/')
self.assertRaises(
Exception, fleet.save_persistables, executor="exe", dirname='/tmp/')
self.assertRaises(
Exception,
fleet.save_persistables,
executor=exe,
dirname='/tmp/',
main_program=compiled_prog)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册