未验证 提交 9d01a7c2 编写于 作者: Z zhang wenhui 提交者: GitHub

cherry pick , 1.6 fix bug , fix node.py for fc sort (#20461)

* cherry pick , 1.6 fix bug node fc sort

* cherry pick , 1.6 fix bug node fc sort

* cherry pick , 1.6 fix bug node fc sort, test=release/1.6

* cherry pick , 1.6 fix bug node fc sort& datanorm double, test=release/1.6
上级 f6beb0fa
......@@ -254,13 +254,13 @@ class DownpourServer(Server):
table = self._server.downpour_server_param.downpour_table_param.add()
table.table_id = table_id
table.table_class = strategy.get('datanorm_table_class',
"DownpourDenseDoubleTable")
'DownpourDenseTable')
table.type = pslib.PS_DENSE_TABLE
table.compress_in_save = strategy.get('datanorm_compress_in_save', True)
table.accessor.accessor_class = strategy.get(
'datanorm_accessor_class', "DownpourDenseValueDoubleAccessor")
'datanorm_accessor_class', 'DownpourDenseValueAccessor')
table.accessor.dense_sgd_param.name = strategy.get('datanorm_operation',
"summarydouble")
'summary')
table.accessor.dense_sgd_param.summary.summary_decay_rate = strategy.get(
'datanorm_decay_rate', 0.999999)
table.accessor.fea_dim = fea_dim
......@@ -377,30 +377,32 @@ class DownpourWorker(Worker):
table = self._worker.dense_table.add()
table.table_id = table_id
def cmp_fc(x, y):
if x.startswith("fc_") and y.startswith("fc_"):
index_x = x.find('.')
index_y = y.find('.')
if index_x > 0 and index_y > 0:
num_x = x[3:index_x]
num_y = y[3:index_y]
if num_x.isdigit() and num_y.isdigit():
if int(num_x) < int(num_y):
return -1
if int(num_x) > int(num_y):
return 1
if x[index_x + 1] == 'w' and y[index_y + 1] == 'b':
return -1
if x[index_x + 1] == 'b' and y[index_y + 1] == 'w':
return 1
if x < y:
return -1
else:
return 1
table.dense_variable_name.extend(sorted(dense_param_name, cmp_fc))
table.dense_gradient_variable_name.extend(
sorted(dense_grad_name, cmp_fc))
#def cmp_fc(x, y):
# if x.startswith("fc_") and y.startswith("fc_"):
# index_x = x.find('.')
# index_y = y.find('.')
# if index_x > 0 and index_y > 0:
# num_x = x[3:index_x]
# num_y = y[3:index_y]
# if num_x.isdigit() and num_y.isdigit():
# if int(num_x) < int(num_y):
# return -1
# if int(num_x) > int(num_y):
# return 1
# if x[index_x + 1] == 'w' and y[index_y + 1] == 'b':
# return -1
# if x[index_x + 1] == 'b' and y[index_y + 1] == 'w':
# return 1
# if x < y:
# return -1
# else:
# return 1
#table.dense_variable_name.extend(sorted(dense_param_name, cmp_fc))
#table.dense_gradient_variable_name.extend(
# sorted(dense_grad_name, cmp_fc))
table.dense_variable_name.extend(dense_param_name)
table.dense_gradient_variable_name.extend(dense_grad_name)
def get_desc(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册