diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index f409b13f01bb209c4ea2a67c8d121288ab238b4d..d2b514b7b48665a1ef25674bfffe5b8a0b0d57c9 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -154,7 +154,7 @@ inline void MergeMultipleVarsIntoOneBySection( } void prefetch(const std::string& id_name, const std::string& out_name, - const std::string& table_name, + const std::vector& table_names, const std::vector& epmap, const std::vector& height_sections, const framework::ExecutionContext& context) { @@ -190,7 +190,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, << " to get " << out_var_names[i] << " back"; rets.push_back(rpc_client->AsyncPrefetchVar( epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i], - table_name + ".block" + std::to_string(i))); + table_names[i])); } else { VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; } diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.h b/paddle/fluid/operators/distributed/parameter_prefetch.h index 9e680ec20bf382177d93bf7f08632d6af9fafe50..0693cfc1fd2b5bb1bee1609149d7e056557d65db 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.h +++ b/paddle/fluid/operators/distributed/parameter_prefetch.h @@ -24,7 +24,7 @@ namespace operators { namespace distributed { void prefetch(const std::string& id_name, const std::string& out_name, - const std::string& table_name, + const std::vector& table_names, const std::vector& epmap, const std::vector& height_sections, const framework::ExecutionContext& context); diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index faf91775e4fe68b87e4bebd292e346de53cd0b85..ab6518641bd72ff4cfc0fc2af081c24c307be46e 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -99,6 +99,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { "(string vector, default 127.0.0.1:6164)" "Server endpoints in the order of input variables for mapping") .SetDefault({}); + AddAttr>( + "table_names", + "(string vector, the splited table names that will be fetched from " + "parameter server)" + "in the order of input variables for mapping") + .SetDefault({}); AddComment(R"DOC( Lookup Table Operator. diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 223de413b2152329eef23a156df6d79ea48c02e0..12c5f8f1eb6be340811fe7eca5cade27744d70db 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -54,13 +54,14 @@ class LookupTableKernel : public framework::OpKernel { auto remote_prefetch = context.Attr("remote_prefetch"); auto height_sections = context.Attr>("height_sections"); + auto table_names = context.Attr>("table_names"); if (remote_prefetch) { // if emap is not empty, then the parameter will be fetched from remote // parameter // server #ifdef PADDLE_WITH_DISTRIBUTE - operators::distributed::prefetch(id_name, out_name, table_name, epmap, + operators::distributed::prefetch(id_name, out_name, table_names, epmap, height_sections, context); #else PADDLE_THROW( diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 59f89e331dac5a10a68cdfd7fc82564ff72d953e..a1ccb704b2d43434beaceb3305783b09dd1f02e2 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -247,10 +247,11 @@ class DistributeTranspiler(object): return sparse_update_ops def _update_remote_sparse_update_op(self, param_varname, height_sections, - endpint_map): + endpint_map, table_names): for op in self.sparse_update_ops: if param_varname in op.input_arg_names: op._set_attr('epmap', endpint_map) + op._set_attr('table_names', table_names) op._set_attr('height_sections', height_sections) op._set_attr('trainer_id', self.trainer_id) @@ -326,6 +327,7 @@ class DistributeTranspiler(object): # get all sparse update ops self.sparse_update_ops = self._get_all_remote_sparse_update_op( self.origin_program) + # use_sparse_update_param_name -> split_height_section self.sparse_param_to_height_sections = dict() # add distributed attrs to program @@ -365,6 +367,13 @@ class DistributeTranspiler(object): splited_grad_varname = splited_vars[0].name index = find_op_by_output_arg( program.global_block(), splited_grad_varname, reverse=True) + if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS: + sparse_param_name = self.grad_name_to_param_name[ + splited_grad_varname] + if self._is_input_of_remote_sparse_update_op( + sparse_param_name): + self.sparse_param_to_height_sections[ + sparse_param_name] = [splited_vars[0].shape[0]] elif len(splited_vars) > 1: orig_var = program.global_block().vars[splited_grad_varname] index = find_op_by_output_arg( @@ -435,9 +444,11 @@ class DistributeTranspiler(object): all_recv_outputs = [] for param_varname, splited_var in six.iteritems(self.param_var_mapping): eps = [] + table_names = [] for var in splited_var: index = [v.name for v in recv_vars].index(var.name) eps.append(eplist[index]) + table_names.append(var.name) if self.sync_mode: recv_dep_in = send_barrier_out else: @@ -457,8 +468,8 @@ class DistributeTranspiler(object): if param_varname in self.sparse_param_to_height_sections: height_sections = self.sparse_param_to_height_sections[ param_varname] - self._update_remote_sparse_update_op(param_varname, - height_sections, eps) + self._update_remote_sparse_update_op( + param_varname, height_sections, eps, table_names) else: all_recv_outputs.extend(splited_var) program.global_block().append_op(