提交 b88f95a2 编写于 作者: Y Yibing Liu

Expose beam size in decoder

上级 c462ab1a
......@@ -26,6 +26,7 @@ Decoder::Decoder(std::string trans_model_in_filename,
std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_in_filename,
size_t beam_size,
kaldi::BaseFloat acoustic_scale) {
const char *usage =
"Generate lattices using neural net model.\n"
......@@ -51,7 +52,7 @@ Decoder::Decoder(std::string trans_model_in_filename,
int argc = 2;
char *argv[] = {(char *)"post-latgen-faster-mapped",
(char *)("--beam=" + std::string("11")).c_str()};
(char *)("--beam=" + std::to_string(beam_size)).c_str()};
po.Read(argc, argv);
......
......@@ -29,6 +29,7 @@ public:
std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_in_filename,
size_t beam_size,
kaldi::BaseFloat acoustic_scale);
~Decoder();
......
......@@ -27,6 +27,7 @@ PYBIND11_MODULE(post_latgen_faster_mapped, m) {
std::string,
std::string,
std::string,
size_t,
kaldi::BaseFloat>())
.def("decode_from_file",
(void (Decoder::*)(std::string, size_t)) & Decoder::decode_from_file,
......
......@@ -4,10 +4,11 @@ export CUDA_VISIBLE_DEVICES=2,3,4,5
python -u ../../infer_by_ckpt.py --batch_size 96 \
--checkpoint checkpoints/deep_asr.pass_20.checkpoint \
--infer_feature_lst data/test_feature.lst \
--mean_var data/aishell/global_mean_var \
--mean_var data/global_mean_var \
--frame_dim 80 \
--class_num 3040 \
--num_threads 24 \
--beam_size 11 \
--decode_to_path $decode_to_path \
--trans_model mapped_decoder_data/exp/tri5a/final.mdl \
--log_prior mapped_decoder_data/logprior \
......
......@@ -27,6 +27,11 @@ def parse_args():
type=int,
default=32,
help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--beam_size',
type=int,
default=11,
help='The beam size for decoding. (default: %(default)d)')
parser.add_argument(
'--minimum_batch_size',
type=int,
......@@ -211,7 +216,7 @@ def infer_from_ckpt(args):
# init decoder
decoder = Decoder(args.trans_model, args.vocabulary, args.graphs,
args.log_prior, args.acoustic_scale)
args.log_prior, args.beam_size, args.acoustic_scale)
ltrans = [
trans_add_delta.TransAddDelta(2, 2),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册