提交 5ab3daf7 编写于 作者: L liuyibing01

Merge branch 'io' into 'master'

change interface for io.py

See merge request !49
......@@ -20,6 +20,11 @@ import numpy as np
import paddle.fluid.dygraph as dg
def is_main_process():
local_rank = dg.parallel.Env().local_rank
return local_rank == 0
def add_yaml_config_to_args(config):
""" Add args in yaml config to the args parsed by argparse. The argument in
yaml config will be overwritten by the same argument in argparse if they
......@@ -41,7 +46,7 @@ def add_yaml_config_to_args(config):
return config
def load_latest_checkpoint(checkpoint_dir, rank=0):
def _load_latest_checkpoint(checkpoint_dir):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
......@@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
Returns:
int: the latest iteration number.
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0:
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-0")
# Make sure that other process waits until checkpoint file is created
# by process 0.
while not os.path.isfile(checkpoint_path):
time.sleep(1)
if (not os.path.isfile(checkpoint_record)):
return 0
# Fetch the latest checkpoint index.
with open(checkpoint_path, "r") as handle:
with open(checkpoint_record, "r") as handle:
latest_checkpoint = handle.readline().split()[-1]
iteration = int(latest_checkpoint.split("-")[-1])
return iteration
def save_latest_checkpoint(checkpoint_dir, iteration):
def _save_checkpoint(checkpoint_dir, iteration):
"""Save the iteration number of the latest model to be checkpointed.
Args:
......@@ -81,60 +80,76 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle:
with open(checkpoint_record, "w") as handle:
handle.write("model_checkpoint_path: step-{}".format(iteration))
def load_parameters(checkpoint_dir,
rank,
model,
def load_parameters(model,
optimizer=None,
checkpoint_dir=None,
iteration=None,
file_path=None,
checkpoint_path=None,
dtype="float32"):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
iteration (int): number of iterations that the loaded checkpoint has
been trained.
"""
if file_path is None:
if checkpoint_dir is not None and checkpoint_path is not None:
raise ValueError(
"Load from either from (checkpoint_dir and iteration) \n"
"or checkpoint_path. Do not pass both.")
if iteration is not None and checkpoint_dir is None:
raise ValueError(
"When iteration is specified, checkpoint_dir should not be None")
if checkpoint_dir is not None:
if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank)
if iteration == 0:
return
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict, optimizer_dict = dg.load_dygraph(file_path)
if dtype == "float16":
for k, v in model_dict.items():
if "conv2d_transpose" in k:
model_dict[k] = v.astype("float32")
else:
model_dict[k] = v.astype(dtype)
iteration = _load_latest_checkpoint(checkpoint_dir)
checkpoint_path = os.path.join(checkpoint_dir,
"step-{}".format(iteration))
if iteration == 0 and not os.path.exists(checkpoint_path):
# if step-0 exist, it is also loaded
return iteration
else:
# checkpoint is not None
iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
local_rank = dg.parallel.Env().local_rank
model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path)
# cast to desired data type
for k, v in model_dict.items():
model_dict[k] = v.astype(dtype)
model.set_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
print("[checkpoint] Rank {}: loaded model from {}.pdparams".format(
local_rank, checkpoint_path))
if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, file_path))
print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt".
format(local_rank, checkpoint_path))
return iteration
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
"""Checkpoint the latest trained model parameters.
Args:
......@@ -147,12 +162,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
Returns:
None
"""
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path)
print("[checkpoint] Saved model to {}".format(file_path))
dg.save_dygraph(model_dict, checkpoint_path)
print("[checkpoint] Saved model to {}.pdparams".format(checkpoint_path))
if optimizer:
opt_dict = optimizer.state_dict()
dg.save_dygraph(opt_dict, file_path)
print("[checkpoint] Saved optimzier state to {}".format(file_path))
dg.save_dygraph(opt_dict, checkpoint_path)
print("[checkpoint] Saved optimzier state to {}.pdopt".format(
checkpoint_path))
_save_checkpoint(checkpoint_dir, iteration)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册