未验证 提交 a22aaf70 编写于 作者: C Chengmo 提交者: GitHub

test=develop, fix test dist fleet geo unittest (#22287)

fix timeout of test_dist_fleet_geo
上级 2d20869c
...@@ -225,25 +225,6 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -225,25 +225,6 @@ class TestDistCTR2x2(FleetDistRunnerBase):
debug=False) debug=False)
pass_time = time.time() - pass_start pass_time = time.time() - pass_start
res_dict = dict()
res_dict['loss'] = self.avg_cost
class FH(fluid.executor.FetchHandler):
def handle(self, res_dict):
for key in res_dict:
v = res_dict[key]
print("{}: \n {}\n".format(key, v))
for epoch_id in range(1):
pass_start = time.time()
dataset.set_filelist(filelist)
exe.train_from_dataset(
program=fleet.main_program,
dataset=dataset,
fetch_handler=FH(var_dict=res_dict, period_secs=2),
debug=False)
pass_time = time.time() - pass_start
if os.getenv("SAVE_MODEL") == "1": if os.getenv("SAVE_MODEL") == "1":
model_dir = tempfile.mkdtemp() model_dir = tempfile.mkdtemp()
fleet.save_inference_model(exe, model_dir, fleet.save_inference_model(exe, model_dir,
...@@ -251,6 +232,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -251,6 +232,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
self.avg_cost) self.avg_cost)
self.check_model_right(model_dir) self.check_model_right(model_dir)
shutil.rmtree(model_dir) shutil.rmtree(model_dir)
fleet.stop_worker() fleet.stop_worker()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册