提交 989e6cd5 编写于 作者: Y Yibing Liu

Return decoding result instead of output directly

上级 84152a09
......@@ -117,8 +117,6 @@ std::vector<std::string> Decoder::decode(std::string posterior_rspecifier) {
for (; !posterior_reader.Done(); posterior_reader.Next()) {
std::string utt = posterior_reader.Key();
Matrix<BaseFloat> &loglikes(posterior_reader.Value());
KALDI_LOG << utt << " " << loglikes.NumRows() << " x "
<< loglikes.NumCols();
ret.push_back(decode(utt, loglikes));
}
}
......@@ -127,11 +125,20 @@ std::vector<std::string> Decoder::decode(std::string posterior_rspecifier) {
return ret;
} catch (const std::exception &e) {
std::cerr << e.what();
// ret.push_back("error");
return ret;
}
}
std::vector<std::string> Decoder::decode_batch(
std::vector<std::string> keys,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch) {
std::vector<std::string> decoding_results;
for (size_t i = 0; i < keys.size(); ++i) {
decoding_results.push_back(decode(keys[i], log_probs_batch[i]));
}
return decoding_results;
}
std::string Decoder::decode(
std::string key,
......@@ -167,25 +174,82 @@ std::string Decoder::decode(std::string key,
trans_model, loglikes, acoustic_scale);
double like;
if (DecodeUtteranceLatticeFaster(*decoder,
matrix_decodable,
trans_model,
word_syms,
key,
acoustic_scale,
determinize,
allow_partial,
alignment_writer,
words_writer,
&compact_lattice_writer,
&lattice_writer,
&like)) {
// tot_like += like;
// frame_count += loglikes.NumRows();
// num_success++;
decoding_result = "succeed!";
} else { // else num_fail++;
decoding_result = "fail!";
return this->DecodeUtteranceLatticeFaster(matrix_decodable, key, &like);
}
// Takes care of output. Returns true on success.
std::string Decoder::DecodeUtteranceLatticeFaster(
DecodableInterface &decodable, // not const but is really an input.
std::string utt,
double *like_ptr) { // puts utterance's like in like_ptr on success.
using fst::VectorFst;
if (!decoder->Decode(&decodable)) {
KALDI_WARN << "Failed to decode file " << utt;
return false;
}
if (!decoder->ReachedFinal()) {
if (allow_partial) {
KALDI_WARN << "Outputting partial output for utterance " << utt
<< " since no final-state reached\n";
} else {
KALDI_WARN << "Not producing output for utterance " << utt
<< " since no final-state reached and "
<< "--allow-partial=false.\n";
return false;
}
}
double likelihood;
LatticeWeight weight;
int32 num_frames;
std::string ret = utt + ' ';
{ // First do some stuff with word-level traceback...
VectorFst<LatticeArc> decoded;
if (!decoder->GetBestPath(&decoded))
// Shouldn't really reach this point as already checked success.
KALDI_ERR << "Failed to get traceback for utterance " << utt;
std::vector<int32> alignment;
std::vector<int32> words;
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
num_frames = alignment.size();
if (alignment_writer->IsOpen()) alignment_writer->Write(utt, alignment);
if (word_syms != NULL) {
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms->Find(words[i]);
ret += s + ' ';
}
}
likelihood = -(weight.Value1() + weight.Value2());
}
// Get lattice, and do determinization if requested.
Lattice lat;
decoder->GetRawLattice(&lat);
if (lat.NumStates() == 0)
KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt;
fst::Connect(&lat);
if (determinize) {
CompactLattice clat;
if (!DeterminizeLatticePhonePrunedWrapper(
trans_model,
&lat,
decoder->GetOptions().lattice_beam,
&clat,
decoder->GetOptions().det_opts))
KALDI_WARN << "Determinization finished earlier than the beam for "
<< "utterance " << utt;
// We'll write the lattice without acoustic scaling.
if (acoustic_scale != 0.0)
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat);
compact_lattice_writer.Write(utt, clat);
} else {
// We'll write the lattice without acoustic scaling.
if (acoustic_scale != 0.0)
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat);
lattice_writer.Write(utt, lat);
}
return decoding_result;
return ret;
}
......@@ -41,10 +41,19 @@ public:
std::string key,
const std::vector<std::vector<kaldi::BaseFloat>> &log_probs);
// Accept the scores of utterances in batch and return the decoding results
std::vector<std::string> decode_batch(
std::vector<std::string> key,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch);
private:
// For decoding one utterance
std::string decode(std::string key,
kaldi::Matrix<kaldi::BaseFloat> &loglikes);
std::string DecodeUtteranceLatticeFaster(kaldi::DecodableInterface &decodable,
std::string utt,
double *like_ptr);
fst::SymbolTable *word_syms;
fst::Fst<fst::StdArc> *decode_fst;
......
export CUDA_VISIBLE_DEVICES=0,1
python -u ../../infer_by_ckpt.py --batch_size 64 \
python -u ../../infer_by_ckpt.py --batch_size 48 \
--checkpoint deep_asr.pass_20.checkpoint \
--infer_feature_lst data/test_feature.lst \
--infer_label_lst data/test_label.lst \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册