align_mandarin.py 5.5 KB
Newer Older
O
oyjxer 已提交
1 2
#!/usr/bin/env python
""" Usage:
小湉湉's avatar
小湉湉 已提交
3
    align_mandarin.py wavfile trsfile outwordfile putphonefile
O
oyjxer 已提交
4
"""
小湉湉's avatar
小湉湉 已提交
5
import multiprocessing as mp
O
oyjxer 已提交
6 7 8
import os
import sys

小湉湉's avatar
小湉湉 已提交
9
from tqdm import tqdm
O
oyjxer 已提交
10 11 12 13 14 15 16 17 18 19

MODEL_DIR = 'tools/aligner/mandarin'
HVITE = 'tools/htk/HTKTools/HVite'
HCOPY = 'tools/htk/HTKTools/HCopy'


def prep_txt(line, tmpbase, dictfile):

    words = []
    line = line.strip()
小湉湉's avatar
小湉湉 已提交
20 21 22 23
    for pun in [
            ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',',
            u'。', u':', u';', u'!', u'?', u'(', u')'
    ]:
O
oyjxer 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
        line = line.replace(pun, ' ')
    for wrd in line.split():
        if (wrd[-1] == '-'):
            wrd = wrd[:-1]
        if (wrd[0] == "'"):
            wrd = wrd[1:]
        if wrd:
            words.append(wrd)

    ds = set([])
    with open(dictfile, 'r') as fid:
        for line in fid:
            ds.add(line.split()[0])

    unk_words = set([])
    with open(tmpbase + '.txt', 'w') as fwid:
        for wrd in words:
            if (wrd not in ds):
                unk_words.add(wrd)
            fwid.write(wrd + ' ')
        fwid.write('\n')
    return unk_words

小湉湉's avatar
小湉湉 已提交
47

O
oyjxer 已提交
48 49 50 51 52 53 54 55 56 57 58 59
def prep_mlf(txt, tmpbase):

    with open(tmpbase + '.mlf', 'w') as fwid:
        fwid.write('#!MLF!#\n')
        fwid.write('"' + tmpbase + '.lab"\n')
        fwid.write('sp\n')
        wrds = txt.split()
        for wrd in wrds:
            fwid.write(wrd.upper() + '\n')
            fwid.write('sp\n')
        fwid.write('.\n')

小湉湉's avatar
小湉湉 已提交
60

O
oyjxer 已提交
61 62 63 64 65 66 67 68 69 70 71 72
def gen_res(tmpbase, outfile1, outfile2):
    with open(tmpbase + '.txt', 'r') as fid:
        words = fid.readline().strip().split()
    words = txt.strip().split()
    words.reverse()

    with open(tmpbase + '.aligned', 'r') as fid:
        lines = fid.readlines()
    i = 2
    times1 = []
    times2 = []
    while (i < len(lines)):
小湉湉's avatar
小湉湉 已提交
73 74
        if (len(lines[i].split()) >= 4) and (
                lines[i].split()[0] != lines[i].split()[1]):
O
oyjxer 已提交
75
            phn = lines[i].split()[2]
小湉湉's avatar
小湉湉 已提交
76 77
            pst = (int(lines[i].split()[0]) / 1000 + 125) / 10000
            pen = (int(lines[i].split()[1]) / 1000 + 125) / 10000
O
oyjxer 已提交
78 79 80 81
            times2.append([phn, pst, pen])
        if (len(lines[i].split()) == 5):
            if (lines[i].split()[0] != lines[i].split()[1]):
                wrd = lines[i].split()[-1].strip()
小湉湉's avatar
小湉湉 已提交
82
                st = (int(lines[i].split()[0]) / 1000 + 125) / 10000
O
oyjxer 已提交
83 84 85
                j = i + 1
                while (lines[j] != '.\n') and (len(lines[j].split()) != 5):
                    j += 1
小湉湉's avatar
小湉湉 已提交
86
                en = (int(lines[j - 1].split()[1]) / 1000 + 125) / 10000
O
oyjxer 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
                times1.append([wrd, st, en])
        i += 1

    with open(outfile1, 'w') as fwid:
        for item in times1:
            if (item[0] == 'sp'):
                fwid.write(str(item[1]) + ' ' + str(item[2]) + ' SIL\n')
            else:
                wrd = words.pop()
                fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + wrd + '\n')
    if words:
        print('not matched::' + alignfile)
        sys.exit(1)

    with open(outfile2, 'w') as fwid:
        for item in times2:
            fwid.write(str(item[1]) + ' ' + str(item[2]) + ' ' + item[0] + '\n')


def alignment_zh(wav_path, text_string):
    tmpbase = '/tmp/' + os.environ['USER'] + '_' + str(os.getpid())

    #prepare wav and trs files
    try:
小湉湉's avatar
小湉湉 已提交
111 112
        os.system('sox ' + wav_path + ' -r 16000 -b 16 ' + tmpbase +
                  '.wav remix -')
O
oyjxer 已提交
113 114 115 116

    except:
        print('sox error!')
        return None
小湉湉's avatar
小湉湉 已提交
117

O
oyjxer 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    #prepare clean_transcript file
    try:
        unk_words = prep_txt(text_string, tmpbase, MODEL_DIR + '/dict')
        if unk_words:
            print('Error! Please add the following words to dictionary:')
            for unk in unk_words:
                print("非法words: ", unk)
    except:
        print('prep_txt error!')
        return None

    #prepare mlf file
    try:
        with open(tmpbase + '.txt', 'r') as fid:
            txt = fid.readline()
        prep_mlf(txt, tmpbase)
    except:
        print('prep_mlf error!')
        return None

    #prepare scp
    try:
小湉湉's avatar
小湉湉 已提交
140 141
        os.system(HCOPY + ' -C ' + MODEL_DIR + '/16000/config ' + tmpbase +
                  '.wav' + ' ' + tmpbase + '.plp')
O
oyjxer 已提交
142 143 144 145 146 147
    except:
        print('HCopy error!')
        return None

    #run alignment
    try:
小湉湉's avatar
小湉湉 已提交
148 149 150 151 152
        os.system(HVITE + ' -a -m -t 10000.0 10000.0 100000.0 -I ' + tmpbase +
                  '.mlf -H ' + MODEL_DIR + '/16000/macros -H ' + MODEL_DIR +
                  '/16000/hmmdefs -i ' + tmpbase + '.aligned ' + MODEL_DIR +
                  '/dict ' + MODEL_DIR + '/monophones ' + tmpbase +
                  '.plp 2>&1 > /dev/null')
O
oyjxer 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167

    except:
        print('HVite error!')
        return None

    with open(tmpbase + '.txt', 'r') as fid:
        words = fid.readline().strip().split()
    words = txt.strip().split()
    words.reverse()

    with open(tmpbase + '.aligned', 'r') as fid:
        lines = fid.readlines()

    i = 2
    times2 = []
小湉湉's avatar
小湉湉 已提交
168
    word2phns = {}
O
oyjxer 已提交
169 170 171 172 173 174
    current_word = ''
    index = 0
    while (i < len(lines)):
        splited_line = lines[i].strip().split()
        if (len(splited_line) >= 4) and (splited_line[0] != splited_line[1]):
            phn = splited_line[2]
小湉湉's avatar
小湉湉 已提交
175 176
            pst = (int(splited_line[0]) / 1000 + 125) / 10000
            pen = (int(splited_line[1]) / 1000 + 125) / 10000
O
oyjxer 已提交
177 178
            times2.append([phn, pst, pen])
            # splited_line[-1]!='sp'
小湉湉's avatar
小湉湉 已提交
179 180
            if len(splited_line) == 5:
                current_word = str(index) + '_' + splited_line[-1]
O
oyjxer 已提交
181
                word2phns[current_word] = phn
小湉湉's avatar
小湉湉 已提交
182 183 184 185 186
                index += 1
            elif len(splited_line) == 4:
                word2phns[current_word] += ' ' + phn
        i += 1
    return times2, word2phns