提交 592f5b10 编写于 作者: C cuicheng01

support batch mix

上级 db46dbd8
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
Eval: Eval:
- CELoss: - CELoss:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
Eval: Eval:
- CELoss: - CELoss:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -22,7 +22,7 @@ Arch: ...@@ -22,7 +22,7 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- CELoss: - MixCELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
Eval: Eval:
......
...@@ -173,9 +173,12 @@ class Trainer(object): ...@@ -173,9 +173,12 @@ class Trainer(object):
out = self.model(batch[0]) out = self.model(batch[0])
else: else:
out = self.model(batch[0], batch[1]) out = self.model(batch[0], batch[1])
# calc loss # calc loss
loss_dict = self.train_loss_func(out, batch[1]) if self.config["DataLoader"]["Train"]["dataset"].get(
"batch_transform_ops", None):
loss_dict = self.train_loss_func(out, batch[1:])
else:
loss_dict = self.train_loss_func(out, batch[1])
for key in loss_dict: for key in loss_dict:
if not key in output_info: if not key in output_info:
......
...@@ -4,7 +4,7 @@ import paddle ...@@ -4,7 +4,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
from ppcls.utils import logger from ppcls.utils import logger
from .celoss import CELoss from .celoss import CELoss, MixCELoss
from .googlenetloss import GoogLeNetLoss from .googlenetloss import GoogLeNetLoss
from .centerloss import CenterLoss from .centerloss import CenterLoss
from .emlloss import EmlLoss from .emlloss import EmlLoss
...@@ -30,7 +30,6 @@ class CombinedLoss(nn.Layer): ...@@ -30,7 +30,6 @@ class CombinedLoss(nn.Layer):
assert isinstance(config_list, list), ( assert isinstance(config_list, list), (
'operator config should be a list') 'operator config should be a list')
for config in config_list: for config in config_list:
print(config)
assert isinstance(config, assert isinstance(config,
dict) and len(config) == 1, "yaml format error" dict) and len(config) == 1, "yaml format error"
name = list(config)[0] name = list(config)[0]
......
...@@ -18,6 +18,10 @@ import paddle.nn.functional as F ...@@ -18,6 +18,10 @@ import paddle.nn.functional as F
class CELoss(nn.Layer): class CELoss(nn.Layer):
"""
Cross entropy loss
"""
def __init__(self, epsilon=None): def __init__(self, epsilon=None):
super().__init__() super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1): if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
...@@ -50,3 +54,21 @@ class CELoss(nn.Layer): ...@@ -50,3 +54,21 @@ class CELoss(nn.Layer):
loss = F.cross_entropy(x, label=label, soft_label=soft_label) loss = F.cross_entropy(x, label=label, soft_label=soft_label)
loss = loss.mean() loss = loss.mean()
return {"CELoss": loss} return {"CELoss": loss}
class MixCELoss(CELoss):
"""
Cross entropy loss with mix(mixup, cutmix, fixmix)
"""
def __init__(self, epsilon=None):
super().__init__()
self.epsilon = epsilon
def __call__(self, input, batch):
target0, target1, lam = batch
loss0 = super().forward(input, target0)["CELoss"]
loss1 = super().forward(input, target1)["CELoss"]
loss = lam * loss0 + (1.0 - lam) * loss1
loss = paddle.mean(loss)
return {"MixCELoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册