diff --git a/docs/zh_CN/community/rfcs/rfc_task-137_model_MoCo-v2.md b/docs/zh_CN/community/rfcs/rfc_task-137_model_MoCo-v2.md new file mode 100644 index 0000000000000000000000000000000000000000..018064f7bed6a8331634a951c8dc8489e4e5f612 --- /dev/null +++ b/docs/zh_CN/community/rfcs/rfc_task-137_model_MoCo-v2.md @@ -0,0 +1,183 @@ +# rfc_task-137_MoCo-v2模型PaddleClas实现设计文档) + +|模型名称 | MoCov2模型 | +|---|---| +|相关paper| https://arxiv.org/pdf/2003.04297.pdf | +|参考项目| https://github.com/PaddlePaddle/PASSL https://github.com/facebookresearch/MoCo| +|提交作者 | 张乐 | +|提交时间 | 2022-03-11 | +|依赖飞桨版本 | PaddlePaddle2.4.1 | +|文件名 | rfc_task_137_model_MoCo-v2.md | + +# MoCo-v2 模型PaddleClas实现设计文档 +## 一、概述 + +MoCo-v2[2](#moco-v2)模型是在MoCo模型的基础上增加了数据增强、将单层fc替换为多层mlp、学习率衰减策略修改为consine衰减。因此,我们在此重点介绍MoCo模型。 + +MoCo[1](#moco-v1)模型本身是一个自监督对比学习框架,可以从大规模图像数据集中学习到良好的图像表示特征,其预训练模型可以无缝地嵌入许多视觉任务中,比如:图像分类、目标检测、分割等。 + +**MoCo框架简述** + +**前向传播** + +下面我们从输入$minibatchImgs=\{I_1,I_2,..I_N\}$ 数据的前向传播过程来简单讲解MoCo框架,首先对$I_n$分别进行变换$view_1$和$view_2$: +$$I^{view1}_n=view_1(I_n)$$ +$$I^{view2}_n=view_2(I_n)$$ +其中,$view_1$和$view_2$表示一系列图像预处理变换(随机裁切、灰度化、均值化等,具体详见paper Source Code),minibatch大小为$N$。这样每幅输入图像$I_n$就会得到两个变换图像$I^{view1}_n$和$I^{view2}_n$。 + +接着将$I^{view1}_n$和$I^{view2}_n$分别送入两个编码器,则: +$$q_n=L2_{normalization}(Encoder_1(I^{view1}_n))$$ +$$k_n=L2_{normalization}(Encoder_2(I^{view2}_n))$$ + +其中$q_n$和$k_n$的特征维度均为k, $Encoder_1$和$Encoder_2$分别是ResNet50的backbone网络串联一个MLP网络组成。 + +为了满足对比学习任务的条件,需要正负样本来进行学习。作者自然而然将输入的样本都看作正样本,至于负样本,则通过构建一个**动态**$Dict_{K\times C}$维度的超大字典,通过将正样本集合$q_+=\{q_1,q_2...q_N\}$和$k_+=\{k_1,k_2...k_N\}$一一做向量点乘求和相加来计算$Loss_+$: + +$$Loss_+=\{l^{1}_+;l^{2}_+; ...;l^{N}_+\}=\{ q_1\cdot k_1; q_2\cdot k_2;...; q_n\cdot k_n \}; Loss_+\in N \times 1$$ + + +$Loss_-$的计算过程为: +$$l^{n,k}_-=q_n \cdot Dict_{:,n};Loss_-\in N \times C$$ + + +最后的loss为: +$$Loss=concat(Loss_+, Loss_-)\in N \times (1+C)$$ +可以看到字典$Dict$在整个图像表示的学习过程中可以看作一个隐特征空间,作者发现,该字典设置的越大,视觉表示学习的效果就越好。其中,每次在做完前向传播后,需要将当前的minibatch以**队列入队**的形式将$k_n$加入到字典$Dict$中,并同时将最旧时刻的minibatch**出队**。 + +学习的目标函数采用交叉熵损失函数如下所示: + +$$Loss_{crossentropy}=-log \cdot \frac{exp(l_+/ \tau)}{ \sum exp(l_n / \tau)}$$ + +其中超参数$\tau$取0.07 + +**反向梯度传播** + +在梯度反向传播过程中,梯度传播只用来更新$Encoder_1$的参数$Param_{Encoder_1}$,为了不影响动态词典$Dict$的视觉表示特征一致性,$Encoder_2$的参数$Param_{Encoder_1}$更新过程为: + +$$Param_{Encoder_2}=m \cdot Param_{Encoder_2} + ( 1- m ) \cdot Param_{Encoder_1} $$ +其中,超参数$m$取0.999 + +## 二、设计思路与实现方案 + +### 模型backbone(PaddleClas已有实现) + +- ResNet50的backbone(去除最后的全连接层) +- MLP由 两个全连接层FC1 $ 2048 \times 2048 $ 和FC2 $ 2048 \times 128 $ 构成 +- 动态字典大小为$65536$ +### optimizer +- SGD:随机梯度下降优化器 +- 初始学习率 $0.03$ +- 权重衰减:$1e-4$ +- momentum of SGD: $0.9$ + +### 训练策略(PaddleClas已有实现) +- batch-size:256 +- 单机8块V100 +- 在每个GPU上做shuffle_BN +- 共迭代$epochs:200$ + +- lr schedule 在$epch=[120, 160]$, $lr=lr*.0.1$ +- 学习率衰减策略$cosine $ + +### metric(PaddleClas已有实现) +- top1 +- top5 + +### dataset +- 数据集:ImageNet +- 数据增强(PaddleClas已有基本变换实现) +```Python + #pytorch 代码 + augmentation = [ + transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), + transforms.RandomApply( + [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8 # not strengthened + ), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([moco.loader.GaussianBlur([0.1, 2.0])], p=0.5), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + + ``` +- 图像对随机变换和高斯模糊(**PSSL已有基本变换实现,需要转为PaddleClas项目实现**) + + ```python +# pytorch 代码 +class TwoCropsTransform: + """Take two random crops of one image as the query and key.""" + + def __init__(self, base_transform): + self.base_transform = base_transform + + def __call__(self, x): + q = self.base_transform(x) + k = self.base_transform(x) + return [q, k] + + +class GaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[0.1, 2.0]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + ``` + +### PSSL项目和PaddleClas项目框架对比 + +- 两个项目基础模型ResNet50的每层参数名称不同,需要将PASSL项目的训练权重转化为PaddleClas项目使用 +- PSSL项目采用Register类方式将模型的architecture、backbone、neck、head、数据集、优化器、钩子函数链接在一起,使得整个模型的训练过程都可以通过命令行提供一份yaml文件搞定,这一点与PaddleClas项目类似 + +### 详细设计方案 + +1. model_zone添加模型MoCo的backbone、neck、head参数配置,backbone采用theseuslayer定义的Resnet50网络; +2. MoCo.yaml格式参考paddleclas项目; +3. 在ppcls.data.preprocess.ops.operators.py 文件下新增GaussianBlur类 +4. 重构ImageNetDataset类中的__init__ 和__getitem__方法, 原来的ImageNetDataset只能返回(img, label) 现增加返回(sample_1, sample_2, label)可选功能,其中,sample_1和sample_2均是img分别经过view_trans1,view_trans2得到的; +5. 在train.py + + +## 三、功能模块测试方法 +|功能模块|测试方法| +|---|---| +|前向完全对齐|给定相同的输入,分别对比PaddleClas实现的模型输出是否和官方的Pytorch版本相同| +|反向完全对齐|给定相同的输入检查反向参数更新,分别对比PaddleClas实现和官方的Pytorch版本参数更新是否一致| +|图像预处理|对照官方实现,编写paddle版本| +|超参数配置|保持和官方实现一致| +|训练环境|最好也是8块V100显卡环境,采用单机多卡分布式训练方式,和官方保持一致| +|精度对齐|在提供的小数据集上预训练并finetune后,实现精度和原PSSL项目模型相同| + +## 四、可行性分析和排期规划 +|时间|开发排期规划|时长| +|---|---|---| +|03.11-03.19|熟悉相关工具、前向对齐|9days| +|03.20-04.02|反向对齐|14days| +|04.03-04.16|训练对齐|14days| +|04.16-04.29|代码合入|14days| + +## 五、风险点与影响面 + +风险点: +- MoCo模型训练后一般作为图像特征提取器使用,并不存在所谓的推理过程 +- **PaddleClas中所有模型和算法需要通过飞桨训推一体认证,当前只需要通过新增模型只需要通过训练和推理的基础认证即可**。但是这个与MoCo模型的训练推理原则相违背,是否可以对MoCo-v2模型的认证给出明确的指定 +- 合入代码题目是MoCo-v2,代码合入的时候是否需要同时考虑MoCo-v1代码模块(原PSSL项目有该项实现) +- 原PSSL有MoCo-Clas分类模型,代码合入的时候是否需要同时加入此模块(原PSSL项目有该项实现) +- 可能涉及到修改train.py部分代码 + +影响面: +数据的Dataloader、数据增强和model均为新增脚本,不对其它模块构成影响 + +# 名词解释 +MoCo(Momentum Contrast,动量对比) +# 附件及参考资料 +
+ [1] He K, Fan H, Wu Y, et al. Momentum contrast for unsupervised visual representation learning[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2020: 9729-9738. + +
+ [2] Chen X, Fan H, Girshick R, et al. Improved baselines with momentum contrastive learning[J]. arXiv preprint arXiv:2003.04297, 2020. diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 78e9b4dc25c5132229afa1273219d94511cd11fe..be0b374e62ebe559da083199ef44c2fb3f940d47 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -75,7 +75,8 @@ from .model_zoo.foundation_vit import CLIP_vit_base_patch32_224, CLIP_vit_base_p from .model_zoo.convnext import ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384 from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384 from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224 - +from .model_zoo.moco import MoCo_V1, MoCo_V2 +from .model_zoo.moco_finetune import MoCo_finetune from .variant_models.resnet_variant import ResNet50_last_stage_stride1 from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d from .variant_models.resnet_variant import ResNet50_metabin diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index 7a4f3b37a2eb70225056d18dbca5a18ef3e18955..0fff46ace97c667c52021d5ca999024f8d0370f7 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -346,7 +346,7 @@ class ResNet(TheseusLayer): [32, 32, 3, 1], [32, 64, 3, 1]] } - self.stem = nn.Sequential(* [ + self.stem = nn.Sequential(*[ ConvBNLayer( num_channels=in_c, num_filters=out_c, @@ -396,7 +396,7 @@ class ResNet(TheseusLayer): self.data_format = data_format - super().init_res( + super().init_net( stages_pattern, return_patterns=return_patterns, return_stages=return_stages) diff --git a/ppcls/arch/backbone/model_zoo/moco.py b/ppcls/arch/backbone/model_zoo/moco.py new file mode 100644 index 0000000000000000000000000000000000000000..f36b6b3e5f12282bf09d597cbb51e0ffaf3e3f44 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/moco.py @@ -0,0 +1,354 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# reference: https://arxiv.org/abs/1611.05431 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +from ppcls.utils.initializer import kaiming_normal_, constant_, normal_ +from ..legendary_models import * +from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = {"MoCo_V1": "UNKNOWN", "MoCo_V2": "UNKNOWN"} + +__all__ = list(MODEL_URLS.keys()) + + +class LinearNeck(nn.Layer): + """Linear neck: fc only. + """ + + def __init__(self, in_channels, out_channels, with_avg_pool=False): + super(LinearNeck, self).__init__() + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + self.fc = nn.Linear(in_channels, out_channels) + + def forward(self, x): + + if self.with_avg_pool: + x = self.avgpool(x) + return self.fc(x.reshape([x.shape[0], -1])) + + +class NonLinearNeck(nn.Layer): + """The non-linear neck in MoCo v2: fc-relu-fc. + """ + + def __init__(self, + in_channels, + hid_channels, + out_channels, + with_avg_pool=False): + super(NonLinearNeck, self).__init__() + self.with_avg_pool = with_avg_pool + if with_avg_pool: + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + + self.mlp = nn.Sequential( + nn.Linear(in_channels, hid_channels), + nn.ReLU(), nn.Linear(hid_channels, out_channels)) + + def forward(self, x): + + if self.with_avg_pool: + x = self.avgpool(x) + return self.mlp(x.reshape([x.shape[0], -1])) + + +class ContrastiveHead(nn.Layer): + """Head for contrastive learning. + + Args: + temperature (float): The temperature hyper-parameter that + controls the concentration level of the distribution. + Default: 0.1. + """ + + def __init__(self, temperature=0.1): + super(ContrastiveHead, self).__init__() + self.criterion = nn.CrossEntropyLoss() + self.temperature = temperature + + def forward(self, pos, neg): + """Forward head. + + Args: + pos (Tensor): Nx1 positive similarity. + neg (Tensor): Nxk negative similarity. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + N = pos.shape[0] + logits = paddle.concat((pos, neg), axis=1) + logits /= self.temperature + labels = paddle.zeros((N, 1), dtype='int64') + + return logits, labels + + +def _load_pretrained(pretrained, model, model_url, use_ssld=False): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +class MoCo(nn.Layer): + """ + Build a MoCo model with: a query encoder, a key encoder, and a queue + https://arxiv.org/abs/1911.05722 + """ + + def __init__(self, + backbone_config, + neck_config, + head_config, + dim=128, + K=65536, + m=0.999, + T=0.07): + """ + initialize `MoCoV1` or `MoCoV2` model depends on args + Args: + backbone_config (dict): config of backbone(eg: ResNet50). + neck_config (dict): config of neck(eg: MLP or FC) + head_config (dict): config of head + dim (int): feature dimension. Default: 128. + K (int): queue size; number of negative keys. Default: 65536. + m (float): moco momentum of updating key encoder. Default: 0.999. + T (float): softmax temperature. Default: 0.07. + """ + super(MoCo, self).__init__() + self.K = K + self.m = m + self.T = T + + backbone_type = backbone_config.pop('name') + backbone = eval(backbone_type) + + neck_type = neck_config.pop('name') + neck = eval(neck_type) + + head_type = head_config.pop('name') + head = eval(head_type) + + backbone_1 = backbone() + backbone_1.stop_after(stop_layer_name='avg_pool') + backbone_2 = backbone() + backbone_2.stop_after(stop_layer_name='avg_pool') + + self.encoder_q = nn.Sequential(backbone_1, neck(**neck_config)) + self.encoder_k = nn.Sequential(backbone_2, neck(**neck_config)) + + self.backbone = self.encoder_q[0] + + self.head = head(**head_config) + + # initialize function by kaiming + self.init_parameters() + + for param_q, param_k in zip(self.encoder_q.parameters(), + self.encoder_k.parameters()): + param_k.set_value(param_q) # moco initialize + param_k.stop_gradient = True # not update by gradient + + # frozen bn normal + freeze_batchnorm_statictis(self.encoder_k) + + # create the queue + self.register_buffer("queue", paddle.randn([dim, K])) + self.queue = nn.functional.normalize(self.queue, axis=0) + + self.register_buffer("queue_ptr", paddle.zeros([1], 'int64')) + + def init_parameters(self, init_linear='kaiming', std=0.01, bias=0.): + assert init_linear in ['normal', 'kaiming'], \ + "Undefined init_linear: {}".format(init_linear) + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + kaiming_normal_(m, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.layer.norm._BatchNormBase, nn.GroupNorm)): + constant_(m, 1) + elif isinstance(m, nn.Linear): + if init_linear == 'normal': + normal_(m, std=std, bias=bias) + else: + kaiming_normal_(m, mode='fan_in', nonlinearity='relu') + + @paddle.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip(self.encoder_q.parameters(), + self.encoder_k.parameters()): + paddle.assign((param_k * self.m + param_q * (1. - self.m)), + param_k) + param_k.stop_gradient = True + + @paddle.no_grad() + def _dequeue_and_enqueue(self, keys): + keys = concat_all_gather(keys) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr[0]) + assert self.K % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.transpose([1, 0]) + ptr = (ptr + batch_size) % self.K # move pointer + + self.queue_ptr[0] = ptr + + @paddle.no_grad() + def _batch_shuffle_ddp(self, x): + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = paddle.randperm(batch_size_all).cuda() + + # broadcast to all gpus + if paddle.distributed.get_world_size() > 1: + paddle.distributed.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = paddle.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = paddle.distributed.get_rank() + idx_this = idx_shuffle.reshape([num_gpus, -1])[gpu_idx] + return paddle.index_select(x_gather, idx_this), idx_unshuffle + + @paddle.no_grad() + def _batch_unshuffle_ddp(self, x, idx_unshuffle): + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = paddle.distributed.get_rank() + idx_this = idx_unshuffle.reshape([num_gpus, -1])[gpu_idx] + + return paddle.index_select(x_gather, idx_this) + + def train_iter(self, inputs, **kwargs): + img_q, img_k = inputs + + # compute query features + q = self.encoder_q(img_q) # queries: NxC + q = nn.functional.normalize(q, axis=1) + + # compute key features + with paddle.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + # shuffle for making use of BN + img_k = paddle.to_tensor(img_k) + im_k, idx_unshuffle = self._batch_shuffle_ddp(img_k) + + k = self.encoder_k(im_k) # keys: NxC + k = nn.functional.normalize(k, axis=1) + + # undo shuffle + k = self._batch_unshuffle_ddp(k, idx_unshuffle) + + # compute logits + # FIXME: Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = paddle.sum(q * k, axis=1).unsqueeze(-1) + # negative logits: NxK + l_neg = paddle.matmul(q, self.queue.clone().detach()) + + outputs = self.head(l_pos, l_neg) + self._dequeue_and_enqueue(k) + # add return label + + return outputs + + def forward(self, inputs, mode='train', **kwargs): + if mode == 'train': + return self.train_iter(inputs, **kwargs) + elif mode == 'test': + return self.test_iter(inputs, **kwargs) + elif mode == 'extract': + return self.backbone(inputs) + else: + raise Exception("No such mode: {}".format(mode)) + + +@paddle.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + """ + if paddle.distributed.get_world_size() < 2: + return tensor + + tensors_gather = [] + paddle.distributed.all_gather(tensors_gather, tensor) + + output = paddle.concat(tensors_gather, axis=0) + return output + + +def freeze_batchnorm_statictis(layer): + def freeze_bn(layer): + if isinstance(layer, (nn.layer.norm._BatchNormBase)): + layer._use_global_stats = True + + +def MoCo_V1(backbone, neck, head, pretrained=False, use_ssld=False): + model = MoCo( + backbone_config=backbone, neck_config=neck, head_config=head, T=0.07) + _load_pretrained( + pretrained, model, MODEL_URLS["MoCo_V1"], use_ssld=use_ssld) + return model + + +def MoCo_V2(backbone, neck, head, pretrained=False, use_ssld=False): + model = MoCo( + backbone_config=backbone, neck_config=neck, head_config=head, T=0.2) + _load_pretrained( + pretrained, model, MODEL_URLS["MoCo_V2"], use_ssld=use_ssld) + return model diff --git a/ppcls/arch/backbone/model_zoo/moco_finetune.py b/ppcls/arch/backbone/model_zoo/moco_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..8de551c7848f86f1c383b5fb70f1468029640598 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/moco_finetune.py @@ -0,0 +1,139 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# reference: https://arxiv.org/abs/1611.05431 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +from ....utils import logger +from ppcls.utils.initializer import normal_ +from ..legendary_models import * +from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = {"MoCo_finetune": "UNKNOWN"} + +__all__ = list(MODEL_URLS.keys()) + + +class ClasHead(nn.Layer): + """Simple classifier head. + """ + + def __init__(self, with_avg_pool=False, in_channels=2048, class_num=1000): + super(ClasHead, self).__init__() + self.with_avg_pool = with_avg_pool + self.in_channels = in_channels + self.num_classes = class_num + + if self.with_avg_pool: + self.avg_pool = nn.AdaptiveAvgPool2D((1, 1)) + self.fc = nn.Linear(in_channels, class_num) + # reset_parameters(self.fc_cls) + normal_(self.fc, mean=0.0, std=0.01, bias=0.0) + + def forward(self, x): + if self.with_avg_pool: + x = self.avg_pool(x) + x = paddle.reshape(x, [-1, self.in_channels]) + x = self.fc(x) + return x + + +def _load_pretrained(pretrained_config, model, use_ssld=False): + if pretrained_config is not None: + if pretrained_config.startswith("http"): + load_dygraph_pretrain_from_url(model, pretrained_config) + else: + load_dygraph_pretrain(model, pretrained_config) + + +class Classification(nn.Layer): + """ + Simple image classification. + """ + + def __init__(self, backbone, head, with_sobel=False): + super(Classification, self).__init__() + self.backbone = backbone + self.head = head + + def forward(self, inputs): + x = self.backbone(inputs) + x = self.head(x) + return x + + +def freeze_batchnorm_statictis(layer): + def freeze_bn(layer): + if isinstance(layer, nn.BatchNorm): + layer._use_global_stats = True + + +def freeze_params(model): + from ppcls.arch.backbone.legendary_models.resnet import ConvBNLayer, BottleneckBlock + for item in ['stem', 'max_pool', 'blocks', 'avg_pool']: + m = getattr(model, item) + if isinstance(m, nn.Sequential): + for item in m: + if isinstance(item, ConvBNLayer): + print(item.bn) + freeze_batchnorm_statictis(item.bn) + + if isinstance(item, BottleneckBlock): + freeze_batchnorm_statictis(item.conv0.bn) + freeze_batchnorm_statictis(item.conv1.bn) + freeze_batchnorm_statictis(item.conv2.bn) + if hasattr(item, 'short'): + freeze_batchnorm_statictis(item.short.bn) + + for param in m.parameters(): + param.trainable = False + + +def MoCo_finetune(backbone, head, pretrained=False, use_ssld=False): + backbone_config = backbone + head_config = head + backbone_name = backbone_config.pop('name') + backbone = eval(backbone_name)(**backbone_config) + + # stop layer for backbone + stop_layer_name = backbone_config.pop('stop_layer_name', None) + if stop_layer_name: + backbone.stop_after(stop_layer_name=stop_layer_name) + # freeze specified layer before + freeze_layer_name = backbone_config.pop('freeze_befor', None) + if freeze_layer_name: + ret = backbone.freeze_befor(freeze_layer_name) + if ret: + logger.info( + "moco_clas backbone successfully freeze param update befor the layer: {}". + format(freeze_layer_name)) + else: + logger.error( + "moco_clas backbone failurely freeze param update befor the layer: {}". + format(freeze_layer_name)) + + freeze_params(backbone) + head_name = head_config.pop('name') + head = eval(head_name)(**head_config) + model = Classification(backbone=backbone, head=head) + + # load pretrain_moco_model weight + pretrained_config = backbone_config.pop('pretrained_model') + _load_pretrained(pretrained_config, model, use_ssld=use_ssld) + return model diff --git a/ppcls/configs/ImageNet/MoCo/MoCoV2_r50.yaml b/ppcls/configs/ImageNet/MoCo/MoCoV2_r50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cadc8d36ba970acfea43e36f31386528af69782a --- /dev/null +++ b/ppcls/configs/ImageNet/MoCo/MoCoV2_r50.yaml @@ -0,0 +1,130 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 50 + # train_epoch_iter_two_samples + train_mode: iter_two_samples + eval_during_train: False + eval_interval: 1 + epochs: 200 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + + +# model architecture +Arch: + name: MoCo_V2 + backbone: + name: ResNet50 + stop_layer_name: AvgPool2D + neck: + name: NonLinearNeck + in_channels: 2048 + hid_channels: 2048 + out_channels: 128 + head: + name: ContrastiveHead + temperature: 0.2 + +# loss function config +Loss: + Train: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + weight_decay: 0.0001 + lr: + name: Cosine + learning_rate: 0.03 + T_max: 200 + + + +# data loader for train +DataLoader: + Train: + dataset: + name: MoCoImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + return_label: False + return_two_sample: True + transform_ops: + - DecodeImage: + to_rgb: True, + channel_first: False + - RandomResizedCrop: + size: 224 + scale: [0.2, 1.] + view_trans1: + - RandomApply: + transforms: + - RawColorJitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + p: 0.8 + - RandomGrayscale: + p: 0.2 + - RandomApply: + transforms: + - GaussianBlur: + sigma: [0.1, 2.0] + p: 0.5 + - RandomHorizontalFlip: + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + view_trans2: + - RandomApply: + transforms: + - RawColorJitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + p: 0.8 + - RandomGrayscale: + p: 0.2 + - RandomApply: + transforms: + - GaussianBlur: + sigma: [0.1, 2.0] + p: 0.5 + - RandomHorizontalFlip: + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: True + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + diff --git a/ppcls/configs/ImageNet/MoCo/MoCo_clas.yaml b/ppcls/configs/ImageNet/MoCo/MoCo_clas.yaml new file mode 100644 index 0000000000000000000000000000000000000000..029b74bb9443caf8edce8e800c24eaedbdbeb419 --- /dev/null +++ b/ppcls/configs/ImageNet/MoCo/MoCo_clas.yaml @@ -0,0 +1,114 @@ +# global configs +Global: + checkpoints: null + output_dir: ./output/ + device: gpu + save_interval: 20 + eval_during_train: True + eval_interval: 1 + epochs: 100 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + + +Arch: + name: MoCo_finetune + pretrained_model: ./pretrain/moco_v2_bs_256_epoch_200 + backbone: + name: ResNet50 + stop_layer_name: avg_pool + freeze_befor: avg_pool + head: + name: ClasHead + class_num: 1000 + + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: MultiStepDecay + learning_rate: 30.0 + milestones: [60, 80] + + +DataLoader: + Train: + dataset: + name: MoCoImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + return_label: True + return_two_sample: False + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandomResizedCrop: + size: 224 + - RandomHorizontalFlip: + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + shuffle: True + drop_last: False + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: MoCoImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + return_label: True + return_two_sample: False + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - Resize: + size: 256 + - CenterCrop: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + shuffle: True + drop_last: True + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] \ No newline at end of file diff --git a/ppcls/data/dataloader/__init__.py b/ppcls/data/dataloader/__init__.py index 391dcef65bdafc2309118a842beb9ef7aae32ae6..9af45e72508032ef69b32178da207bfcf40b389c 100644 --- a/ppcls/data/dataloader/__init__.py +++ b/ppcls/data/dataloader/__init__.py @@ -14,3 +14,4 @@ from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDat from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset from ppcls.data.dataloader.cifar import Cifar10, Cifar100 from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler +from ppcls.data.dataloader.moco_imagenet_dataset import MoCoImageNetDataset diff --git a/ppcls/data/dataloader/imagenet_dataset.py b/ppcls/data/dataloader/imagenet_dataset.py index cc66007d90585fdda9d5bae1b1954bf5f6bb2bf1..398c847d2cae68cef5ea1345b43615a6bcd14491 100644 --- a/ppcls/data/dataloader/imagenet_dataset.py +++ b/ppcls/data/dataloader/imagenet_dataset.py @@ -72,4 +72,4 @@ class ImageNetDataset(CommonDataset): else: self.labels.append(np.int64(line[1])) assert os.path.exists(self.images[ - -1]), f"path {self.images[-1]} does not exist." + -1]), f"path {self.images[-1]} does not exist." \ No newline at end of file diff --git a/ppcls/data/dataloader/moco_imagenet_dataset.py b/ppcls/data/dataloader/moco_imagenet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..17106d8f13a22adb4184d2fca26cca6da0963a58 --- /dev/null +++ b/ppcls/data/dataloader/moco_imagenet_dataset.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import numpy as np +import os +from ppcls.utils import logger +from .common_dataset import CommonDataset, create_operators +from ppcls.data.preprocess import transform + + +class MoCoImageNetDataset(CommonDataset): + """MoCoImageNetDataset + + Args: + image_root (str): image root, path to `ILSVRC2012` + cls_label_path (str): path to annotation file `train_list.txt` or `val_list.txt` + return_label (bool, optional): whether return original label. + return_two_sample (bool, optional): whether return two views about original image. + transform_ops (list, optional): list of transform op(s). Defaults to None. + delimiter (str, optional): delimiter. Defaults to None. + relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False. + view_trans1 (list): some transform op(s) for view1. + view_trans2 (list): some transform op(s) for view2. + """ + + def __init__( + self, + image_root, + cls_label_path, + return_label=True, + return_two_sample=False, + transform_ops=None, + delimiter=None, + relabel=False, + view_trans1=None, + view_trans2=None, ): + self.delimiter = delimiter if delimiter is not None else " " + self.relabel = relabel + super(MoCoImageNetDataset, self).__init__(image_root, cls_label_path, + transform_ops) + + self.return_label = return_label + self.return_two_sample = return_two_sample + + if self.return_two_sample: + self.view_transform1 = create_operators(view_trans1) + self.view_transform2 = create_operators(view_trans2) + + def __getitem__(self, idx): + try: + with open(self.images[idx], 'rb') as f: + img = f.read() + + if self.return_two_sample: + sample1 = transform(img, self._transform_ops) + sample2 = transform(img, self._transform_ops) + sample1 = transform(sample1, self.view_transform1) + sample2 = transform(sample2, self.view_transform2) + + if self.return_label: + return (sample1, sample2, self.labels[idx]) + else: + return (sample1, sample2) + + if self._transform_ops: + img = transform(img, self._transform_ops) + img = img.transpose((2, 0, 1)) + + return (img, self.labels[idx]) + + except Exception as ex: + logger.error("Exception occured when parse line: {} with msg: {}". + format(self.images[idx], ex)) + rnd_idx = np.random.randint(self.__len__()) + return self.__getitem__(rnd_idx) + + def _load_anno(self, seed=None): + assert os.path.exists( + self._cls_path), f"path {self._cls_path} does not exist." + assert os.path.exists( + self._img_root), f"path {self._img_root} does not exist." + self.images = [] + self.labels = [] + + with open(self._cls_path) as fd: + lines = fd.readlines() + if self.relabel: + label_set = set() + for line in lines: + line = line.strip().split(self.delimiter) + label_set.add(np.int64(line[1])) + label_map = { + oldlabel: newlabel + for newlabel, oldlabel in enumerate(label_set) + } + + if seed is not None: + np.random.RandomState(seed).shuffle(lines) + for line in lines: + line = line.strip().split(self.delimiter) + self.images.append(os.path.join(self._img_root, line[0])) + if self.relabel: + self.labels.append(label_map[np.int64(line[1])]) + else: + self.labels.append(np.int64(line[1])) + assert os.path.exists(self.images[ + -1]), f"path {self.images[-1]} does not exist." diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index 66234a44bd23a7e4b55791d9183e9ac013f14d50..01800c81416c35a7c50ef415b915a2a32425cfec 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop, Transpose from ppcls.data.preprocess.ops.autoaugment import ImageNetPolicy as RawImageNetPolicy from ppcls.data.preprocess.ops.randaugment import RandAugment as RawRandAugment from ppcls.data.preprocess.ops.randaugment import RandomApply @@ -48,6 +49,8 @@ from ppcls.data.preprocess.ops.operators import RandomRotation from ppcls.data.preprocess.ops.operators import Padv2 from ppcls.data.preprocess.ops.operators import RandomRot90 from ppcls.data.preprocess.ops.operators import PCALighting +from ppcls.data.preprocess.ops.operators import GaussianBlur + from .ops.operators import format_data from paddle.vision.transforms import Pad as Pad_paddle_vision @@ -58,6 +61,7 @@ import numpy as np from PIL import Image import random + def transform(data, ops=[]): """ transform """ for op in ops: @@ -139,4 +143,4 @@ class TimmAutoAugment(RawTimmAutoAugment): if isinstance(img, Image.Image): img = np.asarray(img) - return img \ No newline at end of file + return img diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index d9604210e6ac097e720dbad6003fffcf5bfa6809..8243d0f41c5fe4271ed6b0b3a44a10cbd252a67c 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -24,12 +24,13 @@ import math import random import cv2 import numpy as np -from PIL import Image, ImageOps, __version__ as PILLOW_VERSION +from PIL import ImageFilter, Image, ImageOps, __version__ as PILLOW_VERSION from paddle.vision.transforms import ColorJitter as RawColorJitter from paddle.vision.transforms import CenterCrop, Resize from paddle.vision.transforms import RandomRotation as RawRandomRotation from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop from paddle.vision.transforms import functional as F +from paddle.vision.transforms import transforms as T from .autoaugment import ImageNetPolicy from .functional import augmentations from ppcls.utils import logger @@ -742,8 +743,8 @@ class Pad(object): # Process fill color for affine transforms major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2]) - major_required, minor_required = (int(v) for v in - min_pil_version.split('.')[:2]) + major_required, minor_required = ( + int(v) for v in min_pil_version.split('.')[:2]) if major_found < major_required or (major_found == major_required and minor_found < minor_required): if fill is None: @@ -858,6 +859,25 @@ class BlurImage(object): return {"img": img, "blur_image": label} +class GaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[.1, 2.], backend="cv2"): + self.sigma = sigma + self.kernel_size = 23 + self.backbend = backend + + def __call__(self, x): + sigma = np.random.uniform(self.sigma[0], self.sigma[1]) + if self.backbend == "PIL": + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + else: + x = cv2.GaussianBlur( + np.array(x), (self.kernel_size, self.kernel_size), sigma) + return Image.fromarray(x.astype(np.uint8)) + + class RandomGrayscale(object): """Randomly convert image to grayscale with a probability of p (default 0.1). @@ -878,14 +898,20 @@ class RandomGrayscale(object): def __call__(self, img): """ Args: - img (PIL Image): Image to be converted to grayscale. + img (PIL.Image|np.array): Image to be converted to grayscale. Returns: PIL Image: Randomly grayscaled image. """ - num_output_channels = 1 if img.mode == 'L' else 3 - if random.random() < self.p: - return F.to_grayscale(img, num_output_channels=num_output_channels) + if isinstance(img, Image.Image): + if img.mode == 'L': + num_output_channels = 1 + + if isinstance(img, np.ndarray) or isinstance(img, Image.Image): + num_output_channels = 3 + if random.random() < self.p: + return F.to_grayscale( + img, num_output_channels=num_output_channels) return img def __repr__(self): diff --git a/ppcls/data/preprocess/ops/randaugment.py b/ppcls/data/preprocess/ops/randaugment.py index f3a61bd506243125636db06f648ebdd69e6043f5..a53f1efd97e22b925212f477bf0a0668c192ceda 100644 --- a/ppcls/data/preprocess/ops/randaugment.py +++ b/ppcls/data/preprocess/ops/randaugment.py @@ -259,4 +259,4 @@ class RandAugmentV2(RandAugment): "equalize": lambda img, _: ImageOps.equalize(img), "invert": lambda img, _: ImageOps.invert(img), "cutout": lambda img, magnitude: cutout(img, magnitude, replace=fillcolor[0]) - } + } \ No newline at end of file diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 7058ec495e004ddeb85240da272ec349542eeb24..af2e4683f0ea68b7e98246694c8ee5192d8077e9 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -330,6 +330,7 @@ class Engine(object): if self.config["Global"]["distributed"]: dist.init_parallel_env() self.model = paddle.DataParallel(self.model) + if self.mode == 'train' and len(self.train_loss_func.parameters( )) > 0: self.train_loss_func = paddle.DataParallel( diff --git a/ppcls/engine/train/__init__.py b/ppcls/engine/train/__init__.py index 50bf9037f4982354724d56f5814f47cf8b92decc..12c7ca45d7a771de6524d3c7bc7341a4b193c408 100644 --- a/ppcls/engine/train/__init__.py +++ b/ppcls/engine/train/__init__.py @@ -16,3 +16,4 @@ from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl from ppcls.engine.train.train_progressive import train_epoch_progressive from ppcls.engine.train.train_metabin import train_epoch_metabin +from ppcls.engine.train.train_iter_two_samples import train_epoch_iter_two_samples \ No newline at end of file diff --git a/ppcls/engine/train/train_iter_two_samples.py b/ppcls/engine/train/train_iter_two_samples.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcadd60e4465caa8fca4b37ea906471238b5a1c --- /dev/null +++ b/ppcls/engine/train/train_iter_two_samples.py @@ -0,0 +1,106 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function + +import time +import paddle +from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name +from ppcls.utils import profiler + + +def train_epoch_iter_two_samples(engine, epoch_id, print_batch_step): + tic = time.time() + + if not hasattr(engine, "train_dataloader_iter"): + engine.train_dataloader_iter = iter(engine.train_dataloader) + + for iter_id in range(engine.iter_per_epoch): + # fetch data batch from dataloader + try: + batch = next(engine.train_dataloader_iter) + except Exception: + engine.train_dataloader_iter = iter(engine.train_dataloader) + batch = next(engine.train_dataloader_iter) + + profiler.add_profiler_step(engine.config["profiler_options"]) + if iter_id == 5: + for key in engine.time_info: + engine.time_info[key].reset() + engine.time_info["reader_cost"].update(time.time() - tic) + # view_1_samples: batch[0] + # view_2_samples: batch[1] + batch_size = batch[0].shape[0] + engine.global_step += 1 + + # image input + if engine.amp: + amp_level = engine.config["AMP"].get("level", "O1").upper() + with paddle.amp.auto_cast( + custom_black_list={ + "flatten_contiguous_range", "greater_than" + }, + level=amp_level): + logits, labels = forward(engine, batch) + loss_dict = engine.train_loss_func(logits, labels) + else: + logits, labels = forward(engine, batch) + loss_dict = engine.train_loss_func(logits, labels) + + # loss + loss = loss_dict["loss"] / engine.update_freq + + # backward & step opt + if engine.amp: + scaled = engine.scaler.scale(loss) + scaled.backward() + if (iter_id + 1) % engine.update_freq == 0: + for i in range(len(engine.optimizer)): + engine.scaler.minimize(engine.optimizer[i], scaled) + else: + loss.backward() + if (iter_id + 1) % engine.update_freq == 0: + for i in range(len(engine.optimizer)): + engine.optimizer[i].step() + + if (iter_id + 1) % engine.update_freq == 0: + # clear grad + for i in range(len(engine.optimizer)): + engine.optimizer[i].clear_grad() + # step lr(by step) + for i in range(len(engine.lr_sch)): + if not getattr(engine.lr_sch[i], "by_epoch", False): + engine.lr_sch[i].step() + # update ema + if engine.ema: + engine.model_ema.update(engine.model) + + # below code just for logging + # update metric_for_logger + update_metric(engine, logits, [labels], batch_size) + # update_loss_for_logger + update_loss(engine, loss_dict, batch_size) + engine.time_info["batch_cost"].update(time.time() - tic) + if iter_id % print_batch_step == 0: + log_info(engine, batch_size, epoch_id, iter_id) + tic = time.time() + + # step lr(by epoch) + for i in range(len(engine.lr_sch)): + if getattr(engine.lr_sch[i], "by_epoch", False) and \ + type_name(engine.lr_sch[i]) != "ReduceOnPlateau": + engine.lr_sch[i].step() + + +def forward(engine, batch): + return engine.model(batch)