From b82e6520e1339e4127a7ecf0b05ee598312f5b6d Mon Sep 17 00:00:00 2001 From: zhang wenhui Date: Fri, 11 Oct 2019 10:36:37 +0800 Subject: [PATCH] fix pslib datanorm double bug (#20297) * fix fc sort . test=develop --- .../fleet/parameter_server/pslib/node.py | 56 ++++++++++--------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py index ce9a368e8c..bc8d8c2153 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/node.py @@ -254,13 +254,13 @@ class DownpourServer(Server): table = self._server.downpour_server_param.downpour_table_param.add() table.table_id = table_id table.table_class = strategy.get('datanorm_table_class', - "DownpourDenseDoubleTable") + 'DownpourDenseTable') table.type = pslib.PS_DENSE_TABLE table.compress_in_save = strategy.get('datanorm_compress_in_save', True) table.accessor.accessor_class = strategy.get( - 'datanorm_accessor_class', "DownpourDenseValueDoubleAccessor") + 'datanorm_accessor_class', 'DownpourDenseValueAccessor') table.accessor.dense_sgd_param.name = strategy.get('datanorm_operation', - "summarydouble") + 'summary') table.accessor.dense_sgd_param.summary.summary_decay_rate = strategy.get( 'datanorm_decay_rate', 0.999999) table.accessor.fea_dim = fea_dim @@ -377,30 +377,32 @@ class DownpourWorker(Worker): table = self._worker.dense_table.add() table.table_id = table_id - def cmp_fc(x, y): - if x.startswith("fc_") and y.startswith("fc_"): - index_x = x.find('.') - index_y = y.find('.') - if index_x > 0 and index_y > 0: - num_x = x[3:index_x] - num_y = y[3:index_y] - if num_x.isdigit() and num_y.isdigit(): - if int(num_x) < int(num_y): - return -1 - if int(num_x) > int(num_y): - return 1 - if x[index_x + 1] == 'w' and y[index_y + 1] == 'b': - return -1 - if x[index_x + 1] == 'b' and y[index_y + 1] == 'w': - return 1 - if x < y: - return -1 - else: - return 1 - - table.dense_variable_name.extend(sorted(dense_param_name, cmp_fc)) - table.dense_gradient_variable_name.extend( - sorted(dense_grad_name, cmp_fc)) + #def cmp_fc(x, y): + # if x.startswith("fc_") and y.startswith("fc_"): + # index_x = x.find('.') + # index_y = y.find('.') + # if index_x > 0 and index_y > 0: + # num_x = x[3:index_x] + # num_y = y[3:index_y] + # if num_x.isdigit() and num_y.isdigit(): + # if int(num_x) < int(num_y): + # return -1 + # if int(num_x) > int(num_y): + # return 1 + # if x[index_x + 1] == 'w' and y[index_y + 1] == 'b': + # return -1 + # if x[index_x + 1] == 'b' and y[index_y + 1] == 'w': + # return 1 + # if x < y: + # return -1 + # else: + # return 1 + + #table.dense_variable_name.extend(sorted(dense_param_name, cmp_fc)) + #table.dense_gradient_variable_name.extend( + # sorted(dense_grad_name, cmp_fc)) + table.dense_variable_name.extend(dense_param_name) + table.dense_gradient_variable_name.extend(dense_grad_name) def get_desc(self): """ -- GitLab