Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
22b8805b
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
22b8805b
编写于
2月 12, 2020
作者:
C
chajchaj
提交者:
GitHub
2月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add features: train with cpu, save and load checkpoint (#4259)
上级
9983e3a9
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
149 addition
and
43 deletion
+149
-43
dygraph/mobilenet/README.md
dygraph/mobilenet/README.md
+67
-0
dygraph/mobilenet/mobilenet_v1.py
dygraph/mobilenet/mobilenet_v1.py
+2
-3
dygraph/mobilenet/mobilenet_v2.py
dygraph/mobilenet/mobilenet_v2.py
+2
-6
dygraph/mobilenet/run_cpu_v1.sh
dygraph/mobilenet/run_cpu_v1.sh
+1
-0
dygraph/mobilenet/run_cpu_v2.sh
dygraph/mobilenet/run_cpu_v2.sh
+1
-0
dygraph/mobilenet/run_mul_v1.sh
dygraph/mobilenet/run_mul_v1.sh
+1
-1
dygraph/mobilenet/run_mul_v1_checkpoint.sh
dygraph/mobilenet/run_mul_v1_checkpoint.sh
+2
-0
dygraph/mobilenet/run_mul_v2.sh
dygraph/mobilenet/run_mul_v2.sh
+1
-1
dygraph/mobilenet/run_mul_v2_checkpoint.sh
dygraph/mobilenet/run_mul_v2_checkpoint.sh
+2
-0
dygraph/mobilenet/run_sing_v1.sh
dygraph/mobilenet/run_sing_v1.sh
+1
-1
dygraph/mobilenet/run_sing_v1_checkpoint.sh
dygraph/mobilenet/run_sing_v1_checkpoint.sh
+2
-0
dygraph/mobilenet/run_sing_v2.sh
dygraph/mobilenet/run_sing_v2.sh
+1
-1
dygraph/mobilenet/run_sing_v2_checkpoint.sh
dygraph/mobilenet/run_sing_v2_checkpoint.sh
+2
-0
dygraph/mobilenet/train.py
dygraph/mobilenet/train.py
+64
-30
未找到文件。
dygraph/mobilenet/R
ADE
ME.md
→
dygraph/mobilenet/R
EAD
ME.md
浏览文件 @
22b8805b
...
@@ -4,15 +4,21 @@
...
@@ -4,15 +4,21 @@
**代码结构**
**代码结构**
├── run_mul_v1.sh # 多卡训练启动脚本_v1
├── run_mul_v1.sh # 多卡训练启动脚本_v1
├── run_mul_v2.sh # 多卡训练启动脚本_v2
├── run_mul_v1_checkpoint.sh # 加载checkpoint多卡训练启动脚本_v1
├── run_sing_v1.sh # 单卡训练启动脚本_v1
├── run_mul_v2.sh # 多卡训练启动脚本_v2
├── run_sing_v2.sh # 单卡训练启动脚本_v2
├── run_mul_v2_checkpoint.sh # 加载checkpoint多卡训练启动脚本_v2
├── train.py # 训练入口
├── run_sing_v1.sh # 单卡训练启动脚本_v1
├── mobilenet_v1.py # 网络结构v1
├── run_sing_v1_checkpoint.sh # 加载checkpoint单卡训练启动脚本_v1
├── mobilenet_v2.py # 网络结构v2
├── run_sing_v2.sh # 单卡训练启动脚本_v2
├── reader.py # 数据reader
├── run_sing_v2_checkpoint.sh # 加载checkpoint单卡训练启动脚本_v2
├── utils # 基础工具目录
├── run_cpu_v1.sh # CPU训练启动脚本_v1
├── run_cpu_v2.sh # CPU训练启动脚本_v2
├── train.py # 训练入口
├── mobilenet_v1.py # 网络结构v1
├── mobilenet_v2.py # 网络结构v2
├── reader.py # 数据reader
├── utils # 基础工具目录
**数据准备**
**数据准备**
...
@@ -24,18 +30,35 @@
...
@@ -24,18 +30,35 @@
bash run_mul_v1.sh
bash run_mul_v1.sh
bash run_mul_v2.sh
bash run_mul_v2.sh
若使用单卡训练,启动方式如下:
若使用单卡训练,启动方式如下:
bash run_sing_v1.sh
bash run_sing_v1.sh
bash run_sing_v2.sh
bash run_sing_v2.sh
**模型精度**
若使用CPU训练,启动方式如下:
bash run_cpu_v1.sh
bash run_cpu_v2.sh
训练过程中,checkpoint会保存在参数model_save_dir指定的文件夹中,我们支持加载checkpoint继续训练.
加载checkpoint使用4卡训练,启动方式如下:
bash run_mul_v1_checkpoint.sh
bash run_mul_v2_checkpoint.sh
加载checkpoint使用单卡训练,启动方式如下:
bash run_sing_v1_checkpoint.sh
bash run_sing_v2_checkpoint.sh
**模型性能**
Model
Top-1 Top-5
Model
Top-1(单卡/4卡) Top-5(单卡/4卡) 收敛时间(单卡/4卡)
MobileNetV1 0.707
0.895
MobileNetV1 0.707
/0.711 0.897/0.899 116小时/30.9小时
MobileNetV2 0.
626 0.845
MobileNetV2 0.
708/0.724 0.899/0.906 227.8小时/60.8小时
**参考论文**
**参考论文**
...
...
dygraph/mobilenet/mobilenet_v1.py
浏览文件 @
22b8805b
...
@@ -12,12 +12,13 @@
...
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
#order: standard library, third party, local library
import
os
import
os
import
time
import
time
import
sys
import
sys
import
math
import
numpy
as
np
import
numpy
as
np
import
argparse
import
argparse
import
ast
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.initializer
import
MSRA
from
paddle.fluid.initializer
import
MSRA
...
@@ -26,8 +27,6 @@ from paddle.fluid.layer_helper import LayerHelper
...
@@ -26,8 +27,6 @@ from paddle.fluid.layer_helper import LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid
import
framework
from
paddle.fluid
import
framework
import
math
import
sys
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
...
...
dygraph/mobilenet/mobilenet_v2.py
浏览文件 @
22b8805b
...
@@ -12,14 +12,13 @@
...
@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
#order: standard library, third party, local library
import
os
import
os
import
numpy
as
np
import
time
import
time
import
sys
import
math
import
sys
import
sys
import
numpy
as
np
import
numpy
as
np
import
argparse
import
argparse
import
ast
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.initializer
import
MSRA
from
paddle.fluid.initializer
import
MSRA
...
@@ -27,11 +26,8 @@ from paddle.fluid.param_attr import ParamAttr
...
@@ -27,11 +26,8 @@ from paddle.fluid.param_attr import ParamAttr
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid
import
framework
from
paddle.fluid
import
framework
import
math
import
sys
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
...
...
dygraph/mobilenet/run_cpu_v1.sh
0 → 100644
浏览文件 @
22b8805b
python3 train.py
--use_gpu
=
False
--batch_size
=
64
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output/
--lr_strategy
=
piecewise_decay
--lr
=
0.1
--data_dir
=
./data/ILSVRC2012
--l2_decay
=
3e-5
--model
=
MobileNetV1
dygraph/mobilenet/run_cpu_v2.sh
0 → 100644
浏览文件 @
22b8805b
python3 train.py
--use_gpu
=
False
--batch_size
=
64
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output/
--lr_strategy
=
cosine_decay
--lr
=
0.1
--num_epochs
=
240
--data_dir
=
/ssd9/chaj//data/ILSVRC2012
--l2_decay
=
4e-5
--model
=
MobileNetV2
dygraph/mobilenet/run_mul_v1.sh
浏览文件 @
22b8805b
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
--log_dir
./mylog.
time train.py
--use_data_parallel
1
--batch_size
=
256
--reader_thread
=
8
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output/
--lr_strategy
=
piecewise_decay
--lr
=
0.1
--data_dir
=
../../PaddleCV/image_classification/data/ILSVRC2012
--l2_decay
=
3e-5
--model
=
MobileNetV1
python3
-m
paddle.distributed.launch
--log_dir
./mylog.
v1 train.py
--use_data_parallel
1
--batch_size
=
256
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--lr_strategy
=
piecewise_decay
--lr
=
0.1
--data_dir
=
./data/ILSVRC2012
--l2_decay
=
3e-5
--model
=
MobileNetV1
--model_save_dir
=
output.v1.mul/
--num_epochs
=
120
dygraph/mobilenet/run_mul_v1_checkpoint.sh
0 → 100644
浏览文件 @
22b8805b
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
--log_dir
./mylog.v1.checkpoint train.py
--use_data_parallel
1
--batch_size
=
256
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--lr_strategy
=
piecewise_decay
--lr
=
0.1
--data_dir
=
./data/ILSVRC2012
--l2_decay
=
3e-5
--model
=
MobileNetV1
--model_save_dir
=
output.v1.mul.checkpoint/
--num_epochs
=
120
--checkpoint
=
./output.v1.mul/_mobilenet_v1_epoch50
dygraph/mobilenet/run_mul_v2.sh
浏览文件 @
22b8805b
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
--log_dir
./mylog.
time train.py
--use_data_parallel
1
--batch_size
=
256
--reader_thread
=
8
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output/
--lr_strategy
=
piecewise_decay
--lr
=
0.1
--data_dir
=
../../PaddleCV/image_classification/data/ILSVRC2012
--l2_decay
=
3
e-5
--model
=
MobileNetV2
python3
-m
paddle.distributed.launch
--log_dir
./mylog.
v2 train.py
--use_data_parallel
1
--batch_size
=
500
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output.v2.mul/
--lr_strategy
=
cosine_decay
--lr
=
0.1
--num_epochs
=
240
--data_dir
=
./data/ILSVRC2012
--l2_decay
=
4
e-5
--model
=
MobileNetV2
dygraph/mobilenet/run_mul_v2_checkpoint.sh
0 → 100644
浏览文件 @
22b8805b
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3
-m
paddle.distributed.launch
--log_dir
./mylog.v2.checkpoint train.py
--use_data_parallel
1
--batch_size
=
500
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output.v2.mul.checkpoint/
--lr_strategy
=
cosine_decay
--lr
=
0.1
--num_epochs
=
240
--data_dir
=
./data/ILSVRC2012
--l2_decay
=
4e-5
--model
=
MobileNetV2
--checkpoint
=
./output.v2.mul/_mobilenet_v2_epoch50
dygraph/mobilenet/run_sing_v1.sh
浏览文件 @
22b8805b
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
python3 train.py
--batch_size
=
256
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output
/
--lr_strategy
=
piecewise_decay
--lr
=
0.1
--data_dir
=
../../PaddleCV/image_classification
/data/ILSVRC2012
--l2_decay
=
3e-5
--model
=
MobileNetV1
python3 train.py
--batch_size
=
256
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output
.v1.sing/
--lr_strategy
=
piecewise_decay
--lr
=
0.1
--data_dir
=
.
/data/ILSVRC2012
--l2_decay
=
3e-5
--model
=
MobileNetV1
dygraph/mobilenet/run_sing_v1_checkpoint.sh
0 → 100644
浏览文件 @
22b8805b
export
CUDA_VISIBLE_DEVICES
=
0
python3 train.py
--batch_size
=
256
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output.v1.sing/
--lr_strategy
=
piecewise_decay
--lr
=
0.1
--data_dir
=
./data/ILSVRC2012
--l2_decay
=
3e-5
--model
=
MobileNetV1
--checkpoint
=
./output.v1.sing/_mobilenet_v1_epoch50
dygraph/mobilenet/run_sing_v2.sh
浏览文件 @
22b8805b
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
python3 train.py
--batch_size
=
128
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output/
--lr_strategy
=
piecewise_decay
--lr
=
0.1
--data_dir
=
../../PaddleCV/image_classification/data/ILSVRC2012
--model
=
MobileNetV2
python3 train.py
--batch_size
=
500
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output.v2.sing/
--lr_strategy
=
cosine_decay
--lr
=
0.1
--num_epochs
=
240
--data_dir
=
./data/ILSVRC2012
--l2_decay
=
4e-5
--model
=
MobileNetV2
dygraph/mobilenet/run_sing_v2_checkpoint.sh
0 → 100644
浏览文件 @
22b8805b
export
CUDA_VISIBLE_DEVICES
=
0
python3 train.py
--batch_size
=
500
--total_images
=
1281167
--class_dim
=
1000
--image_shape
=
3,224,224
--model_save_dir
=
output.v2.sing/
--lr_strategy
=
cosine_decay
--lr
=
0.1
--num_epochs
=
240
--data_dir
=
./data/ILSVRC2012
--l2_decay
=
4e-5
--model
=
MobileNetV2
--checkpoint
=
./output.v2.sing/_mobilenet_v2_epoch50
dygraph/mobilenet/train.py
浏览文件 @
22b8805b
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,35 +12,24 @@
...
@@ -12,35 +12,24 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
mobilenet_v1
import
*
#order: standard library, third party, local library
from
mobilenet_v2
import
*
import
os
import
os
import
numpy
as
np
import
time
import
time
import
sys
import
sys
import
sys
import
math
import
numpy
as
np
import
argparse
import
argparse
import
ast
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.initializer
import
MSRA
from
paddle.fluid.initializer
import
MSRA
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layer_helper
import
LayerHelper
#from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid
import
framework
from
paddle.fluid
import
framework
import
math
import
sys
import
reader
import
reader
from
utils
import
*
from
utils
import
*
from
mobilenet_v1
import
*
IMAGENET1000
=
1281167
from
mobilenet_v2
import
*
base_lr
=
0.1
momentum_rate
=
0.9
l2_decay
=
1e-4
args
=
parse_args
()
args
=
parse_args
()
if
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
==
0
:
if
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
==
0
:
...
@@ -56,7 +45,7 @@ def eval(net, test_data_loader, eop):
...
@@ -56,7 +45,7 @@ def eval(net, test_data_loader, eop):
for
img
,
label
in
test_data_loader
():
for
img
,
label
in
test_data_loader
():
t1
=
time
.
time
()
t1
=
time
.
time
()
label
=
to_variable
(
label
.
numpy
().
astype
(
'int64'
).
reshape
(
label
=
to_variable
(
label
.
numpy
().
astype
(
'int64'
).
reshape
(
int
(
args
.
batch_size
/
paddle
.
fluid
.
core
.
get_cuda_device_count
()),
int
(
args
.
batch_size
/
/
paddle
.
fluid
.
core
.
get_cuda_device_count
()),
1
))
1
))
out
=
net
(
img
)
out
=
net
(
img
)
softmax_out
=
fluid
.
layers
.
softmax
(
out
,
use_cudnn
=
False
)
softmax_out
=
fluid
.
layers
.
softmax
(
out
,
use_cudnn
=
False
)
...
@@ -80,10 +69,14 @@ def eval(net, test_data_loader, eop):
...
@@ -80,10 +69,14 @@ def eval(net, test_data_loader, eop):
def
train_mobilenet
():
def
train_mobilenet
():
epoch
=
args
.
num_epochs
if
not
args
.
use_gpu
:
place
=
fluid
.
CUDAPlace
(
fluid
.
dygraph
.
parallel
.
Env
().
dev_id
)
\
place
=
fluid
.
CPUPlace
()
if
args
.
use_data_parallel
else
fluid
.
CUDAPlace
(
0
)
elif
not
args
.
use_data_parallel
:
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CUDAPlace
(
fluid
.
dygraph
.
parallel
.
Env
().
dev_id
)
with
fluid
.
dygraph
.
guard
(
place
):
with
fluid
.
dygraph
.
guard
(
place
):
# 1. init net and optimizer
if
args
.
ce
:
if
args
.
ce
:
print
(
"ce mode"
)
print
(
"ce mode"
)
seed
=
33
seed
=
33
...
@@ -93,13 +86,12 @@ def train_mobilenet():
...
@@ -93,13 +86,12 @@ def train_mobilenet():
if
args
.
use_data_parallel
:
if
args
.
use_data_parallel
:
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
net
=
None
if
args
.
model
==
"MobileNetV1"
:
if
args
.
model
==
"MobileNetV1"
:
net
=
MobileNetV1
(
class_dim
=
args
.
class_dim
)
net
=
MobileNetV1
(
class_dim
=
args
.
class_dim
,
scale
=
1.0
)
para_name
=
'mobilenet_v1_params
'
model_path_pre
=
'mobilenet_v1
'
elif
args
.
model
==
"MobileNetV2"
:
elif
args
.
model
==
"MobileNetV2"
:
net
=
MobileNetV2
(
class_dim
=
args
.
class_dim
,
scale
=
1.0
)
net
=
MobileNetV2
(
class_dim
=
args
.
class_dim
,
scale
=
1.0
)
para_name
=
'mobilenet_v2_params
'
model_path_pre
=
'mobilenet_v2
'
else
:
else
:
print
(
print
(
"wrong model name, please try model = MobileNetV1 or MobileNetV2"
"wrong model name, please try model = MobileNetV1 or MobileNetV2"
...
@@ -109,6 +101,18 @@ def train_mobilenet():
...
@@ -109,6 +101,18 @@ def train_mobilenet():
optimizer
=
create_optimizer
(
args
=
args
,
parameter_list
=
net
.
parameters
())
optimizer
=
create_optimizer
(
args
=
args
,
parameter_list
=
net
.
parameters
())
if
args
.
use_data_parallel
:
if
args
.
use_data_parallel
:
net
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
net
,
strategy
)
net
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
net
,
strategy
)
# 2. load checkpoint
if
args
.
checkpoint
:
assert
os
.
path
.
exists
(
args
.
checkpoint
+
".pdparams"
),
\
"Given dir {}.pdparams not exist."
.
format
(
args
.
checkpoint
)
assert
os
.
path
.
exists
(
args
.
checkpoint
+
".pdopt"
),
\
"Given dir {}.pdopt not exist."
.
format
(
args
.
checkpoint
)
para_dict
,
opti_dict
=
fluid
.
dygraph
.
load_dygraph
(
args
.
checkpoint
)
net
.
set_dict
(
para_dict
)
optimizer
.
set_dict
(
opti_dict
)
# 3. reader
train_data_loader
,
train_data
=
utility
.
create_data_loader
(
train_data_loader
,
train_data
=
utility
.
create_data_loader
(
is_train
=
True
,
args
=
args
)
is_train
=
True
,
args
=
args
)
test_data_loader
,
test_data
=
utility
.
create_data_loader
(
test_data_loader
,
test_data
=
utility
.
create_data_loader
(
...
@@ -119,7 +123,9 @@ def train_mobilenet():
...
@@ -119,7 +123,9 @@ def train_mobilenet():
test_reader
=
imagenet_reader
.
val
(
settings
=
args
)
test_reader
=
imagenet_reader
.
val
(
settings
=
args
)
train_data_loader
.
set_sample_list_generator
(
train_reader
,
place
)
train_data_loader
.
set_sample_list_generator
(
train_reader
,
place
)
test_data_loader
.
set_sample_list_generator
(
test_reader
,
place
)
test_data_loader
.
set_sample_list_generator
(
test_reader
,
place
)
for
eop
in
range
(
epoch
):
# 4. train loop
for
eop
in
range
(
args
.
num_epochs
):
if
num_trainers
>
1
:
if
num_trainers
>
1
:
imagenet_reader
.
set_shuffle_seed
(
eop
+
(
imagenet_reader
.
set_shuffle_seed
(
eop
+
(
args
.
random_seed
if
args
.
random_seed
else
0
))
args
.
random_seed
if
args
.
random_seed
else
0
))
...
@@ -130,13 +136,17 @@ def train_mobilenet():
...
@@ -130,13 +136,17 @@ def train_mobilenet():
total_sample
=
0
total_sample
=
0
batch_id
=
0
batch_id
=
0
t_last
=
0
t_last
=
0
# 4.1 for each batch, call net() , backward(), and minimize()
for
img
,
label
in
train_data_loader
():
for
img
,
label
in
train_data_loader
():
t1
=
time
.
time
()
t1
=
time
.
time
()
label
=
to_variable
(
label
.
numpy
().
astype
(
'int64'
).
reshape
(
label
=
to_variable
(
label
.
numpy
().
astype
(
'int64'
).
reshape
(
int
(
args
.
batch_size
/
int
(
args
.
batch_size
/
/
paddle
.
fluid
.
core
.
get_cuda_device_count
()),
1
))
paddle
.
fluid
.
core
.
get_cuda_device_count
()),
1
))
t_start
=
time
.
time
()
t_start
=
time
.
time
()
# 4.1.1 call net()
out
=
net
(
img
)
out
=
net
(
img
)
t_end
=
time
.
time
()
t_end
=
time
.
time
()
softmax_out
=
fluid
.
layers
.
softmax
(
out
,
use_cudnn
=
False
)
softmax_out
=
fluid
.
layers
.
softmax
(
out
,
use_cudnn
=
False
)
loss
=
fluid
.
layers
.
cross_entropy
(
loss
=
fluid
.
layers
.
cross_entropy
(
...
@@ -145,14 +155,20 @@ def train_mobilenet():
...
@@ -145,14 +155,20 @@ def train_mobilenet():
acc_top1
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top1
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
acc_top5
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
t_start_back
=
time
.
time
()
t_start_back
=
time
.
time
()
# 4.1.2 call backward()
if
args
.
use_data_parallel
:
if
args
.
use_data_parallel
:
avg_loss
=
net
.
scale_loss
(
avg_loss
)
avg_loss
=
net
.
scale_loss
(
avg_loss
)
avg_loss
.
backward
()
avg_loss
.
backward
()
net
.
apply_collective_grads
()
net
.
apply_collective_grads
()
else
:
else
:
avg_loss
.
backward
()
avg_loss
.
backward
()
t_end_back
=
time
.
time
()
t_end_back
=
time
.
time
()
# 4.1.3 call minimize()
optimizer
.
minimize
(
avg_loss
)
optimizer
.
minimize
(
avg_loss
)
net
.
clear_gradients
()
net
.
clear_gradients
()
t2
=
time
.
time
()
t2
=
time
.
time
()
train_batch_elapse
=
t2
-
t1
train_batch_elapse
=
t2
-
t1
...
@@ -174,13 +190,31 @@ def train_mobilenet():
...
@@ -174,13 +190,31 @@ def train_mobilenet():
print
(
"epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f %2.4f sec"
%
\
print
(
"epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f %2.4f sec"
%
\
(
eop
,
batch_id
,
total_loss
/
total_sample
,
\
(
eop
,
batch_id
,
total_loss
/
total_sample
,
\
total_acc1
/
total_sample
,
total_acc5
/
total_sample
,
train_batch_elapse
))
total_acc1
/
total_sample
,
total_acc5
/
total_sample
,
train_batch_elapse
))
net
.
eval
()
eval
(
net
,
test_data_loader
,
eop
)
# 4.2 save checkpoint
save_parameters
=
(
not
args
.
use_data_parallel
)
or
(
save_parameters
=
(
not
args
.
use_data_parallel
)
or
(
args
.
use_data_parallel
and
args
.
use_data_parallel
and
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
)
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
)
if
save_parameters
:
if
save_parameters
:
fluid
.
save_dygraph
(
net
.
state_dict
(),
para_name
)
if
not
os
.
path
.
isdir
(
args
.
model_save_dir
):
os
.
makedirs
(
args
.
model_save_dir
)
model_path
=
os
.
path
.
join
(
args
.
model_save_dir
,
"_"
+
model_path_pre
+
"_epoch{}"
.
format
(
eop
))
fluid
.
dygraph
.
save_dygraph
(
net
.
state_dict
(),
model_path
)
fluid
.
dygraph
.
save_dygraph
(
optimizer
.
state_dict
(),
model_path
)
# 4.3 validation
net
.
eval
()
eval
(
net
,
test_data_loader
,
eop
)
# 5. save final results
save_parameters
=
(
not
args
.
use_data_parallel
)
or
(
args
.
use_data_parallel
and
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
)
if
save_parameters
:
model_path
=
os
.
path
.
join
(
args
.
model_save_dir
,
"_"
+
model_path_pre
+
"_final"
)
fluid
.
dygraph
.
save_dygraph
(
net
.
state_dict
(),
model_path
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录