提交 3036cf55 编写于 作者: L linan17

update for paddle pslib

Change-Id: I8e42b1d6fac4599ba86ae30c8e4f43bffd3886ea
上级 25d9cef9
......@@ -140,23 +140,23 @@ class DownpourSGD(DeviceWorker):
trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num
for i in self._fleet_desc.trainer_param.dense_table:
for i in self._fleet_desc.trainer_param[0].dense_table:
if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.table_id = \
i.table_id
sparse_len = len(self._fleet_desc.trainer_param.sparse_table)
sparse_len = len(self._fleet_desc.trainer_param[0].sparse_table)
for i in range(sparse_len):
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
self._fleet_desc.trainer_param.sparse_table[i].table_id
self._fleet_desc.trainer_param[0].sparse_table[i].table_id
sparse_table.sparse_key_name.extend(
self._fleet_desc.trainer_param.sparse_table[i].slot_key)
self._fleet_desc.trainer_param[0].sparse_table[i].slot_key)
sparse_table.sparse_value_name.extend(
self._fleet_desc.trainer_param.sparse_table[i].slot_value)
self._fleet_desc.trainer_param[0].sparse_table[i].slot_value)
sparse_table.sparse_grad_name.extend(
self._fleet_desc.trainer_param.sparse_table[i].slot_gradient)
self._fleet_desc.trainer_param[0].sparse_table[i].slot_gradient)
if opt_info["use_cvm"]:
sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
......@@ -173,14 +173,14 @@ class DownpourSGD(DeviceWorker):
for i in opt_info["stat_var_names"]:
downpour.stat_var_names.extend([i])
for i in self._fleet_desc.trainer_param.dense_table:
for i in self._fleet_desc.trainer_param[0].dense_table:
if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name)
downpour.skip_ops.extend(self._fleet_desc.trainer_param.skip_op)
downpour.skip_ops.extend(self._fleet_desc.trainer_param[0].skip_op)
if self._infer:
downpour.push_dense = False
downpour.push_sparse = False
......
......@@ -91,8 +91,9 @@ class DownpourSGD(object):
dense_table_index = 1
program_configs = []
param_grads_list = []
tp = ps_param.trainer_param.add()
for loss_index in range(len(losses)):
program_config = ps_param.trainer_param.program_config.add()
program_config = tp.program_config.add()
program_config.program_id = str(
id(losses[loss_index].block.program))
program_config.pull_sparse_table_id.extend([sparse_table_index])
......@@ -140,13 +141,13 @@ class DownpourSGD(object):
dense_table_index += 1
program_configs.append(program_config)
ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc())
ps_param.trainer_param[0].CopyFrom(worker.get_desc())
for program_config in program_configs:
ps_param.trainer_param.program_config.extend([program_config])
ps_param.trainer_param[0].program_config.extend([program_config])
# Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
ps_param.trainer_param[0].skip_op.extend(worker_skipped_ops)
# all fleet operations should be defined in operators in the future
# we want to return an object here containing:
......
......@@ -89,7 +89,7 @@ class PSLib(Fleet):
# barrier for init model
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
tables = self._dist_desc.trainer_param.dense_table
tables = self._dist_desc.trainer_param[0].dense_table
for prog, scope in zip(self._main_programs, self._scopes):
prog_id = str(id(prog))
prog_conf = self._opt_info['program_configs'][prog_id]
......@@ -304,7 +304,7 @@ class PSLib(Fleet):
"""
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
for i in self._opt_info["fleet_desc"].trainer_param.sparse_table:
for i in self._opt_info["fleet_desc"].trainer_param[0].sparse_table:
self._fleet_ptr.shrink_sparse_table(i.table_id)
self._role_maker._barrier_worker()
......@@ -330,7 +330,7 @@ class PSLib(Fleet):
scope = fluid.global_scope()
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
for i in self._opt_info["fleet_desc"].trainer_param.dense_table:
for i in self._opt_info["fleet_desc"].trainer_param[0].dense_table:
if table_id is not None and table_id != i.table_id:
continue
var_list = [var for var in i.dense_variable_name]
......@@ -476,7 +476,7 @@ class PSLib(Fleet):
if ret != 0:
raise RuntimeError("download model proto file failed")
model_proto_file = dest
for i in self._opt_info["fleet_desc"].trainer_param.dense_table:
for i in self._opt_info["fleet_desc"].trainer_param[0].dense_table:
if table_id is not None and table_id != i.table_id:
continue
table_var_names = [var for var in i.dense_variable_name]
......
......@@ -144,7 +144,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
with open(fleet_desc_file) as f:
text_format.Merge(f.read(), ps_param)
server.get_desc().CopyFrom(ps_param.server_param)
worker.get_desc().CopyFrom(ps_param.trainer_param)
worker.get_desc().CopyFrom(ps_param.trainer_param[0])
sparse_table_index = 0
for tn in sparse_table_names:
......@@ -231,12 +231,16 @@ class DistributedAdam(DistributedOptimizerImplBase):
[dense_table_index])
dense_table_index += 1
ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc())
if len(ps_param.trainer_param) == 0:
tp = ps_param.trainer_param.add()
tp.CopyFrom(worker.get_desc())
else:
ps_param.trainer_param[0].CopyFrom(worker.get_desc())
# Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
if len(ps_param.trainer_param.skip_op) == 0:
ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
if len(ps_param.trainer_param[0].skip_op) == 0:
ps_param.trainer_param[0].skip_op.extend(worker_skipped_ops)
opt_info = {}
opt_info["program_configs"] = program_configs
......
......@@ -832,7 +832,7 @@ class FleetUtil(object):
"""
fleet._role_maker._barrier_worker()
if fleet._role_maker.is_first_worker():
tables = fleet._dist_desc.trainer_param.dense_table
tables = fleet._dist_desc.trainer_param[0].dense_table
prog_id = str(id(program))
prog_conf = fleet._opt_info['program_configs'][prog_id]
prog_tables = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册