未验证 提交 75efc0ac 编写于 作者: F Fan Zhang 提交者: GitHub

[CPU-PSLIB] Add consistency insepection of op's embedding name and sparse...

[CPU-PSLIB] Add consistency insepection of op's embedding name and sparse table name in config_fleet.py (#34287)
上级 6e4c2c5a
......@@ -22,6 +22,7 @@ from google.protobuf import text_format
from collections import OrderedDict
from .node import DownpourWorker, DownpourServer
from . import ps_pb2 as pslib
import logging
# this dict is for store info about pull/push sparse ops.
FLEET_GLOBAL_DICT = {
......@@ -37,6 +38,10 @@ FLEET_GLOBAL_DICT = {
"scale_sparse_grad": None,
}
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
class DistributedOptimizerImplBase(object):
"""
......@@ -167,6 +172,74 @@ class DistributedAdam(DistributedOptimizerImplBase):
ret_list.append(x[0])
return ret_list
def _gen_distributed_emb_to_size_dict(self, program):
d_size = dict()
local_vars = program.current_block().vars
for op in program.global_block().ops:
if op.type in self.supported_embedding_types:
if op.attr('is_distributed') is True:
table_name = op.input("W")[0]
emb_size = local_vars[table_name].shape[1]
if d_size.get(table_name) is None:
d_size[table_name] = emb_size
elif d_size[table_name] != emb_size:
raise ValueError("embedding size error: %s vs %s" %
(emb_size, d_size[table_name]))
return d_size
def _check_config_fleet_with_program_op(self, strategy, table_name,
emb_to_size):
if strategy.get(table_name) is None:
strategy[table_name] = dict()
st = strategy[table_name]
accessor = None
if st.get("sparse_accessor_class") is not None:
accessor = st["sparse_accessor_class"]
if accessor is None:
accessor = "DownpourCtrAccessor"
# set sparse_embedx_dim in strategy,
# user do not have to set it in config_fleet
if accessor == "DownpourFeatureValueAccessor" \
or accessor == "DownpourCtrAccessor" \
or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor":
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size - 3 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name] - 3))
if st.get("sparse_embedx_dim") is None:
logger.warning(
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {} - 3.".
format(table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name] - 3
elif accessor == "DownpourSparseValueAccessor":
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[table_name]:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size = %s" %
(st["sparse_embedx_dim"],
emb_to_size[table_name]))
if st.get("sparse_embedx_dim") is None:
logger.warning(
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim "
"with same sparse table name is not set in config_fleet.py. "
"Hence automatically set sparse_embedx_dim = {}.".format(
table_name, emb_to_size[table_name], emb_to_size[
table_name]))
st["sparse_embedx_dim"] = emb_to_size[table_name]
return strategy
def _minimize(self,
losses,
startup_program=None,
......@@ -226,6 +299,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
sparse_table_to_index[tn] = sparse_table_index
sparse_table_index += 1
# get {table_name: emb_size} dict from program ops
emb_to_size = self._gen_distributed_emb_to_size_dict(
loss.block.program)
# get inputs_dict
inputs_dict = self._find_distributed_lookup_table_inputs(
loss.block.program, sparse_table)
......@@ -340,8 +417,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
# ServerParameter add all sparse tables
for tn in sparse_table_to_index:
sparse_table_index = sparse_table_to_index[tn]
if strategy.get(tn) is not None:
server.add_sparse_table(sparse_table_index, strategy[tn])
st = self._check_config_fleet_with_program_op(strategy, tn,
emb_to_size)
if st.get(tn) is not None:
server.add_sparse_table(sparse_table_index, st[tn])
else:
server.add_sparse_table(sparse_table_index, None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册