diff --git a/docs/zh_CN/algorithm_introduction/ImageNet_models.md b/docs/zh_CN/algorithm_introduction/ImageNet_models.md index d8d0a5bb04115733c08492e8a7dcf3e064708b56..f5c46ef1db125e6f6925330dc085e9682e1282b2 100644 --- a/docs/zh_CN/algorithm_introduction/ImageNet_models.md +++ b/docs/zh_CN/algorithm_introduction/ImageNet_models.md @@ -82,6 +82,7 @@ | MobileNetV3_small_x1_0_ssld | 0.713 | 0.682 | 0.031 | 6.546 | 0.123 | 2.94 | 12 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_ssld_pretrained.pdparams) | | GhostNet_x1_3_ssld | 0.794 | 0.757 | 0.037 | 19.983 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_ssld_pretrained.pdparams) | + #### Intel CPU端知识蒸馏模型 @@ -180,6 +181,10 @@ ResNet及其Vd系列模型的精度、速度指标如下表所示,更多关于 | GhostNet_
x1_0 | 0.7402 | 0.9165 | 13.5587 | 0.294 | 5.2 | 20 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_0_pretrained.pdparams) | | GhostNet_
x1_3 | 0.7579 | 0.9254 | 19.9825 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_pretrained.pdparams) | | GhostNet_
x1_3_ssld | 0.7938 | 0.9449 | 19.9825 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_ssld_pretrained.pdparams) | +| ESNet_x0_25 | 62.48 | 83.46 || 0.031 | 2.83 | 11 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_25_pretrained.pdparams) | +| ESNet_x0_5 | 68.82 | 88.04 || 0.067 | 3.25 | 13 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_5_pretrained.pdparams) | +| ESNet_x0_75 | 72.24 | 90.45 || 0.124 | 3.87 | 15 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_75_pretrained.pdparams) | +| ESNet_x1_0 | 73.92 | 91.40 || 0.197 | 4.64 | 18 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x1_0_pretrained.pdparams) | diff --git a/docs/zh_CN/models/ESNet.md b/docs/zh_CN/models/ESNet.md new file mode 100644 index 0000000000000000000000000000000000000000..bd5322b8587024c13e7f8567c2e76342dc0ed9de --- /dev/null +++ b/docs/zh_CN/models/ESNet.md @@ -0,0 +1,16 @@ +# ESNet系列 + +## 概述 + +ESNet(Enhanced ShuffleNet)是百度自研的一个轻量级网络,该网络在ShuffleNetV2的基础上融合了MobileNetV3、GhostNet、PPLCNet的优点,组合成了一个在ARM设备上速度更快、精度更高的网络,由于其出色的表现,所以在PaddleDetection推出的[PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.3/configs/picodet)使用了该模型做backbone,配合更强的目标检测算法,最终的指标一举刷新了目标检测模型在ARM设备上的SOTA指标。 + +## 精度、FLOPs和参数量 + +| Models | Top1 | Top5 | FLOPs
(M) | Params
(M) | +|:--:|:--:|:--:|:--:|:--:| +| ESNet_x0_25 | 62.48 | 83.46 | 30.9 | 2.83 | +| ESNet_x0_5 | 68.82 | 88.04 | 67.3 | 3.25 | +| ESNet_x0_75 | 72.24 | 90.45 | 123.7 | 3.87 | +| ESNet_x1_0 | 73.92 | 91.40 | 197.3 | 4.64 | + +关于Inference speed等信息,敬请期待。 diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 9dd929bf8872694624b511fce8fce01548760970..bd2b99b0e8c7039fa4e705d0c6b635c0bdd813d6 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -22,6 +22,7 @@ from ppcls.arch.backbone.legendary_models.vgg import VGG11, VGG13, VGG16, VGG19 from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3 from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5 +from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0 from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc from ppcls.arch.backbone.model_zoo.resnext import ResNeXt50_32x4d, ResNeXt50_64x4d, ResNeXt101_32x4d, ResNeXt101_64x4d, ResNeXt152_32x4d, ResNeXt152_64x4d diff --git a/ppcls/arch/backbone/legendary_models/esnet.py b/ppcls/arch/backbone/legendary_models/esnet.py new file mode 100644 index 0000000000000000000000000000000000000000..cf9c9626ef144809ba1025416a71bc0ac5e1cce7 --- /dev/null +++ b/ppcls/arch/backbone/legendary_models/esnet.py @@ -0,0 +1,355 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function +import math +import paddle +from paddle import ParamAttr, reshape, transpose, concat, split +import paddle.nn as nn +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D +from paddle.nn.initializer import KaimingNormal +from paddle.regularizer import L2Decay + +from ppcls.arch.backbone.base.theseus_layer import TheseusLayer +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "ESNet_x0_25": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_25_pretrained.pdparams", + "ESNet_x0_5": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_5_pretrained.pdparams", + "ESNet_x0_75": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_75_pretrained.pdparams", + "ESNet_x1_0": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x1_0_pretrained.pdparams", +} + +__all__ = list(MODEL_URLS.keys()) + + +def channel_shuffle(x, groups): + batch_size, num_channels, height, width = x.shape[0:4] + channels_per_group = num_channels // groups + x = reshape( + x=x, shape=[batch_size, groups, channels_per_group, height, width]) + x = transpose(x=x, perm=[0, 2, 1, 3, 4]) + x = reshape(x=x, shape=[batch_size, num_channels, height, width]) + return x + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNLayer(TheseusLayer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + if_act=True): + super().__init__() + self.conv = Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self.bn = BatchNorm( + out_channels, + param_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.if_act = if_act + self.hardswish = nn.Hardswish() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.if_act: + x = self.hardswish(x) + return x + + +class SEModule(TheseusLayer): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0) + self.relu = nn.ReLU() + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = nn.Hardsigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + x = paddle.multiply(x=identity, y=x) + return x + + +class ESBlock1(TheseusLayer): + def __init__(self, in_channels, out_channels): + super().__init__() + self.pw_1_1 = ConvBNLayer( + in_channels=in_channels // 2, + out_channels=out_channels // 2, + kernel_size=1, + stride=1) + self.dw_1 = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=3, + stride=1, + groups=out_channels // 2, + if_act=False) + self.se = SEModule(out_channels) + + self.pw_1_2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1) + + def forward(self, x): + x1, x2 = split( + x, num_or_sections=[x.shape[1] // 2, x.shape[1] // 2], axis=1) + x2 = self.pw_1_1(x2) + x3 = self.dw_1(x2) + x3 = concat([x2, x3], axis=1) + x3 = self.se(x3) + x3 = self.pw_1_2(x3) + x = concat([x1, x3], axis=1) + return channel_shuffle(x, 2) + + +class ESBlock2(TheseusLayer): + def __init__(self, in_channels, out_channels): + super().__init__() + + # branch1 + self.dw_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=2, + groups=in_channels, + if_act=False) + self.pw_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1) + # branch2 + self.pw_2_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels // 2, + kernel_size=1) + self.dw_2 = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=3, + stride=2, + groups=out_channels // 2, + if_act=False) + self.se = SEModule(out_channels // 2) + self.pw_2_2 = ConvBNLayer( + in_channels=out_channels // 2, + out_channels=out_channels // 2, + kernel_size=1) + self.concat_dw = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + groups=out_channels) + self.concat_pw = ConvBNLayer( + in_channels=out_channels, out_channels=out_channels, kernel_size=1) + + def forward(self, x): + x1 = self.dw_1(x) + x1 = self.pw_1(x1) + x2 = self.pw_2_1(x) + x2 = self.dw_2(x2) + x2 = self.se(x2) + x2 = self.pw_2_2(x2) + x = concat([x1, x2], axis=1) + x = self.concat_dw(x) + x = self.concat_pw(x) + return x + + +class ESNet(TheseusLayer): + def __init__(self, + class_num=1000, + scale=1.0, + dropout_prob=0.2, + class_expand=1280): + super().__init__() + self.scale = scale + self.class_num = class_num + self.class_expand = class_expand + stage_repeats = [3, 7, 3] + stage_out_channels = [ + -1, 24, make_divisible(116 * scale), make_divisible(232 * scale), + make_divisible(464 * scale), 1024 + ] + + self.conv1 = ConvBNLayer( + in_channels=3, + out_channels=stage_out_channels[1], + kernel_size=3, + stride=2) + self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1) + + block_list = [] + for stage_id, num_repeat in enumerate(stage_repeats): + for i in range(num_repeat): + if i == 0: + block = ESBlock2( + in_channels=stage_out_channels[stage_id + 1], + out_channels=stage_out_channels[stage_id + 2]) + else: + block = ESBlock1( + in_channels=stage_out_channels[stage_id + 2], + out_channels=stage_out_channels[stage_id + 2]) + block_list.append(block) + self.blocks = nn.Sequential(*block_list) + + self.conv2 = ConvBNLayer( + in_channels=stage_out_channels[-2], + out_channels=stage_out_channels[-1], + kernel_size=1) + + self.avg_pool = AdaptiveAvgPool2D(1) + + self.last_conv = Conv2D( + in_channels=stage_out_channels[-1], + out_channels=self.class_expand, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.hardswish = nn.Hardswish() + self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer") + self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) + self.fc = Linear(self.class_expand, self.class_num) + + def forward(self, x): + x = self.conv1(x) + x = self.max_pool(x) + x = self.blocks(x) + x = self.conv2(x) + x = self.avg_pool(x) + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def ESNet_x0_25(pretrained=False, use_ssld=False, **kwargs): + """ + ESNet_x0_25 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ESNet_x0_25` model depends on args. + """ + model = ESNet(scale=0.25, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_25"], use_ssld) + return model + + +def ESNet_x0_5(pretrained=False, use_ssld=False, **kwargs): + """ + ESNet_x0_5 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ESNet_x0_5` model depends on args. + """ + model = ESNet(scale=0.5, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_5"], use_ssld) + return model + + +def ESNet_x0_75(pretrained=False, use_ssld=False, **kwargs): + """ + ESNet_x0_75 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ESNet_x0_75` model depends on args. + """ + model = ESNet(scale=0.75, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_75"], use_ssld) + return model + + +def ESNet_x1_0(pretrained=False, use_ssld=False, **kwargs): + """ + ESNet_x1_0 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ESNet_x1_0` model depends on args. + """ + model = ESNet(scale=1.0, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ESNet_x1_0"], use_ssld) + return model diff --git a/ppcls/configs/ImageNet/ESNet/ESNet_x0_25.yaml b/ppcls/configs/ImageNet/ESNet/ESNet_x0_25.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e01e853ac06196c8b0bafe82fed5a973662280bd --- /dev/null +++ b/ppcls/configs/ImageNet/ESNet/ESNet_x0_25.yaml @@ -0,0 +1,129 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + class_num: 1000 + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 360 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference +# model architecture +Arch: + name: ESNet_x0_25 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.8 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00003 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 512 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/whl/demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/ESNet/ESNet_x0_5.yaml b/ppcls/configs/ImageNet/ESNet/ESNet_x0_5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc51351f361650c736ef4dc1b57ab046aa4f3cb0 --- /dev/null +++ b/ppcls/configs/ImageNet/ESNet/ESNet_x0_5.yaml @@ -0,0 +1,129 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + class_num: 1000 + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 360 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference +# model architecture +Arch: + name: ESNet_x0_5 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.8 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00003 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 512 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/whl/demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/ESNet/ESNet_x0_75.yaml b/ppcls/configs/ImageNet/ESNet/ESNet_x0_75.yaml new file mode 100644 index 0000000000000000000000000000000000000000..265a622d28dca21f1b805184a52be95d1d27a117 --- /dev/null +++ b/ppcls/configs/ImageNet/ESNet/ESNet_x0_75.yaml @@ -0,0 +1,129 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + class_num: 1000 + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 360 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference +# model architecture +Arch: + name: ESNet_x0_75 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.8 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00003 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 512 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/whl/demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/ESNet/ESNet_x1_0.yaml b/ppcls/configs/ImageNet/ESNet/ESNet_x1_0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..44fd7d5f158ad81dcfbc8ab69fd44e00c6bf1558 --- /dev/null +++ b/ppcls/configs/ImageNet/ESNet/ESNet_x1_0.yaml @@ -0,0 +1,129 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + class_num: 1000 + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 360 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference +# model architecture +Arch: + name: ESNet_x1_0 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.8 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00003 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 512 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/whl/demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5]