提交 63055a3e 编写于 作者: Q qiaolongfei

complete grad_to_id

上级 39892fee
...@@ -328,7 +328,8 @@ from send_op and send back variables to recv_op. ...@@ -328,7 +328,8 @@ from send_op and send back variables to recv_op.
.SetDefault("127.0.0.1:6164") .SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"grad_to_id(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])", "grad_to_id",
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not") AddAttr<bool>("sync_mode", "if works at sync_mode or not")
......
...@@ -408,9 +408,9 @@ class DistributeTranspiler: ...@@ -408,9 +408,9 @@ class DistributeTranspiler:
in_name.startswith("beta2_pow_acc"): in_name.startswith("beta2_pow_acc"):
global_ops.append(op) global_ops.append(op)
def __append_optimize_op__(op, block): def __append_optimize_op__(op, block, grad_to_block_id):
if self._is_opt_op(op): if self._is_opt_op(op):
self._append_pserver_ops(block, op, endpoint, self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
default_main_program()) default_main_program())
else: else:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op)
...@@ -424,13 +424,14 @@ class DistributeTranspiler: ...@@ -424,13 +424,14 @@ class DistributeTranspiler:
self._append_pserver_non_opt_ops(lr_decay_block, op) self._append_pserver_non_opt_ops(lr_decay_block, op)
# append op to the current block # append op to the current block
grad_to_block_id = []
pre_block_idx = pserver_program.num_blocks - 1 pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver): for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx) per_opt_block = pserver_program.create_block(pre_block_idx)
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself # optimizer is connected to itself
if ufind.is_connected(op, opt_op) and op not in global_ops: if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block) __append_optimize_op__(op, per_opt_block, grad_to_block_id)
# append global ops # append global ops
opt_state_block = None opt_state_block = None
...@@ -476,7 +477,7 @@ class DistributeTranspiler: ...@@ -476,7 +477,7 @@ class DistributeTranspiler:
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block, "PrefetchBlock": prefetch_block,
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_id": [] "grad_to_id": grad_to_block_id
}) })
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
...@@ -883,7 +884,7 @@ class DistributeTranspiler: ...@@ -883,7 +884,7 @@ class DistributeTranspiler:
return orig_var_name return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint, def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
origin_program): grad_to_block_id, origin_program):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block() pserver_block = program.global_block()
new_inputs = dict() new_inputs = dict()
...@@ -904,6 +905,8 @@ class DistributeTranspiler: ...@@ -904,6 +905,8 @@ class DistributeTranspiler:
return return
merged_var = \ merged_var = \
pserver_block.vars[self._orig_varname(grad_block.name)] pserver_block.vars[self._orig_varname(grad_block.name)]
grad_to_block_id.append(merged_var.name + ":" + str(
optimize_block.idx))
if self.trainer_num > 1: if self.trainer_num > 1:
vars2merge = [] vars2merge = []
for i in xrange(self.trainer_num): for i in xrange(self.trainer_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册