From f7cf7b5b2538d3d7a5078c231f085efa69006629 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Fri, 3 Jan 2020 10:39:51 +0800 Subject: [PATCH] Keep track of label feed for eval mode --- model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model.py b/model.py index 3a32bff..ae11bdc 100644 --- a/model.py +++ b/model.py @@ -80,7 +80,7 @@ class StaticGraphAdapter(object): self._startup_prog = fluid.default_startup_program() self._orig_prog = fluid.default_main_program() - self._label_vars = None # label variables + self._label_vars = {} # label variables self._endpoints = {} self._loss_endpoint = None self._executor = None @@ -226,7 +226,7 @@ class StaticGraphAdapter(object): if inputs[idx] is not None: feed[n] = inputs[idx] if labels is not None: - for idx, v in enumerate(self._label_vars): + for idx, v in enumerate(self._label_vars[self.mode]): feed[v.name] = labels[idx] endpoints = self._endpoints[self.mode] @@ -264,8 +264,8 @@ class StaticGraphAdapter(object): loss_fn = getattr(fluid.layers, l) loss = loss_fn(o, label_var) losses.append(fluid.layers.reduce_mean(loss) * w) + self._label_vars[self.mode] = label_vars if self.mode == 'train': - self._label_vars = label_vars self._loss_endpoint = fluid.layers.sum(losses) self.model._optimizer.minimize(self._loss_endpoint) self._progs[self.mode] = prog -- GitLab