diff --git a/model.py b/model.py index 3a32bff3212f826d519fa873442e064c0c337385..ae11bdcbc79d038db17e0e2c5ccb2c499b8e4fde 100644 --- a/model.py +++ b/model.py @@ -80,7 +80,7 @@ class StaticGraphAdapter(object): self._startup_prog = fluid.default_startup_program() self._orig_prog = fluid.default_main_program() - self._label_vars = None # label variables + self._label_vars = {} # label variables self._endpoints = {} self._loss_endpoint = None self._executor = None @@ -226,7 +226,7 @@ class StaticGraphAdapter(object): if inputs[idx] is not None: feed[n] = inputs[idx] if labels is not None: - for idx, v in enumerate(self._label_vars): + for idx, v in enumerate(self._label_vars[self.mode]): feed[v.name] = labels[idx] endpoints = self._endpoints[self.mode] @@ -264,8 +264,8 @@ class StaticGraphAdapter(object): loss_fn = getattr(fluid.layers, l) loss = loss_fn(o, label_var) losses.append(fluid.layers.reduce_mean(loss) * w) + self._label_vars[self.mode] = label_vars if self.mode == 'train': - self._label_vars = label_vars self._loss_endpoint = fluid.layers.sum(losses) self.model._optimizer.minimize(self._loss_endpoint) self._progs[self.mode] = prog