提交 0e833f95 编写于 作者: D dengkaipeng

to_variable before loss_function

上级 59f12446
...@@ -463,7 +463,7 @@ class DynamicGraphAdapter(object): ...@@ -463,7 +463,7 @@ class DynamicGraphAdapter(object):
self.mode = 'train' self.mode = 'train'
inputs = to_list(inputs) inputs = to_list(inputs)
if labels is not None: if labels is not None:
labels = to_list(labels) labels = [to_variable(l) for l in to_list(labels)]
outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs])) outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
final_loss = fluid.layers.sum(losses) final_loss = fluid.layers.sum(losses)
...@@ -472,7 +472,7 @@ class DynamicGraphAdapter(object): ...@@ -472,7 +472,7 @@ class DynamicGraphAdapter(object):
self.model.clear_gradients() self.model.clear_gradients()
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels]) metric_outs = metric.add_metric_op(outputs, to_list(labels))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
return ([to_numpy(l) for l in losses], metrics) \ return ([to_numpy(l) for l in losses], metrics) \
...@@ -483,7 +483,7 @@ class DynamicGraphAdapter(object): ...@@ -483,7 +483,7 @@ class DynamicGraphAdapter(object):
self.mode = 'eval' self.mode = 'eval'
inputs = to_list(inputs) inputs = to_list(inputs)
if labels is not None: if labels is not None:
labels = to_list(labels) labels = [to_variable(l) for l in to_list(labels)]
outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs])) outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))
if self.model._loss_function: if self.model._loss_function:
...@@ -493,7 +493,7 @@ class DynamicGraphAdapter(object): ...@@ -493,7 +493,7 @@ class DynamicGraphAdapter(object):
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels]) metric_outs = metric.add_metric_op(outputs, labels)
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册