diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py index b0cb6a004174494c09329f19c269f62ab7bee5f1..9084b0caad88450c99d5e6550ffa6bea86e06d65 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py @@ -122,12 +122,14 @@ class Fleet(object): print("You should run DistributedOptimizer.minimize() first") sys.exit(-1) - def init_worker(self, program): + def init_worker(self, programs): """ init_worker(): will be called by user. When a user knows current process is_server(), he/she should call init_worker() to initialize global information about worker and connect worker with pserver. """ + if not isinstance(programs, list): + programs = [programs] if self._opt_info: if "fleet_desc" in self._opt_info: self._dist_desc_str = text_format.MessageToString( @@ -145,14 +147,25 @@ class Fleet(object): self.role_maker_.barrier_worker() if self.role_maker_.is_first_worker(): tables = self._dist_desc.trainer_param.dense_table._values - for i in range(0, len(tables)): - table = tables[i]; - var_name_list = [] - for i in range(0, len(table.dense_variable_name)): - var_name_list.append(table.dense_variable_name[i]) - #print "table id ", table.table_id - #print "var_name_list ", var_name_list - self._fleet_ptr.init_model(program.desc, + for prog in programs: + prog_id = str(id(prog)) + prog_conf = self._opt_info['program_configs'][prog_id] + prog_tables = {} + for key in prog_conf: + if "dense" not in key: + continue + for table_id in prog_conf[key]: + prog_tables[int(table_id)] = 0 + for i in range(0, len(tables)): + table = tables[i] + if int(table.table_id) not in prog_tables: + continue + var_name_list = [] + for i in range(0, len(table.dense_variable_name)): + var_name_list.append(table.dense_variable_name[i]) + #print "table id ", table.table_id + #print "var_name_list ", var_name_list + self._fleet_ptr.init_model(prog.desc, int(table.table_id), var_name_list) self.role_maker_.barrier_worker()