提交 e9543dc8 编写于 作者: M malin10

update

上级 71824d93
...@@ -41,6 +41,7 @@ class ModelBase(object): ...@@ -41,6 +41,7 @@ class ModelBase(object):
self._init_hyper_parameters() self._init_hyper_parameters()
self._env = config self._env = config
self._slot_inited = False self._slot_inited = False
self._clear_metrics = None
def _init_hyper_parameters(self): def _init_hyper_parameters(self):
pass pass
...@@ -111,8 +112,23 @@ class ModelBase(object): ...@@ -111,8 +112,23 @@ class ModelBase(object):
def get_infer_inputs(self): def get_infer_inputs(self):
return self._infer_data_var return self._infer_data_var
def get_clear_metrics(self):
if self._clear_metrics is not None:
return self._clear_metrics
self._clear_metrics = []
for key in self._infer_results:
if isinstance(self._infer_results[key], Metric):
self._clear_metrics.append(self._infer_results[key])
return self._clear_metrics
def get_infer_results(self): def get_infer_results(self):
return self._infer_results res = dict()
for key in self._infer_results:
if isinstance(self._infer_results[key], Metric):
res.update(self._infer_results[key].get_result())
elif isinstance(self._infer_results[key], Variable):
res[key] = self._infer_results[key]
return res
def get_avg_cost(self): def get_avg_cost(self):
"""R """R
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册