ctc_utils.py 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hui Zhang 已提交
14
# Modified from wenet(https://github.com/wenet-e2e/wenet)
15 16 17 18 19
from typing import List

import numpy as np
import paddle

20 21 22
from paddlespeech.s2t.utils import text_grid
from paddlespeech.s2t.utils import utility
from paddlespeech.s2t.utils.log import Log
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

logger = Log(__name__).getlog()

__all__ = ["forced_align", "remove_duplicates_and_blank", "insert_blank"]


def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
    """ctc alignment to ctc label ids.

    "abaa-acee-" -> "abaace"

    Args:
        hyp (List[int]): hypotheses ids, (L)
        blank_id (int, optional): blank id. Defaults to 0.

    Returns:
        List[int]: remove dupicate ids, then remove blank id.
    """
    new_hyp: List[int] = []
    cur = 0
    while cur < len(hyp):
H
Hui Zhang 已提交
44
        # add non-blank into new_hyp
45 46
        if hyp[cur] != blank_id:
            new_hyp.append(hyp[cur])
H
Hui Zhang 已提交
47
        # skip repeat label
48 49 50 51 52 53
        prev = cur
        while cur < len(hyp) and hyp[cur] == hyp[prev]:
            cur += 1
    return new_hyp


H
Hui Zhang 已提交
54
def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
55 56 57 58 59
    """Insert blank token between every two label token.

    "abcdefg" -> "-a-b-c-d-e-f-g-"

    Args:
H
Hui Zhang 已提交
60
        label ([np.ndarray]): label ids, List[int], (L).
61 62 63 64 65 66 67 68
        blank_id (int, optional): blank id. Defaults to 0.

    Returns:
        [np.ndarray]: (2L+1).
    """
    label = np.expand_dims(label, 1)  #[L, 1]
    blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
    label = np.concatenate([blanks, label], axis=1)  #[L, 2]
H
Hui Zhang 已提交
69 70
    label = label.reshape(-1)  #[2L], -l-l-l
    label = np.append(label, label[0])  #[2L + 1], -l-l-l-
71 72 73 74
    return label


def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
H
Hui Zhang 已提交
75
                 blank_id=0) -> List[int]:
76 77 78 79 80 81 82 83 84
    """ctc forced alignment.

    https://distill.pub/2017/ctc/

    Args:
        ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D)
        y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
        blank_id (int): blank symbol index
    Returns:
H
Hui Zhang 已提交
85
        List[int]: best alignment result, (T).
86
    """
H
Hui Zhang 已提交
87
    y_insert_blank = insert_blank(y, blank_id)  #(2L+1)
88 89

    log_alpha = paddle.zeros(
H
Hui Zhang 已提交
90
        (ctc_probs.shape[0], len(y_insert_blank)))  #(T, 2L+1)
91
    log_alpha = log_alpha - float('inf')  # log of zero
92

93
    # TODO(Hui Zhang): zeros not support paddle.int16
94
    # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
95
    state_path = (paddle.zeros(
H
Hui Zhang 已提交
96
        (ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1
H
Hui Zhang 已提交
97
                  )  # state path, Tuple((T, 2L+1))
98 99

    # init start state
100 101 102
    # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
    log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])]  # State-b, Sb
    log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])]  # State-nb, Snb
103

H
Hui Zhang 已提交
104
    for t in range(1, ctc_probs.shape[0]):  # T
H
Hui Zhang 已提交
105
        for s in range(len(y_insert_blank)):  # 2L+1
106 107 108 109 110 111 112 113 114 115 116 117
            if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
                    s] == y_insert_blank[s - 2]:
                candidates = paddle.to_tensor(
                    [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
                prev_state = [s, s - 1]
            else:
                candidates = paddle.to_tensor([
                    log_alpha[t - 1, s],
                    log_alpha[t - 1, s - 1],
                    log_alpha[t - 1, s - 2],
                ])
                prev_state = [s, s - 1, s - 2]
118 119 120
            # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
            log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
                y_insert_blank[s])]
121
            state_path[t, s] = prev_state[paddle.argmax(candidates)]
122
    # TODO(Hui Zhang): zeros not support paddle.int16
123
    # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16
H
Hui Zhang 已提交
124
    state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32)
125 126 127 128 129 130 131

    candidates = paddle.to_tensor([
        log_alpha[-1, len(y_insert_blank) - 1],  # Sb
        log_alpha[-1, len(y_insert_blank) - 2]  # Snb
    ])
    prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
    state_seq[-1] = prev_state[paddle.argmax(candidates)]
H
Hui Zhang 已提交
132
    for t in range(ctc_probs.shape[0] - 2, -1, -1):
133 134 135
        state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]

    output_alignment = []
H
Hui Zhang 已提交
136
    for t in range(0, ctc_probs.shape[0]):
137 138 139
        output_alignment.append(y_insert_blank[state_seq[t, 0]])

    return output_alignment
H
Hui Zhang 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211


def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
              result_file):
    """ctc alignment.

    Args:
        model (nn.Layer): U2 Model.
        dataloader (io.DataLoader): dataloader.
        batch_size (int): decoding batchsize.
        stride_ms (int): audio feature stride in ms unit.
        token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '<eos>'].
        result_file (str): alignment output file, e.g. xxx.align.
    """
    if batch_size > 1:
        logger.fatal('alignment mode must be running with batch_size == 1')
        sys.exit(1)

    assert result_file and result_file.endswith('.align')

    model.eval()

    logger.info(f"Align Total Examples: {len(dataloader.dataset)}")

    with open(result_file, 'w') as fout:
        # one example in batch
        for i, batch in enumerate(dataloader):
            key, feat, feats_length, target, target_length = batch

            # 1. Encoder
            encoder_out, encoder_mask = model._forward_encoder(
                feat, feats_length)  # (B, maxlen, encoder_dim)
            maxlen = encoder_out.shape[1]
            ctc_probs = model.ctc.log_softmax(
                encoder_out)  # (1, maxlen, vocab_size)

            # 2. alignment
            ctc_probs = ctc_probs.squeeze(0)
            target = target.squeeze(0)
            alignment = forced_align(ctc_probs, target)

            logger.info(f"align ids: {key[0]} {alignment}")
            fout.write('{} {}\n'.format(key[0], alignment))

            # 3. gen praat
            # segment alignment
            align_segs = text_grid.segment_alignment(alignment)
            logger.info(f"align tokens: {key[0]}, {align_segs}")

            # IntervalTier, List["start end token\n"]
            subsample = utility.get_subsample(self.config)

            tierformat = text_grid.align_to_tierformat(align_segs, subsample,
                                                       token_dict)

            # write tier
            align_output_path = Path(self.args.result_file).parent / "align"
            align_output_path.mkdir(parents=True, exist_ok=True)
            tier_path = align_output_path / (key[0] + ".tier")
            with tier_path.open('w') as f:
                f.writelines(tierformat)

            # write textgrid
            textgrid_path = align_output_path / (key[0] + ".TextGrid")
            second_per_frame = 1. / (1000. /
                                     stride_ms)  # 25ms window, 10ms stride
            second_per_example = (
                len(alignment) + 1) * subsample * second_per_frame
            text_grid.generate_textgrid(
                maxtime=second_per_example,
                intervals=tierformat,
                output=str(textgrid_path))