未验证 提交 61c121cd 编写于 作者: F Fan Zhang 提交者: GitHub

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and...

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py (#34441) (#34454)
上级 75efc0ac
......@@ -180,7 +180,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
if op.type in self.supported_embedding_types:
if op.attr('is_distributed') is True:
table_name = op.input("W")[0]
emb_size = local_vars[table_name].shape[1]
emb_size = local_vars[table_name].shape[-1]
if d_size.get(table_name) is None:
d_size[table_name] = emb_size
elif d_size[table_name] != emb_size:
......@@ -195,13 +195,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
strategy[table_name] = dict()
st = strategy[table_name]
accessor = None
accessor = "DownpourCtrAccessor"
if st.get("sparse_accessor_class") is not None:
accessor = st["sparse_accessor_class"]
if accessor is None:
accessor = "DownpourCtrAccessor"
# set sparse_embedx_dim in strategy,
# user do not have to set it in config_fleet
if accessor == "DownpourFeatureValueAccessor" \
......@@ -211,12 +208,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size - 3 = %s" %
" equal to embedding dim - 3 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name] - 3))
if st.get("sparse_embedx_dim") is None:
logger.warning(
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim "
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {} - 3.".
format(table_name, emb_to_size[table_name], emb_to_size[
......@@ -226,12 +223,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name]:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size = %s" %
" equal to embedding dim = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name]))
if st.get("sparse_embedx_dim") is None:
logger.warning(
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim "
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {}.".format(
table_name, emb_to_size[table_name], emb_to_size[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册