提交 9d5eb740 编写于 作者: H Hui Zhang

fix decode json file

上级 9abe33b4
...@@ -18,8 +18,8 @@ from collections import defaultdict ...@@ -18,8 +18,8 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
...@@ -306,7 +306,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -306,7 +306,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref", target, "hyp": result}) fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")
......
...@@ -21,8 +21,8 @@ from collections import OrderedDict ...@@ -21,8 +21,8 @@ from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
...@@ -467,7 +467,7 @@ class U2Tester(U2Trainer): ...@@ -467,7 +467,7 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref", target, "hyp": result}) fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")
......
...@@ -20,8 +20,8 @@ from collections import defaultdict ...@@ -20,8 +20,8 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
...@@ -446,7 +446,7 @@ class U2Tester(U2Trainer): ...@@ -446,7 +446,7 @@ class U2Tester(U2Trainer):
len_refs += len_ref len_refs += len_ref
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref", target, "hyp": result}) fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")
......
...@@ -20,8 +20,8 @@ from collections import defaultdict ...@@ -20,8 +20,8 @@ from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import jsonlines
import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
...@@ -480,7 +480,7 @@ class U2STTester(U2STTrainer): ...@@ -480,7 +480,7 @@ class U2STTester(U2STTrainer):
len_refs += len(target.split()) len_refs += len(target.split())
num_ins += 1 num_ins += 1
if fout: if fout:
fout.write({"utt": utt, "ref", target, "hyp": result}) fout.write({"utt": utt, "ref": target, "hyp": result})
logger.info(f"Utt: {utt}") logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}") logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}") logger.info(f"Hyp: {result}")
......
# Utils # Utils
* [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils) * [kaldi utils](https://github.com/kaldi-asr/kaldi/blob/cbed4ff688/egs/wsj/s5/utils)
* [espnet utils)(https://github.com/espnet/espnet/tree/master/utils)
文件模式从 100644 更改为 100755
#!/usr/bin/env python3 #!/usr/bin/env python3
# Apache 2.0 # Apache 2.0
import argparse import argparse
import codecs import codecs
import sys import sys
...@@ -12,15 +10,13 @@ is_python2 = sys.version_info[0] == 2 ...@@ -12,15 +10,13 @@ is_python2 = sys.version_info[0] == 2
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="filter words in a text file", description="filter words in a text file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
)
parser.add_argument( parser.add_argument(
"--exclude", "--exclude",
"-v", "-v",
dest="exclude", dest="exclude",
action="store_true", action="store_true",
help="exclude filter words", help="exclude filter words", )
)
parser.add_argument("filt", type=str, help="filter list") parser.add_argument("filt", type=str, help="filter list")
parser.add_argument("infile", type=str, help="input file") parser.add_argument("infile", type=str, help="input file")
return parser return parser
...@@ -37,29 +33,20 @@ def filter_file(infile, filt, exclude): ...@@ -37,29 +33,20 @@ def filter_file(infile, filt, exclude):
for line in vocabfile: for line in vocabfile:
vocab.add(line.strip()) vocab.add(line.strip())
sys.stdout = codecs.getwriter("utf-8")( sys.stdout = codecs.getwriter("utf-8")(sys.stdout
sys.stdout if is_python2 else sys.stdout.buffer if is_python2 else sys.stdout.buffer)
)
with codecs.open(infile, "r", encoding="utf-8") as textfile: with codecs.open(infile, "r", encoding="utf-8") as textfile:
for line in textfile: for line in textfile:
if exclude: if exclude:
print( print(" ".join(
" ".join(
map( map(
lambda word: word if word not in vocab else "", lambda word: word if word not in vocab else "",
line.strip().split(), line.strip().split(), )))
)
)
)
else: else:
print( print(" ".join(
" ".join(
map( map(
lambda word: word if word in vocab else "<UNK>", lambda word: word if word in vocab else "<UNK>",
line.strip().split(), line.strip().split(), )))
)
)
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册