提交 5ab3daf7 编写于 作者: L liuyibing01

Merge branch 'io' into 'master'

change interface for io.py

See merge request !49
...@@ -20,6 +20,11 @@ import numpy as np ...@@ -20,6 +20,11 @@ import numpy as np
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
def is_main_process():
local_rank = dg.parallel.Env().local_rank
return local_rank == 0
def add_yaml_config_to_args(config): def add_yaml_config_to_args(config):
""" Add args in yaml config to the args parsed by argparse. The argument in """ Add args in yaml config to the args parsed by argparse. The argument in
yaml config will be overwritten by the same argument in argparse if they yaml config will be overwritten by the same argument in argparse if they
...@@ -41,7 +46,7 @@ def add_yaml_config_to_args(config): ...@@ -41,7 +46,7 @@ def add_yaml_config_to_args(config):
return config return config
def load_latest_checkpoint(checkpoint_dir, rank=0): def _load_latest_checkpoint(checkpoint_dir):
"""Get the iteration number corresponding to the latest saved checkpoint """Get the iteration number corresponding to the latest saved checkpoint
Args: Args:
...@@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0): ...@@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
Returns: Returns:
int: the latest iteration number. int: the latest iteration number.
""" """
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist. # Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0: if (not os.path.isfile(checkpoint_record)):
with open(checkpoint_path, "w") as handle: return 0
handle.write("model_checkpoint_path: step-0")
# Make sure that other process waits until checkpoint file is created
# by process 0.
while not os.path.isfile(checkpoint_path):
time.sleep(1)
# Fetch the latest checkpoint index. # Fetch the latest checkpoint index.
with open(checkpoint_path, "r") as handle: with open(checkpoint_record, "r") as handle:
latest_checkpoint = handle.readline().split()[-1] latest_checkpoint = handle.readline().split()[-1]
iteration = int(latest_checkpoint.split("-")[-1]) iteration = int(latest_checkpoint.split("-")[-1])
return iteration return iteration
def save_latest_checkpoint(checkpoint_dir, iteration): def _save_checkpoint(checkpoint_dir, iteration):
"""Save the iteration number of the latest model to be checkpointed. """Save the iteration number of the latest model to be checkpointed.
Args: Args:
...@@ -81,60 +80,76 @@ def save_latest_checkpoint(checkpoint_dir, iteration): ...@@ -81,60 +80,76 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
Returns: Returns:
None None
""" """
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index. # Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle: with open(checkpoint_record, "w") as handle:
handle.write("model_checkpoint_path: step-{}".format(iteration)) handle.write("model_checkpoint_path: step-{}".format(iteration))
def load_parameters(checkpoint_dir, def load_parameters(model,
rank,
model,
optimizer=None, optimizer=None,
checkpoint_dir=None,
iteration=None, iteration=None,
file_path=None, checkpoint_path=None,
dtype="float32"): dtype="float32"):
"""Load a specific model checkpoint from disk. """Load a specific model checkpoint from disk.
Args: Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters. model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed. optimizer (obj, optional): optimizer to load states if needed.
Defaults to None. Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
iteration (int, optional): if specified, load the specific checkpoint, iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None. if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint checkpoint_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None. stored in the checkpoint_path. Defaults to None.
dtype (str, optional): precision of the model parameters. dtype (str, optional): precision of the model parameters.
Defaults to float32. Defaults to float32.
Returns: Returns:
None iteration (int): number of iterations that the loaded checkpoint has
been trained.
""" """
if file_path is None: if checkpoint_dir is not None and checkpoint_path is not None:
raise ValueError(
"Load from either from (checkpoint_dir and iteration) \n"
"or checkpoint_path. Do not pass both.")
if iteration is not None and checkpoint_dir is None:
raise ValueError(
"When iteration is specified, checkpoint_dir should not be None")
if checkpoint_dir is not None:
if iteration is None: if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank) iteration = _load_latest_checkpoint(checkpoint_dir)
if iteration == 0: checkpoint_path = os.path.join(checkpoint_dir,
return "step-{}".format(iteration))
file_path = "{}/step-{}".format(checkpoint_dir, iteration) if iteration == 0 and not os.path.exists(checkpoint_path):
# if step-0 exist, it is also loaded
model_dict, optimizer_dict = dg.load_dygraph(file_path) return iteration
if dtype == "float16": else:
for k, v in model_dict.items(): # checkpoint is not None
if "conv2d_transpose" in k: iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
model_dict[k] = v.astype("float32")
else: local_rank = dg.parallel.Env().local_rank
model_dict[k] = v.astype(dtype) model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path)
# cast to desired data type
for k, v in model_dict.items():
model_dict[k] = v.astype(dtype)
model.set_dict(model_dict) model.set_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path)) print("[checkpoint] Rank {}: loaded model from {}.pdparams".format(
local_rank, checkpoint_path))
if optimizer and optimizer_dict: if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict) optimizer.set_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format( print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt".
rank, file_path)) format(local_rank, checkpoint_path))
return iteration
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
"""Checkpoint the latest trained model parameters. """Checkpoint the latest trained model parameters.
Args: Args:
...@@ -147,12 +162,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): ...@@ -147,12 +162,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
Returns: Returns:
None None
""" """
file_path = "{}/step-{}".format(checkpoint_dir, iteration) checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
model_dict = model.state_dict() model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path) dg.save_dygraph(model_dict, checkpoint_path)
print("[checkpoint] Saved model to {}".format(file_path)) print("[checkpoint] Saved model to {}.pdparams".format(checkpoint_path))
if optimizer: if optimizer:
opt_dict = optimizer.state_dict() opt_dict = optimizer.state_dict()
dg.save_dygraph(opt_dict, file_path) dg.save_dygraph(opt_dict, checkpoint_path)
print("[checkpoint] Saved optimzier state to {}".format(file_path)) print("[checkpoint] Saved optimzier state to {}.pdopt".format(
checkpoint_path))
_save_checkpoint(checkpoint_dir, iteration)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册