diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 1e750031f337cf5a6ed0f24797c5676869cce3a7..cd97543ba8b7cc7fbd1bf846b3a06d9e6d2925b4 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -206,12 +206,21 @@ class DistributedAdam(DistributedOptimizerImplBase): or accessor == "DownpourDoubleUnitAccessor" \ or accessor == "DownpourUnitAccessor": if st.get("sparse_embedx_dim") is not None \ + and strategy.get("use_cvm") == True \ and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: raise ValueError("fleet config sparse_embedx_dim=%s not" " equal to embedding dim - 3 = %s" % (st["sparse_embedx_dim"], emb_to_size[table_name] - 3)) - if st.get("sparse_embedx_dim") is None: + if st.get("sparse_embedx_dim") is not None \ + and strategy.get("use_cvm") == False \ + and st["sparse_embedx_dim"] != emb_to_size[table_name] - 1: + raise ValueError("fleet config sparse_embedx_dim=%s not" + " equal to embedding dim - 1 = %s" % + (st["sparse_embedx_dim"], + emb_to_size[table_name] - 1)) + if st.get("sparse_embedx_dim") is None \ + and strategy.get("use_cvm") == True: logger.warning( "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "with same sparse table name is not set in config_fleet.py. " @@ -219,6 +228,15 @@ class DistributedAdam(DistributedOptimizerImplBase): format(table_name, emb_to_size[table_name], emb_to_size[ table_name])) st["sparse_embedx_dim"] = emb_to_size[table_name] - 3 + if st.get("sparse_embedx_dim") is None \ + and strategy.get("use_cvm") == False: + logger.warning( + "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " + "with same sparse table name is not set in config_fleet.py. " + "Hence automatically set sparse_embedx_dim = {} - 1.". + format(table_name, emb_to_size[table_name], emb_to_size[ + table_name])) + st["sparse_embedx_dim"] = emb_to_size[table_name] - 1 elif accessor == "DownpourSparseValueAccessor": if st.get("sparse_embedx_dim") is not None \ and st["sparse_embedx_dim"] != emb_to_size[table_name]: