提交 fcae408a 编写于 作者: Y Yang Zhang

Support weighted losses

上级 90a90cf9
...@@ -168,14 +168,18 @@ class StaticGraphAdapter(object): ...@@ -168,14 +168,18 @@ class StaticGraphAdapter(object):
label_vars = [] label_vars = []
if self.mode != 'test': if self.mode != 'test':
losses = [] losses = []
for o, l in zip(outputs, self.model._loss_functions): loss_weights = self.model._loss_weights
if loss_weights is None:
loss_weights = [1. for _ in self.model._loss_functions]
for o, l, w in zip(outputs, self.model._loss_functions,
loss_weights):
if l is None: if l is None:
continue continue
label_var = self._infer_label_var(o, l) label_var = self._infer_label_var(o, l)
label_vars.append(label_var) label_vars.append(label_var)
loss_fn = getattr(fluid.layers, l) loss_fn = getattr(fluid.layers, l)
loss = loss_fn(o, label_var) loss = loss_fn(o, label_var)
losses.append(fluid.layers.reduce_mean(loss)) losses.append(fluid.layers.reduce_mean(loss) * w)
outputs = losses outputs = losses
if self.mode == 'train': if self.mode == 'train':
self._label_vars = label_vars self._label_vars = label_vars
...@@ -314,12 +318,16 @@ class DynamicGraphAdapter(object): ...@@ -314,12 +318,16 @@ class DynamicGraphAdapter(object):
def _loss(self, pred, labels): def _loss(self, pred, labels):
losses = [] losses = []
for o, l, t in zip(to_list(pred), self.model._loss_functions, labels): loss_weights = self.model._loss_weights
if loss_weights is None:
loss_weights = [1. for _ in self.model._loss_functions]
for o, l, w, t in zip(to_list(pred), self.model._loss_functions,
loss_weights, labels):
if l is None: if l is None:
continue continue
loss_fn = getattr(fluid.layers, l) loss_fn = getattr(fluid.layers, l)
loss = loss_fn(o, to_variable(t)) loss = loss_fn(o, to_variable(t))
losses.append(fluid.layers.reduce_mean(loss)) losses.append(fluid.layers.reduce_mean(loss) * w)
return losses return losses
...@@ -328,6 +336,7 @@ class Model(fluid.dygraph.Layer): ...@@ -328,6 +336,7 @@ class Model(fluid.dygraph.Layer):
super(Model, self).__init__(self.__class__.__name__) super(Model, self).__init__(self.__class__.__name__)
self.mode = 'train' self.mode = 'train'
self._loss_functions = [] self._loss_functions = []
self._loss_weights = None
self._optimizer = None self._optimizer = None
if in_dygraph_mode(): if in_dygraph_mode():
self._adapter = DynamicGraphAdapter(self) self._adapter = DynamicGraphAdapter(self)
...@@ -349,9 +358,11 @@ class Model(fluid.dygraph.Layer): ...@@ -349,9 +358,11 @@ class Model(fluid.dygraph.Layer):
def load(self, *args, **kwargs): def load(self, *args, **kwargs):
return self._adapter.load(*args, **kwargs) return self._adapter.load(*args, **kwargs)
def prepare(self, optimizer, loss_functions): def prepare(self, optimizer, loss_functions, loss_weights=None):
self._optimizer = optimizer self._optimizer = optimizer
self._loss_functions = to_list(loss_functions) self._loss_functions = to_list(loss_functions)
if loss_weights is not None:
self._loss_weights = to_list(loss_weights)
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
return self._adapter.parameters(*args, **kwargs) return self._adapter.parameters(*args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册