未验证 提交 bd35a7f0 编写于 作者: Z zhang wenhui 提交者: GitHub

support fc sort by number, test=develop (#19466)

fleet_desc sort fc name by dictionary sort, but we want to sort by number.
上级 5c8f210c
......@@ -155,7 +155,7 @@ class DownpourServer(Server):
table2.deconverter = "(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)"
def add_dense_table(self, table_id, param_var, grad_var, strategy,
sparse_table_name):
sparse_table_names):
"""
Args:
table_id(int): id of sparse params table
......@@ -166,7 +166,7 @@ class DownpourServer(Server):
fea_dim = 0
dense_param_vars = []
for p in param_var:
if p.name not in sparse_table_name:
if p.name not in sparse_table_names:
dense_param_vars.append(p)
for param in dense_param_vars:
......@@ -216,7 +216,7 @@ class DownpourServer(Server):
table.accessor.fea_dim = fea_dim
def add_data_norm_table(self, table_id, learning_rate, param_var, grad_var,
strategy, sparse_table_name):
strategy, sparse_table_names):
"""
Args:
table_id(int): id of datanorm table
......@@ -227,7 +227,7 @@ class DownpourServer(Server):
fea_dim = 0
dense_param_vars = []
for p in param_var:
if p.name not in sparse_table_name:
if p.name not in sparse_table_names:
dense_param_vars.append(p)
for param in dense_param_vars:
......@@ -325,7 +325,7 @@ class DownpourWorker(Worker):
[var.name + "@GRAD" for var in slot_value_vars])
def add_dense_table(self, table_id, learning_rate, param_vars, grad_vars,
dense_start_table_id, sparse_table_name):
dense_start_table_id, sparse_table_names):
"""
Args:
table_id(int): id of sparse params table
......@@ -337,12 +337,12 @@ class DownpourWorker(Worker):
return None
"""
sparse_table_name_grad = []
for name in sparse_table_name:
for name in sparse_table_names:
sparse_table_name_grad.append(name + "@GRAD")
dense_param_name = []
for p in param_vars:
if p.name not in sparse_table_name:
if p.name not in sparse_table_names:
dense_param_name.append(p.name)
dense_grad_name = []
......@@ -352,6 +352,7 @@ class DownpourWorker(Worker):
dense_param_name.sort()
dense_grad_name.sort()
for table in self._worker.dense_table:
if table.table_id == table_id:
desc_dense_param_name = list(self._worker.dense_table[
......@@ -375,8 +376,31 @@ class DownpourWorker(Worker):
table = self._worker.dense_table.add()
table.table_id = table_id
table.dense_variable_name.extend(dense_param_name)
table.dense_gradient_variable_name.extend(dense_grad_name)
def cmp_fc(x, y):
if x.startswith("fc_") and y.startswith("fc_"):
index_x = x.find('.')
index_y = y.find('.')
if index_x > 0 and index_y > 0:
num_x = x[3:index_x]
num_y = y[3:index_y]
if num_x.isdigit() and num_y.isdigit():
if int(num_x) < int(num_y):
return -1
if int(num_x) > int(num_y):
return 1
if x[index_x + 1] == 'w' and y[index_y + 1] == 'b':
return -1
if x[index_x + 1] == 'b' and y[index_y + 1] == 'w':
return 1
if x < y:
return -1
else:
return 1
table.dense_variable_name.extend(sorted(dense_param_name, cmp_fc))
table.dense_gradient_variable_name.extend(
sorted(dense_grad_name, cmp_fc))
def get_desc(self):
"""
......
......@@ -126,12 +126,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
[optimize_ops, grads_and_weights]
"""
table_name = self._find_multi_distributed_lookup_table(losses)
sparse_table_names = self._find_multi_distributed_lookup_table(losses)
inputs_dict = self._find_distributed_lookup_table_inputs(
losses[0].block.program, table_name)
losses[0].block.program, sparse_table_names)
outputs_dict = self._find_distributed_lookup_table_outputs(
losses[0].block.program, table_name)
losses[0].block.program, sparse_table_names)
ps_param = pslib.PSParameter()
server = DownpourServer()
......@@ -147,7 +147,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
worker.get_desc().CopyFrom(ps_param.trainer_param)
sparse_table_index = 0
for tn in table_name:
for tn in sparse_table_names:
if strategy.get(tn) is not None:
server.add_sparse_table(sparse_table_index, strategy[tn])
else:
......@@ -199,13 +199,14 @@ class DistributedAdam(DistributedOptimizerImplBase):
if strategy.get('dense_table') is not None:
server.add_dense_table(dense_table_index, params, grads,
strategy['dense_table'], table_name)
strategy['dense_table'],
sparse_table_names)
else:
server.add_dense_table(dense_table_index, params, grads, None,
table_name)
sparse_table_names)
worker.add_dense_table(dense_table_index, self._learning_rate,
params, grads, dense_start_table_id,
table_name)
sparse_table_names)
program_configs[program_id]["pull_dense"] = [dense_table_index]
program_configs[program_id]["push_dense"] = [dense_table_index]
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
......@@ -214,15 +215,16 @@ class DistributedAdam(DistributedOptimizerImplBase):
server.add_data_norm_table(
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads,
strategy['datanorm_table'], table_name)
strategy['datanorm_table'], sparse_table_names)
else:
server.add_data_norm_table(
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads, None, table_name)
data_norm_params, data_norm_grads, None,
sparse_table_names)
worker.add_dense_table(dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads,
dense_start_table_id, table_name)
dense_start_table_id, sparse_table_names)
program_configs[program_id]["pull_dense"].extend(
[dense_table_index])
program_configs[program_id]["push_dense"].extend(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册