From f8978a2f0cfbf1f3a8d964dcec30d9020f0bec32 Mon Sep 17 00:00:00 2001 From: zh-hike <1583124882@qq.com> Date: Fri, 30 Dec 2022 08:05:52 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9cifar100=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../FixMatchCCSSL_cifar100_10000_4gpu.yaml | 7 +++---- .../FixMatchCCSSL_cifar10_4000_4gpu.yaml | 2 +- ppcls/data/dataloader/cifar.py | 12 +++++++++++- ppcls/engine/train/train_fixmatch_ccssl.py | 2 +- ppcls/utils/ema.py | 2 +- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml index 4e9fcf2b..e8055d5b 100644 --- a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml +++ b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml @@ -97,7 +97,6 @@ DataLoader: - RandCropImageV2: size: [32, 32] - NormalizeImage: - - Normalize: scale: 1.0/255.0 mean: [0.5071, 0.4867, 0.4408] std: [0.2675, 0.2565, 0.2761] @@ -152,7 +151,7 @@ DataLoader: order: hwc transform_ops_strong2: - - RandCropImageV2: + - RandomResizedCrop: size: [32, 32] - RandFlipImage: flip_code: 1 @@ -168,8 +167,8 @@ DataLoader: p: 0.2 - NormalizeImage: scale: 1.0/255.0 - mean: [0.5071, 0.4867, 0.4408] - std: [0.2675, 0.2565, 0.2761] + mean: [0., 0., 0.] + std: [1., 1., 1.] order: hwc diff --git a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml index 4b6381e8..22ec818e 100644 --- a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml +++ b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml @@ -60,7 +60,7 @@ Loss: UnLabelLoss: Train: - - CCSSLCeLoss: + - CCSSLCELoss: weight: 1. - SoftSupConLoss: weight: 1.0 diff --git a/ppcls/data/dataloader/cifar.py b/ppcls/data/dataloader/cifar.py index 4cf707aa..614f4f3e 100644 --- a/ppcls/data/dataloader/cifar.py +++ b/ppcls/data/dataloader/cifar.py @@ -89,12 +89,14 @@ class Cifar100(Cifar100_paddle): expand_labels=1, transform_ops=None, transform_ops_weak=None, - transform_ops_strong=None): + transform_ops_strong=None, + transform_ops_strong2=None): super().__init__(data_file, mode, None, download, backend) assert isinstance(expand_labels, int) self._transform_ops = create_operators(transform_ops) self._transform_ops_weak = create_operators(transform_ops_weak) self._transform_ops_strong = create_operators(transform_ops_strong) + self._transform_ops_strong2 = create_operators(transform_ops_strong2) self.class_num = 100 labels = [] @@ -117,6 +119,14 @@ class Cifar100(Cifar100_paddle): image1 = transform(image, self._transform_ops) image1 = image1.transpose((2, 0, 1)) return (image1, np.int64(label)) + elif self._transform_ops_weak and self._transform_ops_strong and self._transform_ops_strong2: + image2 = transform(image, self._transform_ops_weak) + image2 = image2.transpose((2, 0, 1)) + image3 = transform(image, self._transform_ops_strong) + image3 = image3.transpose((2, 0, 1)) + image4 = transform(image, self._transform_ops_strong2) + image4 = image4.transpose((2, 0, 1)) + return (image2, image3, image4, np.int64(label)) elif self._transform_ops_weak and self._transform_ops_strong: image2 = transform(image, self._transform_ops_weak) image2 = image2.transpose((2, 0, 1)) diff --git a/ppcls/engine/train/train_fixmatch_ccssl.py b/ppcls/engine/train/train_fixmatch_ccssl.py index 5626ebfd..e42ac8a3 100644 --- a/ppcls/engine/train/train_fixmatch_ccssl.py +++ b/ppcls/engine/train/train_fixmatch_ccssl.py @@ -41,7 +41,7 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step): engine.unlabel_train_dataloader_iter = iter(engine.unlabel_train_dataloader) unlabel_data_batch = engine.unlabel_train_dataloader_iter.next() - assert len(unlabel_data_batch) == 4 + assert len(unlabel_data_batch) in [3, 4] assert unlabel_data_batch[0].shape == unlabel_data_batch[1].shape == unlabel_data_batch[2].shape engine.time_info['reader_cost'].update(time.time() - tic) diff --git a/ppcls/utils/ema.py b/ppcls/utils/ema.py index 82927819..9a3b65cc 100644 --- a/ppcls/utils/ema.py +++ b/ppcls/utils/ema.py @@ -39,4 +39,4 @@ class ExponentialMovingAverage(): self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) def set(self, model): - self._update(model, update_fn=lambda e, m: m) + self._update(model, update_fn=lambda e, m: m) \ No newline at end of file -- GitLab