From f3338265e0f5a79aa8c78e0a5092814f0b08e8c8 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 2 Oct 2021 15:35:12 +0000 Subject: [PATCH] more utils for train --- deepspeech/__init__.py | 75 +++--- deepspeech/models/deepspeech2.py | 4 +- deepspeech/models/u2.py | 4 +- deepspeech/training/trainer.py | 4 +- deepspeech/utils/bleu_score.py | 54 ++++ deepspeech/utils/checkpoint.py | 390 ++++++++++++++++++++--------- deepspeech/utils/ctc_utils.py | 48 ++-- deepspeech/utils/dynamic_import.py | 67 +++++ deepspeech/utils/error_rate.py | 7 + deepspeech/utils/log.py | 103 ++++---- deepspeech/utils/profiler.py | 119 +++++++++ deepspeech/utils/socket_server.py | 4 +- deepspeech/utils/tensor_utils.py | 45 +++- deepspeech/utils/text_grid.py | 127 ++++++++++ deepspeech/utils/utility.py | 62 ++++- requirements.txt | 2 +- 16 files changed, 864 insertions(+), 251 deletions(-) create mode 100644 deepspeech/utils/bleu_score.py create mode 100644 deepspeech/utils/dynamic_import.py create mode 100644 deepspeech/utils/profiler.py create mode 100644 deepspeech/utils/text_grid.py diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 37531657..299c799c 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -26,9 +26,6 @@ from deepspeech.utils.log import Log #TODO(Hui Zhang): remove fluid import logger = Log(__name__).getlog() -########### hcak logging ############# -logger.warn = logger.warning - ########### hcak paddle ############# paddle.bool = 'bool' paddle.float16 = 'float16' @@ -91,23 +88,23 @@ def convert_dtype_to_string(tensor_dtype): if not hasattr(paddle, 'softmax'): - logger.warn("register user softmax to paddle, remove this when fixed!") + logger.debug("register user softmax to paddle, remove this when fixed!") setattr(paddle, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle, 'log_softmax'): - logger.warn("register user log_softmax to paddle, remove this when fixed!") + logger.debug("register user log_softmax to paddle, remove this when fixed!") setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) if not hasattr(paddle, 'sigmoid'): - logger.warn("register user sigmoid to paddle, remove this when fixed!") + logger.debug("register user sigmoid to paddle, remove this when fixed!") setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle, 'log_sigmoid'): - logger.warn("register user log_sigmoid to paddle, remove this when fixed!") + logger.debug("register user log_sigmoid to paddle, remove this when fixed!") setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) if not hasattr(paddle, 'relu'): - logger.warn("register user relu to paddle, remove this when fixed!") + logger.debug("register user relu to paddle, remove this when fixed!") setattr(paddle, 'relu', paddle.nn.functional.relu) @@ -116,7 +113,7 @@ def cat(xs, dim=0): if not hasattr(paddle, 'cat'): - logger.warn( + logger.debug( "override cat of paddle if exists or register, remove this when fixed!") paddle.cat = cat @@ -127,7 +124,7 @@ def item(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'item'): - logger.warn( + logger.debug( "override item of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.item = item @@ -138,13 +135,13 @@ def func_long(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'long'): - logger.warn( + logger.debug( "override long of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.long = func_long if not hasattr(paddle.Tensor, 'numel'): - logger.warn( + logger.debug( "override numel of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.numel = paddle.numel @@ -158,7 +155,7 @@ def new_full(x: paddle.Tensor, if not hasattr(paddle.Tensor, 'new_full'): - logger.warn( + logger.debug( "override new_full of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.new_full = new_full @@ -173,13 +170,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'eq'): - logger.warn( + logger.debug( "override eq of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.eq = eq if not hasattr(paddle, 'eq'): - logger.warn( + logger.debug( "override eq of paddle if exists or register, remove this when fixed!") paddle.eq = eq @@ -189,7 +186,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'contiguous'): - logger.warn( + logger.debug( "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.contiguous = contiguous @@ -206,7 +203,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: #`to_static` do not process `size` property, maybe some `paddle` api dependent on it. -logger.warn( +logger.debug( "override size of paddle.Tensor " "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" ) @@ -218,7 +215,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view'): - logger.warn("register user view to paddle.Tensor, remove this when fixed!") + logger.debug("register user view to paddle.Tensor, remove this when fixed!") paddle.Tensor.view = view @@ -227,7 +224,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view_as'): - logger.warn( + logger.debug( "register user view_as to paddle.Tensor, remove this when fixed!") paddle.Tensor.view_as = view_as @@ -253,7 +250,7 @@ def masked_fill(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill'): - logger.warn( + logger.debug( "register user masked_fill to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill = masked_fill @@ -271,7 +268,7 @@ def masked_fill_(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill_'): - logger.warn( + logger.debug( "register user masked_fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill_ = masked_fill_ @@ -283,7 +280,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'fill_'): - logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.fill_ = fill_ @@ -292,22 +290,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'repeat'): - logger.warn( + logger.debug( "register user repeat to paddle.Tensor, remove this when fixed!") paddle.Tensor.repeat = repeat if not hasattr(paddle.Tensor, 'softmax'): - logger.warn( + logger.debug( "register user softmax to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle.Tensor, 'sigmoid'): - logger.warn( + logger.debug( "register user sigmoid to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle.Tensor, 'relu'): - logger.warn("register user relu to paddle.Tensor, remove this when fixed!") + logger.debug("register user relu to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) @@ -316,7 +314,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'type_as'): - logger.warn( + logger.debug( "register user type_as to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'type_as', type_as) @@ -332,7 +330,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'to'): - logger.warn("register user to to paddle.Tensor, remove this when fixed!") + logger.debug("register user to to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'to', to) @@ -341,7 +339,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'float'): - logger.warn("register user float to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user float to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'float', func_float) @@ -350,7 +349,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'int'): - logger.warn("register user int to paddle.Tensor, remove this when fixed!") + logger.debug("register user int to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'int', func_int) @@ -359,7 +358,7 @@ def tolist(x: paddle.Tensor) -> List[Any]: if not hasattr(paddle.Tensor, 'tolist'): - logger.warn( + logger.debug( "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) @@ -374,7 +373,7 @@ def glu(x: paddle.Tensor, axis=-1) -> paddle.Tensor: if not hasattr(paddle.nn.functional, 'glu'): - logger.warn( + logger.debug( "register user glu to paddle.nn.functional, remove this when fixed!") setattr(paddle.nn.functional, 'glu', glu) @@ -425,19 +424,19 @@ def ctc_loss(logits, return loss_out -logger.warn( +logger.debug( "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!" ) F.ctc_loss = ctc_loss ########### hcak paddle.nn ############# if not hasattr(paddle.nn, 'Module'): - logger.warn("register user Module to paddle.nn, remove this when fixed!") + logger.debug("register user Module to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'Module', paddle.nn.Layer) # maybe cause assert isinstance(sublayer, core.Layer) if not hasattr(paddle.nn, 'ModuleList'): - logger.warn( + logger.debug( "register user ModuleList to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'ModuleList', paddle.nn.LayerList) @@ -454,7 +453,7 @@ class GLU(nn.Layer): if not hasattr(paddle.nn, 'GLU'): - logger.warn("register user GLU to paddle.nn, remove this when fixed!") + logger.debug("register user GLU to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'GLU', GLU) @@ -486,12 +485,12 @@ class ConstantPad2d(nn.Layer): if not hasattr(paddle.nn, 'ConstantPad2d'): - logger.warn( + logger.debug( "register user ConstantPad2d to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d) ########### hcak paddle.jit ############# if not hasattr(paddle.jit, 'export'): - logger.warn("register user export to paddle.jit, remove this when fixed!") + logger.debug("register user export to paddle.jit, remove this when fixed!") setattr(paddle.jit, 'export', paddle.jit.to_static) diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 0ff5514d..5b8f2372 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -21,8 +21,8 @@ from yacs.config import CfgNode from deepspeech.modules.conv import ConvStack from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.rnn import RNNStack -from deepspeech.utils import checkpoint from deepspeech.utils import layer_tools +from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -222,7 +222,7 @@ class DeepSpeech2Model(nn.Layer): rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights) - infos = checkpoint.load_parameters( + infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") layer_tools.summary(model) diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 238e2d35..0677b70c 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -40,8 +40,8 @@ from deepspeech.modules.mask import make_pad_mask from deepspeech.modules.mask import mask_finished_preds from deepspeech.modules.mask import mask_finished_scores from deepspeech.modules.mask import subsequent_mask -from deepspeech.utils import checkpoint from deepspeech.utils import layer_tools +from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.ctc_utils import remove_duplicates_and_blank from deepspeech.utils.log import Log from deepspeech.utils.tensor_utils import add_sos_eos @@ -894,7 +894,7 @@ class U2Model(U2BaseModel): model = cls.from_config(config) if checkpoint_path: - infos = checkpoint.load_parameters( + infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") layer_tools.summary(model) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 56de3261..11e5f214 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,8 +18,8 @@ import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter -from deepspeech.utils import checkpoint from deepspeech.utils import mp_tools +from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log __all__ = ["Trainer"] @@ -151,7 +151,7 @@ class Trainer(): resume training. """ scratch = None - infos = checkpoint.load_parameters( + infos = Checkpoint().load_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, diff --git a/deepspeech/utils/bleu_score.py b/deepspeech/utils/bleu_score.py new file mode 100644 index 00000000..09646133 --- /dev/null +++ b/deepspeech/utils/bleu_score.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module provides functions to calculate bleu score in different level. +e.g. wer for word-level, cer for char-level. +""" +import sacrebleu + +__all__ = ['bleu', 'char_bleu'] + + +def bleu(hypothesis, reference): + """Calculate BLEU. BLEU compares reference text and + hypothesis text in word-level using scarebleu. + + + + :param reference: The reference sentences. + :type reference: list[list[str]] + :param hypothesis: The hypothesis sentence. + :type hypothesis: list[str] + :raises ValueError: If the reference length is zero. + """ + + return sacrebleu.corpus_bleu(hypothesis, reference) + + +def char_bleu(hypothesis, reference): + """Calculate BLEU. BLEU compares reference text and + hypothesis text in char-level using scarebleu. + + + + :param reference: The reference sentences. + :type reference: list[list[str]] + :param hypothesis: The hypothesis sentence. + :type hypothesis: list[str] + :raises ValueError: If the reference number is zero. + """ + hypothesis = [' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis] + reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref] + for ref in reference] + + return sacrebleu.corpus_bleu(hypothesis, reference) diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 8ede6b8f..8e31edfa 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import json import os import re +from pathlib import Path +from typing import Text from typing import Union import paddle @@ -25,128 +28,271 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["load_parameters", "save_parameters"] - - -def _load_latest_checkpoint(checkpoint_dir: str) -> int: - """Get the iteration number corresponding to the latest saved checkpoint. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - Returns: - int: the latest iteration number. -1 for no checkpoint to load. - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - if not os.path.isfile(checkpoint_record): - return -1 - - # Fetch the latest checkpoint index. - with open(checkpoint_record, "rt") as handle: - latest_checkpoint = handle.readlines()[-1].strip() - iteration = int(latest_checkpoint.split(":")[-1]) - return iteration - - -def _save_record(checkpoint_dir: str, iteration: int): - """Save the iteration number of the latest model to be checkpoint record. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - iteration (int): the latest iteration number. - Returns: - None - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - # Update the latest checkpoint index. - with open(checkpoint_record, "a+") as handle: - handle.write("model_checkpoint_path:{}\n".format(iteration)) - - -def load_parameters(model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): - """Load a specific model checkpoint from disk. - Args: - model (Layer): model to load parameters. - optimizer (Optimizer, optional): optimizer to load states if needed. - Defaults to None. - checkpoint_dir (str, optional): the directory where checkpoint is saved. - checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. - Returns: - configs (dict): epoch or step, lr and other meta info should be saved. - """ - configs = {} - - if checkpoint_path is not None: - tag = os.path.basename(checkpoint_path).split(":")[-1] - elif checkpoint_dir is not None: - iteration = _load_latest_checkpoint(checkpoint_dir) - if iteration == -1: - return configs - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) - else: - raise ValueError( - "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" - ) - - rank = dist.get_rank() - - params_path = checkpoint_path + ".pdparams" - model_dict = paddle.load(params_path) - model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) - - optimizer_path = checkpoint_path + ".pdopt" - if optimizer and os.path.isfile(optimizer_path): - optimizer_dict = paddle.load(optimizer_path) - optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( - rank, optimizer_path)) - - info_path = re.sub('.pdparams$', '.json', params_path) - if os.path.exists(info_path): - with open(info_path, 'r') as fin: - configs = json.load(fin) - return configs - - -@mp_tools.rank_zero_only -def save_parameters(checkpoint_dir: str, - tag_or_iteration: Union[int, str], - model: paddle.nn.Layer, - optimizer: Optimizer=None, - infos: dict=None): - """Checkpoint the latest trained model parameters. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - tag_or_iteration (int or str): the latest iteration(step or epoch) number. - model (Layer): model to be checkpointed. - optimizer (Optimizer, optional): optimizer to be checkpointed. - Defaults to None. - infos (dict or None): any info you want to save. - Returns: - None - """ - checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(tag_or_iteration)) - - model_dict = model.state_dict() - params_path = checkpoint_path + ".pdparams" - paddle.save(model_dict, params_path) - logger.info("Saved model to {}".format(params_path)) - - if optimizer: - opt_dict = optimizer.state_dict() +__all__ = ["Checkpoint"] + + +class Checkpoint(): + def __init__(self, kbest_n: int=5, latest_n: int=1): + self.best_records: Mapping[Path, float] = {} + self.latest_records = [] + self.kbest_n = kbest_n + self.latest_n = latest_n + self._save_all = (kbest_n == -1) + + def add_checkpoint(self, + checkpoint_dir, + tag_or_iteration: Union[int, Text], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None, + metric_type="val_loss"): + """Save checkpoint in best_n and latest_n. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + infos (dict or None)): any info you want to save. + metric_type (str, optional): metric type. Defaults to "val_loss". + """ + if (metric_type not in infos.keys()): + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + return + + #save best + if self._should_save_best(infos[metric_type]): + self._save_best_checkpoint_and_update( + infos[metric_type], checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + #save latest + self._save_latest_checkpoint_and_update( + checkpoint_dir, tag_or_iteration, model, optimizer, infos) + + if isinstance(tag_or_iteration, int): + self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) + + def load_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None, + record_file="checkpoint_latest"): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + record_file "checkpoint_latest" or "checkpoint_best" + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + configs = {} + + if checkpoint_path is not None: + pass + elif checkpoint_dir is not None and record_file is not None: + # load checkpint from record file + checkpoint_record = os.path.join(checkpoint_dir, record_file) + iteration = self._load_checkpoint_idx(checkpoint_record) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + logger.info("Rank {}: Restore model from {}".format(rank, params_path)) + optimizer_path = checkpoint_path + ".pdopt" - paddle.save(opt_dict, optimizer_path) - logger.info("Saved optimzier state to {}".format(optimizer_path)) + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + logger.info("Rank {}: Restore optimizer state from {}".format( + rank, optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs + + def load_latest_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_latest") + + def load_best_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_best") + + def _should_save_best(self, metric: float) -> bool: + if not self._best_full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def _best_full(self): + return (not self._save_all) and len(self.best_records) == self.kbest_n + + def _latest_full(self): + return len(self.latest_records) == self.latest_n + + def _save_best_checkpoint_and_update(self, metric, checkpoint_dir, + tag_or_iteration, model, optimizer, + infos): + # remove the worst + if self._best_full(): + worst_record_path = max(self.best_records, + key=self.best_records.get) + self.best_records.pop(worst_record_path) + if (worst_record_path not in self.latest_records): + logger.info( + "remove the worst checkpoint: {}".format(worst_record_path)) + self._del_checkpoint(checkpoint_dir, worst_record_path) + + # add the new one + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + self.best_records[tag_or_iteration] = metric + + def _save_latest_checkpoint_and_update( + self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): + # remove the old + if self._latest_full(): + to_del_fn = self.latest_records.pop(0) + if (to_del_fn not in self.best_records.keys()): + logger.info( + "remove the latest checkpoint: {}".format(to_del_fn)) + self._del_checkpoint(checkpoint_dir, to_del_fn) + self.latest_records.append(tag_or_iteration) + + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + + def _del_checkpoint(self, checkpoint_dir, tag_or_iteration): + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + for filename in glob.glob(checkpoint_path + ".*"): + os.remove(filename) + logger.info("delete file: {}".format(filename)) + + def _load_checkpoint_idx(self, checkpoint_record: str) -> int: + """Get the iteration number corresponding to the latest saved checkpoint. + Args: + checkpoint_path (str): the saved path of checkpoint. + Returns: + int: the latest iteration number. -1 for no checkpoint to load. + """ + if not os.path.isfile(checkpoint_record): + return -1 + + # Fetch the latest checkpoint index. + with open(checkpoint_record, "rt") as handle: + latest_checkpoint = handle.readlines()[-1].strip() + iteration = int(latest_checkpoint.split(":")[-1]) + return iteration + + def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int): + """Save the iteration number of the latest model to be checkpoint record. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + Returns: + None + """ + checkpoint_record_latest = os.path.join(checkpoint_dir, + "checkpoint_latest") + checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") + + with open(checkpoint_record_best, "w") as handle: + for i in self.best_records.keys(): + handle.write("model_checkpoint_path:{}\n".format(i)) + with open(checkpoint_record_latest, "w") as handle: + for i in self.latest_records: + handle.write("model_checkpoint_path:{}\n".format(i)) + + @mp_tools.rank_zero_only + def _save_parameters(self, + checkpoint_dir: str, + tag_or_iteration: Union[int, str], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): + """Checkpoint the latest trained model parameters. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + Defaults to None. + infos (dict or None): any info you want to save. + Returns: + None + """ + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + + model_dict = model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + logger.info("Saved model to {}".format(params_path)) - info_path = re.sub('.pdparams$', '.json', params_path) - infos = {} if infos is None else infos - with open(info_path, 'w') as fout: - data = json.dumps(infos) - fout.write(data) + if optimizer: + opt_dict = optimizer.state_dict() + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + logger.info("Saved optimzier state to {}".format(optimizer_path)) - if isinstance(tag_or_iteration, int): - _save_record(checkpoint_dir, tag_or_iteration) + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py index 73669fea..70d99e6c 100644 --- a/deepspeech/utils/ctc_utils.py +++ b/deepspeech/utils/ctc_utils.py @@ -38,21 +38,23 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): + # add non-blank into new_hyp if hyp[cur] != blank_id: new_hyp.append(hyp[cur]) + # skip repeat label prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 return new_hyp -def insert_blank(label: np.ndarray, blank_id: int=0): +def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray: """Insert blank token between every two label token. "abcdefg" -> "-a-b-c-d-e-f-g-" Args: - label ([np.ndarray]): label ids, (L). + label ([np.ndarray]): label ids, List[int], (L). blank_id (int, optional): blank id. Defaults to 0. Returns: @@ -61,13 +63,13 @@ def insert_blank(label: np.ndarray, blank_id: int=0): label = np.expand_dims(label, 1) #[L, 1] blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id label = np.concatenate([blanks, label], axis=1) #[L, 2] - label = label.reshape(-1) #[2L] - label = np.append(label, label[0]) #[2L + 1] + label = label.reshape(-1) #[2L], -l-l-l + label = np.append(label, label[0]) #[2L + 1], -l-l-l- return label def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, - blank_id=0) -> list: + blank_id=0) -> List[int]: """ctc forced alignment. https://distill.pub/2017/ctc/ @@ -77,23 +79,27 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, y (paddle.Tensor): label id sequence tensor, 1d tensor (L) blank_id (int): blank symbol index Returns: - paddle.Tensor: best alignment result, (T). + List[int]: best alignment result, (T). """ - y_insert_blank = insert_blank(y, blank_id) + y_insert_blank = insert_blank(y, blank_id) #(2L+1) log_alpha = paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) + (ctc_probs.shape[0], len(y_insert_blank))) #(T, 2L+1) log_alpha = log_alpha - float('inf') # log of zero + + # TODO(Hui Zhang): zeros not support paddle.int16 + # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16 state_path = (paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 - ) # state path + (ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1 + ) # state path, Tuple((T, 2L+1)) # init start state - log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb - log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb + # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 + log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb + log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb - for t in range(1, ctc_probs.size(0)): - for s in range(len(y_insert_blank)): + for t in range(1, ctc_probs.shape[0]): # T + for s in range(len(y_insert_blank)): # 2L+1 if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ s] == y_insert_blank[s - 2]: candidates = paddle.to_tensor( @@ -106,11 +112,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, log_alpha[t - 1, s - 2], ]) prev_state = [s, s - 1, s - 2] - log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ - y_insert_blank[s]] + # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 + log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int( + y_insert_blank[s])] state_path[t, s] = prev_state[paddle.argmax(candidates)] - - state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16) + # TODO(Hui Zhang): zeros not support paddle.int16 + # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16 + state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32) candidates = paddle.to_tensor([ log_alpha[-1, len(y_insert_blank) - 1], # Sb @@ -118,11 +126,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, ]) prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] state_seq[-1] = prev_state[paddle.argmax(candidates)] - for t in range(ctc_probs.size(0) - 2, -1, -1): + for t in range(ctc_probs.shape[0] - 2, -1, -1): state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] output_alignment = [] - for t in range(0, ctc_probs.size(0)): + for t in range(0, ctc_probs.shape[0]): output_alignment.append(y_insert_blank[state_seq[t, 0]]) return output_alignment diff --git a/deepspeech/utils/dynamic_import.py b/deepspeech/utils/dynamic_import.py new file mode 100644 index 00000000..533f15ee --- /dev/null +++ b/deepspeech/utils/dynamic_import.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +from typing import Any +from typing import Dict +from typing import List +from typing import Text + +from deepspeech.utils.log import Log +from deepspeech.utils.tensor_utils import has_tensor + +logger = Log(__name__).getlog() + +__all__ = ["dynamic_import", "instance_class"] + + +def dynamic_import(import_path, alias=dict()): + """dynamic import module and class + + :param str import_path: syntax 'module_name:class_name' + e.g., 'deepspeech.models.u2:U2Model' + :param dict alias: shortcut for registered class + :return: imported class + """ + if import_path not in alias and ":" not in import_path: + raise ValueError("import_path should be one of {} or " + 'include ":", e.g. "deepspeech.models.u2:U2Model" : ' + "{}".format(set(alias), import_path)) + if ":" not in import_path: + import_path = alias[import_path] + + module_name, objname = import_path.split(":") + m = importlib.import_module(module_name) + return getattr(m, objname) + + +def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]): + # filter by `valid_keys` and filter `val` is not None + new_args = { + key: val + for key, val in args.items() if (key in valid_keys and val is not None) + } + return new_args + + +def filter_out_tenosr(args: Dict[Text, Any]): + return {key: val for key, val in args.items() if not has_tensor(val)} + + +def instance_class(module_class, args: Dict[Text, Any]): + valid_keys = inspect.signature(module_class).parameters.keys() + new_args = filter_valid_args(args, valid_keys) + logger.info( + f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.") + return module_class(**new_args) diff --git a/deepspeech/utils/error_rate.py b/deepspeech/utils/error_rate.py index b6399bab..81f458b6 100644 --- a/deepspeech/utils/error_rate.py +++ b/deepspeech/utils/error_rate.py @@ -14,10 +14,13 @@ """This module provides functions to calculate error rate in different level. e.g. wer for word-level, cer for char-level. """ +import editdistance import numpy as np __all__ = ['word_errors', 'char_errors', 'wer', 'cer'] +editdistance.eval("a", "b") + def _levenshtein_distance(ref, hyp): """Levenshtein distance is a string metric for measuring the difference @@ -89,6 +92,8 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '): hyp_words = list(filter(None, hypothesis.split(delimiter))) edit_distance = _levenshtein_distance(ref_words, hyp_words) + # `editdistance.eavl precision` less than `_levenshtein_distance` + # edit_distance = editdistance.eval(ref_words, hyp_words) return float(edit_distance), len(ref_words) @@ -119,6 +124,8 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False): hypothesis = join_char.join(list(filter(None, hypothesis.split(' ')))) edit_distance = _levenshtein_distance(reference, hypothesis) + # `editdistance.eavl precision` less than `_levenshtein_distance` + # edit_distance = editdistance.eval(reference, hypothesis) return float(edit_distance), len(reference) diff --git a/deepspeech/utils/log.py b/deepspeech/utils/log.py index 499b1872..7e8de600 100644 --- a/deepspeech/utils/log.py +++ b/deepspeech/utils/log.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import getpass -import logging import os import socket import sys -FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' -DATE_FMT_STR = '%Y/%m/%d %H:%M:%S' - -logging.basicConfig( - level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR) +from loguru import logger +from paddle import inference def find_log_dir(log_dir=None): @@ -96,53 +92,54 @@ def find_log_dir_and_names(program_name=None, log_dir=None): class Log(): + """Default Logger for all.""" + logger.remove() + logger.add( + sys.stdout, + level='INFO', + enqueue=True, + filter=lambda record: record['level'].no >= 20) + _, file_prefix, _ = find_log_dir_and_names() + sink_prefix = os.path.join("exp/log", file_prefix) + sink_path = sink_prefix[:-3] + "{time}.log" + logger.add(sink_path, level='DEBUG', enqueue=True, rotation="500 MB") + + def __init__(self, name=None): + pass - log_name = None - - def __init__(self, logger=None): - self.logger = logging.getLogger(logger) - self.logger.setLevel(logging.DEBUG) - - file_dir = os.getcwd() + '/log' - if not os.path.exists(file_dir): - os.mkdir(file_dir) - self.log_dir = file_dir - - actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names( - program_name=None, log_dir=self.log_dir) - - basename = '%s.DEBUG.%d' % (file_prefix, os.getpid()) - filename = os.path.join(actual_log_dir, basename) - if Log.log_name is None: - Log.log_name = filename - - # Create a symlink to the log file with a canonical name. - symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG') - try: - if os.path.islink(symlink): - os.unlink(symlink) - os.symlink(os.path.basename(Log.log_name), symlink) - except EnvironmentError: - # If it fails, we're sad but it's no error. Commonly, this - # fails because the symlink was created by another user and so - # we can't modify it - pass - - if not self.logger.hasHandlers(): - formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR) - fh = logging.FileHandler(Log.log_name) - fh.setLevel(logging.DEBUG) - fh.setFormatter(formatter) - self.logger.addHandler(fh) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - self.logger.addHandler(ch) - - # stop propagate for propagating may print - # log multiple times - self.logger.propagate = False + def getlog(self): + return logger + + +class Autolog: + """Just used by fullchain project""" + + def __init__(self, + batch_size, + model_name="DeepSpeech", + model_precision="fp32"): + import auto_log + pid = os.getpid() + if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): + gpu_id = int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0]) + infer_config = inference.Config() + infer_config.enable_use_gpu(100, gpu_id) + else: + gpu_id = None + infer_config = inference.Config() + autolog = auto_log.AutoLogger( + model_name=model_name, + model_precision=model_precision, + batch_size=batch_size, + data_shape="dynamic", + save_path="./output/auto_log.lpg", + inference_config=infer_config, + pids=pid, + process_name=None, + gpu_ids=gpu_id, + time_keys=['preprocess_time', 'inference_time', 'postprocess_time'], + warmup=0) + self.autolog = autolog def getlog(self): - return self.logger + return self.autolog diff --git a/deepspeech/utils/profiler.py b/deepspeech/utils/profiler.py new file mode 100644 index 00000000..5733f8ed --- /dev/null +++ b/deepspeech/utils/profiler.py @@ -0,0 +1,119 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +import paddle + +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +# A global variable to record the number of calling times for profiler +# functions. It is used to specify the tracing range of training steps. +_profiler_step_id = 0 + +# A global variable to avoid parsing from string every time. +_profiler_options = None + + +class ProfilerOptions(object): + ''' + Use a string to initialize a ProfilerOptions. + The string should be in the format: "key1=value1;key2=value;key3=value3". + For example: + "profile_path=model.profile" + "batch_range=[50, 60]; profile_path=model.profile" + "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" + ProfilerOptions supports following key-value pair: + batch_range - a integer list, e.g. [100, 110]. + state - a string, the optional values are 'CPU', 'GPU' or 'All'. + sorted_key - a string, the optional values are 'calls', 'total', + 'max', 'min' or 'ave. + tracer_option - a string, the optional values are 'Default', 'OpDetail', + 'AllOpDetail'. + profile_path - a string, the path to save the serialized profile data, + which can be used to generate a timeline. + exit_on_finished - a boolean. + ''' + + def __init__(self, options_str): + assert isinstance(options_str, str) + + self._options = { + 'batch_range': [10, 20], + 'state': 'All', + 'sorted_key': 'total', + 'tracer_option': 'Default', + 'profile_path': '/tmp/profile', + 'exit_on_finished': True + } + self._parse_from_string(options_str) + + def _parse_from_string(self, options_str): + if not options_str: + return + + for kv in options_str.replace(' ', '').split(';'): + key, value = kv.split('=') + if key == 'batch_range': + value_list = value.replace('[', '').replace(']', '').split(',') + value_list = list(map(int, value_list)) + if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ + 1] > value_list[0]: + self._options[key] = value_list + elif key == 'exit_on_finished': + self._options[key] = value.lower() in ("yes", "true", "t", "1") + elif key in [ + 'state', 'sorted_key', 'tracer_option', 'profile_path' + ]: + self._options[key] = value + + def __getitem__(self, name): + if self._options.get(name, None) is None: + raise ValueError( + "ProfilerOptions does not have an option named %s." % name) + return self._options[name] + + +def add_profiler_step(options_str=None): + ''' + Enable the operator-level timing using PaddlePaddle's profiler. + The profiler uses a independent variable to count the profiler steps. + One call of this function is treated as a profiler step. + + Args: + profiler_options - a string to initialize the ProfilerOptions. + Default is None, and the profiler is disabled. + ''' + if options_str is None: + return + + global _profiler_step_id + global _profiler_options + + if _profiler_options is None: + _profiler_options = ProfilerOptions(options_str) + logger.info(f"Profiler: {options_str}") + logger.info(f"Profiler: {_profiler_options._options}") + + if _profiler_step_id == _profiler_options['batch_range'][0]: + paddle.utils.profiler.start_profiler(_profiler_options['state'], + _profiler_options['tracer_option']) + elif _profiler_step_id == _profiler_options['batch_range'][1]: + paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], + _profiler_options['profile_path']) + if _profiler_options['exit_on_finished']: + sys.exit(0) + + _profiler_step_id += 1 diff --git a/deepspeech/utils/socket_server.py b/deepspeech/utils/socket_server.py index adcbf3bb..45c659f6 100644 --- a/deepspeech/utils/socket_server.py +++ b/deepspeech/utils/socket_server.py @@ -48,9 +48,9 @@ def warm_up_test(audio_process_handler, rng = random.Random(random_seed) samples = rng.sample(manifest, num_test_cases) for idx, sample in enumerate(samples): - print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) + print("Warm-up Test Case %d: %s" % (idx, sample['feat'])) start_time = time.time() - transcript = audio_process_handler(sample['audio_filepath']) + transcript = audio_process_handler(sample['feat']) finish_time = time.time() print("Response Time: %f, Transcript: %s" % (finish_time - start_time, transcript)) diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 7679d9e1..0cc03b19 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -19,11 +19,25 @@ import paddle from deepspeech.utils.log import Log -__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] +__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"] logger = Log(__name__).getlog() +def has_tensor(val): + if isinstance(val, (list, tuple)): + for item in val: + if has_tensor(item): + return True + elif isinstance(val, dict): + for k, v in val.items(): + print(k) + if has_tensor(v): + return True + else: + return paddle.is_tensor(val) + + def pad_sequence(sequences: List[paddle.Tensor], batch_first: bool=False, padding_value: float=0.0) -> paddle.Tensor: @@ -69,7 +83,7 @@ def pad_sequence(sequences: List[paddle.Tensor], # (TODO Hui Zhang): slice not supprot `end==start` # trailing_dims = max_size[1:] trailing_dims = max_size[1:] if max_size.ndim >= 2 else () - max_len = max([s.size(0) for s in sequences]) + max_len = max([s.shape[0] for s in sequences]) if batch_first: out_dims = (len(sequences), max_len) + trailing_dims else: @@ -77,12 +91,27 @@ def pad_sequence(sequences: List[paddle.Tensor], out_tensor = sequences[0].new_full(out_dims, padding_value) for i, tensor in enumerate(sequences): - length = tensor.size(0) + length = tensor.shape[0] # use index notation to prevent duplicate references to the tensor + logger.info( + f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}" + ) if batch_first: - out_tensor[i, :length, ...] = tensor + # TODO (Hui Zhang): set_value op not supprot `end==start` + # TODO (Hui Zhang): set_value op not support int16 + # TODO (Hui Zhang): set_varbase 2 rank not support [0,0,...] + # out_tensor[i, :length, ...] = tensor + if length != 0: + out_tensor[i, :length] = tensor + else: + out_tensor[i, length] = tensor else: - out_tensor[:length, i, ...] = tensor + # TODO (Hui Zhang): set_value op not supprot `end==start` + # out_tensor[:length, i, ...] = tensor + if length != 0: + out_tensor[:length, i] = tensor + else: + out_tensor[length, i] = tensor return out_tensor @@ -125,7 +154,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) - B = ys_pad.size(0) + B = ys_pad.shape[0] _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos ys_in = paddle.cat([_sos, ys_pad], dim=1) @@ -151,8 +180,8 @@ def th_accuracy(pad_outputs: paddle.Tensor, Returns: float: Accuracy value (0.0 - 1.0). """ - pad_pred = pad_outputs.view( - pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) + pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], + pad_outputs.shape[1]).argmax(2) mask = pad_targets != ignore_label #TODO(Hui Zhang): sum not support bool type # numerator = paddle.sum( diff --git a/deepspeech/utils/text_grid.py b/deepspeech/utils/text_grid.py new file mode 100644 index 00000000..3af58c9b --- /dev/null +++ b/deepspeech/utils/text_grid.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +from typing import List +from typing import Text + +import textgrid + + +def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]: + """segment ctc alignment ids by continuous blank and repeat label. + + Args: + alignment (List[int]): ctc alignment id sequence. + e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3] + blank_id (int, optional): blank id. Defaults to 0. + + Returns: + List[List[int]]: token align, segment aligment id sequence. + e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]] + """ + # convert alignment to a praat format, which is a doing phonetics + # by computer and helps analyzing alignment + align_segs = [] + # get frames level duration for each token + start = 0 + end = 0 + while end < len(alignment): + while end < len(alignment) and alignment[end] == blank_id: # blank + end += 1 + if end == len(alignment): + align_segs[-1].extend(alignment[start:]) + break + end += 1 + while end < len(alignment) and alignment[end - 1] == alignment[ + end]: # repeat label + end += 1 + align_segs.append(alignment[start:end]) + start = end + return align_segs + + +def align_to_tierformat(align_segs: List[List[int]], + subsample: int, + token_dict: Dict[int, Text], + blank_id=0) -> List[Text]: + """Generate textgrid.Interval format from alignment segmentations. + + Args: + align_segs (List[List[int]]): segmented ctc alignment ids. + subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample + token_dict (Dict[int, Text]): int -> str map. + + Returns: + List[Text]: list of textgrid.Interval text, str(start, end, text). + """ + hop_length = 10 # ms + second_ms = 1000 # ms + frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length + second_per_frame = 1.0 / frame_per_second + + begin = 0 + duration = 0 + tierformat = [] + + for idx, tokens in enumerate(align_segs): + token_len = len(tokens) + token = tokens[-1] + # time duration in second + duration = token_len * subsample * second_per_frame + if idx < len(align_segs) - 1: + print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}") + tierformat.append( + f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n") + else: + for i in tokens: + if i != blank_id: + token = i + break + print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}") + tierformat.append( + f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n") + begin = begin + duration + + return tierformat + + +def generate_textgrid(maxtime: float, + intervals: List[Text], + output: Text, + name: Text='ali') -> None: + """Create alignment textgrid file. + + Args: + maxtime (float): audio duartion. + intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item. + output (Text): textgrid filepath. + name (Text, optional): tier or layer name. Defaults to 'ali'. + """ + # Download Praat: https://www.fon.hum.uva.nl/praat/ + avg_interval = maxtime / (len(intervals) + 1) + print(f"average second/token: {avg_interval}") + margin = 0.0001 + + tg = textgrid.TextGrid(maxTime=maxtime) + tier = textgrid.IntervalTier(name=name, maxTime=maxtime) + + i = 0 + for dur in intervals: + s, e, text = dur.split() + tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text) + + tg.append(tier) + + tg.write(output) + print("successfully generator textgrid {}.".format(output)) diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index 64570026..159b686e 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -15,9 +15,50 @@ import distutils.util import math import os +import random +import sys +from contextlib import contextmanager from typing import List -__all__ = ['print_arguments', 'add_arguments', "log_add"] +import numpy as np +import paddle +import soundfile + +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = [ + "all_version", "UpdateConfig", "seed_all", 'print_arguments', + 'add_arguments', "log_add" +] + + +def all_version(): + vers = { + "python": sys.version, + "paddle": paddle.__version__, + "paddle_commit": paddle.version.commit, + "soundfile": soundfile.__version__, + } + logger.info("Deps Module Version:") + for k, v in vers.items(): + logger.info(f"{k}: {v}") + + +@contextmanager +def UpdateConfig(config): + """Update yacs config""" + config.defrost() + yield + config.freeze() + + +def seed_all(seed: int=210329): + """freeze random generator seed.""" + np.random.seed(seed) + random.seed(seed) + paddle.seed(seed) def print_arguments(args, info=None): @@ -79,3 +120,22 @@ def log_add(args: List[int]) -> float: a_max = max(args) lsp = math.log(sum(math.exp(a - a_max) for a in args)) return a_max + lsp + + +def get_subsample(config): + """Subsample rate from config. + + Args: + config (yacs.config.CfgNode): yaml config + + Returns: + int: subsample rate. + """ + input_layer = config["model"]["encoder_conf"]["input_layer"] + assert input_layer in ["conv2d", "conv2d6", "conv2d8"] + if input_layer == "conv2d": + return 4 + elif input_layer == "conv2d6": + return 6 + elif input_layer == "conv2d8": + return 8 diff --git a/requirements.txt b/requirements.txt index 9ecf6bbd..332b5238 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ coverage +editdistance gpustat jsonlines kaldiio @@ -19,4 +20,3 @@ tqdm typeguard visualdl==2.2.0 yacs -editdistance \ No newline at end of file -- GitLab