提交 e1d90fc0 编写于 作者: Y Yibing Liu

Use thread pool for parallel decoding

上级 989e6cd5
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "post_latgen_faster_mapped.h" #include "post_latgen_faster_mapped.h"
#include <limits>
#include "ThreadPool.h"
using namespace kaldi; using namespace kaldi;
typedef kaldi::int32 int32; typedef kaldi::int32 int32;
...@@ -34,11 +36,9 @@ Decoder::Decoder(std::string trans_model_in_filename, ...@@ -34,11 +36,9 @@ Decoder::Decoder(std::string trans_model_in_filename,
ParseOptions po(usage); ParseOptions po(usage);
allow_partial = false; allow_partial = false;
this->acoustic_scale = acoustic_scale; this->acoustic_scale = acoustic_scale;
LatticeFasterDecoderConfig config;
config.Register(&po); config.Register(&po);
int32 beam = 11; int32 beam = 11;
po.Register("beam", &beam, "Beam size");
po.Register("acoustic-scale", po.Register("acoustic-scale",
&acoustic_scale, &acoustic_scale,
"Scaling factor for acoustic likelihoods"); "Scaling factor for acoustic likelihoods");
...@@ -49,10 +49,13 @@ Decoder::Decoder(std::string trans_model_in_filename, ...@@ -49,10 +49,13 @@ Decoder::Decoder(std::string trans_model_in_filename,
&allow_partial, &allow_partial,
"If true, produce output even if end state was not reached."); "If true, produce output even if end state was not reached.");
// int argc = 2; int argc = 2;
// char *argv[] = {"post-latgen-faster-mapped", "--beam=11"}; char *argv[] = {(char *)"post-latgen-faster-mapped",
// po.Read(argc, argv); (char *)("--beam=" + std::string("11")).c_str()};
po.Read(argc, argv);
po.PrintConfig(std::cout);
std::ifstream is_logprior(logprior_in_filename); std::ifstream is_logprior(logprior_in_filename);
logprior.Read(is_logprior, false); logprior.Read(is_logprior, false);
...@@ -75,14 +78,16 @@ Decoder::Decoder(std::string trans_model_in_filename, ...@@ -75,14 +78,16 @@ Decoder::Decoder(std::string trans_model_in_filename,
// Input FST is just one FST, not a table of FSTs. // Input FST is just one FST, not a table of FSTs.
this->decode_fst = fst::ReadFstKaldiGeneric(fst_in_filename); this->decode_fst = fst::ReadFstKaldiGeneric(fst_in_filename);
this->decoder = new LatticeFasterDecoder(*decode_fst, config); kaldi::LatticeFasterDecoder *decoder =
new LatticeFasterDecoder(*decode_fst, config);
decoder_pool.emplace_back(decoder);
std::string lattice_wspecifier = std::string lattice_wspecifier =
"ark:|gzip -c > mapped_decoder_data/lat.JOB.gz"; "ark:|gzip -c > mapped_decoder_data/lat.JOB.gz";
if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier) if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier)
: lattice_writer.Open(lattice_wspecifier))) : lattice_writer.Open(lattice_wspecifier)))
KALDI_ERR << "Could not open table for writing lattices: "; KALDI_ERR << "Could not open table for writing lattices: "
// << lattice_wspecifier; << lattice_wspecifier;
words_writer = new Int32VectorWriter(""); words_writer = new Int32VectorWriter("");
alignment_writer = new Int32VectorWriter(""); alignment_writer = new Int32VectorWriter("");
...@@ -91,15 +96,16 @@ Decoder::Decoder(std::string trans_model_in_filename, ...@@ -91,15 +96,16 @@ Decoder::Decoder(std::string trans_model_in_filename,
Decoder::~Decoder() { Decoder::~Decoder() {
if (!this->word_syms) delete this->word_syms; if (!this->word_syms) delete this->word_syms;
delete this->decode_fst; delete this->decode_fst;
delete this->decoder; for (size_t i = 0; i < decoder_pool.size(); ++i) {
delete decoder_pool[i];
}
delete words_writer; delete words_writer;
delete alignment_writer; delete alignment_writer;
} }
std::vector<std::string> Decoder::decode(std::string posterior_rspecifier) { void Decoder::decode_from_file(std::string posterior_rspecifier,
std::vector<std::string> ret; size_t num_processes) {
try { try {
double tot_like = 0.0; double tot_like = 0.0;
kaldi::int64 frame_count = 0; kaldi::int64 frame_count = 0;
...@@ -112,40 +118,41 @@ std::vector<std::string> Decoder::decode(std::string posterior_rspecifier) { ...@@ -112,40 +118,41 @@ std::vector<std::string> Decoder::decode(std::string posterior_rspecifier) {
Timer timer; Timer timer;
timer.Reset(); timer.Reset();
double elapsed = 0.0;
for (size_t n = decoder_pool.size(); n < num_processes; ++n) {
kaldi::LatticeFasterDecoder *decoder =
new LatticeFasterDecoder(*decode_fst, config);
decoder_pool.emplace_back(decoder);
}
elapsed = timer.Elapsed();
ThreadPool thread_pool(num_processes);
{ while (!posterior_reader.Done()) {
for (; !posterior_reader.Done(); posterior_reader.Next()) { timer.Reset();
std::vector<std::future<std::string>> que;
for (size_t i = 0; i < num_processes && !posterior_reader.Done(); ++i) {
std::string utt = posterior_reader.Key(); std::string utt = posterior_reader.Key();
Matrix<BaseFloat> &loglikes(posterior_reader.Value()); Matrix<BaseFloat> &loglikes(posterior_reader.Value());
ret.push_back(decode(utt, loglikes)); que.emplace_back(thread_pool.enqueue(std::bind(
&Decoder::decode_internal, this, decoder_pool[i], utt, loglikes)));
posterior_reader.Next();
}
timer.Reset();
for (size_t i = 0; i < que.size(); ++i) {
std::cout << que[i].get() << std::endl;
} }
} }
double elapsed = timer.Elapsed();
return ret;
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::cerr << e.what(); std::cerr << e.what();
return ret;
} }
} }
std::vector<std::string> Decoder::decode_batch( inline kaldi::Matrix<kaldi::BaseFloat> vector2kaldi_mat(
std::vector<std::string> keys,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch) {
std::vector<std::string> decoding_results;
for (size_t i = 0; i < keys.size(); ++i) {
decoding_results.push_back(decode(keys[i], log_probs_batch[i]));
}
return decoding_results;
}
std::string Decoder::decode(
std::string key,
const std::vector<std::vector<kaldi::BaseFloat>> &log_probs) { const std::vector<std::vector<kaldi::BaseFloat>> &log_probs) {
size_t num_frames = log_probs.size(); size_t num_frames = log_probs.size();
size_t dim_label = log_probs[0].size(); size_t dim_label = log_probs[0].size();
kaldi::Matrix<kaldi::BaseFloat> loglikes( kaldi::Matrix<kaldi::BaseFloat> loglikes(
num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols); num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols);
for (size_t i = 0; i < num_frames; ++i) { for (size_t i = 0; i < num_frames; ++i) {
...@@ -153,14 +160,56 @@ std::string Decoder::decode( ...@@ -153,14 +160,56 @@ std::string Decoder::decode(
log_probs[i].data(), log_probs[i].data(),
sizeof(kaldi::BaseFloat) * dim_label); sizeof(kaldi::BaseFloat) * dim_label);
} }
return loglikes;
}
std::vector<std::string> Decoder::decode_batch(
std::vector<std::string> keys,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch,
size_t num_processes) {
ThreadPool thread_pool(num_processes);
std::vector<std::string> decoding_results; //(keys.size(), "");
for (size_t n = decoder_pool.size(); n < num_processes; ++n) {
kaldi::LatticeFasterDecoder *decoder =
new LatticeFasterDecoder(*decode_fst, config);
decoder_pool.emplace_back(decoder);
}
return decode(key, loglikes); size_t index = 0;
while (index < keys.size()) {
std::vector<std::future<std::string>> res_in_que;
for (size_t t = 0; t < num_processes && index < keys.size(); ++t) {
kaldi::Matrix<kaldi::BaseFloat> loglikes =
vector2kaldi_mat(log_probs_batch[index]);
res_in_que.emplace_back(
thread_pool.enqueue(std::bind(&Decoder::decode_internal,
this,
decoder_pool[t],
keys[index],
loglikes)));
index++;
}
for (size_t i = 0; i < res_in_que.size(); ++i) {
decoding_results.emplace_back(res_in_que[i].get());
}
}
return decoding_results;
} }
std::string Decoder::decode(
std::string key,
const std::vector<std::vector<kaldi::BaseFloat>> &log_probs) {
kaldi::Matrix<kaldi::BaseFloat> loglikes = vector2kaldi_mat(log_probs);
return decode_internal(decoder_pool[0], key, loglikes);
}
std::string Decoder::decode(std::string key,
kaldi::Matrix<kaldi::BaseFloat> &loglikes) { std::string Decoder::decode_internal(
std::string decoding_result; LatticeFasterDecoder *decoder,
std::string key,
kaldi::Matrix<kaldi::BaseFloat> &loglikes) {
if (loglikes.NumRows() == 0) { if (loglikes.NumRows() == 0) {
KALDI_WARN << "Zero-length utterance: " << key; KALDI_WARN << "Zero-length utterance: " << key;
// num_fail++; // num_fail++;
...@@ -173,21 +222,22 @@ std::string Decoder::decode(std::string key, ...@@ -173,21 +222,22 @@ std::string Decoder::decode(std::string key,
DecodableMatrixScaledMapped matrix_decodable( DecodableMatrixScaledMapped matrix_decodable(
trans_model, loglikes, acoustic_scale); trans_model, loglikes, acoustic_scale);
double like; double like;
return this->DecodeUtteranceLatticeFaster(
return this->DecodeUtteranceLatticeFaster(matrix_decodable, key, &like); decoder, matrix_decodable, key, &like);
} }
// Takes care of output. Returns true on success.
std::string Decoder::DecodeUtteranceLatticeFaster( std::string Decoder::DecodeUtteranceLatticeFaster(
LatticeFasterDecoder *decoder,
DecodableInterface &decodable, // not const but is really an input. DecodableInterface &decodable, // not const but is really an input.
std::string utt, std::string utt,
double *like_ptr) { // puts utterance's like in like_ptr on success. double *like_ptr) { // puts utterance's like in like_ptr on success.
using fst::VectorFst; using fst::VectorFst;
std::string ret = utt + ' ';
if (!decoder->Decode(&decodable)) { if (!decoder->Decode(&decodable)) {
KALDI_WARN << "Failed to decode file " << utt; KALDI_WARN << "Failed to decode file " << utt;
return false; return ret;
} }
if (!decoder->ReachedFinal()) { if (!decoder->ReachedFinal()) {
if (allow_partial) { if (allow_partial) {
...@@ -197,14 +247,13 @@ std::string Decoder::DecodeUtteranceLatticeFaster( ...@@ -197,14 +247,13 @@ std::string Decoder::DecodeUtteranceLatticeFaster(
KALDI_WARN << "Not producing output for utterance " << utt KALDI_WARN << "Not producing output for utterance " << utt
<< " since no final-state reached and " << " since no final-state reached and "
<< "--allow-partial=false.\n"; << "--allow-partial=false.\n";
return false; return ret;
} }
} }
double likelihood; double likelihood;
LatticeWeight weight; LatticeWeight weight;
int32 num_frames; int32 num_frames;
std::string ret = utt + ' ';
{ // First do some stuff with word-level traceback... { // First do some stuff with word-level traceback...
VectorFst<LatticeArc> decoded; VectorFst<LatticeArc> decoded;
if (!decoder->GetBestPath(&decoded)) if (!decoder->GetBestPath(&decoded))
...@@ -215,7 +264,7 @@ std::string Decoder::DecodeUtteranceLatticeFaster( ...@@ -215,7 +264,7 @@ std::string Decoder::DecodeUtteranceLatticeFaster(
std::vector<int32> words; std::vector<int32> words;
GetLinearSymbolSequence(decoded, &alignment, &words, &weight); GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
num_frames = alignment.size(); num_frames = alignment.size();
if (alignment_writer->IsOpen()) alignment_writer->Write(utt, alignment); // if (alignment_writer->IsOpen()) alignment_writer->Write(utt, alignment);
if (word_syms != NULL) { if (word_syms != NULL) {
for (size_t i = 0; i < words.size(); i++) { for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms->Find(words[i]); std::string s = word_syms->Find(words[i]);
......
...@@ -32,9 +32,10 @@ public: ...@@ -32,9 +32,10 @@ public:
kaldi::BaseFloat acoustic_scale); kaldi::BaseFloat acoustic_scale);
~Decoder(); ~Decoder();
// Interface to accept the scores read from specifier and return // Interface to accept the scores read from specifier and print
// the batch decoding results // the decoding results directly
std::vector<std::string> decode(std::string posterior_rspecifier); void decode_from_file(std::string posterior_rspecifier,
size_t num_processes = 1);
// Accept the scores of one utterance and return the decoding result // Accept the scores of one utterance and return the decoding result
std::string decode( std::string decode(
...@@ -45,21 +46,26 @@ public: ...@@ -45,21 +46,26 @@ public:
std::vector<std::string> decode_batch( std::vector<std::string> decode_batch(
std::vector<std::string> key, std::vector<std::string> key,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>> const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch); &log_probs_batch,
size_t num_processes = 1);
private: private:
// For decoding one utterance // For decoding one utterance
std::string decode(std::string key, std::string decode_internal(kaldi::LatticeFasterDecoder *decoder,
kaldi::Matrix<kaldi::BaseFloat> &loglikes); std::string key,
std::string DecodeUtteranceLatticeFaster(kaldi::DecodableInterface &decodable, kaldi::Matrix<kaldi::BaseFloat> &loglikes);
std::string DecodeUtteranceLatticeFaster(kaldi::LatticeFasterDecoder *decoder,
kaldi::DecodableInterface &decodable,
std::string utt, std::string utt,
double *like_ptr); double *like_ptr);
fst::SymbolTable *word_syms; fst::SymbolTable *word_syms;
fst::Fst<fst::StdArc> *decode_fst; fst::Fst<fst::StdArc> *decode_fst;
kaldi::LatticeFasterDecoder *decoder; std::vector<kaldi::LatticeFasterDecoder *> decoder_pool;
kaldi::Vector<kaldi::BaseFloat> logprior; kaldi::Vector<kaldi::BaseFloat> logprior;
kaldi::TransitionModel trans_model; kaldi::TransitionModel trans_model;
kaldi::LatticeFasterDecoderConfig config;
kaldi::CompactLatticeWriter compact_lattice_writer; kaldi::CompactLatticeWriter compact_lattice_writer;
kaldi::LatticeWriter lattice_writer; kaldi::LatticeWriter lattice_writer;
......
...@@ -28,16 +28,23 @@ PYBIND11_MODULE(post_latgen_faster_mapped, m) { ...@@ -28,16 +28,23 @@ PYBIND11_MODULE(post_latgen_faster_mapped, m) {
std::string, std::string,
std::string, std::string,
kaldi::BaseFloat>()) kaldi::BaseFloat>())
.def("decode", .def("decode_from_file",
(std::vector<std::string> (Decoder::*)(std::string)) & (void (Decoder::*)(std::string, size_t)) & Decoder::decode_from_file,
Decoder::decode,
"Decode for the probability matrices in specifier " "Decode for the probability matrices in specifier "
"and return the transcriptions.") "and print the transcriptions.")
.def( .def(
"decode", "decode",
(std::string (Decoder::*)( (std::string (Decoder::*)(
std::string, const std::vector<std::vector<kaldi::BaseFloat>>&)) & std::string, const std::vector<std::vector<kaldi::BaseFloat>>&)) &
Decoder::decode, Decoder::decode,
"Decode one input probability matrix " "Decode one input probability matrix "
"and return the transcription."); "and return the transcription.")
.def("decode_batch",
(std::vector<std::string> (Decoder::*)(
std::string,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>&,
size_t num_processes)) &
Decoder::decode_batch,
"Decode one batch of probability matrices "
"and return the transcriptions.");
} }
...@@ -24,7 +24,7 @@ except: ...@@ -24,7 +24,7 @@ except:
"install kaldi and export KALDI_ROOT=<kaldi's root dir> .") "install kaldi and export KALDI_ROOT=<kaldi's root dir> .")
args = [ args = [
'-std=c++11', '-Wno-sign-compare', '-Wno-unused-variable', '-std=c++11', '-fopenmp', '-Wno-sign-compare', '-Wno-unused-variable',
'-Wno-unused-local-typedefs', '-Wno-unused-but-set-variable', '-Wno-unused-local-typedefs', '-Wno-unused-but-set-variable',
'-Wno-deprecated-declarations', '-Wno-unused-function' '-Wno-deprecated-declarations', '-Wno-unused-function'
] ]
...@@ -53,7 +53,7 @@ ext_modules = [ ...@@ -53,7 +53,7 @@ ext_modules = [
['pybind.cc', 'post_latgen_faster_mapped.cc'], ['pybind.cc', 'post_latgen_faster_mapped.cc'],
include_dirs=[ include_dirs=[
'pybind11/include', '.', os.path.join(kaldi_root, 'src'), 'pybind11/include', '.', os.path.join(kaldi_root, 'src'),
os.path.join(kaldi_root, 'tools/openfst/src/include') os.path.join(kaldi_root, 'tools/openfst/src/include'), 'ThreadPool'
], ],
language='c++', language='c++',
libraries=LIBS, libraries=LIBS,
......
...@@ -4,4 +4,9 @@ if [ ! -d pybind11 ]; then ...@@ -4,4 +4,9 @@ if [ ! -d pybind11 ]; then
git clone https://github.com/pybind/pybind11.git git clone https://github.com/pybind/pybind11.git
fi fi
if [ ! -d ThreadPool ]; then
git clone https://github.com/progschj/ThreadPool.git
echo -e "\n"
fi
python setup.py build_ext -i python setup.py build_ext -i
...@@ -6,7 +6,6 @@ python -u ../../infer_by_ckpt.py --batch_size 48 \ ...@@ -6,7 +6,6 @@ python -u ../../infer_by_ckpt.py --batch_size 48 \
--mean_var data/aishell/global_mean_var \ --mean_var data/aishell/global_mean_var \
--frame_dim 80 \ --frame_dim 80 \
--class_num 3040 \ --class_num 3040 \
--post_matrix_path post_matrix.decoded \
--target_trans data/text.test \ --target_trans data/text.test \
--trans_model mapped_decoder_data/exp/tri5a/final.mdl \ --trans_model mapped_decoder_data/exp/tri5a/final.mdl \
--log_prior mapped_decoder_data/logprior \ --log_prior mapped_decoder_data/logprior \
......
...@@ -238,10 +238,10 @@ def infer_from_ckpt(args): ...@@ -238,10 +238,10 @@ def infer_from_ckpt(args):
probs, lod = lodtensor_to_ndarray(results[0]) probs, lod = lodtensor_to_ndarray(results[0])
infer_batch = split_infer_result(probs, lod) infer_batch = split_infer_result(probs, lod)
for index, sample in enumerate(infer_batch): decoder.decode_batch(name_lst, infer_batch)
key = name_lst[index] if args.post_matrix_path is not None:
ref = trg_trans[key] for index, sample in enumerate(infer_batch):
if args.post_matrix_path is not None: key = name_lst[index]
out_post_matrix(key, sample) out_post_matrix(key, sample)
''' '''
hyp = decoder.decode(key, sample) hyp = decoder.decode(key, sample)
...@@ -252,9 +252,9 @@ def infer_from_ckpt(args): ...@@ -252,9 +252,9 @@ def infer_from_ckpt(args):
print(key + "|Hyp:", hyp.encode("utf8")) print(key + "|Hyp:", hyp.encode("utf8"))
print("Instance CER: ", edit_dist / ref_len) print("Instance CER: ", edit_dist / ref_len)
''' '''
print("batch: ", batch_id) #print("batch: ", batch_id)
print("Total CER = %f" % (total_edit_dist / total_ref_len)) #print("Total CER = %f" % (total_edit_dist / total_ref_len))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册