提交 ee3c643c 编写于 作者: D dongshuilong

update combined loss for accelerate classification training speed

上级 81746459
......@@ -44,6 +44,12 @@ class CombinedLoss(nn.Layer):
def __call__(self, input, batch):
loss_dict = {}
# just for accelerate classification traing speed
if len(self.loss_func) == 1:
loss = self.loss_func[0](input, batch)
loss_dict.update(loss)
loss_dict["loss"] = list(loss.values())[0]
else:
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch)
weight = self.loss_weight[idx]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册