提交 21f068ab 编写于 作者: Q qiaolongfei

optimize the name of table_grad_list

上级 16027ea1
...@@ -257,7 +257,7 @@ class DistributeTranspiler: ...@@ -257,7 +257,7 @@ class DistributeTranspiler:
][0] ][0]
table_grad_var = self.table_param_grad[1] table_grad_var = self.table_param_grad[1]
if self.sync_mode: if self.sync_mode:
self.table_grad_list = [ self.trainer_side_table_grad_list = [
program.global_block().create_var( program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" % name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index), (table_grad_var.name, trainer_id, index),
...@@ -267,7 +267,7 @@ class DistributeTranspiler: ...@@ -267,7 +267,7 @@ class DistributeTranspiler:
for index in range(len(self.pserver_endpoints)) for index in range(len(self.pserver_endpoints))
] ]
else: else:
self.table_grad_list = [ self.trainer_side_table_grad_list = [
program.global_block().create_var( program.global_block().create_var(
name="%s.pserver_%d" % (table_grad_var.name, index), name="%s.pserver_%d" % (table_grad_var.name, index),
type=table_grad_var.type, type=table_grad_var.type,
...@@ -648,11 +648,11 @@ class DistributeTranspiler: ...@@ -648,11 +648,11 @@ class DistributeTranspiler:
inputs={ inputs={
'Ids': [program.global_block().vars[table_grad_name]] 'Ids': [program.global_block().vars[table_grad_name]]
}, },
outputs={"Out": self.table_grad_list}) outputs={"Out": self.trainer_side_table_grad_list})
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 2, index=op_index + 2,
type="send_vars", type="send_vars",
inputs={'X': self.table_grad_list}, inputs={'X': self.trainer_side_table_grad_list},
outputs={"RPCClient": rpc_client_var}, outputs={"RPCClient": rpc_client_var},
attrs={"sync_send": True, attrs={"sync_send": True,
"epmap": pserver_endpoints}) "epmap": pserver_endpoints})
...@@ -717,7 +717,7 @@ class DistributeTranspiler: ...@@ -717,7 +717,7 @@ class DistributeTranspiler:
if self.sync_mode: if self.sync_mode:
# create grad vars in pserver program # create grad vars in pserver program
table_grad_var = self.table_param_grad[1] table_grad_var = self.table_param_grad[1]
table_grad_list = [ pserver_side_table_grad_list = [
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" % name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index), (table_grad_var.name, index, pserver_index),
...@@ -727,18 +727,19 @@ class DistributeTranspiler: ...@@ -727,18 +727,19 @@ class DistributeTranspiler:
for index in range(self.trainer_num) for index in range(self.trainer_num)
] ]
# append sum op for table_grad_list # append sum op for pserver_side_table_grad_list
table_opt_block.append_op( table_opt_block.append_op(
type="sum", type="sum",
inputs={"X": table_grad_list}, inputs={"X": pserver_side_table_grad_list},
outputs={"Out": [grad_var]}) outputs={"Out": [grad_var]})
else: else:
# in async_mode, for table gradient, it also need to be splited to each parameter server # in async_mode, for table gradient, it also need to be splited to each parameter server
origin_grad_name = grad_var.name origin_grad_name = grad_var.name
splited_grad_name = self.table_grad_list[pserver_index].name splited_grad_name = self.trainer_side_table_grad_list[
pserver_index].name
if not splited_grad_name.startswith(origin_grad_name): if not splited_grad_name.startswith(origin_grad_name):
raise ValueError("origin_grad_var: " + splited_grad_name + raise ValueError("origin_grad_var: " + splited_grad_name +
" grad_var:" + grad_var.name) " grad_var:" + grad_var.name)
grad_var = pserver_program.global_block().rename_var( grad_var = pserver_program.global_block().rename_var(
origin_grad_name, splited_grad_name) origin_grad_name, splited_grad_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册