From 5970e8ac5ecc405a97e12bcc93b32d622b36e8d8 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Mon, 4 Nov 2019 18:41:34 +0800 Subject: [PATCH] find lookup table in order (#20932) test=develop --- .../parameter_server/pslib/optimizer_factory.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 0169de22ed1..59b220c1080 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -130,13 +130,22 @@ class DistributedAdam(DistributedOptimizerImplBase): find multi-sparse-table """ table_names = set() + cnt = 0 + tmp_list = [] + ret_list = [] for loss in losses: for op in loss.block.program.global_block().ops: if op.type == "lookup_table": if op.attr('is_distributed') is True: table_name = op.input("W")[0] - table_names.add(table_name) - return list(table_names) + if table_name not in table_names: + table_names.add(table_name) + tmp_list.append([table_name, cnt]) + cnt += 1 + tmp_list.sort(key=lambda k: k[1]) + for x in tmp_list: + ret_list.append(x[0]) + return ret_list def _minimize(self, losses, -- GitLab