Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
276bb001
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
接近 2 年 前同步成功
通知
707
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
276bb001
编写于
6月 19, 2020
作者:
L
littletomatodonkey
提交者:
GitHub
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fleet training (#961)
add fleet training
上级
d5702896
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
443 addition
and
5 deletion
+443
-5
docs/tutorials/MULTI_MACHINE_TRAINING_cn.md
docs/tutorials/MULTI_MACHINE_TRAINING_cn.md
+69
-0
ppdet/data/reader.py
ppdet/data/reader.py
+8
-5
tools/train_multi_machine.py
tools/train_multi_machine.py
+366
-0
未找到文件。
docs/tutorials/MULTI_MACHINE_TRAINING_cn.md
0 → 100644
浏览文件 @
276bb001
# 多机训练
## 简介
*
分布式训练的高性能,是飞桨的核心优势技术之一,在分类任务上,分布式训练可以达到几乎线性的加速比。
[
Fleet
](
https://github.com/PaddlePaddle/Fleet
)
是用于 PaddlePaddle 分布式训练的高层 API,基于这套接口用户可以很容易切换到分布式训练程序。
为了可以同时支持单机训练和多机训练,
[
PaddleDetection
](
https://github.com/PaddlePaddle/PaddleDetection/
)
采用 Fleet API 接口,可以同时支持单机训练和多机训练。
更多的分布式训练可以参考
[
Fleet API设计文档
](
https://github.com/PaddlePaddle/Fleet/blob/develop/README.md
)
。
## 使用方法
*
使用
`tools/train_multi_machine.py`
可以启动基于Fleet的训练,目前同时支持单机单卡、单机多卡与多机多卡的训练过程。
*
可选参数列表与
`tools/train.py`
完全相同,可以参考
[
入门使用文档
](
./GETTING_STARTED_cn.md
)
。
### 单机训练
*
训练脚本如下所示。
```
bash
# 设置PYTHONPATH路径
export
PYTHONPATH
=
$PYTHONPATH
:.
# 设置GPU卡号信息
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
# 启动训练
python
-m
paddle.distributed.launch
\
--selected_gpus
0,1,2,3,4,5,6,7
\
tools/train_multi_machine.py
\
-c
configs/faster_rcnn_r50_fpn_1x.yml
```
### 多机训练
*
训练脚本如下所示,其中ip1和ip2分别表示不同机器的ip地址,
`PADDLE_TRAINER_ID`
环境变量也是根据
`cluster_node_ips`
提供的ip顺序依次增大。
*
注意:在这里如果需要启动多机实验,需要保证不同的机器的运行代码是完全相同的。
```
export PYTHONPATH=$PYTHONPATH:.
# 设置GPU卡号信息
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# 启动训练
node_ip=`hostname -i`
python -m paddle.distributed.launch \
--use_paddlecloud \
--cluster_node_ips ${ip1},${ip2} \
--node_ip ${node_ip} \
tools/train_multi_machine.py \
-c configs/faster_rcnn_r50_fpn_1x.yml \
```
## 训练时间统计
*
以Faster RCNN R50_vd FPN 1x实验为例,下面给出了基于Fleet分布式训练,不同机器的训练时间对比。
*
这里均是在V100 GPU上展开的实验。
*
1x实验指的是8卡,单卡batch size为2时,训练的minibatch数量为90000(当训练卡数或者batch size变化时,对应的学习率和总的迭代轮数也需要变化)。
| 模型 | 训练策略 | 机器数量 | 每台机器的GPU数量 | 训练时间 | COCO bbox mAP | 加速比 |
| :----------------------: | :------------: | :------------: | :---------------: | :----------: | :-----------: | :-----------: |
| Faster RCNN R50_vd FPN | 1x | 1 | 4 | 15.1h | 38.3% | - |
| Faster RCNN R50_vd FPN | 1x | 2 | 4 | 9.8h | 38.2% | 76% |
| Faster RCNN R50_vd FPN | 1x | 1 | 8 | 8.6h | 38.2% | - |
| Faster RCNN R50_vd FPN | 1x | 2 | 8 | 5.1h | 38.0% | 84% |
*
由上图可知,2机8卡相比于单机8卡,加速比可以达到84%,2即4卡相比于单机4卡,加速比可以达到76%,而且精度几乎没有损失。
*
1x实验相当于COCO数据集训练了约13个epoch,因此在trainer数量很多的时候,每个trainer可能无法训练完1个epoch,这会导致精度出现一些差异,这可以通过适当增加迭代轮数实现精度的对齐,我们实验发现,在训练多尺度3x实验时(配置文件:
[
configs/rcnn_enhance/faster_rcnn_dcn_r50_vd_fpn_3x_server_side.yml
](
../../configs/rcnn_enhance/faster_rcnn_dcn_r50_vd_fpn_3x_server_side.yml
)
),分布式训练与单机训练的模型精度是可以对齐的。
ppdet/data/reader.py
浏览文件 @
276bb001
...
...
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
copy
import
functools
import
collections
...
...
@@ -278,6 +279,11 @@ class Reader(object):
def
reset
(
self
):
"""implementation of Dataset.reset
"""
if
self
.
_epoch
<
0
:
self
.
_epoch
=
0
else
:
self
.
_epoch
+=
1
self
.
indexes
=
[
i
for
i
in
range
(
self
.
size
())]
if
self
.
_class_aware_sampling
:
self
.
indexes
=
np
.
random
.
choice
(
...
...
@@ -287,6 +293,8 @@ class Reader(object):
p
=
self
.
img_weights
)
if
self
.
_shuffle
:
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
np
.
random
.
seed
(
self
.
_epoch
+
trainer_id
)
np
.
random
.
shuffle
(
self
.
indexes
)
if
self
.
_mixup_epoch
>
0
and
len
(
self
.
indexes
)
<
2
:
...
...
@@ -298,11 +306,6 @@ class Reader(object):
"less than 2 samples"
)
self
.
_cutmix_epoch
=
-
1
if
self
.
_epoch
<
0
:
self
.
_epoch
=
0
else
:
self
.
_epoch
+=
1
self
.
_pos
=
0
def
__next__
(
self
):
...
...
tools/train_multi_machine.py
0 → 100644
浏览文件 @
276bb001
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
,
sys
# add python path of PadleDetection to sys.path
parent_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
*
([
'..'
]
*
2
)))
if
parent_path
not
in
sys
.
path
:
sys
.
path
.
append
(
parent_path
)
import
time
import
numpy
as
np
import
random
import
datetime
import
six
from
collections
import
deque
from
paddle.fluid
import
profiler
from
paddle
import
fluid
from
paddle.fluid.layers.learning_rate_scheduler
import
_decay_step_counter
from
paddle.fluid.optimizer
import
ExponentialMovingAverage
from
ppdet.experimental
import
mixed_precision_context
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.reader
import
create_reader
from
ppdet.utils
import
dist_utils
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_run
,
eval_results
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.check
import
check_gpu
,
check_version
,
check_config
import
ppdet.utils.checkpoint
as
checkpoint
from
paddle.fluid.incubate.fleet.collective
import
fleet
,
DistributedStrategy
# new line 1
from
paddle.fluid.incubate.fleet.base
import
role_maker
# new line 2
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
def
main
():
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
# new line 3
fleet
.
init
(
role
)
# new line 4
env
=
os
.
environ
num_trainers
=
int
(
env
.
get
(
'PADDLE_TRAINERS_NUM'
,
0
))
assert
num_trainers
!=
0
,
"multi-machine training process must be started using distributed.launch..."
trainer_id
=
int
(
env
.
get
(
"PADDLE_TRAINER_ID"
,
0
))
# set different seeds for different trainers
random
.
seed
(
trainer_id
)
np
.
random
.
seed
(
trainer_id
)
if
FLAGS
.
enable_ce
:
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
cfg
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
check_config
(
cfg
)
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu
(
cfg
.
use_gpu
)
# check if paddlepaddle version is satisfied
check_version
()
save_only
=
getattr
(
cfg
,
'save_prediction_only'
,
False
)
if
save_only
:
raise
NotImplementedError
(
'The config file only support prediction,'
' training stage is not implemented now'
)
main_arch
=
cfg
.
architecture
assert
cfg
.
use_gpu
==
True
,
"GPU must be supported for multi-machine training..."
devices_num
=
fluid
.
core
.
get_cuda_device_count
()
if
'FLAGS_selected_gpus'
in
env
:
device_id
=
int
(
env
[
'FLAGS_selected_gpus'
])
else
:
device_id
=
0
place
=
fluid
.
CUDAPlace
(
device_id
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
lr_builder
=
create
(
'LearningRate'
)
optim_builder
=
create
(
'OptimizerBuilder'
)
# build program
startup_prog
=
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
if
FLAGS
.
enable_ce
:
startup_prog
.
random_seed
=
1000
train_prog
.
random_seed
=
1000
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
if
FLAGS
.
fp16
:
assert
(
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
!=
'affine_channel'
),
\
'--fp16 currently does not support affine channel, '
\
' please modify backbone settings to use batch norm'
with
mixed_precision_context
(
FLAGS
.
loss_scale
,
FLAGS
.
fp16
)
as
ctx
:
inputs_def
=
cfg
[
'TrainReader'
][
'inputs_def'
]
feed_vars
,
train_loader
=
model
.
build_inputs
(
**
inputs_def
)
train_fetches
=
model
.
train
(
feed_vars
)
loss
=
train_fetches
[
'loss'
]
if
FLAGS
.
fp16
:
loss
*=
ctx
.
get_loss_scale_var
()
lr
=
lr_builder
()
optimizer
=
optim_builder
(
lr
)
dist_strategy
=
DistributedStrategy
()
sync_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'sync_bn'
dist_strategy
.
sync_batch_norm
=
sync_bn
dist_strategy
.
nccl_comm_num
=
1
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
num_threads
=
3
exec_strategy
.
num_iteration_per_drop_scope
=
30
dist_strategy
.
exec_strategy
=
exec_strategy
dist_strategy
.
fuse_all_reduce_ops
=
True
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
dist_strategy
)
# new line 5
optimizer
.
minimize
(
loss
)
if
FLAGS
.
fp16
:
loss
/=
ctx
.
get_loss_scale_var
()
if
'use_ema'
in
cfg
and
cfg
[
'use_ema'
]:
global_steps
=
_decay_step_counter
()
ema
=
ExponentialMovingAverage
(
cfg
[
'ema_decay'
],
thres_steps
=
global_steps
)
ema
.
update
()
# parse train fetches
train_keys
,
train_values
,
_
=
parse_fetches
(
train_fetches
)
train_values
.
append
(
lr
)
if
FLAGS
.
eval
:
eval_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
eval_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
=
create
(
main_arch
)
inputs_def
=
cfg
[
'EvalReader'
][
'inputs_def'
]
feed_vars
,
eval_loader
=
model
.
build_inputs
(
**
inputs_def
)
fetches
=
model
.
eval
(
feed_vars
)
eval_prog
=
eval_prog
.
clone
(
True
)
eval_reader
=
create_reader
(
cfg
.
EvalReader
,
devices_num
=
1
)
eval_loader
.
set_sample_list_generator
(
eval_reader
,
place
)
# parse eval fetches
extra_keys
=
[]
if
cfg
.
metric
==
'COCO'
:
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
]
if
cfg
.
metric
==
'VOC'
:
extra_keys
=
[
'gt_bbox'
,
'gt_class'
,
'is_difficult'
]
if
cfg
.
metric
==
'WIDERFACE'
:
extra_keys
=
[
'im_id'
,
'im_shape'
,
'gt_bbox'
]
eval_keys
,
eval_values
,
eval_cls
=
parse_fetches
(
fetches
,
eval_prog
,
extra_keys
)
exe
.
run
(
startup_prog
)
compiled_train_prog
=
fleet
.
main_program
if
FLAGS
.
eval
:
compiled_eval_prog
=
fluid
.
CompiledProgram
(
eval_prog
)
fuse_bn
=
getattr
(
model
.
backbone
,
'norm_type'
,
None
)
==
'affine_channel'
ignore_params
=
cfg
.
finetune_exclude_pretrained_params
\
if
'finetune_exclude_pretrained_params'
in
cfg
else
[]
start_iter
=
0
if
FLAGS
.
resume_checkpoint
:
checkpoint
.
load_checkpoint
(
exe
,
train_prog
,
FLAGS
.
resume_checkpoint
)
start_iter
=
checkpoint
.
global_step
()
elif
cfg
.
pretrain_weights
and
fuse_bn
and
not
ignore_params
:
checkpoint
.
load_and_fusebn
(
exe
,
train_prog
,
cfg
.
pretrain_weights
)
elif
cfg
.
pretrain_weights
:
checkpoint
.
load_params
(
exe
,
train_prog
,
cfg
.
pretrain_weights
,
ignore_params
=
ignore_params
)
train_reader
=
create_reader
(
cfg
.
TrainReader
,
(
cfg
.
max_iters
-
start_iter
)
*
devices_num
,
cfg
,
devices_num
=
devices_num
)
train_loader
.
set_sample_list_generator
(
train_reader
,
place
)
# whether output bbox is normalized in model output layer
is_bbox_normalized
=
False
if
hasattr
(
model
,
'is_bbox_normalized'
)
and
\
callable
(
model
.
is_bbox_normalized
):
is_bbox_normalized
=
model
.
is_bbox_normalized
()
# if map_type not set, use default 11point, only use in VOC eval
map_type
=
cfg
.
map_type
if
'map_type'
in
cfg
else
'11point'
train_stats
=
TrainingStats
(
cfg
.
log_smooth_window
,
train_keys
)
train_loader
.
start
()
start_time
=
time
.
time
()
end_time
=
time
.
time
()
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_dir
=
os
.
path
.
join
(
cfg
.
save_dir
,
cfg_name
)
time_stat
=
deque
(
maxlen
=
cfg
.
log_smooth_window
)
best_box_ap_list
=
[
0.0
,
0
]
#[map, iter]
# use VisualDL to log data
if
FLAGS
.
use_vdl
:
assert
six
.
PY3
,
"VisualDL requires Python >= 3.5"
from
visualdl
import
LogWriter
vdl_writer
=
LogWriter
(
FLAGS
.
vdl_log_dir
)
vdl_loss_step
=
0
vdl_mAP_step
=
0
for
it
in
range
(
start_iter
,
cfg
.
max_iters
):
start_time
=
end_time
end_time
=
time
.
time
()
time_stat
.
append
(
end_time
-
start_time
)
time_cost
=
np
.
mean
(
time_stat
)
eta_sec
=
(
cfg
.
max_iters
-
it
)
*
time_cost
eta
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
)))
outs
=
exe
.
run
(
compiled_train_prog
,
fetch_list
=
train_values
)
stats
=
{
k
:
np
.
array
(
v
).
mean
()
for
k
,
v
in
zip
(
train_keys
,
outs
[:
-
1
])}
# use vdl-paddle to log loss
if
FLAGS
.
use_vdl
:
if
it
%
cfg
.
log_iter
==
0
:
for
loss_name
,
loss_value
in
stats
.
items
():
vdl_writer
.
add_scalar
(
loss_name
,
loss_value
,
vdl_loss_step
)
vdl_loss_step
+=
1
train_stats
.
update
(
stats
)
logs
=
train_stats
.
log
()
if
it
%
cfg
.
log_iter
==
0
and
trainer_id
==
0
:
strs
=
'iter: {}, lr: {:.6f}, {}, time: {:.3f}, eta: {}'
.
format
(
it
,
np
.
mean
(
outs
[
-
1
]),
logs
,
time_cost
,
eta
)
logger
.
info
(
strs
)
# NOTE : profiler tools, used for benchmark
if
FLAGS
.
is_profiler
and
it
==
5
:
profiler
.
start_profiler
(
"All"
)
elif
FLAGS
.
is_profiler
and
it
==
10
:
profiler
.
stop_profiler
(
"total"
,
FLAGS
.
profiler_path
)
return
if
(
it
>
0
and
it
%
cfg
.
snapshot_iter
==
0
or
it
==
cfg
.
max_iters
-
1
)
\
and
trainer_id
==
0
:
save_name
=
str
(
it
)
if
it
!=
cfg
.
max_iters
-
1
else
"model_final"
if
'use_ema'
in
cfg
and
cfg
[
'use_ema'
]:
exe
.
run
(
ema
.
apply_program
)
checkpoint
.
save
(
exe
,
train_prog
,
os
.
path
.
join
(
save_dir
,
save_name
))
if
FLAGS
.
eval
:
# evaluation
resolution
=
None
if
'Mask'
in
cfg
.
architecture
:
resolution
=
model
.
mask_head
.
resolution
results
=
eval_run
(
exe
,
compiled_eval_prog
,
eval_loader
,
eval_keys
,
eval_values
,
eval_cls
,
cfg
,
resolution
=
resolution
)
box_ap_stats
=
eval_results
(
results
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
is_bbox_normalized
,
FLAGS
.
output_eval
,
map_type
,
cfg
[
'EvalReader'
][
'dataset'
])
# use vdl_paddle to log mAP
if
FLAGS
.
use_vdl
:
vdl_writer
.
add_scalar
(
"mAP"
,
box_ap_stats
[
0
],
vdl_mAP_step
)
vdl_mAP_step
+=
1
if
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
best_box_ap_list
[
1
]
=
it
checkpoint
.
save
(
exe
,
train_prog
,
os
.
path
.
join
(
save_dir
,
"best_model"
))
logger
.
info
(
"Best test box ap: {}, in iter: {}"
.
format
(
best_box_ap_list
[
0
],
best_box_ap_list
[
1
]))
if
'use_ema'
in
cfg
and
cfg
[
'use_ema'
]:
exe
.
run
(
ema
.
restore_program
)
train_loader
.
reset
()
if
__name__
==
'__main__'
:
parser
=
ArgsParser
()
parser
.
add_argument
(
"-r"
,
"--resume_checkpoint"
,
default
=
None
,
type
=
str
,
help
=
"Checkpoint path for resuming training."
)
parser
.
add_argument
(
"--fp16"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Enable mixed precision training."
)
parser
.
add_argument
(
"--loss_scale"
,
default
=
8.
,
type
=
float
,
help
=
"Mixed precision training loss scale."
)
parser
.
add_argument
(
"--eval"
,
action
=
'store_true'
,
default
=
False
,
help
=
"Whether to perform evaluation in train"
)
parser
.
add_argument
(
"--output_eval"
,
default
=
None
,
type
=
str
,
help
=
"Evaluation directory, default is current directory."
)
parser
.
add_argument
(
"--use_vdl"
,
type
=
bool
,
default
=
False
,
help
=
"whether to record the data to VisualDL."
)
parser
.
add_argument
(
'--vdl_log_dir'
,
type
=
str
,
default
=
"vdl_log_dir/scalar"
,
help
=
'VisualDL logging directory for scalar.'
)
parser
.
add_argument
(
"--enable_ce"
,
type
=
bool
,
default
=
False
,
help
=
"If set True, enable continuous evaluation job."
"This flag is only used for internal test."
)
#NOTE:args for profiler tools, used for benchmark
parser
.
add_argument
(
'--is_profiler'
,
type
=
int
,
default
=
0
,
help
=
'The switch of profiler tools. (used for benchmark)'
)
parser
.
add_argument
(
'--profiler_path'
,
type
=
str
,
default
=
"./detection.profiler"
,
help
=
'The profiler output file path. (used for benchmark)'
)
FLAGS
=
parser
.
parse_args
()
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录