未验证 提交 15cb05c8 编写于 作者: F Fan Zhang 提交者: GitHub

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and...

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py (#36215)

* [CPU-PSLIB] Add consistency insepection of use_var_list and data_generator data

* [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py
上级 5f4af11a
......@@ -206,12 +206,21 @@ class DistributedAdam(DistributedOptimizerImplBase):
or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor":
if st.get("sparse_embedx_dim") is not None \
and strategy.get("use_cvm") == True \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding dim - 3 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name] - 3))
if st.get("sparse_embedx_dim") is None:
if st.get("sparse_embedx_dim") is not None \
and strategy.get("use_cvm") == False \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 1:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding dim - 1 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name] - 1))
if st.get("sparse_embedx_dim") is None \
and strategy.get("use_cvm") == True:
logger.warning(
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
......@@ -219,6 +228,15 @@ class DistributedAdam(DistributedOptimizerImplBase):
format(table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 3
if st.get("sparse_embedx_dim") is None \
and strategy.get("use_cvm") == False:
logger.warning(
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {} - 1.".
format(table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 1
elif accessor == "DownpourSparseValueAccessor":
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册