From 9574bcd7c0f361225f7b774b4db4ae8e09f29a19 Mon Sep 17 00:00:00 2001 From: Fan Zhang Date: Fri, 12 Nov 2021 10:53:16 +0800 Subject: [PATCH] [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py (#36753) * [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py * [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py --- .../pslib/optimizer_factory.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 61630d77692..3b4a3aacc06 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -335,19 +335,27 @@ class DistributedAdam(DistributedOptimizerImplBase): if st.get("sparse_accessor_class") is not None: accessor = st["sparse_accessor_class"] - # set sparse_embedx_dim in strategy, - # user do not have to set it in config_fleet + # set sparse_embedx_dim in the strategy according to accessor and use_cvm config if accessor == "DownpourFeatureValueAccessor" \ or accessor == "DownpourCtrAccessor" \ or accessor == "DownpourDoubleUnitAccessor" \ or accessor == "DownpourUnitAccessor": if st.get("sparse_embedx_dim") is not None \ + and strategy.get("use_cvm") == True \ and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: raise ValueError("fleet config sparse_embedx_dim=%s not" " equal to embedding dim - 3 = %s" % (st["sparse_embedx_dim"], emb_to_size[table_name] - 3)) - if st.get("sparse_embedx_dim") is None: + if st.get("sparse_embedx_dim") is not None \ + and strategy.get("use_cvm") == False \ + and st["sparse_embedx_dim"] != emb_to_size[table_name] - 1: + raise ValueError("fleet config sparse_embedx_dim=%s not" + " equal to embedding dim - 1 = %s" % + (st["sparse_embedx_dim"], + emb_to_size[table_name] - 1)) + if st.get("sparse_embedx_dim") is None \ + and strategy.get("use_cvm") == True: logger.warning( "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "with same sparse table name is not set in config_fleet.py. " @@ -355,6 +363,15 @@ class DistributedAdam(DistributedOptimizerImplBase): format(table_name, emb_to_size[table_name], emb_to_size[ table_name])) st["sparse_embedx_dim"] = emb_to_size[table_name] - 3 + if st.get("sparse_embedx_dim") is None \ + and strategy.get("use_cvm") == False: + logger.warning( + "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " + "with same sparse table name is not set in config_fleet.py. " + "Hence automatically set sparse_embedx_dim = {} - 1.". + format(table_name, emb_to_size[table_name], emb_to_size[ + table_name])) + st["sparse_embedx_dim"] = emb_to_size[table_name] - 1 elif accessor == "DownpourSparseValueAccessor": if st.get("sparse_embedx_dim") is not None \ and st["sparse_embedx_dim"] != emb_to_size[table_name]: -- GitLab