提交 4310c411 编写于 作者: M malin10

update

上级 b02bc6ca
...@@ -77,9 +77,10 @@ class RunnerBase(object): ...@@ -77,9 +77,10 @@ class RunnerBase(object):
name = "dataset." + reader_name + "." name = "dataset." + reader_name + "."
if envs.get_global_env(name + "type") == "DataLoader": if envs.get_global_env(name + "type") == "DataLoader":
self._executor_dataloader_train(model_dict, context) return self._executor_dataloader_train(model_dict, context)
else: else:
self._executor_dataset_train(model_dict, context) self._executor_dataset_train(model_dict, context)
return None
def _executor_dataset_train(self, model_dict, context): def _executor_dataset_train(self, model_dict, context):
reader_name = model_dict["dataset_name"] reader_name = model_dict["dataset_name"]
...@@ -137,8 +138,10 @@ class RunnerBase(object): ...@@ -137,8 +138,10 @@ class RunnerBase(object):
metrics_varnames = [] metrics_varnames = []
metrics_format = [] metrics_format = []
metrics_names = ["total_batch"]
metrics_format.append("{}: {{}}".format("batch")) metrics_format.append("{}: {{}}".format("batch"))
for name, var in metrics.items(): for name, var in metrics.items():
metrics_names.append(name)
metrics_varnames.append(var.name) metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name)) metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format) metrics_format = ", ".join(metrics_format)
...@@ -147,6 +150,7 @@ class RunnerBase(object): ...@@ -147,6 +150,7 @@ class RunnerBase(object):
reader.start() reader.start()
batch_id = 0 batch_id = 0
scope = context["model"][model_name]["scope"] scope = context["model"][model_name]["scope"]
result = None
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
try: try:
while True: while True:
...@@ -169,7 +173,8 @@ class RunnerBase(object): ...@@ -169,7 +173,8 @@ class RunnerBase(object):
reader.reset() reader.reset()
if batch_id > 0: if batch_id > 0:
print(metrics_format.format(*metrics)) result = dict(zip(metrics_names, metrics))
return result
def _get_dataloader_program(self, model_dict, context): def _get_dataloader_program(self, model_dict, context):
model_name = model_dict["name"] model_name = model_dict["name"]
...@@ -340,10 +345,16 @@ class SingleRunner(RunnerBase): ...@@ -340,10 +345,16 @@ class SingleRunner(RunnerBase):
for epoch in range(epochs): for epoch in range(epochs):
for model_dict in context["phases"]: for model_dict in context["phases"]:
begin_time = time.time() begin_time = time.time()
self._run(context, model_dict) result = self._run(context, model_dict)
end_time = time.time() end_time = time.time()
seconds = end_time - begin_time seconds = end_time - begin_time
print("epoch {} done, use time: {}".format(epoch, seconds)) message = "epoch {} done, use time: {}".format(epoch, seconds)
if not result is None:
for key in result:
if key.upper().startswith("BATCH_"):
continue
message += ", {}: {}".format(key, result[key])
print(message)
with fluid.scope_guard(context["model"][model_dict["name"]][ with fluid.scope_guard(context["model"][model_dict["name"]][
"scope"]): "scope"]):
train_prog = context["model"][model_dict["name"]][ train_prog = context["model"][model_dict["name"]][
...@@ -477,16 +488,24 @@ class SingleInferRunner(RunnerBase): ...@@ -477,16 +488,24 @@ class SingleInferRunner(RunnerBase):
def run(self, context): def run(self, context):
self._dir_check(context) self._dir_check(context)
self.epoch_model_name_list.sort()
for index, epoch_name in enumerate(self.epoch_model_name_list): for index, epoch_name in enumerate(self.epoch_model_name_list):
for model_dict in context["phases"]: for model_dict in context["phases"]:
self._load(context, model_dict, self._load(context, model_dict,
self.epoch_model_path_list[index]) self.epoch_model_path_list[index])
begin_time = time.time() begin_time = time.time()
self._run(context, model_dict) result = self._run(context, model_dict)
end_time = time.time() end_time = time.time()
seconds = end_time - begin_time seconds = end_time - begin_time
print("Infer {} of {} done, use time: {}".format(model_dict[ message = "Infer {} of epoch {} done, use time: {}".format(
"name"], epoch_name, seconds)) model_dict["name"], epoch_name, seconds)
if not result is None:
for key in result:
if key.upper().startswith("BATCH_"):
continue
message += ", {}: {}".format(key, result[key])
print(message)
context["status"] = "terminal_pass" context["status"] = "terminal_pass"
def _load(self, context, model_dict, model_path): def _load(self, context, model_dict, model_path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册