提交 20117d99 编写于 作者: H Hui Zhang

fix ckpt load

上级 43b52082
......@@ -599,26 +599,26 @@ class U2BaseModel(nn.Module):
best_index = i
return hyps[best_index][0]
@jit.export
#@jit.export
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@jit.export
#@jit.export
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@jit.export
#@jit.export
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@jit.export
#@jit.export
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
......@@ -654,12 +654,14 @@ class U2BaseModel(nn.Module):
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
@jit.export
# @jit.export([
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output
xs (paddle.Tensor): encoder output, (B, T, D)
Returns:
paddle.Tensor: activation before ctc
"""
......@@ -894,7 +896,7 @@ class U2Model(U2BaseModel):
model = cls.from_config(config)
if checkpoint_path:
infos = checkpoint.load_parameters(
infos = checkpoint.Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
......
......@@ -17,6 +17,7 @@ import os
import re
from pathlib import Path
from typing import Union
from typing import Text
import paddle
from paddle import distributed as dist
......@@ -30,7 +31,7 @@ logger = Log(__name__).getlog()
__all__ = ["Checkpoint"]
class Checkpoint(object):
class Checkpoint():
def __init__(self, kbest_n: int=5, latest_n: int=1):
self.best_records: Mapping[Path, float] = {}
self.latest_records = []
......@@ -40,11 +41,21 @@ class Checkpoint(object):
def add_checkpoint(self,
checkpoint_dir,
tag_or_iteration,
model,
optimizer,
infos,
tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None,
metric_type="val_loss"):
"""Save checkpoint in best_n and latest_n.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
infos (dict or None)): any info you want to save.
metric_type (str, optional): metric type. Defaults to "val_loss".
"""
if (metric_type not in infos.keys()):
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
......@@ -61,6 +72,62 @@ class Checkpoint(object):
if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
record_file="checkpoint_latest"):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
record_file "checkpoint_latest" or "checkpoint_best"
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
pass
elif checkpoint_dir is not None and record_file is not None:
# load checkpint from record file
checkpoint_record = os.path.join(checkpoint_dir, record_file)
iteration = self._load_checkpoint_idx(checkpoint_record)
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!"
)
rank = dist.get_rank()
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
return configs
def load_latest_parameters(self,
model,
......@@ -192,61 +259,6 @@ class Checkpoint(object):
for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i))
def _load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_file=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
checkpoint_file "checkpoint_latest" or "checkpoint_best"
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None and checkpoint_file is not None:
checkpoint_record = os.path.join(checkpoint_dir, checkpoint_file)
iteration = self._load_checkpoint_idx(checkpoint_record)
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_file' and 'checkpoint_path' should be specified!"
)
rank = dist.get_rank()
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
return configs
@mp_tools.rank_zero_only
def _save_parameters(self,
checkpoint_dir: str,
......
......@@ -40,5 +40,5 @@ fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册