diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index e957358479cb98d8bde3dac0d4b2785b8965c7bf..efea4ec46860e012d58274ac911a8144dfdef0e2 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -67,6 +67,7 @@ from ppcls.arch.backbone.model_zoo.pvt_v2 import PVT_V2_B0, PVT_V2_B1, PVT_V2_B2 from ppcls.arch.backbone.model_zoo.mobilevit import MobileViT_XXS, MobileViT_XS, MobileViT_S from ppcls.arch.backbone.model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG_B1, RepVGG_B2, RepVGG_B1g2, RepVGG_B1g4, RepVGG_B2g4, RepVGG_B3g4 from ppcls.arch.backbone.model_zoo.van import VAN_tiny +from ppcls.arch.backbone.model_zoo.convnext import ConvNeXt_tiny from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh diff --git a/ppcls/arch/backbone/model_zoo/convnext.py b/ppcls/arch/backbone/model_zoo/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..f30894eab526b8deb5e61a964dc287415f1b1a02 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/convnext.py @@ -0,0 +1,240 @@ +# MIT License +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Code was heavily based on https://github.com/facebookresearch/ConvNeXt + +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant + +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "ConvNeXt_tiny": "", # TODO +} + +__all__ = list(MODEL_URLS.keys()) + +trunc_normal_ = TruncatedNormal(std=.02) +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = paddle.to_tensor(1 - drop_prob) + shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) + random_tensor = paddle.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class ChannelsFirstLayerNorm(nn.Layer): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, epsilon=1e-5): + super().__init__() + self.weight = self.create_parameter( + shape=[normalized_shape], default_initializer=ones_) + self.bias = self.create_parameter( + shape=[normalized_shape], default_initializer=zeros_) + self.epsilon = epsilon + self.normalized_shape = [normalized_shape] + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / paddle.sqrt(s + self.epsilon) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class Block(nn.Layer): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2D( + dim, dim, 7, padding=3, groups=dim) # depthwise conv + self.norm = nn.LayerNorm(dim, epsilon=1e-6) + # pointwise/1x1 convs, implemented with linear layers + self.pwconv1 = nn.Linear(dim, 4 * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + if layer_scale_init_value > 0: + self.gamma = self.create_parameter( + shape=[dim], + default_initializer=Constant(value=layer_scale_init_value)) + else: + self.gamma = None + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.transpose([0, 2, 3, 1]) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose([0, 3, 1, 2]) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXt(nn.Layer): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + + Args: + in_chans (int): Number of input image channels. Default: 3 + class_num (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__(self, + in_chans=3, + class_num=1000, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + drop_path_rate=0., + layer_scale_init_value=1e-6, + head_init_scale=1.): + super().__init__() + + # stem and 3 intermediate downsampling conv layers + self.downsample_layers = nn.LayerList() + stem = nn.Sequential( + nn.Conv2D( + in_chans, dims[0], 4, stride=4), + ChannelsFirstLayerNorm( + dims[0], epsilon=1e-6)) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + ChannelsFirstLayerNorm( + dims[i], epsilon=1e-6), + nn.Conv2D( + dims[i], dims[i + 1], 2, stride=2), ) + self.downsample_layers.append(downsample_layer) + + # 4 feature resolution stages, each consisting of multiple residual blocks + self.stages = nn.LayerList() + dp_rates = [ + x.item() for x in paddle.linspace(0, drop_path_rate, sum(depths)) + ] + cur = 0 + for i in range(4): + stage = nn.Sequential(*[ + Block( + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) + for j in range(depths[i]) + ]) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], epsilon=1e-6) # final norm layer + self.head = nn.Linear(dims[-1], class_num) + + self.apply(self._init_weights) + self.head.weight.set_value(self.head.weight * head_init_scale) + self.head.bias.set_value(self.head.bias * head_init_scale) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2D, nn.Linear)): + trunc_normal_(m.weight) + if m.bias is not None: + zeros_(m.bias) + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + # global average pooling, (N, C, H, W) -> (N, C) + return self.norm(x.mean([-2, -1])) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld=False): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def ConvNeXt_tiny(pretrained=False, use_ssld=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["ConvNeXt_tiny"], use_ssld=use_ssld) + return model diff --git a/ppcls/configs/ImageNet/ConvNeXt/ConvNeXt_tiny.yaml b/ppcls/configs/ImageNet/ConvNeXt/ConvNeXt_tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fb6e3cbdbb2dc648e4ef0bd1cad59106efbf91db --- /dev/null +++ b/ppcls/configs/ImageNet/ConvNeXt/ConvNeXt_tiny.yaml @@ -0,0 +1,170 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + update_freq: 4 # for 8 cards + +# model ema +EMA: + decay: 0.9999 + + +# model architecture +Arch: + name: ConvNeXt_tiny + class_num: 1000 + drop_path_rate: 0.1 + layer_scale_init_value: 1e-6 + head_init_scale: 1.0 + + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 4e-3 # lr 4e-3 for total_batch_size 4096 + eta_min: 1e-6 + warmup_epoch: 20 + warmup_start_lr: 0 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: True + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index d87960e93fe7bc7e2e67f7c30d1b58d811153905..e617b8a71afffeb9e18e4be412f5a3374bd387ec 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -18,6 +18,7 @@ from __future__ import print_function from __future__ import unicode_literals from functools import partial +import io import six import math import random @@ -138,28 +139,53 @@ class OperatorParamError(ValueError): class DecodeImage(object): """ decode image """ - def __init__(self, to_rgb=True, to_np=False, channel_first=False): - self.to_rgb = to_rgb + def __init__(self, + to_np=True, + to_rgb=True, + channel_first=False, + backend="cv2"): self.to_np = to_np # to numpy + self.to_rgb = to_rgb # only enabled when to_np is True self.channel_first = channel_first # only enabled when to_np is True + if backend.lower() not in ["cv2", "pil"]: + logger.warning( + f"The backend of DecodeImage only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead." + ) + backend = "cv2" + self.backend = backend.lower() + + if not to_np: + logger.warning( + f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}." + ) + def __call__(self, img): - if not isinstance(img, np.ndarray): - if six.PY2: - assert type(img) is str and len( - img) > 0, "invalid input 'img' in DecodeImage" + if isinstance(img, Image.Image): + assert self.backend == "pil", "invalid input 'img' in DecodeImage" + elif isinstance(img, np.ndarray): + assert self.backend == "cv2", "invalid input 'img' in DecodeImage" + elif isinstance(img, bytes): + if self.backend == "pil": + data = io.BytesIO(img) + img = Image.open(data) else: - assert type(img) is bytes and len( - img) > 0, "invalid input 'img' in DecodeImage" - data = np.frombuffer(img, dtype='uint8') - img = cv2.imdecode(data, 1) - if self.to_rgb: - assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( - img.shape) - img = img[:, :, ::-1] - - if self.channel_first: - img = img.transpose((2, 0, 1)) + data = np.frombuffer(img, dtype="uint8") + img = cv2.imdecode(data, 1) + else: + raise ValueError("invalid input 'img' in DecodeImage") + + if self.to_np: + if self.backend == "pil": + assert img.mode == "RGB", f"invalid shape of image[{img.shape}]" + img = np.asarray(img)[:, :, ::-1] # BRG + + if self.to_rgb: + assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]" + img = img[:, :, ::-1] + + if self.channel_first: + img = img.transpose((2, 0, 1)) return img diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 884a05bb141947d70d2a20c2d88967bcbe6626ea..1aa0a1e05c306f46c77ff09b3fb6af344d3e01e3 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -34,6 +34,7 @@ from ppcls.arch import apply_to_static from ppcls.loss import build_loss from ppcls.metric import build_metrics from ppcls.optimizer import build_optimizer +from ppcls.utils.ema import ExponentialMovingAverage from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import init_model from ppcls.utils import save_load @@ -99,6 +100,9 @@ class Engine(object): logger.info('train with paddle {} and device {}'.format( paddle.__version__, self.device)) + # gradient accumulation + self.update_freq = self.config["Global"].get("update_freq", 1) + if "class_num" in config["Global"]: global_class_num = config["Global"]["class_num"] if "class_num" not in config["Arch"]: @@ -203,7 +207,7 @@ class Engine(object): if self.mode == 'train': self.optimizer, self.lr_sch = build_optimizer( self.config["Optimizer"], self.config["Global"]["epochs"], - len(self.train_dataloader), + len(self.train_dataloader) // self.update_freq, [self.model, self.train_loss_func]) # AMP training and evaluating @@ -277,6 +281,12 @@ class Engine(object): level=self.amp_level, save_dtype='float32') + # build EMA model + self.ema = "EMA" in self.config and self.mode == "train" + if self.ema: + self.model_ema = ExponentialMovingAverage( + self.model, self.config['EMA'].get("decay", 0.9999)) + # check the gpu num world_size = dist.get_world_size() self.config["Global"]["distributed"] = world_size != 1 @@ -311,6 +321,10 @@ class Engine(object): "metric": -1.0, "epoch": 0, } + ema_module = None + if self.ema: + best_metric_ema = 0.0 + ema_module = self.model_ema.module # key: # val: metrics list word self.output_info = dict() @@ -325,12 +339,14 @@ class Engine(object): if self.config.Global.checkpoints is not None: metric_info = init_model(self.config.Global, self.model, - self.optimizer, self.train_loss_func) + self.optimizer, self.train_loss_func, + ema_module) if metric_info is not None: best_metric.update(metric_info) self.max_iter = len(self.train_dataloader) - 1 if platform.system( ) == "Windows" else len(self.train_dataloader) + self.max_iter = self.max_iter // self.update_freq * self.update_freq for epoch_id in range(best_metric["epoch"] + 1, self.config["Global"]["epochs"] + 1): @@ -361,6 +377,7 @@ class Engine(object): self.optimizer, best_metric, self.output_dir, + ema=ema_module, model_name=self.config["Arch"]["name"], prefix="best_model", loss=self.train_loss_func, @@ -375,6 +392,32 @@ class Engine(object): self.model.train() + if self.ema: + ori_model, self.model = self.model, ema_module + acc_ema = self.eval(epoch_id) + self.model = ori_model + ema_module.eval() + + if acc_ema > best_metric_ema: + best_metric_ema = acc_ema + save_load.save_model( + self.model, + self.optimizer, + {"metric": acc_ema, + "epoch": epoch_id}, + self.output_dir, + ema=ema_module, + model_name=self.config["Arch"]["name"], + prefix="best_model_ema", + loss=self.train_loss_func) + logger.info("[Eval][Epoch {}][best metric ema: {}]".format( + epoch_id, best_metric_ema)) + logger.scaler( + name="eval_acc_ema", + value=acc_ema, + step=epoch_id, + writer=self.vdl_writer) + # save model if epoch_id % save_interval == 0: save_load.save_model( @@ -382,6 +425,7 @@ class Engine(object): self.optimizer, {"metric": acc, "epoch": epoch_id}, self.output_dir, + ema=ema_module, model_name=self.config["Arch"]["name"], prefix="epoch_{}".format(epoch_id), loss=self.train_loss_func) @@ -391,6 +435,7 @@ class Engine(object): self.optimizer, {"metric": acc, "epoch": epoch_id}, self.output_dir, + ema=ema_module, model_name=self.config["Arch"]["name"], prefix="latest", loss=self.train_loss_func) diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 14db79e73e9e51d16d5784b7aa48a6afb12a7e0f..a41674da70c167959c2515ec696ca2a6686cf0f8 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -53,25 +53,33 @@ def train_epoch(engine, epoch_id, print_batch_step): out = forward(engine, batch) loss_dict = engine.train_loss_func(out, batch[1]) + # loss + loss = loss_dict["loss"] / engine.update_freq + # backward & step opt if engine.amp: - scaled = engine.scaler.scale(loss_dict["loss"]) + scaled = engine.scaler.scale(loss) scaled.backward() - for i in range(len(engine.optimizer)): - engine.scaler.minimize(engine.optimizer[i], scaled) + if (iter_id + 1) % engine.update_freq == 0: + for i in range(len(engine.optimizer)): + engine.scaler.minimize(engine.optimizer[i], scaled) else: - loss_dict["loss"].backward() - for i in range(len(engine.optimizer)): - engine.optimizer[i].step() + loss.backward() + if (iter_id + 1) % engine.update_freq == 0: + for i in range(len(engine.optimizer)): + engine.optimizer[i].step() - # clear grad - for i in range(len(engine.optimizer)): - engine.optimizer[i].clear_grad() - - # step lr(by step) - for i in range(len(engine.lr_sch)): - if not getattr(engine.lr_sch[i], "by_epoch", False): - engine.lr_sch[i].step() + if (iter_id + 1) % engine.update_freq == 0: + # clear grad + for i in range(len(engine.optimizer)): + engine.optimizer[i].clear_grad() + # step lr(by step) + for i in range(len(engine.lr_sch)): + if not getattr(engine.lr_sch[i], "by_epoch", False): + engine.lr_sch[i].step() + # update ema + if engine.ema: + engine.model_ema.update(engine.model) # below code just for logging # update metric_for_logger diff --git a/ppcls/engine/train/utils.py b/ppcls/engine/train/utils.py index ca211ff932f19ca63804a5a1ff52def5eb89477f..44e54660b6453b713b2325e26b1bd5590b23c933 100644 --- a/ppcls/engine/train/utils.py +++ b/ppcls/engine/train/utils.py @@ -54,12 +54,12 @@ def log_info(trainer, batch_size, epoch_id, iter_id): ips_msg = "ips: {:.5f} samples/s".format( batch_size / trainer.time_info["batch_cost"].avg) eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1 - ) * len(trainer.train_dataloader) - iter_id + ) * trainer.max_iter - iter_id ) * trainer.time_info["batch_cost"].avg eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec)))) logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format( epoch_id, trainer.config["Global"]["epochs"], iter_id, - len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg, + trainer.max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) for i, lr in enumerate(trainer.lr_sch): diff --git a/ppcls/utils/ema.py b/ppcls/utils/ema.py index b54cdb1b2030dc0a70394816a433e7e715e12996..8292781955210d68cea119b2fd887b534b3a6c04 100644 --- a/ppcls/utils/ema.py +++ b/ppcls/utils/ema.py @@ -1,10 +1,10 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,52 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy + import paddle -import numpy as np class ExponentialMovingAverage(): """ Exponential Moving Average - Code was heavily based on https://github.com/Wanger-SJTU/SegToolbox.Pytorch/blob/master/lib/utils/ema.py + Code was heavily based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py """ - def __init__(self, model, decay, thres_steps=True): - self._model = model - self._decay = decay - self._thres_steps = thres_steps - self._shadow = {} - self._backup = {} - - def register(self): - self._update_step = 0 - for name, param in self._model.named_parameters(): - if param.stop_gradient is False: - self._shadow[name] = param.numpy().copy() - - def update(self): - decay = min(self._decay, (1 + self._update_step) / ( - 10 + self._update_step)) if self._thres_steps else self._decay - for name, param in self._model.named_parameters(): - if param.stop_gradient is False: - assert name in self._shadow - new_val = np.array(param.numpy().copy()) - old_val = np.array(self._shadow[name]) - new_average = decay * old_val + (1 - decay) * new_val - self._shadow[name] = new_average - self._update_step += 1 - return decay - - def apply(self): - for name, param in self._model.named_parameters(): - if param.stop_gradient is False: - assert name in self._shadow - self._backup[name] = np.array(param.numpy().copy()) - param.set_value(np.array(self._shadow[name])) - - def restore(self): - for name, param in self._model.named_parameters(): - if param.stop_gradient is False: - assert name in self._backup - param.set_value(self._backup[name]) - self._backup = {} + def __init__(self, model, decay=0.9999): + super().__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + + @paddle.no_grad() + def _update(self, model, update_fn): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + ema_v.set_value(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 04486cc273bbfe9e3d9863b4c4ded6a8d283eee3..31323e9ae11b3245c898f412057a15fb56734b0a 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -95,7 +95,11 @@ def load_distillation_model(model, pretrained_model): pretrained_model)) -def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): +def init_model(config, + net, + optimizer=None, + loss: paddle.nn.Layer=None, + ema=None): """ load model from checkpoint or pretrained_model """ @@ -115,6 +119,11 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): for i in range(len(optimizer)): optimizer[i].set_state_dict(opti_dict[i] if isinstance( opti_dict, list) else opti_dict) + if ema is not None: + assert os.path.exists(checkpoints + ".ema.pdparams"), \ + "Given dir {}.ema.pdparams not exist.".format(checkpoints) + para_ema_dict = paddle.load(checkpoints + ".ema.pdparams") + ema.set_state_dict(para_ema_dict) logger.info("Finish load checkpoints from {}".format(checkpoints)) return metric_dict @@ -133,6 +142,7 @@ def save_model(net, optimizer, metric_info, model_path, + ema=None, model_name="", prefix='ppcls', loss: paddle.nn.Layer=None, @@ -161,6 +171,8 @@ def save_model(net, paddle.save(s_params, model_path + "_student.pdparams") paddle.save(params_state_dict, model_path + ".pdparams") + if ema is not None: + paddle.save(ema.state_dict(), model_path + ".ema.pdparams") paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt") paddle.save(metric_info, model_path + ".pdstates") logger.info("Already save model in {}".format(model_path)) diff --git a/test_tipc/config/ConvNeXt/ConvNeXt_tiny_train_infer_python.txt b/test_tipc/config/ConvNeXt/ConvNeXt_tiny_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..11b4007ef9fbaef563b028a10bf9f42eb2581f94 --- /dev/null +++ b/test_tipc/config/ConvNeXt/ConvNeXt_tiny_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:ConvNeXt_tiny +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/ConvNeXt/ConvNeXt_tiny.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/ConvNeXt/ConvNeXt_tiny.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/ConvNeXt/ConvNeXt_tiny.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +inference_dir:null +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=256 -o PreProcess.transform_ops.1.CropImage.size=224 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:True|False +-o Global.cpu_num_threads:1|6 +-o Global.batch_size:1|16 +-o Global.use_tensorrt:True|False +-o Global.use_fp16:True|False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val +-o Global.save_log_path:null +-o Global.benchmark:True +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] \ No newline at end of file