提交 9d5eb740 编写于 作者: H Hui Zhang

fix decode json file

上级 9abe33b4
......@@ -18,8 +18,8 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
......@@ -306,7 +306,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref", target, "hyp": result})
fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
......
......@@ -21,8 +21,8 @@ from collections import OrderedDict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
......@@ -467,7 +467,7 @@ class U2Tester(U2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref", target, "hyp": result})
fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
......
......@@ -20,8 +20,8 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
......@@ -446,7 +446,7 @@ class U2Tester(U2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref", target, "hyp": result})
fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
......
......@@ -20,8 +20,8 @@ from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import jsonlines
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
......@@ -480,7 +480,7 @@ class U2STTester(U2STTrainer):
len_refs += len(target.split())
num_ins += 1
if fout:
fout.write({"utt": utt, "ref", target, "hyp": result})
fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
......
# Utils
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
* [espnet utils)(https://github.com/espnet/espnet/tree/master/utils)
文件模式从 100644 更改为 100755
#!/usr/bin/env python3
# Apache 2.0
import argparse
import codecs
import sys
......@@ -12,15 +10,13 @@ is_python2 = sys.version_info[0] == 2
def get_parser():
parser = argparse.ArgumentParser(
description="filter words in a text file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument(
"--exclude",
"-v",
dest="exclude",
action="store_true",
help="exclude filter words",
)
help="exclude filter words", )
parser.add_argument("filt", type=str, help="filter list")
parser.add_argument("infile", type=str, help="input file")
return parser
......@@ -37,29 +33,20 @@ def filter_file(infile, filt, exclude):
for line in vocabfile:
vocab.add(line.strip())
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
sys.stdout = codecs.getwriter("utf-8")(sys.stdout
if is_python2 else sys.stdout.buffer)
with codecs.open(infile, "r", encoding="utf-8") as textfile:
for line in textfile:
if exclude:
print(
" ".join(
print(" ".join(
map(
lambda word: word if word not in vocab else "",
line.strip().split(),
)
)
)
line.strip().split(), )))
else:
print(
" ".join(
print(" ".join(
map(
lambda word: word if word in vocab else "<UNK>",
line.strip().split(),
)
)
)
line.strip().split(), )))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册