提交 f1a10cce 编写于 作者: T tangwei12

enable lookup table to inference

上级 5c537941
......@@ -1319,7 +1319,13 @@ class Program(object):
self._seed = 0
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = []
# for distribute
self._is_distributed = False
self._is_chief = False
self._slice_vars_and_atts = []
self._endpoints = []
self._distributed_lookup_table = None
@property
def op_role(self):
......
......@@ -666,11 +666,22 @@ def save_inference_model(dirname,
save_persistables(executor, dirname, inference_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief:
if main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
_save_lookup_tables_by_notify(
executor, lookup_table_filename,
main_program._distributed_lookup_table, main_program._endpoints)
def load_inference_model(dirname,
executor,
model_filename=None,
params_filename=None):
params_filename=None,
training_role=None,
role_id=None,
pserver_endpoints=None):
"""
Load inference model from a directory
......@@ -736,6 +747,12 @@ def load_inference_model(dirname,
program = Program.parse_from_string(program_desc_str)
load_persistables(executor, dirname, program, params_filename)
if pserver_endpoints:
_endpoints_replacement(program, pserver_endpoints)
if training_role == "PSERVER":
_load_lookup_table_vars(executor, dirname, program, role_id)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [
......@@ -745,6 +762,118 @@ def load_inference_model(dirname,
return [program, feed_target_names, fetch_targets]
def _save_lookup_tables_by_notify(executor, dirname, lookup_table,
pserver_endpoints):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program = Program()
pserver_notify_block = pserver_notify_program.global_block()
attrs = {}
attrs['epmap'] = pserver_endpoints.split(",")
attrs['dir'] = dirname
attrs['lookup_table'] = lookup_table
pserver_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(pserver_notify_program)
def _load_lookup_table_vars(executor, dirname, program, pserver_id):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
_load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
LOOKUP_TABLE_TYPE = "lookup_table"
lookup_table_var_name = None
for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attrs['is_distributed'] is True:
if lookup_table_var_name is None:
lookup_table_var_name = op.input("W")[0]
if lookup_table_var_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
lookup_table_var = program.global_block().vars[lookup_table_var_name]
if lookup_table_var is None:
return
lookup_table_dir = os.path.join(dirname, "__lookup_table__")
table_file = "{}.{}".format(lookup_table_var.name, pserver_id)
load_prog = Program()
load_block = load_prog.global_block()
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [lookup_table_var]},
attrs={'file_path': os.path.join(lookup_table_dir, table_file)})
executor.run(load_prog)
def _endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap"
for op in program.global_block().ops:
if op.attrs.has_key(ENDPOINT_MAP):
op.attrs[ENDPOINT_MAP] = endpoints
def get_parameter_value(para, executor):
"""
Get the LoDTensor value of the given parameter.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册