提交 175f36f9 编写于 作者: Y Yibing Liu

Expose number of threads for decoding

上级 93a7392b
...@@ -41,7 +41,7 @@ PYBIND11_MODULE(post_latgen_faster_mapped, m) { ...@@ -41,7 +41,7 @@ PYBIND11_MODULE(post_latgen_faster_mapped, m) {
"and return the transcription.") "and return the transcription.")
.def("decode_batch", .def("decode_batch",
(std::vector<std::string> (Decoder::*)( (std::vector<std::string> (Decoder::*)(
std::string, std::vector<std::string>,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>&, const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>&,
size_t num_processes)) & size_t num_processes)) &
Decoder::decode_batch, Decoder::decode_batch,
......
export CUDA_VISIBLE_DEVICES=0,1 export CUDA_VISIBLE_DEVICES=2,3,4,5
python -u ../../infer_by_ckpt.py --batch_size 48 \ python -u ../../infer_by_ckpt.py --batch_size 96 \
--checkpoint deep_asr.pass_20.checkpoint \ --checkpoint checkpoints/deep_asr.pass_20.checkpoint \
--infer_feature_lst data/test_feature.lst \ --infer_feature_lst data/test_feature.lst \
--infer_label_lst data/test_label.lst \ --infer_label_lst data/test_label.lst \
--mean_var data/aishell/global_mean_var \ --mean_var data/aishell/global_mean_var \
......
...@@ -59,6 +59,11 @@ def parse_args(): ...@@ -59,6 +59,11 @@ def parse_args():
type=int, type=int,
default=1749, default=1749,
help='Number of classes in label. (default: %(default)d)') help='Number of classes in label. (default: %(default)d)')
parser.add_argument(
'--num_threads',
type=int,
default=10,
help='The number of threads for decoding. (default: %(default)d)')
parser.add_argument( parser.add_argument(
'--learning_rate', '--learning_rate',
type=float, type=float,
...@@ -189,7 +194,7 @@ def infer_from_ckpt(args): ...@@ -189,7 +194,7 @@ def infer_from_ckpt(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
trg_trans = get_trg_trans(args) #trg_trans = get_trg_trans(args)
# load checkpoint. # load checkpoint.
fluid.io.load_persistables(exe, args.checkpoint) fluid.io.load_persistables(exe, args.checkpoint)
...@@ -238,7 +243,9 @@ def infer_from_ckpt(args): ...@@ -238,7 +243,9 @@ def infer_from_ckpt(args):
probs, lod = lodtensor_to_ndarray(results[0]) probs, lod = lodtensor_to_ndarray(results[0])
infer_batch = split_infer_result(probs, lod) infer_batch = split_infer_result(probs, lod)
decoder.decode_batch(name_lst, infer_batch) decoded = decoder.decode_batch(name_lst, infer_batch, args.num_threads)
for res in decoded:
print(res.encode("utf8"))
if args.post_matrix_path is not None: if args.post_matrix_path is not None:
for index, sample in enumerate(infer_batch): for index, sample in enumerate(infer_batch):
key = name_lst[index] key = name_lst[index]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册