diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index fbc961f4bffaa02aa51ddd334822de458f8781f9..839e3ed4d6e04b13f69e6c2cfc463e83aef130f7 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -100,6 +100,7 @@ class RunnerBase(object): fetch_period = int( envs.get_global_env("runner." + context["runner_name"] + ".print_interval", 20)) + scope = context["model"][model_name]["scope"] program = context["model"][model_name]["main_program"] reader = context["dataset"][reader_name] @@ -139,6 +140,9 @@ class RunnerBase(object): fetch_period = int( envs.get_global_env("runner." + context["runner_name"] + ".print_interval", 20)) + save_step_interval = int( + envs.get_global_env("runner." + context["runner_name"] + + ".save_step_interval", -1)) if context["is_infer"]: metrics = model_class.get_infer_results() else: @@ -202,6 +206,24 @@ class RunnerBase(object): metrics_logging.insert(1, seconds) begin_time = end_time logging.info(metrics_format.format(*metrics_logging)) + + if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context[ + "is_infer"] == False: + if context["fleet_mode"].upper() == "PS": + train_prog = context["model"][model_dict["name"]][ + "main_program"] + else: + train_prog = context["model"][model_dict["name"]][ + "default_main_program"] + startup_prog = context["model"][model_dict["name"]][ + "startup_program"] + with fluid.program_guard(train_prog, startup_prog): + self.save( + context, + is_fleet=context["is_fleet"], + epoch_id=None, + batch_id=batch_id) + batch_id += 1 except fluid.core.EOFException: reader.reset() @@ -314,7 +336,7 @@ class RunnerBase(object): exec_strategy=_exe_strategy) return program - def save(self, epoch_id, context, is_fleet=False): + def save(self, context, is_fleet=False, epoch_id=None, batch_id=None): def need_save(epoch_id, epoch_interval, is_last=False): name = "runner." + context["runner_name"] + "." total_epoch = int(envs.get_global_env(name + "epochs", 1)) @@ -371,7 +393,8 @@ class RunnerBase(object): assert dirname is not None dirname = os.path.join(dirname, str(epoch_id)) - + logging.info("\tsave epoch_id:%d model into: \"%s\"" % + (epoch_id, dirname)) if is_fleet: warnings.warn( "Save inference model in cluster training is not recommended! Using save checkpoint instead.", @@ -394,14 +417,35 @@ class RunnerBase(object): if dirname is None or dirname == "": return dirname = os.path.join(dirname, str(epoch_id)) + logging.info("\tsave epoch_id:%d model into: \"%s\"" % + (epoch_id, dirname)) + if is_fleet: + if context["fleet"].worker_index() == 0: + context["fleet"].save_persistables(context["exe"], dirname) + else: + fluid.io.save_persistables(context["exe"], dirname) + + def save_checkpoint_step(): + name = "runner." + context["runner_name"] + "." + save_interval = int( + envs.get_global_env(name + "save_step_interval", -1)) + dirname = envs.get_global_env(name + "save_step_path", None) + if dirname is None or dirname == "": + return + dirname = os.path.join(dirname, str(batch_id)) + logging.info("\tsave batch_id:%d model into: \"%s\"" % + (batch_id, dirname)) if is_fleet: if context["fleet"].worker_index() == 0: context["fleet"].save_persistables(context["exe"], dirname) else: fluid.io.save_persistables(context["exe"], dirname) - save_persistables() - save_inference_model() + if isinstance(epoch_id, int): + save_persistables() + save_inference_model() + if isinstance(batch_id, int): + save_checkpoint_step() class SingleRunner(RunnerBase): @@ -453,7 +497,7 @@ class SingleRunner(RunnerBase): startup_prog = context["model"][model_dict["name"]][ "startup_program"] with fluid.program_guard(train_prog, startup_prog): - self.save(epoch, context) + self.save(context=context, epoch_id=epoch) context["status"] = "terminal_pass" @@ -506,7 +550,7 @@ class PSRunner(RunnerBase): startup_prog = context["model"][model_dict["name"]][ "startup_program"] with fluid.program_guard(train_prog, startup_prog): - self.save(epoch, context, True) + self.save(context=context, is_fleet=True, epoch_id=epoch) context["status"] = "terminal_pass" @@ -539,7 +583,7 @@ class CollectiveRunner(RunnerBase): startup_prog = context["model"][model_dict["name"]][ "startup_program"] with fluid.program_guard(train_prog, startup_prog): - self.save(epoch, context, True) + self.save(context=context, is_fleet=True, epoch_id=epoch) context["status"] = "terminal_pass" diff --git a/core/utils/envs.py b/core/utils/envs.py index ddcc9a94b3adc47cda2023c4d9e196b9fb16faeb..6c2494a903ad821fecf4e3a5786606730e725ba6 100755 --- a/core/utils/envs.py +++ b/core/utils/envs.py @@ -20,7 +20,7 @@ import socket import sys import six import traceback -import six +import warnings global_envs = {} global_envs_flatten = {} @@ -98,6 +98,25 @@ def set_global_envs(envs): value = os_path_adapter(workspace_adapter(value)) global_envs[name] = value + for runner in envs["runner"]: + if "save_step_interval" in runner or "save_step_path" in runner: + phase_name = runner["phases"] + phase = [ + phase for phase in envs["phase"] + if phase["name"] == phase_name[0] + ] + dataset_name = phase[0].get("dataset_name") + dataset = [ + dataset for dataset in envs["dataset"] + if dataset["name"] == dataset_name + ] + if dataset[0].get("type") == "QueueDataset": + runner["save_step_interval"] = None + runner["save_step_path"] = None + warnings.warn( + "QueueDataset can not support save by step, please not config save_step_interval and save_step_path in your yaml" + ) + if get_platform() != "LINUX": for dataset in envs["dataset"]: name = ".".join(["dataset", dataset["name"], "type"]) diff --git a/doc/yaml.md b/doc/yaml.md index 4c517f43e76872246bc1919d01f77c55e56104ea..541b817d123ddd897b2056bc43b25d7aee78ddc2 100644 --- a/doc/yaml.md +++ b/doc/yaml.md @@ -27,6 +27,8 @@ | init_model_path | string | 路径 | 否 | 初始化模型地址 | | save_checkpoint_interval | int | >= 1 | 否 | Save参数的轮数间隔 | | save_checkpoint_path | string | 路径 | 否 | Save参数的地址 | +| save_step_interval | int | >= 1 | 否 | Step save参数的batch数间隔 | +| save_step_path | string | 路径 | 否 | Step save参数的地址 | | save_inference_interval | int | >= 1 | 否 | Save预测模型的轮数间隔 | | save_inference_path | string | 路径 | 否 | Save预测模型的地址 | | save_inference_feed_varnames | list[string] | 组网中指定Variable的name | 否 | 预测模型的入口变量name | diff --git a/models/rank/dnn/config.yaml b/models/rank/dnn/config.yaml index fdb4470b492ce1b1aece958ca20ac78903f84b46..75826684dbc0734e4acf40983bbc837c7b97ac84 100755 --- a/models/rank/dnn/config.yaml +++ b/models/rank/dnn/config.yaml @@ -114,6 +114,23 @@ runner: print_interval: 1 phases: [phase1] +- name: single_multi_gpu_train + class: train + # num of epochs + epochs: 1 + # device to run training or infer + device: gpu + selected_gpus: "0,1" # 选择多卡执行训练 + save_checkpoint_interval: 1 # save model interval of epochs + save_inference_interval: 4 # save inference + save_step_interval: 1 + save_checkpoint_path: "increment_dnn" # save checkpoint path + save_inference_path: "inference" # save inference path + save_step_path: "step_save" + save_inference_feed_varnames: [] # feed vars of save inference + save_inference_fetch_varnames: [] # fetch vars of save inference + print_interval: 1 + phases: [phase1] # runner will run all the phase in each epoch phase: - name: phase1