Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1ae04a73
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1ae04a73
编写于
9月 04, 2020
作者:
C
Chengmo
提交者:
GitHub
9月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix Heter Ps multi thread (#26876) (#27016)
* fix heter-ps multi thread
上级
cbb0f59d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
109 addition
and
46 deletion
+109
-46
python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
.../fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
+73
-42
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
+1
-1
python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py
python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py
+1
-1
python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py
...paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py
+34
-2
未找到文件。
python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
浏览文件 @
1ae04a73
# -*- coding: UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -441,7 +442,23 @@ def find_heter_ops(program, default_device="cpu"):
def
create_heter_program
(
program
,
config
,
heter_program
,
heter_ops
,
block_var_detail
,
current_device
):
# add heter op
# This function mainly includes the following contents:
# 1. For every heter block:
# a) copy heter device op from origin program
# b) create variables which belong to heter op:
# -> if variable is persistable, clone it in global_scope
# -> if variable is temp, create it in heter block
# c) create communicate related op as follow:
# joint_var.0_1 -> slice -> reshape -> origin_var
# origin_var -> origin_program
# reshape -> concat -> joint_var.1_2
# d) copy send op from origin program for var@grad which loacted in current heter block
# e) re-check every op in current blcok if its device is not current heter devie
# 2. Create send op for step counter in last heter-block
# 3. Create Listen&Serv OP for distributed training
# 4. update CompileTimeStrategy for heter_program
optimizer_block
=
[]
grad_to_block_id
=
[]
send_grad_var_list
=
[]
...
...
@@ -453,17 +470,10 @@ def create_heter_program(program, config, heter_program, heter_ops,
for
_
,
op
in
enumerate
(
heter_block_ops
):
block_append_op
(
heter_program
,
program
,
heter_block
,
op
)
# add relate variables
inputs
=
_get_input_map_from_op
(
program
.
global_block
().
vars
,
op
)
add_vars_by_op_map
(
inputs
,
heter_program
)
outputs
=
_get_output_map_from_op
(
program
.
global_block
().
vars
,
op
)
add_vars_by_op_map
(
outputs
,
heter_program
)
entrance_vars
=
block_var_detail
[
index
][
"entrance"
]
add_vars_by_var_list
(
entrance_vars
,
program
,
heter_program
)
add_vars_by_var_list
(
entrance_vars
,
program
,
heter_program
,
heter_block
)
exit_vars
=
block_var_detail
[
index
][
"exit"
]
add_vars_by_var_list
(
exit_vars
,
program
,
heter_program
)
add_vars_by_var_list
(
exit_vars
,
program
,
heter_program
,
heter_block
)
comm_info
=
get_communicate_var_info
(
program
,
index
,
entrance_vars
,
exit_vars
)
...
...
@@ -471,13 +481,13 @@ def create_heter_program(program, config, heter_program, heter_ops,
grad_to_block_id
.
append
(
comm_info
[
"block_input_var_name"
]
+
":"
+
str
(
heter_block
.
idx
))
# create slice op
first_op_index
=
0
get_type_var_name
=
comm_info
[
"input_var_reshape_name"
][
0
].
split
(
".input_reshape@Heter"
)[
0
]
get_type_var
=
heter_
program
.
global_block
()
.
vars
[
get_type_var_name
]
get_type_var
=
heter_
block
.
vars
[
get_type_var_name
]
# create slice op
insert_recv_slice_op
(
heter_program
,
heter_block
,
first_op_index
,
comm_info
[
"block_input_var_name"
],
...
...
@@ -487,6 +497,13 @@ def create_heter_program(program, config, heter_program, heter_ops,
for
i
in
range
(
len
(
comm_info
[
"input_var_reshape_dim"
]))
])
first_op_index
+=
len
(
comm_info
[
"input_var_reshape_dim"
])
heter_program
.
global_block
().
create_var
(
name
=
comm_info
[
"block_input_var_name"
],
shape
=
(
-
1
,
sum
(
comm_info
[
"input_var_reshape_dim"
])),
dtype
=
get_type_var
.
dtype
,
type
=
get_type_var
.
type
)
# create reshape op
for
i
in
range
(
len
(
comm_info
[
"input_var_reshape_name"
])):
var_name
=
entrance_vars
[
i
]
...
...
@@ -514,13 +531,14 @@ def create_heter_program(program, config, heter_program, heter_ops,
comm_info
[
"block_output_var_name"
],
[
-
1
,
sum
(
comm_info
[
"output_var_reshape_dim"
])])
check_op_device
(
heter_block
,
current_device
)
# add send op
send_grad_var_list
=
send_grad_var_list
+
add_heter_send_op
(
program
,
heter_program
,
heter_block
,
block_var_detail
[
index
])
# add step conter
send_input_vars
=
[]
dummy_output
=
[]
trainer_id
=
config
.
get_role_id
()
pserver_endpoints
=
config
.
get_ps_endpoints
()
optimizer_block
[
-
1
].
append_op
(
type
=
"send"
,
...
...
@@ -555,7 +573,6 @@ def create_heter_program(program, config, heter_program, heter_ops,
# append the listen_and_serv op
heter_program
.
global_block
().
append_op
(
type
=
"listen_and_serv"
,
inputs
=
{
'X'
:
[]},
outputs
=
{},
attrs
=
attrs
)
check_heter_compile_time_strategy
(
program
,
config
,
send_grad_var_list
)
...
...
@@ -574,6 +591,16 @@ def check_heter_compile_time_strategy(program, config, send_grad_var_list):
def
create_trainer_program
(
program
,
config
,
heter_ops
,
block_var_detail
):
# This function mainly includes the following contents:
# 1. For every heter block in origin program
# a) delete heter op and related variables
# b) add send&recv op
# c) add communicate ops as follows:
# origin_var -> reshape -> concat -> joint_var.0_1
# send&recv op(send joint_var.0_1; recv joint_var.1_2)
# joint_var.1_2 -> slice -> reshape -> origin_var
# d) remove send op which related var@grad is not in trainer program
# 2. check every op's device
for
device
in
heter_ops
.
keys
():
for
heter_block_index
in
sorted
(
heter_ops
[
device
]):
replace_ops_by_communicate_op
(
program
,
config
,
heter_block_index
,
...
...
@@ -932,19 +959,19 @@ def insert_reshape_op(program,
var_name
,
new_var_name
,
new_var_shape
=
None
):
input_var
=
program
.
global_block
()
.
vars
[
var_name
]
input_var
=
block
.
vars
[
var_name
]
if
new_var_name
not
in
program
.
global_block
()
.
vars
:
out
=
program
.
global_block
()
.
create_var
(
if
new_var_name
not
in
block
.
vars
:
out
=
block
.
create_var
(
name
=
new_var_name
,
shape
=
new_var_shape
,
dtype
=
input_var
.
dtype
,
type
=
input_var
.
type
)
else
:
out
=
program
.
global_block
()
.
vars
[
new_var_name
]
out
=
block
.
vars
[
new_var_name
]
new_var_shape
=
out
.
shape
x_shape
=
program
.
global_block
()
.
create_var
(
x_shape
=
block
.
create_var
(
name
=
"{}.xshape@Heter"
.
format
(
var_name
),
dtype
=
input_var
.
dtype
)
block
.
_insert_op
(
index
=
index
,
...
...
@@ -957,9 +984,7 @@ def insert_reshape_op(program,
def
insert_send_concat_op
(
program
,
block
,
index
,
var_name_list
,
new_var_name
,
new_var_shape
):
input_var_list
=
[
program
.
global_block
().
vars
[
var_name
]
for
var_name
in
var_name_list
]
input_var_list
=
[
block
.
vars
[
var_name
]
for
var_name
in
var_name_list
]
out
=
program
.
global_block
().
create_var
(
name
=
new_var_name
,
...
...
@@ -987,14 +1012,14 @@ def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype,
out_list
=
[]
for
i
in
range
(
len
(
new_var_name_list
)):
if
new_var_name_list
[
i
]
not
in
program
.
global_block
()
.
vars
:
out
=
program
.
global_block
()
.
create_var
(
if
new_var_name_list
[
i
]
not
in
block
.
vars
:
out
=
block
.
create_var
(
name
=
new_var_name_list
[
i
],
shape
=
new_var_shape_list
[
i
],
dtype
=
input_var
.
dtype
,
type
=
input_var
.
type
)
else
:
out
=
program
.
global_block
()
.
vars
[
new_var_name_list
[
i
]]
out
=
block
.
vars
[
new_var_name_list
[
i
]]
out_list
.
append
(
out
)
start_index
=
0
...
...
@@ -1037,21 +1062,33 @@ def deleter_trainer_useless_var(program):
def
block_append_op
(
program
,
origin_program
,
block
,
op
):
inputs
=
_get_input_map_from_op
(
origin_program
.
global_block
().
vars
,
op
)
merge_ordereddict
=
origin_program
.
global_block
().
vars
.
copy
()
merge_ordereddict
.
update
(
block
.
vars
)
inputs
=
_get_input_map_from_op
(
merge_ordereddict
,
op
)
for
key
,
varlist
in
six
.
iteritems
(
inputs
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
var
in
varlist
:
if
var
.
name
not
in
program
.
global_block
().
vars
:
program
.
global_block
().
_clone_variable
(
var
)
if
var
.
name
not
in
program
.
global_block
(
).
vars
and
var
.
name
not
in
block
.
vars
:
if
var
.
persistable
:
program
.
global_block
().
_clone_variable
(
var
,
force_persistable
=
False
)
else
:
block
.
_clone_variable
(
var
,
force_persistable
=
False
)
outputs
=
_get_output_map_from_op
(
origin_program
.
global_block
().
vars
,
op
)
for
key
,
varlist
in
six
.
iteritems
(
outputs
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
var
in
varlist
:
if
var
.
name
not
in
program
.
global_block
().
vars
:
program
.
global_block
().
_clone_variable
(
var
)
if
var
.
name
not
in
program
.
global_block
(
).
vars
and
var
.
name
not
in
block
.
vars
:
if
var
.
persistable
:
program
.
global_block
().
_clone_variable
(
var
,
force_persistable
=
False
)
else
:
block
.
_clone_variable
(
var
,
force_persistable
=
False
)
if
"_grad"
not
in
op
.
type
:
# for forward op
...
...
@@ -1076,21 +1113,15 @@ def block_append_op(program, origin_program, block, op):
block
.
_sync_with_cpp
()
def
add_vars_by_op_map
(
var_map
,
program
):
for
key
,
varlist
in
six
.
iteritems
(
var_map
):
if
not
isinstance
(
varlist
,
list
):
varlist
=
[
varlist
]
for
i
in
range
(
len
(
varlist
)):
var
=
varlist
[
i
]
if
var
.
name
not
in
program
.
global_block
().
vars
:
program
.
global_block
().
_clone_variable
(
var
)
def
add_vars_by_var_list
(
var_name_list
,
origin_program
,
program
):
def
add_vars_by_var_list
(
var_name_list
,
origin_program
,
program
,
block
):
for
var_name
in
var_name_list
:
if
var_name
not
in
program
.
global_block
().
vars
:
var
=
origin_program
.
global_block
().
vars
[
var_name
]
program
.
global_block
().
_clone_variable
(
var
)
if
var
.
persistable
:
program
.
global_block
().
_clone_variable
(
var
,
force_persistable
=
False
)
else
:
block
.
_clone_variable
(
var
,
force_persistable
=
False
)
def
get_varlist_from_op_map
(
var_map
):
...
...
python/paddle/fluid/tests/unittests/ctr_dataset_reader.py
浏览文件 @
1ae04a73
...
...
@@ -153,7 +153,7 @@ def gen_fake_line(dnn_data_num=7,
return
line
def
prepare_fake_data
(
file_nums
=
8
,
file_lines
=
1000
):
def
prepare_fake_data
(
file_nums
=
9
,
file_lines
=
1000
):
"""
Create fake data with same type as avazu_ctr_data
"""
...
...
python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py
浏览文件 @
1ae04a73
...
...
@@ -177,7 +177,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
fleet
.
init_worker
()
exe
.
run
(
fluid
.
default_startup_program
())
thread_num
=
1
thread_num
=
int
(
os
.
getenv
(
"CPU_NUM"
,
2
))
batch_size
=
128
filelist
=
fleet_util
.
get_file_shard
(
train_file_list
)
print
(
"filelist: {}"
.
format
(
filelist
))
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py
浏览文件 @
1ae04a73
...
...
@@ -36,13 +36,45 @@ class TestDistHeterDatasetAsync2x2(TestFleetHeterBase):
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
,
"CPU_NUM"
:
"
1
"
"CPU_NUM"
:
"
3
"
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"4"
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
def
test_dist_train
(
self
):
self
.
check_with_place
(
"dist_fleet_heter_ctr.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
class
TestDistHeterPyreaderAsync2x2
(
TestFleetHeterBase
):
def
_setup_config
(
self
):
self
.
_mode
=
"async"
self
.
_reader
=
"pyreader"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
,
"CPU_NUM"
:
"3"
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"3"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
tr0_losses
,
tr1_losses
=
self
.
_run_cluster
(
model_file
,
required_envs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录