提交 d7ce065d 编写于 作者: D dengkaipeng

not reture metrics if emtpy

上级 9c4f9205
......@@ -317,7 +317,7 @@ class StaticGraphAdapter(object):
metrics = []
for metric, state in zip(self.model._metrics, metric_states):
metrics.append(metric.update(*state))
return losses, metrics
return (losses, metrics) if len(metrics) > 0 else losses
def _make_program(self, inputs):
prog = self._orig_prog.clone()
......@@ -453,7 +453,8 @@ class DynamicGraphAdapter(object):
metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels])
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
return [to_numpy(l) for l in losses], metrics
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]
def eval(self, inputs, labels, device='CPU', device_ids=None):
assert self.model._loss_function, \
......@@ -469,7 +470,8 @@ class DynamicGraphAdapter(object):
metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels])
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
return [to_numpy(l) for l in losses], metrics
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]
def test(self, inputs, device='CPU', device_ids=None):
super(Model, self.model).eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册