提交 f7cf7b5b 编写于 作者: Y Yang Zhang

Keep track of label feed for eval mode

上级 8069b961
......@@ -80,7 +80,7 @@ class StaticGraphAdapter(object):
self._startup_prog = fluid.default_startup_program()
self._orig_prog = fluid.default_main_program()
self._label_vars = None # label variables
self._label_vars = {} # label variables
self._endpoints = {}
self._loss_endpoint = None
self._executor = None
......@@ -226,7 +226,7 @@ class StaticGraphAdapter(object):
if inputs[idx] is not None:
feed[n] = inputs[idx]
if labels is not None:
for idx, v in enumerate(self._label_vars):
for idx, v in enumerate(self._label_vars[self.mode]):
feed[v.name] = labels[idx]
endpoints = self._endpoints[self.mode]
......@@ -264,8 +264,8 @@ class StaticGraphAdapter(object):
loss_fn = getattr(fluid.layers, l)
loss = loss_fn(o, label_var)
losses.append(fluid.layers.reduce_mean(loss) * w)
self._label_vars[self.mode] = label_vars
if self.mode == 'train':
self._label_vars = label_vars
self._loss_endpoint = fluid.layers.sum(losses)
self.model._optimizer.minimize(self._loss_endpoint)
self._progs[self.mode] = prog
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册