ctc_prefix_score.py 14.2 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import numpy as np
H
Hui Zhang 已提交
5
import paddle
H
Hui Zhang 已提交
6 7 8
import six


H
Hui Zhang 已提交
9
class CTCPrefixScorePD():
H
Hui Zhang 已提交
10 11 12 13 14 15 16 17 18 19 20 21 22
    """Batch processing of CTCPrefixScore

    which is based on Algorithm 2 in WATANABE et al.
    "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
    but extended to efficiently compute the label probablities for multiple
    hypotheses simultaneously
    See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
    Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
    """

    def __init__(self, x, xlens, blank, eos, margin=0):
        """Construct CTC prefix scorer

H
Hui Zhang 已提交
23 24 25 26
        `margin` is M in eq.(22,23)

        :param paddle.Tensor x: input label posterior sequences (B, T, O)
        :param paddle.Tensor xlens: input lengths (B,)
H
Hui Zhang 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39
        :param int blank: blank label id
        :param int eos: end-of-sequence id
        :param int margin: margin parameter for windowing (0 means no windowing)
        """
        # In the comment lines,
        # we assume T: input_length, B: batch size, W: beam width, O: output dim.
        self.logzero = -10000000000.0
        self.blank = blank
        self.eos = eos
        self.batch = x.size(0)
        self.input_length = x.size(1)
        self.odim = x.size(2)
        self.dtype = x.dtype
H
Hui Zhang 已提交
40

H
Hui Zhang 已提交
41 42 43 44 45 46 47
        # Pad the rest of posteriors in the batch
        # TODO(takaaki-hori): need a better way without for-loops
        for i, l in enumerate(xlens):
            if l < self.input_length:
                x[i, l:, :] = self.logzero
                x[i, l:, blank] = 0
        # Reshape input x
H
Hui Zhang 已提交
48
        xn = x.transpose([1, 0, 2])  # (B, T, O) -> (T, B, O)
H
Hui Zhang 已提交
49 50
        xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1,
                                                      self.odim)  # (T,B,O)
H
Hui Zhang 已提交
51
        self.x = paddle.stack([xn, xb])  # (2, T, B, O)
H
Hui Zhang 已提交
52
        self.end_frames = paddle.to_tensor(xlens) - 1  # (B,)
H
Hui Zhang 已提交
53 54 55 56

        # Setup CTC windowing
        self.margin = margin
        if margin > 0:
H
Hui Zhang 已提交
57
            self.frame_ids = paddle.arange(self.input_length, dtype=self.dtype)
H
Hui Zhang 已提交
58
        # Base indices for index conversion
H
Hui Zhang 已提交
59
        # B idx, hyp idx. shape (B*W, 1)
H
Hui Zhang 已提交
60
        self.idx_bh = None
H
Hui Zhang 已提交
61
        # B idx. shape (B,)
H
Hui Zhang 已提交
62
        self.idx_b = paddle.arange(self.batch)
H
Hui Zhang 已提交
63
        # B idx, O idx. shape (B, 1)
H
Hui Zhang 已提交
64 65 66 67 68 69 70
        self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)

    def __call__(self, y, state, scoring_ids=None, att_w=None):
        """Compute CTC prefix scores for next labels

        :param list y: prefix label sequences
        :param tuple state: previous CTC state
H
Hui Zhang 已提交
71 72
        :param paddle.Tensor scoring_ids: selected next ids to score (BW, O'), O' <= O
        :param paddle.Tensor att_w: attention weights to decide CTC window
H
Hui Zhang 已提交
73 74 75 76 77 78
        :return new_state, ctc_local_scores (BW, O)
        """
        output_length = len(y[0]) - 1  # ignore sos
        last_ids = [yi[-1] for yi in y]  # last output label ids
        n_bh = len(last_ids)  # batch * hyps
        n_hyps = n_bh // self.batch  # assuming each utterance has the same # of hyps
H
Hui Zhang 已提交
79 80
        self.scoring_num = scoring_ids.size(
            -1) if scoring_ids is not None else 0
H
Hui Zhang 已提交
81 82
        # prepare state info
        if state is None:
H
Hui Zhang 已提交
83
            r_prev = paddle.full(
H
Hui Zhang 已提交
84 85
                (self.input_length, 2, self.batch, n_hyps),
                self.logzero,
H
Hui Zhang 已提交
86 87 88 89
                dtype=self.dtype, )  # (T, 2, B, W)
            r_prev[:, 1] = paddle.cumsum(self.x[0, :, :, self.blank],
                                         0).unsqueeze(2)
            r_prev = r_prev.view(-1, 2, n_bh)  # (T, 2, BW)
H
Hui Zhang 已提交
90
            s_prev = 0.0  # score
H
Hui Zhang 已提交
91 92
            f_min_prev = 0  # eq. 22-23
            f_max_prev = 1  # eq. 22-23
H
Hui Zhang 已提交
93 94 95 96 97
        else:
            r_prev, s_prev, f_min_prev, f_max_prev = state

        # select input dimensions for scoring
        if self.scoring_num > 0:
H
Hui Zhang 已提交
98
            # (BW, O)
H
Hui Zhang 已提交
99 100
            scoring_idmap = paddle.full(
                (n_bh, self.odim), -1, dtype=paddle.long)
H
Hui Zhang 已提交
101 102
            snum = self.scoring_num
            if self.idx_bh is None or n_bh > len(self.idx_bh):
H
Hui Zhang 已提交
103
                self.idx_bh = paddle.arange(n_bh).view(-1, 1)  # (BW, 1)
H
Hui Zhang 已提交
104
            scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = paddle.arange(snum)
H
Hui Zhang 已提交
105
            scoring_idx = (
H
Hui Zhang 已提交
106 107 108
                scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1,
                                                                 1)  # (BW,1)
            ).view(-1)  # (BWO)
H
Hui Zhang 已提交
109 110
            # x_ shape (2, T, B*W, O)
            x_ = paddle.index_select(
H
Hui Zhang 已提交
111 112
                self.x.view(2, -1, self.batch * self.odim), scoring_idx,
                2).view(2, -1, n_bh, snum)
H
Hui Zhang 已提交
113 114 115 116
        else:
            scoring_ids = None
            scoring_idmap = None
            snum = self.odim
H
Hui Zhang 已提交
117
            # x_ shape (2, T, B*W, O)
H
Hui Zhang 已提交
118 119
            x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1,
                                                                     n_bh, snum)
H
Hui Zhang 已提交
120 121 122

        # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
        # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
H
Hui Zhang 已提交
123
        r = paddle.full(
H
Hui Zhang 已提交
124 125
            (self.input_length, 2, n_bh, snum),
            self.logzero,
H
Hui Zhang 已提交
126
            dtype=self.dtype, )
H
Hui Zhang 已提交
127 128 129
        if output_length == 0:
            r[0, 0] = x_[0, 0]

H
Hui Zhang 已提交
130 131
        r_sum = paddle.logsumexp(r_prev, 1)  #(T,BW)
        log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)  # (T, BW, O)
H
Hui Zhang 已提交
132 133 134 135 136 137 138 139 140 141 142
        if scoring_ids is not None:
            for idx in range(n_bh):
                pos = scoring_idmap[idx, last_ids[idx]]
                if pos >= 0:
                    log_phi[:, idx, pos] = r_prev[:, 1, idx]
        else:
            for idx in range(n_bh):
                log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]

        # decide start and end frames based on attention weights
        if att_w is not None and self.margin > 0:
H
Hui Zhang 已提交
143
            f_arg = paddle.matmul(att_w, self.frame_ids)
H
Hui Zhang 已提交
144 145 146 147 148 149
            f_min = max(int(f_arg.min().cpu()), f_min_prev)
            f_max = max(int(f_arg.max().cpu()), f_max_prev)
            start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
            end = min(f_max + self.margin, self.input_length)
        else:
            f_min = f_max = 0
H
Hui Zhang 已提交
150
            # if one frame one out, the output_length is the eating frame num now.
H
Hui Zhang 已提交
151 152 153 154 155
            start = max(output_length, 1)
            end = self.input_length

        # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
        for t in range(start, end):
H
Hui Zhang 已提交
156
            rp = r[t - 1]  # (2 x BW x O') 
H
Hui Zhang 已提交
157
            rr = paddle.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
H
Hui Zhang 已提交
158
                2, 2, n_bh, snum)  # (2,2,BW,O')
H
Hui Zhang 已提交
159
            r[t] = paddle.logsumexp(rr, 1) + x_[:, t]
H
Hui Zhang 已提交
160 161

        # compute log prefix probabilities log(psi)
H
Hui Zhang 已提交
162 163
        log_phi_x = paddle.concat(
            (log_phi[0].unsqueeze(0), log_phi[:-1]), axis=0) + x_[0]
H
Hui Zhang 已提交
164
        if scoring_ids is not None:
H
Hui Zhang 已提交
165 166
            log_psi = paddle.full(
                (n_bh, self.odim), self.logzero, dtype=self.dtype)
H
Hui Zhang 已提交
167
            log_psi_ = paddle.logsumexp(
H
Hui Zhang 已提交
168 169 170 171
                paddle.concat(
                    (log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)),
                    axis=0),
                axis=0, )
H
Hui Zhang 已提交
172 173 174
            for si in range(n_bh):
                log_psi[si, scoring_ids[si]] = log_psi_[si]
        else:
H
Hui Zhang 已提交
175
            log_psi = paddle.logsumexp(
H
Hui Zhang 已提交
176 177 178 179
                paddle.concat(
                    (log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)),
                    axis=0),
                axis=0, )
H
Hui Zhang 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199

        for si in range(n_bh):
            log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]

        # exclude blank probs
        log_psi[:, self.blank] = self.logzero

        return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)

    def index_select_state(self, state, best_ids):
        """Select CTC states according to best ids

        :param state    : CTC state
        :param best_ids : index numbers selected by beam pruning (B, W)
        :return selected_state
        """
        r, s, f_min, f_max, scoring_idmap = state
        # convert ids to BHO space
        n_bh = len(s)
        n_hyps = n_bh // self.batch
H
Hui Zhang 已提交
200 201
        vidx = (best_ids + (self.idx_b *
                            (n_hyps * self.odim)).view(-1, 1)).view(-1)
H
Hui Zhang 已提交
202
        # select hypothesis scores
H
Hui Zhang 已提交
203
        s_new = paddle.index_select(s.view(-1), vidx, 0)
H
Hui Zhang 已提交
204 205 206 207
        s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
        # convert ids to BHS space (S: scoring_num)
        if scoring_idmap is not None:
            snum = self.scoring_num
H
Hui Zhang 已提交
208 209
            hyp_idx = (best_ids // self.odim +
                       (self.idx_b * n_hyps).view(-1, 1)).view(-1)
H
Hui Zhang 已提交
210
            label_ids = paddle.fmod(best_ids, self.odim).view(-1)
H
Hui Zhang 已提交
211 212 213 214 215 216
            score_idx = scoring_idmap[hyp_idx, label_ids]
            score_idx[score_idx == -1] = 0
            vidx = score_idx + hyp_idx * snum
        else:
            snum = self.odim
        # select forward probabilities
H
Hui Zhang 已提交
217
        r_new = paddle.index_select(r.view(-1, 2, n_bh * snum), vidx, 2).view(
H
Hui Zhang 已提交
218
            -1, 2, n_bh)
H
Hui Zhang 已提交
219 220 221 222 223
        return r_new, s_new, f_min, f_max

    def extend_prob(self, x):
        """Extend CTC prob.

H
Hui Zhang 已提交
224
        :param paddle.Tensor x: input label posterior sequences (B, T, O)
H
Hui Zhang 已提交
225 226 227 228 229 230 231 232 233 234 235
        """

        if self.x.shape[1] < x.shape[1]:  # self.x (2,T,B,O); x (B,T,O)
            # Pad the rest of posteriors in the batch
            # TODO(takaaki-hori): need a better way without for-loops
            xlens = [x.size(1)]
            for i, l in enumerate(xlens):
                if l < self.input_length:
                    x[i, l:, :] = self.logzero
                    x[i, l:, self.blank] = 0
            tmp_x = self.x
H
Hui Zhang 已提交
236
            xn = x.transpose([1, 0, 2])  # (B, T, O) -> (T, B, O)
H
Hui Zhang 已提交
237
            xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
H
Hui Zhang 已提交
238
            self.x = paddle.stack([xn, xb])  # (2, T, B, O)
H
Hui Zhang 已提交
239
            self.x[:, :tmp_x.shape[1], :, :] = tmp_x
H
Hui Zhang 已提交
240
            self.input_length = x.size(1)
H
Hui Zhang 已提交
241
            self.end_frames = paddle.to_tensor(xlens) - 1
H
Hui Zhang 已提交
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256

    def extend_state(self, state):
        """Compute CTC prefix state.


        :param state    : CTC state
        :return ctc_state
        """

        if state is None:
            # nothing to do
            return state
        else:
            r_prev, s_prev, f_min_prev, f_max_prev = state

H
Hui Zhang 已提交
257
            r_prev_new = paddle.full(
H
Hui Zhang 已提交
258 259
                (self.input_length, 2),
                self.logzero,
H
Hui Zhang 已提交
260
                dtype=self.dtype, )
H
Hui Zhang 已提交
261 262
            start = max(r_prev.shape[0], 1)
            r_prev_new[0:start] = r_prev
H
Hui Zhang 已提交
263
            for t in range(start, self.input_length):
H
Hui Zhang 已提交
264 265
                r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :,
                                                                 self.blank]
H
Hui Zhang 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284

            return (r_prev_new, s_prev, f_min_prev, f_max_prev)


class CTCPrefixScore():
    """Compute CTC label sequence scores

    which is based on Algorithm 2 in WATANABE et al.
    "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
    but extended to efficiently compute the probablities of multiple labels
    simultaneously
    """

    def __init__(self, x, blank, eos, xp):
        self.xp = xp
        self.logzero = -10000000000.0
        self.blank = blank
        self.eos = eos
        self.input_length = len(x)
H
Hui Zhang 已提交
285
        self.x = x  # (T, O)
H
Hui Zhang 已提交
286 287 288 289 290 291 292 293 294

    def initial_state(self):
        """Obtain an initial CTC state

        :return: CTC state
        """
        # initial CTC state is made of a frame x 2 tensor that corresponds to
        # r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
        # superscripts n and b (non-blank and blank), respectively.
H
Hui Zhang 已提交
295
        # r shape (T, 2)
H
Hui Zhang 已提交
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
        r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
        r[0, 1] = self.x[0, self.blank]
        for i in six.moves.range(1, self.input_length):
            r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
        return r

    def __call__(self, y, cs, r_prev):
        """Compute CTC prefix scores for next labels

        :param y     : prefix label sequence
        :param cs    : array of next labels
        :param r_prev: previous CTC state
        :return ctc_scores, ctc_states
        """
        # initialize CTC states
        output_length = len(y) - 1  # ignore sos
        # new CTC states are prepared as a frame x (n or b) x n_labels tensor
        # that corresponds to r_t^n(h) and r_t^b(h).
H
Hui Zhang 已提交
314
        # r shape (T, 2, n_labels)
H
Hui Zhang 已提交
315 316 317 318 319 320 321 322 323
        r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
        xs = self.x[:, cs]
        if output_length == 0:
            r[0, 0] = xs[0]
            r[0, 1] = self.logzero
        else:
            r[output_length - 1] = self.logzero

        # prepare forward probabilities for the last label
H
Hui Zhang 已提交
324 325
        r_sum = self.xp.logaddexp(r_prev[:, 0],
                                  r_prev[:, 1])  # log(r_t^n(g) + r_t^b(g))
H
Hui Zhang 已提交
326 327
        last = y[-1]
        if output_length > 0 and last in cs:
H
Hui Zhang 已提交
328 329
            log_phi = self.xp.ndarray(
                (self.input_length, len(cs)), dtype=np.float32)
H
Hui Zhang 已提交
330 331 332 333 334 335 336 337 338 339 340
            for i in six.moves.range(len(cs)):
                log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
        else:
            log_phi = r_sum

        # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
        # and log prefix probabilities log(psi)
        start = max(output_length, 1)
        log_psi = r[start - 1, 0]
        for t in six.moves.range(start, self.input_length):
            r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
H
Hui Zhang 已提交
341 342
            r[t, 1] = (self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) +
                       self.x[t, self.blank])
H
Hui Zhang 已提交
343 344 345 346 347 348 349 350 351 352 353 354 355 356
            log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])

        # get P(...eos|X) that ends with the prefix itself
        eos_pos = self.xp.where(cs == self.eos)[0]
        if len(eos_pos) > 0:
            log_psi[eos_pos] = r_sum[-1]  # log(r_T^n(g) + r_T^b(g))

        # exclude blank probs
        blank_pos = self.xp.where(cs == self.blank)[0]
        if len(blank_pos) > 0:
            log_psi[blank_pos] = self.logzero

        # return the log prefix probability and CTC states, where the label axis
        # of the CTC states is moved to the first axis to slice it easily
H
Hui Zhang 已提交
357
        # log_psi shape (n_labels,), state shape (n_labels, T, 2)
H
Hui Zhang 已提交
358
        return log_psi, self.xp.rollaxis(r, 2)