Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
c351dac6
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c351dac6
编写于
3月 23, 2023
作者:
Y
Yang Nie
提交者:
Tingquan Gao
5月 06, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add tinynet
上级
f8fdc5fd
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
1126 addition
and
47 deletion
+1126
-47
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/model_zoo/efficientnet.py
ppcls/arch/backbone/model_zoo/efficientnet.py
+76
-42
ppcls/arch/backbone/model_zoo/tinynet.py
ppcls/arch/backbone/model_zoo/tinynet.py
+191
-0
ppcls/configs/ImageNet/TinyNet/TinyNet_A.yaml
ppcls/configs/ImageNet/TinyNet/TinyNet_A.yaml
+155
-0
ppcls/configs/ImageNet/TinyNet/TinyNet_B.yaml
ppcls/configs/ImageNet/TinyNet/TinyNet_B.yaml
+155
-0
ppcls/configs/ImageNet/TinyNet/TinyNet_C.yaml
ppcls/configs/ImageNet/TinyNet/TinyNet_C.yaml
+155
-0
ppcls/configs/ImageNet/TinyNet/TinyNet_D.yaml
ppcls/configs/ImageNet/TinyNet/TinyNet_D.yaml
+155
-0
ppcls/configs/ImageNet/TinyNet/TinyNet_E.yaml
ppcls/configs/ImageNet/TinyNet/TinyNet_E.yaml
+155
-0
ppcls/optimizer/learning_rate.py
ppcls/optimizer/learning_rate.py
+2
-2
ppcls/optimizer/optimizer.py
ppcls/optimizer/optimizer.py
+27
-3
test_tipc/configs/TinyNet/TinyNet_A_train_infer_python.txt
test_tipc/configs/TinyNet/TinyNet_A_train_infer_python.txt
+54
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
c351dac6
...
...
@@ -80,6 +80,7 @@ from .model_zoo.micronet import MicroNet_M0, MicroNet_M1, MicroNet_M2, MicroNet_
from
.model_zoo.mobilenext
import
MobileNeXt_x0_35
,
MobileNeXt_x0_5
,
MobileNeXt_x0_75
,
MobileNeXt_x1_0
,
MobileNeXt_x1_4
from
.model_zoo.mobilevit_v2
import
MobileViTV2_x0_5
,
MobileViTV2_x0_75
,
MobileViTV2_x1_0
,
MobileViTV2_x1_25
,
MobileViTV2_x1_5
,
MobileViTV2_x1_75
,
MobileViTV2_x2_0
from
.model_zoo.mobilevit_v3
import
MobileViTv3_XXS
,
MobileViTv3_XS
,
MobileViTv3_S
,
MobileViTv3_XXS_L2
,
MobileViTv3_XS_L2
,
MobileViTv3_S_L2
,
MobileViTv3_x0_5
,
MobileViTv3_x0_75
,
MobileViTv3_x1_0
from
.model_zoo.tinynet
import
TinyNet_A
,
TinyNet_B
,
TinyNet_C
,
TinyNet_D
,
TinyNet_E
from
.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
.variant_models.resnet_variant
import
ResNet50_adaptive_max_pool2d
...
...
ppcls/arch/backbone/model_zoo/efficientnet.py
浏览文件 @
c351dac6
...
...
@@ -60,6 +60,7 @@ GlobalParams = collections.namedtuple('GlobalParams', [
'width_coefficient'
,
'depth_coefficient'
,
'depth_divisor'
,
'depth_trunc'
,
'min_depth'
,
'drop_connect_rate'
,
])
...
...
@@ -77,6 +78,7 @@ def efficientnet_params(model_name):
""" Map EfficientNet model name to parameter coefficients. """
params_dict
=
{
# Coefficients: width,depth,resolution,dropout
'efficientnet-b0-small'
:
(
1.0
,
1.0
,
224
,
0.2
),
'efficientnet-b0'
:
(
1.0
,
1.0
,
224
,
0.2
),
'efficientnet-b1'
:
(
1.0
,
1.1
,
240
,
0.2
),
'efficientnet-b2'
:
(
1.1
,
1.2
,
260
,
0.3
),
...
...
@@ -114,6 +116,7 @@ def efficientnet(width_coefficient=None,
width_coefficient
=
width_coefficient
,
depth_coefficient
=
depth_coefficient
,
depth_divisor
=
8
,
depth_trunc
=
'ceil'
,
min_depth
=
None
)
return
blocks_args
,
global_params
...
...
@@ -154,7 +157,10 @@ def round_repeats(repeats, global_params):
multiplier
=
global_params
.
depth_coefficient
if
not
multiplier
:
return
repeats
return
int
(
math
.
ceil
(
multiplier
*
repeats
))
if
global_params
.
depth_trunc
==
'round'
:
return
max
(
1
,
round
(
multiplier
*
repeats
))
else
:
return
int
(
math
.
ceil
(
multiplier
*
repeats
))
class
BlockDecoder
(
object
):
...
...
@@ -314,10 +320,10 @@ class Conv2ds(TheseusLayer):
padding
=
((
stride
-
1
)
+
dilation
*
(
filter_size
-
1
))
//
2
return
padding
inps
=
1
if
model_name
==
None
and
cur_stage
==
None
else
inp_shape
[
model_name
][
cur_stage
]
self
.
need_crop
=
False
if
padding_type
==
"SAME"
:
inps
=
1
if
model_name
==
None
and
cur_stage
==
None
else
inp_shape
[
model_name
][
cur_stage
]
top_padding
,
bottom_padding
=
cal_padding
(
inps
,
stride
,
filter_size
)
left_padding
,
right_padding
=
cal_padding
(
inps
,
stride
,
...
...
@@ -398,12 +404,13 @@ class ConvBNLayer(TheseusLayer):
if
use_bn
is
True
:
bn_name
=
name
+
bn_name
param_attr
,
bias_attr
=
init_batch_norm_layer
(
bn_name
)
momentum
=
global_params
.
batch_norm_momentum
epsilon
=
global_params
.
batch_norm_epsilon
self
.
_bn
=
BatchNorm
(
num_channels
=
output_channels
,
act
=
bn_act
,
momentum
=
0.99
,
momentum
=
momentum
,
epsilon
=
epsilon
,
moving_mean_name
=
bn_name
+
"_mean"
,
moving_variance_name
=
bn_name
+
"_variance"
,
...
...
@@ -501,12 +508,12 @@ class ProjectConvNorm(TheseusLayer):
cur_stage
=
None
):
super
(
ProjectConvNorm
,
self
).
__init__
()
final_oup
=
block_args
.
output_filters
self
.
final_oup
=
block_args
.
output_filters
self
.
_conv
=
ConvBNLayer
(
input_channels
,
1
,
final_oup
,
self
.
final_oup
,
global_params
=
global_params
,
bn_act
=
None
,
padding_type
=
padding_type
,
...
...
@@ -619,6 +626,8 @@ class MbConvBlock(TheseusLayer):
model_name
=
model_name
,
cur_stage
=
cur_stage
)
self
.
final_oup
=
self
.
_pcn
.
final_oup
def
forward
(
self
,
inputs
):
x
=
inputs
if
self
.
expand_ratio
!=
1
:
...
...
@@ -647,10 +656,11 @@ class ConvStemNorm(TheseusLayer):
_global_params
,
name
=
None
,
model_name
=
None
,
fix_stem
=
False
,
cur_stage
=
None
):
super
(
ConvStemNorm
,
self
).
__init__
()
output_channels
=
round_filters
(
32
,
_global_params
)
output_channels
=
32
if
fix_stem
else
round_filters
(
32
,
_global_params
)
self
.
_conv
=
ConvBNLayer
(
input_channels
,
filter_size
=
3
,
...
...
@@ -676,7 +686,8 @@ class ExtractFeatures(TheseusLayer):
_global_params
,
padding_type
,
use_se
,
model_name
=
None
):
model_name
=
None
,
fix_stem
=
False
):
super
(
ExtractFeatures
,
self
).
__init__
()
self
.
_global_params
=
_global_params
...
...
@@ -686,6 +697,7 @@ class ExtractFeatures(TheseusLayer):
padding_type
=
padding_type
,
_global_params
=
_global_params
,
model_name
=
model_name
,
fix_stem
=
fix_stem
,
cur_stage
=
0
)
self
.
block_args_copy
=
copy
.
deepcopy
(
_block_args
)
...
...
@@ -702,12 +714,14 @@ class ExtractFeatures(TheseusLayer):
for
_
in
range
(
block_arg
.
num_repeat
-
1
):
block_size
+=
1
self
.
final_oup
=
None
self
.
conv_seq
=
[]
cur_stage
=
1
for
block_args
in
_block_args
:
for
block_idx
,
block_args
in
enumerate
(
_block_args
):
if
not
(
fix_stem
and
block_idx
==
0
):
block_args
=
block_args
.
_replace
(
input_filters
=
round_filters
(
block_args
.
input_filters
,
_global_params
))
block_args
=
block_args
.
_replace
(
input_filters
=
round_filters
(
block_args
.
input_filters
,
_global_params
),
output_filters
=
round_filters
(
block_args
.
output_filters
,
_global_params
),
num_repeat
=
round_repeats
(
block_args
.
num_repeat
,
...
...
@@ -730,6 +744,7 @@ class ExtractFeatures(TheseusLayer):
model_name
=
model_name
,
cur_stage
=
cur_stage
))
self
.
conv_seq
.
append
(
_mc_block
)
self
.
final_oup
=
_mc_block
.
final_oup
idx
+=
1
if
block_args
.
num_repeat
>
1
:
block_args
=
block_args
.
_replace
(
...
...
@@ -751,6 +766,7 @@ class ExtractFeatures(TheseusLayer):
model_name
=
model_name
,
cur_stage
=
cur_stage
))
self
.
conv_seq
.
append
(
_mc_block
)
self
.
final_oup
=
_mc_block
.
final_oup
idx
+=
1
cur_stage
+=
1
...
...
@@ -764,17 +780,20 @@ class ExtractFeatures(TheseusLayer):
class
EfficientNet
(
TheseusLayer
):
def
__init__
(
self
,
block_args
,
global_params
,
name
=
"b0"
,
padding_type
=
"SAME"
,
override_params
=
None
,
use_se
=
True
,
fix_stem
=
False
,
num_features
=
None
,
class_num
=
1000
):
super
(
EfficientNet
,
self
).
__init__
()
model_name
=
'efficientnet-'
+
name
self
.
name
=
name
self
.
_block_args
,
self
.
_global_params
=
get_model_params
(
model_name
,
override_params
)
self
.
fix_stem
=
fix_stem
self
.
_block_args
=
block_args
self
.
_global_params
=
global_params
self
.
padding_type
=
padding_type
self
.
use_se
=
use_se
...
...
@@ -784,25 +803,13 @@ class EfficientNet(TheseusLayer):
self
.
_global_params
,
self
.
padding_type
,
self
.
use_se
,
model_name
=
self
.
name
)
output_channels
=
round_filters
(
1280
,
self
.
_global_params
)
if
name
==
"b0_small"
or
name
==
"b0"
or
name
==
"b1"
:
oup
=
320
elif
name
==
"b2"
:
oup
=
352
elif
name
==
"b3"
:
oup
=
384
elif
name
==
"b4"
:
oup
=
448
elif
name
==
"b5"
:
oup
=
512
elif
name
==
"b6"
:
oup
=
576
elif
name
==
"b7"
:
oup
=
640
model_name
=
self
.
name
,
fix_stem
=
self
.
fix_stem
)
output_channels
=
num_features
or
round_filters
(
1280
,
self
.
_global_params
)
self
.
_conv
=
ConvBNLayer
(
oup
,
self
.
_ef
.
final_
oup
,
1
,
output_channels
,
global_params
=
self
.
_global_params
,
...
...
@@ -856,10 +863,13 @@ def EfficientNetB0_small(padding_type='DYNAMIC',
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"efficientnet-b0-small"
,
override_params
)
model
=
EfficientNet
(
block_args
,
global_params
,
name
=
'b0'
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetB0_small"
])
...
...
@@ -872,10 +882,13 @@ def EfficientNetB0(padding_type='SAME',
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"efficientnet-b0"
,
override_params
)
model
=
EfficientNet
(
block_args
,
global_params
,
name
=
'b0'
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetB0"
])
...
...
@@ -888,10 +901,13 @@ def EfficientNetB1(padding_type='SAME',
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"efficientnet-b1"
,
override_params
)
model
=
EfficientNet
(
block_args
,
global_params
,
name
=
'b1'
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetB1"
])
...
...
@@ -904,10 +920,13 @@ def EfficientNetB2(padding_type='SAME',
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"efficientnet-b2"
,
override_params
)
model
=
EfficientNet
(
block_args
,
global_params
,
name
=
'b2'
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetB2"
])
...
...
@@ -920,10 +939,13 @@ def EfficientNetB3(padding_type='SAME',
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"efficientnet-b3"
,
override_params
)
model
=
EfficientNet
(
block_args
,
global_params
,
name
=
'b3'
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetB3"
])
...
...
@@ -936,10 +958,13 @@ def EfficientNetB4(padding_type='SAME',
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"efficientnet-b4"
,
override_params
)
model
=
EfficientNet
(
block_args
,
global_params
,
name
=
'b4'
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetB4"
])
...
...
@@ -952,10 +977,13 @@ def EfficientNetB5(padding_type='SAME',
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"efficientnet-b5"
,
override_params
)
model
=
EfficientNet
(
block_args
,
global_params
,
name
=
'b5'
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetB5"
])
...
...
@@ -968,10 +996,13 @@ def EfficientNetB6(padding_type='SAME',
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"efficientnet-b6"
,
override_params
)
model
=
EfficientNet
(
block_args
,
global_params
,
name
=
'b6'
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetB6"
])
...
...
@@ -984,10 +1015,13 @@ def EfficientNetB7(padding_type='SAME',
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"efficientnet-b7"
,
override_params
)
model
=
EfficientNet
(
block_args
,
global_params
,
name
=
'b7'
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"EfficientNetB7"
])
...
...
ppcls/arch/backbone/model_zoo/tinynet.py
0 → 100644
浏览文件 @
c351dac6
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code was based on https://gitee.com/mindspore/models/tree/master/research/cv/tinynet
# reference: https://arxiv.org/abs/2010.14819
import
paddle.nn
as
nn
from
.efficientnet
import
EfficientNet
,
efficientnet
from
....utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
"TinyNet_A"
:
""
,
"TinyNet_B"
:
""
,
"TinyNet_C"
:
""
,
"TinyNet_D"
:
""
,
"TinyNet_E"
:
""
,
}
__all__
=
list
(
MODEL_URLS
.
keys
())
def
tinynet_params
(
model_name
):
""" Map TinyNet model name to parameter coefficients. """
params_dict
=
{
# Coefficients: width,depth,resolution,dropout
"tinynet-a"
:
(
1.00
,
1.200
,
192
,
0.2
),
"tinynet-b"
:
(
0.75
,
1.100
,
188
,
0.2
),
"tinynet-c"
:
(
0.54
,
0.850
,
184
,
0.2
),
"tinynet-d"
:
(
0.54
,
0.695
,
152
,
0.2
),
"tinynet-e"
:
(
0.51
,
0.600
,
106
,
0.2
),
}
return
params_dict
[
model_name
]
def
get_model_params
(
model_name
,
override_params
):
""" Get the block args and global params for a given model """
if
model_name
.
startswith
(
'tinynet'
):
w
,
d
,
_
,
p
=
tinynet_params
(
model_name
)
blocks_args
,
global_params
=
efficientnet
(
width_coefficient
=
w
,
depth_coefficient
=
d
,
dropout_rate
=
p
)
else
:
raise
NotImplementedError
(
'model name is not pre-defined: %s'
%
model_name
)
if
override_params
:
global_params
=
global_params
.
_replace
(
**
override_params
)
return
blocks_args
,
global_params
class
TinyNet
(
EfficientNet
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Conv2D
):
fin_in
=
m
.
weight
.
shape
[
1
]
*
m
.
weight
.
shape
[
2
]
*
m
.
weight
.
shape
[
3
]
std
=
(
2
/
fin_in
)
**
0.5
nn
.
initializer
.
Normal
(
std
=
std
)(
m
.
weight
)
if
m
.
bias
is
not
None
:
nn
.
initializer
.
Constant
(
0
)(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
Linear
):
fin_in
=
m
.
weight
.
shape
[
0
]
bound
=
1
/
fin_in
**
0.5
nn
.
initializer
.
Uniform
(
-
bound
,
bound
)(
m
.
weight
)
if
m
.
bias
is
not
None
:
nn
.
initializer
.
Constant
(
0
)(
m
.
bias
)
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
=
False
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
TinyNet_A
(
padding_type
=
'DYNAMIC'
,
override_params
=
None
,
use_se
=
True
,
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"tinynet-a"
,
override_params
)
model
=
TinyNet
(
block_args
,
global_params
,
name
=
'a'
,
padding_type
=
padding_type
,
use_se
=
use_se
,
fix_stem
=
True
,
num_features
=
1280
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"TinyNet_A"
],
use_ssld
)
return
model
def
TinyNet_B
(
padding_type
=
'DYNAMIC'
,
override_params
=
None
,
use_se
=
True
,
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"tinynet-b"
,
override_params
)
model
=
TinyNet
(
block_args
,
global_params
,
name
=
'b'
,
padding_type
=
padding_type
,
use_se
=
use_se
,
fix_stem
=
True
,
num_features
=
1280
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"TinyNet_B"
],
use_ssld
)
return
model
def
TinyNet_C
(
padding_type
=
'DYNAMIC'
,
override_params
=
None
,
use_se
=
True
,
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"tinynet-c"
,
override_params
)
model
=
TinyNet
(
block_args
,
global_params
,
name
=
'c'
,
padding_type
=
padding_type
,
use_se
=
use_se
,
fix_stem
=
True
,
num_features
=
1280
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"TinyNet_C"
],
use_ssld
)
return
model
def
TinyNet_D
(
padding_type
=
'DYNAMIC'
,
override_params
=
None
,
use_se
=
True
,
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"tinynet-d"
,
override_params
)
model
=
TinyNet
(
block_args
,
global_params
,
name
=
'd'
,
padding_type
=
padding_type
,
use_se
=
use_se
,
fix_stem
=
True
,
num_features
=
1280
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"TinyNet_D"
],
use_ssld
)
return
model
def
TinyNet_E
(
padding_type
=
'DYNAMIC'
,
override_params
=
None
,
use_se
=
True
,
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
block_args
,
global_params
=
get_model_params
(
"tinynet-e"
,
override_params
)
model
=
TinyNet
(
block_args
,
global_params
,
name
=
'e'
,
padding_type
=
padding_type
,
use_se
=
use_se
,
fix_stem
=
True
,
num_features
=
1280
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"TinyNet_E"
],
use_ssld
)
return
model
ppcls/configs/ImageNet/TinyNet/TinyNet_A.yaml
0 → 100644
浏览文件 @
c351dac6
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
450
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
192
,
192
]
save_inference_dir
:
./inference
# model ema
EMA
:
decay
:
0.9999
# model architecture
Arch
:
name
:
TinyNet_A
class_num
:
1000
override_params
:
batch_norm_momentum
:
0.9
batch_norm_epsilon
:
1e-5
depth_trunc
:
round
drop_connect_rate
:
0.1
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
RMSProp
momentum
:
0.9
rho
:
0.9
epsilon
:
0.001
one_dim_param_no_weight_decay
:
True
lr
:
name
:
Step
learning_rate
:
0.048
step_size
:
2.4
gamma
:
0.97
warmup_epoch
:
3
warmup_start_lr
:
1e-6
regularizer
:
name
:
'
L2'
coeff
:
1e-5
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
backend
:
pil
-
RandCropImage
:
size
:
192
interpolation
:
bicubic
backend
:
pil
use_log_aspect
:
True
-
RandFlipImage
:
flip_code
:
1
-
ColorJitter
:
brightness
:
0.4
contrast
:
0.4
saturation
:
0.4
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
backend
:
pil
-
ResizeImage
:
resize_short
:
219
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
192
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
-
ResizeImage
:
resize_short
:
219
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
192
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/TinyNet/TinyNet_B.yaml
0 → 100644
浏览文件 @
c351dac6
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
450
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
188
,
188
]
save_inference_dir
:
./inference
# model ema
EMA
:
decay
:
0.9999
# model architecture
Arch
:
name
:
TinyNet_B
class_num
:
1000
override_params
:
batch_norm_momentum
:
0.9
batch_norm_epsilon
:
1e-5
depth_trunc
:
round
drop_connect_rate
:
0.1
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
RMSProp
momentum
:
0.9
rho
:
0.9
epsilon
:
0.001
one_dim_param_no_weight_decay
:
True
lr
:
name
:
Step
learning_rate
:
0.048
step_size
:
2.4
gamma
:
0.97
warmup_epoch
:
3
warmup_start_lr
:
1e-6
regularizer
:
name
:
'
L2'
coeff
:
1e-5
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
backend
:
pil
-
RandCropImage
:
size
:
188
interpolation
:
bicubic
backend
:
pil
use_log_aspect
:
True
-
RandFlipImage
:
flip_code
:
1
-
ColorJitter
:
brightness
:
0.4
contrast
:
0.4
saturation
:
0.4
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
backend
:
pil
-
ResizeImage
:
resize_short
:
214
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
188
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
-
ResizeImage
:
resize_short
:
214
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
188
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/TinyNet/TinyNet_C.yaml
0 → 100644
浏览文件 @
c351dac6
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
450
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
184
,
184
]
save_inference_dir
:
./inference
# model ema
EMA
:
decay
:
0.9999
# model architecture
Arch
:
name
:
TinyNet_C
class_num
:
1000
override_params
:
batch_norm_momentum
:
0.9
batch_norm_epsilon
:
1e-5
depth_trunc
:
round
drop_connect_rate
:
0.0
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
RMSProp
momentum
:
0.9
rho
:
0.9
epsilon
:
0.001
one_dim_param_no_weight_decay
:
True
lr
:
name
:
Step
learning_rate
:
0.048
step_size
:
2.4
gamma
:
0.97
warmup_epoch
:
3
warmup_start_lr
:
1e-6
regularizer
:
name
:
'
L2'
coeff
:
1e-5
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
backend
:
pil
-
RandCropImage
:
size
:
184
interpolation
:
bicubic
backend
:
pil
use_log_aspect
:
True
-
RandFlipImage
:
flip_code
:
1
-
ColorJitter
:
brightness
:
0.4
contrast
:
0.4
saturation
:
0.4
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
backend
:
pil
-
ResizeImage
:
resize_short
:
210
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
184
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
-
ResizeImage
:
resize_short
:
210
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
184
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/TinyNet/TinyNet_D.yaml
0 → 100644
浏览文件 @
c351dac6
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
450
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
152
,
152
]
save_inference_dir
:
./inference
# model ema
EMA
:
decay
:
0.9999
# model architecture
Arch
:
name
:
TinyNet_D
class_num
:
1000
override_params
:
batch_norm_momentum
:
0.9
batch_norm_epsilon
:
1e-5
depth_trunc
:
round
drop_connect_rate
:
0.0
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
RMSProp
momentum
:
0.9
rho
:
0.9
epsilon
:
0.001
one_dim_param_no_weight_decay
:
True
lr
:
name
:
Step
learning_rate
:
0.048
step_size
:
2.4
gamma
:
0.97
warmup_epoch
:
3
warmup_start_lr
:
1e-6
regularizer
:
name
:
'
L2'
coeff
:
1e-5
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
backend
:
pil
-
RandCropImage
:
size
:
152
interpolation
:
bicubic
backend
:
pil
use_log_aspect
:
True
-
RandFlipImage
:
flip_code
:
1
-
ColorJitter
:
brightness
:
0.4
contrast
:
0.4
saturation
:
0.4
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
backend
:
pil
-
ResizeImage
:
resize_short
:
173
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
152
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
-
ResizeImage
:
resize_short
:
173
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
152
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/TinyNet/TinyNet_E.yaml
0 → 100644
浏览文件 @
c351dac6
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
450
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
106
,
106
]
save_inference_dir
:
./inference
# model ema
EMA
:
decay
:
0.9999
# model architecture
Arch
:
name
:
TinyNet_E
class_num
:
1000
override_params
:
batch_norm_momentum
:
0.9
batch_norm_epsilon
:
1e-5
depth_trunc
:
round
drop_connect_rate
:
0.0
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
RMSProp
momentum
:
0.9
rho
:
0.9
epsilon
:
0.001
one_dim_param_no_weight_decay
:
True
lr
:
name
:
Step
learning_rate
:
0.048
step_size
:
2.4
gamma
:
0.97
warmup_epoch
:
3
warmup_start_lr
:
1e-6
regularizer
:
name
:
'
L2'
coeff
:
1e-5
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
backend
:
pil
-
RandCropImage
:
size
:
106
interpolation
:
bicubic
backend
:
pil
use_log_aspect
:
True
-
RandFlipImage
:
flip_code
:
1
-
ColorJitter
:
brightness
:
0.4
contrast
:
0.4
saturation
:
0.4
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
backend
:
pil
-
ResizeImage
:
resize_short
:
121
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
106
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_np
:
False
channel_first
:
False
-
ResizeImage
:
resize_short
:
121
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
106
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/optimizer/learning_rate.py
浏览文件 @
c351dac6
...
...
@@ -339,7 +339,7 @@ class Step(LRBase):
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
learning_rate (float): learning rate
step_size (int): the interval to update.
step_size (int
|float
): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Default: 0.1.
warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
...
...
@@ -361,7 +361,7 @@ class Step(LRBase):
super
(
Step
,
self
).
__init__
(
epochs
,
step_each_epoch
,
learning_rate
,
warmup_epoch
,
warmup_start_lr
,
last_epoch
,
by_epoch
)
self
.
step_size
=
step_size
*
step_each_epoch
self
.
step_size
=
int
(
step_size
*
step_each_epoch
)
self
.
gamma
=
gamma
if
self
.
by_epoch
:
self
.
step_size
=
step_size
...
...
ppcls/optimizer/optimizer.py
浏览文件 @
c351dac6
...
...
@@ -215,7 +215,9 @@ class RMSProp(object):
epsilon
=
1e-6
,
weight_decay
=
None
,
grad_clip
=
None
,
multi_precision
=
False
):
multi_precision
=
False
,
no_weight_decay_name
=
None
,
one_dim_param_no_weight_decay
=
False
):
super
().
__init__
()
self
.
learning_rate
=
learning_rate
self
.
momentum
=
momentum
...
...
@@ -223,11 +225,33 @@ class RMSProp(object):
self
.
epsilon
=
epsilon
self
.
weight_decay
=
weight_decay
self
.
grad_clip
=
grad_clip
self
.
no_weight_decay_name_list
=
no_weight_decay_name
.
split
(
)
if
no_weight_decay_name
else
[]
self
.
one_dim_param_no_weight_decay
=
one_dim_param_no_weight_decay
def
__call__
(
self
,
model_list
):
# model_list is None in static graph
parameters
=
sum
([
m
.
parameters
()
for
m
in
model_list
],
[])
if
model_list
else
None
parameters
=
None
if
len
(
self
.
no_weight_decay_name_list
)
>
0
:
params_with_decay
=
[]
params_without_decay
=
[]
for
m
in
model_list
:
params
=
[
p
for
n
,
p
in
m
.
named_parameters
()
\
if
not
any
(
nd
in
n
for
nd
in
self
.
no_weight_decay_name_list
)]
params_with_decay
.
extend
(
params
)
params
=
[
p
for
n
,
p
in
m
.
named_parameters
()
\
if
any
(
nd
in
n
for
nd
in
self
.
no_weight_decay_name_list
)
or
(
self
.
one_dim_param_no_weight_decay
and
len
(
p
.
shape
)
==
1
)]
params_without_decay
.
extend
(
params
)
parameters
=
[{
"params"
:
params_with_decay
,
"weight_decay"
:
self
.
weight_decay
},
{
"params"
:
params_without_decay
,
"weight_decay"
:
0.0
}]
else
:
parameters
=
sum
([
m
.
parameters
()
for
m
in
model_list
],
[])
if
model_list
else
None
opt
=
optim
.
RMSProp
(
learning_rate
=
self
.
learning_rate
,
momentum
=
self
.
momentum
,
...
...
test_tipc/configs/TinyNet/TinyNet_A_train_infer_python.txt
0 → 100644
浏览文件 @
c351dac6
===========================train_params===========================
model_name:TinyNet_A
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/TinyNet/TinyNet_A.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/TinyNet/TinyNet_A.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/TinyNet/TinyNet_A.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
inference_dir:null
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:False
-o Global.cpu_num_threads:1
-o Global.batch_size:1
-o Global.use_tensorrt:False
-o Global.use_fp16:False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG
-o Global.save_log_path:null
-o Global.benchmark:False
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,192,192]}]
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录