提交 8a760fb8 编写于 作者: C cuicheng01

Add PPHGNet code

上级 50c1302b
# PP-HGNet 系列
---
## 目录
* [1. 概述](#1)
* [2. 精度、FLOPs 和参数量](#2)
<a name='1'></a>
## 1. 概述
PP-HGNet是百度自研的一个在 GPU 端上高性能的网络,该网络在 VOVNet 的基础上融合了 ResNet_vd、PPLCNet 的优点,使用了可学习的下采样层,组合成了一个在 GPU 设备上速度快、精度高的网络,超越其他 GPU 端 SOTA 模型。
<a name='2'></a>
## 2.精度、FLOPs 和参数量
| Models | Top1 | Top5 | FLOPs<br>(G) | Params<br/>(M) |
|:--:|:--:|:--:|:--:|:--:|
| PPHGNet_tiny | 79.83 | 95.04 | 4.54 | 14.75 |
| PPHGNet_tiny_ssld | 81.95 | 96.12 | 4.54 | 14.75 |
| PPHGNet_small | 81.51 | 95.82 | 8.53 | 24.38 |
关于 Inference speed 等信息,敬请期待。
......@@ -23,6 +23,7 @@ from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3
from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5
from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0
from ppcls.arch.backbone.legendary_models.pp_hgnet import PPHGNet_tiny, PPHGNet_small, PPHGNet_base
from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc
from ppcls.arch.backbone.model_zoo.resnext import ResNeXt50_32x4d, ResNeXt50_64x4d, ResNeXt101_32x4d, ResNeXt101_64x4d, ResNeXt152_32x4d, ResNeXt152_64x4d
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import KaimingNormal, Constant
from paddle.nn import Conv2D, BatchNorm2D, ReLU, AdaptiveAvgPool2D, MaxPool2D
from paddle.regularizer import L2Decay
from paddle import ParamAttr
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"PPHGNet_tiny":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams",
"PPHGNet_small":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams"
}
__all__ = list(MODEL_URLS.keys())
kaiming_normal_ = KaimingNormal()
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
class ConvBNAct(TheseusLayer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
use_act=True):
super().__init__()
self.use_act = use_act
self.conv = Conv2D(
in_channels,
out_channels,
kernel_size,
stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias_attr=False)
self.bn = BatchNorm2D(
out_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
if self.use_act:
self.act = ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.use_act:
x = self.act(x)
return x
class ESEModule(TheseusLayer):
def __init__(self, channels):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv = Conv2D(
in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv(x)
x = self.sigmoid(x)
return paddle.multiply(x=identity, y=x)
class _HG_Block(TheseusLayer):
def __init__(
self,
in_channels,
mid_channels,
out_channels,
layer_num,
identity=False, ):
super().__init__()
self.identity = identity
self.layers = nn.LayerList()
self.layers.append(
ConvBNAct(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=3,
stride=1))
for _ in range(layer_num - 1):
self.layers.append(
ConvBNAct(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
stride=1))
# feature aggregation
total_channels = in_channels + layer_num * mid_channels
self.aggregation_conv = ConvBNAct(
in_channels=total_channels,
out_channels=out_channels,
kernel_size=1,
stride=1)
self.att = ESEModule(out_channels)
def forward(self, x):
identity = x
output = []
output.append(x)
for layer in self.layers:
x = layer(x)
output.append(x)
x = paddle.concat(output, axis=1)
x = self.aggregation_conv(x)
x = self.att(x)
if self.identity:
x += identity
return x
class _HG_Stage(TheseusLayer):
def __init__(self,
in_channels,
mid_channels,
out_channels,
block_num,
layer_num,
downsample=True):
super().__init__()
self.downsample = downsample
if downsample:
self.downsample = ConvBNAct(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=2,
groups=in_channels,
use_act=False)
blocks_list = []
blocks_list.append(
_HG_Block(
in_channels,
mid_channels,
out_channels,
layer_num,
identity=False))
for _ in range(block_num - 1):
blocks_list.append(
_HG_Block(
out_channels,
mid_channels,
out_channels,
layer_num,
identity=True))
self.blocks = nn.Sequential(*blocks_list)
def forward(self, x):
if self.downsample:
x = self.downsample(x)
x = self.blocks(x)
return x
class PPHGNet(TheseusLayer):
"""
PPHGNet
Args:
stem_channels: list. Stem channel list of PPHGNet.
stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
layer_num: int. Number of layers of HG_Block.
use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific PPHGNet model depends on args.
"""
def __init__(self,
stem_channels,
stage_config,
layer_num,
use_last_conv=True,
class_expand=2048,
dropout_prob=0.0,
class_num=1000):
super().__init__()
self.use_last_conv = use_last_conv
self.class_expand = class_expand
# stem
stem_channels.insert(0, 3)
self.stem = nn.Sequential(* [
ConvBNAct(
in_channels=stem_channels[i],
out_channels=stem_channels[i + 1],
kernel_size=3,
stride=2 if i == 0 else 1) for i in range(
len(stem_channels) - 1)
])
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
# stages
self.stages = nn.LayerList()
for k in stage_config:
in_channels, mid_channels, out_channels, block_num, downsample = stage_config[
k]
self.stages.append(
_HG_Stage(in_channels, mid_channels, out_channels, block_num,
layer_num, downsample))
self.avg_pool = AdaptiveAvgPool2D(1)
if self.use_last_conv:
self.last_conv = Conv2D(
in_channels=out_channels,
out_channels=self.class_expand,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.act = nn.ReLU()
self.dropout = nn.Dropout(
p=dropout_prob, mode="downscale_in_infer")
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.fc = nn.Linear(self.class_expand
if self.use_last_conv else out_channels, class_num)
self._init_weights()
def _init_weights(self):
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2D)):
ones_(m.weight)
zeros_(m.bias)
elif isinstance(m, nn.Linear):
zeros_(m.bias)
def forward(self, x):
x = self.stem(x)
x = self.pool(x)
for stage in self.stages:
x = stage(x)
x = self.avg_pool(x)
if self.use_last_conv:
x = self.last_conv(x)
x = self.act(x)
x = self.dropout(x)
x = self.flatten(x)
x = self.fc(x)
return x
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNet_tiny
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_tiny` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [96, 96, 224, 1, False],
"stage2": [224, 128, 448, 1, True],
"stage3": [448, 160, 512, 2, True],
"stage4": [512, 192, 768, 1, True],
}
model = PPHGNet(
stem_channels=[48, 48, 96],
stage_config=stage_config,
layer_num=5,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["PPHGNet_tiny"], use_ssld)
return model
def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNet_small
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, False],
"stage2": [256, 160, 512, 1, True],
"stage3": [512, 192, 768, 2, True],
"stage4": [768, 224, 1024, 1, True],
}
model = PPHGNet(
stem_channels=[64, 64, 128],
stage_config=stage_config,
layer_num=6,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["PPHGNet_small"], use_ssld)
return model
def PPHGNet_base(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNet_base
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_base` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [160, 192, 320, 1, False],
"stage2": [320, 224, 640, 2, True],
"stage3": [640, 256, 960, 3, True],
"stage4": [960, 288, 1280, 2, True],
}
model = PPHGNet(
stem_channels=[96, 96, 160],
stage_config=stage_config,
layer_num=7,
dropout_prob=0.2,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["PPHGNet_base"], use_ssld)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 600
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
use_dali: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: PPHGNet_small
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.5
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m7-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
batch_transform_ops:
- OpSampler:
MixupOperator:
alpha: 0.2
prob: 0.5
CutmixOperator:
alpha: 1.0
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 16
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 236
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 16
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 236
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 600
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
use_dali: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: PPHGNet_tiny
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.5
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m7-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
batch_transform_ops:
- OpSampler:
MixupOperator:
alpha: 0.2
prob: 0.5
CutmixOperator:
alpha: 1.0
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 16
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 232
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 16
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 232
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
===========================train_params===========================
model_name:PPHGNet_small
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=236
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_params===========================
model_name:PPHGNet_tiny
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=232
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册