未验证 提交 9574bcd7 编写于 作者: F Fan Zhang 提交者: GitHub

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and...

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py (#36753)

* [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py

* [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py
上级 9303b095
......@@ -335,19 +335,27 @@ class DistributedAdam(DistributedOptimizerImplBase):
if st.get("sparse_accessor_class") is not None:
accessor = st["sparse_accessor_class"]
# set sparse_embedx_dim in strategy,
# user do not have to set it in config_fleet
# set sparse_embedx_dim in the strategy according to accessor and use_cvm config
if accessor == "DownpourFeatureValueAccessor" \
or accessor == "DownpourCtrAccessor" \
or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor":
if st.get("sparse_embedx_dim") is not None \
and strategy.get("use_cvm") == True \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding dim - 3 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name] - 3))
if st.get("sparse_embedx_dim") is None:
if st.get("sparse_embedx_dim") is not None \
and strategy.get("use_cvm") == False \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 1:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding dim - 1 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name] - 1))
if st.get("sparse_embedx_dim") is None \
and strategy.get("use_cvm") == True:
logger.warning(
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
......@@ -355,6 +363,15 @@ class DistributedAdam(DistributedOptimizerImplBase):
format(table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 3
if st.get("sparse_embedx_dim") is None \
and strategy.get("use_cvm") == False:
logger.warning(
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {} - 1.".
format(table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 1
elif accessor == "DownpourSparseValueAccessor":
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册