diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index caaaa88cc4c6435fe46ba8622bebbf16abf33f30..607a3c94f8a4e7c14fa98ebb8de45241d97da437 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -313,7 +313,7 @@ class DistributedAdam(DistributedOptimizerImplBase): if op.type in self.supported_embedding_types: if op.attr('is_distributed') is True: table_name = op.input("W")[0] - emb_size = local_vars[table_name].shape[1] + emb_size = local_vars[table_name].shape[-1] if d_size.get(table_name) is None: d_size[table_name] = emb_size elif d_size[table_name] != emb_size: @@ -328,13 +328,10 @@ class DistributedAdam(DistributedOptimizerImplBase): strategy[table_name] = dict() st = strategy[table_name] - accessor = None + accessor = "DownpourCtrAccessor" if st.get("sparse_accessor_class") is not None: accessor = st["sparse_accessor_class"] - if accessor is None: - accessor = "DownpourCtrAccessor" - # set sparse_embedx_dim in strategy, # user do not have to set it in config_fleet if accessor == "DownpourFeatureValueAccessor" \ @@ -344,12 +341,12 @@ class DistributedAdam(DistributedOptimizerImplBase): if st.get("sparse_embedx_dim") is not None \ and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: raise ValueError("fleet config sparse_embedx_dim=%s not" - " equal to embedding size - 3 = %s" % + " equal to embedding dim - 3 = %s" % (st["sparse_embedx_dim"], emb_to_size[table_name] - 3)) if st.get("sparse_embedx_dim") is None: logger.warning( - "sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim " + "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "with same sparse table name is not set in config_fleet.py. " "Hence automatically set sparse_embedx_dim = {} - 3.". format(table_name, emb_to_size[table_name], emb_to_size[ @@ -359,12 +356,12 @@ class DistributedAdam(DistributedOptimizerImplBase): if st.get("sparse_embedx_dim") is not None \ and st["sparse_embedx_dim"] != emb_to_size[table_name]: raise ValueError("fleet config sparse_embedx_dim=%s not" - " equal to embedding size = %s" % + " equal to embedding dim = %s" % (st["sparse_embedx_dim"], emb_to_size[table_name])) if st.get("sparse_embedx_dim") is None: logger.warning( - "sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim " + "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "with same sparse table name is not set in config_fleet.py. " "Hence automatically set sparse_embedx_dim = {}.".format( table_name, emb_to_size[table_name], emb_to_size[