提交 63055a3e 编写于 作者: Q qiaolongfei

complete grad_to_id

上级 39892fee
......@@ -328,7 +328,8 @@ from send_op and send back variables to recv_op.
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>(
"grad_to_id(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])",
"grad_to_id",
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not")
......
......@@ -408,9 +408,9 @@ class DistributeTranspiler:
in_name.startswith("beta2_pow_acc"):
global_ops.append(op)
def __append_optimize_op__(op, block):
def __append_optimize_op__(op, block, grad_to_block_id):
if self._is_opt_op(op):
self._append_pserver_ops(block, op, endpoint,
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
default_main_program())
else:
self._append_pserver_non_opt_ops(block, op)
......@@ -424,13 +424,14 @@ class DistributeTranspiler:
self._append_pserver_non_opt_ops(lr_decay_block, op)
# append op to the current block
grad_to_block_id = []
pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx)
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block)
__append_optimize_op__(op, per_opt_block, grad_to_block_id)
# append global ops
opt_state_block = None
......@@ -476,7 +477,7 @@ class DistributeTranspiler:
"Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block,
"sync_mode": self.sync_mode,
"grad_to_id": []
"grad_to_id": grad_to_block_id
})
pserver_program.sync_with_cpp()
......@@ -883,7 +884,7 @@ class DistributeTranspiler:
return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
origin_program):
grad_to_block_id, origin_program):
program = optimize_block.program
pserver_block = program.global_block()
new_inputs = dict()
......@@ -904,6 +905,8 @@ class DistributeTranspiler:
return
merged_var = \
pserver_block.vars[self._orig_varname(grad_block.name)]
grad_to_block_id.append(merged_var.name + ":" + str(
optimize_block.idx))
if self.trainer_num > 1:
vars2merge = []
for i in xrange(self.trainer_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册