未验证 提交 bd35a7f0 编写于 作者: Z zhang wenhui 提交者: GitHub

support fc sort by number, test=develop (#19466)

fleet_desc sort fc name by dictionary sort, but we want to sort by number.
上级 5c8f210c
...@@ -155,7 +155,7 @@ class DownpourServer(Server): ...@@ -155,7 +155,7 @@ class DownpourServer(Server):
table2.deconverter = "(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)" table2.deconverter = "(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)"
def add_dense_table(self, table_id, param_var, grad_var, strategy, def add_dense_table(self, table_id, param_var, grad_var, strategy,
sparse_table_name): sparse_table_names):
""" """
Args: Args:
table_id(int): id of sparse params table table_id(int): id of sparse params table
...@@ -166,7 +166,7 @@ class DownpourServer(Server): ...@@ -166,7 +166,7 @@ class DownpourServer(Server):
fea_dim = 0 fea_dim = 0
dense_param_vars = [] dense_param_vars = []
for p in param_var: for p in param_var:
if p.name not in sparse_table_name: if p.name not in sparse_table_names:
dense_param_vars.append(p) dense_param_vars.append(p)
for param in dense_param_vars: for param in dense_param_vars:
...@@ -216,7 +216,7 @@ class DownpourServer(Server): ...@@ -216,7 +216,7 @@ class DownpourServer(Server):
table.accessor.fea_dim = fea_dim table.accessor.fea_dim = fea_dim
def add_data_norm_table(self, table_id, learning_rate, param_var, grad_var, def add_data_norm_table(self, table_id, learning_rate, param_var, grad_var,
strategy, sparse_table_name): strategy, sparse_table_names):
""" """
Args: Args:
table_id(int): id of datanorm table table_id(int): id of datanorm table
...@@ -227,7 +227,7 @@ class DownpourServer(Server): ...@@ -227,7 +227,7 @@ class DownpourServer(Server):
fea_dim = 0 fea_dim = 0
dense_param_vars = [] dense_param_vars = []
for p in param_var: for p in param_var:
if p.name not in sparse_table_name: if p.name not in sparse_table_names:
dense_param_vars.append(p) dense_param_vars.append(p)
for param in dense_param_vars: for param in dense_param_vars:
...@@ -325,7 +325,7 @@ class DownpourWorker(Worker): ...@@ -325,7 +325,7 @@ class DownpourWorker(Worker):
[var.name + "@GRAD" for var in slot_value_vars]) [var.name + "@GRAD" for var in slot_value_vars])
def add_dense_table(self, table_id, learning_rate, param_vars, grad_vars, def add_dense_table(self, table_id, learning_rate, param_vars, grad_vars,
dense_start_table_id, sparse_table_name): dense_start_table_id, sparse_table_names):
""" """
Args: Args:
table_id(int): id of sparse params table table_id(int): id of sparse params table
...@@ -337,12 +337,12 @@ class DownpourWorker(Worker): ...@@ -337,12 +337,12 @@ class DownpourWorker(Worker):
return None return None
""" """
sparse_table_name_grad = [] sparse_table_name_grad = []
for name in sparse_table_name: for name in sparse_table_names:
sparse_table_name_grad.append(name + "@GRAD") sparse_table_name_grad.append(name + "@GRAD")
dense_param_name = [] dense_param_name = []
for p in param_vars: for p in param_vars:
if p.name not in sparse_table_name: if p.name not in sparse_table_names:
dense_param_name.append(p.name) dense_param_name.append(p.name)
dense_grad_name = [] dense_grad_name = []
...@@ -352,6 +352,7 @@ class DownpourWorker(Worker): ...@@ -352,6 +352,7 @@ class DownpourWorker(Worker):
dense_param_name.sort() dense_param_name.sort()
dense_grad_name.sort() dense_grad_name.sort()
for table in self._worker.dense_table: for table in self._worker.dense_table:
if table.table_id == table_id: if table.table_id == table_id:
desc_dense_param_name = list(self._worker.dense_table[ desc_dense_param_name = list(self._worker.dense_table[
...@@ -375,8 +376,31 @@ class DownpourWorker(Worker): ...@@ -375,8 +376,31 @@ class DownpourWorker(Worker):
table = self._worker.dense_table.add() table = self._worker.dense_table.add()
table.table_id = table_id table.table_id = table_id
table.dense_variable_name.extend(dense_param_name)
table.dense_gradient_variable_name.extend(dense_grad_name) def cmp_fc(x, y):
if x.startswith("fc_") and y.startswith("fc_"):
index_x = x.find('.')
index_y = y.find('.')
if index_x > 0 and index_y > 0:
num_x = x[3:index_x]
num_y = y[3:index_y]
if num_x.isdigit() and num_y.isdigit():
if int(num_x) < int(num_y):
return -1
if int(num_x) > int(num_y):
return 1
if x[index_x + 1] == 'w' and y[index_y + 1] == 'b':
return -1
if x[index_x + 1] == 'b' and y[index_y + 1] == 'w':
return 1
if x < y:
return -1
else:
return 1
table.dense_variable_name.extend(sorted(dense_param_name, cmp_fc))
table.dense_gradient_variable_name.extend(
sorted(dense_grad_name, cmp_fc))
def get_desc(self): def get_desc(self):
""" """
......
...@@ -126,12 +126,12 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -126,12 +126,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
[optimize_ops, grads_and_weights] [optimize_ops, grads_and_weights]
""" """
table_name = self._find_multi_distributed_lookup_table(losses) sparse_table_names = self._find_multi_distributed_lookup_table(losses)
inputs_dict = self._find_distributed_lookup_table_inputs( inputs_dict = self._find_distributed_lookup_table_inputs(
losses[0].block.program, table_name) losses[0].block.program, sparse_table_names)
outputs_dict = self._find_distributed_lookup_table_outputs( outputs_dict = self._find_distributed_lookup_table_outputs(
losses[0].block.program, table_name) losses[0].block.program, sparse_table_names)
ps_param = pslib.PSParameter() ps_param = pslib.PSParameter()
server = DownpourServer() server = DownpourServer()
...@@ -147,7 +147,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -147,7 +147,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
worker.get_desc().CopyFrom(ps_param.trainer_param) worker.get_desc().CopyFrom(ps_param.trainer_param)
sparse_table_index = 0 sparse_table_index = 0
for tn in table_name: for tn in sparse_table_names:
if strategy.get(tn) is not None: if strategy.get(tn) is not None:
server.add_sparse_table(sparse_table_index, strategy[tn]) server.add_sparse_table(sparse_table_index, strategy[tn])
else: else:
...@@ -199,13 +199,14 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -199,13 +199,14 @@ class DistributedAdam(DistributedOptimizerImplBase):
if strategy.get('dense_table') is not None: if strategy.get('dense_table') is not None:
server.add_dense_table(dense_table_index, params, grads, server.add_dense_table(dense_table_index, params, grads,
strategy['dense_table'], table_name) strategy['dense_table'],
sparse_table_names)
else: else:
server.add_dense_table(dense_table_index, params, grads, None, server.add_dense_table(dense_table_index, params, grads, None,
table_name) sparse_table_names)
worker.add_dense_table(dense_table_index, self._learning_rate, worker.add_dense_table(dense_table_index, self._learning_rate,
params, grads, dense_start_table_id, params, grads, dense_start_table_id,
table_name) sparse_table_names)
program_configs[program_id]["pull_dense"] = [dense_table_index] program_configs[program_id]["pull_dense"] = [dense_table_index]
program_configs[program_id]["push_dense"] = [dense_table_index] program_configs[program_id]["push_dense"] = [dense_table_index]
if len(data_norm_params) != 0 and len(data_norm_grads) != 0: if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
...@@ -214,15 +215,16 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -214,15 +215,16 @@ class DistributedAdam(DistributedOptimizerImplBase):
server.add_data_norm_table( server.add_data_norm_table(
dense_table_index, self._learning_rate, dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads, data_norm_params, data_norm_grads,
strategy['datanorm_table'], table_name) strategy['datanorm_table'], sparse_table_names)
else: else:
server.add_data_norm_table( server.add_data_norm_table(
dense_table_index, self._learning_rate, dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads, None, table_name) data_norm_params, data_norm_grads, None,
sparse_table_names)
worker.add_dense_table(dense_table_index, self._learning_rate, worker.add_dense_table(dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads, data_norm_params, data_norm_grads,
dense_start_table_id, table_name) dense_start_table_id, sparse_table_names)
program_configs[program_id]["pull_dense"].extend( program_configs[program_id]["pull_dense"].extend(
[dense_table_index]) [dense_table_index])
program_configs[program_id]["push_dense"].extend( program_configs[program_id]["push_dense"].extend(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册