未验证 提交 0d794983 编写于 作者: Z zhang wenhui 提交者: GitHub

fix fleet_desc bug && support format for abacus hotstart (#19430)

fix fleet_desc dense_table unsort bug ,not  support format for abacus hotstart yet.
上级 2e3ec66b
......@@ -154,7 +154,8 @@ class DownpourServer(Server):
table2.converter = "(scripts/xbox_compressor_mf.py | bin/xbox_pb_converter)"
table2.deconverter = "(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)"
def add_dense_table(self, table_id, param_var, grad_var, strategy):
def add_dense_table(self, table_id, param_var, grad_var, strategy,
sparse_table_name):
"""
Args:
table_id(int): id of sparse params table
......@@ -163,8 +164,12 @@ class DownpourServer(Server):
return None
"""
fea_dim = 0
for param in filter(lambda x: x.name.find("embedding") == -1,
param_var):
dense_param_vars = []
for p in param_var:
if p.name not in sparse_table_name:
dense_param_vars.append(p)
for param in dense_param_vars:
fea_dim += reduce(lambda x, y: x * y, param.shape, 1)
for table in self._server.downpour_server_param.downpour_table_param:
......@@ -211,7 +216,7 @@ class DownpourServer(Server):
table.accessor.fea_dim = fea_dim
def add_data_norm_table(self, table_id, learning_rate, param_var, grad_var,
strategy):
strategy, sparse_table_name):
"""
Args:
table_id(int): id of datanorm table
......@@ -220,8 +225,12 @@ class DownpourServer(Server):
return None
"""
fea_dim = 0
for param in filter(lambda x: x.name.find("embedding") == -1,
param_var):
dense_param_vars = []
for p in param_var:
if p.name not in sparse_table_name:
dense_param_vars.append(p)
for param in dense_param_vars:
fea_dim += reduce(lambda x, y: x * y, param.shape, 1)
for table in self._server.downpour_server_param.downpour_table_param:
......@@ -316,7 +325,7 @@ class DownpourWorker(Worker):
[var.name + "@GRAD" for var in slot_value_vars])
def add_dense_table(self, table_id, learning_rate, param_vars, grad_vars,
dense_start_table_id):
dense_start_table_id, sparse_table_name):
"""
Args:
table_id(int): id of sparse params table
......@@ -327,12 +336,34 @@ class DownpourWorker(Worker):
Returns:
return None
"""
sparse_table_name_grad = []
for name in sparse_table_name:
sparse_table_name_grad.append(name + "@GRAD")
dense_param_name = []
for p in param_vars:
if p.name not in sparse_table_name:
dense_param_name.append(p.name)
dense_grad_name = []
for g in grad_vars:
if g.name not in sparse_table_name_grad:
dense_grad_name.append(g.name)
dense_param_name.sort()
dense_grad_name.sort()
for table in self._worker.dense_table:
if table.table_id == table_id:
if filter(lambda x: x.find("embedding") == -1, [p.name for p in param_vars]) ==\
self._worker.dense_table[table_id - dense_start_table_id].dense_variable_name:
if filter(lambda x: x.find("embedding") == -1, [g.name for g in grad_vars]) ==\
self._worker.dense_table[table_id - dense_start_table_id].dense_gradient_variable_name:
desc_dense_param_name = list(self._worker.dense_table[
table_id - dense_start_table_id].dense_variable_name)
desc_dense_param_name.sort()
if dense_param_name == desc_dense_param_name:
desc_dense_grad_name = list(self._worker.dense_table[
table_id - dense_start_table_id]
.dense_gradient_variable_name)
desc_dense_grad_name.sort()
if dense_grad_name == desc_dense_grad_name:
return
else:
raise ValueError(
......@@ -344,12 +375,8 @@ class DownpourWorker(Worker):
table = self._worker.dense_table.add()
table.table_id = table_id
table.dense_variable_name.extend(
filter(lambda x: x.find("embedding") == -1,
[p.name for p in param_vars]))
table.dense_gradient_variable_name.extend(
filter(lambda x: x.find("embedding") == -1,
[g.name for g in grad_vars]))
table.dense_variable_name.extend(dense_param_name)
table.dense_gradient_variable_name.extend(dense_grad_name)
def get_desc(self):
"""
......
......@@ -127,12 +127,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
"""
table_name = self._find_multi_distributed_lookup_table(losses)
prefetch_slots = find_distributed_lookup_table_inputs(
losses[0].block.program, table_name[0])
inputs_dict = self._find_distributed_lookup_table_inputs(
losses[0].block.program, table_name)
prefetch_slots_emb = find_distributed_lookup_table_outputs(
losses[0].block.program, table_name[0])
outputs_dict = self._find_distributed_lookup_table_outputs(
losses[0].block.program, table_name)
......@@ -191,6 +187,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
data_norm_params.append(i[0])
if not is_data_norm_data:
params.append(i[0])
for i in params_grads:
is_data_norm_data = False
for data_norm_grad in self.data_norm_name:
......@@ -199,13 +196,16 @@ class DistributedAdam(DistributedOptimizerImplBase):
data_norm_grads.append(i[1])
if not is_data_norm_data:
grads.append(i[1])
if strategy.get('dense_table') is not None:
server.add_dense_table(dense_table_index, params, grads,
strategy['dense_table'])
strategy['dense_table'], table_name)
else:
server.add_dense_table(dense_table_index, params, grads, None)
server.add_dense_table(dense_table_index, params, grads, None,
table_name)
worker.add_dense_table(dense_table_index, self._learning_rate,
params, grads, dense_start_table_id)
params, grads, dense_start_table_id,
table_name)
program_configs[program_id]["pull_dense"] = [dense_table_index]
program_configs[program_id]["push_dense"] = [dense_table_index]
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
......@@ -214,15 +214,15 @@ class DistributedAdam(DistributedOptimizerImplBase):
server.add_data_norm_table(
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads,
strategy['datanorm_table'])
strategy['datanorm_table'], table_name)
else:
server.add_data_norm_table(
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads, None)
data_norm_params, data_norm_grads, None, table_name)
worker.add_dense_table(dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads,
dense_start_table_id)
dense_start_table_id, table_name)
program_configs[program_id]["pull_dense"].extend(
[dense_table_index])
program_configs[program_id]["push_dense"].extend(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册