diff --git a/core/model.py b/core/model.py index 6293e28d3b6a07c794a8ed96214ba66976f7e563..265f5311d2a49601fb21addc9031358170a287fd 100755 --- a/core/model.py +++ b/core/model.py @@ -41,6 +41,7 @@ class ModelBase(object): self._init_hyper_parameters() self._env = config self._slot_inited = False + self._clear_metrics = None def _init_hyper_parameters(self): pass @@ -111,8 +112,23 @@ class ModelBase(object): def get_infer_inputs(self): return self._infer_data_var + def get_clear_metrics(self): + if self._clear_metrics is not None: + return self._clear_metrics + self._clear_metrics = [] + for key in self._infer_results: + if isinstance(self._infer_results[key], Metric): + self._clear_metrics.append(self._infer_results[key]) + return self._clear_metrics + def get_infer_results(self): - return self._infer_results + res = dict() + for key in self._infer_results: + if isinstance(self._infer_results[key], Metric): + res.update(self._infer_results[key].get_result()) + elif isinstance(self._infer_results[key], Variable): + res[key] = self._infer_results[key] + return res def get_avg_cost(self): """R