未验证 提交 b82e6520 编写于 作者: Z zhang wenhui 提交者: GitHub

fix pslib datanorm double bug (#20297)

* fix fc sort . test=develop
上级 53d8799b
...@@ -254,13 +254,13 @@ class DownpourServer(Server): ...@@ -254,13 +254,13 @@ class DownpourServer(Server):
table = self._server.downpour_server_param.downpour_table_param.add() table = self._server.downpour_server_param.downpour_table_param.add()
table.table_id = table_id table.table_id = table_id
table.table_class = strategy.get('datanorm_table_class', table.table_class = strategy.get('datanorm_table_class',
"DownpourDenseDoubleTable") 'DownpourDenseTable')
table.type = pslib.PS_DENSE_TABLE table.type = pslib.PS_DENSE_TABLE
table.compress_in_save = strategy.get('datanorm_compress_in_save', True) table.compress_in_save = strategy.get('datanorm_compress_in_save', True)
table.accessor.accessor_class = strategy.get( table.accessor.accessor_class = strategy.get(
'datanorm_accessor_class', "DownpourDenseValueDoubleAccessor") 'datanorm_accessor_class', 'DownpourDenseValueAccessor')
table.accessor.dense_sgd_param.name = strategy.get('datanorm_operation', table.accessor.dense_sgd_param.name = strategy.get('datanorm_operation',
"summarydouble") 'summary')
table.accessor.dense_sgd_param.summary.summary_decay_rate = strategy.get( table.accessor.dense_sgd_param.summary.summary_decay_rate = strategy.get(
'datanorm_decay_rate', 0.999999) 'datanorm_decay_rate', 0.999999)
table.accessor.fea_dim = fea_dim table.accessor.fea_dim = fea_dim
...@@ -377,30 +377,32 @@ class DownpourWorker(Worker): ...@@ -377,30 +377,32 @@ class DownpourWorker(Worker):
table = self._worker.dense_table.add() table = self._worker.dense_table.add()
table.table_id = table_id table.table_id = table_id
def cmp_fc(x, y): #def cmp_fc(x, y):
if x.startswith("fc_") and y.startswith("fc_"): # if x.startswith("fc_") and y.startswith("fc_"):
index_x = x.find('.') # index_x = x.find('.')
index_y = y.find('.') # index_y = y.find('.')
if index_x > 0 and index_y > 0: # if index_x > 0 and index_y > 0:
num_x = x[3:index_x] # num_x = x[3:index_x]
num_y = y[3:index_y] # num_y = y[3:index_y]
if num_x.isdigit() and num_y.isdigit(): # if num_x.isdigit() and num_y.isdigit():
if int(num_x) < int(num_y): # if int(num_x) < int(num_y):
return -1 # return -1
if int(num_x) > int(num_y): # if int(num_x) > int(num_y):
return 1 # return 1
if x[index_x + 1] == 'w' and y[index_y + 1] == 'b': # if x[index_x + 1] == 'w' and y[index_y + 1] == 'b':
return -1 # return -1
if x[index_x + 1] == 'b' and y[index_y + 1] == 'w': # if x[index_x + 1] == 'b' and y[index_y + 1] == 'w':
return 1 # return 1
if x < y: # if x < y:
return -1 # return -1
else: # else:
return 1 # return 1
table.dense_variable_name.extend(sorted(dense_param_name, cmp_fc)) #table.dense_variable_name.extend(sorted(dense_param_name, cmp_fc))
table.dense_gradient_variable_name.extend( #table.dense_gradient_variable_name.extend(
sorted(dense_grad_name, cmp_fc)) # sorted(dense_grad_name, cmp_fc))
table.dense_variable_name.extend(dense_param_name)
table.dense_gradient_variable_name.extend(dense_grad_name)
def get_desc(self): def get_desc(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册