未验证 提交 3b9c100b 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #236 from vslyu/fix_phase

fix phase & format save_step output information
......@@ -76,9 +76,6 @@ class Trainer(object):
_config = envs.load_yaml(config)
self._context["env"] = _config
self._context["dataset"] = _config.get("dataset")
phases = []
if phase_names is None:
phases = _config.get("phase")
......@@ -86,8 +83,10 @@ class Trainer(object):
for phase in _config.get("phase"):
if phase["name"] in phase_names:
phases.append(phase)
self._context["phases"] = phases
_config["phase"] = phases
self._context["env"] = _config
self._context["dataset"] = _config.get("dataset")
print("PaddleRec: Runner {} Begin".format(self._runner_name))
self.which_engine()
self.which_device()
......
......@@ -238,8 +238,8 @@ class PSNetwork(NetworkBase):
else:
context["fleet"].init_worker()
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
type = envs.get_global_env("dataset." + dataset["name"] +
for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + phase["dataset_name"] +
".type")
if type == "DataLoader":
data_loader = DataLoader(context)
......@@ -247,9 +247,9 @@ class PSNetwork(NetworkBase):
model._data_loader)
elif type == "QueueDataset":
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(
dataset["name"], context)
context["dataset"][phase[
"dataset_name"]] = dataset_class.create_dataset(
phase["dataset_name"], context)
context["status"] = "startup_pass"
def _build_strategy(self, context):
......@@ -336,7 +336,7 @@ class PslibNetwork(NetworkBase):
self._server(context)
else:
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + dataset["name"] +
".type")
if type == "DataLoader":
......@@ -363,6 +363,7 @@ class CollectiveNetwork(NetworkBase):
def build_network(self, context):
context["model"] = {}
if len(context["env"]["phase"]) > 1:
print("CollectiveNetwork phase:{}".format(context["env"]["phase"]))
warnings.warn(
"Cluster Train Only Support One Phase.",
category=UserWarning,
......@@ -407,16 +408,17 @@ class CollectiveNetwork(NetworkBase):
context["model"][model_dict["name"]]["compiled_program"] = None
context["dataset"] = {}
for dataset in context["env"]["dataset"]:
type = envs.get_global_env("dataset." + dataset["name"] + ".type")
for phase in context["env"]["phase"]:
type = envs.get_global_env("dataset." + phase["dataset_name"] +
".type")
if type == "QueueDataset":
raise ValueError(
"Collective don't support QueueDataset training, please use DataLoader."
)
dataset_class = QueueDataset(context)
context["dataset"][dataset[
"name"]] = dataset_class.create_dataset(dataset["name"],
context)
context["dataset"][phase[
"dataset_name"]] = dataset_class.create_dataset(
phase["dataset_name"], context)
context["status"] = "startup_pass"
def _build_strategy(self, context):
......
......@@ -436,9 +436,11 @@ class RunnerBase(object):
dirname = envs.get_global_env(name + "save_step_path", None)
if dirname is None or dirname == "":
return
dirname = os.path.join(dirname, str(batch_id))
logging.info("\tsave batch_id:%d model into: \"%s\"" %
(batch_id, dirname))
dirname = os.path.join(dirname,
"epoch_" + str(context["current_epoch"]) +
"_batch_" + str(batch_id))
logging.info("\tsave epoch_id:%d, batch_id:%d model into: \"%s\"" %
(context["current_epoch"], batch_id, dirname))
if is_fleet:
if context["fleet"].worker_index() == 0:
context["fleet"].save_persistables(context["exe"], dirname)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册