diff --git a/core/trainer.py b/core/trainer.py index bbba6250529283d24389e2719b7110f8aa321973..8951e69f6f2a69270e1688979533b0346d27b521 100755 --- a/core/trainer.py +++ b/core/trainer.py @@ -76,9 +76,6 @@ class Trainer(object): _config = envs.load_yaml(config) - self._context["env"] = _config - self._context["dataset"] = _config.get("dataset") - phases = [] if phase_names is None: phases = _config.get("phase") @@ -86,8 +83,10 @@ class Trainer(object): for phase in _config.get("phase"): if phase["name"] in phase_names: phases.append(phase) - self._context["phases"] = phases + _config["phase"] = phases + self._context["env"] = _config + self._context["dataset"] = _config.get("dataset") print("PaddleRec: Runner {} Begin".format(self._runner_name)) self.which_engine() self.which_device() diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 7d7a8273b6a402bd163f653a7beb3900de899ae3..2a9a3a4003f36627aff4fe7ab4f86a4979c46525 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -238,8 +238,8 @@ class PSNetwork(NetworkBase): else: context["fleet"].init_worker() context["dataset"] = {} - for dataset in context["env"]["dataset"]: - type = envs.get_global_env("dataset." + dataset["name"] + + for phase in context["env"]["phase"]: + type = envs.get_global_env("dataset." + phase["dataset_name"] + ".type") if type == "DataLoader": data_loader = DataLoader(context) @@ -247,9 +247,9 @@ class PSNetwork(NetworkBase): model._data_loader) elif type == "QueueDataset": dataset_class = QueueDataset(context) - context["dataset"][dataset[ - "name"]] = dataset_class.create_dataset( - dataset["name"], context) + context["dataset"][phase[ + "dataset_name"]] = dataset_class.create_dataset( + phase["dataset_name"], context) context["status"] = "startup_pass" def _build_strategy(self, context): @@ -336,7 +336,7 @@ class PslibNetwork(NetworkBase): self._server(context) else: context["dataset"] = {} - for dataset in context["env"]["dataset"]: + for phase in context["env"]["phase"]: type = envs.get_global_env("dataset." + dataset["name"] + ".type") if type == "DataLoader": @@ -363,6 +363,7 @@ class CollectiveNetwork(NetworkBase): def build_network(self, context): context["model"] = {} if len(context["env"]["phase"]) > 1: + print("CollectiveNetwork phase:{}".format(context["env"]["phase"])) warnings.warn( "Cluster Train Only Support One Phase.", category=UserWarning, @@ -407,16 +408,17 @@ class CollectiveNetwork(NetworkBase): context["model"][model_dict["name"]]["compiled_program"] = None context["dataset"] = {} - for dataset in context["env"]["dataset"]: - type = envs.get_global_env("dataset." + dataset["name"] + ".type") + for phase in context["env"]["phase"]: + type = envs.get_global_env("dataset." + phase["dataset_name"] + + ".type") if type == "QueueDataset": raise ValueError( "Collective don't support QueueDataset training, please use DataLoader." ) dataset_class = QueueDataset(context) - context["dataset"][dataset[ - "name"]] = dataset_class.create_dataset(dataset["name"], - context) + context["dataset"][phase[ + "dataset_name"]] = dataset_class.create_dataset( + phase["dataset_name"], context) context["status"] = "startup_pass" def _build_strategy(self, context): diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 9064989ad467d3104adf50d64300b7903376e7ce..b837cd676d80f2cddddf5c2079c5f39e96a21ab7 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -436,9 +436,11 @@ class RunnerBase(object): dirname = envs.get_global_env(name + "save_step_path", None) if dirname is None or dirname == "": return - dirname = os.path.join(dirname, str(batch_id)) - logging.info("\tsave batch_id:%d model into: \"%s\"" % - (batch_id, dirname)) + dirname = os.path.join(dirname, + "epoch_" + str(context["current_epoch"]) + + "_batch_" + str(batch_id)) + logging.info("\tsave epoch_id:%d, batch_id:%d model into: \"%s\"" % + (context["current_epoch"], batch_id, dirname)) if is_fleet: if context["fleet"].worker_index() == 0: context["fleet"].save_persistables(context["exe"], dirname)