From ee3c643cc1a72eeba83c8bf2cff9ad4f1baec4c1 Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Fri, 19 Nov 2021 08:35:47 +0000 Subject: [PATCH] update combined loss for accelerate classification training speed --- ppcls/loss/__init__.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 102934d1..68739de2 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -44,12 +44,18 @@ class CombinedLoss(nn.Layer): def __call__(self, input, batch): loss_dict = {} - for idx, loss_func in enumerate(self.loss_func): - loss = loss_func(input, batch) - weight = self.loss_weight[idx] - loss = {key: loss[key] * weight for key in loss} + # just for accelerate classification traing speed + if len(self.loss_func) == 1: + loss = self.loss_func[0](input, batch) loss_dict.update(loss) - loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) + loss_dict["loss"] = list(loss.values())[0] + else: + for idx, loss_func in enumerate(self.loss_func): + loss = loss_func(input, batch) + weight = self.loss_weight[idx] + loss = {key: loss[key] * weight for key in loss} + loss_dict.update(loss) + loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) return loss_dict -- GitLab