diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 61630d7769206921fa29c0bc731c948f92fe8f91..3b4a3aacc06c6c37e16b44d85da408a5d5f79aa4 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -335,19 +335,27 @@ class DistributedAdam(DistributedOptimizerImplBase): if st.get("sparse_accessor_class") is not None: accessor = st["sparse_accessor_class"] - # set sparse_embedx_dim in strategy, - # user do not have to set it in config_fleet + # set sparse_embedx_dim in the strategy according to accessor and use_cvm config if accessor == "DownpourFeatureValueAccessor" \ or accessor == "DownpourCtrAccessor" \ or accessor == "DownpourDoubleUnitAccessor" \ or accessor == "DownpourUnitAccessor": if st.get("sparse_embedx_dim") is not None \ + and strategy.get("use_cvm") == True \ and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: raise ValueError("fleet config sparse_embedx_dim=%s not" " equal to embedding dim - 3 = %s" % (st["sparse_embedx_dim"], emb_to_size[table_name] - 3)) - if st.get("sparse_embedx_dim") is None: + if st.get("sparse_embedx_dim") is not None \ + and strategy.get("use_cvm") == False \ + and st["sparse_embedx_dim"] != emb_to_size[table_name] - 1: + raise ValueError("fleet config sparse_embedx_dim=%s not" + " equal to embedding dim - 1 = %s" % + (st["sparse_embedx_dim"], + emb_to_size[table_name] - 1)) + if st.get("sparse_embedx_dim") is None \ + and strategy.get("use_cvm") == True: logger.warning( "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "with same sparse table name is not set in config_fleet.py. " @@ -355,6 +363,15 @@ class DistributedAdam(DistributedOptimizerImplBase): format(table_name, emb_to_size[table_name], emb_to_size[ table_name])) st["sparse_embedx_dim"] = emb_to_size[table_name] - 3 + if st.get("sparse_embedx_dim") is None \ + and strategy.get("use_cvm") == False: + logger.warning( + "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " + "with same sparse table name is not set in config_fleet.py. " + "Hence automatically set sparse_embedx_dim = {} - 1.". + format(table_name, emb_to_size[table_name], emb_to_size[ + table_name])) + st["sparse_embedx_dim"] = emb_to_size[table_name] - 1 elif accessor == "DownpourSparseValueAccessor": if st.get("sparse_embedx_dim") is not None \ and st["sparse_embedx_dim"] != emb_to_size[table_name]: