提交 9f4b66f6 编写于 作者: Q qiaolongfei

table gradient should be split and send to each pserver

上级 25f47fc0
......@@ -797,7 +797,7 @@ class Block(object):
Rename variable in vars and ops' inputs and outputs
"""
if not self.has_var(name):
raise ValueError("var %s is not in current" % name)
raise ValueError("var %s is not in current block" % name)
v = self.var(name)
if type(v) == Parameter:
var_type = "Parameter"
......@@ -843,6 +843,7 @@ class Block(object):
self.vars[new_name] = var
del self.vars[name]
self.sync_with_cpp()
return var
def remove_var(self, name):
self.sync_with_cpp()
......
......@@ -256,15 +256,25 @@ class DistributeTranspiler:
if param_grad[0].name == self.table_name
][0]
table_grad_var = self.table_param_grad[1]
self.table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
if self.sync_mode:
self.table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
else:
self.table_grad_list = [
program.global_block().create_var(
name="%s.pserver_%d" % (table_grad_var.name, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
......@@ -328,7 +338,7 @@ class DistributeTranspiler:
if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist)
pserver_endpoints)
self._split_table_grad_and_add_send_vars(program, rpc_client_var,
pserver_endpoints)
......@@ -551,7 +561,7 @@ class DistributeTranspiler:
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist):
pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None
self.prefetch_output_vars = None
......@@ -602,7 +612,7 @@ class DistributeTranspiler:
"Out": self.prefetch_output_vars,
"RPCClient": rpc_client_var
},
attrs={"epmap": eplist})
attrs={"epmap": pserver_endpoints})
# insert concat_op
program.global_block().insert_op(
......@@ -731,6 +741,12 @@ class DistributeTranspiler:
type="sum",
inputs={"X": table_grad_list},
outputs={"Out": [grad_var]})
else:
# in async_mode, for table gradient, it also need to be splited to each parameter server
old_name = grad_var.name
new_name = old_name + ".pserver_" + str(pserver_index)
grad_var = pserver_program.global_block().rename_var(old_name,
new_name)
lr_var = pserver_program.global_block().vars[table_opt_op.input(
"LearningRate")[0]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册