From 20b76f3deb056b7b7a9be1cbf6d9581bcbd3fc43 Mon Sep 17 00:00:00 2001 From: xujiaqi01 Date: Wed, 20 Mar 2019 16:10:35 +0800 Subject: [PATCH] init model support multi programs --- .../fleet/parameter_server/__init__.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py index b0cb6a0041..9084b0caad 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py @@ -122,12 +122,14 @@ class Fleet(object): print("You should run DistributedOptimizer.minimize() first") sys.exit(-1) - def init_worker(self, program): + def init_worker(self, programs): """ init_worker(): will be called by user. When a user knows current process is_server(), he/she should call init_worker() to initialize global information about worker and connect worker with pserver. """ + if not isinstance(programs, list): + programs = [programs] if self._opt_info: if "fleet_desc" in self._opt_info: self._dist_desc_str = text_format.MessageToString( @@ -145,14 +147,25 @@ class Fleet(object): self.role_maker_.barrier_worker() if self.role_maker_.is_first_worker(): tables = self._dist_desc.trainer_param.dense_table._values - for i in range(0, len(tables)): - table = tables[i]; - var_name_list = [] - for i in range(0, len(table.dense_variable_name)): - var_name_list.append(table.dense_variable_name[i]) - #print "table id ", table.table_id - #print "var_name_list ", var_name_list - self._fleet_ptr.init_model(program.desc, + for prog in programs: + prog_id = str(id(prog)) + prog_conf = self._opt_info['program_configs'][prog_id] + prog_tables = {} + for key in prog_conf: + if "dense" not in key: + continue + for table_id in prog_conf[key]: + prog_tables[int(table_id)] = 0 + for i in range(0, len(tables)): + table = tables[i] + if int(table.table_id) not in prog_tables: + continue + var_name_list = [] + for i in range(0, len(table.dense_variable_name)): + var_name_list.append(table.dense_variable_name[i]) + #print "table id ", table.table_id + #print "var_name_list ", var_name_list + self._fleet_ptr.init_model(prog.desc, int(table.table_id), var_name_list) self.role_maker_.barrier_worker() -- GitLab