From fcae408a261ed99bfb699663d079091d4df9bc9a Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Tue, 31 Dec 2019 14:11:04 +0800 Subject: [PATCH] Support weighted losses --- model.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/model.py b/model.py index 58f5446..f8e4658 100644 --- a/model.py +++ b/model.py @@ -168,14 +168,18 @@ class StaticGraphAdapter(object): label_vars = [] if self.mode != 'test': losses = [] - for o, l in zip(outputs, self.model._loss_functions): + loss_weights = self.model._loss_weights + if loss_weights is None: + loss_weights = [1. for _ in self.model._loss_functions] + for o, l, w in zip(outputs, self.model._loss_functions, + loss_weights): if l is None: continue label_var = self._infer_label_var(o, l) label_vars.append(label_var) loss_fn = getattr(fluid.layers, l) loss = loss_fn(o, label_var) - losses.append(fluid.layers.reduce_mean(loss)) + losses.append(fluid.layers.reduce_mean(loss) * w) outputs = losses if self.mode == 'train': self._label_vars = label_vars @@ -314,12 +318,16 @@ class DynamicGraphAdapter(object): def _loss(self, pred, labels): losses = [] - for o, l, t in zip(to_list(pred), self.model._loss_functions, labels): + loss_weights = self.model._loss_weights + if loss_weights is None: + loss_weights = [1. for _ in self.model._loss_functions] + for o, l, w, t in zip(to_list(pred), self.model._loss_functions, + loss_weights, labels): if l is None: continue loss_fn = getattr(fluid.layers, l) loss = loss_fn(o, to_variable(t)) - losses.append(fluid.layers.reduce_mean(loss)) + losses.append(fluid.layers.reduce_mean(loss) * w) return losses @@ -328,6 +336,7 @@ class Model(fluid.dygraph.Layer): super(Model, self).__init__(self.__class__.__name__) self.mode = 'train' self._loss_functions = [] + self._loss_weights = None self._optimizer = None if in_dygraph_mode(): self._adapter = DynamicGraphAdapter(self) @@ -349,9 +358,11 @@ class Model(fluid.dygraph.Layer): def load(self, *args, **kwargs): return self._adapter.load(*args, **kwargs) - def prepare(self, optimizer, loss_functions): + def prepare(self, optimizer, loss_functions, loss_weights=None): self._optimizer = optimizer self._loss_functions = to_list(loss_functions) + if loss_weights is not None: + self._loss_weights = to_list(loss_weights) def parameters(self, *args, **kwargs): return self._adapter.parameters(*args, **kwargs) -- GitLab