From 1b7e0449850cac4ef753b0e89822c5666ee53dbf Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 6 Jun 2018 21:26:41 -0700 Subject: [PATCH] Adapt the decoder to the new label --- fluid/DeepASR/decoder/post_decode_faster.cc | 145 --------------- .../decoder/post_latgen_faster_mapped.cc | 172 ++++++++++++++++++ ...e_faster.h => post_latgen_faster_mapped.h} | 22 ++- fluid/DeepASR/decoder/pybind.cc | 10 +- fluid/DeepASR/decoder/setup.py | 8 +- fluid/DeepASR/infer_by_ckpt.py | 41 ++++- 6 files changed, 231 insertions(+), 167 deletions(-) delete mode 100644 fluid/DeepASR/decoder/post_decode_faster.cc create mode 100644 fluid/DeepASR/decoder/post_latgen_faster_mapped.cc rename fluid/DeepASR/decoder/{post_decode_faster.h => post_latgen_faster_mapped.h} (75%) diff --git a/fluid/DeepASR/decoder/post_decode_faster.cc b/fluid/DeepASR/decoder/post_decode_faster.cc deleted file mode 100644 index ce2b45bc..00000000 --- a/fluid/DeepASR/decoder/post_decode_faster.cc +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "post_decode_faster.h" - -typedef kaldi::int32 int32; -using fst::SymbolTable; -using fst::VectorFst; -using fst::StdArc; - -Decoder::Decoder(std::string word_syms_filename, - std::string fst_in_filename, - std::string logprior_rxfilename, - kaldi::BaseFloat acoustic_scale) { - const char* usage = - "Decode, reading log-likelihoods (of transition-ids or whatever symbol " - "is on the graph) as matrices."; - - kaldi::ParseOptions po(usage); - binary = true; - this->acoustic_scale = acoustic_scale; - allow_partial = true; - kaldi::FasterDecoderOptions decoder_opts; - decoder_opts.Register(&po, true); // true == include obscure settings. - po.Register("binary", &binary, "Write output in binary mode"); - po.Register("allow-partial", - &allow_partial, - "Produce output even when final state was not reached"); - po.Register("acoustic-scale", - &acoustic_scale, - "Scaling factor for acoustic likelihoods"); - - word_syms = NULL; - if (word_syms_filename != "") { - word_syms = fst::SymbolTable::ReadText(word_syms_filename); - if (!word_syms) - KALDI_ERR << "Could not read symbol table from file " - << word_syms_filename; - } - - std::ifstream is_logprior(logprior_rxfilename); - logprior.Read(is_logprior, false); - - // It's important that we initialize decode_fst after loglikes_reader, as it - // can prevent crashes on systems installed without enough virtual memory. - // It has to do with what happens on UNIX systems if you call fork() on a - // large process: the page-table entries are duplicated, which requires a - // lot of virtual memory. - decode_fst = fst::ReadFstKaldi(fst_in_filename); - - decoder = new kaldi::FasterDecoder(*decode_fst, decoder_opts); -} - - -Decoder::~Decoder() { - if (!word_syms) delete word_syms; - delete decode_fst; - delete decoder; -} - -std::string Decoder::decode( - std::string key, - const std::vector>& log_probs) { - size_t num_frames = log_probs.size(); - size_t dim_label = log_probs[0].size(); - - kaldi::Matrix loglikes( - num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols); - for (size_t i = 0; i < num_frames; ++i) { - memcpy(loglikes.Data() + i * dim_label, - log_probs[i].data(), - sizeof(kaldi::BaseFloat) * dim_label); - } - - return decode(key, loglikes); -} - - -std::vector Decoder::decode(std::string posterior_rspecifier) { - kaldi::SequentialBaseFloatMatrixReader posterior_reader(posterior_rspecifier); - std::vector decoding_results; - - for (; !posterior_reader.Done(); posterior_reader.Next()) { - std::string key = posterior_reader.Key(); - kaldi::Matrix loglikes(posterior_reader.Value()); - - decoding_results.push_back(decode(key, loglikes)); - } - - return decoding_results; -} - - -std::string Decoder::decode(std::string key, - kaldi::Matrix& loglikes) { - std::string decoding_result; - - if (loglikes.NumRows() == 0) { - KALDI_WARN << "Zero-length utterance: " << key; - } - KALDI_ASSERT(loglikes.NumCols() == logprior.Dim()); - - loglikes.ApplyLog(); - loglikes.AddVecToRows(-1.0, logprior); - - kaldi::DecodableMatrixScaled decodable(loglikes, acoustic_scale); - decoder->Decode(&decodable); - - VectorFst decoded; // linear FST. - - if ((allow_partial || decoder->ReachedFinal()) && - decoder->GetBestPath(&decoded)) { - if (!decoder->ReachedFinal()) - KALDI_WARN << "Decoder did not reach end-state, outputting partial " - "traceback."; - - std::vector alignment; - std::vector words; - kaldi::LatticeWeight weight; - - GetLinearSymbolSequence(decoded, &alignment, &words, &weight); - - if (word_syms != NULL) { - for (size_t i = 0; i < words.size(); i++) { - std::string s = word_syms->Find(words[i]); - decoding_result += s; - if (s == "") - KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; - } - } - } - - return decoding_result; -} diff --git a/fluid/DeepASR/decoder/post_latgen_faster_mapped.cc b/fluid/DeepASR/decoder/post_latgen_faster_mapped.cc new file mode 100644 index 00000000..19d5dbea --- /dev/null +++ b/fluid/DeepASR/decoder/post_latgen_faster_mapped.cc @@ -0,0 +1,172 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "post_latgen_faster_mapped.h" + +using namespace kaldi; +typedef kaldi::int32 int32; +using fst::SymbolTable; +using fst::Fst; +using fst::StdArc; + +Decoder::Decoder(std::string trans_model_in_filename, + std::string word_syms_filename, + std::string fst_in_filename, + std::string logprior_in_filename, + kaldi::BaseFloat acoustic_scale) { + const char *usage = + "Generate lattices using neural net model.\n" + "Usage: post-latgen-faster-mapped [options] " + " " + " [ [] " + "]\n"; + ParseOptions po(usage); + allow_partial = false; + this->acoustic_scale = acoustic_scale; + LatticeFasterDecoderConfig config; + + config.Register(&po); + int32 beam = 11; + po.Register("beam", &beam, "Beam size"); + po.Register("acoustic-scale", + &acoustic_scale, + "Scaling factor for acoustic likelihoods"); + po.Register("word-symbol-table", + &word_syms_filename, + "Symbol table for words [for debug output]"); + po.Register("allow-partial", + &allow_partial, + "If true, produce output even if end state was not reached."); + + // int argc = 2; + // char *argv[] = {"post-latgen-faster-mapped", "--beam=11"}; + // po.Read(argc, argv); + + std::ifstream is_logprior(logprior_in_filename); + logprior.Read(is_logprior, false); + + { + bool binary; + Input ki(trans_model_in_filename, &binary); + this->trans_model.Read(ki.Stream(), binary); + } + + this->determinize = config.determinize_lattice; + + this->word_syms = NULL; + if (word_syms_filename != "") { + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) { + KALDI_ERR << "Could not read symbol table from file " + << word_syms_filename; + } + } + + // Input FST is just one FST, not a table of FSTs. + this->decode_fst = fst::ReadFstKaldiGeneric(fst_in_filename); + + this->decoder = new LatticeFasterDecoder(*decode_fst, config); + + std::string lattice_wspecifier = + "ark:|gzip -c > mapped_decoder_data/lat.JOB.gz"; + if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier) + : lattice_writer.Open(lattice_wspecifier))) + KALDI_ERR << "Could not open table for writing lattices: "; + // << lattice_wspecifier; + + words_writer = new Int32VectorWriter(""); + alignment_writer = new Int32VectorWriter(""); +} + +Decoder::~Decoder() { + if (!this->word_syms) delete this->word_syms; + delete this->decode_fst; + delete this->decoder; + delete words_writer; + delete alignment_writer; +} + + +std::string Decoder::decode(std::string key, + kaldi::Matrix &loglikes) { + std::string decoding_result; + if (loglikes.NumRows() == 0) { + KALDI_WARN << "Zero-length utterance: " << key; + // num_fail++; + } + KALDI_ASSERT(loglikes.NumCols() == logprior.Dim()); + + loglikes.ApplyLog(); + loglikes.AddVecToRows(-1.0, logprior); + + DecodableMatrixScaledMapped matrix_decodable( + trans_model, loglikes, acoustic_scale); + double like; + + if (DecodeUtteranceLatticeFaster(*decoder, + matrix_decodable, + trans_model, + word_syms, + key, + acoustic_scale, + determinize, + allow_partial, + alignment_writer, + words_writer, + &compact_lattice_writer, + &lattice_writer, + &like)) { + // tot_like += like; + // frame_count += loglikes.NumRows(); + // num_success++; + decoding_result = "succeed!"; + } else { // else num_fail++; + decoding_result = "fail!"; + } + return decoding_result; +} + +std::vector Decoder::decode(std::string posterior_rspecifier) { + std::vector ret; + + try { + double tot_like = 0.0; + kaldi::int64 frame_count = 0; + // int num_success = 0, num_fail = 0; + + KALDI_ASSERT(ClassifyRspecifier(fst_in_filename, NULL, NULL) == + kNoRspecifier); + SequentialBaseFloatMatrixReader posterior_reader("ark:" + + posterior_rspecifier); + + Timer timer; + timer.Reset(); + + { + for (; !posterior_reader.Done(); posterior_reader.Next()) { + std::string utt = posterior_reader.Key(); + Matrix &loglikes(posterior_reader.Value()); + KALDI_LOG << utt << " " << loglikes.NumRows() << " x " + << loglikes.NumCols(); + ret.push_back(decode(utt, loglikes)); + } + } + + double elapsed = timer.Elapsed(); + return ret; + } catch (const std::exception &e) { + std::cerr << e.what(); + // ret.push_back("error"); + return ret; + } +} diff --git a/fluid/DeepASR/decoder/post_decode_faster.h b/fluid/DeepASR/decoder/post_latgen_faster_mapped.h similarity index 75% rename from fluid/DeepASR/decoder/post_decode_faster.h rename to fluid/DeepASR/decoder/post_latgen_faster_mapped.h index 8bade8d6..4adbf6ba 100644 --- a/fluid/DeepASR/decoder/post_decode_faster.h +++ b/fluid/DeepASR/decoder/post_latgen_faster_mapped.h @@ -17,19 +17,18 @@ limitations under the License. */ #include "base/kaldi-common.h" #include "base/timer.h" #include "decoder/decodable-matrix.h" -#include "decoder/faster-decoder.h" -#include "fstext/fstext-lib.h" +#include "decoder/decoder-wrappers.h" +#include "fstext/kaldi-fst-io.h" #include "hmm/transition-model.h" -#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc #include "tree/context-dep.h" #include "util/common-utils.h" - class Decoder { public: - Decoder(std::string word_syms_filename, + Decoder(std::string trans_model_in_filename, + std::string word_syms_filename, std::string fst_in_filename, - std::string logprior_rxfilename, + std::string logprior_in_filename, kaldi::BaseFloat acoustic_scale); ~Decoder(); @@ -48,11 +47,18 @@ private: kaldi::Matrix &loglikes); fst::SymbolTable *word_syms; - fst::VectorFst *decode_fst; - kaldi::FasterDecoder *decoder; + fst::Fst *decode_fst; + kaldi::LatticeFasterDecoder *decoder; kaldi::Vector logprior; + kaldi::TransitionModel trans_model; + + kaldi::CompactLatticeWriter compact_lattice_writer; + kaldi::LatticeWriter lattice_writer; + kaldi::Int32VectorWriter *words_writer; + kaldi::Int32VectorWriter *alignment_writer; bool binary; + bool determinize; kaldi::BaseFloat acoustic_scale; bool allow_partial; }; diff --git a/fluid/DeepASR/decoder/pybind.cc b/fluid/DeepASR/decoder/pybind.cc index 90ea38ff..e99050e6 100644 --- a/fluid/DeepASR/decoder/pybind.cc +++ b/fluid/DeepASR/decoder/pybind.cc @@ -15,15 +15,19 @@ limitations under the License. */ #include #include -#include "post_decode_faster.h" +#include "post_latgen_faster_mapped.h" namespace py = pybind11; -PYBIND11_MODULE(post_decode_faster, m) { +PYBIND11_MODULE(post_latgen_faster_mapped, m) { m.doc() = "Decoder for Deep ASR model"; py::class_(m, "Decoder") - .def(py::init()) + .def(py::init()) .def("decode", (std::vector (Decoder::*)(std::string)) & Decoder::decode, diff --git a/fluid/DeepASR/decoder/setup.py b/fluid/DeepASR/decoder/setup.py index a98c0b4c..74e8aa00 100644 --- a/fluid/DeepASR/decoder/setup.py +++ b/fluid/DeepASR/decoder/setup.py @@ -49,8 +49,8 @@ LIB_DIRS = [os.path.abspath(path) for path in LIB_DIRS] ext_modules = [ Extension( - 'post_decode_faster', - ['pybind.cc', 'post_decode_faster.cc'], + 'post_latgen_faster_mapped', + ['pybind.cc', 'post_latgen_faster_mapped.cc'], include_dirs=[ 'pybind11/include', '.', os.path.join(kaldi_root, 'src'), os.path.join(kaldi_root, 'tools/openfst/src/include') @@ -63,8 +63,8 @@ ext_modules = [ ] setup( - name='post_decode_faster', - version='0.0.1', + name='post_latgen_faster_mapped', + version='0.1.0', author='Paddle', author_email='', description='Decoder for Deep ASR model', diff --git a/fluid/DeepASR/infer_by_ckpt.py b/fluid/DeepASR/infer_by_ckpt.py index 83158192..36681e9a 100644 --- a/fluid/DeepASR/infer_by_ckpt.py +++ b/fluid/DeepASR/infer_by_ckpt.py @@ -14,7 +14,7 @@ import data_utils.augmentor.trans_add_delta as trans_add_delta import data_utils.augmentor.trans_splice as trans_splice import data_utils.augmentor.trans_delay as trans_delay import data_utils.async_data_reader as reader -from decoder.post_decode_faster import Decoder +from decoder.post_latgen_faster_mapped import Decoder from data_utils.util import lodtensor_to_ndarray from model_utils.model import stacked_lstmp_model from data_utils.util import split_infer_result @@ -98,20 +98,25 @@ def parse_args(): type=str, default='./checkpoint', help="The checkpoint path to init model. (default: %(default)s)") + parser.add_argument( + '--trans_model', + type=str, + default='./graph/trans_model', + help="The path to vocabulary. (default: %(default)s)") parser.add_argument( '--vocabulary', type=str, - default='./decoder/graph/words.txt', + default='./graph/words.txt', help="The path to vocabulary. (default: %(default)s)") parser.add_argument( '--graphs', type=str, - default='./decoder/graph/TLG.fst', + default='./graph/TLG.fst', help="The path to TLG graphs for decoding. (default: %(default)s)") parser.add_argument( '--log_prior', type=str, - default="./decoder/logprior", + default="./logprior", help="The log prior probs for training data. (default: %(default)s)") parser.add_argument( '--acoustic_scale', @@ -123,6 +128,11 @@ def parse_args(): type=str, default="./decoder/target_trans.txt", help="The path to target transcription. (default: %(default)s)") + parser.add_argument( + '--post_matrix_path', + type=str, + default=None, + help="The path to output post prob matrix. (default: %(default)s)") args = parser.parse_args() return args @@ -146,6 +156,16 @@ def get_trg_trans(args): return trans_dict +def out_post_matrix(key, prob): + with open(args.post_matrix_path, "a") as post_matrix: + post_matrix.write(key + " [\n") + for i in range(prob.shape[0]): + for j in range(prob.shape[1]): + post_matrix.write(str(prob[i][j]) + " ") + post_matrix.write("\n") + post_matrix.write("]\n") + + def infer_from_ckpt(args): """Inference by using checkpoint.""" @@ -174,13 +194,13 @@ def infer_from_ckpt(args): fluid.io.load_persistables(exe, args.checkpoint) # init decoder - decoder = Decoder(args.vocabulary, args.graphs, args.log_prior, - args.acoustic_scale) + decoder = Decoder(args.trans_model, args.vocabulary, args.graphs, + args.log_prior, args.acoustic_scale) ltrans = [ trans_add_delta.TransAddDelta(2, 2), trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var), - trans_splice.TransSplice(), trans_delay.TransDelay(5) + trans_splice.TransSplice(5, 5), trans_delay.TransDelay(5) ] feature_t = fluid.LoDTensor() @@ -197,6 +217,8 @@ def infer_from_ckpt(args): args.minimum_batch_size)): # load_data (features, labels, lod, name_lst) = batch_data + features = np.reshape(features, (-1, 11, 3, args.frame_dim)) + features = np.transpose(features, (0, 2, 1, 3)) feature_t.set(features, place) feature_t.set_lod([lod]) label_t.set(labels, place) @@ -216,6 +238,9 @@ def infer_from_ckpt(args): for index, sample in enumerate(infer_batch): key = name_lst[index] ref = trg_trans[key] + if args.post_matrix_path is not None: + out_post_matrix(key, sample) + ''' hyp = decoder.decode(key, sample) edit_dist, ref_len = char_errors(ref.decode("utf8"), hyp) total_edit_dist += edit_dist @@ -223,6 +248,8 @@ def infer_from_ckpt(args): print(key + "|Ref:", ref) print(key + "|Hyp:", hyp.encode("utf8")) print("Instance CER: ", edit_dist / ref_len) + ''' + print("batch: ", batch_id) print("Total CER = %f" % (total_edit_dist / total_ref_len)) -- GitLab