From 009f347d64ee65ce23466e8e265158530d990fdd Mon Sep 17 00:00:00 2001 From: zh-hike <1583124882@qq.com> Date: Thu, 15 Dec 2022 09:14:59 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E7=A9=BA=E6=A0=BC=E7=AD=89?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../semi_supervised_learning/FixMatchCCSSL.md | 4 +- ppcls/arch/__init__.py | 2 - ppcls/arch/backbone/model_zoo/wideresnet.py | 2 +- .../FixMatchCCSSL_cifar100_10000_4gpu.yaml | 43 ++++++------------- .../FixMatchCCSSL_cifar10_4000_4gpu.yaml | 4 +- ppcls/data/preprocess/__init__.py | 6 +-- ppcls/engine/train/train_fixmatch_ccssl.py | 10 ++--- ppcls/loss/softsuploss.py | 20 ++++++--- 8 files changed, 36 insertions(+), 55 deletions(-) diff --git a/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md b/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md index 201c1f99..53322776 100644 --- a/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md +++ b/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md @@ -33,13 +33,13 @@ pytorch版本 - + 95.54 paddle版本 - + 95.61 cifar10上,paddle版本的配置文件及训练好的模型如下表所示 diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 805e7867..56d5c8d1 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -14,8 +14,6 @@ import copy import importlib -from pyexpat import features - import paddle.nn as nn from paddle.jit import to_static from paddle.static import InputSpec diff --git a/ppcls/arch/backbone/model_zoo/wideresnet.py b/ppcls/arch/backbone/model_zoo/wideresnet.py index 12b8f77d..100289d3 100644 --- a/ppcls/arch/backbone/model_zoo/wideresnet.py +++ b/ppcls/arch/backbone/model_zoo/wideresnet.py @@ -235,4 +235,4 @@ def WideResNet(depth, num_classes=num_classes, proj=proj, low_dim=low_dim, - **kwargs) + **kwargs) \ No newline at end of file diff --git a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml index 59040c6f..5bd466f5 100644 --- a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml +++ b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml @@ -30,8 +30,8 @@ Arch: name: WideResNet widen_factor: 8 depth: 28 - dropout: 0 # CCSSL为 drop_rate - num_classes: &sign_num_classes 100 + dropout: 0 + num_classes: 100 low_dim: 64 proj: true proj_after: false @@ -59,14 +59,6 @@ UnLabelLoss: - SoftSupConLoss: weight: 1.0 temperature: 0.07 - # - CCSSLLoss: - # CELoss: - # weight: 1.0 - # reduction: "none" - # SoftSupConLoss: - # weight: 1.0 - # temperature: 0.07 - # weight: 1. Optimizer: name: Momentum @@ -80,8 +72,8 @@ Optimizer: num_training_steps: 524800 DataLoader: - mean: &sign_mean [0.5071, 0.4867, 0.4408] - std: &sign_std [0.2675, 0.2565, 0.2761] + mean: [0.5071, 0.4867, 0.4408] + std: [0.2675, 0.2565, 0.2761] Train: dataset: name: CIFAR100SSL @@ -99,11 +91,11 @@ DataLoader: padding_mode: "reflect" - ToTensor: - Normalize: - mean: *sign_mean - std: *sign_std + mean: [0.5071, 0.4867, 0.4408] + std: [0.2675, 0.2565, 0.2761] sampler: - name: DistributedBatchSampler # DistributedBatchSampler + name: DistributedBatchSampler batch_size: 16 drop_last: true shuffle: true @@ -111,8 +103,6 @@ DataLoader: num_workers: 4 use_shared_memory: true - - UnLabelTrain: dataset: name: CIFAR100SSL @@ -129,8 +119,8 @@ DataLoader: padding_mode: 'reflect' - ToTensor: - Normalize: - mean: *sign_mean - std: *sign_std + mean: [0.5071, 0.4867, 0.4408] + std: [0.2675, 0.2565, 0.2761] transform_s1: - RandomHorizontalFlip: @@ -144,8 +134,8 @@ DataLoader: m: 10 - ToTensor: - Normalize: - mean: *sign_mean - std: *sign_std + mean: [0.5071, 0.4867, 0.4408] + std: [0.2675, 0.2565, 0.2761] transform_s2: - RandomResizedCrop: @@ -163,12 +153,9 @@ DataLoader: - RandomGrayscale: p: 0.2 - ToTensor: - # - Normalize: - # mean: *sign_mean - # std: *sign_std sampler: - name: DistributedBatchSampler # DistributedBatchSampler + name: DistributedBatchSampler batch_size: 112 drop_last: true shuffle: true @@ -185,8 +172,8 @@ DataLoader: transform_ops: - ToTensor: - Normalize: - mean: *sign_mean - std: *sign_std + mean: [0.5071, 0.4867, 0.4408] + std: [0.2675, 0.2565, 0.2761] sampler: name: DistributedBatchSampler batch_size: 16 @@ -196,8 +183,6 @@ DataLoader: num_workers: 4 use_shared_memory: true - - Metric: Eval: - TopkAcc: diff --git a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml index 1a675371..9c6bc28d 100644 --- a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml +++ b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml @@ -7,7 +7,7 @@ Global: eval_during_train: true eval_interval: 1 epochs: 1024 - iter_per_epoch: 40 + iter_per_epoch: 1024 print_batch_step: 20 use_visualdl: false use_dali: false @@ -196,8 +196,6 @@ DataLoader: num_workers: 4 use_shared_memory: true - - Metric: Eval: - TopkAcc: diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index e581f8e3..13ee7930 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -51,13 +51,10 @@ from paddle.vision.transforms import Pad as Pad_paddle_vision from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid - - import numpy as np from PIL import Image import random - def transform(data, ops=[]): """ transform """ for op in ops: @@ -120,5 +117,4 @@ class TimmAutoAugment(RawTimmAutoAugment): if isinstance(img, Image.Image): img = np.asarray(img) - return img - \ No newline at end of file + return img \ No newline at end of file diff --git a/ppcls/engine/train/train_fixmatch_ccssl.py b/ppcls/engine/train/train_fixmatch_ccssl.py index b452733b..43815368 100644 --- a/ppcls/engine/train/train_fixmatch_ccssl.py +++ b/ppcls/engine/train/train_fixmatch_ccssl.py @@ -1,7 +1,4 @@ - - from __future__ import absolute_import, division, print_function - import time from turtle import update import paddle @@ -11,11 +8,11 @@ from ppcls.utils import profiler from paddle.nn import functional as F import numpy as np import paddle -# from reprod_log import ReprodLogger def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step): - + print(engine.model.state_dict().keys()) + assert 1==0 tic = time.time() if not hasattr(engine, 'train_dataloader_iter'): engine.train_dataloader_iter = iter(engine.train_dataloader) @@ -135,5 +132,4 @@ def get_loss(engine, loss_dict[k] = v loss_dict['loss'] = loss_dict_label['loss'] + unlabel_loss['loss'] - return loss_dict, logits_x - \ No newline at end of file + return loss_dict, logits_x \ No newline at end of file diff --git a/ppcls/loss/softsuploss.py b/ppcls/loss/softsuploss.py index f428fc94..7a64d380 100644 --- a/ppcls/loss/softsuploss.py +++ b/ppcls/loss/softsuploss.py @@ -1,7 +1,16 @@ -""" -CCSSL loss -author: zhhike -""" +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import paddle import paddle.nn as nn @@ -65,5 +74,4 @@ class SoftSupConLoss(nn.Layer): if reduction == 'mean': loss = loss.mean() - return {"SoftSupConLoss": loss} - \ No newline at end of file + return {"SoftSupConLoss": loss} \ No newline at end of file -- GitLab