未验证 提交 15cb05c8 编写于 作者: F Fan Zhang 提交者: GitHub

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and...

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py (#36215)

* [CPU-PSLIB] Add consistency insepection of use_var_list and data_generator data

* [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py
上级 5f4af11a
...@@ -206,12 +206,21 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -206,12 +206,21 @@ class DistributedAdam(DistributedOptimizerImplBase):
or accessor == "DownpourDoubleUnitAccessor" \ or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor": or accessor == "DownpourUnitAccessor":
if st.get("sparse_embedx_dim") is not None \ if st.get("sparse_embedx_dim") is not None \
and strategy.get("use_cvm") == True \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3:
raise ValueError("fleet config sparse_embedx_dim=%s not" raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding dim - 3 = %s" % " equal to embedding dim - 3 = %s" %
(st["sparse_embedx_dim"], (st["sparse_embedx_dim"],
emb_to_size[table_name] - 3)) emb_to_size[table_name] - 3))
if st.get("sparse_embedx_dim") is None: if st.get("sparse_embedx_dim") is not None \
and strategy.get("use_cvm") == False \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 1:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding dim - 1 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name] - 1))
if st.get("sparse_embedx_dim") is None \
and strategy.get("use_cvm") == True:
logger.warning( logger.warning(
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. " "with same sparse table name is not set in config_fleet.py. "
...@@ -219,6 +228,15 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -219,6 +228,15 @@ class DistributedAdam(DistributedOptimizerImplBase):
format(table_name, emb_to_size[table_name], emb_to_size[ format(table_name, emb_to_size[table_name], emb_to_size[
table_name])) table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 3 st["sparse_embedx_dim"] = emb_to_size[table_name] - 3
if st.get("sparse_embedx_dim") is None \
and strategy.get("use_cvm") == False:
logger.warning(
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {} - 1.".
format(table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 1
elif accessor == "DownpourSparseValueAccessor": elif accessor == "DownpourSparseValueAccessor":
if st.get("sparse_embedx_dim") is not None \ if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name]: and st["sparse_embedx_dim"] != emb_to_size[table_name]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册