From 05bd9db84bfb6b0a2beea4c4c79306c5eb127ff7 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 20 Jun 2018 17:25:16 +0800 Subject: [PATCH] add comments in io.py --- python/paddle/fluid/io.py | 94 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index e59ac11fd4..32f53ebe38 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -840,6 +840,12 @@ def save_checkpoint(executor, max_num_checkpoints(int): The max number of total number of existing checkpoints. Default: 3 + lookup_table(string|None): the lookup table name, when use distribute + lookup table, we can get lookup table name by DistributeTranspiler. + table_name + ps_endpoint_list(list|None): the parameter server ip:port list. + when use distribute lookup table, we can get ps_endpoint_list by + distribute arguments. Returns: None @@ -856,15 +862,21 @@ def save_checkpoint(executor, prog = fluid.default_main_program() trainer_args = {"epoch_id": 200, "step_id": 20} # just an example + table_name = "share_w" + ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] + fluid.io.save_checkpoint(executor=exe, checkpoint_dir=path, trainer_id=0, trainer_args=trainer_args, main_program=prog, - max_num_checkpoints=3) + max_num_checkpoints=3, + lookup_table=table_name, + ps_endpoint_list = ps_endpoints) """ if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") + assert checkpoint_dir if trainer_args: assert isinstance(trainer_args, dict) @@ -881,6 +893,7 @@ def save_checkpoint(executor, if is_chief: save_persist_vars_without_grad(executor, cur_dir, main_program) + if is_chief and lookup_table and ps_endpoint_list: save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list) @@ -1020,6 +1033,31 @@ def load_persist_vars_without_grad(executor, def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): + """ + The parameter server will load lookup table's local file in + selectedrows variable. + + Args: + executor(Executor): The executor to run for loading persistable variables + dirname(str): The directory path + main_program(Program): Find the variable named table_name in main_program + pserver_id(int): the serial number in pserver_endpoints list + table_name(str): lookup table name + Returns: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + dirname = "./checkpoints/checkpoint_9/__model__" + prog = fluid.default_main_program() + pserver_id = 1 + table_name = "share_w" + fluid.io.load_lookup_table_vars(executor=exe, + dirname=dirname, program=prog, pserver_id=pserver_id, + table_name=table_name) + """ for var in program.list_vars(): if var.name == table_name: @@ -1092,6 +1130,35 @@ def save_persist_vars_without_grad(executor, dirname, program): def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list): """ + This function will send checkpoint notify message from Trainer 0 + to all the pservers. + The checkpoint notify message contains lookup table name, + the absolute path on pserver to save lookup_table. + + Args: + executor(Executor): The executor to run for send checkpoint notify. + dirname(str): The folder where to save checkpoints. + lookup_table(string): the lookup table name, when use distribute + lookup table, we can get lookup table name by DistributeTranspiler. + table_name + ps_endpoint_list(list): the parameter server ip:port list. + when use distribute lookup table, we can get ps_endpoint_list by + distribute arguments. + Return: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + param_path = "./my_paddle_model" + prog = fluid.default_main_program() + table_name = "share_w" + ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] + + fluid.io.save_pserver_vars_by_notify(executor=exe, + dirname=param_path, lookup_table=table_name, + ps_endpoint_list=ps_endpoints) """ cur_dir = _get_lookuptable_dir(dirname) @@ -1121,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): + """ + trainer will load some args from it's independent directory, + such as epoch_id and step_id. + + Args: + checkpoint_dir(str): The folder where all checkpoints are. + serial(int): The serial of checkpoint you would like to load. + trainer_id(int): current trainer id. + trainer_args(list): list about load trainer args + Return: + None + + Examples: + .. code-block:: python + + param_path = "./checkpoint/" + serial = 7 + trainer_id = 2 + trainer_args = ["epoch_id", "step_id"] + + fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial, + trainer_id=trainer_id, trainer_args=trainer_args) + """ assert isinstance(trainer_args, list) cur_dir = _get_serial_dir(checkpoint_dir, serial) @@ -1141,7 +1231,7 @@ def _is_checkpoint_var(var): the checkpoint will not save or load all the variables. var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. - : param var + : param var(Variable) """ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ -- GitLab