未验证 提交 9574bcd7 编写于 作者: F Fan Zhang 提交者: GitHub

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and...

[CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py (#36753)

* [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py

* [CPU-PSLIB] Fix bug for consistency insepection of op's embedding name and sparse table name in config_fleet.py
上级 9303b095
...@@ -335,19 +335,27 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -335,19 +335,27 @@ class DistributedAdam(DistributedOptimizerImplBase):
if st.get("sparse_accessor_class") is not None: if st.get("sparse_accessor_class") is not None:
accessor = st["sparse_accessor_class"] accessor = st["sparse_accessor_class"]
# set sparse_embedx_dim in strategy, # set sparse_embedx_dim in the strategy according to accessor and use_cvm config
# user do not have to set it in config_fleet
if accessor == "DownpourFeatureValueAccessor" \ if accessor == "DownpourFeatureValueAccessor" \
or accessor == "DownpourCtrAccessor" \ or accessor == "DownpourCtrAccessor" \
or accessor == "DownpourDoubleUnitAccessor" \ or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor": or accessor == "DownpourUnitAccessor":
if st.get("sparse_embedx_dim") is not None \ if st.get("sparse_embedx_dim") is not None \
and strategy.get("use_cvm") == True \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3:
raise ValueError("fleet config sparse_embedx_dim=%s not" raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding dim - 3 = %s" % " equal to embedding dim - 3 = %s" %
(st["sparse_embedx_dim"], (st["sparse_embedx_dim"],
emb_to_size[table_name] - 3)) emb_to_size[table_name] - 3))
if st.get("sparse_embedx_dim") is None: if st.get("sparse_embedx_dim") is not None \
and strategy.get("use_cvm") == False \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 1:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding dim - 1 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name] - 1))
if st.get("sparse_embedx_dim") is None \
and strategy.get("use_cvm") == True:
logger.warning( logger.warning(
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim " "sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. " "with same sparse table name is not set in config_fleet.py. "
...@@ -355,6 +363,15 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -355,6 +363,15 @@ class DistributedAdam(DistributedOptimizerImplBase):
format(table_name, emb_to_size[table_name], emb_to_size[ format(table_name, emb_to_size[table_name], emb_to_size[
table_name])) table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 3 st["sparse_embedx_dim"] = emb_to_size[table_name] - 3
if st.get("sparse_embedx_dim") is None \
and strategy.get("use_cvm") == False:
logger.warning(
"sparse embedding dim for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {} - 1.".
format(table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 1
elif accessor == "DownpourSparseValueAccessor": elif accessor == "DownpourSparseValueAccessor":
if st.get("sparse_embedx_dim") is not None \ if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name]: and st["sparse_embedx_dim"] != emb_to_size[table_name]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册