提交 e9543dc8 编写于 作者: M malin10

update

上级 71824d93
......@@ -41,6 +41,7 @@ class ModelBase(object):
self._init_hyper_parameters()
self._env = config
self._slot_inited = False
self._clear_metrics = None
def _init_hyper_parameters(self):
pass
......@@ -111,8 +112,23 @@ class ModelBase(object):
def get_infer_inputs(self):
return self._infer_data_var
def get_clear_metrics(self):
if self._clear_metrics is not None:
return self._clear_metrics
self._clear_metrics = []
for key in self._infer_results:
if isinstance(self._infer_results[key], Metric):
self._clear_metrics.append(self._infer_results[key])
return self._clear_metrics
def get_infer_results(self):
return self._infer_results
res = dict()
for key in self._infer_results:
if isinstance(self._infer_results[key], Metric):
res.update(self._infer_results[key].get_result())
elif isinstance(self._infer_results[key], Variable):
res[key] = self._infer_results[key]
return res
def get_avg_cost(self):
"""R
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册