提交 f3338265 编写于 作者: H Hui Zhang

more utils for train

上级 472cf70e
......@@ -26,9 +26,6 @@ from deepspeech.utils.log import Log
#TODO(Hui Zhang): remove fluid import
logger = Log(__name__).getlog()
########### hcak logging #############
logger.warn = logger.warning
########### hcak paddle #############
paddle.bool = 'bool'
paddle.float16 = 'float16'
......@@ -91,23 +88,23 @@ def convert_dtype_to_string(tensor_dtype):
if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!")
logger.debug("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle, 'log_softmax'):
logger.warn("register user log_softmax to paddle, remove this when fixed!")
logger.debug("register user log_softmax to paddle, remove this when fixed!")
setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax)
if not hasattr(paddle, 'sigmoid'):
logger.warn("register user sigmoid to paddle, remove this when fixed!")
logger.debug("register user sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle, 'log_sigmoid'):
logger.warn("register user log_sigmoid to paddle, remove this when fixed!")
logger.debug("register user log_sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid)
if not hasattr(paddle, 'relu'):
logger.warn("register user relu to paddle, remove this when fixed!")
logger.debug("register user relu to paddle, remove this when fixed!")
setattr(paddle, 'relu', paddle.nn.functional.relu)
......@@ -116,7 +113,7 @@ def cat(xs, dim=0):
if not hasattr(paddle, 'cat'):
logger.warn(
logger.debug(
"override cat of paddle if exists or register, remove this when fixed!")
paddle.cat = cat
......@@ -127,7 +124,7 @@ def item(x: paddle.Tensor):
if not hasattr(paddle.Tensor, 'item'):
logger.warn(
logger.debug(
"override item of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.item = item
......@@ -138,13 +135,13 @@ def func_long(x: paddle.Tensor):
if not hasattr(paddle.Tensor, 'long'):
logger.warn(
logger.debug(
"override long of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.long = func_long
if not hasattr(paddle.Tensor, 'numel'):
logger.warn(
logger.debug(
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.numel = paddle.numel
......@@ -158,7 +155,7 @@ def new_full(x: paddle.Tensor,
if not hasattr(paddle.Tensor, 'new_full'):
logger.warn(
logger.debug(
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.new_full = new_full
......@@ -173,13 +170,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'eq'):
logger.warn(
logger.debug(
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.eq = eq
if not hasattr(paddle, 'eq'):
logger.warn(
logger.debug(
"override eq of paddle if exists or register, remove this when fixed!")
paddle.eq = eq
......@@ -189,7 +186,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'contiguous'):
logger.warn(
logger.debug(
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.contiguous = contiguous
......@@ -206,7 +203,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.warn(
logger.debug(
"override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
)
......@@ -218,7 +215,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view'):
logger.warn("register user view to paddle.Tensor, remove this when fixed!")
logger.debug("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view
......@@ -227,7 +224,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view_as'):
logger.warn(
logger.debug(
"register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as
......@@ -253,7 +250,7 @@ def masked_fill(xs: paddle.Tensor,
if not hasattr(paddle.Tensor, 'masked_fill'):
logger.warn(
logger.debug(
"register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill
......@@ -271,7 +268,7 @@ def masked_fill_(xs: paddle.Tensor,
if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.warn(
logger.debug(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_
......@@ -283,7 +280,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'fill_'):
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!")
logger.debug(
"register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_
......@@ -292,22 +290,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'repeat'):
logger.warn(
logger.debug(
"register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'):
logger.warn(
logger.debug(
"register user softmax to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle.Tensor, 'sigmoid'):
logger.warn(
logger.debug(
"register user sigmoid to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle.Tensor, 'relu'):
logger.warn("register user relu to paddle.Tensor, remove this when fixed!")
logger.debug("register user relu to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
......@@ -316,7 +314,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'type_as'):
logger.warn(
logger.debug(
"register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as)
......@@ -332,7 +330,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'to'):
logger.warn("register user to to paddle.Tensor, remove this when fixed!")
logger.debug("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to)
......@@ -341,7 +339,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'float'):
logger.warn("register user float to paddle.Tensor, remove this when fixed!")
logger.debug(
"register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float)
......@@ -350,7 +349,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'int'):
logger.warn("register user int to paddle.Tensor, remove this when fixed!")
logger.debug("register user int to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'int', func_int)
......@@ -359,7 +358,7 @@ def tolist(x: paddle.Tensor) -> List[Any]:
if not hasattr(paddle.Tensor, 'tolist'):
logger.warn(
logger.debug(
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
......@@ -374,7 +373,7 @@ def glu(x: paddle.Tensor, axis=-1) -> paddle.Tensor:
if not hasattr(paddle.nn.functional, 'glu'):
logger.warn(
logger.debug(
"register user glu to paddle.nn.functional, remove this when fixed!")
setattr(paddle.nn.functional, 'glu', glu)
......@@ -425,19 +424,19 @@ def ctc_loss(logits,
return loss_out
logger.warn(
logger.debug(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
F.ctc_loss = ctc_loss
########### hcak paddle.nn #############
if not hasattr(paddle.nn, 'Module'):
logger.warn("register user Module to paddle.nn, remove this when fixed!")
logger.debug("register user Module to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'Module', paddle.nn.Layer)
# maybe cause assert isinstance(sublayer, core.Layer)
if not hasattr(paddle.nn, 'ModuleList'):
logger.warn(
logger.debug(
"register user ModuleList to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'ModuleList', paddle.nn.LayerList)
......@@ -454,7 +453,7 @@ class GLU(nn.Layer):
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
logger.debug("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)
......@@ -486,12 +485,12 @@ class ConstantPad2d(nn.Layer):
if not hasattr(paddle.nn, 'ConstantPad2d'):
logger.warn(
logger.debug(
"register user ConstantPad2d to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d)
########### hcak paddle.jit #############
if not hasattr(paddle.jit, 'export'):
logger.warn("register user export to paddle.jit, remove this when fixed!")
logger.debug("register user export to paddle.jit, remove this when fixed!")
setattr(paddle.jit, 'export', paddle.jit.to_static)
......@@ -21,8 +21,8 @@ from yacs.config import CfgNode
from deepspeech.modules.conv import ConvStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.rnn import RNNStack
from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
......@@ -222,7 +222,7 @@ class DeepSpeech2Model(nn.Layer):
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
infos = checkpoint.load_parameters(
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
......
......@@ -40,8 +40,8 @@ from deepspeech.modules.mask import make_pad_mask
from deepspeech.modules.mask import mask_finished_preds
from deepspeech.modules.mask import mask_finished_scores
from deepspeech.modules.mask import subsequent_mask
from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import add_sos_eos
......@@ -894,7 +894,7 @@ class U2Model(U2BaseModel):
model = cls.from_config(config)
if checkpoint_path:
infos = checkpoint.load_parameters(
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
......
......@@ -18,8 +18,8 @@ import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from deepspeech.utils import checkpoint
from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
__all__ = ["Trainer"]
......@@ -151,7 +151,7 @@ class Trainer():
resume training.
"""
scratch = None
infos = checkpoint.load_parameters(
infos = Checkpoint().load_parameters(
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
"""
import sacrebleu
__all__ = ['bleu', 'char_bleu']
def bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in word-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference length is zero.
"""
return sacrebleu.corpus_bleu(hypothesis, reference)
def char_bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in char-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference number is zero.
"""
hypothesis = [' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis]
reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref]
for ref in reference]
return sacrebleu.corpus_bleu(hypothesis, reference)
......@@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import os
import re
from pathlib import Path
from typing import Text
from typing import Union
import paddle
......@@ -25,46 +28,58 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["load_parameters", "save_parameters"]
__all__ = ["Checkpoint"]
def _load_latest_checkpoint(checkpoint_dir: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
Returns:
int: the latest iteration number. -1 for no checkpoint to load.
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
if not os.path.isfile(checkpoint_record):
return -1
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip()
iteration = int(latest_checkpoint.split(":")[-1])
return iteration
class Checkpoint():
def __init__(self, kbest_n: int=5, latest_n: int=1):
self.best_records: Mapping[Path, float] = {}
self.latest_records = []
self.kbest_n = kbest_n
self.latest_n = latest_n
self._save_all = (kbest_n == -1)
def add_checkpoint(self,
checkpoint_dir,
tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None,
metric_type="val_loss"):
"""Save checkpoint in best_n and latest_n.
def _save_record(checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpoint record.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
infos (dict or None)): any info you want to save.
metric_type (str, optional): metric type. Defaults to "val_loss".
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_record, "a+") as handle:
handle.write("model_checkpoint_path:{}\n".format(iteration))
if (metric_type not in infos.keys()):
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
return
#save best
if self._should_save_best(infos[metric_type]):
self._save_best_checkpoint_and_update(
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
#save latest
self._save_latest_checkpoint_and_update(
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(model,
def load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a specific model checkpoint from disk.
checkpoint_path=None,
record_file="checkpoint_latest"):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
......@@ -73,21 +88,25 @@ def load_parameters(model,
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
record_file "checkpoint_latest" or "checkpoint_best"
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None:
iteration = _load_latest_checkpoint(checkpoint_dir)
pass
elif checkpoint_dir is not None and record_file is not None:
# load checkpint from record file
checkpoint_record = os.path.join(checkpoint_dir, record_file)
iteration = self._load_checkpoint_idx(checkpoint_record)
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration))
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
"At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!"
)
rank = dist.get_rank()
......@@ -95,13 +114,13 @@ def load_parameters(model,
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
logger.info("Rank {}: Restore model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format(
logger.info("Rank {}: Restore optimizer state from {}".format(
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
......@@ -110,9 +129,139 @@ def load_parameters(model,
configs = json.load(fin)
return configs
def load_latest_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_latest")
@mp_tools.rank_zero_only
def save_parameters(checkpoint_dir: str,
def load_best_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_best")
def _should_save_best(self, metric: float) -> bool:
if not self._best_full():
return True
# already full
worst_record_path = max(self.best_records, key=self.best_records.get)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric = self.best_records[worst_record_path]
return metric < worst_metric
def _best_full(self):
return (not self._save_all) and len(self.best_records) == self.kbest_n
def _latest_full(self):
return len(self.latest_records) == self.latest_n
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
tag_or_iteration, model, optimizer,
infos):
# remove the worst
if self._best_full():
worst_record_path = max(self.best_records,
key=self.best_records.get)
self.best_records.pop(worst_record_path)
if (worst_record_path not in self.latest_records):
logger.info(
"remove the worst checkpoint: {}".format(worst_record_path))
self._del_checkpoint(checkpoint_dir, worst_record_path)
# add the new one
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
self.best_records[tag_or_iteration] = metric
def _save_latest_checkpoint_and_update(
self, checkpoint_dir, tag_or_iteration, model, optimizer, infos):
# remove the old
if self._latest_full():
to_del_fn = self.latest_records.pop(0)
if (to_del_fn not in self.best_records.keys()):
logger.info(
"remove the latest checkpoint: {}".format(to_del_fn))
self._del_checkpoint(checkpoint_dir, to_del_fn)
self.latest_records.append(tag_or_iteration)
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
for filename in glob.glob(checkpoint_path + ".*"):
os.remove(filename)
logger.info("delete file: {}".format(filename))
def _load_checkpoint_idx(self, checkpoint_record: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_path (str): the saved path of checkpoint.
Returns:
int: the latest iteration number. -1 for no checkpoint to load.
"""
if not os.path.isfile(checkpoint_record):
return -1
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip()
iteration = int(latest_checkpoint.split(":")[-1])
return iteration
def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpoint record.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_record_latest = os.path.join(checkpoint_dir,
"checkpoint_latest")
checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best")
with open(checkpoint_record_best, "w") as handle:
for i in self.best_records.keys():
handle.write("model_checkpoint_path:{}\n".format(i))
with open(checkpoint_record_latest, "w") as handle:
for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i))
@mp_tools.rank_zero_only
def _save_parameters(self,
checkpoint_dir: str,
tag_or_iteration: Union[int, str],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
......@@ -147,6 +296,3 @@ def save_parameters(checkpoint_dir: str,
with open(info_path, 'w') as fout:
data = json.dumps(infos)
fout.write(data)
if isinstance(tag_or_iteration, int):
_save_record(checkpoint_dir, tag_or_iteration)
......@@ -38,21 +38,23 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
# add non-blank into new_hyp
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
# skip repeat label
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def insert_blank(label: np.ndarray, blank_id: int=0):
def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
"""Insert blank token between every two label token.
"abcdefg" -> "-a-b-c-d-e-f-g-"
Args:
label ([np.ndarray]): label ids, (L).
label ([np.ndarray]): label ids, List[int], (L).
blank_id (int, optional): blank id. Defaults to 0.
Returns:
......@@ -61,13 +63,13 @@ def insert_blank(label: np.ndarray, blank_id: int=0):
label = np.expand_dims(label, 1) #[L, 1]
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1) #[L, 2]
label = label.reshape(-1) #[2L]
label = np.append(label, label[0]) #[2L + 1]
label = label.reshape(-1) #[2L], -l-l-l
label = np.append(label, label[0]) #[2L + 1], -l-l-l-
return label
def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
blank_id=0) -> list:
blank_id=0) -> List[int]:
"""ctc forced alignment.
https://distill.pub/2017/ctc/
......@@ -77,23 +79,27 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
List[int]: best alignment result, (T).
"""
y_insert_blank = insert_blank(y, blank_id)
y_insert_blank = insert_blank(y, blank_id) #(2L+1)
log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
(ctc_probs.shape[0], len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
# self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1
) # state path
(ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path, Tuple((T, 2L+1))
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.size(0)):
for s in range(len(y_insert_blank)):
for t in range(1, ctc_probs.shape[0]): # T
for s in range(len(y_insert_blank)): # 2L+1
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
candidates = paddle.to_tensor(
......@@ -106,11 +112,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][
y_insert_blank[s]]
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
y_insert_blank[s])]
state_path[t, s] = prev_state[paddle.argmax(candidates)]
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16)
# TODO(Hui Zhang): zeros not support paddle.int16
# self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32)
candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb
......@@ -118,11 +126,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
])
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
state_seq[-1] = prev_state[paddle.argmax(candidates)]
for t in range(ctc_probs.size(0) - 2, -1, -1):
for t in range(ctc_probs.shape[0] - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_alignment = []
for t in range(0, ctc_probs.size(0)):
for t in range(0, ctc_probs.shape[0]):
output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
from typing import Any
from typing import Dict
from typing import List
from typing import Text
from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import has_tensor
logger = Log(__name__).getlog()
__all__ = ["dynamic_import", "instance_class"]
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'deepspeech.models.u2:U2Model'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ":" not in import_path:
raise ValueError("import_path should be one of {} or "
'include ":", e.g. "deepspeech.models.u2:U2Model" : '
"{}".format(set(alias), import_path))
if ":" not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)
def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]):
# filter by `valid_keys` and filter `val` is not None
new_args = {
key: val
for key, val in args.items() if (key in valid_keys and val is not None)
}
return new_args
def filter_out_tenosr(args: Dict[Text, Any]):
return {key: val for key, val in args.items() if not has_tensor(val)}
def instance_class(module_class, args: Dict[Text, Any]):
valid_keys = inspect.signature(module_class).parameters.keys()
new_args = filter_valid_args(args, valid_keys)
logger.info(
f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.")
return module_class(**new_args)
......@@ -14,10 +14,13 @@
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
"""
import editdistance
import numpy as np
__all__ = ['word_errors', 'char_errors', 'wer', 'cer']
editdistance.eval("a", "b")
def _levenshtein_distance(ref, hyp):
"""Levenshtein distance is a string metric for measuring the difference
......@@ -89,6 +92,8 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
hyp_words = list(filter(None, hypothesis.split(delimiter)))
edit_distance = _levenshtein_distance(ref_words, hyp_words)
# `editdistance.eavl precision` less than `_levenshtein_distance`
# edit_distance = editdistance.eval(ref_words, hyp_words)
return float(edit_distance), len(ref_words)
......@@ -119,6 +124,8 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
hypothesis = join_char.join(list(filter(None, hypothesis.split(' '))))
edit_distance = _levenshtein_distance(reference, hypothesis)
# `editdistance.eavl precision` less than `_levenshtein_distance`
# edit_distance = editdistance.eval(reference, hypothesis)
return float(edit_distance), len(reference)
......
......@@ -12,16 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import getpass
import logging
import os
import socket
import sys
FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
DATE_FMT_STR = '%Y/%m/%d %H:%M:%S'
logging.basicConfig(
level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR)
from loguru import logger
from paddle import inference
def find_log_dir(log_dir=None):
......@@ -96,53 +92,54 @@ def find_log_dir_and_names(program_name=None, log_dir=None):
class Log():
log_name = None
def __init__(self, logger=None):
self.logger = logging.getLogger(logger)
self.logger.setLevel(logging.DEBUG)
file_dir = os.getcwd() + '/log'
if not os.path.exists(file_dir):
os.mkdir(file_dir)
self.log_dir = file_dir
actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names(
program_name=None, log_dir=self.log_dir)
basename = '%s.DEBUG.%d' % (file_prefix, os.getpid())
filename = os.path.join(actual_log_dir, basename)
if Log.log_name is None:
Log.log_name = filename
# Create a symlink to the log file with a canonical name.
symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG')
try:
if os.path.islink(symlink):
os.unlink(symlink)
os.symlink(os.path.basename(Log.log_name), symlink)
except EnvironmentError:
# If it fails, we're sad but it's no error. Commonly, this
# fails because the symlink was created by another user and so
# we can't modify it
"""Default Logger for all."""
logger.remove()
logger.add(
sys.stdout,
level='INFO',
enqueue=True,
filter=lambda record: record['level'].no >= 20)
_, file_prefix, _ = find_log_dir_and_names()
sink_prefix = os.path.join("exp/log", file_prefix)
sink_path = sink_prefix[:-3] + "{time}.log"
logger.add(sink_path, level='DEBUG', enqueue=True, rotation="500 MB")
def __init__(self, name=None):
pass
if not self.logger.hasHandlers():
formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR)
fh = logging.FileHandler(Log.log_name)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
self.logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
self.logger.addHandler(ch)
# stop propagate for propagating may print
# log multiple times
self.logger.propagate = False
def getlog(self):
return logger
class Autolog:
"""Just used by fullchain project"""
def __init__(self,
batch_size,
model_name="DeepSpeech",
model_precision="fp32"):
import auto_log
pid = os.getpid()
if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
gpu_id = int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0])
infer_config = inference.Config()
infer_config.enable_use_gpu(100, gpu_id)
else:
gpu_id = None
infer_config = inference.Config()
autolog = auto_log.AutoLogger(
model_name=model_name,
model_precision=model_precision,
batch_size=batch_size,
data_shape="dynamic",
save_path="./output/auto_log.lpg",
inference_config=infer_config,
pids=pid,
process_name=None,
gpu_ids=gpu_id,
time_keys=['preprocess_time', 'inference_time', 'postprocess_time'],
warmup=0)
self.autolog = autolog
def getlog(self):
return self.logger
return self.autolog
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
# A global variable to record the number of calling times for profiler
# functions. It is used to specify the tracing range of training steps.
_profiler_step_id = 0
# A global variable to avoid parsing from string every time.
_profiler_options = None
class ProfilerOptions(object):
'''
Use a string to initialize a ProfilerOptions.
The string should be in the format: "key1=value1;key2=value;key3=value3".
For example:
"profile_path=model.profile"
"batch_range=[50, 60]; profile_path=model.profile"
"batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
ProfilerOptions supports following key-value pair:
batch_range - a integer list, e.g. [100, 110].
state - a string, the optional values are 'CPU', 'GPU' or 'All'.
sorted_key - a string, the optional values are 'calls', 'total',
'max', 'min' or 'ave.
tracer_option - a string, the optional values are 'Default', 'OpDetail',
'AllOpDetail'.
profile_path - a string, the path to save the serialized profile data,
which can be used to generate a timeline.
exit_on_finished - a boolean.
'''
def __init__(self, options_str):
assert isinstance(options_str, str)
self._options = {
'batch_range': [10, 20],
'state': 'All',
'sorted_key': 'total',
'tracer_option': 'Default',
'profile_path': '/tmp/profile',
'exit_on_finished': True
}
self._parse_from_string(options_str)
def _parse_from_string(self, options_str):
if not options_str:
return
for kv in options_str.replace(' ', '').split(';'):
key, value = kv.split('=')
if key == 'batch_range':
value_list = value.replace('[', '').replace(']', '').split(',')
value_list = list(map(int, value_list))
if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
1] > value_list[0]:
self._options[key] = value_list
elif key == 'exit_on_finished':
self._options[key] = value.lower() in ("yes", "true", "t", "1")
elif key in [
'state', 'sorted_key', 'tracer_option', 'profile_path'
]:
self._options[key] = value
def __getitem__(self, name):
if self._options.get(name, None) is None:
raise ValueError(
"ProfilerOptions does not have an option named %s." % name)
return self._options[name]
def add_profiler_step(options_str=None):
'''
Enable the operator-level timing using PaddlePaddle's profiler.
The profiler uses a independent variable to count the profiler steps.
One call of this function is treated as a profiler step.
Args:
profiler_options - a string to initialize the ProfilerOptions.
Default is None, and the profiler is disabled.
'''
if options_str is None:
return
global _profiler_step_id
global _profiler_options
if _profiler_options is None:
_profiler_options = ProfilerOptions(options_str)
logger.info(f"Profiler: {options_str}")
logger.info(f"Profiler: {_profiler_options._options}")
if _profiler_step_id == _profiler_options['batch_range'][0]:
paddle.utils.profiler.start_profiler(_profiler_options['state'],
_profiler_options['tracer_option'])
elif _profiler_step_id == _profiler_options['batch_range'][1]:
paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
_profiler_options['profile_path'])
if _profiler_options['exit_on_finished']:
sys.exit(0)
_profiler_step_id += 1
......@@ -48,9 +48,9 @@ def warm_up_test(audio_process_handler,
rng = random.Random(random_seed)
samples = rng.sample(manifest, num_test_cases)
for idx, sample in enumerate(samples):
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
print("Warm-up Test Case %d: %s" % (idx, sample['feat']))
start_time = time.time()
transcript = audio_process_handler(sample['audio_filepath'])
transcript = audio_process_handler(sample['feat'])
finish_time = time.time()
print("Response Time: %f, Transcript: %s" %
(finish_time - start_time, transcript))
......
......@@ -19,11 +19,25 @@ import paddle
from deepspeech.utils.log import Log
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"]
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
logger = Log(__name__).getlog()
def has_tensor(val):
if isinstance(val, (list, tuple)):
for item in val:
if has_tensor(item):
return True
elif isinstance(val, dict):
for k, v in val.items():
print(k)
if has_tensor(v):
return True
else:
return paddle.is_tensor(val)
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
......@@ -69,7 +83,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims = max_size[1:] if max_size.ndim >= 2 else ()
max_len = max([s.size(0) for s in sequences])
max_len = max([s.shape[0] for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
......@@ -77,12 +91,27 @@ def pad_sequence(sequences: List[paddle.Tensor],
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor
logger.info(
f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}"
)
if batch_first:
out_tensor[i, :length, ...] = tensor
# TODO (Hui Zhang): set_value op not supprot `end==start`
# TODO (Hui Zhang): set_value op not support int16
# TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...]
# out_tensor[i, :length, ...] = tensor
if length != 0:
out_tensor[i, :length] = tensor
else:
out_tensor[i, length] = tensor
else:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if length != 0:
out_tensor[:length, i] = tensor
else:
out_tensor[:length, i, ...] = tensor
out_tensor[length, i] = tensor
return out_tensor
......@@ -125,7 +154,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B = ys_pad.size(0)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
ys_in = paddle.cat([_sos, ys_pad], dim=1)
......@@ -151,8 +180,8 @@ def th_accuracy(pad_outputs: paddle.Tensor,
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
from typing import Text
import textgrid
def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]:
"""segment ctc alignment ids by continuous blank and repeat label.
Args:
alignment (List[int]): ctc alignment id sequence.
e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3]
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[List[int]]: token align, segment aligment id sequence.
e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]]
"""
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
align_segs = []
# get frames level duration for each token
start = 0
end = 0
while end < len(alignment):
while end < len(alignment) and alignment[end] == blank_id: # blank
end += 1
if end == len(alignment):
align_segs[-1].extend(alignment[start:])
break
end += 1
while end < len(alignment) and alignment[end - 1] == alignment[
end]: # repeat label
end += 1
align_segs.append(alignment[start:end])
start = end
return align_segs
def align_to_tierformat(align_segs: List[List[int]],
subsample: int,
token_dict: Dict[int, Text],
blank_id=0) -> List[Text]:
"""Generate textgrid.Interval format from alignment segmentations.
Args:
align_segs (List[List[int]]): segmented ctc alignment ids.
subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample
token_dict (Dict[int, Text]): int -> str map.
Returns:
List[Text]: list of textgrid.Interval text, str(start, end, text).
"""
hop_length = 10 # ms
second_ms = 1000 # ms
frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length
second_per_frame = 1.0 / frame_per_second
begin = 0
duration = 0
tierformat = []
for idx, tokens in enumerate(align_segs):
token_len = len(tokens)
token = tokens[-1]
# time duration in second
duration = token_len * subsample * second_per_frame
if idx < len(align_segs) - 1:
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
else:
for i in tokens:
if i != blank_id:
token = i
break
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
begin = begin + duration
return tierformat
def generate_textgrid(maxtime: float,
intervals: List[Text],
output: Text,
name: Text='ali') -> None:
"""Create alignment textgrid file.
Args:
maxtime (float): audio duartion.
intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item.
output (Text): textgrid filepath.
name (Text, optional): tier or layer name. Defaults to 'ali'.
"""
# Download Praat: https://www.fon.hum.uva.nl/praat/
avg_interval = maxtime / (len(intervals) + 1)
print(f"average second/token: {avg_interval}")
margin = 0.0001
tg = textgrid.TextGrid(maxTime=maxtime)
tier = textgrid.IntervalTier(name=name, maxTime=maxtime)
i = 0
for dur in intervals:
s, e, text = dur.split()
tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text)
tg.append(tier)
tg.write(output)
print("successfully generator textgrid {}.".format(output))
......@@ -15,9 +15,50 @@
import distutils.util
import math
import os
import random
import sys
from contextlib import contextmanager
from typing import List
__all__ = ['print_arguments', 'add_arguments', "log_add"]
import numpy as np
import paddle
import soundfile
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"all_version", "UpdateConfig", "seed_all", 'print_arguments',
'add_arguments', "log_add"
]
def all_version():
vers = {
"python": sys.version,
"paddle": paddle.__version__,
"paddle_commit": paddle.version.commit,
"soundfile": soundfile.__version__,
}
logger.info("Deps Module Version:")
for k, v in vers.items():
logger.info(f"{k}: {v}")
@contextmanager
def UpdateConfig(config):
"""Update yacs config"""
config.defrost()
yield
config.freeze()
def seed_all(seed: int=210329):
"""freeze random generator seed."""
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def print_arguments(args, info=None):
......@@ -79,3 +120,22 @@ def log_add(args: List[int]) -> float:
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp
def get_subsample(config):
"""Subsample rate from config.
Args:
config (yacs.config.CfgNode): yaml config
Returns:
int: subsample rate.
"""
input_layer = config["model"]["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册