diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py index 0e3186f9eca6b119273276b6b7546c918f30f497..19e70892a219517a4f23c25b73ef0da9f0341897 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py @@ -225,25 +225,6 @@ class TestDistCTR2x2(FleetDistRunnerBase): debug=False) pass_time = time.time() - pass_start - res_dict = dict() - res_dict['loss'] = self.avg_cost - - class FH(fluid.executor.FetchHandler): - def handle(self, res_dict): - for key in res_dict: - v = res_dict[key] - print("{}: \n {}\n".format(key, v)) - - for epoch_id in range(1): - pass_start = time.time() - dataset.set_filelist(filelist) - exe.train_from_dataset( - program=fleet.main_program, - dataset=dataset, - fetch_handler=FH(var_dict=res_dict, period_secs=2), - debug=False) - pass_time = time.time() - pass_start - if os.getenv("SAVE_MODEL") == "1": model_dir = tempfile.mkdtemp() fleet.save_inference_model(exe, model_dir, @@ -251,6 +232,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): self.avg_cost) self.check_model_right(model_dir) shutil.rmtree(model_dir) + fleet.stop_worker()