提交 cfab4c17 编写于 作者: G gaotingquan 提交者: cuicheng01

feat: add pplcnetv2_small and pplcnetv2_big

上级 130328e7
......@@ -22,7 +22,7 @@ from .legendary_models.vgg import VGG11, VGG13, VGG16, VGG19
from .legendary_models.inception_v3 import InceptionV3
from .legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C
from .legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5
from .legendary_models.pp_lcnet_v2 import PPLCNetV2_base
from .legendary_models.pp_lcnet_v2 import PPLCNetV2_small, PPLCNetV2_base, PPLCNetV2_large
from .legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0
from .legendary_models.pp_hgnet import PPHGNet_tiny, PPHGNet_small, PPHGNet_base
......
......@@ -26,8 +26,12 @@ from ..base.theseus_layer import TheseusLayer
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"PPLCNetV2_small":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_small_pretrained.pdparams",
"PPLCNetV2_base":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams",
"PPLCNetV2_large":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_large_pretrained.pdparams",
}
__all__ = list(MODEL_URLS.keys())
......@@ -340,6 +344,23 @@ def _load_pretrained(pretrained, model, model_url, use_ssld):
)
def PPLCNetV2_small(pretrained=False, use_ssld=False, **kwargs):
"""
PPLCNetV2_small
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPLCNetV2_base` model depends on args.
"""
model = PPLCNetV2(
scale=0.75, depths=[2, 2, 4, 2], dropout_prob=0.2, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_small"],
use_ssld)
return model
def PPLCNetV2_base(pretrained=False, use_ssld=False, **kwargs):
"""
PPLCNetV2_base
......@@ -354,3 +375,20 @@ def PPLCNetV2_base(pretrained=False, use_ssld=False, **kwargs):
scale=1.0, depths=[2, 2, 6, 2], dropout_prob=0.2, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld)
return model
def PPLCNetV2_large(pretrained=False, use_ssld=False, **kwargs):
"""
PPLCNetV2_large
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPLCNetV2_base` model depends on args.
"""
model = PPLCNetV2(
scale=1.25, depths=[2, 2, 8, 2], dropout_prob=0.2, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_large"],
use_ssld)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 480
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: PPLCNetV2_large
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.4
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiScaleDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
# support to specify width and height respectively:
# scales: [(160,160), (192,192), (224,224) (288,288) (320,320)]
sampler:
name: MultiScaleSampler
scales: [160, 192, 224, 288, 320]
# first_bs: batch size for the first image resolution in the scales list
# divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
first_bs: 250
divided_factor: 32
is_training: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 480
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: PPLCNetV2_small
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.8
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00002
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiScaleDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
# support to specify width and height respectively:
# scales: [(160,160), (192,192), (224,224) (288,288) (320,320)]
sampler:
name: MultiScaleSampler
scales: [160, 192, 224, 288, 320]
# first_bs: batch size for the first image resolution in the scales list
# divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
first_bs: 500
divided_factor: 32
is_training: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册