From 592f5b10f50e7ba64beadbae71aba46f67c2266d Mon Sep 17 00:00:00 2001 From: cuicheng01 Date: Tue, 13 Jul 2021 06:38:01 +0000 Subject: [PATCH] support batch mix --- ppcls/configs/ImageNet/DarkNet/DarkNet53.yaml | 2 +- .../ImageNet/DataAugment/ResNet50_Cutmix.yaml | 2 +- .../ImageNet/DataAugment/ResNet50_Mixup.yaml | 2 +- .../ImageNet/Inception/InceptionV3.yaml | 2 +- .../ImageNet/Inception/InceptionV4.yaml | 2 +- .../Res2Net/Res2Net101_vd_26w_4s.yaml | 2 +- .../Res2Net/Res2Net200_vd_26w_4s.yaml | 2 +- .../ImageNet/Res2Net/Res2Net50_14w_8s.yaml | 2 +- .../ImageNet/Res2Net/Res2Net50_26w_4s.yaml | 2 +- .../ImageNet/Res2Net/Res2Net50_vd_26w_4s.yaml | 2 +- .../ImageNet/ResNeXt/ResNeXt101_vd_32x4d.yaml | 2 +- .../ImageNet/ResNeXt/ResNeXt101_vd_64x4d.yaml | 2 +- .../ImageNet/ResNeXt/ResNeXt152_vd_32x4d.yaml | 2 +- .../ImageNet/ResNeXt/ResNeXt152_vd_64x4d.yaml | 2 +- .../ImageNet/ResNeXt/ResNeXt50_vd_32x4d.yaml | 2 +- .../ImageNet/ResNeXt/ResNeXt50_vd_64x4d.yaml | 2 +- .../configs/ImageNet/ResNet/ResNet101_vd.yaml | 2 +- .../configs/ImageNet/ResNet/ResNet152_vd.yaml | 2 +- .../configs/ImageNet/ResNet/ResNet18_vd.yaml | 2 +- .../configs/ImageNet/ResNet/ResNet200_vd.yaml | 2 +- .../configs/ImageNet/ResNet/ResNet34_vd.yaml | 2 +- .../configs/ImageNet/ResNet/ResNet50_vd.yaml | 2 +- ppcls/configs/ImageNet/SENet/SENet154_vd.yaml | 2 +- .../ImageNet/SENet/SE_ResNeXt101_32x4d.yaml | 2 +- .../ImageNet/SENet/SE_ResNeXt50_32x4d.yaml | 2 +- .../ImageNet/SENet/SE_ResNeXt50_vd_32x4d.yaml | 2 +- .../ImageNet/SENet/SE_ResNet18_vd.yaml | 2 +- .../ImageNet/SENet/SE_ResNet34_vd.yaml | 2 +- .../ImageNet/SENet/SE_ResNet50_vd.yaml | 2 +- .../configs/ImageNet/Xception/Xception65.yaml | 2 +- .../configs/ImageNet/Xception/Xception71.yaml | 2 +- ppcls/engine/trainer.py | 7 ++++-- ppcls/loss/__init__.py | 3 +-- ppcls/loss/celoss.py | 22 +++++++++++++++++++ 34 files changed, 59 insertions(+), 35 deletions(-) diff --git a/ppcls/configs/ImageNet/DarkNet/DarkNet53.yaml b/ppcls/configs/ImageNet/DarkNet/DarkNet53.yaml index 1a55e75d..b69ccfcf 100644 --- a/ppcls/configs/ImageNet/DarkNet/DarkNet53.yaml +++ b/ppcls/configs/ImageNet/DarkNet/DarkNet53.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/DataAugment/ResNet50_Cutmix.yaml b/ppcls/configs/ImageNet/DataAugment/ResNet50_Cutmix.yaml index 6ab79d35..918a7629 100644 --- a/ppcls/configs/ImageNet/DataAugment/ResNet50_Cutmix.yaml +++ b/ppcls/configs/ImageNet/DataAugment/ResNet50_Cutmix.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 Eval: - CELoss: diff --git a/ppcls/configs/ImageNet/DataAugment/ResNet50_Mixup.yaml b/ppcls/configs/ImageNet/DataAugment/ResNet50_Mixup.yaml index 448440ec..b1256715 100644 --- a/ppcls/configs/ImageNet/DataAugment/ResNet50_Mixup.yaml +++ b/ppcls/configs/ImageNet/DataAugment/ResNet50_Mixup.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 Eval: - CELoss: diff --git a/ppcls/configs/ImageNet/Inception/InceptionV3.yaml b/ppcls/configs/ImageNet/Inception/InceptionV3.yaml index a8c30ea1..fa8b64a5 100644 --- a/ppcls/configs/ImageNet/Inception/InceptionV3.yaml +++ b/ppcls/configs/ImageNet/Inception/InceptionV3.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/Inception/InceptionV4.yaml b/ppcls/configs/ImageNet/Inception/InceptionV4.yaml index 17415b3c..6a6dbb62 100644 --- a/ppcls/configs/ImageNet/Inception/InceptionV4.yaml +++ b/ppcls/configs/ImageNet/Inception/InceptionV4.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/Res2Net/Res2Net101_vd_26w_4s.yaml b/ppcls/configs/ImageNet/Res2Net/Res2Net101_vd_26w_4s.yaml index bf27b303..7e5cbfd3 100644 --- a/ppcls/configs/ImageNet/Res2Net/Res2Net101_vd_26w_4s.yaml +++ b/ppcls/configs/ImageNet/Res2Net/Res2Net101_vd_26w_4s.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/Res2Net/Res2Net200_vd_26w_4s.yaml b/ppcls/configs/ImageNet/Res2Net/Res2Net200_vd_26w_4s.yaml index 90b7b879..edceda10 100644 --- a/ppcls/configs/ImageNet/Res2Net/Res2Net200_vd_26w_4s.yaml +++ b/ppcls/configs/ImageNet/Res2Net/Res2Net200_vd_26w_4s.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/Res2Net/Res2Net50_14w_8s.yaml b/ppcls/configs/ImageNet/Res2Net/Res2Net50_14w_8s.yaml index af1c4c73..1f3ecde9 100644 --- a/ppcls/configs/ImageNet/Res2Net/Res2Net50_14w_8s.yaml +++ b/ppcls/configs/ImageNet/Res2Net/Res2Net50_14w_8s.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/Res2Net/Res2Net50_26w_4s.yaml b/ppcls/configs/ImageNet/Res2Net/Res2Net50_26w_4s.yaml index e792e9d0..31ad95e6 100644 --- a/ppcls/configs/ImageNet/Res2Net/Res2Net50_26w_4s.yaml +++ b/ppcls/configs/ImageNet/Res2Net/Res2Net50_26w_4s.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/Res2Net/Res2Net50_vd_26w_4s.yaml b/ppcls/configs/ImageNet/Res2Net/Res2Net50_vd_26w_4s.yaml index 58d4968b..1157ac0c 100644 --- a/ppcls/configs/ImageNet/Res2Net/Res2Net50_vd_26w_4s.yaml +++ b/ppcls/configs/ImageNet/Res2Net/Res2Net50_vd_26w_4s.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNeXt/ResNeXt101_vd_32x4d.yaml b/ppcls/configs/ImageNet/ResNeXt/ResNeXt101_vd_32x4d.yaml index c400b9e2..4ac6ab70 100644 --- a/ppcls/configs/ImageNet/ResNeXt/ResNeXt101_vd_32x4d.yaml +++ b/ppcls/configs/ImageNet/ResNeXt/ResNeXt101_vd_32x4d.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNeXt/ResNeXt101_vd_64x4d.yaml b/ppcls/configs/ImageNet/ResNeXt/ResNeXt101_vd_64x4d.yaml index 4f5f3c79..1754e63a 100644 --- a/ppcls/configs/ImageNet/ResNeXt/ResNeXt101_vd_64x4d.yaml +++ b/ppcls/configs/ImageNet/ResNeXt/ResNeXt101_vd_64x4d.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNeXt/ResNeXt152_vd_32x4d.yaml b/ppcls/configs/ImageNet/ResNeXt/ResNeXt152_vd_32x4d.yaml index d3054143..5cfb972f 100644 --- a/ppcls/configs/ImageNet/ResNeXt/ResNeXt152_vd_32x4d.yaml +++ b/ppcls/configs/ImageNet/ResNeXt/ResNeXt152_vd_32x4d.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNeXt/ResNeXt152_vd_64x4d.yaml b/ppcls/configs/ImageNet/ResNeXt/ResNeXt152_vd_64x4d.yaml index c8b76d0f..a9590731 100644 --- a/ppcls/configs/ImageNet/ResNeXt/ResNeXt152_vd_64x4d.yaml +++ b/ppcls/configs/ImageNet/ResNeXt/ResNeXt152_vd_64x4d.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNeXt/ResNeXt50_vd_32x4d.yaml b/ppcls/configs/ImageNet/ResNeXt/ResNeXt50_vd_32x4d.yaml index 3a03646f..466dfb36 100644 --- a/ppcls/configs/ImageNet/ResNeXt/ResNeXt50_vd_32x4d.yaml +++ b/ppcls/configs/ImageNet/ResNeXt/ResNeXt50_vd_32x4d.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNeXt/ResNeXt50_vd_64x4d.yaml b/ppcls/configs/ImageNet/ResNeXt/ResNeXt50_vd_64x4d.yaml index c9b9a101..d2a2f86e 100644 --- a/ppcls/configs/ImageNet/ResNeXt/ResNeXt50_vd_64x4d.yaml +++ b/ppcls/configs/ImageNet/ResNeXt/ResNeXt50_vd_64x4d.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNet/ResNet101_vd.yaml b/ppcls/configs/ImageNet/ResNet/ResNet101_vd.yaml index f30ca077..83d1fc02 100644 --- a/ppcls/configs/ImageNet/ResNet/ResNet101_vd.yaml +++ b/ppcls/configs/ImageNet/ResNet/ResNet101_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNet/ResNet152_vd.yaml b/ppcls/configs/ImageNet/ResNet/ResNet152_vd.yaml index f3168c43..e09bb60c 100644 --- a/ppcls/configs/ImageNet/ResNet/ResNet152_vd.yaml +++ b/ppcls/configs/ImageNet/ResNet/ResNet152_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNet/ResNet18_vd.yaml b/ppcls/configs/ImageNet/ResNet/ResNet18_vd.yaml index 2dc6bba0..e0ba71a6 100644 --- a/ppcls/configs/ImageNet/ResNet/ResNet18_vd.yaml +++ b/ppcls/configs/ImageNet/ResNet/ResNet18_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNet/ResNet200_vd.yaml b/ppcls/configs/ImageNet/ResNet/ResNet200_vd.yaml index a52c8374..98de87e3 100644 --- a/ppcls/configs/ImageNet/ResNet/ResNet200_vd.yaml +++ b/ppcls/configs/ImageNet/ResNet/ResNet200_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNet/ResNet34_vd.yaml b/ppcls/configs/ImageNet/ResNet/ResNet34_vd.yaml index daae960b..9ff07171 100644 --- a/ppcls/configs/ImageNet/ResNet/ResNet34_vd.yaml +++ b/ppcls/configs/ImageNet/ResNet/ResNet34_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/ResNet/ResNet50_vd.yaml b/ppcls/configs/ImageNet/ResNet/ResNet50_vd.yaml index 0a2c4aa4..ba38350b 100644 --- a/ppcls/configs/ImageNet/ResNet/ResNet50_vd.yaml +++ b/ppcls/configs/ImageNet/ResNet/ResNet50_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/SENet/SENet154_vd.yaml b/ppcls/configs/ImageNet/SENet/SENet154_vd.yaml index f7f1ba0f..f8255a97 100644 --- a/ppcls/configs/ImageNet/SENet/SENet154_vd.yaml +++ b/ppcls/configs/ImageNet/SENet/SENet154_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d.yaml b/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d.yaml index 3b09c3fd..bf274618 100644 --- a/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d.yaml +++ b/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/SENet/SE_ResNeXt50_32x4d.yaml b/ppcls/configs/ImageNet/SENet/SE_ResNeXt50_32x4d.yaml index d04f298a..2c128692 100644 --- a/ppcls/configs/ImageNet/SENet/SE_ResNeXt50_32x4d.yaml +++ b/ppcls/configs/ImageNet/SENet/SE_ResNeXt50_32x4d.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/SENet/SE_ResNeXt50_vd_32x4d.yaml b/ppcls/configs/ImageNet/SENet/SE_ResNeXt50_vd_32x4d.yaml index cabff29b..48e6e420 100644 --- a/ppcls/configs/ImageNet/SENet/SE_ResNeXt50_vd_32x4d.yaml +++ b/ppcls/configs/ImageNet/SENet/SE_ResNeXt50_vd_32x4d.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/SENet/SE_ResNet18_vd.yaml b/ppcls/configs/ImageNet/SENet/SE_ResNet18_vd.yaml index fcaada93..20b3a0c4 100644 --- a/ppcls/configs/ImageNet/SENet/SE_ResNet18_vd.yaml +++ b/ppcls/configs/ImageNet/SENet/SE_ResNet18_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/SENet/SE_ResNet34_vd.yaml b/ppcls/configs/ImageNet/SENet/SE_ResNet34_vd.yaml index 69d15cca..7280e324 100644 --- a/ppcls/configs/ImageNet/SENet/SE_ResNet34_vd.yaml +++ b/ppcls/configs/ImageNet/SENet/SE_ResNet34_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/SENet/SE_ResNet50_vd.yaml b/ppcls/configs/ImageNet/SENet/SE_ResNet50_vd.yaml index f670c159..030dff93 100644 --- a/ppcls/configs/ImageNet/SENet/SE_ResNet50_vd.yaml +++ b/ppcls/configs/ImageNet/SENet/SE_ResNet50_vd.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/Xception/Xception65.yaml b/ppcls/configs/ImageNet/Xception/Xception65.yaml index 1b840677..2ff30d9d 100644 --- a/ppcls/configs/ImageNet/Xception/Xception65.yaml +++ b/ppcls/configs/ImageNet/Xception/Xception65.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/configs/ImageNet/Xception/Xception71.yaml b/ppcls/configs/ImageNet/Xception/Xception71.yaml index 7475a5f9..bda7ecfe 100644 --- a/ppcls/configs/ImageNet/Xception/Xception71.yaml +++ b/ppcls/configs/ImageNet/Xception/Xception71.yaml @@ -22,7 +22,7 @@ Arch: # loss function config for traing/eval process Loss: Train: - - CELoss: + - MixCELoss: weight: 1.0 epsilon: 0.1 Eval: diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index 37a53408..569d3b41 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -173,9 +173,12 @@ class Trainer(object): out = self.model(batch[0]) else: out = self.model(batch[0], batch[1]) - # calc loss - loss_dict = self.train_loss_func(out, batch[1]) + if self.config["DataLoader"]["Train"]["dataset"].get( + "batch_transform_ops", None): + loss_dict = self.train_loss_func(out, batch[1:]) + else: + loss_dict = self.train_loss_func(out, batch[1]) for key in loss_dict: if not key in output_info: diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index cee4b05a..5421f421 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -4,7 +4,7 @@ import paddle import paddle.nn as nn from ppcls.utils import logger -from .celoss import CELoss +from .celoss import CELoss, MixCELoss from .googlenetloss import GoogLeNetLoss from .centerloss import CenterLoss from .emlloss import EmlLoss @@ -30,7 +30,6 @@ class CombinedLoss(nn.Layer): assert isinstance(config_list, list), ( 'operator config should be a list') for config in config_list: - print(config) assert isinstance(config, dict) and len(config) == 1, "yaml format error" name = list(config)[0] diff --git a/ppcls/loss/celoss.py b/ppcls/loss/celoss.py index 54c37030..7bc3c06c 100644 --- a/ppcls/loss/celoss.py +++ b/ppcls/loss/celoss.py @@ -18,6 +18,10 @@ import paddle.nn.functional as F class CELoss(nn.Layer): + """ + Cross entropy loss + """ + def __init__(self, epsilon=None): super().__init__() if epsilon is not None and (epsilon <= 0 or epsilon >= 1): @@ -50,3 +54,21 @@ class CELoss(nn.Layer): loss = F.cross_entropy(x, label=label, soft_label=soft_label) loss = loss.mean() return {"CELoss": loss} + + +class MixCELoss(CELoss): + """ + Cross entropy loss with mix(mixup, cutmix, fixmix) + """ + + def __init__(self, epsilon=None): + super().__init__() + self.epsilon = epsilon + + def __call__(self, input, batch): + target0, target1, lam = batch + loss0 = super().forward(input, target0)["CELoss"] + loss1 = super().forward(input, target1)["CELoss"] + loss = lam * loss0 + (1.0 - lam) * loss1 + loss = paddle.mean(loss) + return {"MixCELoss": loss} -- GitLab