Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
bcf563ff
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bcf563ff
编写于
6月 09, 2020
作者:
D
dyning
提交者:
GitHub
6月 09, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #152 from WuHaobo/dynamic
Dynamic Graph
上级
38ad51ca
f1c0a59f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
285 addition
and
480 deletion
+285
-480
ppcls/modeling/architectures/__init__.py
ppcls/modeling/architectures/__init__.py
+1
-35
ppcls/modeling/architectures/resnet.py
ppcls/modeling/architectures/resnet.py
+152
-196
ppcls/optimizer/optimizer.py
ppcls/optimizer/optimizer.py
+13
-11
tools/program.py
tools/program.py
+65
-150
tools/train.py
tools/train.py
+54
-88
未找到文件。
ppcls/modeling/architectures/__init__.py
浏览文件 @
bcf563ff
...
@@ -12,38 +12,4 @@
...
@@ -12,38 +12,4 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.alexnet
import
AlexNet
from
.resnet
import
*
from
.mobilenet_v1
import
MobileNetV1_x0_25
,
MobileNetV1_x0_5
,
MobileNetV1_x1_0
,
MobileNetV1_x0_75
,
MobileNetV1
from
.mobilenet_v2
import
MobileNetV2_x0_25
,
MobileNetV2_x0_5
,
MobileNetV2_x0_75
,
MobileNetV2_x1_0
,
MobileNetV2_x1_5
,
MobileNetV2_x2_0
,
MobileNetV2
from
.mobilenet_v3
import
MobileNetV3_small_x0_35
,
MobileNetV3_small_x0_5
,
MobileNetV3_small_x0_75
,
MobileNetV3_small_x1_0
,
MobileNetV3_small_x1_25
,
MobileNetV3_large_x0_35
,
MobileNetV3_large_x0_5
,
MobileNetV3_large_x0_75
,
MobileNetV3_large_x1_0
,
MobileNetV3_large_x1_25
from
.googlenet
import
GoogLeNet
from
.vgg
import
VGG11
,
VGG13
,
VGG16
,
VGG19
from
.resnet
import
ResNet18
,
ResNet34
,
ResNet50
,
ResNet101
,
ResNet152
from
.resnet_vc
import
ResNet50_vc
,
ResNet101_vc
,
ResNet152_vc
from
.resnet_vd
import
ResNet18_vd
,
ResNet34_vd
,
ResNet50_vd
,
ResNet101_vd
,
ResNet152_vd
,
ResNet200_vd
from
.resnext
import
ResNeXt50_64x4d
,
ResNeXt101_64x4d
,
ResNeXt152_64x4d
,
ResNeXt50_32x4d
,
ResNeXt101_32x4d
,
ResNeXt152_32x4d
from
.resnext_vd
import
ResNeXt50_vd_64x4d
,
ResNeXt101_vd_64x4d
,
ResNeXt152_vd_64x4d
,
ResNeXt50_vd_32x4d
,
ResNeXt101_vd_32x4d
,
ResNeXt152_vd_32x4d
from
.inception_v4
import
InceptionV4
from
.se_resnet_vd
import
SE_ResNet18_vd
,
SE_ResNet34_vd
,
SE_ResNet50_vd
,
SE_ResNet101_vd
,
SE_ResNet152_vd
,
SE_ResNet200_vd
from
.se_resnext
import
SE_ResNeXt50_32x4d
,
SE_ResNeXt101_32x4d
,
SE_ResNeXt152_32x4d
from
.se_resnext_vd
import
SE_ResNeXt50_vd_32x4d
,
SE_ResNeXt101_vd_32x4d
,
SENet154_vd
from
.dpn
import
DPN68
,
DPN92
,
DPN98
,
DPN107
,
DPN131
from
.shufflenet_v2_swish
import
ShuffleNetV2_swish
,
ShuffleNetV2_x0_5_swish
,
ShuffleNetV2_x1_0_swish
,
ShuffleNetV2_x1_5_swish
,
ShuffleNetV2_x2_0_swish
from
.shufflenet_v2
import
ShuffleNetV2_x0_25
,
ShuffleNetV2_x0_33
,
ShuffleNetV2_x0_5
,
ShuffleNetV2_x1_0
,
ShuffleNetV2_x1_5
,
ShuffleNetV2_x2_0
,
ShuffleNetV2
from
.xception
import
Xception41
,
Xception65
,
Xception71
from
.xception_deeplab
import
Xception41_deeplab
,
Xception65_deeplab
,
Xception71_deeplab
from
.densenet
import
DenseNet121
,
DenseNet161
,
DenseNet169
,
DenseNet201
,
DenseNet264
from
.squeezenet
import
SqueezeNet1_0
,
SqueezeNet1_1
from
.darknet
import
DarkNet53
from
.resnext101_wsl
import
ResNeXt101_32x8d_wsl
,
ResNeXt101_32x16d_wsl
,
ResNeXt101_32x32d_wsl
,
ResNeXt101_32x48d_wsl
,
Fix_ResNeXt101_32x48d_wsl
from
.efficientnet
import
EfficientNet
,
EfficientNetB0
,
EfficientNetB0_small
,
EfficientNetB1
,
EfficientNetB2
,
EfficientNetB3
,
EfficientNetB4
,
EfficientNetB5
,
EfficientNetB6
,
EfficientNetB7
from
.res2net
import
Res2Net50_48w_2s
,
Res2Net50_26w_4s
,
Res2Net50_14w_8s
,
Res2Net50_26w_6s
,
Res2Net50_26w_8s
,
Res2Net101_26w_4s
,
Res2Net152_26w_4s
from
.res2net_vd
import
Res2Net50_vd_48w_2s
,
Res2Net50_vd_26w_4s
,
Res2Net50_vd_14w_8s
,
Res2Net50_vd_26w_6s
,
Res2Net50_vd_26w_8s
,
Res2Net101_vd_26w_4s
,
Res2Net152_vd_26w_4s
,
Res2Net200_vd_26w_4s
from
.hrnet
import
HRNet_W18_C
,
HRNet_W30_C
,
HRNet_W32_C
,
HRNet_W40_C
,
HRNet_W44_C
,
HRNet_W48_C
,
HRNet_W60_C
,
HRNet_W64_C
,
SE_HRNet_W18_C
,
SE_HRNet_W30_C
,
SE_HRNet_W32_C
,
SE_HRNet_W40_C
,
SE_HRNet_W44_C
,
SE_HRNet_W48_C
,
SE_HRNet_W60_C
,
SE_HRNet_W64_C
from
.darts_gs
import
DARTS_GS_6M
,
DARTS_GS_4M
from
.resnet_acnet
import
ResNet18_ACNet
,
ResNet34_ACNet
,
ResNet50_ACNet
,
ResNet101_ACNet
,
ResNet152_ACNet
# distillation model
from
.distillation_models
import
ResNet50_vd_distill_MobileNetV3_large_x1_0
,
ResNeXt101_32x16d_wsl_distill_ResNet50_vd
from
.csp_resnet
import
CSPResNet50_leaky
\ No newline at end of file
ppcls/modeling/architectures/resnet.py
浏览文件 @
bcf563ff
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
#Licensed under the Apache License, Version 2.0 (the "License");
#
Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#
you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
#Unless required by applicable law or agreed to in writing, software
#
Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#
distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#
See the License for the specific language governing permissions and
#limitations under the License.
#
limitations under the License.
from
__future__
import
absolute_import
import
paddle.fluid
as
fluid
from
__future__
import
division
from
paddle.fluid.layer_helper
import
LayerHelper
from
__future__
import
print_function
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
import
math
import
math
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
__all__
=
[
__all__
=
[
"ResNet"
,
"ResNet18"
,
"ResNet34"
,
"ResNet50"
,
"ResNet101"
,
"ResNet152"
"ResNet18"
,
"ResNet34"
,
"ResNet50"
,
"ResNet101"
,
"ResNet152"
,
]
]
class
ResNet
():
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
layers
=
50
):
def
__init__
(
self
,
self
.
layers
=
layers
num_channels
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_conv
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
bias_attr
=
False
)
self
.
_batch_norm
=
BatchNorm
(
num_filters
,
act
=
act
)
def
forward
(
self
,
inputs
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
BottleneckBlock
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
shortcut
=
True
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
'relu'
)
self
.
conv1
=
ConvBNLayer
(
num_channels
=
num_filters
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
stride
,
act
=
'relu'
)
self
.
conv2
=
ConvBNLayer
(
num_channels
=
num_filters
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
act
=
None
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
stride
=
stride
)
self
.
shortcut
=
shortcut
self
.
_num_channels_out
=
num_filters
*
4
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv2
)
layer_helper
=
LayerHelper
(
self
.
full_name
(),
act
=
'relu'
)
return
layer_helper
.
append_activation
(
y
)
def
net
(
self
,
input
,
class_dim
=
1000
,
data_format
=
"NCHW"
):
class
ResNet
(
fluid
.
dygraph
.
Layer
):
layers
=
self
.
layers
def
__init__
(
self
,
layers
=
50
,
class_dim
=
1000
):
supported_layers
=
[
18
,
34
,
50
,
101
,
152
]
super
(
ResNet
,
self
).
__init__
()
self
.
layers
=
layers
supported_layers
=
[
50
,
101
,
152
]
assert
layers
in
supported_layers
,
\
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
18
:
if
layers
==
50
:
depth
=
[
2
,
2
,
2
,
2
]
elif
layers
==
34
or
layers
==
50
:
depth
=
[
3
,
4
,
6
,
3
]
depth
=
[
3
,
4
,
6
,
3
]
elif
layers
==
101
:
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
depth
=
[
3
,
8
,
36
,
3
]
num_channels
=
[
64
,
256
,
512
,
1024
]
num_filters
=
[
64
,
128
,
256
,
512
]
num_filters
=
[
64
,
128
,
256
,
512
]
conv
=
self
.
conv_bn_l
ayer
(
self
.
conv
=
ConvBNL
ayer
(
input
=
input
,
num_channels
=
3
,
num_filters
=
64
,
num_filters
=
64
,
filter_size
=
7
,
filter_size
=
7
,
stride
=
2
,
stride
=
2
,
act
=
'relu'
,
act
=
'relu'
)
name
=
"conv1"
,
self
.
pool2d_max
=
Pool2D
(
data_format
=
data_format
)
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
conv
=
fluid
.
layers
.
pool2d
(
input
=
conv
,
self
.
bottleneck_block_list
=
[]
pool_size
=
3
,
for
block
in
range
(
len
(
depth
)):
pool_stride
=
2
,
shortcut
=
False
pool_padding
=
1
,
for
i
in
range
(
depth
[
block
]):
pool_type
=
'max'
,
bottleneck_block
=
self
.
add_sublayer
(
data_format
=
data_format
)
'bb_%d_%d'
%
(
block
,
i
),
if
layers
>=
50
:
BottleneckBlock
(
for
block
in
range
(
len
(
depth
)):
num_channels
=
num_channels
[
block
]
for
i
in
range
(
depth
[
block
]):
if
i
==
0
else
num_filters
[
block
]
*
4
,
if
layers
in
[
101
,
152
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
conv
=
self
.
bottleneck_block
(
input
=
conv
,
num_filters
=
num_filters
[
block
],
num_filters
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
name
=
conv_name
,
shortcut
=
shortcut
))
data_format
=
data_format
)
self
.
bottleneck_block_list
.
append
(
bottleneck_block
)
shortcut
=
True
else
:
self
.
pool2d_avg
=
Pool2D
(
for
block
in
range
(
len
(
depth
)):
pool_size
=
7
,
pool_type
=
'avg'
,
global_pooling
=
True
)
for
i
in
range
(
depth
[
block
]):
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
self
.
pool2d_avg_output
=
num_filters
[
len
(
num_filters
)
-
1
]
*
4
*
1
*
1
conv
=
self
.
basic_block
(
input
=
conv
,
stdv
=
1.0
/
math
.
sqrt
(
2048
*
1.0
)
num_filters
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
self
.
out
=
Linear
(
is_first
=
block
==
i
==
0
,
self
.
pool2d_avg_output
,
name
=
conv_name
,
class_dim
,
data_format
=
data_format
)
pool
=
fluid
.
layers
.
pool2d
(
input
=
conv
,
pool_type
=
'avg'
,
global_pooling
=
True
,
data_format
=
data_format
)
stdv
=
1.0
/
math
.
sqrt
(
pool
.
shape
[
1
]
*
1.0
)
out
=
fluid
.
layers
.
fc
(
input
=
pool
,
size
=
class_dim
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
name
=
"fc_0.w_0"
,
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
)))
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
)),
bias_attr
=
ParamAttr
(
name
=
"fc_0.b_0"
))
return
out
def
conv_bn_layer
(
self
,
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
,
data_format
=
'NCHW'
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
,
name
=
name
+
'.conv2d.output.1'
,
data_format
=
data_format
)
if
name
==
"conv1"
:
def
forward
(
self
,
inputs
):
bn_name
=
"bn_"
+
name
y
=
self
.
conv
(
inputs
)
else
:
y
=
self
.
pool2d_max
(
y
)
bn_name
=
"bn"
+
name
[
3
:]
for
bottleneck_block
in
self
.
bottleneck_block_list
:
return
fluid
.
layers
.
batch_norm
(
y
=
bottleneck_block
(
y
)
input
=
conv
,
y
=
self
.
pool2d_avg
(
y
)
act
=
act
,
y
=
fluid
.
layers
.
reshape
(
y
,
shape
=
[
-
1
,
self
.
pool2d_avg_output
])
name
=
bn_name
+
'.output.1'
,
y
=
self
.
out
(
y
)
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
return
y
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
data_layout
=
data_format
)
def
shortcut
(
self
,
input
,
ch_out
,
stride
,
is_first
,
name
,
data_format
):
if
data_format
==
'NCHW'
:
ch_in
=
input
.
shape
[
1
]
else
:
ch_in
=
input
.
shape
[
-
1
]
if
ch_in
!=
ch_out
or
stride
!=
1
or
is_first
==
True
:
return
self
.
conv_bn_layer
(
input
,
ch_out
,
1
,
stride
,
name
=
name
,
data_format
=
data_format
)
else
:
return
input
def
bottleneck_block
(
self
,
input
,
num_filters
,
stride
,
name
,
data_format
):
conv0
=
self
.
conv_bn_layer
(
def
ResNet18
(
**
kwargs
):
input
=
input
,
model
=
ResNet
(
layers
=
18
,
**
kwargs
)
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
,
data_format
=
data_format
)
conv1
=
self
.
conv_bn_layer
(
input
=
conv0
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2b"
,
data_format
=
data_format
)
conv2
=
self
.
conv_bn_layer
(
input
=
conv1
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
act
=
None
,
name
=
name
+
"_branch2c"
,
data_format
=
data_format
)
short
=
self
.
shortcut
(
input
,
num_filters
*
4
,
stride
,
is_first
=
False
,
name
=
name
+
"_branch1"
,
data_format
=
data_format
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv2
,
act
=
'relu'
,
name
=
name
+
".add.output.5"
)
def
basic_block
(
self
,
input
,
num_filters
,
stride
,
is_first
,
name
,
data_format
):
conv0
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
'relu'
,
stride
=
stride
,
name
=
name
+
"_branch2a"
,
data_format
=
data_format
)
conv1
=
self
.
conv_bn_layer
(
input
=
conv0
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
,
data_format
=
data_format
)
short
=
self
.
shortcut
(
input
,
num_filters
,
stride
,
is_first
,
name
=
name
+
"_branch1"
,
data_format
=
data_format
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
def
ResNet18
():
model
=
ResNet
(
layers
=
18
)
return
model
return
model
def
ResNet34
():
def
ResNet34
(
**
kwargs
):
model
=
ResNet
(
layers
=
34
)
model
=
ResNet
(
layers
=
34
,
**
kwargs
)
return
model
return
model
def
ResNet50
():
def
ResNet50
(
**
kwargs
):
model
=
ResNet
(
layers
=
50
)
model
=
ResNet
(
layers
=
50
,
**
kwargs
)
return
model
return
model
def
ResNet101
():
def
ResNet101
(
**
kwargs
):
model
=
ResNet
(
layers
=
101
)
model
=
ResNet
(
layers
=
101
,
**
kwargs
)
return
model
return
model
def
ResNet152
():
def
ResNet152
(
class_dim
=
1000
):
model
=
ResNet
(
layers
=
152
)
model
=
ResNet
(
layers
=
152
,
class_dim
=
class_dim
)
return
model
return
model
ppcls/optimizer/optimizer.py
浏览文件 @
bcf563ff
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
#Licensed under the Apache License, Version 2.0 (the "License");
#
Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#
you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
#Unless required by applicable law or agreed to in writing, software
#
Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#
distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#
See the License for the specific language governing permissions and
#limitations under the License.
#
limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
@@ -48,6 +48,8 @@ class OptimizerBuilder(object):
...
@@ -48,6 +48,8 @@ class OptimizerBuilder(object):
reg
=
getattr
(
pfreg
,
reg_func
)(
reg_factor
)
reg
=
getattr
(
pfreg
,
reg_func
)(
reg_factor
)
self
.
params
[
'regularization'
]
=
reg
self
.
params
[
'regularization'
]
=
reg
def
__call__
(
self
,
learning_rate
):
def
__call__
(
self
,
learning_rate
,
parameter_list
):
opt
=
getattr
(
pfopt
,
self
.
function
)
opt
=
getattr
(
pfopt
,
self
.
function
)
return
opt
(
learning_rate
=
learning_rate
,
**
self
.
params
)
return
opt
(
learning_rate
=
learning_rate
,
parameter_list
=
parameter_list
,
**
self
.
params
)
tools/program.py
浏览文件 @
bcf563ff
...
@@ -33,41 +33,12 @@ from ppcls.modeling.loss import GoogLeNetLoss
...
@@ -33,41 +33,12 @@ from ppcls.modeling.loss import GoogLeNetLoss
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.incubate.fleet.collective
import
fleet
from
paddle.fluid.incubate.fleet.collective
import
fleet
from
paddle.fluid.incubate.fleet.collective
import
DistributedStrategy
from
paddle.fluid.incubate.fleet.collective
import
DistributedStrategy
from
ema
import
ExponentialMovingAverage
def
create_dataloader
():
def
create_feeds
(
image_shape
,
use_mix
=
None
):
"""
Create feeds as model input
Args:
image_shape(list[int]): model input shape, such as [3, 224, 224]
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
feeds(dict): dict of model input variables
"""
feeds
=
OrderedDict
()
feeds
[
'image'
]
=
fluid
.
data
(
name
=
"feed_image"
,
shape
=
[
None
]
+
image_shape
,
dtype
=
"float32"
)
if
use_mix
:
feeds
[
'feed_y_a'
]
=
fluid
.
data
(
name
=
"feed_y_a"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
feeds
[
'feed_y_b'
]
=
fluid
.
data
(
name
=
"feed_y_b"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
feeds
[
'feed_lam'
]
=
fluid
.
data
(
name
=
"feed_lam"
,
shape
=
[
None
,
1
],
dtype
=
"float32"
)
else
:
feeds
[
'label'
]
=
fluid
.
data
(
name
=
"feed_label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
return
feeds
def
create_dataloader
(
feeds
):
"""
"""
Create a dataloader with model input variables
Create a dataloader with model input variables
...
@@ -80,7 +51,6 @@ def create_dataloader(feeds):
...
@@ -80,7 +51,6 @@ def create_dataloader(feeds):
trainer_num
=
int
(
os
.
environ
.
get
(
'PADDLE_TRAINERS_NUM'
,
1
))
trainer_num
=
int
(
os
.
environ
.
get
(
'PADDLE_TRAINERS_NUM'
,
1
))
capacity
=
64
if
trainer_num
<=
1
else
8
capacity
=
64
if
trainer_num
<=
1
else
8
dataloader
=
fluid
.
io
.
DataLoader
.
from_generator
(
dataloader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
feeds
,
capacity
=
capacity
,
capacity
=
capacity
,
use_double_buffer
=
True
,
use_double_buffer
=
True
,
iterable
=
True
)
iterable
=
True
)
...
@@ -88,7 +58,7 @@ def create_dataloader(feeds):
...
@@ -88,7 +58,7 @@ def create_dataloader(feeds):
return
dataloader
return
dataloader
def
create_model
(
architecture
,
image
,
classes_num
,
is_train
):
def
create_model
(
architecture
,
classes_num
):
"""
"""
Create a model
Create a model
...
@@ -103,15 +73,11 @@ def create_model(architecture, image, classes_num, is_train):
...
@@ -103,15 +73,11 @@ def create_model(architecture, image, classes_num, is_train):
"""
"""
name
=
architecture
[
"name"
]
name
=
architecture
[
"name"
]
params
=
architecture
.
get
(
"params"
,
{})
params
=
architecture
.
get
(
"params"
,
{})
if
"is_test"
in
params
:
return
architectures
.
__dict__
[
name
](
class_dim
=
classes_num
,
**
params
)
params
[
'is_test'
]
=
not
is_train
model
=
architectures
.
__dict__
[
name
](
**
params
)
out
=
model
.
net
(
input
=
image
,
class_dim
=
classes_num
)
return
out
def
create_loss
(
out
,
def
create_loss
(
out
,
feeds
,
label
,
architecture
,
architecture
,
classes_num
=
1000
,
classes_num
=
1000
,
epsilon
=
None
,
epsilon
=
None
,
...
@@ -140,8 +106,7 @@ def create_loss(out,
...
@@ -140,8 +106,7 @@ def create_loss(out,
if
architecture
[
"name"
]
==
"GoogLeNet"
:
if
architecture
[
"name"
]
==
"GoogLeNet"
:
assert
len
(
out
)
==
3
,
"GoogLeNet should have 3 outputs"
assert
len
(
out
)
==
3
,
"GoogLeNet should have 3 outputs"
loss
=
GoogLeNetLoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
loss
=
GoogLeNetLoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
target
=
feeds
[
'label'
]
return
loss
(
out
[
0
],
out
[
1
],
out
[
2
],
label
)
return
loss
(
out
[
0
],
out
[
1
],
out
[
2
],
target
)
if
use_distillation
:
if
use_distillation
:
assert
len
(
out
)
==
2
,
(
"distillation output length must be 2, "
assert
len
(
out
)
==
2
,
(
"distillation output length must be 2, "
...
@@ -151,18 +116,18 @@ def create_loss(out,
...
@@ -151,18 +116,18 @@ def create_loss(out,
if
use_mix
:
if
use_mix
:
loss
=
MixCELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
loss
=
MixCELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
feed_y_a
=
feeds
[
'feed_y_a'
]
raise
NotImplementedError
feed_y_b
=
feeds
[
'feed_y_b'
]
#feed_y_a = feeds['feed_y_a']
feed_lam
=
feeds
[
'feed_lam'
]
#feed_y_b = feeds['feed_y_b']
return
loss
(
out
,
feed_y_a
,
feed_y_b
,
feed_lam
)
#feed_lam = feeds['feed_lam']
#return loss(out, feed_y_a, feed_y_b, feed_lam)
else
:
else
:
loss
=
CELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
loss
=
CELoss
(
class_dim
=
classes_num
,
epsilon
=
epsilon
)
target
=
feeds
[
'label'
]
return
loss
(
out
,
label
)
return
loss
(
out
,
target
)
def
create_metric
(
out
,
def
create_metric
(
out
,
feeds
,
label
,
architecture
,
architecture
,
topk
=
5
,
topk
=
5
,
classes_num
=
1000
,
classes_num
=
1000
,
...
@@ -190,19 +155,19 @@ def create_metric(out,
...
@@ -190,19 +155,19 @@ def create_metric(out,
fetchs
=
OrderedDict
()
fetchs
=
OrderedDict
()
# set top1 to fetchs
# set top1 to fetchs
top1
=
fluid
.
layers
.
accuracy
(
softmax_out
,
label
=
feeds
[
'label'
]
,
k
=
1
)
top1
=
fluid
.
layers
.
accuracy
(
softmax_out
,
label
=
label
,
k
=
1
)
fetchs
[
'top1'
]
=
(
top1
,
AverageMeter
(
'top1'
,
'.4f'
,
need_avg
=
True
))
fetchs
[
'top1'
]
=
top1
# set topk to fetchs
# set topk to fetchs
k
=
min
(
topk
,
classes_num
)
k
=
min
(
topk
,
classes_num
)
topk
=
fluid
.
layers
.
accuracy
(
softmax_out
,
label
=
feeds
[
'label'
]
,
k
=
k
)
topk
=
fluid
.
layers
.
accuracy
(
softmax_out
,
label
=
label
,
k
=
k
)
topk_name
=
'top{}'
.
format
(
k
)
topk_name
=
'top{}'
.
format
(
k
)
fetchs
[
topk_name
]
=
(
topk
,
AverageMeter
(
topk_name
,
'.4f'
,
need_avg
=
True
))
fetchs
[
topk_name
]
=
topk
return
fetchs
return
fetchs
def
create_fetchs
(
out
,
def
create_fetchs
(
out
,
feeds
,
label
,
architecture
,
architecture
,
topk
=
5
,
topk
=
5
,
classes_num
=
1000
,
classes_num
=
1000
,
...
@@ -228,18 +193,17 @@ def create_fetchs(out,
...
@@ -228,18 +193,17 @@ def create_fetchs(out,
fetchs(dict): dict of model outputs(included loss and measures)
fetchs(dict): dict of model outputs(included loss and measures)
"""
"""
fetchs
=
OrderedDict
()
fetchs
=
OrderedDict
()
loss
=
create_loss
(
out
,
feeds
,
architecture
,
classes_num
,
epsilon
,
use_mix
,
fetchs
[
'loss'
]
=
create_loss
(
out
,
label
,
architecture
,
classes_num
,
epsilon
,
use_mix
,
use_distillation
)
use_distillation
)
fetchs
[
'loss'
]
=
(
loss
,
AverageMeter
(
'loss'
,
'7.4f'
,
need_avg
=
True
))
if
not
use_mix
:
if
not
use_mix
:
metric
=
create_metric
(
out
,
feeds
,
architecture
,
topk
,
classes_num
,
metric
=
create_metric
(
out
,
label
,
architecture
,
topk
,
classes_num
,
use_distillation
)
use_distillation
)
fetchs
.
update
(
metric
)
fetchs
.
update
(
metric
)
return
fetchs
return
fetchs
def
create_optimizer
(
config
):
def
create_optimizer
(
config
,
parameter_list
=
None
):
"""
"""
Create an optimizer using config, usually including
Create an optimizer using config, usually including
learning rate and regularization.
learning rate and regularization.
...
@@ -274,7 +238,7 @@ def create_optimizer(config):
...
@@ -274,7 +238,7 @@ def create_optimizer(config):
# create optimizer instance
# create optimizer instance
opt_config
=
config
[
'OPTIMIZER'
]
opt_config
=
config
[
'OPTIMIZER'
]
opt
=
OptimizerBuilder
(
**
opt_config
)
opt
=
OptimizerBuilder
(
**
opt_config
)
return
opt
(
lr
)
return
opt
(
lr
,
parameter_list
)
def
dist_optimizer
(
config
,
optimizer
):
def
dist_optimizer
(
config
,
optimizer
):
...
@@ -314,7 +278,7 @@ def mixed_precision_optimizer(config, optimizer):
...
@@ -314,7 +278,7 @@ def mixed_precision_optimizer(config, optimizer):
return
optimizer
return
optimizer
def
build
(
config
,
main_prog
,
startup_prog
,
is_train
=
True
):
def
compute
(
config
,
out
,
label
,
mode
=
'train'
):
"""
"""
Build a program using a model and an optimizer
Build a program using a model and an optimizer
1. create feeds
1. create feeds
...
@@ -333,79 +297,20 @@ def build(config, main_prog, startup_prog, is_train=True):
...
@@ -333,79 +297,20 @@ def build(config, main_prog, startup_prog, is_train=True):
dataloader(): a bridge between the model and the data
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
fetchs(dict): dict of model outputs(included loss and measures)
"""
"""
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
fetchs
=
create_fetchs
(
with
fluid
.
unique_name
.
guard
():
out
,
use_mix
=
config
.
get
(
'use_mix'
)
and
is_train
label
,
use_distillation
=
config
.
get
(
'use_distillation'
)
config
.
ARCHITECTURE
,
feeds
=
create_feeds
(
config
.
image_shape
,
use_mix
=
use_mix
)
config
.
topk
,
dataloader
=
create_dataloader
(
feeds
.
values
())
config
.
classes_num
,
out
=
create_model
(
config
.
ARCHITECTURE
,
feeds
[
'image'
],
epsilon
=
config
.
get
(
'ls_epsilon'
),
config
.
classes_num
,
is_train
)
use_mix
=
config
.
get
(
'use_mix'
)
and
mode
==
'train'
,
fetchs
=
create_fetchs
(
use_distillation
=
config
.
get
(
'use_distillation'
))
out
,
feeds
,
config
.
ARCHITECTURE
,
config
.
topk
,
config
.
classes_num
,
epsilon
=
config
.
get
(
'ls_epsilon'
),
use_mix
=
use_mix
,
use_distillation
=
use_distillation
)
if
is_train
:
optimizer
=
create_optimizer
(
config
)
lr
=
optimizer
.
_global_learning_rate
()
fetchs
[
'lr'
]
=
(
lr
,
AverageMeter
(
'lr'
,
'f'
,
need_avg
=
False
))
optimizer
=
mixed_precision_optimizer
(
config
,
optimizer
)
optimizer
=
dist_optimizer
(
config
,
optimizer
)
optimizer
.
minimize
(
fetchs
[
'loss'
][
0
])
if
config
.
get
(
'use_ema'
):
global_steps
=
fluid
.
layers
.
learning_rate_scheduler
.
_decay_step_counter
(
)
ema
=
ExponentialMovingAverage
(
config
.
get
(
'ema_decay'
),
thres_steps
=
global_steps
)
ema
.
update
()
return
dataloader
,
fetchs
,
ema
return
dataloader
,
fetchs
def
compile
(
config
,
program
,
loss_name
=
None
):
"""
Compile the program
Args:
return
fetchs
config(dict): config
program(): the program which is wrapped by
loss_name(str): loss name
Returns:
compiled_program(): a compiled program
"""
build_strategy
=
fluid
.
compiler
.
BuildStrategy
()
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
num_threads
=
1
exec_strategy
.
num_iteration_per_drop_scope
=
10
compiled_program
=
fluid
.
CompiledProgram
(
program
).
with_data_parallel
(
loss_name
=
loss_name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
return
compiled_program
total_step
=
0
def
run
(
dataloader
,
def
run
(
dataloader
,
config
,
net
,
optimizer
=
None
,
epoch
=
0
,
mode
=
'train'
):
exe
,
program
,
fetchs
,
epoch
=
0
,
mode
=
'train'
,
vdl_writer
=
None
):
"""
"""
Feed data to the model and fetch the measures and loss
Feed data to the model and fetch the measures and loss
...
@@ -419,48 +324,58 @@ def run(dataloader,
...
@@ -419,48 +324,58 @@ def run(dataloader,
Returns:
Returns:
"""
"""
fetch_list
=
[
f
[
0
]
for
f
in
fetchs
.
values
()]
topk_name
=
'top{}'
.
format
(
config
.
topk
)
metric_list
=
[
f
[
1
]
for
f
in
fetchs
.
values
()]
metric_list
=
OrderedDict
([
for
m
in
metric_list
:
(
"loss"
,
AverageMeter
(
'loss'
,
'7.4f'
)),
m
.
reset
()
(
"top1"
,
AverageMeter
(
'top1'
,
'.4f'
)),
batch_time
=
AverageMeter
(
'elapse'
,
'.3f'
)
(
topk_name
,
AverageMeter
(
topk_name
,
'.4f'
)),
(
"lr"
,
AverageMeter
(
'lr'
,
'f'
,
need_avg
=
False
)),
(
"batch_time"
,
AverageMeter
(
'elapse'
,
'.3f'
)),
])
tic
=
time
.
time
()
tic
=
time
.
time
()
for
idx
,
batch
in
enumerate
(
dataloader
()):
for
idx
,
(
img
,
label
)
in
enumerate
(
dataloader
()):
metrics
=
exe
.
run
(
program
=
program
,
feed
=
batch
,
fetch_list
=
fetch_list
)
label
=
to_variable
(
label
.
numpy
().
astype
(
'int64'
).
reshape
(
-
1
,
1
))
batch_time
.
update
(
time
.
time
()
-
tic
)
fetchs
=
compute
(
config
,
net
(
img
),
label
,
mode
)
if
mode
==
'train'
:
avg_loss
=
net
.
scale_loss
(
fetchs
[
'loss'
])
avg_loss
.
backward
()
net
.
apply_collective_grads
()
optimizer
.
minimize
(
avg_loss
)
net
.
clear_gradients
()
metric_list
[
'lr'
].
update
(
optimizer
.
_global_learning_rate
().
numpy
()[
0
],
len
(
img
))
for
name
,
fetch
in
fetchs
.
items
():
metric_list
[
name
].
update
(
fetch
.
numpy
()[
0
],
len
(
img
))
metric_list
[
'batch_time'
].
update
(
time
.
time
()
-
tic
)
tic
=
time
.
time
()
tic
=
time
.
time
()
for
i
,
m
in
enumerate
(
metrics
):
metric_list
[
i
].
update
(
m
[
0
],
len
(
batch
[
0
]))
fetchs_str
=
' '
.
join
([
str
(
m
.
value
)
for
m
in
metric_list
.
values
()])
fetchs_str
=
''
.
join
([
str
(
m
.
value
)
+
' '
for
m
in
metric_list
]
+
[
batch_time
.
value
])
+
's'
if
vdl_writer
:
global
total_step
logger
.
scaler
(
'loss'
,
metrics
[
0
][
0
],
total_step
,
vdl_writer
)
total_step
+=
1
if
mode
==
'eval'
:
if
mode
==
'eval'
:
logger
.
info
(
"{:s} step:{:<4d} {:s}s"
.
format
(
mode
,
idx
,
fetchs_str
))
logger
.
info
(
"{:s} step:{:<4d} {:s}s"
.
format
(
mode
,
idx
,
fetchs_str
))
else
:
else
:
epoch_str
=
"epoch:{:<3d}"
.
format
(
epoch
)
epoch_str
=
"epoch:{:<3d}"
.
format
(
epoch
)
step_str
=
"{:s} step:{:<4d}"
.
format
(
mode
,
idx
)
step_str
=
"{:s} step:{:<4d}"
.
format
(
mode
,
idx
)
logger
.
info
(
"{:s} {:s} {:s}"
.
format
(
logger
.
info
(
"{:s} {:s} {:s}
s
"
.
format
(
logger
.
coloring
(
epoch_str
,
"HEADER"
)
logger
.
coloring
(
epoch_str
,
"HEADER"
)
if
idx
==
0
else
epoch_str
,
if
idx
==
0
else
epoch_str
,
logger
.
coloring
(
step_str
,
"PURPLE"
),
logger
.
coloring
(
step_str
,
"PURPLE"
),
logger
.
coloring
(
fetchs_str
,
'OKGREEN'
)))
logger
.
coloring
(
fetchs_str
,
'OKGREEN'
)))
end_str
=
''
.
join
([
str
(
m
.
mean
)
+
' '
end_str
=
' '
.
join
([
str
(
m
.
mean
)
for
m
in
metric_list
.
values
()]
+
[
metric_list
[
'batch_time'
].
total
])
for
m
in
metric_list
]
+
[
batch_time
.
total
])
+
's'
if
mode
==
'eval'
:
if
mode
==
'eval'
:
logger
.
info
(
"END {:s} {:s}s"
.
format
(
mode
,
end_str
))
logger
.
info
(
"END {:s} {:s}s"
.
format
(
mode
,
end_str
))
else
:
else
:
end_epoch_str
=
"END epoch:{:<3d}"
.
format
(
epoch
)
end_epoch_str
=
"END epoch:{:<3d}"
.
format
(
epoch
)
logger
.
info
(
"{:s} {:s} {:s}"
.
format
(
logger
.
info
(
"{:s} {:s} {:s}
s
"
.
format
(
logger
.
coloring
(
end_epoch_str
,
"RED"
),
logger
.
coloring
(
end_epoch_str
,
"RED"
),
logger
.
coloring
(
mode
,
"PURPLE"
),
logger
.
coloring
(
mode
,
"PURPLE"
),
logger
.
coloring
(
end_str
,
"OKGREEN"
)))
logger
.
coloring
(
end_str
,
"OKGREEN"
)))
# return top1_acc in order to save the best model
# return top1_acc in order to save the best model
if
mode
==
'valid'
:
if
mode
==
'valid'
:
return
fetchs
[
"top1"
][
1
].
avg
return
metric_list
[
'top1'
].
avg
\ No newline at end of file
tools/train.py
浏览文件 @
bcf563ff
...
@@ -19,10 +19,7 @@ from __future__ import print_function
...
@@ -19,10 +19,7 @@ from __future__ import print_function
import
argparse
import
argparse
import
os
import
os
from
visualdl
import
LogWriter
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.incubate.fleet.base
import
role_maker
from
paddle.fluid.incubate.fleet.collective
import
fleet
from
ppcls.data
import
Reader
from
ppcls.data
import
Reader
from
ppcls.utils.config
import
get_config
from
ppcls.utils.config
import
get_config
...
@@ -39,11 +36,6 @@ def parse_args():
...
@@ -39,11 +36,6 @@ def parse_args():
type
=
str
,
type
=
str
,
default
=
'configs/ResNet/ResNet50.yaml'
,
default
=
'configs/ResNet/ResNet50.yaml'
,
help
=
'config file path'
)
help
=
'config file path'
)
parser
.
add_argument
(
'--vdl_dir'
,
type
=
str
,
default
=
None
,
help
=
'VisualDL logging directory for image.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'-o'
,
'-o'
,
'--override'
,
'--override'
,
...
@@ -55,91 +47,65 @@ def parse_args():
...
@@ -55,91 +47,65 @@ def parse_args():
def
main
(
args
):
def
main
(
args
):
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
fleet
.
init
(
role
)
config
=
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
config
=
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
# assign the place
# assign the place
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
gpu_id
=
fluid
.
dygraph
.
parallel
.
Env
().
dev_id
place
=
fluid
.
CUDAPlace
(
gpu_id
)
place
=
fluid
.
CUDAPlace
(
gpu_id
)
# startup_prog is used to do some parameter init work,
with
fluid
.
dygraph
.
guard
(
place
):
# and train prog is used to hold the network
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
startup_prog
=
fluid
.
Program
()
net
=
program
.
create_model
(
config
.
ARCHITECTURE
,
config
.
classes_num
)
train_prog
=
fluid
.
Program
()
net
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
net
,
strategy
)
best_top1_acc
=
0.0
# best top1 acc record
optimizer
=
program
.
create_optimizer
(
config
,
parameter_list
=
net
.
parameters
())
if
not
config
.
get
(
'use_ema'
):
train_dataloader
,
train_fetchs
=
program
.
build
(
# load model from checkpoint or pretrained model
config
,
train_prog
,
startup_prog
,
is_train
=
True
)
init_model
(
config
,
net
,
optimizer
)
else
:
train_dataloader
,
train_fetchs
,
ema
=
program
.
build
(
train_dataloader
=
program
.
create_dataloader
()
config
,
train_prog
,
startup_prog
,
is_train
=
True
)
train_reader
=
Reader
(
config
,
'train'
)()
train_dataloader
.
set_sample_list_generator
(
train_reader
,
place
)
if
config
.
validate
:
valid_prog
=
fluid
.
Program
()
if
config
.
validate
:
valid_dataloader
,
valid_fetchs
=
program
.
build
(
valid_dataloader
=
program
.
create_dataloader
()
config
,
valid_prog
,
startup_prog
,
is_train
=
False
)
valid_reader
=
Reader
(
config
,
'valid'
)()
# clone to prune some content which is irrelevant in valid_prog
valid_dataloader
.
set_sample_list_generator
(
valid_reader
,
place
)
valid_prog
=
valid_prog
.
clone
(
for_test
=
True
)
best_top1_acc
=
0.0
# best top1 acc record
# create the "Executor" with the statement of which place
for
epoch_id
in
range
(
config
.
epochs
):
exe
=
fluid
.
Executor
(
place
)
net
.
train
()
# Parameter initialization
# 1. train with train dataset
exe
.
run
(
startup_prog
)
program
.
run
(
train_dataloader
,
config
,
net
,
optimizer
,
epoch_id
,
'train'
)
# load model from 1. checkpoint to resume training, 2. pretrained model to finetune
init_model
(
config
,
train_prog
,
exe
)
if
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
:
# 2. validate with validate dataset
train_reader
=
Reader
(
config
,
'train'
)()
if
config
.
validate
and
epoch_id
%
config
.
valid_interval
==
0
:
train_dataloader
.
set_sample_list_generator
(
train_reader
,
place
)
net
.
eval
()
top1_acc
=
program
.
run
(
valid_dataloader
,
config
,
net
,
None
,
if
config
.
validate
:
epoch_id
,
'valid'
)
valid_reader
=
Reader
(
config
,
'valid'
)()
if
top1_acc
>
best_top1_acc
:
valid_dataloader
.
set_sample_list_generator
(
valid_reader
,
place
)
best_top1_acc
=
top1_acc
compiled_valid_prog
=
program
.
compile
(
config
,
valid_prog
)
message
=
"The best top1 acc {:.5f}, in epoch: {:d}"
.
format
(
best_top1_acc
,
epoch_id
)
compiled_train_prog
=
fleet
.
main_program
logger
.
info
(
"{:s}"
.
format
(
vdl_writer
=
LogWriter
(
args
.
vdl_dir
)
if
args
.
vdl_dir
else
None
logger
.
coloring
(
message
,
"RED"
)))
if
epoch_id
%
config
.
save_interval
==
0
:
for
epoch_id
in
range
(
config
.
epochs
):
# 1. train with train dataset
model_path
=
os
.
path
.
join
(
program
.
run
(
train_dataloader
,
exe
,
compiled_train_prog
,
train_fetchs
,
config
.
model_save_dir
,
epoch_id
,
'train'
,
vdl_writer
)
config
.
ARCHITECTURE
[
"name"
])
if
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
==
0
:
save_model
(
net
,
optimizer
,
model_path
,
# 2. validate with validate dataset
"best_model_in_epoch_"
+
str
(
epoch_id
))
if
config
.
validate
and
epoch_id
%
config
.
valid_interval
==
0
:
if
config
.
get
(
'use_ema'
):
# 3. save the persistable model
logger
.
info
(
logger
.
coloring
(
"EMA validate start..."
))
if
epoch_id
%
config
.
save_interval
==
0
:
with
ema
.
apply
(
exe
):
model_path
=
os
.
path
.
join
(
config
.
model_save_dir
,
top1_acc
=
program
.
run
(
valid_dataloader
,
exe
,
config
.
ARCHITECTURE
[
"name"
])
compiled_valid_prog
,
save_model
(
net
,
optimizer
,
model_path
,
epoch_id
)
valid_fetchs
,
epoch_id
,
'valid'
)
logger
.
info
(
logger
.
coloring
(
"EMA validate over!"
))
top1_acc
=
program
.
run
(
valid_dataloader
,
exe
,
compiled_valid_prog
,
valid_fetchs
,
epoch_id
,
'valid'
)
if
top1_acc
>
best_top1_acc
:
best_top1_acc
=
top1_acc
message
=
"The best top1 acc {:.5f}, in epoch: {:d}"
.
format
(
best_top1_acc
,
epoch_id
)
logger
.
info
(
"{:s}"
.
format
(
logger
.
coloring
(
message
,
"RED"
)))
if
epoch_id
%
config
.
save_interval
==
0
:
model_path
=
os
.
path
.
join
(
config
.
model_save_dir
,
config
.
ARCHITECTURE
[
"name"
])
save_model
(
train_prog
,
model_path
,
"best_model_in_epoch_"
+
str
(
epoch_id
))
# 3. save the persistable model
if
epoch_id
%
config
.
save_interval
==
0
:
model_path
=
os
.
path
.
join
(
config
.
model_save_dir
,
config
.
ARCHITECTURE
[
"name"
])
save_model
(
train_prog
,
model_path
,
epoch_id
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
=
parse_args
()
args
=
parse_args
()
main
(
args
)
main
(
args
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录