From e9543dc8da07afa3c47e83a1b3849a1b37a5cab7 Mon Sep 17 00:00:00 2001 From: malin10 Date: Tue, 21 Jul 2020 21:29:14 +0800 Subject: [PATCH] update --- core/model.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/core/model.py b/core/model.py index 6293e28d..265f5311 100755 --- a/core/model.py +++ b/core/model.py @@ -41,6 +41,7 @@ class ModelBase(object): self._init_hyper_parameters() self._env = config self._slot_inited = False + self._clear_metrics = None def _init_hyper_parameters(self): pass @@ -111,8 +112,23 @@ class ModelBase(object): def get_infer_inputs(self): return self._infer_data_var + def get_clear_metrics(self): + if self._clear_metrics is not None: + return self._clear_metrics + self._clear_metrics = [] + for key in self._infer_results: + if isinstance(self._infer_results[key], Metric): + self._clear_metrics.append(self._infer_results[key]) + return self._clear_metrics + def get_infer_results(self): - return self._infer_results + res = dict() + for key in self._infer_results: + if isinstance(self._infer_results[key], Metric): + res.update(self._infer_results[key].get_result()) + elif isinstance(self._infer_results[key], Variable): + res[key] = self._infer_results[key] + return res def get_avg_cost(self): """R -- GitLab