diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 737f34d75914f4d2098a092ddd96257676ae1651..b62b5a64df348e257beee174eeb5bff1007f1d3e 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -64,6 +64,7 @@ from ppcls.arch.backbone.model_zoo.cspnet import CSPDarkNet53 from ppcls.arch.backbone.model_zoo.pvt_v2 import PVT_V2_B0, PVT_V2_B1, PVT_V2_B2_Linear, PVT_V2_B2, PVT_V2_B3, PVT_V2_B4, PVT_V2_B5 from ppcls.arch.backbone.model_zoo.mobilevit import MobileViT_XXS, MobileViT_XS, MobileViT_S from ppcls.arch.backbone.model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG_B1, RepVGG_B2, RepVGG_B1g2, RepVGG_B1g4, RepVGG_B2g4, RepVGG_B3g4 +from ppcls.arch.backbone.model_zoo.van import VAN_tiny from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh diff --git a/ppcls/arch/backbone/model_zoo/van.py b/ppcls/arch/backbone/model_zoo/van.py new file mode 100644 index 0000000000000000000000000000000000000000..8cef8c96c59a167bd8a129909d68a0c942a08d3a --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/van.py @@ -0,0 +1,312 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was heavily based on https://github.com/Visual-Attention-Network/VAN-Classification + +from functools import partial +import math +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant + +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "VAN_tiny": "", # TODO +} + +__all__ = list(MODEL_URLS.keys()) + +trunc_normal_ = TruncatedNormal(std=.02) +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = paddle.to_tensor(1 - drop_prob) + shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) + random_tensor = paddle.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +@paddle.jit.not_to_static +def swapdim(x, dim1, dim2): + a = list(range(len(x.shape))) + a[dim1], a[dim2] = a[dim2], a[dim1] + return x.transpose(a) + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2D(in_features, hidden_features, 1) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Conv2D(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class LKA(nn.Layer): + def __init__(self, dim): + super().__init__() + self.conv0 = nn.Conv2D(dim, dim, 5, padding=2, groups=dim) + self.conv_spatial = nn.Conv2D( + dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) + self.conv1 = nn.Conv2D(dim, dim, 1) + + def forward(self, x): + attn = self.conv0(x) + attn = self.conv_spatial(attn) + attn = self.conv1(attn) + return x * attn + + +class Attention(nn.Layer): + def __init__(self, d_model): + super().__init__() + self.proj_1 = nn.Conv2D(d_model, d_model, 1) + self.activation = nn.GELU() + self.spatial_gating_unit = LKA(d_model) + self.proj_2 = nn.Conv2D(d_model, d_model, 1) + + def forward(self, x): + shorcut = x + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class Block(nn.Layer): + def __init__(self, + dim, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer=nn.GELU): + super().__init__() + self.norm1 = nn.BatchNorm2D(dim) + self.attn = Attention(dim) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = nn.BatchNorm2D(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + layer_scale_init_value = 1e-2 + self.layer_scale_1 = self.create_parameter( + shape=[dim, 1, 1], + default_initializer=Constant(value=layer_scale_init_value)) + self.layer_scale_2 = self.create_parameter( + shape=[dim, 1, 1], + default_initializer=Constant(value=layer_scale_init_value)) + + def forward(self, x): + x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x))) + return x + + +class OverlapPatchEmbed(nn.Layer): + """ Image to Patch Embedding + """ + + def __init__(self, + img_size=224, + patch_size=7, + stride=4, + in_chans=3, + embed_dim=768): + super().__init__() + self.proj = nn.Conv2D( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=patch_size // 2) + self.norm = nn.BatchNorm2D(embed_dim) + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = self.norm(x) + return x, H, W + + +class VAN(nn.Layer): + def __init__(self, + img_size=224, + in_chans=3, + class_num=1000, + embed_dims=[64, 128, 256, 512], + mlp_ratios=[4, 4, 4, 4], + drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], + num_stages=4, + flag=False): + super().__init__() + if flag == False: + self.class_num = class_num + self.depths = depths + self.num_stages = num_stages + + dpr = [x for x in paddle.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for i in range(num_stages): + patch_embed = OverlapPatchEmbed( + img_size=img_size if i == 0 else img_size // (2**(i + 1)), + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_chans=in_chans if i == 0 else embed_dims[i - 1], + embed_dim=embed_dims[i]) + + block = nn.LayerList([ + Block( + dim=embed_dims[i], + mlp_ratio=mlp_ratios[i], + drop=drop_rate, + drop_path=dpr[cur + j]) for j in range(depths[i]) + ]) + norm = norm_layer(embed_dims[i]) + cur += depths[i] + + setattr(self, f"patch_embed{i + 1}", patch_embed) + setattr(self, f"block{i + 1}", block) + setattr(self, f"norm{i + 1}", norm) + + # classification head + self.head = nn.Linear(embed_dims[3], + class_num) if class_num > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + elif isinstance(m, nn.Conv2D): + fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + fan_out //= m._groups + m.weight.set_value( + paddle.normal( + std=math.sqrt(2.0 / fan_out), shape=m.weight.shape)) + if m.bias is not None: + zeros_(m.bias) + + def forward_features(self, x): + B = x.shape[0] + + for i in range(self.num_stages): + patch_embed = getattr(self, f"patch_embed{i + 1}") + block = getattr(self, f"block{i + 1}") + norm = getattr(self, f"norm{i + 1}") + x, H, W = patch_embed(x) + for blk in block: + x = blk(x) + x = x.flatten(2) + x = swapdim(x, 1, 2) + x = norm(x) + if i != self.num_stages - 1: + x = x.reshape([B, H, W, x.shape[2]]).transpose([0, 3, 1, 2]) + + return x.mean(axis=1) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + + return x + + +class DWConv(nn.Layer): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim) + + def forward(self, x): + x = self.dwconv(x) + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld=False): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def VAN_tiny(pretrained=False, use_ssld=False, **kwargs): + model = VAN(embed_dims=[32, 64, 160, 256], + mlp_ratios=[8, 8, 4, 4], + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + depths=[3, 3, 5, 2], + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["VAN_tiny"], use_ssld=use_ssld) + return model diff --git a/ppcls/configs/ImageNet/VAN/VAN_tiny.yaml b/ppcls/configs/ImageNet/VAN/VAN_tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae08dcfa9e99d0ccab4f051abe9351ff61c9b736 --- /dev/null +++ b/ppcls/configs/ImageNet/VAN/VAN_tiny.yaml @@ -0,0 +1,158 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# model architecture +Arch: + name: VAN_tiny + class_num: 1000 + drop_path_rate: 0.1 + drop_rate: 0.0 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + one_dim_param_no_weight_decay: True + lr: + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-6 + warmup_epoch: 5 + warmup_start_lr: 1e-6 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: random + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: random + img_size: 224 + mean: [0.5, 0.5, 0.5] + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: True + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/test_tipc/config/VAN/VAN_tiny.txt b/test_tipc/config/VAN/VAN_tiny.txt new file mode 100644 index 0000000000000000000000000000000000000000..4ad311b58a653c38e9ae2d656024023102c46737 --- /dev/null +++ b/test_tipc/config/VAN/VAN_tiny.txt @@ -0,0 +1,58 @@ +===========================train_params=========================== +model_name:VAN_tiny +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/VAN/VAN_tiny.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/VAN/VAN_tiny.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/VAN/VAN_tiny.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +inference_dir:null +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=248 -o PreProcess.transform_ops.2.NormalizeImage.mean=[0.5,0.5,0.5] -o PreProcess.transform_ops.2.NormalizeImage.std=[0.5,0.5,0.5] +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:True|False +-o Global.cpu_num_threads:1|6 +-o Global.batch_size:1|16 +-o Global.use_tensorrt:True|False +-o Global.use_fp16:True|False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val +-o Global.save_log_path:null +-o Global.benchmark:True +null:null +null:null +===========================train_benchmark_params========================== +batch_size:128 +fp_items:fp32 +epoch:1 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096