diff --git a/fluid/DeepASR/decoder/decoder.cc b/fluid/DeepASR/decoder/decoder.cc deleted file mode 100644 index a99f972e2fc2341247eb2a6aa564d8d6b5e2905d..0000000000000000000000000000000000000000 --- a/fluid/DeepASR/decoder/decoder.cc +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "decoder.h" - -std::string decode(std::vector> probs_mat) { - // Add decoding logic here - - return "example decoding result"; -} diff --git a/fluid/DeepASR/decoder/post_decode_faster.cc b/fluid/DeepASR/decoder/post_decode_faster.cc new file mode 100644 index 0000000000000000000000000000000000000000..d3f20a6ea38f76c11733373a0c72f0fe769b334d --- /dev/null +++ b/fluid/DeepASR/decoder/post_decode_faster.cc @@ -0,0 +1,165 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "post_decode_faster.h" +#include "base/kaldi-common.h" +#include "base/timer.h" +#include "decoder/decodable-matrix.h" +#include "decoder/faster-decoder.h" +#include "fstext/fstext-lib.h" +#include "hmm/transition-model.h" +#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc +#include "tree/context-dep.h" +#include "util/common-utils.h" + +std::vector decode(std::string word_syms_filename, + std::string fst_in_filename, + std::string logprior_rxfilename, + std::string posterior_rspecifier, + std::string words_wspecifier, + std::string alignment_wspecifier) { + std::vector decoding_results; + + try { + using namespace kaldi; + typedef kaldi::int32 int32; + using fst::SymbolTable; + using fst::VectorFst; + using fst::StdArc; + + const char *usage = + "Decode, reading log-likelihoods (of transition-ids or whatever symbol " + "is on the graph) as matrices."; + ParseOptions po(usage); + bool binary = true; + BaseFloat acoustic_scale = 1.5; + bool allow_partial = true; + FasterDecoderOptions decoder_opts; + decoder_opts.Register(&po, true); // true == include obscure settings. + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("allow-partial", + &allow_partial, + "Produce output even when final state was not reached"); + po.Register("acoustic-scale", + &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + + Int32VectorWriter words_writer(words_wspecifier); + + Int32VectorWriter alignment_writer(alignment_wspecifier); + fst::SymbolTable *word_syms = NULL; + if (word_syms_filename != "") { + word_syms = fst::SymbolTable::ReadText(word_syms_filename); + if (!word_syms) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + } + + SequentialBaseFloatMatrixReader posterior_reader(posterior_rspecifier); + std::ifstream is_logprior(logprior_rxfilename); + Vector logprior; + logprior.Read(is_logprior, false); + + // It's important that we initialize decode_fst after loglikes_reader, as it + // can prevent crashes on systems installed without enough virtual memory. + // It has to do with what happens on UNIX systems if you call fork() on a + // large process: the page-table entries are duplicated, which requires a + // lot of virtual memory. + VectorFst *decode_fst = fst::ReadFstKaldi(fst_in_filename); + + BaseFloat tot_like = 0.0; + kaldi::int64 frame_count = 0; + int num_success = 0, num_fail = 0; + FasterDecoder decoder(*decode_fst, decoder_opts); + + Timer timer; + + for (; !posterior_reader.Done(); posterior_reader.Next()) { + std::string key = posterior_reader.Key(); + Matrix loglikes(posterior_reader.Value()); + KALDI_LOG << key << " " << loglikes.NumRows() << " x " + << loglikes.NumCols(); + + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << key; + num_fail++; + continue; + } + KALDI_ASSERT(loglikes.NumCols() == logprior.Dim()); + + loglikes.ApplyLog(); + loglikes.AddVecToRows(-1.0, logprior); + + DecodableMatrixScaled decodable(loglikes, acoustic_scale); + decoder.Decode(&decodable); + + VectorFst decoded; // linear FST. + + if ((allow_partial || decoder.ReachedFinal()) && + decoder.GetBestPath(&decoded)) { + num_success++; + if (!decoder.ReachedFinal()) + KALDI_WARN << "Decoder did not reach end-state, outputting partial " + "traceback."; + + std::vector alignment; + std::vector words; + LatticeWeight weight; + frame_count += loglikes.NumRows(); + + GetLinearSymbolSequence(decoded, &alignment, &words, &weight); + + words_writer.Write(key, words); + if (alignment_writer.IsOpen()) alignment_writer.Write(key, alignment); + if (word_syms != NULL) { + std::string res; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms->Find(words[i]); + res += s; + if (s == "") + KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; + std::cerr << s << ' '; + } + decoding_results.push_back(res); + } + BaseFloat like = -weight.Value1() - weight.Value2(); + tot_like += like; + KALDI_LOG << "Log-like per frame for utterance " << key << " is " + << (like / loglikes.NumRows()) << " over " + << loglikes.NumRows() << " frames."; + + } else { + num_fail++; + KALDI_WARN << "Did not successfully decode utterance " << key + << ", len = " << loglikes.NumRows(); + } + } + + double elapsed = timer.Elapsed(); + KALDI_LOG << "Time taken [excluding initialization] " << elapsed + << "s: real-time factor assuming 100 frames/sec is " + << (elapsed * 100.0 / frame_count); + KALDI_LOG << "Done " << num_success << " utterances, failed for " + << num_fail; + KALDI_LOG << "Overall log-likelihood per frame is " + << (tot_like / frame_count) << " over " << frame_count + << " frames."; + + delete word_syms; + delete decode_fst; + } catch (const std::exception &e) { + std::cerr << e.what(); + } + return decoding_results; +} diff --git a/fluid/DeepASR/decoder/decoder.h b/fluid/DeepASR/decoder/post_decode_faster.h similarity index 61% rename from fluid/DeepASR/decoder/decoder.h rename to fluid/DeepASR/decoder/post_decode_faster.h index 4a67fa366ae31dd393d0e3b2f04d27e360596fe5..04983a3b93b38ebd7375cc7d18336b92d7b1c1b3 100644 --- a/fluid/DeepASR/decoder/decoder.h +++ b/fluid/DeepASR/decoder/post_decode_faster.h @@ -15,4 +15,9 @@ limitations under the License. */ #include #include -std::string decode(std::vector> probs_mat); +std::vector decode(std::string word_syms_filename, + std::string fst_in_filename, + std::string logprior_rxfilename, + std::string posterior_respecifier, + std::string words_wspecifier, + std::string alignment_wspecifier = ""); diff --git a/fluid/DeepASR/decoder/pybind.cc b/fluid/DeepASR/decoder/pybind.cc index 8cd65903eae63268d9f4412bd737162c639d8910..a8744ee2ac5f89af3b6b5d36776fc9761dd98f03 100644 --- a/fluid/DeepASR/decoder/pybind.cc +++ b/fluid/DeepASR/decoder/pybind.cc @@ -15,11 +15,11 @@ limitations under the License. */ #include #include -#include "decoder.h" +#include "post_decode_faster.h" namespace py = pybind11; -PYBIND11_MODULE(decoder, m) { +PYBIND11_MODULE(post_decode_faster, m) { m.doc() = "Decode function for Deep ASR model"; m.def("decode", diff --git a/fluid/DeepASR/decoder/setup.py b/fluid/DeepASR/decoder/setup.py index cedd5d644e0dc1ca8855ab7e75ee1b7a30f8fcb1..e1c74fcb0ddca0e1bea3cc88ea1f616ca8f4683f 100644 --- a/fluid/DeepASR/decoder/setup.py +++ b/fluid/DeepASR/decoder/setup.py @@ -13,27 +13,50 @@ # limitations under the License. import os +import glob from distutils.core import setup, Extension from distutils.sysconfig import get_config_vars -args = ['-std=c++11'] +args = [ + '-std=c++11', '-Wno-sign-compare', '-Wno-unused-variable', + '-Wno-unused-local-typedefs', '-Wno-unused-but-set-variable', + '-Wno-deprecated-declarations', '-Wno-unused-function' +] # remove warning about -Wstrict-prototypes (opt, ) = get_config_vars('OPT') os.environ['OPT'] = " ".join(flag for flag in opt.split() if flag != '-Wstrict-prototypes') +os.environ['CC'] = 'g++' + +LIBS = [ + 'fst', 'kaldi-base', 'kaldi-util', 'kaldi-matrix', 'kaldi-tree', + 'kaldi-hmm', 'kaldi-fstext', 'kaldi-decoder', 'kaldi-lat' +] + +LIB_DIRS = [ + 'kaldi/tools/openfst/lib', 'kaldi/src/base', 'kaldi/src/matrix', + 'kaldi/src/util', 'kaldi/src/tree', 'kaldi/src/hmm', 'kaldi/src/fstext', + 'kaldi/src/decoder', 'kaldi/src/lat' +] ext_modules = [ Extension( - 'decoder', - ['pybind.cc', 'decoder.cc'], - include_dirs=['pybind11/include', '.'], + 'post_decode_faster', + ['pybind.cc', 'post_decode_faster.cc'], + include_dirs=[ + 'pybind11/include', '.', 'kaldi/src/', + 'kaldi/tools/openfst/src/include' + ], + libraries=LIBS, language='c++', + library_dirs=LIB_DIRS, + runtime_library_dirs=LIB_DIRS, extra_compile_args=args, ), ] setup( - name='decoder', + name='post_decode_faster', version='0.0.1', author='Paddle', author_email='', diff --git a/fluid/DeepASR/decoder/setup.sh b/fluid/DeepASR/decoder/setup.sh index 71fd6626efe1b7cf72a1e678ab7b74000ebfb8c3..74cec0a48239fcf541abf963da7e447d92cdab32 100644 --- a/fluid/DeepASR/decoder/setup.sh +++ b/fluid/DeepASR/decoder/setup.sh @@ -1,7 +1,11 @@ - +set -e if [ ! -d pybind11 ]; then git clone https://github.com/pybind/pybind11.git fi +if [ ! -d kaldi ]; then + git clone https://github.com/kaldi-asr/kaldi.git +fi + python setup.py build_ext -i