From 15cb05c8031075c44f2a73214e7f2b16b0497956 Mon Sep 17 00:00:00 2001 From: Fan Zhang Date: Tue, 26 Oct 2021 18:40:39 +0800 Subject: [PATCH] [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py (#36215) * [CPU-PSLIB] Add consistency insepection of use_var_list and data_generator data * [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py --- .../pslib/optimizer_factory.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 1e750031f33..cd97543ba8b 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -206,12 +206,21 @@ class DistributedAdam(DistributedOptimizerImplBase): or accessor == "DownpourDoubleUnitAccessor" \ or accessor == "DownpourUnitAccessor": if st.get("sparse_embedx_dim") is not None \ + and strategy.get("use_cvm") == True \ and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: raise ValueError("fleet config sparse_embedx_dim=%s not" " equal to embedding dim - 3 = %s" % (st["sparse_embedx_dim"], emb_to_size[table_name] - 3)) - if st.get("sparse_embedx_dim") is None: + if st.get("sparse_embedx_dim") is not None \ + and strategy.get("use_cvm") == False \ + and st["sparse_embedx_dim"] != emb_to_size[table_name] - 1: + raise ValueError("fleet config sparse_embedx_dim=%s not" + " equal to embedding dim - 1 = %s" % + (st["sparse_embedx_dim"], + emb_to_size[table_name] - 1)) + if st.get("sparse_embedx_dim") is None \ + and strategy.get("use_cvm") == True: logger.warning( "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "with same sparse table name is not set in config_fleet.py. " @@ -219,6 +228,15 @@ class DistributedAdam(DistributedOptimizerImplBase): format(table_name, emb_to_size[table_name], emb_to_size[ table_name])) st["sparse_embedx_dim"] = emb_to_size[table_name] - 3 + if st.get("sparse_embedx_dim") is None \ + and strategy.get("use_cvm") == False: + logger.warning( + "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " + "with same sparse table name is not set in config_fleet.py. " + "Hence automatically set sparse_embedx_dim = {} - 1.". + format(table_name, emb_to_size[table_name], emb_to_size[ + table_name])) + st["sparse_embedx_dim"] = emb_to_size[table_name] - 1 elif accessor == "DownpourSparseValueAccessor": if st.get("sparse_embedx_dim") is not None \ and st["sparse_embedx_dim"] != emb_to_size[table_name]: -- GitLab