提交 d868a637 编写于 作者: C cuicheng01

Add ESNet code and pretrained models

上级 507a9213
...@@ -82,6 +82,7 @@ ...@@ -82,6 +82,7 @@
| MobileNetV3_small_x1_0_ssld | 0.713 | 0.682 | 0.031 | 6.546 | 0.123 | 2.94 | 12 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_ssld_pretrained.pdparams) | | MobileNetV3_small_x1_0_ssld | 0.713 | 0.682 | 0.031 | 6.546 | 0.123 | 2.94 | 12 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_ssld_pretrained.pdparams) |
| GhostNet_x1_3_ssld | 0.794 | 0.757 | 0.037 | 19.983 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_ssld_pretrained.pdparams) | | GhostNet_x1_3_ssld | 0.794 | 0.757 | 0.037 | 19.983 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_ssld_pretrained.pdparams) |
<a name="Intel-CPU端知识蒸馏模型"></a> <a name="Intel-CPU端知识蒸馏模型"></a>
#### Intel CPU端知识蒸馏模型 #### Intel CPU端知识蒸馏模型
...@@ -180,6 +181,10 @@ ResNet及其Vd系列模型的精度、速度指标如下表所示,更多关于 ...@@ -180,6 +181,10 @@ ResNet及其Vd系列模型的精度、速度指标如下表所示,更多关于
| GhostNet_<br>x1_0 | 0.7402 | 0.9165 | 13.5587 | 0.294 | 5.2 | 20 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_0_pretrained.pdparams) | | GhostNet_<br>x1_0 | 0.7402 | 0.9165 | 13.5587 | 0.294 | 5.2 | 20 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_0_pretrained.pdparams) |
| GhostNet_<br>x1_3 | 0.7579 | 0.9254 | 19.9825 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_pretrained.pdparams) | | GhostNet_<br>x1_3 | 0.7579 | 0.9254 | 19.9825 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_pretrained.pdparams) |
| GhostNet_<br>x1_3_ssld | 0.7938 | 0.9449 | 19.9825 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_ssld_pretrained.pdparams) | | GhostNet_<br>x1_3_ssld | 0.7938 | 0.9449 | 19.9825 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_ssld_pretrained.pdparams) |
| ESNet_x0_25 | 62.48 | 83.46 || 0.031 | 2.83 | 11 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ESNet_x0_25_pretrained.pdparams) |
| ESNet_x0_5 | 68.82 | 88.04 || 0.067 | 3.25 | 13 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ESNet_x0_5_pretrained.pdparams) |
| ESNet_x0_75 | 72.24 | 90.45 || 0.124 | 3.87 | 15 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ESNet_x0_75_pretrained.pdparams) |
| ESNet_x1_0 | 73.92 | 91.40 || 0.197 | 4.64 | 18 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ESNet_x1_0_pretrained.pdparams) |
<a name="SEResNeXt与Res2Net系列"></a> <a name="SEResNeXt与Res2Net系列"></a>
......
# ESNet系列
## 概述
ESNet(Enhanced ShuffleNet)是百度自研的一个轻量级网络,该网络在ShuffleNetV2的基础上融合了MobileNetV3、GhostNet、PPLCNet的优点,组合成了一个在ARM设备上速度更快、精度更高的网络,由于其出色的表现,所以在PaddleDetection推出的[PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.3/configs/picodet)使用了该模型做backbone,配合更强的目标检测算法,最终的指标一举刷新了目标检测模型在ARM设备上的SOTA指标。
## 精度、FLOPs和参数量
| Models | Top1 | Top5 | FLOPs<br>(M) | Params<br/>(M) |
|:--:|:--:|:--:|:--:|:--:|
| ESNet_x0_25 | 62.48 | 83.46 | 30.9 | 2.83 |
| ESNet_x0_5 | 68.82 | 88.04 | 67.3 | 3.25 |
| ESNet_x0_75 | 72.24 | 90.45 | 123.7 | 3.87 |
| ESNet_x1_0 | 73.92 | 91.40 | 197.3 | 4.64 |
关于Inference speed等信息,敬请期待。
...@@ -22,6 +22,7 @@ from ppcls.arch.backbone.legendary_models.vgg import VGG11, VGG13, VGG16, VGG19 ...@@ -22,6 +22,7 @@ from ppcls.arch.backbone.legendary_models.vgg import VGG11, VGG13, VGG16, VGG19
from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3 from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3
from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5 from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5
from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0
from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc
from ppcls.arch.backbone.model_zoo.resnext import ResNeXt50_32x4d, ResNeXt50_64x4d, ResNeXt101_32x4d, ResNeXt101_64x4d, ResNeXt152_32x4d, ResNeXt152_64x4d from ppcls.arch.backbone.model_zoo.resnext import ResNeXt50_32x4d, ResNeXt50_64x4d, ResNeXt101_32x4d, ResNeXt101_64x4d, ResNeXt152_32x4d, ResNeXt152_64x4d
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import math
import paddle
from paddle import ParamAttr, reshape, transpose, concat, split
import paddle.nn as nn
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import KaimingNormal
from paddle.regularizer import L2Decay
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"ESNet_x0_25":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_25_pretrained.pdparams",
"ESNet_x0_5":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_5_pretrained.pdparams",
"ESNet_x0_75":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_75_pretrained.pdparams",
"ESNet_x1_0":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x1_0_pretrained.pdparams",
}
__all__ = list(MODEL_URLS.keys())
def channel_shuffle(x, groups):
batch_size, num_channels, height, width = x.shape[0:4]
channels_per_group = num_channels // groups
x = reshape(
x=x, shape=[batch_size, groups, channels_per_group, height, width])
x = transpose(x=x, perm=[0, 2, 1, 3, 4])
x = reshape(x=x, shape=[batch_size, num_channels, height, width])
return x
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNLayer(TheseusLayer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
if_act=True):
super().__init__()
self.conv = Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn = BatchNorm(
out_channels,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.if_act = if_act
self.hardswish = nn.Hardswish()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
x = self.hardswish(x)
return x
class SEModule(TheseusLayer):
def __init__(self, channel, reduction=4):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0)
self.relu = nn.ReLU()
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0)
self.hardsigmoid = nn.Hardsigmoid()
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.hardsigmoid(x)
x = paddle.multiply(x=identity, y=x)
return x
class ESBlock1(TheseusLayer):
def __init__(self,
in_channels,
out_channels):
super().__init__()
self.pw_1_1 = ConvBNLayer(
in_channels=in_channels // 2,
out_channels=out_channels // 2,
kernel_size=1,
stride=1)
self.dw_1 = ConvBNLayer(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=3,
stride=1,
groups=out_channels // 2,
if_act=False)
self.se = SEModule(out_channels)
self.pw_1_2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1)
def forward(self, x):
x1, x2 = split(
x,
num_or_sections=[x.shape[1] // 2, x.shape[1] // 2],
axis=1)
x2 = self.pw_1_1(x2)
x3 = self.dw_1(x2)
x3 = concat([x2, x3], axis=1)
x3 = self.se(x3)
x3 = self.pw_1_2(x3)
x = concat([x1, x3], axis=1)
return channel_shuffle(x, 2)
class ESBlock2(TheseusLayer):
def __init__(self,
in_channels,
out_channels):
super().__init__()
# branch1
self.dw_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=2,
groups=in_channels,
if_act=False)
self.pw_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1)
# branch2
self.pw_2_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels // 2,
kernel_size=1)
self.dw_2 = ConvBNLayer(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=3,
stride=2,
groups=out_channels // 2,
if_act=False)
self.se = SEModule(out_channels // 2)
self.pw_2_2 = ConvBNLayer(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=1)
self.concat_dw = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
groups=out_channels)
self.concat_pw = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1)
def forward(self, x):
x1 = self.dw_1(x)
x1 = self.pw_1(x1)
x2 = self.pw_2_1(x)
x2 = self.dw_2(x2)
x2 = self.se(x2)
x2 = self.pw_2_2(x2)
x = concat([x1, x2], axis=1)
x = self.concat_dw(x)
x = self.concat_pw(x)
return x
class ESNet(TheseusLayer):
def __init__(self, class_num=1000, scale=1.0, dropout_prob=0.2, class_expand=1280):
super().__init__()
self.scale = scale
self.class_num = class_num
self.class_expand = class_expand
stage_repeats = [3, 7, 3]
stage_out_channels = [-1, 24, make_divisible(116*scale), make_divisible(232*scale), make_divisible(464*scale), 1024]
self.conv1 = ConvBNLayer(
in_channels=3,
out_channels=stage_out_channels[1],
kernel_size=3,
stride=2)
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
block_list = []
for stage_id, num_repeat in enumerate(stage_repeats):
for i in range(num_repeat):
if i == 0:
block = ESBlock2(
in_channels=stage_out_channels[stage_id + 1],
out_channels=stage_out_channels[stage_id + 2])
else:
block = ESBlock1(
in_channels=stage_out_channels[stage_id + 2],
out_channels=stage_out_channels[stage_id + 2])
block_list.append(block)
self.blocks = nn.Sequential(*block_list)
self.conv2 = ConvBNLayer(
in_channels=stage_out_channels[-2],
out_channels=stage_out_channels[-1],
kernel_size=1)
self.avg_pool = AdaptiveAvgPool2D(1)
self.last_conv = Conv2D(
in_channels=stage_out_channels[-1],
out_channels=self.class_expand,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.hardswish = nn.Hardswish()
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.fc = Linear(self.class_expand, self.class_num)
def forward(self, x):
x = self.conv1(x)
x = self.max_pool(x)
x = self.blocks(x)
x = self.conv2(x)
x = self.avg_pool(x)
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
x = self.flatten(x)
x = self.fc(x)
return x
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def ESNet_x0_25(pretrained=False, use_ssld=False, **kwargs):
"""
ESNet_x0_25
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ESNet_x0_25` model depends on args.
"""
model = ESNet(scale=0.25, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_25"], use_ssld)
return model
def ESNet_x0_5(pretrained=False, use_ssld=False, **kwargs):
"""
ESNet_x0_5
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ESNet_x0_5` model depends on args.
"""
model = ESNet(scale=0.5, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_5"], use_ssld)
return model
def ESNet_x0_75(pretrained=False, use_ssld=False, **kwargs):
"""
ESNet_x0_75
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ESNet_x0_75` model depends on args.
"""
model = ESNet(scale=0.75, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_75"], use_ssld)
return model
def ESNet_x1_0(pretrained=False, use_ssld=False, **kwargs):
"""
ESNet_x1_0
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ESNet_x1_0` model depends on args.
"""
model = ESNet(scale=1.0, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x1_0"], use_ssld)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
class_num: 1000
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 360
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: ESNet_x0_25
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.8
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00003
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/whl/demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
class_num: 1000
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 360
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: ESNet_x0_5
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.8
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00003
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/whl/demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
class_num: 1000
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 360
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: ESNet_x0_75
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.8
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00003
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/whl/demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
class_num: 1000
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 360
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: ESNet_x1_0
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.8
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00003
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/whl/demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册