提交 592f5b10 编写于 作者: C cuicheng01

support batch mix

上级 db46dbd8
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
Eval:
- CELoss:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
Eval:
- CELoss:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
- MixCELoss:
weight: 1.0
epsilon: 0.1
Eval:
......
......@@ -173,9 +173,12 @@ class Trainer(object):
out = self.model(batch[0])
else:
out = self.model(batch[0], batch[1])
# calc loss
loss_dict = self.train_loss_func(out, batch[1])
if self.config["DataLoader"]["Train"]["dataset"].get(
"batch_transform_ops", None):
loss_dict = self.train_loss_func(out, batch[1:])
else:
loss_dict = self.train_loss_func(out, batch[1])
for key in loss_dict:
if not key in output_info:
......
......@@ -4,7 +4,7 @@ import paddle
import paddle.nn as nn
from ppcls.utils import logger
from .celoss import CELoss
from .celoss import CELoss, MixCELoss
from .googlenetloss import GoogLeNetLoss
from .centerloss import CenterLoss
from .emlloss import EmlLoss
......@@ -30,7 +30,6 @@ class CombinedLoss(nn.Layer):
assert isinstance(config_list, list), (
'operator config should be a list')
for config in config_list:
print(config)
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
......
......@@ -18,6 +18,10 @@ import paddle.nn.functional as F
class CELoss(nn.Layer):
"""
Cross entropy loss
"""
def __init__(self, epsilon=None):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
......@@ -50,3 +54,21 @@ class CELoss(nn.Layer):
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
loss = loss.mean()
return {"CELoss": loss}
class MixCELoss(CELoss):
"""
Cross entropy loss with mix(mixup, cutmix, fixmix)
"""
def __init__(self, epsilon=None):
super().__init__()
self.epsilon = epsilon
def __call__(self, input, batch):
target0, target1, lam = batch
loss0 = super().forward(input, target0)["CELoss"]
loss1 = super().forward(input, target1)["CELoss"]
loss = lam * loss0 + (1.0 - lam) * loss1
loss = paddle.mean(loss)
return {"MixCELoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册