提交 9f4b66f6 编写于 作者: Q qiaolongfei

table gradient should be split and send to each pserver

上级 25f47fc0
...@@ -797,7 +797,7 @@ class Block(object): ...@@ -797,7 +797,7 @@ class Block(object):
Rename variable in vars and ops' inputs and outputs Rename variable in vars and ops' inputs and outputs
""" """
if not self.has_var(name): if not self.has_var(name):
raise ValueError("var %s is not in current" % name) raise ValueError("var %s is not in current block" % name)
v = self.var(name) v = self.var(name)
if type(v) == Parameter: if type(v) == Parameter:
var_type = "Parameter" var_type = "Parameter"
...@@ -843,6 +843,7 @@ class Block(object): ...@@ -843,6 +843,7 @@ class Block(object):
self.vars[new_name] = var self.vars[new_name] = var
del self.vars[name] del self.vars[name]
self.sync_with_cpp() self.sync_with_cpp()
return var
def remove_var(self, name): def remove_var(self, name):
self.sync_with_cpp() self.sync_with_cpp()
......
...@@ -256,6 +256,7 @@ class DistributeTranspiler: ...@@ -256,6 +256,7 @@ class DistributeTranspiler:
if param_grad[0].name == self.table_name if param_grad[0].name == self.table_name
][0] ][0]
table_grad_var = self.table_param_grad[1] table_grad_var = self.table_param_grad[1]
if self.sync_mode:
self.table_grad_list = [ self.table_grad_list = [
program.global_block().create_var( program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" % name="%s.trainer_%d.pserver_%d" %
...@@ -265,6 +266,15 @@ class DistributeTranspiler: ...@@ -265,6 +266,15 @@ class DistributeTranspiler:
dtype=table_grad_var.dtype) dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints)) for index in range(len(self.pserver_endpoints))
] ]
else:
self.table_grad_list = [
program.global_block().create_var(
name="%s.pserver_%d" % (table_grad_var.name, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
...@@ -328,7 +338,7 @@ class DistributeTranspiler: ...@@ -328,7 +338,7 @@ class DistributeTranspiler:
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist) pserver_endpoints)
self._split_table_grad_and_add_send_vars(program, rpc_client_var, self._split_table_grad_and_add_send_vars(program, rpc_client_var,
pserver_endpoints) pserver_endpoints)
...@@ -551,7 +561,7 @@ class DistributeTranspiler: ...@@ -551,7 +561,7 @@ class DistributeTranspiler:
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist): pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None self.prefetch_input_vars = None
self.prefetch_output_vars = None self.prefetch_output_vars = None
...@@ -602,7 +612,7 @@ class DistributeTranspiler: ...@@ -602,7 +612,7 @@ class DistributeTranspiler:
"Out": self.prefetch_output_vars, "Out": self.prefetch_output_vars,
"RPCClient": rpc_client_var "RPCClient": rpc_client_var
}, },
attrs={"epmap": eplist}) attrs={"epmap": pserver_endpoints})
# insert concat_op # insert concat_op
program.global_block().insert_op( program.global_block().insert_op(
...@@ -731,6 +741,12 @@ class DistributeTranspiler: ...@@ -731,6 +741,12 @@ class DistributeTranspiler:
type="sum", type="sum",
inputs={"X": table_grad_list}, inputs={"X": table_grad_list},
outputs={"Out": [grad_var]}) outputs={"Out": [grad_var]})
else:
# in async_mode, for table gradient, it also need to be splited to each parameter server
old_name = grad_var.name
new_name = old_name + ".pserver_" + str(pserver_index)
grad_var = pserver_program.global_block().rename_var(old_name,
new_name)
lr_var = pserver_program.global_block().vars[table_opt_op.input( lr_var = pserver_program.global_block().vars[table_opt_op.input(
"LearningRate")[0]] "LearningRate")[0]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册