提交 20b76f3d 编写于 作者: X xujiaqi01 提交者: dongdaxiang

init model support multi programs

上级 f5c6a14b
......@@ -122,12 +122,14 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
def init_worker(self, program):
def init_worker(self, programs):
"""
init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect
worker with pserver.
"""
if not isinstance(programs, list):
programs = [programs]
if self._opt_info:
if "fleet_desc" in self._opt_info:
self._dist_desc_str = text_format.MessageToString(
......@@ -145,14 +147,25 @@ class Fleet(object):
self.role_maker_.barrier_worker()
if self.role_maker_.is_first_worker():
tables = self._dist_desc.trainer_param.dense_table._values
for i in range(0, len(tables)):
table = tables[i];
var_name_list = []
for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i])
#print "table id ", table.table_id
#print "var_name_list ", var_name_list
self._fleet_ptr.init_model(program.desc,
for prog in programs:
prog_id = str(id(prog))
prog_conf = self._opt_info['program_configs'][prog_id]
prog_tables = {}
for key in prog_conf:
if "dense" not in key:
continue
for table_id in prog_conf[key]:
prog_tables[int(table_id)] = 0
for i in range(0, len(tables)):
table = tables[i]
if int(table.table_id) not in prog_tables:
continue
var_name_list = []
for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i])
#print "table id ", table.table_id
#print "var_name_list ", var_name_list
self._fleet_ptr.init_model(prog.desc,
int(table.table_id),
var_name_list)
self.role_maker_.barrier_worker()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册