提交 175f36f9 编写于 作者: Y Yibing Liu

Expose number of threads for decoding

上级 93a7392b
......@@ -41,7 +41,7 @@ PYBIND11_MODULE(post_latgen_faster_mapped, m) {
"and return the transcription.")
.def("decode_batch",
(std::vector<std::string> (Decoder::*)(
std::string,
std::vector<std::string>,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>&,
size_t num_processes)) &
Decoder::decode_batch,
......
export CUDA_VISIBLE_DEVICES=0,1
python -u ../../infer_by_ckpt.py --batch_size 48 \
--checkpoint deep_asr.pass_20.checkpoint \
export CUDA_VISIBLE_DEVICES=2,3,4,5
python -u ../../infer_by_ckpt.py --batch_size 96 \
--checkpoint checkpoints/deep_asr.pass_20.checkpoint \
--infer_feature_lst data/test_feature.lst \
--infer_label_lst data/test_label.lst \
--mean_var data/aishell/global_mean_var \
......
......@@ -59,6 +59,11 @@ def parse_args():
type=int,
default=1749,
help='Number of classes in label. (default: %(default)d)')
parser.add_argument(
'--num_threads',
type=int,
default=10,
help='The number of threads for decoding. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
......@@ -189,7 +194,7 @@ def infer_from_ckpt(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
trg_trans = get_trg_trans(args)
#trg_trans = get_trg_trans(args)
# load checkpoint.
fluid.io.load_persistables(exe, args.checkpoint)
......@@ -238,7 +243,9 @@ def infer_from_ckpt(args):
probs, lod = lodtensor_to_ndarray(results[0])
infer_batch = split_infer_result(probs, lod)
decoder.decode_batch(name_lst, infer_batch)
decoded = decoder.decode_batch(name_lst, infer_batch, args.num_threads)
for res in decoded:
print(res.encode("utf8"))
if args.post_matrix_path is not None:
for index, sample in enumerate(infer_batch):
key = name_lst[index]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册