未验证 提交 61c121cd 编写于 作者: F Fan Zhang 提交者: GitHub

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and...

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py (#34441) (#34454)
上级 75efc0ac
...@@ -180,7 +180,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -180,7 +180,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
if op.type in self.supported_embedding_types: if op.type in self.supported_embedding_types:
if op.attr('is_distributed') is True: if op.attr('is_distributed') is True:
table_name = op.input("W")[0] table_name = op.input("W")[0]
emb_size = local_vars[table_name].shape[1] emb_size = local_vars[table_name].shape[-1]
if d_size.get(table_name) is None: if d_size.get(table_name) is None:
d_size[table_name] = emb_size d_size[table_name] = emb_size
elif d_size[table_name] != emb_size: elif d_size[table_name] != emb_size:
...@@ -195,13 +195,10 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -195,13 +195,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
strategy[table_name] = dict() strategy[table_name] = dict()
st = strategy[table_name] st = strategy[table_name]
accessor = None accessor = "DownpourCtrAccessor"
if st.get("sparse_accessor_class") is not None: if st.get("sparse_accessor_class") is not None:
accessor = st["sparse_accessor_class"] accessor = st["sparse_accessor_class"]
if accessor is None:
accessor = "DownpourCtrAccessor"
# set sparse_embedx_dim in strategy, # set sparse_embedx_dim in strategy,
# user do not have to set it in config_fleet # user do not have to set it in config_fleet
if accessor == "DownpourFeatureValueAccessor" \ if accessor == "DownpourFeatureValueAccessor" \
...@@ -211,12 +208,12 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -211,12 +208,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
if st.get("sparse_embedx_dim") is not None \ if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3:
raise ValueError("fleet config sparse_embedx_dim=%s not" raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size - 3 = %s" % " equal to embedding dim - 3 = %s" %
(st["sparse_embedx_dim"], (st["sparse_embedx_dim"],
emb_to_size[table_name] - 3)) emb_to_size[table_name] - 3))
if st.get("sparse_embedx_dim") is None: if st.get("sparse_embedx_dim") is None:
logger.warning( logger.warning(
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim " "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. " "with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {} - 3.". "Hence automatically set sparse_embedx_dim = {} - 3.".
format(table_name, emb_to_size[table_name], emb_to_size[ format(table_name, emb_to_size[table_name], emb_to_size[
...@@ -226,12 +223,12 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -226,12 +223,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
if st.get("sparse_embedx_dim") is not None \ if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name]: and st["sparse_embedx_dim"] != emb_to_size[table_name]:
raise ValueError("fleet config sparse_embedx_dim=%s not" raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size = %s" % " equal to embedding dim = %s" %
(st["sparse_embedx_dim"], (st["sparse_embedx_dim"],
emb_to_size[table_name])) emb_to_size[table_name]))
if st.get("sparse_embedx_dim") is None: if st.get("sparse_embedx_dim") is None:
logger.warning( logger.warning(
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim " "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. " "with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {}.".format( "Hence automatically set sparse_embedx_dim = {}.".format(
table_name, emb_to_size[table_name], emb_to_size[ table_name, emb_to_size[table_name], emb_to_size[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册