diff --git a/examples/tiny/s0/local/data.sh b/examples/tiny/s0/local/data.sh index a54f80e5db4b137ae54403e3697c00b1d4b2448e..410bff395cff0b671722af6c3ea331507339405d 100644 --- a/examples/tiny/s0/local/data.sh +++ b/examples/tiny/s0/local/data.sh @@ -24,7 +24,7 @@ bpeprefix="data/bpe_${bpemode}_${nbpe}" # build vocabulary python3 ${MAIN_ROOT}/utils/build_vocab.py \ --unit_type "spm" \ ---count_threshold=${nbpe} \ +--vocab_size=${nbpe} \ --spm_mode ${bpemode} \ --spm_model_prefix ${bpeprefix} \ --vocab_path="data/vocab.txt" \ diff --git a/utils/build_vocab.py b/utils/build_vocab.py index 3ef566b12bd5e46effeebb16100eaecf4fc7f956..591fda33ed165ed5423816bdde4a46af5d53a917 100644 --- a/utils/build_vocab.py +++ b/utils/build_vocab.py @@ -35,7 +35,8 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") -add_arg('count_threshold', int, 0, "Truncation threshold for char/word/spm counts.") +add_arg('count_threshold', int, 0, + "Truncation threshold for char/word counts.Default 0, no truncate.") add_arg('vocab_path', str, 'examples/librispeech/data/vocab.txt', "Filepath to write the vocabulary.") @@ -46,6 +47,7 @@ add_arg('manifest_paths', str, nargs='+', required=True) # bpe +add_arg('vocab_size', int, 0, "Vocab size for spm.") add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)", @@ -72,18 +74,7 @@ def main(): fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC fout.write(UNK + '\n') # must be 1 - if args.unit_type != 'spm': - text_feature = TextFeaturizer(args.unit_type, args.vocab_path) - counter = Counter() - - for manifest_path in args.manifest_paths: - count_manifest(counter, text_feature, manifest_path) - - count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) - for char, count in count_sorted: - if count < args.count_threshold: break - fout.write(char + '\n') - else: + if args.unit_type == 'spm': # tools/spm_train --input=$wave_data/lang_char/input.txt # --vocab_size=${nbpe} --model_type=${bpemode} # --model_prefix=${bpemodel} --input_sentence_size=100000000 @@ -96,39 +87,29 @@ def main(): # train spm.SentencePieceTrainer.Train( input=fp.name, - vocab_size=args.count_threshold, + vocab_size=args.vocab_size, model_type=args.spm_mode, model_prefix=args.spm_model_prefix, input_sentence_size=100000000, character_coverage=0.9995) os.unlink(fp.name) - # encode - text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) - - # vocabs = set() - # for manifest_path in args.manifest_paths: - # manifest_jsons = read_manifest(manifest_path) - # for line_json in manifest_jsons: - # line = line_json['text'] - # enc_line = text_feature.spm_tokenize(line) - # for code in enc_line: - # vocabs.add(code) - # #print(" ".join(enc_line)) - # vocabs_sorted = sorted(vocabs) - # for unit in vocabs_sorted: - # fout.write(unit + "\n") - - counter = Counter() - - for manifest_path in args.manifest_paths: - count_manifest(counter, text_feature, manifest_path) - - count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) - for token, count in count_sorted: - fout.write(token + '\n') - - print(f"spm vocab size: {len(count_sorted)}") + # encode + text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) + counter = Counter() + + for manifest_path in args.manifest_paths: + count_manifest(counter, text_feature, manifest_path) + + count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) + tokens = [] + for token, count in count_sorted: + if count < args.count_threshold: break + tokens.append(token) + + tokens = sorted(tokens) + for token in tokens: + fout.write(token + '\n') fout.write(SOS + "\n") # fout.close() diff --git a/utils/format_data.py b/utils/format_data.py index f1744d175eaf8e8635edb4256ef664533d645428..d19bed09eb10700491b0e36e3d8f24ab5dd5f49d 100644 --- a/utils/format_data.py +++ b/utils/format_data.py @@ -67,16 +67,12 @@ def main(): vocab_size = text_feature.vocab_size print(f"Vocab size: {vocab_size}") + count = 0 for manifest_path in args.manifest_paths: manifest_jsons = read_manifest(manifest_path) for line_json in manifest_jsons: line = line_json['text'] - if args.unit_type == 'char': - tokens = text_feature.char_tokenize(line) - elif args.unit_type == 'word': - tokens = text_feature.word_tokenize(line) - else: #spm - tokens = text_feature.spm_tokenize(line) + tokens = text_feature.tokenize(line) tokenids = text_feature.featurize(line) line_json['token'] = tokens line_json['token_id'] = tokenids @@ -88,7 +84,9 @@ def main(): else: # kaldi raise NotImplemented('no support kaldi feat now!') fout.write(json.dumps(line_json) + '\n') - + count += 1 + + print(f"Examples number: {count}") fout.close()