build_vocab.py 4.6 KB
Newer Older
H
Hui Zhang 已提交
1
#!/usr/bin/env python3
H
Hui Zhang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yangyaming 已提交
15
"""Build vocabulary from manifest files.
16 17 18
Each item in vocabulary file is a character.
"""
import argparse
X
Xinghai Sun 已提交
19
import functools
20 21
import os
import tempfile
22
from collections import Counter
H
Hui Zhang 已提交
23

24 25
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.frontend.utility import BLANK
H
Hui Zhang 已提交
26
from deepspeech.frontend.utility import read_manifest
27
from deepspeech.frontend.utility import SOS
H
Hui Zhang 已提交
28
from deepspeech.frontend.utility import SPACE
29 30 31
from deepspeech.frontend.utility import UNK
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
32

Y
yangyaming 已提交
33
parser = argparse.ArgumentParser(description=__doc__)
X
Xinghai Sun 已提交
34 35
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
36 37 38 39
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('count_threshold', int, 0,
        "Truncation threshold for char/word counts.Default 0, no truncate.")
add_arg('vocab_path', str,
H
Hui Zhang 已提交
40
        'examples/librispeech/data/vocab.txt',
41
        "Filepath to write the vocabulary.")
42
add_arg('manifest_paths', str,
43 44 45 46 47
        None,
        "Filepaths of manifests for building vocabulary. "
        "You can provide multiple manifest files.",
        nargs='+',
        required=True)
J
Junkun 已提交
48 49 50 51 52
add_arg('text_keys', str,
        'text',
        "keys of the text in manifest for building vocabulary. "
        "You can provide multiple k.",
        nargs='+')
53 54 55 56
# bpe
add_arg('spm_vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm")
57
# yapf: disable
X
Xinghai Sun 已提交
58
args = parser.parse_args()
59 60


61
def count_manifest(counter, text_feature, manifest_path):
62
    manifest_jsons = read_manifest(manifest_path)
Y
yangyaming 已提交
63
    for line_json in manifest_jsons:
H
Hui Zhang 已提交
64
        line = text_feature.tokenize(line_json['text'], replace_space=False)
65
        counter.update(line)
66

J
Junkun 已提交
67
def dump_text_manifest(fileobj, manifest_path, key='text'):
68 69
    manifest_jsons = read_manifest(manifest_path)
    for line_json in manifest_jsons:
J
Junkun 已提交
70
        fileobj.write(line_json[key] + "\n")
71 72

def main():
73 74 75 76 77 78 79
    print_arguments(args, globals())

    fout = open(args.vocab_path, 'w', encoding='utf-8')
    fout.write(BLANK + "\n")  # 0 will be used for "blank" in CTC
    fout.write(UNK + '\n')  # <unk> must be 1

    if args.unit_type == 'spm':
H
Hui Zhang 已提交
80 81
        # tools/spm_train --input=$wave_data/lang_char/input.txt
        # --vocab_size=${nbpe} --model_type=${bpemode}
82 83
        # --model_prefix=${bpemodel} --input_sentence_size=100000000
        import sentencepiece as spm
84

85 86
        fp = tempfile.NamedTemporaryFile(mode='w', delete=False)
        for manifest_path in args.manifest_paths:
J
Junkun 已提交
87 88 89
            text_keys = [args.text_keys] if type(args.text_keys) is not list else args.text_keys
            for text_key in text_keys:
                dump_text_manifest(fp, manifest_path, key=text_key)
90 91 92 93 94 95 96 97 98 99 100 101 102
        fp.close()
        # train
        spm.SentencePieceTrainer.Train(
            input=fp.name,
            vocab_size=args.spm_vocab_size,
            model_type=args.spm_mode,
            model_prefix=args.spm_model_prefix,
            input_sentence_size=100000000,
            character_coverage=0.9995)
        os.unlink(fp.name)

    # encode
    text_feature = TextFeaturizer(args.unit_type, "", args.spm_model_prefix)
103
    counter = Counter()
104

105
    for manifest_path in args.manifest_paths:
106
        count_manifest(counter, text_feature, manifest_path)
107 108

    count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
109 110 111 112
    tokens = []
    for token, count in count_sorted:
        if count < args.count_threshold:
            break
H
Hui Zhang 已提交
113 114
        # replace space by `<space>`
        token = SPACE if token == ' ' else token
115 116 117 118 119 120 121 122
        tokens.append(token)

    tokens = sorted(tokens)
    for token in tokens:
        fout.write(token + '\n')

    fout.write(SOS + "\n")  # <sos/eos>
    fout.close()
123 124 125 126


if __name__ == '__main__':
    main()