diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index efea4ec46860e012d58274ac911a8144dfdef0e2..d3bb4541981fb4c01befc82b3b569a2e098ac92b 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -67,7 +67,9 @@ from ppcls.arch.backbone.model_zoo.pvt_v2 import PVT_V2_B0, PVT_V2_B1, PVT_V2_B2 from ppcls.arch.backbone.model_zoo.mobilevit import MobileViT_XXS, MobileViT_XS, MobileViT_S from ppcls.arch.backbone.model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG_B1, RepVGG_B2, RepVGG_B1g2, RepVGG_B1g4, RepVGG_B2g4, RepVGG_B3g4 from ppcls.arch.backbone.model_zoo.van import VAN_tiny +from ppcls.arch.backbone.model_zoo.peleenet import PeleeNet from ppcls.arch.backbone.model_zoo.convnext import ConvNeXt_tiny + from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh diff --git a/ppcls/arch/backbone/model_zoo/peleenet.py b/ppcls/arch/backbone/model_zoo/peleenet.py new file mode 100644 index 0000000000000000000000000000000000000000..a09091af23d7d2a67c2f8303b4f8c119f77e8593 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/peleenet.py @@ -0,0 +1,239 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Code was heavily based on https://github.com/Robert-JunWang/PeleeNet +# reference: https://arxiv.org/pdf/1804.06882.pdf + +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import Normal, Constant + +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "peleenet": "" # TODO +} + +__all__ = MODEL_URLS.keys() + +normal_ = lambda x, mean=0, std=1: Normal(mean, std)(x) +constant_ = lambda x, value=0: Constant(value)(x) +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +class _DenseLayer(nn.Layer): + def __init__(self, num_input_features, growth_rate, bottleneck_width, drop_rate): + super(_DenseLayer, self).__init__() + + growth_rate = int(growth_rate / 2) + inter_channel = int(growth_rate * bottleneck_width / 4) * 4 + + if inter_channel > num_input_features / 2: + inter_channel = int(num_input_features / 8) * 4 + print('adjust inter_channel to ', inter_channel) + + self.branch1a = BasicConv2D( + num_input_features, inter_channel, kernel_size=1) + self.branch1b = BasicConv2D( + inter_channel, growth_rate, kernel_size=3, padding=1) + + self.branch2a = BasicConv2D( + num_input_features, inter_channel, kernel_size=1) + self.branch2b = BasicConv2D( + inter_channel, growth_rate, kernel_size=3, padding=1) + self.branch2c = BasicConv2D( + growth_rate, growth_rate, kernel_size=3, padding=1) + + def forward(self, x): + branch1 = self.branch1a(x) + branch1 = self.branch1b(branch1) + + branch2 = self.branch2a(x) + branch2 = self.branch2b(branch2) + branch2 = self.branch2c(branch2) + + return paddle.concat([x, branch1, branch2], 1) + + +class _DenseBlock(nn.Sequential): + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer(num_input_features + i * + growth_rate, growth_rate, bn_size, drop_rate) + setattr(self, 'denselayer%d' % (i + 1), layer) + + +class _StemBlock(nn.Layer): + def __init__(self, num_input_channels, num_init_features): + super(_StemBlock, self).__init__() + + num_stem_features = int(num_init_features/2) + + self.stem1 = BasicConv2D( + num_input_channels, num_init_features, kernel_size=3, stride=2, padding=1) + self.stem2a = BasicConv2D( + num_init_features, num_stem_features, kernel_size=1, stride=1, padding=0) + self.stem2b = BasicConv2D( + num_stem_features, num_init_features, kernel_size=3, stride=2, padding=1) + self.stem3 = BasicConv2D( + 2*num_init_features, num_init_features, kernel_size=1, stride=1, padding=0) + self.pool = nn.MaxPool2D(kernel_size=2, stride=2) + + def forward(self, x): + out = self.stem1(x) + + branch2 = self.stem2a(out) + branch2 = self.stem2b(branch2) + branch1 = self.pool(out) + + out = paddle.concat([branch1, branch2], 1) + out = self.stem3(out) + + return out + + +class BasicConv2D(nn.Layer): + + def __init__(self, in_channels, out_channels, activation=True, **kwargs): + super(BasicConv2D, self).__init__() + self.conv = nn.Conv2D(in_channels, out_channels, + bias_attr=False, **kwargs) + self.norm = nn.BatchNorm2D(out_channels) + self.activation = activation + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + if self.activation: + return F.relu(x) + else: + return x + + +class PeleeNetDY(nn.Layer): + r"""PeleeNet model class, based on + `"Densely Connected Convolutional Networks" and + "Pelee: A Real-Time Object Detection System on Mobile Devices" ` + + Args: + growth_rate (int or list of 4 ints) - how many filters to add each layer (`k` in paper) + block_config (list of 4 ints) - how many layers in each pooling block + num_init_features (int) - the number of filters to learn in the first convolution layer + bottleneck_width (int or list of 4 ints) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + class_num (int) - number of classification classes + """ + + def __init__(self, growth_rate=32, block_config=[3, 4, 8, 6], + num_init_features=32, bottleneck_width=[1, 2, 4, 4], + drop_rate=0.05, class_num=1000): + + super(PeleeNetDY, self).__init__() + + self.features = nn.Sequential(*[ + ('stemblock', _StemBlock(3, num_init_features)), + ]) + + if type(growth_rate) is list: + growth_rates = growth_rate + assert len(growth_rates) == 4, \ + 'The growth rate must be the list and the size must be 4' + else: + growth_rates = [growth_rate] * 4 + + if type(bottleneck_width) is list: + bottleneck_widths = bottleneck_width + assert len(bottleneck_widths) == 4, \ + 'The bottleneck width must be the list and the size must be 4' + else: + bottleneck_widths = [bottleneck_width] * 4 + + # Each denseblock + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = _DenseBlock(num_layers=num_layers, + num_input_features=num_features, + bn_size=bottleneck_widths[i], + growth_rate=growth_rates[i], + drop_rate=drop_rate) + setattr(self.features, 'denseblock%d' % (i + 1), block) + num_features = num_features + num_layers * growth_rates[i] + + setattr(self.features, 'transition%d' % (i + 1), BasicConv2D( + num_features, num_features, kernel_size=1, stride=1, padding=0)) + + if i != len(block_config) - 1: + setattr(self.features, 'transition%d_pool' % + (i + 1), nn.AvgPool2D(kernel_size=2, stride=2)) + num_features = num_features + + # Linear layer + self.classifier = nn.Linear(num_features, class_num) + self.drop_rate = drop_rate + + self.apply(self._initialize_weights) + + def forward(self, x): + features = self.features(x) + out = F.avg_pool2d(features, kernel_size=features.shape[2:4]).flatten(1) + if self.drop_rate > 0: + out = F.dropout(out, p=self.drop_rate, training=self.training) + out = self.classifier(out) + return out + + def _initialize_weights(self, m): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + normal_(m.weight, std=math.sqrt(2. / n)) + if m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2D): + ones_(m.weight) + zeros_(m.bias) + elif isinstance(m, nn.Linear): + normal_(m.weight, std=0.01) + zeros_(m.bias) + + +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def PeleeNet(pretrained=False, use_ssld=False, **kwargs): + model = PeleeNetDY(**kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["peleenet"], use_ssld) + return model diff --git a/ppcls/configs/ImageNet/PeleeNet/PeleeNet.yaml b/ppcls/configs/ImageNet/PeleeNet/PeleeNet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..648f97040e36c135d4896386ced20b23d328d746 --- /dev/null +++ b/ppcls/configs/ImageNet/PeleeNet/PeleeNet.yaml @@ -0,0 +1,130 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# model architecture +Arch: + name: PeleeNet + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.18 # for total batch size 512 + regularizer: + name: 'L2' + coeff: 0.0001 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 256 # for 2 cards + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/test_tipc/config/PeleeNet/PeleeNet_train_infer_python.txt b/test_tipc/config/PeleeNet/PeleeNet_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..f2c3a82a1dea3bc226940ac711790d32939dc541 --- /dev/null +++ b/test_tipc/config/PeleeNet/PeleeNet_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:PeleeNet +python:python3 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/PeleeNet/PeleeNet.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/PeleeNet/PeleeNet.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PeleeNet/PeleeNet.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +inference_dir:null +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:True|False +-o Global.cpu_num_threads:1|6 +-o Global.batch_size:1|16 +-o Global.use_tensorrt:True|False +-o Global.use_fp16:True|False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val +-o Global.save_log_path:null +-o Global.benchmark:True +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}]