提交 3036cf55 编写于 作者: L linan17

update for paddle pslib

Change-Id: I8e42b1d6fac4599ba86ae30c8e4f43bffd3886ea
上级 25d9cef9
...@@ -140,23 +140,23 @@ class DownpourSGD(DeviceWorker): ...@@ -140,23 +140,23 @@ class DownpourSGD(DeviceWorker):
trainer_desc.device_worker_name = "DownpourWorker" trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num pull_thread.device_num = trainer_desc.thread_num
for i in self._fleet_desc.trainer_param.dense_table: for i in self._fleet_desc.trainer_param[0].dense_table:
if i.table_id in dense_table_set: if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add() dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(i.dense_variable_name) dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.table_id = \ dense_table.table_id = \
i.table_id i.table_id
sparse_len = len(self._fleet_desc.trainer_param.sparse_table) sparse_len = len(self._fleet_desc.trainer_param[0].sparse_table)
for i in range(sparse_len): for i in range(sparse_len):
sparse_table = downpour.sparse_table.add() sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \ sparse_table.table_id = \
self._fleet_desc.trainer_param.sparse_table[i].table_id self._fleet_desc.trainer_param[0].sparse_table[i].table_id
sparse_table.sparse_key_name.extend( sparse_table.sparse_key_name.extend(
self._fleet_desc.trainer_param.sparse_table[i].slot_key) self._fleet_desc.trainer_param[0].sparse_table[i].slot_key)
sparse_table.sparse_value_name.extend( sparse_table.sparse_value_name.extend(
self._fleet_desc.trainer_param.sparse_table[i].slot_value) self._fleet_desc.trainer_param[0].sparse_table[i].slot_value)
sparse_table.sparse_grad_name.extend( sparse_table.sparse_grad_name.extend(
self._fleet_desc.trainer_param.sparse_table[i].slot_gradient) self._fleet_desc.trainer_param[0].sparse_table[i].slot_gradient)
if opt_info["use_cvm"]: if opt_info["use_cvm"]:
sparse_table.emb_dim = \ sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[ self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
...@@ -173,14 +173,14 @@ class DownpourSGD(DeviceWorker): ...@@ -173,14 +173,14 @@ class DownpourSGD(DeviceWorker):
for i in opt_info["stat_var_names"]: for i in opt_info["stat_var_names"]:
downpour.stat_var_names.extend([i]) downpour.stat_var_names.extend([i])
for i in self._fleet_desc.trainer_param.dense_table: for i in self._fleet_desc.trainer_param[0].dense_table:
if i.table_id in dense_table_set: if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add() dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(i.dense_variable_name) dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend( dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name) i.dense_gradient_variable_name)
downpour.skip_ops.extend(self._fleet_desc.trainer_param.skip_op) downpour.skip_ops.extend(self._fleet_desc.trainer_param[0].skip_op)
if self._infer: if self._infer:
downpour.push_dense = False downpour.push_dense = False
downpour.push_sparse = False downpour.push_sparse = False
......
...@@ -91,8 +91,9 @@ class DownpourSGD(object): ...@@ -91,8 +91,9 @@ class DownpourSGD(object):
dense_table_index = 1 dense_table_index = 1
program_configs = [] program_configs = []
param_grads_list = [] param_grads_list = []
tp = ps_param.trainer_param.add()
for loss_index in range(len(losses)): for loss_index in range(len(losses)):
program_config = ps_param.trainer_param.program_config.add() program_config = tp.program_config.add()
program_config.program_id = str( program_config.program_id = str(
id(losses[loss_index].block.program)) id(losses[loss_index].block.program))
program_config.pull_sparse_table_id.extend([sparse_table_index]) program_config.pull_sparse_table_id.extend([sparse_table_index])
...@@ -140,13 +141,13 @@ class DownpourSGD(object): ...@@ -140,13 +141,13 @@ class DownpourSGD(object):
dense_table_index += 1 dense_table_index += 1
program_configs.append(program_config) program_configs.append(program_config)
ps_param.server_param.CopyFrom(server.get_desc()) ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc()) ps_param.trainer_param[0].CopyFrom(worker.get_desc())
for program_config in program_configs: for program_config in program_configs:
ps_param.trainer_param.program_config.extend([program_config]) ps_param.trainer_param[0].program_config.extend([program_config])
# Todo(guru4elephant): figure out how to support more sparse parameters # Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table # currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"] worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops) ps_param.trainer_param[0].skip_op.extend(worker_skipped_ops)
# all fleet operations should be defined in operators in the future # all fleet operations should be defined in operators in the future
# we want to return an object here containing: # we want to return an object here containing:
......
...@@ -89,7 +89,7 @@ class PSLib(Fleet): ...@@ -89,7 +89,7 @@ class PSLib(Fleet):
# barrier for init model # barrier for init model
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
if self._role_maker.is_first_worker(): if self._role_maker.is_first_worker():
tables = self._dist_desc.trainer_param.dense_table tables = self._dist_desc.trainer_param[0].dense_table
for prog, scope in zip(self._main_programs, self._scopes): for prog, scope in zip(self._main_programs, self._scopes):
prog_id = str(id(prog)) prog_id = str(id(prog))
prog_conf = self._opt_info['program_configs'][prog_id] prog_conf = self._opt_info['program_configs'][prog_id]
...@@ -304,7 +304,7 @@ class PSLib(Fleet): ...@@ -304,7 +304,7 @@ class PSLib(Fleet):
""" """
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
if self._role_maker.is_first_worker(): if self._role_maker.is_first_worker():
for i in self._opt_info["fleet_desc"].trainer_param.sparse_table: for i in self._opt_info["fleet_desc"].trainer_param[0].sparse_table:
self._fleet_ptr.shrink_sparse_table(i.table_id) self._fleet_ptr.shrink_sparse_table(i.table_id)
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
...@@ -330,7 +330,7 @@ class PSLib(Fleet): ...@@ -330,7 +330,7 @@ class PSLib(Fleet):
scope = fluid.global_scope() scope = fluid.global_scope()
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
if self._role_maker.is_first_worker(): if self._role_maker.is_first_worker():
for i in self._opt_info["fleet_desc"].trainer_param.dense_table: for i in self._opt_info["fleet_desc"].trainer_param[0].dense_table:
if table_id is not None and table_id != i.table_id: if table_id is not None and table_id != i.table_id:
continue continue
var_list = [var for var in i.dense_variable_name] var_list = [var for var in i.dense_variable_name]
...@@ -476,7 +476,7 @@ class PSLib(Fleet): ...@@ -476,7 +476,7 @@ class PSLib(Fleet):
if ret != 0: if ret != 0:
raise RuntimeError("download model proto file failed") raise RuntimeError("download model proto file failed")
model_proto_file = dest model_proto_file = dest
for i in self._opt_info["fleet_desc"].trainer_param.dense_table: for i in self._opt_info["fleet_desc"].trainer_param[0].dense_table:
if table_id is not None and table_id != i.table_id: if table_id is not None and table_id != i.table_id:
continue continue
table_var_names = [var for var in i.dense_variable_name] table_var_names = [var for var in i.dense_variable_name]
......
...@@ -144,7 +144,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -144,7 +144,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
with open(fleet_desc_file) as f: with open(fleet_desc_file) as f:
text_format.Merge(f.read(), ps_param) text_format.Merge(f.read(), ps_param)
server.get_desc().CopyFrom(ps_param.server_param) server.get_desc().CopyFrom(ps_param.server_param)
worker.get_desc().CopyFrom(ps_param.trainer_param) worker.get_desc().CopyFrom(ps_param.trainer_param[0])
sparse_table_index = 0 sparse_table_index = 0
for tn in sparse_table_names: for tn in sparse_table_names:
...@@ -231,12 +231,16 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -231,12 +231,16 @@ class DistributedAdam(DistributedOptimizerImplBase):
[dense_table_index]) [dense_table_index])
dense_table_index += 1 dense_table_index += 1
ps_param.server_param.CopyFrom(server.get_desc()) ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc()) if len(ps_param.trainer_param) == 0:
tp = ps_param.trainer_param.add()
tp.CopyFrom(worker.get_desc())
else:
ps_param.trainer_param[0].CopyFrom(worker.get_desc())
# Todo(guru4elephant): figure out how to support more sparse parameters # Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table # currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"] worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
if len(ps_param.trainer_param.skip_op) == 0: if len(ps_param.trainer_param[0].skip_op) == 0:
ps_param.trainer_param.skip_op.extend(worker_skipped_ops) ps_param.trainer_param[0].skip_op.extend(worker_skipped_ops)
opt_info = {} opt_info = {}
opt_info["program_configs"] = program_configs opt_info["program_configs"] = program_configs
......
...@@ -832,7 +832,7 @@ class FleetUtil(object): ...@@ -832,7 +832,7 @@ class FleetUtil(object):
""" """
fleet._role_maker._barrier_worker() fleet._role_maker._barrier_worker()
if fleet._role_maker.is_first_worker(): if fleet._role_maker.is_first_worker():
tables = fleet._dist_desc.trainer_param.dense_table tables = fleet._dist_desc.trainer_param[0].dense_table
prog_id = str(id(program)) prog_id = str(id(program))
prog_conf = fleet._opt_info['program_configs'][prog_id] prog_conf = fleet._opt_info['program_configs'][prog_id]
prog_tables = {} prog_tables = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册