提交 6e5635fd 编写于 作者: Y Yancey1989

update

上级 b1e51836
...@@ -279,11 +279,20 @@ class DistributeTranspiler: ...@@ -279,11 +279,20 @@ class DistributeTranspiler:
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
assert (len(grad_blocks) == len(param_blocks))
# step2: Create new vars for the parameters and gradients blocks and # step2: Create new vars for the parameters and gradients blocks and
# add ops to do the split. # add ops to do the split.
grad_var_mapping = self._append_split_op(program, grad_blocks)
param_var_mapping = self._create_vars_from_blocklist(program, param_var_mapping = self._create_vars_from_blocklist(program,
param_blocks) param_blocks)
grad_var_mapping = self._create_vars_from_blocklist(
program, grad_blocks, add_trainer_suffix=self.trainer_num > 1)
grad_param_mapping = dict()
for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":")
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \
param_var_mapping[p_name][int(p_bid)]
rpc_client_var = program.global_block().create_var( rpc_client_var = program.global_block().create_var(
name=RPC_CLIENT_VAR_NAME, name=RPC_CLIENT_VAR_NAME,
persistable=True, persistable=True,
...@@ -304,15 +313,21 @@ class DistributeTranspiler: ...@@ -304,15 +313,21 @@ class DistributeTranspiler:
# step 3.1: insert send op to send gradient vars to parameter servers # step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset() ps_dispatcher.reset()
for varname, send_vars in grad_var_mapping.items(): send_vars = []
for varname, splited_vars in grad_var_mapping.items():
index = find_op_by_output_arg(program.global_block(), varname) index = find_op_by_output_arg(program.global_block(), varname)
eplist = ps_dispatcher.dispatch(send_vars) eplist = ps_dispatcher.dispatch(splited_vars)
if len(splited_vars) > 1:
self._insert_split_op(program, varname, splited_vars)
index += 1
program.global_block().insert_op( program.global_block().insert_op(
index=index, index=index + 1,
type="send_vars", type="send_vars",
inputs={"X": send_vars}, inputs={"X": splited_vars},
outputs={"RPCClient": rpc_client_var}, outputs={"RPCClient": rpc_client_var},
attrs={"epmap": eplist}) attrs={"epmap": eplist})
for _, var in enumerate(splited_vars):
send_vars.append(var)
if self.sync_mode: if self.sync_mode:
program.global_block().append_op( program.global_block().append_op(
...@@ -322,21 +337,12 @@ class DistributeTranspiler: ...@@ -322,21 +337,12 @@ class DistributeTranspiler:
attrs={"endpoints": pserver_endpoints}) attrs={"endpoints": pserver_endpoints})
# step 3.2: insert recv op to receive parameters from parameter server # step 3.2: insert recv op to receive parameters from parameter server
ps_dispatcher.reset()
recv_vars = [] recv_vars = []
for b in param_blocks: for _, var in enumerate(send_vars):
varname, block_id, _ = b.split(":") recv_vars.append(grad_param_mapping[var])
recv_vars.append(param_var_mapping[varname][int(block_id)]) ps_dispatcher.reset()
for b in grad_blocks:
varname, block_id, _ = b.split(":")
send_vars.append(grad_var_mapping[varname][int(block_id)])
eplist = ps_dispatcher.dispatch(recv_vars) eplist = ps_dispatcher.dispatch(recv_vars)
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={}, inputs={},
...@@ -344,6 +350,10 @@ class DistributeTranspiler: ...@@ -344,6 +350,10 @@ class DistributeTranspiler:
"RPCClient": rpc_client_var}, "RPCClient": rpc_client_var},
attrs={"epmap": eplist}) attrs={"epmap": eplist})
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# TODO(Yancey1989): check dist lookup table # TODO(Yancey1989): check dist lookup table
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
...@@ -848,6 +858,34 @@ class DistributeTranspiler: ...@@ -848,6 +858,34 @@ class DistributeTranspiler:
lod_level=var.lod_level, lod_level=var.lod_level,
persistable=persistable) persistable=persistable)
def _insert_split_op(self, program, orig_varname, splited_vars):
orig_var = program.global_block().vars[orig_varname]
index = find_op_by_output_arg(program.global_block(), orig_varname)
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = []
for v in splited_vars:
height_sections.append(v.shape[0])
program.global_block().insert_op(
index=index + 1,
type="split_selected_rows",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
attrs={"height_sections": height_sections})
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
sections = []
for v in splited_vars:
sections.append(v.shape[0])
program.global_block().insert_op(
index=index + 1,
type="split_byref",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
attrs={"sections": sections} # assume split evenly
)
else:
AssertionError("Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]")
def _append_split_op(self, program, gradblocks): def _append_split_op(self, program, gradblocks):
# Split variables that need to be split and append respective ops # Split variables that need to be split and append respective ops
add_suffix = False add_suffix = False
...@@ -860,11 +898,13 @@ class DistributeTranspiler: ...@@ -860,11 +898,13 @@ class DistributeTranspiler:
if len(splited_vars) <= 1: if len(splited_vars) <= 1:
continue continue
orig_var = program.global_block().vars[varname] orig_var = program.global_block().vars[varname]
index = find_op_by_output_arg(program.global_block(), orig_var.name)
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = [] height_sections = []
for v in splited_vars: for v in splited_vars:
height_sections.append(v.shape[0]) height_sections.append(v.shape[0])
program.global_block().append_op( program.global_block().insert_op(
index=index + 1,
type="split_selected_rows", type="split_selected_rows",
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
...@@ -873,7 +913,8 @@ class DistributeTranspiler: ...@@ -873,7 +913,8 @@ class DistributeTranspiler:
sections = [] sections = []
for v in splited_vars: for v in splited_vars:
sections.append(v.shape[0]) sections.append(v.shape[0])
program.global_block().append_op( program.global_block().insert_op(
index=index + 1,
type="split_byref", type="split_byref",
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册