diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 05ffdefe05b3d47cf43fb5eca67761d452bcad4b..7b8bf17f27ca308184fab03d3bbf90cb1c5943ee 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -279,11 +279,20 @@ class DistributeTranspiler: grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) + assert (len(grad_blocks) == len(param_blocks)) # step2: Create new vars for the parameters and gradients blocks and # add ops to do the split. - grad_var_mapping = self._append_split_op(program, grad_blocks) param_var_mapping = self._create_vars_from_blocklist(program, param_blocks) + grad_var_mapping = self._create_vars_from_blocklist( + program, grad_blocks, add_trainer_suffix=self.trainer_num > 1) + grad_param_mapping = dict() + for g, p in zip(grad_blocks, param_blocks): + g_name, g_bid, _ = g.split(":") + p_name, p_bid, _ = p.split(":") + grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \ + param_var_mapping[p_name][int(p_bid)] + rpc_client_var = program.global_block().create_var( name=RPC_CLIENT_VAR_NAME, persistable=True, @@ -304,15 +313,21 @@ class DistributeTranspiler: # step 3.1: insert send op to send gradient vars to parameter servers ps_dispatcher.reset() - for varname, send_vars in grad_var_mapping.items(): + send_vars = [] + for varname, splited_vars in grad_var_mapping.items(): index = find_op_by_output_arg(program.global_block(), varname) - eplist = ps_dispatcher.dispatch(send_vars) + eplist = ps_dispatcher.dispatch(splited_vars) + if len(splited_vars) > 1: + self._insert_split_op(program, varname, splited_vars) + index += 1 program.global_block().insert_op( - index=index, + index=index + 1, type="send_vars", - inputs={"X": send_vars}, + inputs={"X": splited_vars}, outputs={"RPCClient": rpc_client_var}, attrs={"epmap": eplist}) + for _, var in enumerate(splited_vars): + send_vars.append(var) if self.sync_mode: program.global_block().append_op( @@ -322,21 +337,12 @@ class DistributeTranspiler: attrs={"endpoints": pserver_endpoints}) # step 3.2: insert recv op to receive parameters from parameter server - ps_dispatcher.reset() recv_vars = [] - for b in param_blocks: - varname, block_id, _ = b.split(":") - recv_vars.append(param_var_mapping[varname][int(block_id)]) - for b in grad_blocks: - varname, block_id, _ = b.split(":") - send_vars.append(grad_var_mapping[varname][int(block_id)]) - + for _, var in enumerate(send_vars): + recv_vars.append(grad_param_mapping[var]) + ps_dispatcher.reset() eplist = ps_dispatcher.dispatch(recv_vars) - for i, ep in enumerate(eplist): - self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) - self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) - program.global_block().append_op( type="recv", inputs={}, @@ -344,6 +350,10 @@ class DistributeTranspiler: "RPCClient": rpc_client_var}, attrs={"epmap": eplist}) + for i, ep in enumerate(eplist): + self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) + self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + # TODO(Yancey1989): check dist lookup table if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, @@ -848,6 +858,34 @@ class DistributeTranspiler: lod_level=var.lod_level, persistable=persistable) + def _insert_split_op(self, program, orig_varname, splited_vars): + orig_var = program.global_block().vars[orig_varname] + index = find_op_by_output_arg(program.global_block(), orig_varname) + if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: + height_sections = [] + for v in splited_vars: + height_sections.append(v.shape[0]) + program.global_block().insert_op( + index=index + 1, + type="split_selected_rows", + inputs={"X": orig_var}, + outputs={"Out": splited_vars}, + attrs={"height_sections": height_sections}) + elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: + sections = [] + for v in splited_vars: + sections.append(v.shape[0]) + program.global_block().insert_op( + index=index + 1, + type="split_byref", + inputs={"X": orig_var}, + outputs={"Out": splited_vars}, + attrs={"sections": sections} # assume split evenly + ) + else: + AssertionError("Variable type should be in set " + "[LOD_TENSOR, SELECTED_ROWS]") + def _append_split_op(self, program, gradblocks): # Split variables that need to be split and append respective ops add_suffix = False @@ -860,11 +898,13 @@ class DistributeTranspiler: if len(splited_vars) <= 1: continue orig_var = program.global_block().vars[varname] + index = find_op_by_output_arg(program.global_block(), orig_var.name) if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: height_sections = [] for v in splited_vars: height_sections.append(v.shape[0]) - program.global_block().append_op( + program.global_block().insert_op( + index=index + 1, type="split_selected_rows", inputs={"X": orig_var}, outputs={"Out": splited_vars}, @@ -873,7 +913,8 @@ class DistributeTranspiler: sections = [] for v in splited_vars: sections.append(v.shape[0]) - program.global_block().append_op( + program.global_block().insert_op( + index=index + 1, type="split_byref", inputs={"X": orig_var}, outputs={"Out": splited_vars},