未验证 提交 fe8a14dd 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #1740 from zh794390558/fix

[speechx] fix nnet input and output name
...@@ -12,6 +12,8 @@ exclude = ...@@ -12,6 +12,8 @@ exclude =
.git, .git,
# python cache # python cache
__pycache__, __pycache__,
# third party
utils/compute-wer.py,
third_party/, third_party/,
# Provide a comma-separate list of glob patterns to include for checks. # Provide a comma-separate list of glob patterns to include for checks.
filename = filename =
......
...@@ -40,6 +40,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig ...@@ -40,6 +40,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
@cli_register( @cli_register(
name='paddlespeech.asr', description='Speech to text infer command.') name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor): class ASRExecutor(BaseExecutor):
...@@ -148,7 +149,7 @@ class ASRExecutor(BaseExecutor): ...@@ -148,7 +149,7 @@ class ASRExecutor(BaseExecutor):
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
...@@ -278,7 +279,8 @@ class ASRExecutor(BaseExecutor): ...@@ -278,7 +279,8 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
logger.info(f"we will use the transformer like model : {model_type}") logger.info(
f"we will use the transformer like model : {model_type}")
try: try:
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
......
...@@ -279,7 +279,7 @@ class U2BaseModel(ASRInterface, nn.Layer): ...@@ -279,7 +279,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size: # TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size: if end_flag.cast(paddle.int64).sum() == running_size:
break break
# 2.1 Forward decoder step # 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i) running_size, 1, 1).to(device) # (B*N, i, i)
......
...@@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase): ...@@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
# init once # init once
if self._ext_scorer is not None: if self._ext_scorer is not None:
return return
if language_model_path != '': if language_model_path != '':
logger.info("begin to initialize the external scorer " logger.info("begin to initialize the external scorer "
"for decoding") "for decoding")
......
...@@ -47,4 +47,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml ...@@ -47,4 +47,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml
``` ```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
``` ```
\ No newline at end of file
...@@ -48,4 +48,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml ...@@ -48,4 +48,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml
``` ```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
``` ```
\ No newline at end of file
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
import paddle import paddle
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import log_add
......
...@@ -52,7 +52,7 @@ def evaluate(args): ...@@ -52,7 +52,7 @@ def evaluate(args):
# acoustic model # acoustic model
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference( am_inference = get_am_inference(
am=args.am, am=args.am,
am_config=am_config, am_config=am_config,
......
...@@ -20,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement. ...@@ -20,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement.
import argparse import argparse
import copy import copy
import warnings import warnings
from distutils.util import strtobool
import numpy as np import numpy as np
import scipy import scipy
import sklearn import sklearn
from distutils.util import strtobool
from scipy import linalg from scipy import linalg
from scipy import sparse from scipy import sparse
from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import connected_components
......
...@@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length, ...@@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate, DEFINE_int32(downsampling_rate,
4, 4,
"two CNN(kernel=5) module downsampling rate."); "two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names, DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/" "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names"); "model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
...@@ -76,6 +78,7 @@ int main(int argc, char* argv[]) { ...@@ -76,6 +78,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_path; model_opts.model_path = model_path;
model_opts.params_path = model_params; model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names; model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names; model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
......
...@@ -48,7 +48,6 @@ if [ ! -f $lm ]; then ...@@ -48,7 +48,6 @@ if [ ! -f $lm ]; then
popd popd
fi fi
feat_wspecifier=$exp_dir/feats.ark feat_wspecifier=$exp_dir/feats.ark
cmvn=$exp_dir/cmvn.ark cmvn=$exp_dir/cmvn.ark
...@@ -57,7 +56,7 @@ export GLOG_logtostderr=1 ...@@ -57,7 +56,7 @@ export GLOG_logtostderr=1
# dump json cmvn to kaldi # dump json cmvn to kaldi
cmvn-json2kaldi \ cmvn-json2kaldi \
--json_file $ckpt_dir/data/mean_std.json \ --json_file $ckpt_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \ --cmvn_write_path $cmvn \
--binary=false --binary=false
echo "convert json cmvn to kaldi ark." echo "convert json cmvn to kaldi ark."
...@@ -66,7 +65,7 @@ echo "convert json cmvn to kaldi ark." ...@@ -66,7 +65,7 @@ echo "convert json cmvn to kaldi ark."
linear-spectrogram-wo-db-norm-ol \ linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/wav.scp \ --wav_rspecifier=scp:$data/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \ --feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_file=$exp_dir/cmvn.ark --cmvn_file=$cmvn
echo "compute linear spectrogram feature." echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming # run ctc beam search decoder as streaming
......
...@@ -37,10 +37,12 @@ DEFINE_int32(receptive_field_length, ...@@ -37,10 +37,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate, DEFINE_int32(downsampling_rate,
4, 4,
"two CNN(kernel=5) module downsampling rate."); "two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names, DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/" "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names"); "model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
...@@ -79,6 +81,7 @@ int main(int argc, char* argv[]) { ...@@ -79,6 +81,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_graph; model_opts.model_path = model_graph;
model_opts.params_path = model_params; model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names; model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names; model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
......
...@@ -9,4 +9,4 @@ target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags g ...@@ -9,4 +9,4 @@ target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags g
set(bin_name cmvn-json2kaldi) set(bin_name cmvn-json2kaldi)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog) target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog ${DEPS})
\ No newline at end of file
...@@ -14,18 +14,20 @@ ...@@ -14,18 +14,20 @@
// Note: Do not print/log ondemand object. // Note: Do not print/log ondemand object.
#include "base/common.h"
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h" #include "utils/file_utils.h"
#include "utils/simdjson.h" // #include "boost/json.hpp"
#include <boost/json/src.hpp>
DEFINE_string(json_file, "", "cmvn json file"); DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)"); DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace simdjson; using namespace boost::json; // from <boost/json.hpp>
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
...@@ -33,49 +35,51 @@ int main(int argc, char* argv[]) { ...@@ -33,49 +35,51 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file; LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
try { auto ifs = std::ifstream(FLAGS_json_file);
padded_string json = padded_string::load(FLAGS_json_file); std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file);
auto value = boost::json::parse(json_str);
ondemand::parser parser; if (!value.is_object()) {
ondemand::document doc = parser.iterate(json); LOG(ERROR) << "Input json file format error.";
ondemand::value val = doc; }
ondemand::array mean_stat = val["mean_stat"]; for (auto obj : value.as_object()) {
std::vector<kaldi::BaseFloat> mean_stat_vec; if (obj.key() == "mean_stat") {
for (double x : mean_stat) { LOG(INFO) << "mean_stat:" << obj.value();
mean_stat_vec.push_back(x);
} }
// LOG(INFO) << mean_stat; this line will casue if (obj.key() == "var_stat") {
// simdjson::simdjson_error("Objects and arrays can only be iterated LOG(INFO) << "var_stat: " << obj.value();
// when
// they are first encountered")
ondemand::array var_stat = val["var_stat"];
std::vector<kaldi::BaseFloat> var_stat_vec;
for (double x : var_stat) {
var_stat_vec.push_back(x);
} }
if (obj.key() == "frame_num") {
kaldi::int32 frame_num = uint64_t(val["frame_num"]); LOG(INFO) << "frame_num: " << obj.value();
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
} }
cmvn_stats(0, mean_size) = frame_num; }
LOG(INFO) << cmvn_stats;
boost::json::array mean_stat = value.at("mean_stat").as_array();
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) {
mean_stat_vec.push_back(it->as_double());
}
kaldi::WriteKaldiObject( boost::json::array var_stat = value.at("var_stat").as_array();
cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary); std::vector<kaldi::BaseFloat> var_stat_vec;
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; for (auto it = var_stat.begin(); it != var_stat.end(); it++) {
LOG(INFO) << "Binary: " << FLAGS_binary; var_stat_vec.push_back(it->as_double());
} catch (simdjson::simdjson_error& err) {
LOG(ERROR) << err.what();
} }
kaldi::int32 frame_num = uint64_t(value.at("frame_num").as_int64());
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
return 0; return 0;
} }
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import argparse import argparse
from collections import Counter from collections import Counter
def main(args): def main(args):
counter = Counter() counter = Counter()
with open(args.text, 'r') as fin, open(args.lexicon, 'w') as fout: with open(args.text, 'r') as fin, open(args.lexicon, 'w') as fout:
...@@ -12,7 +13,7 @@ def main(args): ...@@ -12,7 +13,7 @@ def main(args):
words = text.split() words = text.split()
else: else:
words = line.split() words = line.split()
counter.update(words) counter.update(words)
for word in counter: for word in counter:
...@@ -20,21 +21,16 @@ def main(args): ...@@ -20,21 +21,16 @@ def main(args):
fout.write(f"{word}\t{val}\n") fout.write(f"{word}\t{val}\n")
fout.flush() fout.flush()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='text(line:utt1 中国 人) to lexicon(line:中国 中 国).') description='text(line:utt1 中国 人) to lexicon(line:中国 中 国).')
parser.add_argument( parser.add_argument(
'--has_key', '--has_key', default=True, help='text path, with utt or not')
default=True,
help='text path, with utt or not')
parser.add_argument( parser.add_argument(
'--text', '--text', required=True, help='text path. line: utt1 中国 人 or 中国 人')
required=True,
help='text path. line: utt1 中国 人 or 中国 人')
parser.add_argument( parser.add_argument(
'--lexicon', '--lexicon', required=True, help='lexicon path. line:中国 中 国')
required=True,
help='lexicon path. line:中国 中 国')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -183,4 +183,4 @@ data/ ...@@ -183,4 +183,4 @@ data/
├── lexiconp_disambig.txt ├── lexiconp_disambig.txt
├── lexiconp.txt ├── lexiconp.txt
└── units.list └── units.list
``` ```
\ No newline at end of file
...@@ -26,9 +26,9 @@ import argparse ...@@ -26,9 +26,9 @@ import argparse
import os import os
import re import re
import subprocess import subprocess
from distutils.util import strtobool
import numpy as np import numpy as np
from distutils.util import strtobool
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+") SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")
......
此差异已折叠。
import os # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import jsonlines import jsonlines
def trans_hyp(origin_hyp, def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None):
trans_hyp = None,
trans_hyp_sclite = None):
""" """
Args: Args:
origin_hyp: The input json file which contains the model output origin_hyp: The input json file which contains the model output
...@@ -17,19 +27,18 @@ def trans_hyp(origin_hyp, ...@@ -17,19 +27,18 @@ def trans_hyp(origin_hyp,
with open(origin_hyp, "r+", encoding="utf8") as f: with open(origin_hyp, "r+", encoding="utf8") as f:
for item in jsonlines.Reader(f): for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["hyps"][0] input_dict[item["utt"]] = item["hyps"][0]
if trans_hyp is not None: if trans_hyp is not None:
with open(trans_hyp, "w+", encoding="utf8") as f: with open(trans_hyp, "w+", encoding="utf8") as f:
for key in input_dict.keys(): for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n") f.write(key + " " + input_dict[key] + "\n")
if trans_hyp_sclite is not None: if trans_hyp_sclite is not None:
with open(trans_hyp_sclite, "w+") as f: with open(trans_hyp_sclite, "w+") as f:
for key in input_dict.keys(): for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n" line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line) f.write(line)
def trans_ref(origin_ref,
trans_ref = None, def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None):
trans_ref_sclite = None):
""" """
Args: Args:
origin_hyp: The input json file which contains the model output origin_hyp: The input json file which contains the model output
...@@ -49,42 +58,48 @@ def trans_ref(origin_ref, ...@@ -49,42 +58,48 @@ def trans_ref(origin_ref,
if trans_ref_sclite is not None: if trans_ref_sclite is not None:
with open(trans_ref_sclite, "w") as f: with open(trans_ref_sclite, "w") as f:
for key in input_dict.keys(): for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n" line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line) f.write(line)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='format hyp file for compute CER/WER', add_help=True) parser = argparse.ArgumentParser(
prog='format hyp file for compute CER/WER', add_help=True)
parser.add_argument( parser.add_argument(
'--origin_hyp', '--origin_hyp', type=str, default=None, help='origin hyp file')
type=str,
default = None,
help='origin hyp file')
parser.add_argument( parser.add_argument(
'--trans_hyp', type=str, default = None, help='hyp file for caculating CER/WER') '--trans_hyp',
type=str,
default=None,
help='hyp file for caculating CER/WER')
parser.add_argument( parser.add_argument(
'--trans_hyp_sclite', type=str, default = None, help='hyp file for caculating CER/WER by sclite') '--trans_hyp_sclite',
type=str,
default=None,
help='hyp file for caculating CER/WER by sclite')
parser.add_argument( parser.add_argument(
'--origin_ref', '--origin_ref', type=str, default=None, help='origin ref file')
type=str,
default = None,
help='origin ref file')
parser.add_argument( parser.add_argument(
'--trans_ref', type=str, default = None, help='ref file for caculating CER/WER') '--trans_ref',
type=str,
default=None,
help='ref file for caculating CER/WER')
parser.add_argument( parser.add_argument(
'--trans_ref_sclite', type=str, default = None, help='ref file for caculating CER/WER by sclite') '--trans_ref_sclite',
type=str,
default=None,
help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args() parser_args = parser.parse_args()
if parser_args.origin_hyp is not None: if parser_args.origin_hyp is not None:
trans_hyp( trans_hyp(
origin_hyp = parser_args.origin_hyp, origin_hyp=parser_args.origin_hyp,
trans_hyp = parser_args.trans_hyp, trans_hyp=parser_args.trans_hyp,
trans_hyp_sclite = parser_args.trans_hyp_sclite, ) trans_hyp_sclite=parser_args.trans_hyp_sclite, )
if parser_args.origin_ref is not None: if parser_args.origin_ref is not None:
trans_ref( trans_ref(
origin_ref = parser_args.origin_ref, origin_ref=parser_args.origin_ref,
trans_ref = parser_args.trans_ref, trans_ref=parser_args.trans_ref,
trans_ref_sclite = parser_args.trans_ref_sclite, ) trans_ref_sclite=parser_args.trans_ref_sclite, )
...@@ -35,7 +35,7 @@ def main(args): ...@@ -35,7 +35,7 @@ def main(args):
# used to filter polyphone and invalid word # used to filter polyphone and invalid word
lexicon_table = set() lexicon_table = set()
in_n = 0 # in lexicon word count in_n = 0 # in lexicon word count
out_n = 0 # out lexicon word cout out_n = 0 # out lexicon word cout
with open(args.in_lexicon, 'r') as fin, \ with open(args.in_lexicon, 'r') as fin, \
open(args.out_lexicon, 'w') as fout: open(args.out_lexicon, 'w') as fout:
for line in fin: for line in fin:
...@@ -82,7 +82,10 @@ def main(args): ...@@ -82,7 +82,10 @@ def main(args):
lexicon_table.add(word) lexicon_table.add(word)
out_n += 1 out_n += 1
print(f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}") print(
f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}"
)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册