提交 e1d90fc0 编写于 作者: Y Yibing Liu

Use thread pool for parallel decoding

上级 989e6cd5
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "post_latgen_faster_mapped.h"
#include <limits>
#include "ThreadPool.h"
using namespace kaldi;
typedef kaldi::int32 int32;
......@@ -34,11 +36,9 @@ Decoder::Decoder(std::string trans_model_in_filename,
ParseOptions po(usage);
allow_partial = false;
this->acoustic_scale = acoustic_scale;
LatticeFasterDecoderConfig config;
config.Register(&po);
int32 beam = 11;
po.Register("beam", &beam, "Beam size");
po.Register("acoustic-scale",
&acoustic_scale,
"Scaling factor for acoustic likelihoods");
......@@ -49,10 +49,13 @@ Decoder::Decoder(std::string trans_model_in_filename,
&allow_partial,
"If true, produce output even if end state was not reached.");
// int argc = 2;
// char *argv[] = {"post-latgen-faster-mapped", "--beam=11"};
// po.Read(argc, argv);
int argc = 2;
char *argv[] = {(char *)"post-latgen-faster-mapped",
(char *)("--beam=" + std::string("11")).c_str()};
po.Read(argc, argv);
po.PrintConfig(std::cout);
std::ifstream is_logprior(logprior_in_filename);
logprior.Read(is_logprior, false);
......@@ -75,14 +78,16 @@ Decoder::Decoder(std::string trans_model_in_filename,
// Input FST is just one FST, not a table of FSTs.
this->decode_fst = fst::ReadFstKaldiGeneric(fst_in_filename);
this->decoder = new LatticeFasterDecoder(*decode_fst, config);
kaldi::LatticeFasterDecoder *decoder =
new LatticeFasterDecoder(*decode_fst, config);
decoder_pool.emplace_back(decoder);
std::string lattice_wspecifier =
"ark:|gzip -c > mapped_decoder_data/lat.JOB.gz";
if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier)
: lattice_writer.Open(lattice_wspecifier)))
KALDI_ERR << "Could not open table for writing lattices: ";
// << lattice_wspecifier;
KALDI_ERR << "Could not open table for writing lattices: "
<< lattice_wspecifier;
words_writer = new Int32VectorWriter("");
alignment_writer = new Int32VectorWriter("");
......@@ -91,15 +96,16 @@ Decoder::Decoder(std::string trans_model_in_filename,
Decoder::~Decoder() {
if (!this->word_syms) delete this->word_syms;
delete this->decode_fst;
delete this->decoder;
for (size_t i = 0; i < decoder_pool.size(); ++i) {
delete decoder_pool[i];
}
delete words_writer;
delete alignment_writer;
}
std::vector<std::string> Decoder::decode(std::string posterior_rspecifier) {
std::vector<std::string> ret;
void Decoder::decode_from_file(std::string posterior_rspecifier,
size_t num_processes) {
try {
double tot_like = 0.0;
kaldi::int64 frame_count = 0;
......@@ -112,40 +118,41 @@ std::vector<std::string> Decoder::decode(std::string posterior_rspecifier) {
Timer timer;
timer.Reset();
double elapsed = 0.0;
{
for (; !posterior_reader.Done(); posterior_reader.Next()) {
for (size_t n = decoder_pool.size(); n < num_processes; ++n) {
kaldi::LatticeFasterDecoder *decoder =
new LatticeFasterDecoder(*decode_fst, config);
decoder_pool.emplace_back(decoder);
}
elapsed = timer.Elapsed();
ThreadPool thread_pool(num_processes);
while (!posterior_reader.Done()) {
timer.Reset();
std::vector<std::future<std::string>> que;
for (size_t i = 0; i < num_processes && !posterior_reader.Done(); ++i) {
std::string utt = posterior_reader.Key();
Matrix<BaseFloat> &loglikes(posterior_reader.Value());
ret.push_back(decode(utt, loglikes));
que.emplace_back(thread_pool.enqueue(std::bind(
&Decoder::decode_internal, this, decoder_pool[i], utt, loglikes)));
posterior_reader.Next();
}
timer.Reset();
for (size_t i = 0; i < que.size(); ++i) {
std::cout << que[i].get() << std::endl;
}
}
double elapsed = timer.Elapsed();
return ret;
} catch (const std::exception &e) {
std::cerr << e.what();
return ret;
}
}
std::vector<std::string> Decoder::decode_batch(
std::vector<std::string> keys,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch) {
std::vector<std::string> decoding_results;
for (size_t i = 0; i < keys.size(); ++i) {
decoding_results.push_back(decode(keys[i], log_probs_batch[i]));
}
return decoding_results;
}
std::string Decoder::decode(
std::string key,
inline kaldi::Matrix<kaldi::BaseFloat> vector2kaldi_mat(
const std::vector<std::vector<kaldi::BaseFloat>> &log_probs) {
size_t num_frames = log_probs.size();
size_t dim_label = log_probs[0].size();
kaldi::Matrix<kaldi::BaseFloat> loglikes(
num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols);
for (size_t i = 0; i < num_frames; ++i) {
......@@ -153,14 +160,56 @@ std::string Decoder::decode(
log_probs[i].data(),
sizeof(kaldi::BaseFloat) * dim_label);
}
return loglikes;
}
std::vector<std::string> Decoder::decode_batch(
std::vector<std::string> keys,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch,
size_t num_processes) {
ThreadPool thread_pool(num_processes);
std::vector<std::string> decoding_results; //(keys.size(), "");
for (size_t n = decoder_pool.size(); n < num_processes; ++n) {
kaldi::LatticeFasterDecoder *decoder =
new LatticeFasterDecoder(*decode_fst, config);
decoder_pool.emplace_back(decoder);
}
return decode(key, loglikes);
size_t index = 0;
while (index < keys.size()) {
std::vector<std::future<std::string>> res_in_que;
for (size_t t = 0; t < num_processes && index < keys.size(); ++t) {
kaldi::Matrix<kaldi::BaseFloat> loglikes =
vector2kaldi_mat(log_probs_batch[index]);
res_in_que.emplace_back(
thread_pool.enqueue(std::bind(&Decoder::decode_internal,
this,
decoder_pool[t],
keys[index],
loglikes)));
index++;
}
for (size_t i = 0; i < res_in_que.size(); ++i) {
decoding_results.emplace_back(res_in_que[i].get());
}
}
return decoding_results;
}
std::string Decoder::decode(
std::string key,
const std::vector<std::vector<kaldi::BaseFloat>> &log_probs) {
kaldi::Matrix<kaldi::BaseFloat> loglikes = vector2kaldi_mat(log_probs);
return decode_internal(decoder_pool[0], key, loglikes);
}
std::string Decoder::decode(std::string key,
std::string Decoder::decode_internal(
LatticeFasterDecoder *decoder,
std::string key,
kaldi::Matrix<kaldi::BaseFloat> &loglikes) {
std::string decoding_result;
if (loglikes.NumRows() == 0) {
KALDI_WARN << "Zero-length utterance: " << key;
// num_fail++;
......@@ -173,21 +222,22 @@ std::string Decoder::decode(std::string key,
DecodableMatrixScaledMapped matrix_decodable(
trans_model, loglikes, acoustic_scale);
double like;
return this->DecodeUtteranceLatticeFaster(matrix_decodable, key, &like);
return this->DecodeUtteranceLatticeFaster(
decoder, matrix_decodable, key, &like);
}
// Takes care of output. Returns true on success.
std::string Decoder::DecodeUtteranceLatticeFaster(
LatticeFasterDecoder *decoder,
DecodableInterface &decodable, // not const but is really an input.
std::string utt,
double *like_ptr) { // puts utterance's like in like_ptr on success.
using fst::VectorFst;
std::string ret = utt + ' ';
if (!decoder->Decode(&decodable)) {
KALDI_WARN << "Failed to decode file " << utt;
return false;
return ret;
}
if (!decoder->ReachedFinal()) {
if (allow_partial) {
......@@ -197,14 +247,13 @@ std::string Decoder::DecodeUtteranceLatticeFaster(
KALDI_WARN << "Not producing output for utterance " << utt
<< " since no final-state reached and "
<< "--allow-partial=false.\n";
return false;
return ret;
}
}
double likelihood;
LatticeWeight weight;
int32 num_frames;
std::string ret = utt + ' ';
{ // First do some stuff with word-level traceback...
VectorFst<LatticeArc> decoded;
if (!decoder->GetBestPath(&decoded))
......@@ -215,7 +264,7 @@ std::string Decoder::DecodeUtteranceLatticeFaster(
std::vector<int32> words;
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
num_frames = alignment.size();
if (alignment_writer->IsOpen()) alignment_writer->Write(utt, alignment);
// if (alignment_writer->IsOpen()) alignment_writer->Write(utt, alignment);
if (word_syms != NULL) {
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms->Find(words[i]);
......
......@@ -32,9 +32,10 @@ public:
kaldi::BaseFloat acoustic_scale);
~Decoder();
// Interface to accept the scores read from specifier and return
// the batch decoding results
std::vector<std::string> decode(std::string posterior_rspecifier);
// Interface to accept the scores read from specifier and print
// the decoding results directly
void decode_from_file(std::string posterior_rspecifier,
size_t num_processes = 1);
// Accept the scores of one utterance and return the decoding result
std::string decode(
......@@ -45,21 +46,26 @@ public:
std::vector<std::string> decode_batch(
std::vector<std::string> key,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch);
&log_probs_batch,
size_t num_processes = 1);
private:
// For decoding one utterance
std::string decode(std::string key,
std::string decode_internal(kaldi::LatticeFasterDecoder *decoder,
std::string key,
kaldi::Matrix<kaldi::BaseFloat> &loglikes);
std::string DecodeUtteranceLatticeFaster(kaldi::DecodableInterface &decodable,
std::string DecodeUtteranceLatticeFaster(kaldi::LatticeFasterDecoder *decoder,
kaldi::DecodableInterface &decodable,
std::string utt,
double *like_ptr);
fst::SymbolTable *word_syms;
fst::Fst<fst::StdArc> *decode_fst;
kaldi::LatticeFasterDecoder *decoder;
std::vector<kaldi::LatticeFasterDecoder *> decoder_pool;
kaldi::Vector<kaldi::BaseFloat> logprior;
kaldi::TransitionModel trans_model;
kaldi::LatticeFasterDecoderConfig config;
kaldi::CompactLatticeWriter compact_lattice_writer;
kaldi::LatticeWriter lattice_writer;
......
......@@ -28,16 +28,23 @@ PYBIND11_MODULE(post_latgen_faster_mapped, m) {
std::string,
std::string,
kaldi::BaseFloat>())
.def("decode",
(std::vector<std::string> (Decoder::*)(std::string)) &
Decoder::decode,
.def("decode_from_file",
(void (Decoder::*)(std::string, size_t)) & Decoder::decode_from_file,
"Decode for the probability matrices in specifier "
"and return the transcriptions.")
"and print the transcriptions.")
.def(
"decode",
(std::string (Decoder::*)(
std::string, const std::vector<std::vector<kaldi::BaseFloat>>&)) &
Decoder::decode,
"Decode one input probability matrix "
"and return the transcription.");
"and return the transcription.")
.def("decode_batch",
(std::vector<std::string> (Decoder::*)(
std::string,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>&,
size_t num_processes)) &
Decoder::decode_batch,
"Decode one batch of probability matrices "
"and return the transcriptions.");
}
......@@ -24,7 +24,7 @@ except:
"install kaldi and export KALDI_ROOT=<kaldi's root dir> .")
args = [
'-std=c++11', '-Wno-sign-compare', '-Wno-unused-variable',
'-std=c++11', '-fopenmp', '-Wno-sign-compare', '-Wno-unused-variable',
'-Wno-unused-local-typedefs', '-Wno-unused-but-set-variable',
'-Wno-deprecated-declarations', '-Wno-unused-function'
]
......@@ -53,7 +53,7 @@ ext_modules = [
['pybind.cc', 'post_latgen_faster_mapped.cc'],
include_dirs=[
'pybind11/include', '.', os.path.join(kaldi_root, 'src'),
os.path.join(kaldi_root, 'tools/openfst/src/include')
os.path.join(kaldi_root, 'tools/openfst/src/include'), 'ThreadPool'
],
language='c++',
libraries=LIBS,
......
......@@ -4,4 +4,9 @@ if [ ! -d pybind11 ]; then
git clone https://github.com/pybind/pybind11.git
fi
if [ ! -d ThreadPool ]; then
git clone https://github.com/progschj/ThreadPool.git
echo -e "\n"
fi
python setup.py build_ext -i
......@@ -6,7 +6,6 @@ python -u ../../infer_by_ckpt.py --batch_size 48 \
--mean_var data/aishell/global_mean_var \
--frame_dim 80 \
--class_num 3040 \
--post_matrix_path post_matrix.decoded \
--target_trans data/text.test \
--trans_model mapped_decoder_data/exp/tri5a/final.mdl \
--log_prior mapped_decoder_data/logprior \
......
......@@ -238,10 +238,10 @@ def infer_from_ckpt(args):
probs, lod = lodtensor_to_ndarray(results[0])
infer_batch = split_infer_result(probs, lod)
decoder.decode_batch(name_lst, infer_batch)
if args.post_matrix_path is not None:
for index, sample in enumerate(infer_batch):
key = name_lst[index]
ref = trg_trans[key]
if args.post_matrix_path is not None:
out_post_matrix(key, sample)
'''
hyp = decoder.decode(key, sample)
......@@ -252,9 +252,9 @@ def infer_from_ckpt(args):
print(key + "|Hyp:", hyp.encode("utf8"))
print("Instance CER: ", edit_dist / ref_len)
'''
print("batch: ", batch_id)
#print("batch: ", batch_id)
print("Total CER = %f" % (total_edit_dist / total_ref_len))
#print("Total CER = %f" % (total_edit_dist / total_ref_len))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册