未验证 提交 c4846196 编写于 作者: C Chengmo 提交者: GitHub

fix Heter Ps multi thread (#26876)

* fix heter-ps multi thread
上级 35ae1027
# -*- coding: UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -441,7 +442,23 @@ def find_heter_ops(program, default_device="cpu"): ...@@ -441,7 +442,23 @@ def find_heter_ops(program, default_device="cpu"):
def create_heter_program(program, config, heter_program, heter_ops, def create_heter_program(program, config, heter_program, heter_ops,
block_var_detail, current_device): block_var_detail, current_device):
# add heter op
# This function mainly includes the following contents:
# 1. For every heter block:
# a) copy heter device op from origin program
# b) create variables which belong to heter op:
# -> if variable is persistable, clone it in global_scope
# -> if variable is temp, create it in heter block
# c) create communicate related op as follow:
# joint_var.0_1 -> slice -> reshape -> origin_var
# origin_var -> origin_program
# reshape -> concat -> joint_var.1_2
# d) copy send op from origin program for var@grad which loacted in current heter block
# e) re-check every op in current blcok if its device is not current heter devie
# 2. Create send op for step counter in last heter-block
# 3. Create Listen&Serv OP for distributed training
# 4. update CompileTimeStrategy for heter_program
optimizer_block = [] optimizer_block = []
grad_to_block_id = [] grad_to_block_id = []
send_grad_var_list = [] send_grad_var_list = []
...@@ -453,17 +470,10 @@ def create_heter_program(program, config, heter_program, heter_ops, ...@@ -453,17 +470,10 @@ def create_heter_program(program, config, heter_program, heter_ops,
for _, op in enumerate(heter_block_ops): for _, op in enumerate(heter_block_ops):
block_append_op(heter_program, program, heter_block, op) block_append_op(heter_program, program, heter_block, op)
# add relate variables
inputs = _get_input_map_from_op(program.global_block().vars, op)
add_vars_by_op_map(inputs, heter_program)
outputs = _get_output_map_from_op(program.global_block().vars, op)
add_vars_by_op_map(outputs, heter_program)
entrance_vars = block_var_detail[index]["entrance"] entrance_vars = block_var_detail[index]["entrance"]
add_vars_by_var_list(entrance_vars, program, heter_program) add_vars_by_var_list(entrance_vars, program, heter_program, heter_block)
exit_vars = block_var_detail[index]["exit"] exit_vars = block_var_detail[index]["exit"]
add_vars_by_var_list(exit_vars, program, heter_program) add_vars_by_var_list(exit_vars, program, heter_program, heter_block)
comm_info = get_communicate_var_info(program, index, entrance_vars, comm_info = get_communicate_var_info(program, index, entrance_vars,
exit_vars) exit_vars)
...@@ -471,13 +481,13 @@ def create_heter_program(program, config, heter_program, heter_ops, ...@@ -471,13 +481,13 @@ def create_heter_program(program, config, heter_program, heter_ops,
grad_to_block_id.append(comm_info["block_input_var_name"] + ":" + str( grad_to_block_id.append(comm_info["block_input_var_name"] + ":" + str(
heter_block.idx)) heter_block.idx))
# create slice op
first_op_index = 0 first_op_index = 0
get_type_var_name = comm_info["input_var_reshape_name"][0].split( get_type_var_name = comm_info["input_var_reshape_name"][0].split(
".input_reshape@Heter")[0] ".input_reshape@Heter")[0]
get_type_var = heter_program.global_block().vars[get_type_var_name] get_type_var = heter_block.vars[get_type_var_name]
# create slice op
insert_recv_slice_op( insert_recv_slice_op(
heter_program, heter_block, first_op_index, heter_program, heter_block, first_op_index,
comm_info["block_input_var_name"], comm_info["block_input_var_name"],
...@@ -487,6 +497,13 @@ def create_heter_program(program, config, heter_program, heter_ops, ...@@ -487,6 +497,13 @@ def create_heter_program(program, config, heter_program, heter_ops,
for i in range(len(comm_info["input_var_reshape_dim"])) for i in range(len(comm_info["input_var_reshape_dim"]))
]) ])
first_op_index += len(comm_info["input_var_reshape_dim"]) first_op_index += len(comm_info["input_var_reshape_dim"])
heter_program.global_block().create_var(
name=comm_info["block_input_var_name"],
shape=(-1, sum(comm_info["input_var_reshape_dim"])),
dtype=get_type_var.dtype,
type=get_type_var.type)
# create reshape op # create reshape op
for i in range(len(comm_info["input_var_reshape_name"])): for i in range(len(comm_info["input_var_reshape_name"])):
var_name = entrance_vars[i] var_name = entrance_vars[i]
...@@ -514,13 +531,14 @@ def create_heter_program(program, config, heter_program, heter_ops, ...@@ -514,13 +531,14 @@ def create_heter_program(program, config, heter_program, heter_ops,
comm_info["block_output_var_name"], comm_info["block_output_var_name"],
[-1, sum(comm_info["output_var_reshape_dim"])]) [-1, sum(comm_info["output_var_reshape_dim"])])
check_op_device(heter_block, current_device) check_op_device(heter_block, current_device)
# add send op
send_grad_var_list = send_grad_var_list + add_heter_send_op( send_grad_var_list = send_grad_var_list + add_heter_send_op(
program, heter_program, heter_block, block_var_detail[index]) program, heter_program, heter_block, block_var_detail[index])
# add step conter # add step conter
send_input_vars = [] send_input_vars = []
dummy_output = [] dummy_output = []
trainer_id = config.get_role_id()
pserver_endpoints = config.get_ps_endpoints() pserver_endpoints = config.get_ps_endpoints()
optimizer_block[-1].append_op( optimizer_block[-1].append_op(
type="send", type="send",
...@@ -555,7 +573,6 @@ def create_heter_program(program, config, heter_program, heter_ops, ...@@ -555,7 +573,6 @@ def create_heter_program(program, config, heter_program, heter_ops,
# append the listen_and_serv op # append the listen_and_serv op
heter_program.global_block().append_op( heter_program.global_block().append_op(
type="listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs) type="listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs)
check_heter_compile_time_strategy(program, config, send_grad_var_list) check_heter_compile_time_strategy(program, config, send_grad_var_list)
...@@ -574,6 +591,16 @@ def check_heter_compile_time_strategy(program, config, send_grad_var_list): ...@@ -574,6 +591,16 @@ def check_heter_compile_time_strategy(program, config, send_grad_var_list):
def create_trainer_program(program, config, heter_ops, block_var_detail): def create_trainer_program(program, config, heter_ops, block_var_detail):
# This function mainly includes the following contents:
# 1. For every heter block in origin program
# a) delete heter op and related variables
# b) add send&recv op
# c) add communicate ops as follows:
# origin_var -> reshape -> concat -> joint_var.0_1
# send&recv op(send joint_var.0_1; recv joint_var.1_2)
# joint_var.1_2 -> slice -> reshape -> origin_var
# d) remove send op which related var@grad is not in trainer program
# 2. check every op's device
for device in heter_ops.keys(): for device in heter_ops.keys():
for heter_block_index in sorted(heter_ops[device]): for heter_block_index in sorted(heter_ops[device]):
replace_ops_by_communicate_op(program, config, heter_block_index, replace_ops_by_communicate_op(program, config, heter_block_index,
...@@ -932,19 +959,19 @@ def insert_reshape_op(program, ...@@ -932,19 +959,19 @@ def insert_reshape_op(program,
var_name, var_name,
new_var_name, new_var_name,
new_var_shape=None): new_var_shape=None):
input_var = program.global_block().vars[var_name] input_var = block.vars[var_name]
if new_var_name not in program.global_block().vars: if new_var_name not in block.vars:
out = program.global_block().create_var( out = block.create_var(
name=new_var_name, name=new_var_name,
shape=new_var_shape, shape=new_var_shape,
dtype=input_var.dtype, dtype=input_var.dtype,
type=input_var.type) type=input_var.type)
else: else:
out = program.global_block().vars[new_var_name] out = block.vars[new_var_name]
new_var_shape = out.shape new_var_shape = out.shape
x_shape = program.global_block().create_var( x_shape = block.create_var(
name="{}.xshape@Heter".format(var_name), dtype=input_var.dtype) name="{}.xshape@Heter".format(var_name), dtype=input_var.dtype)
block._insert_op( block._insert_op(
index=index, index=index,
...@@ -957,9 +984,7 @@ def insert_reshape_op(program, ...@@ -957,9 +984,7 @@ def insert_reshape_op(program,
def insert_send_concat_op(program, block, index, var_name_list, new_var_name, def insert_send_concat_op(program, block, index, var_name_list, new_var_name,
new_var_shape): new_var_shape):
input_var_list = [ input_var_list = [block.vars[var_name] for var_name in var_name_list]
program.global_block().vars[var_name] for var_name in var_name_list
]
out = program.global_block().create_var( out = program.global_block().create_var(
name=new_var_name, name=new_var_name,
...@@ -987,14 +1012,14 @@ def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype, ...@@ -987,14 +1012,14 @@ def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype,
out_list = [] out_list = []
for i in range(len(new_var_name_list)): for i in range(len(new_var_name_list)):
if new_var_name_list[i] not in program.global_block().vars: if new_var_name_list[i] not in block.vars:
out = program.global_block().create_var( out = block.create_var(
name=new_var_name_list[i], name=new_var_name_list[i],
shape=new_var_shape_list[i], shape=new_var_shape_list[i],
dtype=input_var.dtype, dtype=input_var.dtype,
type=input_var.type) type=input_var.type)
else: else:
out = program.global_block().vars[new_var_name_list[i]] out = block.vars[new_var_name_list[i]]
out_list.append(out) out_list.append(out)
start_index = 0 start_index = 0
...@@ -1037,21 +1062,33 @@ def deleter_trainer_useless_var(program): ...@@ -1037,21 +1062,33 @@ def deleter_trainer_useless_var(program):
def block_append_op(program, origin_program, block, op): def block_append_op(program, origin_program, block, op):
inputs = _get_input_map_from_op(origin_program.global_block().vars, op) merge_ordereddict = origin_program.global_block().vars.copy()
merge_ordereddict.update(block.vars)
inputs = _get_input_map_from_op(merge_ordereddict, op)
for key, varlist in six.iteritems(inputs): for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
if var.name not in program.global_block().vars: if var.name not in program.global_block(
program.global_block()._clone_variable(var) ).vars and var.name not in block.vars:
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False)
else:
block._clone_variable(var, force_persistable=False)
outputs = _get_output_map_from_op(origin_program.global_block().vars, op) outputs = _get_output_map_from_op(origin_program.global_block().vars, op)
for key, varlist in six.iteritems(outputs): for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
if var.name not in program.global_block().vars: if var.name not in program.global_block(
program.global_block()._clone_variable(var) ).vars and var.name not in block.vars:
if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False)
else:
block._clone_variable(var, force_persistable=False)
if "_grad" not in op.type: if "_grad" not in op.type:
# for forward op # for forward op
...@@ -1076,21 +1113,15 @@ def block_append_op(program, origin_program, block, op): ...@@ -1076,21 +1113,15 @@ def block_append_op(program, origin_program, block, op):
block._sync_with_cpp() block._sync_with_cpp()
def add_vars_by_op_map(var_map, program): def add_vars_by_var_list(var_name_list, origin_program, program, block):
for key, varlist in six.iteritems(var_map):
if not isinstance(varlist, list):
varlist = [varlist]
for i in range(len(varlist)):
var = varlist[i]
if var.name not in program.global_block().vars:
program.global_block()._clone_variable(var)
def add_vars_by_var_list(var_name_list, origin_program, program):
for var_name in var_name_list: for var_name in var_name_list:
if var_name not in program.global_block().vars: if var_name not in program.global_block().vars:
var = origin_program.global_block().vars[var_name] var = origin_program.global_block().vars[var_name]
program.global_block()._clone_variable(var) if var.persistable:
program.global_block()._clone_variable(
var, force_persistable=False)
else:
block._clone_variable(var, force_persistable=False)
def get_varlist_from_op_map(var_map): def get_varlist_from_op_map(var_map):
......
...@@ -153,7 +153,7 @@ def gen_fake_line(dnn_data_num=7, ...@@ -153,7 +153,7 @@ def gen_fake_line(dnn_data_num=7,
return line return line
def prepare_fake_data(file_nums=8, file_lines=1000): def prepare_fake_data(file_nums=9, file_lines=1000):
""" """
Create fake data with same type as avazu_ctr_data Create fake data with same type as avazu_ctr_data
""" """
......
...@@ -177,7 +177,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -177,7 +177,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
fleet.init_worker() fleet.init_worker()
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
thread_num = 1 thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128 batch_size = 128
filelist = fleet_util.get_file_shard(train_file_list) filelist = fleet_util.get_file_shard(train_file_list)
print("filelist: {}".format(filelist)) print("filelist: {}".format(filelist))
......
...@@ -36,13 +36,45 @@ class TestDistHeterDatasetAsync2x2(TestFleetHeterBase): ...@@ -36,13 +36,45 @@ class TestDistHeterDatasetAsync2x2(TestFleetHeterBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "", "http_proxy": "",
"CPU_NUM": "1" "CPU_NUM": "3"
} }
required_envs.update(need_envs) required_envs.update(need_envs)
if check_error_log: if check_error_log:
required_envs["GLOG_v"] = "4" required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_heter_ctr.py", delta=1e-5, check_error_log=True)
class TestDistHeterPyreaderAsync2x2(TestFleetHeterBase):
def _setup_config(self):
self._mode = "async"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "3"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册