提交 05bd9db8 编写于 作者: T tangwei12

add comments in io.py

上级 c073bb3b
...@@ -840,6 +840,12 @@ def save_checkpoint(executor, ...@@ -840,6 +840,12 @@ def save_checkpoint(executor,
max_num_checkpoints(int): The max number of total number of existing max_num_checkpoints(int): The max number of total number of existing
checkpoints. checkpoints.
Default: 3 Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Returns: Returns:
None None
...@@ -856,15 +862,21 @@ def save_checkpoint(executor, ...@@ -856,15 +862,21 @@ def save_checkpoint(executor,
prog = fluid.default_main_program() prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200, trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example "step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_checkpoint(executor=exe, fluid.io.save_checkpoint(executor=exe,
checkpoint_dir=path, checkpoint_dir=path,
trainer_id=0, trainer_id=0,
trainer_args=trainer_args, trainer_args=trainer_args,
main_program=prog, main_program=prog,
max_num_checkpoints=3) max_num_checkpoints=3,
lookup_table=table_name,
ps_endpoint_list = ps_endpoints)
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
assert checkpoint_dir
if trainer_args: if trainer_args:
assert isinstance(trainer_args, dict) assert isinstance(trainer_args, dict)
...@@ -881,6 +893,7 @@ def save_checkpoint(executor, ...@@ -881,6 +893,7 @@ def save_checkpoint(executor,
if is_chief: if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program) save_persist_vars_without_grad(executor, cur_dir, main_program)
if is_chief and lookup_table and ps_endpoint_list: if is_chief and lookup_table and ps_endpoint_list:
save_pserver_vars_by_notify(executor, cur_dir, lookup_table, save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
ps_endpoint_list) ps_endpoint_list)
...@@ -1020,6 +1033,31 @@ def load_persist_vars_without_grad(executor, ...@@ -1020,6 +1033,31 @@ def load_persist_vars_without_grad(executor,
def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/__model__"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for var in program.list_vars(): for var in program.list_vars():
if var.name == table_name: if var.name == table_name:
...@@ -1092,6 +1130,35 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -1092,6 +1130,35 @@ def save_persist_vars_without_grad(executor, dirname, program):
def save_pserver_vars_by_notify(executor, dirname, lookup_table, def save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoint_list): ps_endpoint_list):
""" """
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
""" """
cur_dir = _get_lookuptable_dir(dirname) cur_dir = _get_lookuptable_dir(dirname)
...@@ -1121,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args): ...@@ -1121,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert isinstance(trainer_args, list) assert isinstance(trainer_args, list)
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
...@@ -1141,7 +1231,7 @@ def _is_checkpoint_var(var): ...@@ -1141,7 +1231,7 @@ def _is_checkpoint_var(var):
the checkpoint will not save or load all the variables. the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var : param var(Variable)
""" """
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册