build_vocab.py 1.6 KB
Newer Older
Y
yangyaming 已提交
1
"""Build vocabulary from manifest files.
2 3 4 5 6 7 8 9

Each item in vocabulary file is a character.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
X
Xinghai Sun 已提交
10
import functools
11 12 13 14
import codecs
import json
from collections import Counter
import os.path
Y
yangyaming 已提交
15
import _init_paths
16 17
from data_utils.utility import read_manifest
from utils.utility import add_arguments, print_arguments
18

Y
yangyaming 已提交
19
parser = argparse.ArgumentParser(description=__doc__)
X
Xinghai Sun 已提交
20 21
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
22 23 24 25 26 27 28 29 30 31 32
add_arg('count_threshold',  int,    0,  "Truncation threshold for char counts.")
add_arg('vocab_path',       str,
        'datasets/vocab/zh_vocab.txt',
        "Filepath to write the vocabulary.")
add_arg('manifest_paths',   str,
        None,
        "Filepaths of manifests for building vocabulary. "
        "You can provide multiple manifest files.",
        nargs='+',
        required=True)
# yapf: disable
X
Xinghai Sun 已提交
33
args = parser.parse_args()
34 35 36


def count_manifest(counter, manifest_path):
Y
yangyaming 已提交
37 38 39
    manifest_jsons = utils.read_manifest(manifest_path)
    for line_json in manifest_jsons:
        for char in line_json['text']:
40 41 42 43
            counter.update(char)


def main():
44 45
    print_arguments(args)

46 47 48 49 50 51
    counter = Counter()
    for manifest_path in args.manifest_paths:
        count_manifest(counter, manifest_path)

    count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
    with codecs.open(args.vocab_path, 'w', 'utf-8') as fout:
Y
yangyaming 已提交
52 53 54
        for char, count in count_sorted:
            if count < args.count_threshold: break
            fout.write(char + '\n')
55 56 57 58


if __name__ == '__main__':
    main()