提交 553aa359 编写于 作者: H Hui Zhang

refactor data preprare

上级 b7674866
......@@ -24,7 +24,7 @@ bpeprefix="data/bpe_${bpemode}_${nbpe}"
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type "spm" \
--count_threshold=${nbpe} \
--vocab_size=${nbpe} \
--spm_mode ${bpemode} \
--spm_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \
......@@ -35,7 +35,8 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('count_threshold', int, 0, "Truncation threshold for char/word/spm counts.")
add_arg('count_threshold', int, 0,
"Truncation threshold for char/word counts.Default 0, no truncate.")
add_arg('vocab_path', str,
"Filepath to write the vocabulary.")
......@@ -46,6 +47,7 @@ add_arg('manifest_paths', str,
# bpe
add_arg('vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram',
"spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)",
......@@ -72,18 +74,7 @@ def main():
fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC
fout.write(UNK + '\n') # <unk> must be 1
if args.unit_type != 'spm':
text_feature = TextFeaturizer(args.unit_type, args.vocab_path)
counter = Counter()
for manifest_path in args.manifest_paths:
count_manifest(counter, text_feature, manifest_path)
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
for char, count in count_sorted:
if count < args.count_threshold: break
fout.write(char + '\n')
if args.unit_type == 'spm':
# tools/spm_train --input=$wave_data/lang_char/input.txt
# --vocab_size=${nbpe} --model_type=${bpemode}
# --model_prefix=${bpemodel} --input_sentence_size=100000000
......@@ -96,7 +87,7 @@ def main():
# train
......@@ -105,30 +96,20 @@ def main():
# encode
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)
# vocabs = set()
# for manifest_path in args.manifest_paths:
# manifest_jsons = read_manifest(manifest_path)
# for line_json in manifest_jsons:
# line = line_json['text']
# enc_line = text_feature.spm_tokenize(line)
# for code in enc_line:
# vocabs.add(code)
# #print(" ".join(enc_line))
# vocabs_sorted = sorted(vocabs)
# for unit in vocabs_sorted:
# fout.write(unit + "\n")
counter = Counter()
for manifest_path in args.manifest_paths:
count_manifest(counter, text_feature, manifest_path)
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
tokens = []
for token, count in count_sorted:
fout.write(token + '\n')
if count < args.count_threshold: break
print(f"spm vocab size: {len(count_sorted)}")
tokens = sorted(tokens)
for token in tokens:
fout.write(token + '\n')
fout.write(SOS + "\n") # <sos/eos>
......@@ -67,16 +67,12 @@ def main():
vocab_size = text_feature.vocab_size
print(f"Vocab size: {vocab_size}")
count = 0
for manifest_path in args.manifest_paths:
manifest_jsons = read_manifest(manifest_path)
for line_json in manifest_jsons:
line = line_json['text']
if args.unit_type == 'char':
tokens = text_feature.char_tokenize(line)
elif args.unit_type == 'word':
tokens = text_feature.word_tokenize(line)
else: #spm
tokens = text_feature.spm_tokenize(line)
tokens = text_feature.tokenize(line)
tokenids = text_feature.featurize(line)
line_json['token'] = tokens
line_json['token_id'] = tokenids
......@@ -88,7 +84,9 @@ def main():
else: # kaldi
raise NotImplemented('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n')
count += 1
print(f"Examples number: {count}")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册