未验证 提交 c4846196 编写于 作者: C Chengmo 提交者: GitHub

fix Heter Ps multi thread (#26876)

* fix heter-ps multi thread
上级 35ae1027
# -*- coding: UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -441,7 +442,23 @@ def find_heter_ops(program, default_device="cpu"):
def create_heter_program(program, config, heter_program, heter_ops,
block_var_detail, current_device):
# add heter op
# This function mainly includes the following contents:
# 1. For every heter block:
# a) copy heter device op from origin program
# b) create variables which belong to heter op:
# -> if variable is persistable, clone it in global_scope
# -> if variable is temp, create it in heter block
# c) create communicate related op as follow:
# joint_var.0_1 -> slice -> reshape -> origin_var
# origin_var -> origin_program
# reshape -> concat -> joint_var.1_2
# d) copy send op from origin program for var@grad which loacted in current heter block
# e) re-check every op in current blcok if its device is not current heter devie
# 2. Create send op for step counter in last heter-block
# 3. Create Listen&Serv OP for distributed training
# 4. update CompileTimeStrategy for heter_program
optimizer_block = []
grad_to_block_id = []
send_grad_var_list = []
......@@ -453,17 +470,10 @@ def create_heter_program(program, config, heter_program, heter_ops,
for _, op in enumerate(heter_block_ops):
block_append_op(heter_program, program, heter_block, op)
# add relate variables
inputs = _get_input_map_from_op(program.global_block().vars, op)
add_vars_by_op_map(inputs, heter_program)
outputs = _get_output_map_from_op(program.global_block().vars, op)
add_vars_by_op_map(outputs, heter_program)
entrance_vars = block_var_detail[index]["entrance"]
add_vars_by_var_list(entrance_vars, program, heter_program)
add_vars_by_var_list(entrance_vars, program, heter_program, heter_block)
exit_vars = block_var_detail[index]["exit"]
add_vars_by_var_list(exit_vars, program, heter_program)
add_vars_by_var_list(exit_vars, program, heter_program, heter_block)
comm_info = get_communicate_var_info(program, index, entrance_vars,
exit_vars)
......@@ -471,13 +481,13 @@ def create_heter_program(program, config, heter_program, heter_ops,
grad_to_block_id.append(comm_info["block_input_var_name"] + ":" + str(
heter_block.idx))
# create slice op
first_op_index = 0
get_type_var_name = comm_info["input_var_reshape_name"][0].split(
".input_reshape@Heter")[0]
get_type_var = heter_program.global_block().vars[get_type_var_name]
get_type_var = heter_block.vars[get_type_var_name]
# create slice op
insert_recv_slice_op(
heter_program, heter_block, first_op_index,
comm_info["block_input_var_name"],
......@@ -487,6 +497,13 @@ def create_heter_program(program, config, heter_program, heter_ops,
for i in range(len(comm_info["input_var_reshape_dim"]))
])
first_op_index += len(comm_info["input_var_reshape_dim"])
heter_program.global_block().create_var(
name=comm_info["block_input_var_name"],
shape=(-1, sum(comm_info["input_var_reshape_dim"])),
dtype=get_type_var.dtype,
type=get_type_var.type)
# create reshape op
for i in range(len(comm_info["input_var_reshape_name"])):
var_name = entrance_vars[i]
......@@ -514,13 +531,14 @@ def create_heter_program(program, config, heter_program, heter_ops,
comm_info["block_output_var_name"],
[-1, sum(comm_info["output_var_reshape_dim"])])
check_op_device(heter_block, current_device)
# add send op
send_grad_var_list = send_grad_var_list + add_heter_send_op(
program, heter_program, heter_block, block_var_detail[index])
# add step conter
send_input_vars = []
dummy_output = []
trainer_id = config.get_role_id()
pserver_endpoints = config.get_ps_endpoints()
optimizer_block[-1].append_op(
type="send",
......@@ -555,7 +573,6 @@ def create_heter_program(program, config, heter_program, heter_ops,
# append the listen_and_serv op
heter_program.global_block().append_op(
type="listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs)
check_heter_compile_time_strategy(program, config, send_grad_var_list)
......@@ -574,6 +591,16 @@ def check_heter_compile_time_strategy(program, config, send_grad_var_list):
def create_trainer_program(program, config, heter_ops, block_var_detail):
# This function mainly includes the following contents:
# 1. For every heter block in origin program
# a) delete heter op and related variables
# b) add send&recv op
# c) add communicate ops as follows:
# origin_var -> reshape -> concat -> joint_var.0_1
# send&recv op(send joint_var.0_1; recv joint_var.1_2)
# joint_var.1_2 -> slice -> reshape -> origin_var
# d) remove send op which related var@grad is not in trainer program
# 2. check every op's device
for device in heter_ops.keys():
for heter_block_index in sorted(heter_ops[device]):
replace_ops_by_communicate_op(program, config, heter_block_index,
......@@ -932,19 +959,19 @@ def insert_reshape_op(program,
var_name,
new_var_name,
new_var_shape=None):
input_var = program.global_block().vars[var_name]
input_var = block.vars[var_name]
if new_var_name not in program.global_block().vars:
out = program.global_block().create_var(
if new_var_name not in block.vars:
out = block.create_var(
name=new_var_name,
shape=new_var_shape,
dtype=input_var.dtype,
type=input_var.type)
else:
out = program.global_block().vars[new_var_name]
out = block.vars[new_var_name]
new_var_shape = out.shape
x_shape = program.global_block().create_var(
x_shape = block.create_var(
name="{}.xshape@Heter".format(var_name), dtype=input_var.dtype)
block._insert_op(
index=index,
......@@ -957,9 +984,7 @@ def insert_reshape_op(program,
def insert_send_concat_op(program, block, index, var_name_list, new_var_name,
new_var_shape):
input_var_list = [
program.global_block().vars[var_name] for var_name in var_name_list
]
input_var_list = [block.vars[var_name] for var_name in var_name_list]
out = program.global_block().create_var(
name=new_var_name,
......@@ -987,14 +1012,14 @@ def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype,
out_list = []
for i in range(len(new_var_name_list)):
if new_var_name_list[i] not in program.global_block().vars:
out = program.global_block().create_var(
if new_var_name_list[i] not in block.vars:
out = block.create_var(
name=new_var_name_list[i],
shape=new_var_shape_list[i],
dtype=input_var.dtype,
type=input_var.type)
else:
out = program.global_block().vars[new_var_name_list[i]]
out = block.vars[new_var_name_list[i]]
out_list.append(out)
start_index = 0
......@@ -1037,21 +1062,33 @@ def deleter_trainer_useless_var(program):
def block_append_op(program, origin_program, block, op):
inputs = _get_input_map_from_op(origin_program.global_block().vars, op)
merge_ordereddict = origin_program.global_block().vars.copy()
merge_ordereddict.update(block.vars)
inputs = _get_input_map_from_op(merge_ordereddict, op)
for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if var.name not in program.global_block().vars:
program.global_block()._clone_variable(var)
if var.name not in program.global_block(
).vars and var.name not in block.vars:
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False)
else:
block._clone_variable(var, force_persistable=False)
outputs = _get_output_map_from_op(origin_program.global_block().vars, op)
for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
if var.name not in program.global_block().vars:
program.global_block()._clone_variable(var)
if var.name not in program.global_block(
).vars and var.name not in block.vars:
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False)
else:
block._clone_variable(var, force_persistable=False)
if "_grad" not in op.type:
# for forward op
......@@ -1076,21 +1113,15 @@ def block_append_op(program, origin_program, block, op):
block._sync_with_cpp()
def add_vars_by_op_map(var_map, program):
for key, varlist in six.iteritems(var_map):
if not isinstance(varlist, list):
varlist = [varlist]
for i in range(len(varlist)):
var = varlist[i]
if var.name not in program.global_block().vars:
program.global_block()._clone_variable(var)
def add_vars_by_var_list(var_name_list, origin_program, program):
def add_vars_by_var_list(var_name_list, origin_program, program, block):
for var_name in var_name_list:
if var_name not in program.global_block().vars:
var = origin_program.global_block().vars[var_name]
program.global_block()._clone_variable(var)
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False)
else:
block._clone_variable(var, force_persistable=False)
def get_varlist_from_op_map(var_map):
......
......@@ -153,7 +153,7 @@ def gen_fake_line(dnn_data_num=7,
return line
def prepare_fake_data(file_nums=8, file_lines=1000):
def prepare_fake_data(file_nums=9, file_lines=1000):
"""
Create fake data with same type as avazu_ctr_data
"""
......
......@@ -177,7 +177,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
fleet.init_worker()
exe.run(fluid.default_startup_program())
thread_num = 1
thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128
filelist = fleet_util.get_file_shard(train_file_list)
print("filelist: {}".format(filelist))
......
......@@ -36,13 +36,45 @@ class TestDistHeterDatasetAsync2x2(TestFleetHeterBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "1"
"CPU_NUM": "3"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "4"
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_heter_ctr.py", delta=1e-5, check_error_log=True)
class TestDistHeterPyreaderAsync2x2(TestFleetHeterBase):
def _setup_config(self):
self._mode = "async"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "3"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册