From 4310c4118835cfa310f54871dc66c4c6fc17e31d Mon Sep 17 00:00:00 2001 From: malin10 Date: Mon, 27 Jul 2020 19:54:39 +0800 Subject: [PATCH] update --- core/trainers/framework/runner.py | 33 ++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index db00c7b0..aeab7c4a 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -77,9 +77,10 @@ class RunnerBase(object): name = "dataset." + reader_name + "." if envs.get_global_env(name + "type") == "DataLoader": - self._executor_dataloader_train(model_dict, context) + return self._executor_dataloader_train(model_dict, context) else: self._executor_dataset_train(model_dict, context) + return None def _executor_dataset_train(self, model_dict, context): reader_name = model_dict["dataset_name"] @@ -137,8 +138,10 @@ class RunnerBase(object): metrics_varnames = [] metrics_format = [] + metrics_names = ["total_batch"] metrics_format.append("{}: {{}}".format("batch")) for name, var in metrics.items(): + metrics_names.append(name) metrics_varnames.append(var.name) metrics_format.append("{}: {{}}".format(name)) metrics_format = ", ".join(metrics_format) @@ -147,6 +150,7 @@ class RunnerBase(object): reader.start() batch_id = 0 scope = context["model"][model_name]["scope"] + result = None with fluid.scope_guard(scope): try: while True: @@ -169,7 +173,8 @@ class RunnerBase(object): reader.reset() if batch_id > 0: - print(metrics_format.format(*metrics)) + result = dict(zip(metrics_names, metrics)) + return result def _get_dataloader_program(self, model_dict, context): model_name = model_dict["name"] @@ -340,10 +345,16 @@ class SingleRunner(RunnerBase): for epoch in range(epochs): for model_dict in context["phases"]: begin_time = time.time() - self._run(context, model_dict) + result = self._run(context, model_dict) end_time = time.time() seconds = end_time - begin_time - print("epoch {} done, use time: {}".format(epoch, seconds)) + message = "epoch {} done, use time: {}".format(epoch, seconds) + if not result is None: + for key in result: + if key.upper().startswith("BATCH_"): + continue + message += ", {}: {}".format(key, result[key]) + print(message) with fluid.scope_guard(context["model"][model_dict["name"]][ "scope"]): train_prog = context["model"][model_dict["name"]][ @@ -477,16 +488,24 @@ class SingleInferRunner(RunnerBase): def run(self, context): self._dir_check(context) + self.epoch_model_name_list.sort() for index, epoch_name in enumerate(self.epoch_model_name_list): for model_dict in context["phases"]: self._load(context, model_dict, self.epoch_model_path_list[index]) begin_time = time.time() - self._run(context, model_dict) + result = self._run(context, model_dict) end_time = time.time() seconds = end_time - begin_time - print("Infer {} of {} done, use time: {}".format(model_dict[ - "name"], epoch_name, seconds)) + message = "Infer {} of epoch {} done, use time: {}".format( + model_dict["name"], epoch_name, seconds) + if not result is None: + for key in result: + if key.upper().startswith("BATCH_"): + continue + message += ", {}: {}".format(key, result[key]) + print(message) + context["status"] = "terminal_pass" def _load(self, context, model_dict, model_path): -- GitLab