提交 21f068ab 编写于 作者: Q qiaolongfei

optimize the name of table_grad_list

上级 16027ea1
......@@ -257,7 +257,7 @@ class DistributeTranspiler:
][0]
table_grad_var = self.table_param_grad[1]
if self.sync_mode:
self.table_grad_list = [
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index),
......@@ -267,7 +267,7 @@ class DistributeTranspiler:
for index in range(len(self.pserver_endpoints))
]
else:
self.table_grad_list = [
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.pserver_%d" % (table_grad_var.name, index),
type=table_grad_var.type,
......@@ -648,11 +648,11 @@ class DistributeTranspiler:
inputs={
'Ids': [program.global_block().vars[table_grad_name]]
},
outputs={"Out": self.table_grad_list})
outputs={"Out": self.trainer_side_table_grad_list})
program.global_block().insert_op(
index=op_index + 2,
type="send_vars",
inputs={'X': self.table_grad_list},
inputs={'X': self.trainer_side_table_grad_list},
outputs={"RPCClient": rpc_client_var},
attrs={"sync_send": True,
"epmap": pserver_endpoints})
......@@ -717,7 +717,7 @@ class DistributeTranspiler:
if self.sync_mode:
# create grad vars in pserver program
table_grad_var = self.table_param_grad[1]
table_grad_list = [
pserver_side_table_grad_list = [
pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index),
......@@ -727,15 +727,16 @@ class DistributeTranspiler:
for index in range(self.trainer_num)
]
# append sum op for table_grad_list
# append sum op for pserver_side_table_grad_list
table_opt_block.append_op(
type="sum",
inputs={"X": table_grad_list},
inputs={"X": pserver_side_table_grad_list},
outputs={"Out": [grad_var]})
else:
# in async_mode, for table gradient, it also need to be splited to each parameter server
origin_grad_name = grad_var.name
splited_grad_name = self.table_grad_list[pserver_index].name
splited_grad_name = self.trainer_side_table_grad_list[
pserver_index].name
if not splited_grad_name.startswith(origin_grad_name):
raise ValueError("origin_grad_var: " + splited_grad_name +
" grad_var:" + grad_var.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册