Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
f83ff59c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f83ff59c
编写于
12月 15, 2022
作者:
Z
zh-hike
提交者:
Walter
1月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
增加代码规范,删除一些无用的function
上级
692b8d8c
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
10 addition
and
309 deletion
+10
-309
docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md
.../zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md
+6
-6
ppcls/arch/gears/decoup.py
ppcls/arch/gears/decoup.py
+0
-16
ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml
...gs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml
+1
-2
ppcls/data/dataloader/__init__.py
ppcls/data/dataloader/__init__.py
+0
-1
ppcls/data/dataloader/cifar.py
ppcls/data/dataloader/cifar.py
+1
-178
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+1
-40
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+0
-1
ppcls/loss/ccssl_loss.py
ppcls/loss/ccssl_loss.py
+0
-42
ppcls/loss/softsuploss.py
ppcls/loss/softsuploss.py
+1
-5
ppcls/optimizer/learning_rate.py
ppcls/optimizer/learning_rate.py
+0
-18
未找到文件。
docs/zh_CN/training/semi_supervised_learning/FixMatchCCSSL.md
浏览文件 @
f83ff59c
...
...
@@ -79,14 +79,14 @@ python tools/train.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4
python -m paddle.distributed.launch --gpus='0,1,2,3' tools/train.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml
```
2.
**查看训练日志和保存的模型参数文件**
训练过程中屏幕会实时打印loss等指标信息,同时会保存日志文件
`train.log`
,模型参数文件
`*.pdparams`
,优化器参数文件
`*.pdopt`
等内容到
`Global.output_dir`
指定的文件夹下,默认在
`PaddleClas/output/
WideResNet
/`
文件夹下。
2.
**查看训练日志和保存的模型参数文件**
训练过程中屏幕会实时打印loss等指标信息,同时会保存日志文件
`train.log`
,模型参数文件
`*.pdparams`
,优化器参数文件
`*.pdopt`
等内容到
`Global.output_dir`
指定的文件夹下,默认在
`PaddleClas/output/
RecModel
/`
文件夹下。
## 5. 模型评估与推理部署
### 5.1 模型评估
准备用于评估的
`*.pdparams`
模型参数文件,可以使用训练好的模型,也可以使用
*4. 模型训练*
中保存的模型。
*
以训练过程中保存的
`best_model_ema.ema.pdparams`
为例,执行如下命令即可进行评估。
```
python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml -o Global.pretrained_model="./output/
WideResNet
/best_model_ema.ema"
python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml -o Global.pretrained_model="./output/
RecModel
/best_model_ema.ema"
```
*
以训练好的模型为例,下载提供的已经训练好的模型,到
`PaddleClas/pretrained_models`
文件夹中,执行如下命令即可进行评估。
...
...
@@ -98,7 +98,7 @@ cd pretrained_models
wget
cd ..
# 评估
python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatch
_CCSSL_cifar10_4000.yaml -o Global.pretrained_model="
"
python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatch
CCSSL_cifar10_4000.yaml -o Global.pretrained_model="./output/RecModel/best_model_ema.ema
"
```
**注:**
`pretrained_model`
后填入的地址不需要加
`.pdparams`
后缀,在程序运行时会自动补上。
...
...
@@ -114,15 +114,15 @@ python3.7 tools/eval.py -c ppcls/configs/ssl/FixMatch_CCSSL_cifar10_4000.yaml -o
[2022/12/08 09:36:16] ppcls INFO: [Eval][Epoch 0][Iter: 140/157]CELoss: 0.03242, loss: 0.03242, top1: 0.95601, top5: 0.99945, batch_cost: 0.02084s, reader_cost: 0.00075, ips: 3071.00311 images/sec
[2022/12/08 09:36:16] ppcls INFO: [Eval][Epoch 0][Avg]CELoss: 0.16041, loss: 0.16041, top1: 0.95610, top5: 0.99950
```
默认评估日志保存在
`PaddleClas/output/
WideResNetCCSSL/eval.log`
中,可以看到我们提供的模型在cifar10数据集上的评估指标为top1: 95.61
, top5: 99.95
默认评估日志保存在
`PaddleClas/output/
RecModel/eval.log`
中,可以看到我们提供的模型在cifar10数据集上的评估指标为top1: 95.57
, top5: 99.95
### 5.2 模型推理
#### 5.2.1 推理模型准备
将训练过程中保存的模型文件转成inference模型,同样以
`best_model_ema.ema_pdparams`
为例,执行以下命令进行转换
```
python3.7 tools/export_model.py \
-c ppcls/configs/ssl/FixMatch
_CCSSL/FixMatch_
CCSSL_cifar10_4000.yaml \
-o Global.pretrained_model="output/
WideResNetCCSSL
/best_model_ema.ema" \
-c ppcls/configs/ssl/FixMatch
CCSSL/FixMatch
CCSSL_cifar10_4000.yaml \
-o Global.pretrained_model="output/
RecModel
/best_model_ema.ema" \
-o Global.save_inference_fir="./deploy/inference"
```
...
...
ppcls/arch/gears/decoup.py
已删除
100644 → 0
浏览文件 @
692b8d8c
import
paddle
import
paddle.nn
as
nn
class
Decoup
(
nn
.
Layer
):
def
__init__
(
self
,
logits_index
,
features_index
,
**
kwargs
):
super
(
Decoup
,
self
).
__init__
()
self
.
logits_index
=
logits_index
self
.
features_index
=
features_index
def
forward
(
self
,
out
,
**
kwargs
):
assert
isinstance
(
out
,
(
list
,
tuple
)),
'out must be list or tuple'
out
=
{
'logits'
:
out
[
self
.
logits_index
],
'features'
:
out
[
self
.
features_index
]}
return
out
ppcls/configs/ssl/FixMatchCCSSL/FixMatchCCSSL_cifar10_4000_4gpu.yaml
浏览文件 @
f83ff59c
...
...
@@ -64,10 +64,9 @@ Optimizer:
use_nesterov
:
true
weight_decay
:
0.001
lr
:
name
:
'
cosine_schedule_with_warmup
'
name
:
'
CosineFixmatch
'
learning_rate
:
0.03
num_warmup_steps
:
0
num_training_steps
:
524800
DataLoader
:
mean
:
[
0.4914
,
0.4822
,
0.4465
]
...
...
ppcls/data/dataloader/__init__.py
浏览文件 @
f83ff59c
...
...
@@ -13,4 +13,3 @@ from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from
ppcls.data.dataloader.face_dataset
import
AdaFaceDataset
,
FiveValidationDataset
from
ppcls.data.dataloader.custom_label_dataset
import
CustomLabelDataset
from
ppcls.data.dataloader.cifar
import
Cifar10
,
Cifar100
# from ppcls.data.dataloader.cifar import CIFAR10SSL, CIFAR100SSL
ppcls/data/dataloader/cifar.py
浏览文件 @
f83ff59c
...
...
@@ -15,16 +15,11 @@
from
__future__
import
print_function
import
numpy
as
np
import
cv2
import
shutil
from
ppcls.data
import
preprocess
from
ppcls.data.preprocess
import
transform
# from ppcls.data.preprocess import BaseTransform, ListTransform
from
ppcls.data.dataloader.common_dataset
import
create_operators
from
paddle.vision.datasets
import
Cifar10
as
Cifar10_paddle
from
paddle.vision.datasets
import
Cifar100
as
Cifar100_paddle
# from paddle.vision.datasets import cifar
import
os
# from PIL import Image
class
Cifar10
(
Cifar10_paddle
):
...
...
@@ -128,176 +123,4 @@ class Cifar100(Cifar100_paddle):
image3
=
transform
(
image
,
self
.
_transform_ops_strong
)
image3
=
image3
.
transpose
((
2
,
0
,
1
))
return
(
image2
,
image3
,
np
.
int64
(
label
))
# def np_convert_pil(array):
# """
# array conver image
# Args:
# array: array and dim is 1
# """
# assert len(array.shape), "dim of array should 1"
# img = Image.fromarray(array.reshape(3, 32, 32).transpose(1, 2, 0))
# return img
# class CIFAR10(cifar.Cifar10):
# """
# cifar10 dataset
# """
# def __init__(self, data_file, download=True, mode='train'):
# super().__init__(download=download, mode=mode)
# if data_file is not None:
# os.makedirs(data_file, exist_ok=True)
# if not os.path.exists(os.path.join(data_file, 'cifar-10-python.tar.gz')):
# shutil.move('~/.cache/paddle/dataset/cifar/cifar-10-python.tar.gz', data_file)
# self.num_classes = 10
# self.x = []
# self.y = []
# for d in self.data:
# self.x.append(d[0])
# self.y.append(d[1])
# self.x = np.array(self.x)
# self.y = np.array(self.y)
# def __getitem__(self, idx):
# return self.x[idx], self.y[idx]
# def __len__(self):
# return self.x.shape[0]
# class CIFAR100(cifar.Cifar100):
# """
# cifar10 dataset
# """
# def __init__(self, data_file, download=True, mode='train'):
# super().__init__(download=download, mode=mode)
# if data_file is not None:
# os.makedirs(data_file, exist_ok=True)
# if not os.path.exists(os.path.join(data_file, 'cifar-100-python.tar.gz')):
# shutil.move('~/.cache/paddle/dataset/cifar/cifar-100-python.tar.gz', data_file)
# self.num_classes = 100
# self.x = []
# self.y = []
# for d in self.data:
# self.x.append(d[0])
# self.y.append(d[1])
# self.x = np.array(self.x)
# self.y = np.array(self.y)
# def __getitem__(self, idx):
# return self.x[idx], self.y[idx]
# def __len__(self):
# return self.x.shape[0]
# class CIFAR10SSL(CIFAR10):
# """
# from Cifar10
# """
# def __init__(self,
# data_file=None,
# sample_per_label=None,
# download=True,
# expand_labels=1,
# mode='train',
# transform_ops=None,
# transform_w=None,
# transform_s1=None,
# transform_s2=None):
# super().__init__(data_file, download=download, mode=mode)
# self.data_type = 'unlabeled_train' if mode == 'train' else 'val'
# if transform_ops is not None and sample_per_label is not None:
# index = []
# self.data_type = 'labeled_train'
# for c in range(self.num_classes):
# idx = np.where(self.y == c)[0]
# idx = np.random.choice(idx, sample_per_label, False)
# index.extend(idx)
# index = index * expand_labels
# # print(index)
# self.x = self.x[index]
# self.y = self.y[index]
# self.transforms = [transform_ops] if transform_ops is not None else [transform_w, transform_s1, transform_s2]
# self.mode = mode
# def __getitem__(self, idx):
# img, label = np_convert_pil(self.x[idx]), self.y[idx]
# results = ListTransform(self.transforms)(img)
# if self.data_type == 'unlabeled_train':
# return results
# return results[0], label
# def __len__(self):
# return self.x.shape[0]
# class CIFAR100SSL(CIFAR100):
# """
# from Cifar100
# """
# def __init__(self,
# data_file=None,
# sample_per_label=None,
# download=True,
# expand_labels=1,
# mode='train',
# transform_ops=None,
# transform_w=None,
# transform_s1=None,
# transform_s2=None):
# super().__init__(data_file, download=download, mode=mode)
# self.data_type = 'unlabeled_train' if mode == 'train' else 'val'
# if transform_ops is not None and sample_per_label is not None:
# index = []
# self.data_type = 'labeled_train'
# for c in range(self.num_classes):
# idx = np.where(self.y == c)[0]
# idx = np.random.choice(idx, sample_per_label, False)
# index.extend(idx)
# index = index * expand_labels
# # print(index)
# self.x = self.x[index]
# self.y = self.y[index]
# self.transforms = [transform_ops] if transform_ops is not None else [transform_w, transform_s1, transform_s2]
# self.mode = mode
# def __getitem__(self, idx):
# img, label = np_convert_pil(self.x[idx]), self.y[idx]
# results = ListTransform(self.transforms)(img)
# if self.data_type == 'unlabeled_train':
# return results
# return results[0], label
# def __len__(self):
# return self.x.shape[0]
# def x_u_split(num_labeled, num_classes, label):
# """
# split index of dataset to labeled x and unlabeled u
# Args:
# num_labeled: num of labeled dataset
# label: list or array, label
# """
# assert num_labeled <= len(label), "arg num_labeled should <= num of label"
# label = np.array(label) if isinstance(label, list) else label
# label_per_class = num_labeled // num_classes
# labeled_idx = []
# unlabeled_idx = np.array(list(range(label.shape[0])))
# for c in range(num_classes):
# idx = np.where(label == c)[0]
# idx = np.random.choice(idx, label_per_class, False)
# labeled_idx.extend(idx)
# np.random.shuffle(labeled_idx)
# return labeled_idx, unlabeled_idx
\ No newline at end of file
return
(
image2
,
image3
,
np
.
int64
(
label
))
\ No newline at end of file
ppcls/data/preprocess/__init__.py
浏览文件 @
f83ff59c
...
...
@@ -56,8 +56,6 @@ from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
import
numpy
as
np
from
PIL
import
Image
import
random
from
paddle.vision.transforms
import
transforms
as
T
from
paddle.vision.transforms.transforms
import
RandomCrop
,
ToTensor
,
Normalize
def
transform
(
data
,
ops
=
[]):
...
...
@@ -123,41 +121,4 @@ class TimmAutoAugment(RawTimmAutoAugment):
img
=
np
.
asarray
(
img
)
return
img
# class BaseTransform:
# def __init__(self, cfg) -> None:
# """
# Args:
# cfg: list [dict, dict, dict]
# """
# ts = []
# for op in cfg:
# name = list(op.keys())[0]
# if op[name] is None:
# ts.append(eval(name)())
# else:
# ts.append(eval(name)(**(op[name])))
# self.t = T.Compose(ts)
# def __call__(self, img):
# return self.t(img)
# class ListTransform:
# def __init__(self, ops) -> None:
# """
# Args:
# ops: list[list[dict, dict], ...]
# """
# self.ts = []
# for op in ops:
# self.ts.append(BaseTransform(op))
# def __call__(self, img):
# results = []
# for op in self.ts:
# results.append(op(img))
# return results
\ No newline at end of file
ppcls/loss/__init__.py
浏览文件 @
f83ff59c
...
...
@@ -22,7 +22,6 @@ from .pairwisecosface import PairwiseCosface
from
.dmlloss
import
DMLLoss
from
.distanceloss
import
DistanceLoss
from
.softtargetceloss
import
SoftTargetCrossEntropy
from
.ccssl_loss
import
CCSSLLoss
from
.distillationloss
import
DistillationCELoss
from
.distillationloss
import
DistillationGTCELoss
from
.distillationloss
import
DistillationDMLLoss
...
...
ppcls/loss/ccssl_loss.py
浏览文件 @
f83ff59c
from
ppcls.engine.train.train
import
forward
from
.softsuploss
import
SoftSupConLoss
import
copy
import
paddle.nn
as
nn
class
CCSSLLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
(
CCSSLLoss
,
self
).
__init__
()
ce_cfg
=
copy
.
deepcopy
(
kwargs
[
'CELoss'
])
self
.
ce_weight
=
ce_cfg
.
pop
(
'weight'
)
softsupconloss_cfg
=
copy
.
deepcopy
(
kwargs
[
'SoftSupConLoss'
])
self
.
softsupconloss_weight
=
softsupconloss_cfg
.
pop
(
'weight'
)
self
.
softsuploss
=
SoftSupConLoss
(
**
softsupconloss_cfg
)
self
.
celoss
=
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
def
forward
(
self
,
feats
,
batch
,
**
kwargs
):
"""
Args:
feats: feature of s1 and s2, (n, 2, d)
batch: dict
"""
logits_w
=
batch
[
'logits_w'
]
logits_s1
=
batch
[
'logits_s1'
]
p_targets_u_w
=
batch
[
'p_targets_u_w'
]
mask
=
batch
[
'mask'
]
max_probs
=
batch
[
'max_probs'
]
# reduction = batch['reduction']
loss_u
=
self
.
celoss
(
logits_s1
,
p_targets_u_w
)
*
mask
loss_u
=
loss_u
.
mean
()
loss_c
=
self
.
softsuploss
(
feats
,
max_probs
,
p_targets_u_w
)
return
{
'CCSSLLoss'
:
self
.
ce_weight
*
loss_u
+
self
.
softsupconloss_weight
*
loss_c
}
class
CCSSLCeLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
(
CCSSLCeLoss
,
self
).
__init__
()
self
.
celoss
=
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
def
forward
(
self
,
inputs
,
batch
,
**
kwargs
):
p_targets_u_w
=
batch
[
'p_targets_u_w'
]
logits_s1
=
batch
[
'logits_s1'
]
...
...
@@ -56,6 +17,3 @@ class CCSSLCeLoss(nn.Layer):
loss_u
=
loss_u
.
mean
()
return
{
'CCSSLCeLoss'
:
loss_u
}
ppcls/loss/softsuploss.py
浏览文件 @
f83ff59c
...
...
@@ -41,7 +41,6 @@ class SoftSupConLoss(nn.Layer):
score_mask
=
paddle
.
matmul
(
max_probs
,
max_probs
.
T
)
mask
=
paddle
.
multiply
(
mask
,
score_mask
)
contrast_count
=
feat
.
shape
[
1
]
...
...
@@ -55,7 +54,6 @@ class SoftSupConLoss(nn.Layer):
mask
=
paddle
.
concat
([
mask
,
mask
],
axis
=
0
)
mask
=
paddle
.
concat
([
mask
,
mask
],
axis
=
1
)
# mask = paddle.repeat_interleave(paddle.repeat_interleave(mask, 2, 0), 2, 1)
logits_mask
=
1
-
paddle
.
eye
(
batch_size
*
contrast_count
,
dtype
=
paddle
.
float64
)
mask
=
mask
*
logits_mask
exp_logits
=
paddle
.
exp
(
logits
)
*
logits_mask
...
...
@@ -68,6 +66,4 @@ class SoftSupConLoss(nn.Layer):
loss
=
loss
.
mean
()
return
{
"SoftSupConLoss"
:
loss
}
\ No newline at end of file
ppcls/optimizer/learning_rate.py
浏览文件 @
f83ff59c
...
...
@@ -519,21 +519,3 @@ class CosineFixmatch(LRBase):
last_epoch
=
self
.
last_epoch
)
setattr
(
learning_rate
,
"by_epoch"
,
self
.
by_epoch
)
return
learning_rate
def
cosine_schedule_with_warmup
(
learning_rate
,
num_warmup_steps
,
num_training_steps
,
num_cycles
=
7.
/
16
,
last_epoch
=-
1
,
**
kwargs
):
def
_lr_lambda
(
current_step
):
if
current_step
<
num_warmup_steps
:
return
float
(
current_step
)
/
float
(
max
(
1
,
num_warmup_steps
))
no_progress
=
float
(
current_step
-
num_warmup_steps
)
/
float
(
max
(
1
,
num_training_steps
-
num_warmup_steps
))
return
max
(
0.
,
math
.
cos
(
math
.
pi
*
num_cycles
*
no_progress
))
return
lr
.
LambdaDecay
(
learning_rate
=
learning_rate
,
lr_lambda
=
_lr_lambda
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录