From 9f4b66f6840a6417e43a3f162c64b398ef31cb04 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 28 May 2018 17:32:09 +0800 Subject: [PATCH] table gradient should be split and send to each pserver --- python/paddle/fluid/framework.py | 3 +- .../fluid/transpiler/distribute_transpiler.py | 40 +++++++++++++------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 08b756d95b9..33b5caa0eab 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -797,7 +797,7 @@ class Block(object): Rename variable in vars and ops' inputs and outputs """ if not self.has_var(name): - raise ValueError("var %s is not in current" % name) + raise ValueError("var %s is not in current block" % name) v = self.var(name) if type(v) == Parameter: var_type = "Parameter" @@ -843,6 +843,7 @@ class Block(object): self.vars[new_name] = var del self.vars[name] self.sync_with_cpp() + return var def remove_var(self, name): self.sync_with_cpp() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 42ff0a9eb11..d497f79e9c9 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -256,15 +256,25 @@ class DistributeTranspiler: if param_grad[0].name == self.table_name ][0] table_grad_var = self.table_param_grad[1] - self.table_grad_list = [ - program.global_block().create_var( - name="%s.trainer_%d.pserver_%d" % - (table_grad_var.name, trainer_id, index), - type=table_grad_var.type, - shape=table_grad_var.shape, - dtype=table_grad_var.dtype) - for index in range(len(self.pserver_endpoints)) - ] + if self.sync_mode: + self.table_grad_list = [ + program.global_block().create_var( + name="%s.trainer_%d.pserver_%d" % + (table_grad_var.name, trainer_id, index), + type=table_grad_var.type, + shape=table_grad_var.shape, + dtype=table_grad_var.dtype) + for index in range(len(self.pserver_endpoints)) + ] + else: + self.table_grad_list = [ + program.global_block().create_var( + name="%s.pserver_%d" % (table_grad_var.name, index), + type=table_grad_var.type, + shape=table_grad_var.shape, + dtype=table_grad_var.dtype) + for index in range(len(self.pserver_endpoints)) + ] grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) @@ -328,7 +338,7 @@ class DistributeTranspiler: if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, - eplist) + pserver_endpoints) self._split_table_grad_and_add_send_vars(program, rpc_client_var, pserver_endpoints) @@ -551,7 +561,7 @@ class DistributeTranspiler: # transpiler function for dis lookup_table def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, - eplist): + pserver_endpoints): # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op self.prefetch_input_vars = None self.prefetch_output_vars = None @@ -602,7 +612,7 @@ class DistributeTranspiler: "Out": self.prefetch_output_vars, "RPCClient": rpc_client_var }, - attrs={"epmap": eplist}) + attrs={"epmap": pserver_endpoints}) # insert concat_op program.global_block().insert_op( @@ -731,6 +741,12 @@ class DistributeTranspiler: type="sum", inputs={"X": table_grad_list}, outputs={"Out": [grad_var]}) + else: + # in async_mode, for table gradient, it also need to be splited to each parameter server + old_name = grad_var.name + new_name = old_name + ".pserver_" + str(pserver_index) + grad_var = pserver_program.global_block().rename_var(old_name, + new_name) lr_var = pserver_program.global_block().vars[table_opt_op.input( "LearningRate")[0]] -- GitLab