diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 89bc24802751340b6d4657be8673d714f3d3dc2b..7a3cf1230b29b40d762b978d8ca504f5c53c85d4 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -236,6 +236,22 @@ class DistributeTranspiler(object): else: raise ValueError("must set trainer_id > 0") + def _get_all_sparse_update_op(self, main_program): + sparse_update_ops = [] + sparse_update_op_types = ["lookup_table"] + for op in main_program.global_block().ops: + if op.type in sparse_update_op_types and op.attr( + 'is_sparse') is True and not op.attr('is_distributed'): + sparse_update_ops.append(op) + return sparse_update_ops + + def _update_sparse_update_op(self, param_varname, height_sections, + endpint_map): + for op in self.sparse_update_ops: + if param_varname in op.input_arg_names: + op._set_attr('epmap', endpint_map) + op._set_attr('height_sections', height_sections) + def transpile(self, trainer_id, program=None, @@ -299,6 +315,11 @@ class DistributeTranspiler(object): self.param_name_to_grad_name[param_var.name] = grad_var.name self.grad_name_to_param_name[grad_var.name] = param_var.name + # get all sparse update ops + self.sparse_update_ops = self._get_all_sparse_update_op( + self.origin_program) + self.sparse_param_to_height_sections = dict() + # add distributed attrs to program self.origin_program._is_distributed = True self.origin_program._endpoints = self.pserver_endpoints @@ -425,18 +446,24 @@ class DistributeTranspiler(object): if len(splited_trainer_grad) == 1: recv_op_role_var_name = splited_trainer_grad[0].name - program.global_block().append_op( - type="recv", - inputs={"X": [recv_dep_in]}, - outputs={"Out": splited_var}, - attrs={ - "epmap": eps, - "trainer_id": self.trainer_id, - RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, - OP_ROLE_VAR_ATTR_NAME: - [param_varname, recv_op_role_var_name], - "sync_mode": not self.sync_mode - }) + if param_varname in self.sparse_param_to_height_sections: + height_sections = self.sparse_param_to_height_sections[ + param_varname] + self._update_sparse_update_op(param_varname, height_sections, + eps) + else: + program.global_block().append_op( + type="recv", + inputs={"X": [recv_dep_in]}, + outputs={"Out": splited_var}, + attrs={ + "epmap": eps, + "trainer_id": self.trainer_id, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: + [param_varname, recv_op_role_var_name], + "sync_mode": not self.sync_mode + }) if self.sync_mode: # form a WAW dependency @@ -454,14 +481,17 @@ class DistributeTranspiler(object): if len(splited_var) <= 1: continue orig_param = program.global_block().vars[param_varname] - program.global_block().append_op( - type="concat", - inputs={"X": splited_var}, - outputs={"Out": [orig_param]}, - attrs={ - "axis": 0, - RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE - }) + print("sparse_param_to_height_sections: " + str( + self.sparse_param_to_height_sections)) + if param_varname not in self.sparse_param_to_height_sections: + program.global_block().append_op( + type="concat", + inputs={"X": splited_var}, + outputs={"Out": [orig_param]}, + attrs={ + "axis": 0, + RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE + }) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) @@ -1237,9 +1267,8 @@ to transpile() call.") # create table param and grad var in pserver program # create table optimize block in pserver program table_opt_op = [ - op for op in self.optimize_ops - if 'Param' in op.input_names and op.input("Param")[0] == - self.table_name + op for op in self.optimize_ops if 'Param' in op.input_names and + op.input("Param")[0] == self.table_name ][0] origin_param_var = self.origin_program.global_block().vars[ @@ -1418,6 +1447,10 @@ to transpile() call.") height_sections = [] for v in splited_vars: height_sections.append(v.shape[0]) + sparse_param_name = self.grad_name_to_param_name[orig_var.name] + if sparse_param_name != self.table_name: + self.sparse_param_to_height_sections[ + sparse_param_name] = height_sections program.global_block()._insert_op( index=index + 1, type="split_selected_rows",