From 165dd8170383f916e78bf376da8ffb7b0656e4e4 Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Mon, 14 Sep 2020 21:42:07 +0800 Subject: [PATCH] add save model for step (#210) * add save model for step * add PS/Collective support , add definition in yaml.md * add check for QueueDataset while saving by step, modify dnn config yaml as an example * add logging about model save operation * fix bug for adding check for QueueDataset, modify dnn config yaml * fix bug for dnn model config yaml * fix bug for checking QueueDataset * fix UT fail * fix UT fail Co-authored-by: tangwei12 --- core/trainers/framework/runner.py | 59 +++++++++++++++++++++++++++---- core/utils/envs.py | 21 ++++++++++- doc/yaml.md | 2 ++ models/rank/dnn/config.yaml | 19 ++++++++++ 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 8bf64a71..c50e608e 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -100,6 +100,7 @@ class RunnerBase(object): fetch_period = int( envs.get_global_env("runner." + context["runner_name"] + ".print_interval", 20)) + scope = context["model"][model_name]["scope"] program = context["model"][model_name]["main_program"] reader = context["dataset"][reader_name] @@ -139,6 +140,9 @@ class RunnerBase(object): fetch_period = int( envs.get_global_env("runner." + context["runner_name"] + ".print_interval", 20)) + save_step_interval = int( + envs.get_global_env("runner." + context["runner_name"] + + ".save_step_interval", -1)) if context["is_infer"]: metrics = model_class.get_infer_results() else: @@ -202,6 +206,25 @@ class RunnerBase(object): metrics_logging = metrics.insert(1, seconds) begin_time = end_time logging.info(metrics_format.format(*metrics)) + + if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context[ + "is_infer"] == False: + if context["fleet_mode"]: + if context["fleet_mode"].upper() == "PS": + train_prog = context["model"][model_dict[ + "name"]]["main_program"] + elif not context["is_fleet"] or context[ + "fleet_mode"].upper() == "COLLECTIVE": + train_prog = context["model"][model_dict["name"]][ + "default_main_program"] + startup_prog = context["model"][model_dict["name"]][ + "startup_program"] + with fluid.program_guard(train_prog, startup_prog): + self.save( + context, + is_fleet=context["is_fleet"], + epoch_id=None, + batch_id=batch_id) batch_id += 1 except fluid.core.EOFException: reader.reset() @@ -314,7 +337,7 @@ class RunnerBase(object): exec_strategy=_exe_strategy) return program - def save(self, epoch_id, context, is_fleet=False): + def save(self, context, is_fleet=False, epoch_id=None, batch_id=None): def need_save(epoch_id, epoch_interval, is_last=False): name = "runner." + context["runner_name"] + "." total_epoch = int(envs.get_global_env(name + "epochs", 1)) @@ -371,7 +394,8 @@ class RunnerBase(object): assert dirname is not None dirname = os.path.join(dirname, str(epoch_id)) - + logging.info("\tsave epoch_id:%d model into: \"%s\"" % + (epoch_id, dirname)) if is_fleet: warnings.warn( "Save inference model in cluster training is not recommended! Using save checkpoint instead.", @@ -394,14 +418,35 @@ class RunnerBase(object): if dirname is None or dirname == "": return dirname = os.path.join(dirname, str(epoch_id)) + logging.info("\tsave epoch_id:%d model into: \"%s\"" % + (epoch_id, dirname)) + if is_fleet: + if context["fleet"].worker_index() == 0: + context["fleet"].save_persistables(context["exe"], dirname) + else: + fluid.io.save_persistables(context["exe"], dirname) + + def save_checkpoint_step(): + name = "runner." + context["runner_name"] + "." + save_interval = int( + envs.get_global_env(name + "save_step_interval", -1)) + dirname = envs.get_global_env(name + "save_step_path", None) + if dirname is None or dirname == "": + return + dirname = os.path.join(dirname, str(batch_id)) + logging.info("\tsave batch_id:%d model into: \"%s\"" % + (batch_id, dirname)) if is_fleet: if context["fleet"].worker_index() == 0: context["fleet"].save_persistables(context["exe"], dirname) else: fluid.io.save_persistables(context["exe"], dirname) - save_persistables() - save_inference_model() + if isinstance(epoch_id, int): + save_persistables() + save_inference_model() + if isinstance(batch_id, int): + save_checkpoint_step() class SingleRunner(RunnerBase): @@ -453,7 +498,7 @@ class SingleRunner(RunnerBase): startup_prog = context["model"][model_dict["name"]][ "startup_program"] with fluid.program_guard(train_prog, startup_prog): - self.save(epoch, context) + self.save(context=context, epoch_id=epoch) context["status"] = "terminal_pass" @@ -506,7 +551,7 @@ class PSRunner(RunnerBase): startup_prog = context["model"][model_dict["name"]][ "startup_program"] with fluid.program_guard(train_prog, startup_prog): - self.save(epoch, context, True) + self.save(context=context, is_fleet=True, epoch_id=epoch) context["status"] = "terminal_pass" @@ -539,7 +584,7 @@ class CollectiveRunner(RunnerBase): startup_prog = context["model"][model_dict["name"]][ "startup_program"] with fluid.program_guard(train_prog, startup_prog): - self.save(epoch, context, True) + self.save(context=context, is_fleet=True, epoch_id=epoch) context["status"] = "terminal_pass" diff --git a/core/utils/envs.py b/core/utils/envs.py index ddcc9a94..6c2494a9 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -20,7 +20,7 @@ import socket import sys import six import traceback -import six +import warnings global_envs = {} global_envs_flatten = {} @@ -98,6 +98,25 @@ def set_global_envs(envs): value = os_path_adapter(workspace_adapter(value)) global_envs[name] = value + for runner in envs["runner"]: + if "save_step_interval" in runner or "save_step_path" in runner: + phase_name = runner["phases"] + phase = [ + phase for phase in envs["phase"] + if phase["name"] == phase_name[0] + ] + dataset_name = phase[0].get("dataset_name") + dataset = [ + dataset for dataset in envs["dataset"] + if dataset["name"] == dataset_name + ] + if dataset[0].get("type") == "QueueDataset": + runner["save_step_interval"] = None + runner["save_step_path"] = None + warnings.warn( + "QueueDataset can not support save by step, please not config save_step_interval and save_step_path in your yaml" + ) + if get_platform() != "LINUX": for dataset in envs["dataset"]: name = ".".join(["dataset", dataset["name"], "type"]) diff --git a/doc/yaml.md b/doc/yaml.md index 4c517f43..541b817d 100644 --- a/doc/yaml.md +++ b/doc/yaml.md @@ -27,6 +27,8 @@ | init_model_path | string | 路径 | 否 | 初始化模型地址 | | save_checkpoint_interval | int | >= 1 | 否 | Save参数的轮数间隔 | | save_checkpoint_path | string | 路径 | 否 | Save参数的地址 | +| save_step_interval | int | >= 1 | 否 | Step save参数的batch数间隔 | +| save_step_path | string | 路径 | 否 | Step save参数的地址 | | save_inference_interval | int | >= 1 | 否 | Save预测模型的轮数间隔 | | save_inference_path | string | 路径 | 否 | Save预测模型的地址 | | save_inference_feed_varnames | list[string] | 组网中指定Variable的name | 否 | 预测模型的入口变量name | diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml index fdb4470b..f0c82462 100755 --- a/models/rank/dnn/config.yaml +++ b/models/rank/dnn/config.yaml @@ -114,6 +114,25 @@ runner: print_interval: 1 phases: [phase1] +- name: local_ps_train + class: local_cluster_train + # num of epochs + epochs: 1 + # device to run training or infer + device: cpu + selected_gpus: "0" # 选择多卡执行训练 + work_num: 1 + server_num: 1 + save_checkpoint_interval: 1 # save model interval of epochs + save_inference_interval: 4 # save inference + save_step_interval: 1 + save_checkpoint_path: "increment_dnn" # save checkpoint path + save_inference_path: "inference" # save inference path + save_step_path: "step_save" + save_inference_feed_varnames: [] # feed vars of save inference + save_inference_fetch_varnames: [] # fetch vars of save inference + print_interval: 1 + phases: [phase1] # runner will run all the phase in each epoch phase: - name: phase1 -- GitLab