Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
4150df91
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4150df91
编写于
9月 07, 2021
作者:
C
cuicheng01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add LCNet codes and config
上级
3f0b406f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
1433 addition
and
0 deletion
+1433
-0
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/legendary_models/lcnet.py
ppcls/arch/backbone/legendary_models/lcnet.py
+399
-0
ppcls/configs/ImageNet/LCNet/LCNet_x0_25.yaml
ppcls/configs/ImageNet/LCNet/LCNet_x0_25.yaml
+129
-0
ppcls/configs/ImageNet/LCNet/LCNet_x0_35.yaml
ppcls/configs/ImageNet/LCNet/LCNet_x0_35.yaml
+129
-0
ppcls/configs/ImageNet/LCNet/LCNet_x0_5.yaml
ppcls/configs/ImageNet/LCNet/LCNet_x0_5.yaml
+129
-0
ppcls/configs/ImageNet/LCNet/LCNet_x0_75.yaml
ppcls/configs/ImageNet/LCNet/LCNet_x0_75.yaml
+129
-0
ppcls/configs/ImageNet/LCNet/LCNet_x1_0.yaml
ppcls/configs/ImageNet/LCNet/LCNet_x1_0.yaml
+129
-0
ppcls/configs/ImageNet/LCNet/LCNet_x1_5.yaml
ppcls/configs/ImageNet/LCNet/LCNet_x1_5.yaml
+129
-0
ppcls/configs/ImageNet/LCNet/LCNet_x2_0.yaml
ppcls/configs/ImageNet/LCNet/LCNet_x2_0.yaml
+129
-0
ppcls/configs/ImageNet/LCNet/LCNet_x2_5.yaml
ppcls/configs/ImageNet/LCNet/LCNet_x2_5.yaml
+130
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
4150df91
...
...
@@ -21,6 +21,7 @@ from ppcls.arch.backbone.legendary_models.resnet import ResNet18, ResNet18_vd, R
from
ppcls.arch.backbone.legendary_models.vgg
import
VGG11
,
VGG13
,
VGG16
,
VGG19
from
ppcls.arch.backbone.legendary_models.inception_v3
import
InceptionV3
from
ppcls.arch.backbone.legendary_models.hrnet
import
HRNet_W18_C
,
HRNet_W30_C
,
HRNet_W32_C
,
HRNet_W40_C
,
HRNet_W44_C
,
HRNet_W48_C
,
HRNet_W60_C
,
HRNet_W64_C
,
SE_HRNet_W64_C
from
ppcls.arch.backbone.legendary_models.lcnet
import
LCNet_x0_25
,
LCNet_x0_35
,
LCNet_x0_5
,
LCNet_x0_75
,
LCNet_x1_0
,
LCNet_x1_5
,
LCNet_x2_0
,
LCNet_x2_5
from
ppcls.arch.backbone.model_zoo.resnet_vc
import
ResNet50_vc
from
ppcls.arch.backbone.model_zoo.resnext
import
ResNeXt50_32x4d
,
ResNeXt50_64x4d
,
ResNeXt101_32x4d
,
ResNeXt101_64x4d
,
ResNeXt152_32x4d
,
ResNeXt152_64x4d
...
...
ppcls/arch/backbone/legendary_models/lcnet.py
0 → 100644
浏览文件 @
4150df91
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
paddle
import
paddle.nn
as
nn
from
paddle
import
ParamAttr
from
paddle.nn
import
AdaptiveAvgPool2D
,
BatchNorm
,
Conv2D
,
Dropout
,
Linear
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
KaimingNormal
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
"LCNet_x0_25"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/LCNet_x0_25_pretrained.pdparams"
,
"LCNet_x0_35"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/LCNet_x0_35_pretrained.pdparams"
,
"LCNet_x0_5"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/LCNet_x0_5_pretrained.pdparams"
,
"LCNet_x0_75"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/LCNet_x0_75_pretrained.pdparams"
,
"LCNet_x1_0"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/LCNet_x1_0_pretrained.pdparams"
,
"LCNet_x1_5"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/LCNet_x1_5_pretrained.pdparams"
,
"LCNet_x2_0"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/LCNet_x2_0_pretrained.pdparams"
,
"LCNet_x2_5"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/LCNet_x2_5_pretrained.pdparams"
}
__all__
=
list
(
MODEL_URLS
.
keys
())
# Each element(list) represents a depthwise block, which is composed of k, in_c, out_c, s, use_se.
# k: kernel_size
# in_c: input channel number in depthwise block
# out_c: output channel number in depthwise block
# s: stride in depthwise block
# use_se: whether to use SE block
NET_CONFIG
=
{
"blocks2"
:
#k, in_c, out_c, s, use_se
[[
3
,
16
,
32
,
1
,
False
]],
"blocks3"
:
[[
3
,
32
,
64
,
2
,
False
],
[
3
,
64
,
64
,
1
,
False
]],
"blocks4"
:
[[
3
,
64
,
128
,
2
,
False
],
[
3
,
128
,
128
,
1
,
False
]],
"blocks5"
:
[[
3
,
128
,
256
,
2
,
False
],
[
5
,
256
,
256
,
1
,
False
],
[
5
,
256
,
256
,
1
,
False
],
[
5
,
256
,
256
,
1
,
False
],
[
5
,
256
,
256
,
1
,
False
],
[
5
,
256
,
256
,
1
,
False
]],
"blocks6"
:
[[
5
,
256
,
512
,
2
,
True
],
[
5
,
512
,
512
,
1
,
True
]]
}
def
make_divisible
(
v
,
divisor
=
8
,
min_value
=
None
):
if
min_value
is
None
:
min_value
=
divisor
new_v
=
max
(
min_value
,
int
(
v
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_v
<
0.9
*
v
:
new_v
+=
divisor
return
new_v
class
ConvBNLayer
(
TheseusLayer
):
def
__init__
(
self
,
num_channels
,
filter_size
,
num_filters
,
stride
,
num_groups
=
1
):
super
().
__init__
()
self
.
conv
=
Conv2D
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
num_groups
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()),
bias_attr
=
False
)
self
.
bn
=
BatchNorm
(
num_filters
,
param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
hardswish
=
nn
.
Hardswish
()
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
x
=
self
.
hardswish
(
x
)
return
x
class
DepthwiseSeparable
(
TheseusLayer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
dw_size
=
3
,
use_se
=
False
):
super
().
__init__
()
self
.
use_se
=
use_se
self
.
dw_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
num_channels
,
filter_size
=
dw_size
,
stride
=
stride
,
num_groups
=
num_channels
)
if
use_se
:
self
.
se
=
SEModule
(
num_channels
)
self
.
pw_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
filter_size
=
1
,
num_filters
=
num_filters
,
stride
=
1
)
def
forward
(
self
,
x
):
x
=
self
.
dw_conv
(
x
)
if
self
.
use_se
:
x
=
self
.
se
(
x
)
x
=
self
.
pw_conv
(
x
)
return
x
class
SEModule
(
TheseusLayer
):
def
__init__
(
self
,
channel
,
reduction
=
4
):
super
().
__init__
()
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
conv1
=
Conv2D
(
in_channels
=
channel
,
out_channels
=
channel
//
reduction
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
Conv2D
(
in_channels
=
channel
//
reduction
,
out_channels
=
channel
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
hardsigmoid
=
nn
.
Hardsigmoid
()
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
avg_pool
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
hardsigmoid
(
x
)
x
=
paddle
.
multiply
(
x
=
identity
,
y
=
x
)
return
x
class
LCNet
(
TheseusLayer
):
def
__init__
(
self
,
scale
=
1.0
,
class_num
=
1000
,
dropout_prob
=
0.2
,
class_expand
=
1280
):
super
().
__init__
()
self
.
scale
=
scale
self
.
class_expand
=
class_expand
self
.
conv1
=
ConvBNLayer
(
num_channels
=
3
,
filter_size
=
3
,
num_filters
=
make_divisible
(
16
*
scale
),
stride
=
2
)
self
.
blocks2
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks2"
])
])
self
.
blocks3
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks3"
])
])
self
.
blocks4
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks4"
])
])
self
.
blocks5
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks5"
])
])
self
.
blocks6
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks6"
])
])
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
last_conv
=
Conv2D
(
in_channels
=
make_divisible
(
NET_CONFIG
[
"blocks6"
][
-
1
][
2
]
*
scale
),
out_channels
=
self
.
class_expand
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
)
self
.
hardswish
=
nn
.
Hardswish
()
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
self
.
fc
=
Linear
(
self
.
class_expand
,
class_num
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
blocks2
(
x
)
x
=
self
.
blocks3
(
x
)
x
=
self
.
blocks4
(
x
)
x
=
self
.
blocks5
(
x
)
x
=
self
.
blocks6
(
x
)
x
=
self
.
avg_pool
(
x
)
x
=
self
.
last_conv
(
x
)
x
=
self
.
hardswish
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc
(
x
)
return
x
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
LCNet_x0_25
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
LCNet_x0_25
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `LCNet_x0_25` model depends on args.
"""
model
=
LCNet
(
scale
=
0.25
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"LCNet_x0_25"
],
use_ssld
)
return
model
def
LCNet_x0_35
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
LCNet_x0_35
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `LCNet_x0_35` model depends on args.
"""
model
=
LCNet
(
scale
=
0.35
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"LCNet_x0_35"
],
use_ssld
)
return
model
def
LCNet_x0_5
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
LCNet_x0_5
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `LCNet_x0_5` model depends on args.
"""
model
=
LCNet
(
scale
=
0.5
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"LCNet_x0_5"
],
use_ssld
)
return
model
def
LCNet_x0_75
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
LCNet_x0_75
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `LCNet_x0_75` model depends on args.
"""
model
=
LCNet
(
scale
=
0.75
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"LCNet_x0_75"
],
use_ssld
)
return
model
def
LCNet_x1_0
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
LCNet_x1_0
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `LCNet_x1_0` model depends on args.
"""
model
=
LCNet
(
scale
=
1.0
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"LCNet_x1_0"
],
use_ssld
)
return
model
def
LCNet_x1_5
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
LCNet_x1_5
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `LCNet_x1_5` model depends on args.
"""
model
=
LCNet
(
scale
=
1.5
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"LCNet_x1_5"
],
use_ssld
)
return
model
def
LCNet_x2_0
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
LCNet_x2_0
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `LCNet_x2_0` model depends on args.
"""
model
=
LCNet
(
scale
=
2.0
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"LCNet_x2_0"
],
use_ssld
)
return
model
def
LCNet_x2_5
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
LCNet_x2_5
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `LCNet_x2_5` model depends on args.
"""
model
=
LCNet
(
scale
=
2.5
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"LCNet_x2_5"
],
use_ssld
)
return
model
ppcls/configs/ImageNet/LCNet/LCNet_x0_25.yaml
0 → 100644
浏览文件 @
4150df91
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
class_num
:
1000
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
360
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
LCNet_x0_25
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.8
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00003
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/LCNet/LCNet_x0_35.yaml
0 → 100644
浏览文件 @
4150df91
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
class_num
:
1000
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
360
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
LCNet_x0_35
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.8
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00003
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/LCNet/LCNet_x0_5.yaml
0 → 100644
浏览文件 @
4150df91
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
class_num
:
1000
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
360
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
LCNet_x0_5
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.8
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00003
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/LCNet/LCNet_x0_75.yaml
0 → 100644
浏览文件 @
4150df91
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
class_num
:
1000
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
360
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
LCNet_x0_75
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.8
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00003
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/LCNet/LCNet_x1_0.yaml
0 → 100644
浏览文件 @
4150df91
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
class_num
:
1000
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
360
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
LCNet_x1_0
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.8
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00003
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/LCNet/LCNet_x1_5.yaml
0 → 100644
浏览文件 @
4150df91
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
class_num
:
1000
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
360
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
LCNet_x1_5
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.8
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/LCNet/LCNet_x2_0.yaml
0 → 100644
浏览文件 @
4150df91
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
class_num
:
1000
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
360
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
LCNet_x2_0
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.8
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/LCNet/LCNet_x2_5.yaml
0 → 100644
浏览文件 @
4150df91
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
class_num
:
1000
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
360
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
LCNet_x2_5
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.8
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
AutoAugment
:
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/whl/demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录