From 63055a3e08768b1e376dbca749fa4c0fb42ef0a0 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 23 Apr 2018 11:17:59 +0800 Subject: [PATCH] complete grad_to_id --- paddle/fluid/operators/listen_and_serv_op.cc | 3 ++- python/paddle/fluid/distribute_transpiler.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index d5cb165351..a01f0aef8d 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -328,7 +328,8 @@ from send_op and send back variables to recv_op. .SetDefault("127.0.0.1:6164") .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); AddAttr>( - "grad_to_id(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])", + "grad_to_id", + "['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] " "a map from grad name to it's optimize block id") .SetDefault({}); AddAttr("sync_mode", "if works at sync_mode or not") diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 5f37a790f1..9a514bd9a3 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -408,9 +408,9 @@ class DistributeTranspiler: in_name.startswith("beta2_pow_acc"): global_ops.append(op) - def __append_optimize_op__(op, block): + def __append_optimize_op__(op, block, grad_to_block_id): if self._is_opt_op(op): - self._append_pserver_ops(block, op, endpoint, + self._append_pserver_ops(block, op, endpoint, grad_to_block_id, default_main_program()) else: self._append_pserver_non_opt_ops(block, op) @@ -424,13 +424,14 @@ class DistributeTranspiler: self._append_pserver_non_opt_ops(lr_decay_block, op) # append op to the current block + grad_to_block_id = [] pre_block_idx = pserver_program.num_blocks - 1 for idx, opt_op in enumerate(opt_op_on_pserver): per_opt_block = pserver_program.create_block(pre_block_idx) for _, op in enumerate(self.optimize_ops): # optimizer is connected to itself if ufind.is_connected(op, opt_op) and op not in global_ops: - __append_optimize_op__(op, per_opt_block) + __append_optimize_op__(op, per_opt_block, grad_to_block_id) # append global ops opt_state_block = None @@ -476,7 +477,7 @@ class DistributeTranspiler: "Fanin": self.trainer_num, "PrefetchBlock": prefetch_block, "sync_mode": self.sync_mode, - "grad_to_id": [] + "grad_to_id": grad_to_block_id }) pserver_program.sync_with_cpp() @@ -883,7 +884,7 @@ class DistributeTranspiler: return orig_var_name def _append_pserver_ops(self, optimize_block, opt_op, endpoint, - origin_program): + grad_to_block_id, origin_program): program = optimize_block.program pserver_block = program.global_block() new_inputs = dict() @@ -904,6 +905,8 @@ class DistributeTranspiler: return merged_var = \ pserver_block.vars[self._orig_varname(grad_block.name)] + grad_to_block_id.append(merged_var.name + ":" + str( + optimize_block.idx)) if self.trainer_num > 1: vars2merge = [] for i in xrange(self.trainer_num): -- GitLab