提交 692b8d8c 编写于 作者: Z zh-hike 提交者: Walter

增加data的代码复用及RecModel的修改,代码已跑通

上级 4db13244
......@@ -2,20 +2,20 @@
论文出处:[https://arxiv.org/abs/2203.02261](https://arxiv.org/abs/2203.02261)
## 目录
* 1. 原理介绍
* 2. 精度指标
* 3. 数据准备
* 4. 模型训练
* 5. 模型评估与推理部署
* 5.1 模型评估
* 5.2 模型推理
* * 5.2.1 推理模型准备
* * 5.2.2 基于Python预测引擎推理
* * 5.2.3 基于C++预测引擎推理
* 5.4 服务化部署
* 5.5 端侧部署
* 5.6 Paddle2ONNX模型转换与预测
* 6. 参考文献
* [1. 原理介绍](#1-原理介绍)
* [2. 精度指标](#2-精度指标)
* [3. 数据准备](#3-数据准备)
* [4. 模型训练](#4-模型训练)
* [5. 模型评估与推理部署](#5-模型评估与推理部署)
* [5.1 模型评估](#51-模型评估)
* [5.2 模型推理](#52-模型推理)
* * [5.2.1 推理模型准备](#521-推理模型准备)
* * [5.2.2 基于Python预测引擎推理](#522-基于-python-预测引擎推理)
* * [5.2.3 基于C++预测引擎推理](#523-基于c预测引擎推理)
* [5.4 服务化部署](#54-服务化部署)
* [5.5 端侧部署](#55-端侧部署)
* [5.6 Paddle2ONNX模型转换与预测](#56-paddle2onnx-模型转换与预测)
* [6. 参考文献](#6-参考资料)
## 1. 原理介绍
作者提出了一种新颖的半监督学习方法。对有标签的数据进行数据训练的同时,对无标签数据进行一种弱增强和两种强增强。如果若增强的分类结果大于阈值,则弱数据增强的输出标签作为伪标签。通过伪标签,制作一个仅包含类级信息的监督对比矩阵。然后,通过对分布外数据的图像级对比形成类感知对比矩阵,以减少确认偏差。通过应用重新加权模块,将学习重点放在干净的数据上,并获得最终的目标矩阵。此外,特征亲和矩阵由两个强大的增强视图组成。通过最小化亲和矩阵和目标矩阵之间的交叉熵来制定用于未标记数据的类感知对比模块。模型的流程图如下
......@@ -178,7 +178,7 @@ Paddle Lite 是一个高性能、轻量级、灵活性强且易于扩展的深
PaddleClas 提供了基于 Paddle Lite 来完成模型[端侧部署](https://github.com/zh-hike/PaddleClas/blob/develop/docs/zh_CN/deployment/image_classification/paddle_lite.md)的示例,您可以参考端侧部署来完成相应的部署工作。
## Paddle2ONNX 模型转换与预测
## 5.6 Paddle2ONNX 模型转换与预测
Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式。通过 ONNX 可以完成将 Paddle 模型到多种推理引擎的部署,包括TensorRT/OpenVINO/MNN/TNN/NCNN,以及其它对 ONNX 开源格式进行支持的推理引擎或硬件。更多关于 Paddle2ONNX 的介绍,可以参考Paddle2ONNX 代码仓库。
PaddleClas 提供了基于 Paddle2ONNX 来完成 inference 模型转换 ONNX 模型并作推理预测的示例,您可以参考 [Paddle2ONNX](https://github.com/zh-hike/PaddleClas/blob/develop/docs/zh_CN/deployment/image_classification/paddle2onnx.md) 模型转换与预测来完成相应的部署工作。
......
......@@ -14,6 +14,7 @@
import copy
import importlib
from pyexpat import features
import paddle.nn as nn
from paddle.jit import to_static
......@@ -71,6 +72,9 @@ class RecModel(TheseusLayer):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
self.decoup = False
if backbone_config.get('decoup', False):
self.decoup = backbone_config.pop('decoup')
self.backbone = eval(backbone_name)(**backbone_config)
if "BackboneStopLayer" in config:
backbone_stop_layer = config["BackboneStopLayer"]["name"]
......@@ -85,25 +89,26 @@ class RecModel(TheseusLayer):
self.head = build_gear(config["Head"])
else:
self.head = None
if "Decoup" in config:
self.decoup = build_gear(config['Decoup'])
else:
self.decoup = None
def forward(self, x, label=None):
out = dict()
x = self.backbone(x)
if self.decoup is not None:
return self.decoup(x)
out["backbone"] = x
if self.decoup:
logits_index, features_index = self.decoup['logits_index'], self.decoup['features_index']
logits, feat = x[logits_index], x[features_index]
out['logits'] = logits
out['features'] =feat
return out
if self.neck is not None:
x = self.neck(x)
out["neck"] = x
out["features"] = x
feat = self.neck(x)
out["neck"] = feat
out["features"] = out['neck'] if self.neck else x
if self.head is not None:
y = self.head(x, label)
y = self.head(out['features'], label)
out["logits"] = y
return out
......
......@@ -20,7 +20,6 @@ from .vehicle_neck import VehicleNeck
from paddle.nn import Tanh
from .bnneck import BNNeck
from .adamargin import AdaMargin
from .decoup import Decoup
__all__ = ['build_gear']
......@@ -28,7 +27,7 @@ __all__ = ['build_gear']
def build_gear(config):
support_dict = [
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh',
'BNNeck', 'AdaMargin', 'FRFBNeck', 'Decoup'
'BNNeck', 'AdaMargin',
]
module_name = config.pop('name')
assert module_name in support_dict, Exception(
......
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output
device: gpu
save_interval: -1
eval_during_train: true
eval_interval: 1
epochs: 1024
iter_per_epoch: 1024
print_batch_step: 20
use_visualdl: false
use_dali: false
train_mode: fixmatch_ccssl
image_shape: [3, 32, 32]
save_inference_dir: ./inference
SSL:
T: 1
threshold: 0.95
EMA:
decay: 0.999
Arch:
name: RecModel
infer_output_key: logits
infer_add_softmax: false
Backbone:
name: WideResNet
widen_factor: 8
depth: 28
dropout: 0 # CCSSL为 drop_rate
num_classes: &sign_num_classes 100
low_dim: 64
proj: true
proj_after: false
Decoup:
name: Decoup
logits_index: 0
features_index: 1
use_sync_bn: true
Loss:
Train:
- CELoss:
weight: 1.0
reduction: "mean"
Eval:
- CELoss:
weight: 1.0
UnLabelLoss:
Train:
- CCSSLCeLoss:
weight: 1.
- SoftSupConLoss:
weight: 1.0
temperature: 0.07
# - CCSSLLoss:
# CELoss:
# weight: 1.0
# reduction: "none"
# SoftSupConLoss:
# weight: 1.0
# temperature: 0.07
# weight: 1.
Optimizer:
name: Momentum
momentum: 0.9
use_nesterov: true
weight_decay: 0.001
lr:
name: 'cosine_schedule_with_warmup'
learning_rate: 0.03
num_warmup_steps: 0
num_training_steps: 524800
DataLoader:
mean: &sign_mean [0.5071, 0.4867, 0.4408]
std: &sign_std [0.2675, 0.2565, 0.2761]
Train:
dataset:
name: CIFAR100SSL
data_file: null
mode: 'train'
download: true
sample_per_label: 100
expand_labels: 1
transform_ops:
- RandomHorizontalFlip:
prob: 0.5
- RandomCrop:
size: 32
padding: 4
padding_mode: "reflect"
- ToTensor:
- Normalize:
mean: *sign_mean
std: *sign_std
sampler:
name: DistributedBatchSampler # DistributedBatchSampler
batch_size: 16
drop_last: true
shuffle: true
loader:
num_workers: 4
use_shared_memory: true
UnLabelTrain:
dataset:
name: CIFAR100SSL
data_file: null
mode: 'train'
download: true
transform_w:
- RandomHorizontalFlip:
prob: 0.5
- RandomCrop:
size: 32
padding: 4
padding_mode: 'reflect'
- ToTensor:
- Normalize:
mean: *sign_mean
std: *sign_std
transform_s1:
- RandomHorizontalFlip:
prob: 0.5
- RandomCrop:
size: 32
padding: 4
padding_mode: 'reflect'
- RandAugmentMC:
n: 2
m: 10
- ToTensor:
- Normalize:
mean: *sign_mean
std: *sign_std
transform_s2:
- RandomResizedCrop:
size: 32
- RandomHorizontalFlip:
prob: 0.5
- RandomApply:
transforms:
- ColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
hue: 0.1
p: 0.8
- RandomGrayscale:
p: 0.2
- ToTensor:
# - Normalize:
# mean: *sign_mean
# std: *sign_std
sampler:
name: DistributedBatchSampler # DistributedBatchSampler
batch_size: 112
drop_last: true
shuffle: true
loader:
num_workers: 4
use_shared_memory: true
Eval:
dataset:
name: CIFAR100SSL
mode: 'test'
download: true
data_file: null
transform_ops:
- ToTensor:
- Normalize:
mean: *sign_mean
std: *sign_std
sampler:
name: DistributedBatchSampler
batch_size: 16
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: true
Metric:
Eval:
- TopkAcc:
topk: [1, 5]
\ No newline at end of file
......@@ -7,7 +7,7 @@ Global:
eval_during_train: true
eval_interval: 1
epochs: 1024
iter_per_epoch: 200
iter_per_epoch: 40
print_batch_step: 20
use_visualdl: false
use_dali: false
......@@ -28,19 +28,17 @@ Arch:
infer_add_softmax: false
Backbone:
name: WideResNet
decoup:
logits_index: 0
features_index: 1
widen_factor: 2
depth: 28
dropout: 0 # CCSSL为 drop_rate
num_classes: &sign_num_classes 10
dropout: 0
num_classes: 10
low_dim: 64
proj: true
proj_after: false
Decoup:
name: Decoup
logits_index: 0
features_index: 1
use_sync_bn: true
Loss:
......@@ -59,14 +57,6 @@ UnLabelLoss:
- SoftSupConLoss:
weight: 1.0
temperature: 0.07
# - CCSSLLoss:
# CELoss:
# weight: 1.0
# reduction: "none"
# SoftSupConLoss:
# weight: 1.0
# temperature: 0.07
# weight: 1.
Optimizer:
name: Momentum
......@@ -80,27 +70,30 @@ Optimizer:
num_training_steps: 524800
DataLoader:
mean: &sign_mean [0.4914, 0.4822, 0.4465]
std: &sign_std [0.2471, 0.2435, 0.2616]
mean: [0.4914, 0.4822, 0.4465]
std: [0.2471, 0.2435, 0.2616]
Train:
dataset:
name: CIFAR10SSL
name: Cifar10
data_file: null
mode: 'train'
download: true
backend: 'pil'
sample_per_label: 400
expand_labels: 1
transform_ops:
- RandomHorizontalFlip:
prob: 0.5
- RandomCrop:
size: 32
- RandFlipImage:
flip_code: 1
- Pad_paddle_vision:
padding: 4
padding_mode: "reflect"
- ToTensor:
- Normalize:
mean: *sign_mean
std: *sign_std
padding_mode: reflect
- RandCropImageV2:
size: [32, 32]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2471, 0.2435, 0.2616]
order: hwc
sampler:
name: DistributedBatchSampler # DistributedBatchSampler
......@@ -115,46 +108,51 @@ DataLoader:
UnLabelTrain:
dataset:
name: CIFAR10SSL
name: Cifar10
data_file: null
mode: 'train'
backend: 'pil'
download: true
transform_w:
- RandomHorizontalFlip:
prob: 0.5
- RandomCrop:
size: 32
transform_ops_weak:
- RandFlipImage:
flip_code: 1
- Pad_paddle_vision:
padding: 4
padding_mode: 'reflect'
- ToTensor:
- Normalize:
mean: *sign_mean
std: *sign_std
transform_s1:
- RandomHorizontalFlip:
prob: 0.5
- RandomCrop:
size: 32
padding_mode: reflect
- RandCropImageV2:
size: [32, 32]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2471, 0.2435, 0.2616]
order: hwc
transform_ops_strong:
- RandFlipImage:
flip_code: 1
- Pad_paddle_vision:
padding: 4
padding_mode: 'reflect'
- RandAugmentMC:
n: 2
m: 10
- ToTensor:
- Normalize:
mean: *sign_mean
std: *sign_std
padding_mode: reflect
- RandCropImageV2:
size: [32, 32]
- RandAugment:
num_layers: 2
magnitude: 10
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2471, 0.2435, 0.2616]
order: hwc
transform_s2:
- RandomResizedCrop:
size: 32
- RandomHorizontalFlip:
prob: 0.5
transform_ops_strong2:
- RandCropImageV2:
size: [32, 32]
- RandFlipImage:
flip_code: 1
- RandomApply:
transforms:
- ColorJitter:
- RawColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
......@@ -162,13 +160,14 @@ DataLoader:
p: 0.8
- RandomGrayscale:
p: 0.2
- ToTensor:
# - Normalize:
# mean: *sign_mean
# std: *sign_std
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2471, 0.2435, 0.2616]
order: hwc
sampler:
name: DistributedBatchSampler # DistributedBatchSampler
name: DistributedBatchSampler
batch_size: 448
drop_last: true
shuffle: true
......@@ -178,15 +177,17 @@ DataLoader:
Eval:
dataset:
name: CIFAR10SSL
name: Cifar10
mode: 'test'
backend: 'pil'
download: true
data_file: null
transform_ops:
- ToTensor:
- Normalize:
mean: *sign_mean
std: *sign_std
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2471, 0.2435, 0.2616]
order: hwc
sampler:
name: DistributedBatchSampler
batch_size: 64
......
......@@ -35,7 +35,7 @@ from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.face_dataset import FiveValidationDataset, AdaFaceDataset
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
from ppcls.data.dataloader.cifar import Cifar10, Cifar100, CIFAR10SSL
from ppcls.data.dataloader.cifar import Cifar10, Cifar100
# sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
......
......@@ -13,4 +13,4 @@ from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDataset
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
from ppcls.data.dataloader.cifar import Cifar10, Cifar100
from ppcls.data.dataloader.cifar import CIFAR10SSL
# from ppcls.data.dataloader.cifar import CIFAR10SSL, CIFAR100SSL
......@@ -18,13 +18,13 @@ import cv2
import shutil
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
from ppcls.data.preprocess import BaseTransform, ListTransform
# from ppcls.data.preprocess import BaseTransform, ListTransform
from ppcls.data.dataloader.common_dataset import create_operators
from paddle.vision.datasets import Cifar10 as Cifar10_paddle
from paddle.vision.datasets import Cifar100 as Cifar100_paddle
from paddle.vision.datasets import cifar
# from paddle.vision.datasets import cifar
import os
from PIL import Image
# from PIL import Image
class Cifar10(Cifar10_paddle):
......@@ -37,12 +37,14 @@ class Cifar10(Cifar10_paddle):
expand_labels=1,
transform_ops=None,
transform_ops_weak=None,
transform_ops_strong=None):
transform_ops_strong=None,
transform_ops_strong2=None):
super().__init__(data_file, mode, None, download, backend)
assert isinstance(expand_labels, int)
self._transform_ops = create_operators(transform_ops)
self._transform_ops_weak = create_operators(transform_ops_weak)
self._transform_ops_strong = create_operators(transform_ops_strong)
self._transform_ops_strong2 = create_operators(transform_ops_strong2)
self.class_num = 10
labels = []
for x in self.data:
......@@ -64,6 +66,15 @@ class Cifar10(Cifar10_paddle):
image1 = transform(image, self._transform_ops)
image1 = image1.transpose((2, 0, 1))
return (image1, np.int64(label))
elif self._transform_ops_weak and self._transform_ops_strong and self._transform_ops_strong2:
image2 = transform(image, self._transform_ops_weak)
image2 = image2.transpose((2, 0, 1))
image3 = transform(image, self._transform_ops_strong)
image3 = image3.transpose((2, 0, 1))
image4 = transform(image, self._transform_ops_strong2)
image4 = image4.transpose((2, 0, 1))
return (image2, image3, image4, np.int64(label))
elif self._transform_ops_weak and self._transform_ops_strong:
image2 = transform(image, self._transform_ops_weak)
image2 = image2.transpose((2, 0, 1))
......@@ -120,85 +131,156 @@ class Cifar100(Cifar100_paddle):
return (image2, image3, np.int64(label))
def np_convert_pil(array):
"""
array conver image
Args:
array: array and dim is 1
"""
assert len(array.shape), "dim of array should 1"
img = Image.fromarray(array.reshape(3, 32, 32).transpose(1, 2, 0))
return img
class CIFAR10(cifar.Cifar10):
"""
cifar10 dataset
"""
def __init__(self, data_file, download=True, mode='train'):
super().__init__(download=download, mode=mode)
if data_file is not None:
os.makedirs(data_file, exist_ok=True)
if not os.path.exists(os.path.join(data_file, 'cifar-10-python.tar.gz')):
shutil.move('~/.cache/paddle/dataset/cifar/cifar-10-python.tar.gz', data_file)
self.num_classes = 10
self.x = []
self.y = []
for d in self.data:
self.x.append(d[0])
self.y.append(d[1])
self.x = np.array(self.x)
self.y = np.array(self.y)
# def np_convert_pil(array):
# """
# array conver image
# Args:
# array: array and dim is 1
# """
# assert len(array.shape), "dim of array should 1"
# img = Image.fromarray(array.reshape(3, 32, 32).transpose(1, 2, 0))
# return img
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def __len__(self):
return self.x.shape[0]
# class CIFAR10(cifar.Cifar10):
# """
# cifar10 dataset
# """
# def __init__(self, data_file, download=True, mode='train'):
# super().__init__(download=download, mode=mode)
# if data_file is not None:
# os.makedirs(data_file, exist_ok=True)
# if not os.path.exists(os.path.join(data_file, 'cifar-10-python.tar.gz')):
# shutil.move('~/.cache/paddle/dataset/cifar/cifar-10-python.tar.gz', data_file)
# self.num_classes = 10
# self.x = []
# self.y = []
# for d in self.data:
# self.x.append(d[0])
# self.y.append(d[1])
# self.x = np.array(self.x)
# self.y = np.array(self.y)
class CIFAR10SSL(CIFAR10):
"""
from Cifar10
"""
# def __getitem__(self, idx):
# return self.x[idx], self.y[idx]
def __init__(self,
data_file=None,
sample_per_label=None,
download=True,
expand_labels=1,
mode='train',
transform_ops=None,
transform_w=None,
transform_s1=None,
transform_s2=None):
super().__init__(data_file, download=download, mode=mode)
self.data_type = 'unlabeled_train' if mode == 'train' else 'val'
if transform_ops is not None and sample_per_label is not None:
index = []
self.data_type = 'labeled_train'
for c in range(self.num_classes):
idx = np.where(self.y == c)[0]
idx = np.random.choice(idx, sample_per_label, False)
index.extend(idx)
index = index * expand_labels
# print(index)
self.x = self.x[index]
self.y = self.y[index]
self.transforms = [transform_ops] if transform_ops is not None else [transform_w, transform_s1, transform_s2]
self.mode = mode
# def __len__(self):
# return self.x.shape[0]
def __getitem__(self, idx):
img, label = np_convert_pil(self.x[idx]), self.y[idx]
results = ListTransform(self.transforms)(img)
if self.data_type == 'unlabeled_train':
return results
return results[0], label
# class CIFAR100(cifar.Cifar100):
# """
# cifar10 dataset
# """
# def __init__(self, data_file, download=True, mode='train'):
# super().__init__(download=download, mode=mode)
# if data_file is not None:
# os.makedirs(data_file, exist_ok=True)
# if not os.path.exists(os.path.join(data_file, 'cifar-100-python.tar.gz')):
# shutil.move('~/.cache/paddle/dataset/cifar/cifar-100-python.tar.gz', data_file)
# self.num_classes = 100
# self.x = []
# self.y = []
# for d in self.data:
# self.x.append(d[0])
# self.y.append(d[1])
# self.x = np.array(self.x)
# self.y = np.array(self.y)
# def __getitem__(self, idx):
# return self.x[idx], self.y[idx]
# def __len__(self):
# return self.x.shape[0]
# class CIFAR10SSL(CIFAR10):
# """
# from Cifar10
# """
# def __init__(self,
# data_file=None,
# sample_per_label=None,
# download=True,
# expand_labels=1,
# mode='train',
# transform_ops=None,
# transform_w=None,
# transform_s1=None,
# transform_s2=None):
# super().__init__(data_file, download=download, mode=mode)
# self.data_type = 'unlabeled_train' if mode == 'train' else 'val'
# if transform_ops is not None and sample_per_label is not None:
# index = []
# self.data_type = 'labeled_train'
# for c in range(self.num_classes):
# idx = np.where(self.y == c)[0]
# idx = np.random.choice(idx, sample_per_label, False)
# index.extend(idx)
# index = index * expand_labels
# # print(index)
# self.x = self.x[index]
# self.y = self.y[index]
# self.transforms = [transform_ops] if transform_ops is not None else [transform_w, transform_s1, transform_s2]
# self.mode = mode
# def __getitem__(self, idx):
# img, label = np_convert_pil(self.x[idx]), self.y[idx]
# results = ListTransform(self.transforms)(img)
# if self.data_type == 'unlabeled_train':
# return results
# return results[0], label
# def __len__(self):
# return self.x.shape[0]
def __len__(self):
return self.x.shape[0]
# class CIFAR100SSL(CIFAR100):
# """
# from Cifar100
# """
# def __init__(self,
# data_file=None,
# sample_per_label=None,
# download=True,
# expand_labels=1,
# mode='train',
# transform_ops=None,
# transform_w=None,
# transform_s1=None,
# transform_s2=None):
# super().__init__(data_file, download=download, mode=mode)
# self.data_type = 'unlabeled_train' if mode == 'train' else 'val'
# if transform_ops is not None and sample_per_label is not None:
# index = []
# self.data_type = 'labeled_train'
# for c in range(self.num_classes):
# idx = np.where(self.y == c)[0]
# idx = np.random.choice(idx, sample_per_label, False)
# index.extend(idx)
# index = index * expand_labels
# # print(index)
# self.x = self.x[index]
# self.y = self.y[index]
# self.transforms = [transform_ops] if transform_ops is not None else [transform_w, transform_s1, transform_s2]
# self.mode = mode
# def __getitem__(self, idx):
# img, label = np_convert_pil(self.x[idx]), self.y[idx]
# results = ListTransform(self.transforms)(img)
# if self.data_type == 'unlabeled_train':
# return results
# return results[0], label
# def __len__(self):
# return self.x.shape[0]
# def x_u_split(num_labeled, num_classes, label):
# """
......
......@@ -14,6 +14,7 @@
from ppcls.data.preprocess.ops.autoaugment import ImageNetPolicy as RawImageNetPolicy
from ppcls.data.preprocess.ops.randaugment import RandAugment as RawRandAugment
from ppcls.data.preprocess.ops.randaugment import RandomApply
from ppcls.data.preprocess.ops.timm_autoaugment import RawTimmAutoAugment
from ppcls.data.preprocess.ops.cutout import Cutout
......@@ -50,7 +51,7 @@ from paddle.vision.transforms import Pad as Pad_paddle_vision
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
from .ops.randaugmentmc import RandAugmentMC, RandomApply
import numpy as np
from PIL import Image
......@@ -124,39 +125,39 @@ class TimmAutoAugment(RawTimmAutoAugment):
return img
class BaseTransform:
def __init__(self, cfg) -> None:
"""
Args:
cfg: list [dict, dict, dict]
"""
ts = []
for op in cfg:
name = list(op.keys())[0]
if op[name] is None:
ts.append(eval(name)())
else:
ts.append(eval(name)(**(op[name])))
# class BaseTransform:
# def __init__(self, cfg) -> None:
# """
# Args:
# cfg: list [dict, dict, dict]
# """
# ts = []
# for op in cfg:
# name = list(op.keys())[0]
# if op[name] is None:
# ts.append(eval(name)())
# else:
# ts.append(eval(name)(**(op[name])))
self.t = T.Compose(ts)
# self.t = T.Compose(ts)
def __call__(self, img):
# def __call__(self, img):
return self.t(img)
class ListTransform:
def __init__(self, ops) -> None:
"""
Args:
ops: list[list[dict, dict], ...]
"""
self.ts = []
for op in ops:
self.ts.append(BaseTransform(op))
def __call__(self, img):
results = []
for op in self.ts:
results.append(op(img))
return results
# return self.t(img)
# class ListTransform:
# def __init__(self, ops) -> None:
# """
# Args:
# ops: list[list[dict, dict], ...]
# """
# self.ts = []
# for op in ops:
# self.ts.append(BaseTransform(op))
# def __call__(self, img):
# results = []
# for op in self.ts:
# results.append(op(img))
# return results
......@@ -18,6 +18,8 @@
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
from .operators import RawColorJitter
from paddle.vision.transforms import transforms as T
class RandAugment(object):
......@@ -105,3 +107,18 @@ class RandAugment(object):
op_name = np.random.choice(avaiable_op_names)
img = self.func[op_name](img, self.level_map[op_name])
return img
class RandomApply(object):
def __init__(self, p, transforms):
self.p = p
ts = []
for t in transforms:
for key in t.keys():
ts.append(eval(key)(**t[key]))
self.trans = T.Compose(ts)
def __call__(self, img):
timg = self.trans(img)
return timg
\ No newline at end of file
......@@ -15,15 +15,7 @@ import paddle
def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
##############################################################
# out_logger = ReprodLogger()
# loss_logger = ReprodLogger()
# epoch = 0
##############################################################
paddle.save(engine.model.state_dict(), '../recmodel.pdparams')
assert 1==0
tic = time.time()
if not hasattr(engine, 'train_dataloader_iter'):
engine.train_dataloader_iter = iter(engine.train_dataloader)
......@@ -34,8 +26,6 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
assert engine.iter_per_epoch is not None, "Global.iter_per_epoch need to be set"
threshold = paddle.to_tensor(threshold)
# dataload_logger = ReprodLogger()
for iter_id in range(engine.iter_per_epoch):
if iter_id >= engine.iter_per_epoch:
break
......@@ -56,22 +46,9 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
engine.unlabel_train_dataloader_iter = iter(engine.unlabel_train_dataloader)
unlabel_data_batch = engine.unlabel_train_dataloader_iter.next()
assert len(unlabel_data_batch) == 3
assert len(unlabel_data_batch) == 4
assert unlabel_data_batch[0].shape == unlabel_data_batch[1].shape == unlabel_data_batch[2].shape
##############################################################
# inputs_x, target_x = label_data_batch
# inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch
# dataload_logger.add(f"inputs_x_iter_{iter_id}", inputs_x.detach().numpy())
# dataload_logger.add(f"target_x_iter_{iter_id}", target_x.detach().numpy())
# dataload_logger.add(f"inputs_w_iter_{iter_id}", inputs_w.detach().numpy())
# dataload_logger.add(f"inputs_s1_iter_{iter_id}", inputs_s1.detach().numpy())
# dataload_logger.add(f"inputs_s2_iter_{iter_id}", inputs_s2.detach().numpy())
# dataload_logger.save('../align/step2/data/paddle.npy')
# assert 1==0
##############################################################
engine.time_info['reader_cost'].update(time.time() - tic)
batch_size = label_data_batch[0].shape[0] \
+ unlabel_data_batch[0].shape[0] \
......@@ -79,25 +56,18 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
+ unlabel_data_batch[2].shape[0]
engine.global_step += 1
# make inputs
inputs_x, targets_x = label_data_batch
# inputs_x = inputs_x[0]
inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch
inputs_w, inputs_s1, inputs_s2 = unlabel_data_batch[:3]
batch_size_label = inputs_x.shape[0]
inputs = paddle.concat([inputs_x, inputs_w, inputs_s1, inputs_s2], axis=0)
loss_dict, logits_label = get_loss(engine, inputs, batch_size_label,
temperture, threshold, targets_x,
# epoch=epoch,
# batch_idx=iter_id,
# out_logger=out_logger,
# loss_logger=loss_logger
)
loss = loss_dict['loss']
loss.backward()
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
......@@ -123,9 +93,6 @@ def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
tic = time.time()
# if iter_id == 10:
# assert 1==0
for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], 'by_epoch', False):
engine.lr_sch[i].step()
......@@ -140,7 +107,6 @@ def get_loss(engine,
):
out = engine.model(inputs)
logits, feats = out['logits'], out['features']
# logits, feats = engine.model(inputs)
feat_w, feat_s1, feat_s2 = feats[batch_size_label:].chunk(3)
feat_x = feats[:batch_size_label]
logits_x = logits[:batch_size_label]
......@@ -150,7 +116,6 @@ def get_loss(engine,
max_probs, p_targets_u_w = probs_u_w.max(axis=-1), probs_u_w.argmax(axis=-1)
mask = paddle.greater_equal(max_probs, threshold).astype('float')
# feats = paddle.concat([logits_s1.unsqueeze(1), logits_s2.unsqueeze(1)], axis=1)
feats = paddle.concat([feat_s1.unsqueeze(1), feat_s2.unsqueeze(1)], axis=1)
batch = {'logits_w': logits_w,
'logits_s1': logits_s1,
......@@ -170,30 +135,5 @@ def get_loss(engine,
loss_dict[k] = v
loss_dict['loss'] = loss_dict_label['loss'] + unlabel_loss['loss']
##############################################################
# print(loss_dict)
# epoch = kwargs['epoch']
# batch_idx = kwargs['batch_idx']
# out_logger = kwargs['out_logger']
# loss_logger = kwargs['loss_logger']
# out_logger.add(f'logit_x_{epoch}_{batch_idx}', logits_x.detach().numpy())
# out_logger.add(f'logit_u_w_{epoch}_{batch_idx}', logits_w.detach().numpy())
# out_logger.add(f'logit_u_s1_{epoch}_{batch_idx}', logits_s1.detach().numpy())
# out_logger.add(f'logit_u_s2_{epoch}_{batch_idx}', logits_s2.detach().numpy())
# out_logger.add(f'feat_x_{epoch}_{batch_idx}', feat_x.detach().numpy())
# out_logger.add(f'feat_w_{epoch}_{batch_idx}', feat_w.detach().numpy())
# out_logger.add(f'feat_s1_{epoch}_{batch_idx}', feat_s1.detach().numpy())
# out_logger.add(f'feat_s2_{epoch}_{batch_idx}', feat_s2.detach().numpy())
# loss_logger.add(f'loss_{epoch}_{batch_idx}', loss_dict['loss'].detach().numpy())
# loss_logger.add(f'loss_x_{epoch}_{batch_idx}', loss_dict['CELoss'].detach().cpu().numpy())
# loss_logger.add(f'loss_u_{epoch}_{batch_idx}', loss_dict['CCSSLCeLoss'].detach().cpu().numpy())
# loss_logger.add(f'loss_c_{epoch}_{batch_idx}', loss_dict['SoftSupConLoss'].detach().cpu().numpy())
# loss_logger.add(f'mask_prob_{epoch}_{batch_idx}', mask.mean().detach().numpy())
# out_logger.save('../align/step3/data/paddle_out.npy')
# loss_logger.save('../align/step3/data/paddle_loss.npy')
##############################################################
# assert 1==0
return loss_dict, logits_x
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册