From 009f347d64ee65ce23466e8e265158530d990fdd Mon Sep 17 00:00:00 2001
From: zh-hike <1583124882@qq.com>
Date: Thu, 15 Dec 2022 09:14:59 +0000
Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E7=A9=BA=E6=A0=BC=E7=AD=89?=
=?UTF-8?q?=E4=BB=A3=E7=A0=81=E8=A7=84=E8=8C=83?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../semi_supervised_learning/FixMatchCCSSL.md | 4 +-
ppcls/arch/__init__.py | 2 -
ppcls/arch/backbone/model_zoo/wideresnet.py | 2 +-
.../FixMatchCCSSL_cifar100_10000_4gpu.yaml | 43 ++++++-------------
.../FixMatchCCSSL_cifar10_4000_4gpu.yaml | 4 +-
ppcls/data/preprocess/__init__.py | 6 +--
ppcls/engine/train/train_fixmatch_ccssl.py | 10 ++---
ppcls/loss/softsuploss.py | 20 ++++++---
8 files changed, 36 insertions(+), 55 deletions(-)
diff --git a/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md b/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md
index 201c1f99..53322776 100644
--- a/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md
+++ b/docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md
@@ -33,13 +33,13 @@
pytorch版本 |
|
|
- |
+ 95.54 |
paddle版本 |
|
|
- |
+ 95.61 |
cifar10上,paddle版本的配置文件及训练好的模型如下表所示
diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py
index 805e7867..56d5c8d1 100644
--- a/ppcls/arch/__init__.py
+++ b/ppcls/arch/__init__.py
@@ -14,8 +14,6 @@
import copy
import importlib
-from pyexpat import features
-
import paddle.nn as nn
from paddle.jit import to_static
from paddle.static import InputSpec
diff --git a/ppcls/arch/backbone/model_zoo/wideresnet.py b/ppcls/arch/backbone/model_zoo/wideresnet.py
index 12b8f77d..100289d3 100644
--- a/ppcls/arch/backbone/model_zoo/wideresnet.py
+++ b/ppcls/arch/backbone/model_zoo/wideresnet.py
@@ -235,4 +235,4 @@ def WideResNet(depth,
num_classes=num_classes,
proj=proj,
low_dim=low_dim,
- **kwargs)
+ **kwargs)
\ No newline at end of file
diff --git a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml
index 59040c6f..5bd466f5 100644
--- a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml
+++ b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar100_10000_4gpu.yaml
@@ -30,8 +30,8 @@ Arch:
name: WideResNet
widen_factor: 8
depth: 28
- dropout: 0 # CCSSL为 drop_rate
- num_classes: &sign_num_classes 100
+ dropout: 0
+ num_classes: 100
low_dim: 64
proj: true
proj_after: false
@@ -59,14 +59,6 @@ UnLabelLoss:
- SoftSupConLoss:
weight: 1.0
temperature: 0.07
- # - CCSSLLoss:
- # CELoss:
- # weight: 1.0
- # reduction: "none"
- # SoftSupConLoss:
- # weight: 1.0
- # temperature: 0.07
- # weight: 1.
Optimizer:
name: Momentum
@@ -80,8 +72,8 @@ Optimizer:
num_training_steps: 524800
DataLoader:
- mean: &sign_mean [0.5071, 0.4867, 0.4408]
- std: &sign_std [0.2675, 0.2565, 0.2761]
+ mean: [0.5071, 0.4867, 0.4408]
+ std: [0.2675, 0.2565, 0.2761]
Train:
dataset:
name: CIFAR100SSL
@@ -99,11 +91,11 @@ DataLoader:
padding_mode: "reflect"
- ToTensor:
- Normalize:
- mean: *sign_mean
- std: *sign_std
+ mean: [0.5071, 0.4867, 0.4408]
+ std: [0.2675, 0.2565, 0.2761]
sampler:
- name: DistributedBatchSampler # DistributedBatchSampler
+ name: DistributedBatchSampler
batch_size: 16
drop_last: true
shuffle: true
@@ -111,8 +103,6 @@ DataLoader:
num_workers: 4
use_shared_memory: true
-
-
UnLabelTrain:
dataset:
name: CIFAR100SSL
@@ -129,8 +119,8 @@ DataLoader:
padding_mode: 'reflect'
- ToTensor:
- Normalize:
- mean: *sign_mean
- std: *sign_std
+ mean: [0.5071, 0.4867, 0.4408]
+ std: [0.2675, 0.2565, 0.2761]
transform_s1:
- RandomHorizontalFlip:
@@ -144,8 +134,8 @@ DataLoader:
m: 10
- ToTensor:
- Normalize:
- mean: *sign_mean
- std: *sign_std
+ mean: [0.5071, 0.4867, 0.4408]
+ std: [0.2675, 0.2565, 0.2761]
transform_s2:
- RandomResizedCrop:
@@ -163,12 +153,9 @@ DataLoader:
- RandomGrayscale:
p: 0.2
- ToTensor:
- # - Normalize:
- # mean: *sign_mean
- # std: *sign_std
sampler:
- name: DistributedBatchSampler # DistributedBatchSampler
+ name: DistributedBatchSampler
batch_size: 112
drop_last: true
shuffle: true
@@ -185,8 +172,8 @@ DataLoader:
transform_ops:
- ToTensor:
- Normalize:
- mean: *sign_mean
- std: *sign_std
+ mean: [0.5071, 0.4867, 0.4408]
+ std: [0.2675, 0.2565, 0.2761]
sampler:
name: DistributedBatchSampler
batch_size: 16
@@ -196,8 +183,6 @@ DataLoader:
num_workers: 4
use_shared_memory: true
-
-
Metric:
Eval:
- TopkAcc:
diff --git a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml
index 1a675371..9c6bc28d 100644
--- a/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml
+++ b/ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml
@@ -7,7 +7,7 @@ Global:
eval_during_train: true
eval_interval: 1
epochs: 1024
- iter_per_epoch: 40
+ iter_per_epoch: 1024
print_batch_step: 20
use_visualdl: false
use_dali: false
@@ -196,8 +196,6 @@ DataLoader:
num_workers: 4
use_shared_memory: true
-
-
Metric:
Eval:
- TopkAcc:
diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py
index e581f8e3..13ee7930 100644
--- a/ppcls/data/preprocess/__init__.py
+++ b/ppcls/data/preprocess/__init__.py
@@ -51,13 +51,10 @@ from paddle.vision.transforms import Pad as Pad_paddle_vision
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
-
-
import numpy as np
from PIL import Image
import random
-
def transform(data, ops=[]):
""" transform """
for op in ops:
@@ -120,5 +117,4 @@ class TimmAutoAugment(RawTimmAutoAugment):
if isinstance(img, Image.Image):
img = np.asarray(img)
- return img
-
\ No newline at end of file
+ return img
\ No newline at end of file
diff --git a/ppcls/engine/train/train_fixmatch_ccssl.py b/ppcls/engine/train/train_fixmatch_ccssl.py
index b452733b..43815368 100644
--- a/ppcls/engine/train/train_fixmatch_ccssl.py
+++ b/ppcls/engine/train/train_fixmatch_ccssl.py
@@ -1,7 +1,4 @@
-
-
from __future__ import absolute_import, division, print_function
-
import time
from turtle import update
import paddle
@@ -11,11 +8,11 @@ from ppcls.utils import profiler
from paddle.nn import functional as F
import numpy as np
import paddle
-# from reprod_log import ReprodLogger
def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
-
+ print(engine.model.state_dict().keys())
+ assert 1==0
tic = time.time()
if not hasattr(engine, 'train_dataloader_iter'):
engine.train_dataloader_iter = iter(engine.train_dataloader)
@@ -135,5 +132,4 @@ def get_loss(engine,
loss_dict[k] = v
loss_dict['loss'] = loss_dict_label['loss'] + unlabel_loss['loss']
- return loss_dict, logits_x
-
\ No newline at end of file
+ return loss_dict, logits_x
\ No newline at end of file
diff --git a/ppcls/loss/softsuploss.py b/ppcls/loss/softsuploss.py
index f428fc94..7a64d380 100644
--- a/ppcls/loss/softsuploss.py
+++ b/ppcls/loss/softsuploss.py
@@ -1,7 +1,16 @@
-"""
-CCSSL loss
-author: zhhike
-"""
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import paddle
import paddle.nn as nn
@@ -65,5 +74,4 @@ class SoftSupConLoss(nn.Layer):
if reduction == 'mean':
loss = loss.mean()
- return {"SoftSupConLoss": loss}
-
\ No newline at end of file
+ return {"SoftSupConLoss": loss}
\ No newline at end of file
--
GitLab