未验证 提交 3b9c100b 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #236 from vslyu/fix_phase

fix phase & format save_step output information
...@@ -76,9 +76,6 @@ class Trainer(object): ...@@ -76,9 +76,6 @@ class Trainer(object):
_config = envs.load_yaml(config) _config = envs.load_yaml(config)
self._context["env"] = _config
self._context["dataset"] = _config.get("dataset")
phases = [] phases = []
if phase_names is None: if phase_names is None:
phases = _config.get("phase") phases = _config.get("phase")
...@@ -86,8 +83,10 @@ class Trainer(object): ...@@ -86,8 +83,10 @@ class Trainer(object):
for phase in _config.get("phase"): for phase in _config.get("phase"):
if phase["name"] in phase_names: if phase["name"] in phase_names:
phases.append(phase) phases.append(phase)
self._context["phases"] = phases self._context["phases"] = phases
_config["phase"] = phases
self._context["env"] = _config
self._context["dataset"] = _config.get("dataset")
print("PaddleRec: Runner {} Begin".format(self._runner_name)) print("PaddleRec: Runner {} Begin".format(self._runner_name))
self.which_engine() self.which_engine()
self.which_device() self.which_device()
......
...@@ -238,8 +238,8 @@ class PSNetwork(NetworkBase): ...@@ -238,8 +238,8 @@ class PSNetwork(NetworkBase):
else: else:
context["fleet"].init_worker() context["fleet"].init_worker()
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + dataset["name"] + type = envs.get_global_env("dataset." + phase["dataset_name"] +
".type") ".type")
if type == "DataLoader": if type == "DataLoader":
data_loader = DataLoader(context) data_loader = DataLoader(context)
...@@ -247,9 +247,9 @@ class PSNetwork(NetworkBase): ...@@ -247,9 +247,9 @@ class PSNetwork(NetworkBase):
model._data_loader) model._data_loader)
elif type == "QueueDataset": elif type == "QueueDataset":
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][phase[
"name"]] = dataset_class.create_dataset( "dataset_name"]] = dataset_class.create_dataset(
dataset["name"], context) phase["dataset_name"], context)
context["status"] = "startup_pass" context["status"] = "startup_pass"
def _build_strategy(self, context): def _build_strategy(self, context):
...@@ -336,7 +336,7 @@ class PslibNetwork(NetworkBase): ...@@ -336,7 +336,7 @@ class PslibNetwork(NetworkBase):
self._server(context) self._server(context)
else: else:
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + dataset["name"] + type = envs.get_global_env("dataset." + dataset["name"] +
".type") ".type")
if type == "DataLoader": if type == "DataLoader":
...@@ -363,6 +363,7 @@ class CollectiveNetwork(NetworkBase): ...@@ -363,6 +363,7 @@ class CollectiveNetwork(NetworkBase):
def build_network(self, context): def build_network(self, context):
context["model"] = {} context["model"] = {}
if len(context["env"]["phase"]) > 1: if len(context["env"]["phase"]) > 1:
print("CollectiveNetwork phase:{}".format(context["env"]["phase"]))
warnings.warn( warnings.warn(
"Cluster Train Only Support One Phase.", "Cluster Train Only Support One Phase.",
category=UserWarning, category=UserWarning,
...@@ -407,16 +408,17 @@ class CollectiveNetwork(NetworkBase): ...@@ -407,16 +408,17 @@ class CollectiveNetwork(NetworkBase):
context["model"][model_dict["name"]]["compiled_program"] = None context["model"][model_dict["name"]]["compiled_program"] = None
context["dataset"] = {} context["dataset"] = {}
for dataset in context["env"]["dataset"]: for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + dataset["name"] + ".type") type = envs.get_global_env("dataset." + phase["dataset_name"] +
".type")
if type == "QueueDataset": if type == "QueueDataset":
raise ValueError( raise ValueError(
"Collective don't support QueueDataset training, please use DataLoader." "Collective don't support QueueDataset training, please use DataLoader."
) )
dataset_class = QueueDataset(context) dataset_class = QueueDataset(context)
context["dataset"][dataset[ context["dataset"][phase[
"name"]] = dataset_class.create_dataset(dataset["name"], "dataset_name"]] = dataset_class.create_dataset(
context) phase["dataset_name"], context)
context["status"] = "startup_pass" context["status"] = "startup_pass"
def _build_strategy(self, context): def _build_strategy(self, context):
......
...@@ -436,9 +436,11 @@ class RunnerBase(object): ...@@ -436,9 +436,11 @@ class RunnerBase(object):
dirname = envs.get_global_env(name + "save_step_path", None) dirname = envs.get_global_env(name + "save_step_path", None)
if dirname is None or dirname == "": if dirname is None or dirname == "":
return return
dirname = os.path.join(dirname, str(batch_id)) dirname = os.path.join(dirname,
logging.info("\tsave batch_id:%d model into: \"%s\"" % "epoch_" + str(context["current_epoch"]) +
(batch_id, dirname)) "_batch_" + str(batch_id))
logging.info("\tsave epoch_id:%d, batch_id:%d model into: \"%s\"" %
(context["current_epoch"], batch_id, dirname))
if is_fleet: if is_fleet:
if context["fleet"].worker_index() == 0: if context["fleet"].worker_index() == 0:
context["fleet"].save_persistables(context["exe"], dirname) context["fleet"].save_persistables(context["exe"], dirname)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册