未验证 提交 122f7f97 编写于 作者: L leozhang0912 提交者: GitHub

add MoCo V2 (#2757)

上级 5d06a88a
# rfc_task-137_MoCo-v2模型PaddleClas实现设计文档)
|模型名称 | MoCov2模型 |
|---|---|
|相关paper| https://arxiv.org/pdf/2003.04297.pdf |
|参考项目| https://github.com/PaddlePaddle/PASSL https://github.com/facebookresearch/MoCo|
|提交作者 | 张乐 |
|提交时间 | 2022-03-11 |
|依赖飞桨版本 | PaddlePaddle2.4.1 |
|文件名 | rfc_task_137_model_MoCo-v2.md |
# MoCo-v2 模型PaddleClas实现设计文档
## 一、概述
MoCo-v2[<sup>2</sup>](#moco-v2)模型是在MoCo模型的基础上增加了数据增强、将单层fc替换为多层mlp、学习率衰减策略修改为consine衰减。因此,我们在此重点介绍MoCo模型。
MoCo[<sup>1</sup>](#moco-v1)模型本身是一个自监督对比学习框架,可以从大规模图像数据集中学习到良好的图像表示特征,其预训练模型可以无缝地嵌入许多视觉任务中,比如:图像分类、目标检测、分割等。
**MoCo框架简述**
**前向传播**
下面我们从输入$minibatchImgs=\{I_1,I_2,..I_N\}$ 数据的前向传播过程来简单讲解MoCo框架,首先对$I_n$分别进行变换$view_1$和$view_2$:
$$I^{view1}_n=view_1(I_n)$$
$$I^{view2}_n=view_2(I_n)$$
其中,$view_1$和$view_2$表示一系列图像预处理变换(随机裁切、灰度化、均值化等,具体详见paper Source Code),minibatch大小为$N$。这样每幅输入图像$I_n$就会得到两个变换图像$I^{view1}_n$和$I^{view2}_n$。
接着将$I^{view1}_n$和$I^{view2}_n$分别送入两个编码器,则:
$$q_n=L2_{normalization}(Encoder_1(I^{view1}_n))$$
$$k_n=L2_{normalization}(Encoder_2(I^{view2}_n))$$
其中$q_n$和$k_n$的特征维度均为k, $Encoder_1$和$Encoder_2$分别是ResNet50的backbone网络串联一个MLP网络组成。
为了满足对比学习任务的条件,需要正负样本来进行学习。作者自然而然将输入的样本都看作正样本,至于负样本,则通过构建一个**动态**$Dict_{K\times C}$维度的超大字典,通过将正样本集合$q_+=\{q_1,q_2...q_N\}$和$k_+=\{k_1,k_2...k_N\}$一一做向量点乘求和相加来计算$Loss_+$:
$$Loss_+=\{l^{1}_+;l^{2}_+; ...;l^{N}_+\}=\{ q_1\cdot k_1; q_2\cdot k_2;...; q_n\cdot k_n \}; Loss_+\in N \times 1$$
$Loss_-$的计算过程为:
$$l^{n,k}_-=q_n \cdot Dict_{:,n};Loss_-\in N \times C$$
最后的loss为:
$$Loss=concat(Loss_+, Loss_-)\in N \times (1+C)$$
可以看到字典$Dict$在整个图像表示的学习过程中可以看作一个隐特征空间,作者发现,该字典设置的越大,视觉表示学习的效果就越好。其中,每次在做完前向传播后,需要将当前的minibatch以**队列入队**的形式将$k_n$加入到字典$Dict$中,并同时将最旧时刻的minibatch**出队**
学习的目标函数采用交叉熵损失函数如下所示:
$$Loss_{crossentropy}=-log \cdot \frac{exp(l_+/ \tau)}{ \sum exp(l_n / \tau)}$$
其中超参数$\tau$取0.07
**反向梯度传播**
在梯度反向传播过程中,梯度传播只用来更新$Encoder_1$的参数$Param_{Encoder_1}$,为了不影响动态词典$Dict$的视觉表示特征一致性,$Encoder_2$的参数$Param_{Encoder_1}$更新过程为:
$$Param_{Encoder_2}=m \cdot Param_{Encoder_2} + ( 1- m ) \cdot Param_{Encoder_1} $$
其中,超参数$m$取0.999
## 二、设计思路与实现方案
### 模型backbone(PaddleClas已有实现)
- ResNet50的backbone(去除最后的全连接层)
- MLP由 两个全连接层FC1 $ 2048 \times 2048 $ 和FC2 $ 2048 \times 128 $ 构成
- 动态字典大小为$65536$
### optimizer
- SGD:随机梯度下降优化器
- 初始学习率 $0.03$
- 权重衰减:$1e-4$
- momentum of SGD: $0.9$
### 训练策略(PaddleClas已有实现)
- batch-size:256
- 单机8块V100
- 在每个GPU上做shuffle_BN
- 共迭代$epochs:200$
- lr schedule 在$epch=[120, 160]$, $lr=lr*.0.1$
- 学习率衰减策略$cosine $
### metric(PaddleClas已有实现)
- top1
- top5
### dataset
- 数据集:ImageNet
- 数据增强(PaddleClas已有基本变换实现)
```Python
#pytorch 代码
augmentation = [
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomApply(
[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8 # not strengthened
),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([moco.loader.GaussianBlur([0.1, 2.0])], p=0.5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
```
- 图像对随机变换和高斯模糊(**PSSL已有基本变换实现,需要转为PaddleClas项目实现**
```python
# pytorch 代码
class TwoCropsTransform:
"""Take two random crops of one image as the query and key."""
def __init__(self, base_transform):
self.base_transform = base_transform
def __call__(self, x):
q = self.base_transform(x)
k = self.base_transform(x)
return [q, k]
class GaussianBlur(object):
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
def __init__(self, sigma=[0.1, 2.0]):
self.sigma = sigma
def __call__(self, x):
sigma = random.uniform(self.sigma[0], self.sigma[1])
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x
```
### PSSL项目和PaddleClas项目框架对比
- 两个项目基础模型ResNet50的每层参数名称不同,需要将PASSL项目的训练权重转化为PaddleClas项目使用
- PSSL项目采用Register类方式将模型的architecture、backbone、neck、head、数据集、优化器、钩子函数链接在一起,使得整个模型的训练过程都可以通过命令行提供一份yaml文件搞定,这一点与PaddleClas项目类似
### 详细设计方案
1. model_zone添加模型MoCo的backbone、neck、head参数配置,backbone采用theseuslayer定义的Resnet50网络;
2. MoCo.yaml格式参考paddleclas项目;
3. 在ppcls.data.preprocess.ops.operators.py 文件下新增GaussianBlur类
4. 重构ImageNetDataset类中的__init__ 和__getitem__方法, 原来的ImageNetDataset只能返回(img, label) 现增加返回(sample_1, sample_2, label)可选功能,其中,sample_1和sample_2均是img分别经过view_trans1,view_trans2得到的;
5. 在train.py
## 三、功能模块测试方法
|功能模块|测试方法|
|---|---|
|前向完全对齐|给定相同的输入,分别对比PaddleClas实现的模型输出是否和官方的Pytorch版本相同|
|反向完全对齐|给定相同的输入检查反向参数更新,分别对比PaddleClas实现和官方的Pytorch版本参数更新是否一致|
|图像预处理|对照官方实现,编写paddle版本|
|超参数配置|保持和官方实现一致|
|训练环境|最好也是8块V100显卡环境,采用单机多卡分布式训练方式,和官方保持一致|
|精度对齐|在提供的小数据集上预训练并finetune后,实现精度和原PSSL项目模型相同|
## 四、可行性分析和排期规划
|时间|开发排期规划|时长|
|---|---|---|
|03.11-03.19|熟悉相关工具、前向对齐|9days|
|03.20-04.02|反向对齐|14days|
|04.03-04.16|训练对齐|14days|
|04.16-04.29|代码合入|14days|
## 五、风险点与影响面
风险点:
- MoCo模型训练后一般作为图像特征提取器使用,并不存在所谓的推理过程
- **PaddleClas中所有模型和算法需要通过飞桨训推一体认证,当前只需要通过新增模型只需要通过训练和推理的基础认证即可**。但是这个与MoCo模型的训练推理原则相违背,是否可以对MoCo-v2模型的认证给出明确的指定
- 合入代码题目是MoCo-v2,代码合入的时候是否需要同时考虑MoCo-v1代码模块(原PSSL项目有该项实现)
- 原PSSL有MoCo-Clas分类模型,代码合入的时候是否需要同时加入此模块(原PSSL项目有该项实现)
- 可能涉及到修改train.py部分代码
影响面:
数据的Dataloader、数据增强和model均为新增脚本,不对其它模块构成影响
# 名词解释
MoCo(Momentum Contrast,动量对比)
# 附件及参考资料
<div id="moco-v1"></div>
[1] He K, Fan H, Wu Y, et al. Momentum contrast for unsupervised visual representation learning[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2020: 9729-9738.
<div id="moco-v2"></div>
[2] Chen X, Fan H, Girshick R, et al. Improved baselines with momentum contrastive learning[J]. arXiv preprint arXiv:2003.04297, 2020.
......@@ -75,7 +75,8 @@ from .model_zoo.foundation_vit import CLIP_vit_base_patch32_224, CLIP_vit_base_p
from .model_zoo.convnext import ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384
from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384
from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224
from .model_zoo.moco import MoCo_V1, MoCo_V2
from .model_zoo.moco_finetune import MoCo_finetune
from .variant_models.resnet_variant import ResNet50_last_stage_stride1
from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d
from .variant_models.resnet_variant import ResNet50_metabin
......
......@@ -346,7 +346,7 @@ class ResNet(TheseusLayer):
[32, 32, 3, 1], [32, 64, 3, 1]]
}
self.stem = nn.Sequential(* [
self.stem = nn.Sequential(*[
ConvBNLayer(
num_channels=in_c,
num_filters=out_c,
......@@ -396,7 +396,7 @@ class ResNet(TheseusLayer):
self.data_format = data_format
super().init_res(
super().init_net(
stages_pattern,
return_patterns=return_patterns,
return_stages=return_stages)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://arxiv.org/abs/1611.05431
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from ppcls.utils.initializer import kaiming_normal_, constant_, normal_
from ..legendary_models import *
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {"MoCo_V1": "UNKNOWN", "MoCo_V2": "UNKNOWN"}
__all__ = list(MODEL_URLS.keys())
class LinearNeck(nn.Layer):
"""Linear neck: fc only.
"""
def __init__(self, in_channels, out_channels, with_avg_pool=False):
super(LinearNeck, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
self.fc = nn.Linear(in_channels, out_channels)
def forward(self, x):
if self.with_avg_pool:
x = self.avgpool(x)
return self.fc(x.reshape([x.shape[0], -1]))
class NonLinearNeck(nn.Layer):
"""The non-linear neck in MoCo v2: fc-relu-fc.
"""
def __init__(self,
in_channels,
hid_channels,
out_channels,
with_avg_pool=False):
super(NonLinearNeck, self).__init__()
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
self.mlp = nn.Sequential(
nn.Linear(in_channels, hid_channels),
nn.ReLU(), nn.Linear(hid_channels, out_channels))
def forward(self, x):
if self.with_avg_pool:
x = self.avgpool(x)
return self.mlp(x.reshape([x.shape[0], -1]))
class ContrastiveHead(nn.Layer):
"""Head for contrastive learning.
Args:
temperature (float): The temperature hyper-parameter that
controls the concentration level of the distribution.
Default: 0.1.
"""
def __init__(self, temperature=0.1):
super(ContrastiveHead, self).__init__()
self.criterion = nn.CrossEntropyLoss()
self.temperature = temperature
def forward(self, pos, neg):
"""Forward head.
Args:
pos (Tensor): Nx1 positive similarity.
neg (Tensor): Nxk negative similarity.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
N = pos.shape[0]
logits = paddle.concat((pos, neg), axis=1)
logits /= self.temperature
labels = paddle.zeros((N, 1), dtype='int64')
return logits, labels
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
class MoCo(nn.Layer):
"""
Build a MoCo model with: a query encoder, a key encoder, and a queue
https://arxiv.org/abs/1911.05722
"""
def __init__(self,
backbone_config,
neck_config,
head_config,
dim=128,
K=65536,
m=0.999,
T=0.07):
"""
initialize `MoCoV1` or `MoCoV2` model depends on args
Args:
backbone_config (dict): config of backbone(eg: ResNet50).
neck_config (dict): config of neck(eg: MLP or FC)
head_config (dict): config of head
dim (int): feature dimension. Default: 128.
K (int): queue size; number of negative keys. Default: 65536.
m (float): moco momentum of updating key encoder. Default: 0.999.
T (float): softmax temperature. Default: 0.07.
"""
super(MoCo, self).__init__()
self.K = K
self.m = m
self.T = T
backbone_type = backbone_config.pop('name')
backbone = eval(backbone_type)
neck_type = neck_config.pop('name')
neck = eval(neck_type)
head_type = head_config.pop('name')
head = eval(head_type)
backbone_1 = backbone()
backbone_1.stop_after(stop_layer_name='avg_pool')
backbone_2 = backbone()
backbone_2.stop_after(stop_layer_name='avg_pool')
self.encoder_q = nn.Sequential(backbone_1, neck(**neck_config))
self.encoder_k = nn.Sequential(backbone_2, neck(**neck_config))
self.backbone = self.encoder_q[0]
self.head = head(**head_config)
# initialize function by kaiming
self.init_parameters()
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.set_value(param_q) # moco initialize
param_k.stop_gradient = True # not update by gradient
# frozen bn normal
freeze_batchnorm_statictis(self.encoder_k)
# create the queue
self.register_buffer("queue", paddle.randn([dim, K]))
self.queue = nn.functional.normalize(self.queue, axis=0)
self.register_buffer("queue_ptr", paddle.zeros([1], 'int64'))
def init_parameters(self, init_linear='kaiming', std=0.01, bias=0.):
assert init_linear in ['normal', 'kaiming'], \
"Undefined init_linear: {}".format(init_linear)
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
kaiming_normal_(m, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.layer.norm._BatchNormBase, nn.GroupNorm)):
constant_(m, 1)
elif isinstance(m, nn.Linear):
if init_linear == 'normal':
normal_(m, std=std, bias=bias)
else:
kaiming_normal_(m, mode='fan_in', nonlinearity='relu')
@paddle.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
paddle.assign((param_k * self.m + param_q * (1. - self.m)),
param_k)
param_k.stop_gradient = True
@paddle.no_grad()
def _dequeue_and_enqueue(self, keys):
keys = concat_all_gather(keys)
batch_size = keys.shape[0]
ptr = int(self.queue_ptr[0])
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.transpose([1, 0])
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr
@paddle.no_grad()
def _batch_shuffle_ddp(self, x):
"""
Batch shuffle, for making use of BatchNorm.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# random shuffle index
idx_shuffle = paddle.randperm(batch_size_all).cuda()
# broadcast to all gpus
if paddle.distributed.get_world_size() > 1:
paddle.distributed.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = paddle.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = paddle.distributed.get_rank()
idx_this = idx_shuffle.reshape([num_gpus, -1])[gpu_idx]
return paddle.index_select(x_gather, idx_this), idx_unshuffle
@paddle.no_grad()
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
"""
Undo batch shuffle.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# restored index for this gpu
gpu_idx = paddle.distributed.get_rank()
idx_this = idx_unshuffle.reshape([num_gpus, -1])[gpu_idx]
return paddle.index_select(x_gather, idx_this)
def train_iter(self, inputs, **kwargs):
img_q, img_k = inputs
# compute query features
q = self.encoder_q(img_q) # queries: NxC
q = nn.functional.normalize(q, axis=1)
# compute key features
with paddle.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
# shuffle for making use of BN
img_k = paddle.to_tensor(img_k)
im_k, idx_unshuffle = self._batch_shuffle_ddp(img_k)
k = self.encoder_k(im_k) # keys: NxC
k = nn.functional.normalize(k, axis=1)
# undo shuffle
k = self._batch_unshuffle_ddp(k, idx_unshuffle)
# compute logits
# FIXME: Einstein sum is more intuitive
# positive logits: Nx1
l_pos = paddle.sum(q * k, axis=1).unsqueeze(-1)
# negative logits: NxK
l_neg = paddle.matmul(q, self.queue.clone().detach())
outputs = self.head(l_pos, l_neg)
self._dequeue_and_enqueue(k)
# add return label
return outputs
def forward(self, inputs, mode='train', **kwargs):
if mode == 'train':
return self.train_iter(inputs, **kwargs)
elif mode == 'test':
return self.test_iter(inputs, **kwargs)
elif mode == 'extract':
return self.backbone(inputs)
else:
raise Exception("No such mode: {}".format(mode))
@paddle.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
"""
if paddle.distributed.get_world_size() < 2:
return tensor
tensors_gather = []
paddle.distributed.all_gather(tensors_gather, tensor)
output = paddle.concat(tensors_gather, axis=0)
return output
def freeze_batchnorm_statictis(layer):
def freeze_bn(layer):
if isinstance(layer, (nn.layer.norm._BatchNormBase)):
layer._use_global_stats = True
def MoCo_V1(backbone, neck, head, pretrained=False, use_ssld=False):
model = MoCo(
backbone_config=backbone, neck_config=neck, head_config=head, T=0.07)
_load_pretrained(
pretrained, model, MODEL_URLS["MoCo_V1"], use_ssld=use_ssld)
return model
def MoCo_V2(backbone, neck, head, pretrained=False, use_ssld=False):
model = MoCo(
backbone_config=backbone, neck_config=neck, head_config=head, T=0.2)
_load_pretrained(
pretrained, model, MODEL_URLS["MoCo_V2"], use_ssld=use_ssld)
return model
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://arxiv.org/abs/1611.05431
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from ....utils import logger
from ppcls.utils.initializer import normal_
from ..legendary_models import *
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {"MoCo_finetune": "UNKNOWN"}
__all__ = list(MODEL_URLS.keys())
class ClasHead(nn.Layer):
"""Simple classifier head.
"""
def __init__(self, with_avg_pool=False, in_channels=2048, class_num=1000):
super(ClasHead, self).__init__()
self.with_avg_pool = with_avg_pool
self.in_channels = in_channels
self.num_classes = class_num
if self.with_avg_pool:
self.avg_pool = nn.AdaptiveAvgPool2D((1, 1))
self.fc = nn.Linear(in_channels, class_num)
# reset_parameters(self.fc_cls)
normal_(self.fc, mean=0.0, std=0.01, bias=0.0)
def forward(self, x):
if self.with_avg_pool:
x = self.avg_pool(x)
x = paddle.reshape(x, [-1, self.in_channels])
x = self.fc(x)
return x
def _load_pretrained(pretrained_config, model, use_ssld=False):
if pretrained_config is not None:
if pretrained_config.startswith("http"):
load_dygraph_pretrain_from_url(model, pretrained_config)
else:
load_dygraph_pretrain(model, pretrained_config)
class Classification(nn.Layer):
"""
Simple image classification.
"""
def __init__(self, backbone, head, with_sobel=False):
super(Classification, self).__init__()
self.backbone = backbone
self.head = head
def forward(self, inputs):
x = self.backbone(inputs)
x = self.head(x)
return x
def freeze_batchnorm_statictis(layer):
def freeze_bn(layer):
if isinstance(layer, nn.BatchNorm):
layer._use_global_stats = True
def freeze_params(model):
from ppcls.arch.backbone.legendary_models.resnet import ConvBNLayer, BottleneckBlock
for item in ['stem', 'max_pool', 'blocks', 'avg_pool']:
m = getattr(model, item)
if isinstance(m, nn.Sequential):
for item in m:
if isinstance(item, ConvBNLayer):
print(item.bn)
freeze_batchnorm_statictis(item.bn)
if isinstance(item, BottleneckBlock):
freeze_batchnorm_statictis(item.conv0.bn)
freeze_batchnorm_statictis(item.conv1.bn)
freeze_batchnorm_statictis(item.conv2.bn)
if hasattr(item, 'short'):
freeze_batchnorm_statictis(item.short.bn)
for param in m.parameters():
param.trainable = False
def MoCo_finetune(backbone, head, pretrained=False, use_ssld=False):
backbone_config = backbone
head_config = head
backbone_name = backbone_config.pop('name')
backbone = eval(backbone_name)(**backbone_config)
# stop layer for backbone
stop_layer_name = backbone_config.pop('stop_layer_name', None)
if stop_layer_name:
backbone.stop_after(stop_layer_name=stop_layer_name)
# freeze specified layer before
freeze_layer_name = backbone_config.pop('freeze_befor', None)
if freeze_layer_name:
ret = backbone.freeze_befor(freeze_layer_name)
if ret:
logger.info(
"moco_clas backbone successfully freeze param update befor the layer: {}".
format(freeze_layer_name))
else:
logger.error(
"moco_clas backbone failurely freeze param update befor the layer: {}".
format(freeze_layer_name))
freeze_params(backbone)
head_name = head_config.pop('name')
head = eval(head_name)(**head_config)
model = Classification(backbone=backbone, head=head)
# load pretrain_moco_model weight
pretrained_config = backbone_config.pop('pretrained_model')
_load_pretrained(pretrained_config, model, use_ssld=use_ssld)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 50
# train_epoch_iter_two_samples
train_mode: iter_two_samples
eval_during_train: False
eval_interval: 1
epochs: 200
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# model architecture
Arch:
name: MoCo_V2
backbone:
name: ResNet50
stop_layer_name: AvgPool2D
neck:
name: NonLinearNeck
in_channels: 2048
hid_channels: 2048
out_channels: 128
head:
name: ContrastiveHead
temperature: 0.2
# loss function config
Loss:
Train:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 0.0001
lr:
name: Cosine
learning_rate: 0.03
T_max: 200
# data loader for train
DataLoader:
Train:
dataset:
name: MoCoImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
return_label: False
return_two_sample: True
transform_ops:
- DecodeImage:
to_rgb: True,
channel_first: False
- RandomResizedCrop:
size: 224
scale: [0.2, 1.]
view_trans1:
- RandomApply:
transforms:
- RawColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
hue: 0.1
p: 0.8
- RandomGrayscale:
p: 0.2
- RandomApply:
transforms:
- GaussianBlur:
sigma: [0.1, 2.0]
p: 0.5
- RandomHorizontalFlip:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
view_trans2:
- RandomApply:
transforms:
- RawColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
hue: 0.1
p: 0.8
- RandomGrayscale:
p: 0.2
- RandomApply:
transforms:
- GaussianBlur:
sigma: [0.1, 2.0]
p: 0.5
- RandomHorizontalFlip:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Metric:
Train:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
output_dir: ./output/
device: gpu
save_interval: 20
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
Arch:
name: MoCo_finetune
pretrained_model: ./pretrain/moco_v2_bs_256_epoch_200
backbone:
name: ResNet50
stop_layer_name: avg_pool
freeze_befor: avg_pool
head:
name: ClasHead
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: MultiStepDecay
learning_rate: 30.0
milestones: [60, 80]
DataLoader:
Train:
dataset:
name: MoCoImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
return_label: True
return_two_sample: False
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandomResizedCrop:
size: 224
- RandomHorizontalFlip:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
shuffle: True
drop_last: False
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: MoCoImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
return_label: True
return_two_sample: False
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- Resize:
size: 256
- CenterCrop:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
shuffle: True
drop_last: True
loader:
num_workers: 4
use_shared_memory: True
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
\ No newline at end of file
......@@ -14,3 +14,4 @@ from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDat
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
from ppcls.data.dataloader.cifar import Cifar10, Cifar100
from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler
from ppcls.data.dataloader.moco_imagenet_dataset import MoCoImageNetDataset
......@@ -72,4 +72,4 @@ class ImageNetDataset(CommonDataset):
else:
self.labels.append(np.int64(line[1]))
assert os.path.exists(self.images[
-1]), f"path {self.images[-1]} does not exist."
-1]), f"path {self.images[-1]} does not exist."
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import os
from ppcls.utils import logger
from .common_dataset import CommonDataset, create_operators
from ppcls.data.preprocess import transform
class MoCoImageNetDataset(CommonDataset):
"""MoCoImageNetDataset
Args:
image_root (str): image root, path to `ILSVRC2012`
cls_label_path (str): path to annotation file `train_list.txt` or `val_list.txt`
return_label (bool, optional): whether return original label.
return_two_sample (bool, optional): whether return two views about original image.
transform_ops (list, optional): list of transform op(s). Defaults to None.
delimiter (str, optional): delimiter. Defaults to None.
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
view_trans1 (list): some transform op(s) for view1.
view_trans2 (list): some transform op(s) for view2.
"""
def __init__(
self,
image_root,
cls_label_path,
return_label=True,
return_two_sample=False,
transform_ops=None,
delimiter=None,
relabel=False,
view_trans1=None,
view_trans2=None, ):
self.delimiter = delimiter if delimiter is not None else " "
self.relabel = relabel
super(MoCoImageNetDataset, self).__init__(image_root, cls_label_path,
transform_ops)
self.return_label = return_label
self.return_two_sample = return_two_sample
if self.return_two_sample:
self.view_transform1 = create_operators(view_trans1)
self.view_transform2 = create_operators(view_trans2)
def __getitem__(self, idx):
try:
with open(self.images[idx], 'rb') as f:
img = f.read()
if self.return_two_sample:
sample1 = transform(img, self._transform_ops)
sample2 = transform(img, self._transform_ops)
sample1 = transform(sample1, self.view_transform1)
sample2 = transform(sample2, self.view_transform2)
if self.return_label:
return (sample1, sample2, self.labels[idx])
else:
return (sample1, sample2)
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, self.labels[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def _load_anno(self, seed=None):
assert os.path.exists(
self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
self._img_root), f"path {self._img_root} does not exist."
self.images = []
self.labels = []
with open(self._cls_path) as fd:
lines = fd.readlines()
if self.relabel:
label_set = set()
for line in lines:
line = line.strip().split(self.delimiter)
label_set.add(np.int64(line[1]))
label_map = {
oldlabel: newlabel
for newlabel, oldlabel in enumerate(label_set)
}
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
for line in lines:
line = line.strip().split(self.delimiter)
self.images.append(os.path.join(self._img_root, line[0]))
if self.relabel:
self.labels.append(label_map[np.int64(line[1])])
else:
self.labels.append(np.int64(line[1]))
assert os.path.exists(self.images[
-1]), f"path {self.images[-1]} does not exist."
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop, Transpose
from ppcls.data.preprocess.ops.autoaugment import ImageNetPolicy as RawImageNetPolicy
from ppcls.data.preprocess.ops.randaugment import RandAugment as RawRandAugment
from ppcls.data.preprocess.ops.randaugment import RandomApply
......@@ -48,6 +49,8 @@ from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90
from ppcls.data.preprocess.ops.operators import PCALighting
from ppcls.data.preprocess.ops.operators import GaussianBlur
from .ops.operators import format_data
from paddle.vision.transforms import Pad as Pad_paddle_vision
......@@ -58,6 +61,7 @@ import numpy as np
from PIL import Image
import random
def transform(data, ops=[]):
""" transform """
for op in ops:
......@@ -139,4 +143,4 @@ class TimmAutoAugment(RawTimmAutoAugment):
if isinstance(img, Image.Image):
img = np.asarray(img)
return img
\ No newline at end of file
return img
......@@ -24,12 +24,13 @@ import math
import random
import cv2
import numpy as np
from PIL import Image, ImageOps, __version__ as PILLOW_VERSION
from PIL import ImageFilter, Image, ImageOps, __version__ as PILLOW_VERSION
from paddle.vision.transforms import ColorJitter as RawColorJitter
from paddle.vision.transforms import CenterCrop, Resize
from paddle.vision.transforms import RandomRotation as RawRandomRotation
from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop
from paddle.vision.transforms import functional as F
from paddle.vision.transforms import transforms as T
from .autoaugment import ImageNetPolicy
from .functional import augmentations
from ppcls.utils import logger
......@@ -742,8 +743,8 @@ class Pad(object):
# Process fill color for affine transforms
major_found, minor_found = (int(v)
for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (int(v) for v in
min_pil_version.split('.')[:2])
major_required, minor_required = (
int(v) for v in min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and
minor_found < minor_required):
if fill is None:
......@@ -858,6 +859,25 @@ class BlurImage(object):
return {"img": img, "blur_image": label}
class GaussianBlur(object):
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
def __init__(self, sigma=[.1, 2.], backend="cv2"):
self.sigma = sigma
self.kernel_size = 23
self.backbend = backend
def __call__(self, x):
sigma = np.random.uniform(self.sigma[0], self.sigma[1])
if self.backbend == "PIL":
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x
else:
x = cv2.GaussianBlur(
np.array(x), (self.kernel_size, self.kernel_size), sigma)
return Image.fromarray(x.astype(np.uint8))
class RandomGrayscale(object):
"""Randomly convert image to grayscale with a probability of p (default 0.1).
......@@ -878,14 +898,20 @@ class RandomGrayscale(object):
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be converted to grayscale.
img (PIL.Image|np.array): Image to be converted to grayscale.
Returns:
PIL Image: Randomly grayscaled image.
"""
num_output_channels = 1 if img.mode == 'L' else 3
if random.random() < self.p:
return F.to_grayscale(img, num_output_channels=num_output_channels)
if isinstance(img, Image.Image):
if img.mode == 'L':
num_output_channels = 1
if isinstance(img, np.ndarray) or isinstance(img, Image.Image):
num_output_channels = 3
if random.random() < self.p:
return F.to_grayscale(
img, num_output_channels=num_output_channels)
return img
def __repr__(self):
......
......@@ -259,4 +259,4 @@ class RandAugmentV2(RandAugment):
"equalize": lambda img, _: ImageOps.equalize(img),
"invert": lambda img, _: ImageOps.invert(img),
"cutout": lambda img, magnitude: cutout(img, magnitude, replace=fillcolor[0])
}
}
\ No newline at end of file
......@@ -330,6 +330,7 @@ class Engine(object):
if self.config["Global"]["distributed"]:
dist.init_parallel_env()
self.model = paddle.DataParallel(self.model)
if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.DataParallel(
......
......@@ -16,3 +16,4 @@ from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch
from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
from ppcls.engine.train.train_progressive import train_epoch_progressive
from ppcls.engine.train.train_metabin import train_epoch_metabin
from ppcls.engine.train.train_iter_two_samples import train_epoch_iter_two_samples
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import time
import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
from ppcls.utils import profiler
def train_epoch_iter_two_samples(engine, epoch_id, print_batch_step):
tic = time.time()
if not hasattr(engine, "train_dataloader_iter"):
engine.train_dataloader_iter = iter(engine.train_dataloader)
for iter_id in range(engine.iter_per_epoch):
# fetch data batch from dataloader
try:
batch = next(engine.train_dataloader_iter)
except Exception:
engine.train_dataloader_iter = iter(engine.train_dataloader)
batch = next(engine.train_dataloader_iter)
profiler.add_profiler_step(engine.config["profiler_options"])
if iter_id == 5:
for key in engine.time_info:
engine.time_info[key].reset()
engine.time_info["reader_cost"].update(time.time() - tic)
# view_1_samples: batch[0]
# view_2_samples: batch[1]
batch_size = batch[0].shape[0]
engine.global_step += 1
# image input
if engine.amp:
amp_level = engine.config["AMP"].get("level", "O1").upper()
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=amp_level):
logits, labels = forward(engine, batch)
loss_dict = engine.train_loss_func(logits, labels)
else:
logits, labels = forward(engine, batch)
loss_dict = engine.train_loss_func(logits, labels)
# loss
loss = loss_dict["loss"] / engine.update_freq
# backward & step opt
if engine.amp:
scaled = engine.scaler.scale(loss)
scaled.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else:
loss.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
if (iter_id + 1) % engine.update_freq == 0:
# clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].clear_grad()
# step lr(by step)
for i in range(len(engine.lr_sch)):
if not getattr(engine.lr_sch[i], "by_epoch", False):
engine.lr_sch[i].step()
# update ema
if engine.ema:
engine.model_ema.update(engine.model)
# below code just for logging
# update metric_for_logger
update_metric(engine, logits, [labels], batch_size)
# update_loss_for_logger
update_loss(engine, loss_dict, batch_size)
engine.time_info["batch_cost"].update(time.time() - tic)
if iter_id % print_batch_step == 0:
log_info(engine, batch_size, epoch_id, iter_id)
tic = time.time()
# step lr(by epoch)
for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
engine.lr_sch[i].step()
def forward(engine, batch):
return engine.model(batch)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册