diff --git a/parakeet/utils/io.py b/parakeet/utils/io.py index e6124008d9e5964b8cce91ae6bd66a69cf061d06..e9e12405791bc57c2e2f2e602b3d59734ac11942 100644 --- a/parakeet/utils/io.py +++ b/parakeet/utils/io.py @@ -20,6 +20,11 @@ import numpy as np import paddle.fluid.dygraph as dg +def is_main_process(): + local_rank = dg.parallel.Env().local_rank + return local_rank == 0 + + def add_yaml_config_to_args(config): """ Add args in yaml config to the args parsed by argparse. The argument in yaml config will be overwritten by the same argument in argparse if they @@ -41,7 +46,7 @@ def add_yaml_config_to_args(config): return config -def load_latest_checkpoint(checkpoint_dir, rank=0): +def _load_latest_checkpoint(checkpoint_dir): """Get the iteration number corresponding to the latest saved checkpoint Args: @@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0): Returns: int: the latest iteration number. """ - checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") # Create checkpoint index file if not exist. - if (not os.path.isfile(checkpoint_path)) and rank == 0: - with open(checkpoint_path, "w") as handle: - handle.write("model_checkpoint_path: step-0") - - # Make sure that other process waits until checkpoint file is created - # by process 0. - while not os.path.isfile(checkpoint_path): - time.sleep(1) + if (not os.path.isfile(checkpoint_record)): + return 0 # Fetch the latest checkpoint index. - with open(checkpoint_path, "r") as handle: + with open(checkpoint_record, "r") as handle: latest_checkpoint = handle.readline().split()[-1] iteration = int(latest_checkpoint.split("-")[-1]) return iteration -def save_latest_checkpoint(checkpoint_dir, iteration): +def _save_checkpoint(checkpoint_dir, iteration): """Save the iteration number of the latest model to be checkpointed. Args: @@ -81,60 +80,76 @@ def save_latest_checkpoint(checkpoint_dir, iteration): Returns: None """ - checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") # Update the latest checkpoint index. - with open(checkpoint_path, "w") as handle: + with open(checkpoint_record, "w") as handle: handle.write("model_checkpoint_path: step-{}".format(iteration)) -def load_parameters(checkpoint_dir, - rank, - model, +def load_parameters(model, optimizer=None, + checkpoint_dir=None, iteration=None, - file_path=None, + checkpoint_path=None, dtype="float32"): """Load a specific model checkpoint from disk. Args: - checkpoint_dir (str): the directory where checkpoint is saved. - rank (int): the rank of the process in multi-process setting. model (obj): model to load parameters. optimizer (obj, optional): optimizer to load states if needed. Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. iteration (int, optional): if specified, load the specific checkpoint, if not specified, load the latest one. Defaults to None. - file_path (str, optional): if specified, load the checkpoint - stored in the file_path. Defaults to None. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path. Defaults to None. dtype (str, optional): precision of the model parameters. Defaults to float32. Returns: - None + iteration (int): number of iterations that the loaded checkpoint has + been trained. """ - if file_path is None: + if checkpoint_dir is not None and checkpoint_path is not None: + raise ValueError( + "Load from either from (checkpoint_dir and iteration) \n" + "or checkpoint_path. Do not pass both.") + if iteration is not None and checkpoint_dir is None: + raise ValueError( + "When iteration is specified, checkpoint_dir should not be None") + + if checkpoint_dir is not None: if iteration is None: - iteration = load_latest_checkpoint(checkpoint_dir, rank) - if iteration == 0: - return - file_path = "{}/step-{}".format(checkpoint_dir, iteration) - - model_dict, optimizer_dict = dg.load_dygraph(file_path) - if dtype == "float16": - for k, v in model_dict.items(): - if "conv2d_transpose" in k: - model_dict[k] = v.astype("float32") - else: - model_dict[k] = v.astype(dtype) + iteration = _load_latest_checkpoint(checkpoint_dir) + checkpoint_path = os.path.join(checkpoint_dir, + "step-{}".format(iteration)) + if iteration == 0 and not os.path.exists(checkpoint_path): + # if step-0 exist, it is also loaded + return iteration + else: + # checkpoint is not None + iteration = int(os.path.basename(checkpoint_path).split("-")[-1]) + + local_rank = dg.parallel.Env().local_rank + model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path) + + # cast to desired data type + for k, v in model_dict.items(): + model_dict[k] = v.astype(dtype) + model.set_dict(model_dict) - print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path)) + print("[checkpoint] Rank {}: loaded model from {}.pdparams".format( + local_rank, checkpoint_path)) + if optimizer and optimizer_dict: optimizer.set_dict(optimizer_dict) - print("[checkpoint] Rank {}: loaded optimizer state from {}".format( - rank, file_path)) + print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt". + format(local_rank, checkpoint_path)) + return iteration -def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): + +def save_parameters(checkpoint_dir, iteration, model, optimizer=None): """Checkpoint the latest trained model parameters. Args: @@ -147,12 +162,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): Returns: None """ - file_path = "{}/step-{}".format(checkpoint_dir, iteration) + checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration)) model_dict = model.state_dict() - dg.save_dygraph(model_dict, file_path) - print("[checkpoint] Saved model to {}".format(file_path)) + dg.save_dygraph(model_dict, checkpoint_path) + print("[checkpoint] Saved model to {}.pdparams".format(checkpoint_path)) if optimizer: opt_dict = optimizer.state_dict() - dg.save_dygraph(opt_dict, file_path) - print("[checkpoint] Saved optimzier state to {}".format(file_path)) + dg.save_dygraph(opt_dict, checkpoint_path) + print("[checkpoint] Saved optimzier state to {}.pdopt".format( + checkpoint_path)) + + _save_checkpoint(checkpoint_dir, iteration)