提交 4310c411 编写于 作者: M malin10

update

上级 b02bc6ca
......@@ -77,9 +77,10 @@ class RunnerBase(object):
name = "dataset." + reader_name + "."
if envs.get_global_env(name + "type") == "DataLoader":
self._executor_dataloader_train(model_dict, context)
return self._executor_dataloader_train(model_dict, context)
else:
self._executor_dataset_train(model_dict, context)
return None
def _executor_dataset_train(self, model_dict, context):
reader_name = model_dict["dataset_name"]
......@@ -137,8 +138,10 @@ class RunnerBase(object):
metrics_varnames = []
metrics_format = []
metrics_names = ["total_batch"]
metrics_format.append("{}: {{}}".format("batch"))
for name, var in metrics.items():
metrics_names.append(name)
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
......@@ -147,6 +150,7 @@ class RunnerBase(object):
reader.start()
batch_id = 0
scope = context["model"][model_name]["scope"]
result = None
with fluid.scope_guard(scope):
try:
while True:
......@@ -169,7 +173,8 @@ class RunnerBase(object):
reader.reset()
if batch_id > 0:
print(metrics_format.format(*metrics))
result = dict(zip(metrics_names, metrics))
return result
def _get_dataloader_program(self, model_dict, context):
model_name = model_dict["name"]
......@@ -340,10 +345,16 @@ class SingleRunner(RunnerBase):
for epoch in range(epochs):
for model_dict in context["phases"]:
begin_time = time.time()
self._run(context, model_dict)
result = self._run(context, model_dict)
end_time = time.time()
seconds = end_time - begin_time
print("epoch {} done, use time: {}".format(epoch, seconds))
message = "epoch {} done, use time: {}".format(epoch, seconds)
if not result is None:
for key in result:
if key.upper().startswith("BATCH_"):
continue
message += ", {}: {}".format(key, result[key])
print(message)
with fluid.scope_guard(context["model"][model_dict["name"]][
"scope"]):
train_prog = context["model"][model_dict["name"]][
......@@ -477,16 +488,24 @@ class SingleInferRunner(RunnerBase):
def run(self, context):
self._dir_check(context)
self.epoch_model_name_list.sort()
for index, epoch_name in enumerate(self.epoch_model_name_list):
for model_dict in context["phases"]:
self._load(context, model_dict,
self.epoch_model_path_list[index])
begin_time = time.time()
self._run(context, model_dict)
result = self._run(context, model_dict)
end_time = time.time()
seconds = end_time - begin_time
print("Infer {} of {} done, use time: {}".format(model_dict[
"name"], epoch_name, seconds))
message = "Infer {} of epoch {} done, use time: {}".format(
model_dict["name"], epoch_name, seconds)
if not result is None:
for key in result:
if key.upper().startswith("BATCH_"):
continue
message += ", {}: {}".format(key, result[key])
print(message)
context["status"] = "terminal_pass"
def _load(self, context, model_dict, model_path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册