diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 7d5a8d22ccbf71d53ec9bb2fc959a14f532423e8..311c6271f2f0ada98eda559bcf776d80a2cfbfdc 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -180,7 +180,7 @@ class DistributedAdam(DistributedOptimizerImplBase): if op.type in self.supported_embedding_types: if op.attr('is_distributed') is True: table_name = op.input("W")[0] - emb_size = local_vars[table_name].shape[1] + emb_size = local_vars[table_name].shape[-1] if d_size.get(table_name) is None: d_size[table_name] = emb_size elif d_size[table_name] != emb_size: @@ -195,13 +195,10 @@ class DistributedAdam(DistributedOptimizerImplBase): strategy[table_name] = dict() st = strategy[table_name] - accessor = None + accessor = "DownpourCtrAccessor" if st.get("sparse_accessor_class") is not None: accessor = st["sparse_accessor_class"] - if accessor is None: - accessor = "DownpourCtrAccessor" - # set sparse_embedx_dim in strategy, # user do not have to set it in config_fleet if accessor == "DownpourFeatureValueAccessor" \ @@ -211,12 +208,12 @@ class DistributedAdam(DistributedOptimizerImplBase): if st.get("sparse_embedx_dim") is not None \ and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: raise ValueError("fleet config sparse_embedx_dim=%s not" - " equal to embedding size - 3 = %s" % + " equal to embedding dim - 3 = %s" % (st["sparse_embedx_dim"], emb_to_size[table_name] - 3)) if st.get("sparse_embedx_dim") is None: logger.warning( - "sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim " + "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "with same sparse table name is not set in config_fleet.py. " "Hence automatically set sparse_embedx_dim = {} - 3.". format(table_name, emb_to_size[table_name], emb_to_size[ @@ -226,12 +223,12 @@ class DistributedAdam(DistributedOptimizerImplBase): if st.get("sparse_embedx_dim") is not None \ and st["sparse_embedx_dim"] != emb_to_size[table_name]: raise ValueError("fleet config sparse_embedx_dim=%s not" - " equal to embedding size = %s" % + " equal to embedding dim = %s" % (st["sparse_embedx_dim"], emb_to_size[table_name])) if st.get("sparse_embedx_dim") is None: logger.warning( - "sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim " + "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "with same sparse table name is not set in config_fleet.py. " "Hence automatically set sparse_embedx_dim = {}.".format( table_name, emb_to_size[table_name], emb_to_size[