未验证 提交 13890311 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #967 from kuke/infer_ckpt

Adapt decoder to the new net config
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "post_decode_faster.h"
typedef kaldi::int32 int32;
using fst::SymbolTable;
using fst::VectorFst;
using fst::StdArc;
Decoder::Decoder(std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_rxfilename,
kaldi::BaseFloat acoustic_scale) {
const char* usage =
"Decode, reading log-likelihoods (of transition-ids or whatever symbol "
"is on the graph) as matrices.";
kaldi::ParseOptions po(usage);
binary = true;
this->acoustic_scale = acoustic_scale;
allow_partial = true;
kaldi::FasterDecoderOptions decoder_opts;
decoder_opts.Register(&po, true); // true == include obscure settings.
po.Register("binary", &binary, "Write output in binary mode");
po.Register("allow-partial",
&allow_partial,
"Produce output even when final state was not reached");
po.Register("acoustic-scale",
&acoustic_scale,
"Scaling factor for acoustic likelihoods");
word_syms = NULL;
if (word_syms_filename != "") {
word_syms = fst::SymbolTable::ReadText(word_syms_filename);
if (!word_syms)
KALDI_ERR << "Could not read symbol table from file "
<< word_syms_filename;
}
std::ifstream is_logprior(logprior_rxfilename);
logprior.Read(is_logprior, false);
// It's important that we initialize decode_fst after loglikes_reader, as it
// can prevent crashes on systems installed without enough virtual memory.
// It has to do with what happens on UNIX systems if you call fork() on a
// large process: the page-table entries are duplicated, which requires a
// lot of virtual memory.
decode_fst = fst::ReadFstKaldi(fst_in_filename);
decoder = new kaldi::FasterDecoder(*decode_fst, decoder_opts);
}
Decoder::~Decoder() {
if (!word_syms) delete word_syms;
delete decode_fst;
delete decoder;
}
std::string Decoder::decode(
std::string key,
const std::vector<std::vector<kaldi::BaseFloat>>& log_probs) {
size_t num_frames = log_probs.size();
size_t dim_label = log_probs[0].size();
kaldi::Matrix<kaldi::BaseFloat> loglikes(
num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols);
for (size_t i = 0; i < num_frames; ++i) {
memcpy(loglikes.Data() + i * dim_label,
log_probs[i].data(),
sizeof(kaldi::BaseFloat) * dim_label);
}
return decode(key, loglikes);
}
std::vector<std::string> Decoder::decode(std::string posterior_rspecifier) {
kaldi::SequentialBaseFloatMatrixReader posterior_reader(posterior_rspecifier);
std::vector<std::string> decoding_results;
for (; !posterior_reader.Done(); posterior_reader.Next()) {
std::string key = posterior_reader.Key();
kaldi::Matrix<kaldi::BaseFloat> loglikes(posterior_reader.Value());
decoding_results.push_back(decode(key, loglikes));
}
return decoding_results;
}
std::string Decoder::decode(std::string key,
kaldi::Matrix<kaldi::BaseFloat>& loglikes) {
std::string decoding_result;
if (loglikes.NumRows() == 0) {
KALDI_WARN << "Zero-length utterance: " << key;
}
KALDI_ASSERT(loglikes.NumCols() == logprior.Dim());
loglikes.ApplyLog();
loglikes.AddVecToRows(-1.0, logprior);
kaldi::DecodableMatrixScaled decodable(loglikes, acoustic_scale);
decoder->Decode(&decodable);
VectorFst<kaldi::LatticeArc> decoded; // linear FST.
if ((allow_partial || decoder->ReachedFinal()) &&
decoder->GetBestPath(&decoded)) {
if (!decoder->ReachedFinal())
KALDI_WARN << "Decoder did not reach end-state, outputting partial "
"traceback.";
std::vector<int32> alignment;
std::vector<int32> words;
kaldi::LatticeWeight weight;
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
if (word_syms != NULL) {
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms->Find(words[i]);
decoding_result += s;
if (s == "")
KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
}
}
}
return decoding_result;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "post_latgen_faster_mapped.h"
#include <limits>
#include "ThreadPool.h"
using namespace kaldi;
typedef kaldi::int32 int32;
using fst::SymbolTable;
using fst::Fst;
using fst::StdArc;
Decoder::Decoder(std::string trans_model_in_filename,
std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_in_filename,
size_t beam_size,
kaldi::BaseFloat acoustic_scale) {
const char *usage =
"Generate lattices using neural net model.\n"
"Usage: post-latgen-faster-mapped [options] <trans-model> "
"<fst-in|fsts-rspecifier> <logprior> <posts-rspecifier>"
" <lattice-wspecifier> [ <words-wspecifier> [<alignments-wspecifier>] "
"]\n";
ParseOptions po(usage);
allow_partial = false;
this->acoustic_scale = acoustic_scale;
config.Register(&po);
int32 beam = 11;
po.Register("acoustic-scale",
&acoustic_scale,
"Scaling factor for acoustic likelihoods");
po.Register("word-symbol-table",
&word_syms_filename,
"Symbol table for words [for debug output]");
po.Register("allow-partial",
&allow_partial,
"If true, produce output even if end state was not reached.");
int argc = 2;
char *argv[] = {(char *)"post-latgen-faster-mapped",
(char *)("--beam=" + std::to_string(beam_size)).c_str()};
po.Read(argc, argv);
std::ifstream is_logprior(logprior_in_filename);
logprior.Read(is_logprior, false);
{
bool binary;
Input ki(trans_model_in_filename, &binary);
this->trans_model.Read(ki.Stream(), binary);
}
this->determinize = config.determinize_lattice;
this->word_syms = NULL;
if (word_syms_filename != "") {
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename))) {
KALDI_ERR << "Could not read symbol table from file "
<< word_syms_filename;
}
}
// Input FST is just one FST, not a table of FSTs.
this->decode_fst = fst::ReadFstKaldiGeneric(fst_in_filename);
kaldi::LatticeFasterDecoder *decoder =
new LatticeFasterDecoder(*decode_fst, config);
decoder_pool.emplace_back(decoder);
std::string lattice_wspecifier =
"ark:|gzip -c > mapped_decoder_data/lat.JOB.gz";
if (!(determinize ? compact_lattice_writer.Open(lattice_wspecifier)
: lattice_writer.Open(lattice_wspecifier)))
KALDI_ERR << "Could not open table for writing lattices: "
<< lattice_wspecifier;
words_writer = new Int32VectorWriter("");
alignment_writer = new Int32VectorWriter("");
}
Decoder::~Decoder() {
if (!this->word_syms) delete this->word_syms;
delete this->decode_fst;
for (size_t i = 0; i < decoder_pool.size(); ++i) {
delete decoder_pool[i];
}
delete words_writer;
delete alignment_writer;
}
void Decoder::decode_from_file(std::string posterior_rspecifier,
size_t num_processes) {
try {
double tot_like = 0.0;
kaldi::int64 frame_count = 0;
// int num_success = 0, num_fail = 0;
KALDI_ASSERT(ClassifyRspecifier(fst_in_filename, NULL, NULL) ==
kNoRspecifier);
SequentialBaseFloatMatrixReader posterior_reader("ark:" +
posterior_rspecifier);
Timer timer;
timer.Reset();
double elapsed = 0.0;
for (size_t n = decoder_pool.size(); n < num_processes; ++n) {
kaldi::LatticeFasterDecoder *decoder =
new LatticeFasterDecoder(*decode_fst, config);
decoder_pool.emplace_back(decoder);
}
elapsed = timer.Elapsed();
ThreadPool thread_pool(num_processes);
while (!posterior_reader.Done()) {
timer.Reset();
std::vector<std::future<std::string>> que;
for (size_t i = 0; i < num_processes && !posterior_reader.Done(); ++i) {
std::string utt = posterior_reader.Key();
Matrix<BaseFloat> &loglikes(posterior_reader.Value());
que.emplace_back(thread_pool.enqueue(std::bind(
&Decoder::decode_internal, this, decoder_pool[i], utt, loglikes)));
posterior_reader.Next();
}
timer.Reset();
for (size_t i = 0; i < que.size(); ++i) {
std::cout << que[i].get() << std::endl;
}
}
} catch (const std::exception &e) {
std::cerr << e.what();
}
}
inline kaldi::Matrix<kaldi::BaseFloat> vector2kaldi_mat(
const std::vector<std::vector<kaldi::BaseFloat>> &log_probs) {
size_t num_frames = log_probs.size();
size_t dim_label = log_probs[0].size();
kaldi::Matrix<kaldi::BaseFloat> loglikes(
num_frames, dim_label, kaldi::kSetZero, kaldi::kStrideEqualNumCols);
for (size_t i = 0; i < num_frames; ++i) {
memcpy(loglikes.Data() + i * dim_label,
log_probs[i].data(),
sizeof(kaldi::BaseFloat) * dim_label);
}
return loglikes;
}
std::vector<std::string> Decoder::decode_batch(
std::vector<std::string> keys,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch,
size_t num_processes) {
ThreadPool thread_pool(num_processes);
std::vector<std::string> decoding_results; //(keys.size(), "");
for (size_t n = decoder_pool.size(); n < num_processes; ++n) {
kaldi::LatticeFasterDecoder *decoder =
new LatticeFasterDecoder(*decode_fst, config);
decoder_pool.emplace_back(decoder);
}
size_t index = 0;
while (index < keys.size()) {
std::vector<std::future<std::string>> res_in_que;
for (size_t t = 0; t < num_processes && index < keys.size(); ++t) {
kaldi::Matrix<kaldi::BaseFloat> loglikes =
vector2kaldi_mat(log_probs_batch[index]);
res_in_que.emplace_back(
thread_pool.enqueue(std::bind(&Decoder::decode_internal,
this,
decoder_pool[t],
keys[index],
loglikes)));
index++;
}
for (size_t i = 0; i < res_in_que.size(); ++i) {
decoding_results.emplace_back(res_in_que[i].get());
}
}
return decoding_results;
}
std::string Decoder::decode(
std::string key,
const std::vector<std::vector<kaldi::BaseFloat>> &log_probs) {
kaldi::Matrix<kaldi::BaseFloat> loglikes = vector2kaldi_mat(log_probs);
return decode_internal(decoder_pool[0], key, loglikes);
}
std::string Decoder::decode_internal(
LatticeFasterDecoder *decoder,
std::string key,
kaldi::Matrix<kaldi::BaseFloat> &loglikes) {
if (loglikes.NumRows() == 0) {
KALDI_WARN << "Zero-length utterance: " << key;
// num_fail++;
}
KALDI_ASSERT(loglikes.NumCols() == logprior.Dim());
loglikes.ApplyLog();
loglikes.AddVecToRows(-1.0, logprior);
DecodableMatrixScaledMapped matrix_decodable(
trans_model, loglikes, acoustic_scale);
double like;
return this->DecodeUtteranceLatticeFaster(
decoder, matrix_decodable, key, &like);
}
std::string Decoder::DecodeUtteranceLatticeFaster(
LatticeFasterDecoder *decoder,
DecodableInterface &decodable, // not const but is really an input.
std::string utt,
double *like_ptr) { // puts utterance's like in like_ptr on success.
using fst::VectorFst;
std::string ret = utt + ' ';
if (!decoder->Decode(&decodable)) {
KALDI_WARN << "Failed to decode file " << utt;
return ret;
}
if (!decoder->ReachedFinal()) {
if (allow_partial) {
KALDI_WARN << "Outputting partial output for utterance " << utt
<< " since no final-state reached\n";
} else {
KALDI_WARN << "Not producing output for utterance " << utt
<< " since no final-state reached and "
<< "--allow-partial=false.\n";
return ret;
}
}
double likelihood;
LatticeWeight weight;
int32 num_frames;
{ // First do some stuff with word-level traceback...
VectorFst<LatticeArc> decoded;
if (!decoder->GetBestPath(&decoded))
// Shouldn't really reach this point as already checked success.
KALDI_ERR << "Failed to get traceback for utterance " << utt;
std::vector<int32> alignment;
std::vector<int32> words;
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
num_frames = alignment.size();
// if (alignment_writer->IsOpen()) alignment_writer->Write(utt, alignment);
if (word_syms != NULL) {
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms->Find(words[i]);
ret += s + ' ';
}
}
likelihood = -(weight.Value1() + weight.Value2());
}
// Get lattice, and do determinization if requested.
Lattice lat;
decoder->GetRawLattice(&lat);
if (lat.NumStates() == 0)
KALDI_ERR << "Unexpected problem getting lattice for utterance " << utt;
fst::Connect(&lat);
if (determinize) {
CompactLattice clat;
if (!DeterminizeLatticePhonePrunedWrapper(
trans_model,
&lat,
decoder->GetOptions().lattice_beam,
&clat,
decoder->GetOptions().det_opts))
KALDI_WARN << "Determinization finished earlier than the beam for "
<< "utterance " << utt;
// We'll write the lattice without acoustic scaling.
if (acoustic_scale != 0.0)
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &clat);
// disable output lattice temporarily
// compact_lattice_writer.Write(utt, clat);
} else {
// We'll write the lattice without acoustic scaling.
if (acoustic_scale != 0.0)
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &lat);
// lattice_writer.Write(utt, lat);
}
return ret;
}
......@@ -17,42 +17,64 @@ limitations under the License. */
#include "base/kaldi-common.h"
#include "base/timer.h"
#include "decoder/decodable-matrix.h"
#include "decoder/faster-decoder.h"
#include "fstext/fstext-lib.h"
#include "decoder/decoder-wrappers.h"
#include "fstext/kaldi-fst-io.h"
#include "hmm/transition-model.h"
#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc
#include "tree/context-dep.h"
#include "util/common-utils.h"
class Decoder {
public:
Decoder(std::string word_syms_filename,
Decoder(std::string trans_model_in_filename,
std::string word_syms_filename,
std::string fst_in_filename,
std::string logprior_rxfilename,
std::string logprior_in_filename,
size_t beam_size,
kaldi::BaseFloat acoustic_scale);
~Decoder();
// Interface to accept the scores read from specifier and return
// the batch decoding results
std::vector<std::string> decode(std::string posterior_rspecifier);
// Interface to accept the scores read from specifier and print
// the decoding results directly
void decode_from_file(std::string posterior_rspecifier,
size_t num_processes = 1);
// Accept the scores of one utterance and return the decoding result
std::string decode(
std::string key,
const std::vector<std::vector<kaldi::BaseFloat>> &log_probs);
// Accept the scores of utterances in batch and return the decoding results
std::vector<std::string> decode_batch(
std::vector<std::string> key,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>
&log_probs_batch,
size_t num_processes = 1);
private:
// For decoding one utterance
std::string decode(std::string key,
kaldi::Matrix<kaldi::BaseFloat> &loglikes);
std::string decode_internal(kaldi::LatticeFasterDecoder *decoder,
std::string key,
kaldi::Matrix<kaldi::BaseFloat> &loglikes);
std::string DecodeUtteranceLatticeFaster(kaldi::LatticeFasterDecoder *decoder,
kaldi::DecodableInterface &decodable,
std::string utt,
double *like_ptr);
fst::SymbolTable *word_syms;
fst::VectorFst<fst::StdArc> *decode_fst;
kaldi::FasterDecoder *decoder;
fst::Fst<fst::StdArc> *decode_fst;
std::vector<kaldi::LatticeFasterDecoder *> decoder_pool;
kaldi::Vector<kaldi::BaseFloat> logprior;
kaldi::TransitionModel trans_model;
kaldi::LatticeFasterDecoderConfig config;
kaldi::CompactLatticeWriter compact_lattice_writer;
kaldi::LatticeWriter lattice_writer;
kaldi::Int32VectorWriter *words_writer;
kaldi::Int32VectorWriter *alignment_writer;
bool binary;
bool determinize;
kaldi::BaseFloat acoustic_scale;
bool allow_partial;
};
......@@ -15,25 +15,37 @@ limitations under the License. */
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "post_decode_faster.h"
#include "post_latgen_faster_mapped.h"
namespace py = pybind11;
PYBIND11_MODULE(post_decode_faster, m) {
PYBIND11_MODULE(post_latgen_faster_mapped, m) {
m.doc() = "Decoder for Deep ASR model";
py::class_<Decoder>(m, "Decoder")
.def(py::init<std::string, std::string, std::string, kaldi::BaseFloat>())
.def("decode",
(std::vector<std::string> (Decoder::*)(std::string)) &
Decoder::decode,
.def(py::init<std::string,
std::string,
std::string,
std::string,
size_t,
kaldi::BaseFloat>())
.def("decode_from_file",
(void (Decoder::*)(std::string, size_t)) & Decoder::decode_from_file,
"Decode for the probability matrices in specifier "
"and return the transcriptions.")
"and print the transcriptions.")
.def(
"decode",
(std::string (Decoder::*)(
std::string, const std::vector<std::vector<kaldi::BaseFloat>>&)) &
Decoder::decode,
"Decode one input probability matrix "
"and return the transcription.");
"and return the transcription.")
.def("decode_batch",
(std::vector<std::string> (Decoder::*)(
std::vector<std::string>,
const std::vector<std::vector<std::vector<kaldi::BaseFloat>>>&,
size_t num_processes)) &
Decoder::decode_batch,
"Decode one batch of probability matrices "
"and return the transcriptions.");
}
......@@ -24,7 +24,7 @@ except:
"install kaldi and export KALDI_ROOT=<kaldi's root dir> .")
args = [
'-std=c++11', '-Wno-sign-compare', '-Wno-unused-variable',
'-std=c++11', '-fopenmp', '-Wno-sign-compare', '-Wno-unused-variable',
'-Wno-unused-local-typedefs', '-Wno-unused-but-set-variable',
'-Wno-deprecated-declarations', '-Wno-unused-function'
]
......@@ -49,11 +49,11 @@ LIB_DIRS = [os.path.abspath(path) for path in LIB_DIRS]
ext_modules = [
Extension(
'post_decode_faster',
['pybind.cc', 'post_decode_faster.cc'],
'post_latgen_faster_mapped',
['pybind.cc', 'post_latgen_faster_mapped.cc'],
include_dirs=[
'pybind11/include', '.', os.path.join(kaldi_root, 'src'),
os.path.join(kaldi_root, 'tools/openfst/src/include')
os.path.join(kaldi_root, 'tools/openfst/src/include'), 'ThreadPool'
],
language='c++',
libraries=LIBS,
......@@ -63,8 +63,8 @@ ext_modules = [
]
setup(
name='post_decode_faster',
version='0.0.1',
name='post_latgen_faster_mapped',
version='0.1.0',
author='Paddle',
author_email='',
description='Decoder for Deep ASR model',
......
......@@ -4,4 +4,9 @@ if [ ! -d pybind11 ]; then
git clone https://github.com/pybind/pybind11.git
fi
if [ ! -d ThreadPool ]; then
git clone https://github.com/progschj/ThreadPool.git
echo -e "\n"
fi
python setup.py build_ext -i
decode_to_path=./decoding_result.txt
export CUDA_VISIBLE_DEVICES=2,3,4,5
python -u ../../infer_by_ckpt.py --batch_size 96 \
--checkpoint checkpoints/deep_asr.pass_20.checkpoint \
--infer_feature_lst data/test_feature.lst \
--mean_var data/global_mean_var \
--frame_dim 80 \
--class_num 3040 \
--num_threads 24 \
--beam_size 11 \
--decode_to_path $decode_to_path \
--trans_model mapped_decoder_data/exp/tri5a/final.mdl \
--log_prior mapped_decoder_data/logprior \
--vocabulary mapped_decoder_data/exp/tri5a/graph/words.txt \
--graphs mapped_decoder_data/exp/tri5a/graph/HCLG.fst \
--acoustic_scale 0.059 \
--parallel
ref_txt=data/text.test
hyp_txt=decoding_result.txt
python ../../score_error_rate.py --error_rate_type cer --ref $ref_txt --hyp $hyp_txt
......@@ -14,10 +14,9 @@ import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.augmentor.trans_delay as trans_delay
import data_utils.async_data_reader as reader
from decoder.post_decode_faster import Decoder
from data_utils.util import lodtensor_to_ndarray
from data_utils.util import lodtensor_to_ndarray, split_infer_result
from model_utils.model import stacked_lstmp_model
from data_utils.util import split_infer_result
from decoder.post_latgen_faster_mapped import Decoder
from tools.error_rate import char_errors
......@@ -28,6 +27,11 @@ def parse_args():
type=int,
default=32,
help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--beam_size',
type=int,
default=11,
help='The beam size for decoding. (default: %(default)d)')
parser.add_argument(
'--minimum_batch_size',
type=int,
......@@ -60,10 +64,10 @@ def parse_args():
default=1749,
help='Number of classes in label. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
default=0.00016,
help='Learning rate used to train. (default: %(default)f)')
'--num_threads',
type=int,
default=10,
help='The number of threads for decoding. (default: %(default)d)')
parser.add_argument(
'--device',
type=str,
......@@ -75,7 +79,7 @@ def parse_args():
parser.add_argument(
'--mean_var',
type=str,
default='data/global_mean_var_search26kHr',
default='data/global_mean_var',
help="The path for feature's global mean and variance. "
"(default: %(default)s)")
parser.add_argument(
......@@ -83,35 +87,30 @@ def parse_args():
type=str,
default='data/infer_feature.lst',
help='The feature list path for inference. (default: %(default)s)')
parser.add_argument(
'--infer_label_lst',
type=str,
default='data/infer_label.lst',
help='The label list path for inference. (default: %(default)s)')
parser.add_argument(
'--ref_txt',
type=str,
default='data/text.test',
help='The reference text for decoding. (default: %(default)s)')
parser.add_argument(
'--checkpoint',
type=str,
default='./checkpoint',
help="The checkpoint path to init model. (default: %(default)s)")
parser.add_argument(
'--trans_model',
type=str,
default='./graph/trans_model',
help="The path to vocabulary. (default: %(default)s)")
parser.add_argument(
'--vocabulary',
type=str,
default='./decoder/graph/words.txt',
default='./graph/words.txt',
help="The path to vocabulary. (default: %(default)s)")
parser.add_argument(
'--graphs',
type=str,
default='./decoder/graph/TLG.fst',
default='./graph/TLG.fst',
help="The path to TLG graphs for decoding. (default: %(default)s)")
parser.add_argument(
'--log_prior',
type=str,
default="./decoder/logprior",
default="./logprior",
help="The log prior probs for training data. (default: %(default)s)")
parser.add_argument(
'--acoustic_scale',
......@@ -119,10 +118,16 @@ def parse_args():
default=0.2,
help="Scaling factor for acoustic likelihoods. (default: %(default)f)")
parser.add_argument(
'--target_trans',
'--post_matrix_path',
type=str,
default=None,
help="The path to output post prob matrix. (default: %(default)s)")
parser.add_argument(
'--decode_to_path',
type=str,
default="./decoder/target_trans.txt",
help="The path to target transcription. (default: %(default)s)")
default='./decoding_result.txt',
required=True,
help="The path to output the decoding result. (default: %(default)s)")
args = parser.parse_args()
return args
......@@ -134,16 +139,47 @@ def print_arguments(args):
print('------------------------------------------------')
def get_trg_trans(args):
trans_dict = {}
with open(args.target_trans) as trg_trans:
line = trg_trans.readline()
while line:
items = line.strip().split()
key = items[0]
trans_dict[key] = ''.join(items[1:])
line = trg_trans.readline()
return trans_dict
class PostMatrixWriter:
""" The writer for outputing the post probability matrix
"""
def __init__(self, to_path):
self._to_path = to_path
with open(self._to_path, "w") as post_matrix:
post_matrix.seek(0)
post_matrix.truncate()
def write(self, keys, probs):
with open(self._to_path, "a") as post_matrix:
if isinstance(keys, str):
keys, probs = [keys], [probs]
for key, prob in zip(keys, probs):
post_matrix.write(key + " [\n")
for i in range(prob.shape[0]):
for j in range(prob.shape[1]):
post_matrix.write(str(prob[i][j]) + " ")
post_matrix.write("\n")
post_matrix.write("]\n")
class DecodingResultWriter:
""" The writer for writing out decoding results
"""
def __init__(self, to_path):
self._to_path = to_path
with open(self._to_path, "w") as decoding_result:
decoding_result.seek(0)
decoding_result.truncate()
def write(self, results):
with open(self._to_path, "a") as decoding_result:
if isinstance(results, str):
decoding_result.write(results.encode("utf8") + "\n")
else:
for result in results:
decoding_result.write(result.encode("utf8") + "\n")
def infer_from_ckpt(args):
......@@ -162,9 +198,10 @@ def infer_from_ckpt(args):
infer_program = fluid.default_main_program().clone()
# optimizer, placeholder
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.exponential_decay(
learning_rate=args.learning_rate,
learning_rate=0.0001,
decay_steps=1879,
decay_rate=1 / 1.2,
staircase=True))
......@@ -174,34 +211,38 @@ def infer_from_ckpt(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
trg_trans = get_trg_trans(args)
# load checkpoint.
fluid.io.load_persistables(exe, args.checkpoint)
# init decoder
decoder = Decoder(args.vocabulary, args.graphs, args.log_prior,
args.acoustic_scale)
decoder = Decoder(args.trans_model, args.vocabulary, args.graphs,
args.log_prior, args.beam_size, args.acoustic_scale)
ltrans = [
trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
trans_splice.TransSplice(), trans_delay.TransDelay(5)
trans_splice.TransSplice(5, 5), trans_delay.TransDelay(5)
]
feature_t = fluid.LoDTensor()
label_t = fluid.LoDTensor()
# infer data reader
infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst,
args.infer_label_lst)
infer_data_reader = reader.AsyncDataReader(
args.infer_feature_lst, drop_frame_len=-1, split_sentence_threshold=-1)
infer_data_reader.set_transformers(ltrans)
infer_costs, infer_accs = [], []
total_edit_dist, total_ref_len = 0.0, 0
decoding_result_writer = DecodingResultWriter(args.decode_to_path)
post_matrix_writer = None if args.post_matrix_path is None \
else PostMatrixWriter(args.post_matrix_path)
for batch_id, batch_data in enumerate(
infer_data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)):
# load_data
(features, labels, lod, name_lst) = batch_data
features = np.reshape(features, (-1, 11, 3, args.frame_dim))
features = np.transpose(features, (0, 2, 1, 3))
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
......@@ -212,24 +253,17 @@ def infer_from_ckpt(args):
"label": label_t},
fetch_list=[prediction, avg_cost, accuracy],
return_numpy=False)
infer_costs.append(lodtensor_to_ndarray(results[1])[0])
infer_accs.append(lodtensor_to_ndarray(results[2])[0])
probs, lod = lodtensor_to_ndarray(results[0])
infer_batch = split_infer_result(probs, lod)
for index, sample in enumerate(infer_batch):
key = name_lst[index]
ref = trg_trans[key]
hyp = decoder.decode(key, sample)
edit_dist, ref_len = char_errors(ref.decode("utf8"), hyp)
total_edit_dist += edit_dist
total_ref_len += ref_len
print(key + "|Ref:", ref)
print(key + "|Hyp:", hyp.encode("utf8"))
print("Instance CER: ", edit_dist / ref_len)
print("Total CER = %f" % (total_edit_dist / total_ref_len))
print("Decoding batch %d ..." % batch_id)
decoded = decoder.decode_batch(name_lst, infer_batch, args.num_threads)
decoding_result_writer.write(decoded)
if args.post_matrix_path is not None:
post_matrix_writer.write(name_lst, infer_batch)
if __name__ == '__main__':
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from tools.error_rate import char_errors, word_errors
def parse_args():
parser = argparse.ArgumentParser(
"Score word/character error rate (WER/CER) "
"for decoding result.")
parser.add_argument(
'--error_rate_type',
type=str,
default='cer',
choices=['cer', 'wer'],
help="Error rate type. (default: %(default)s)")
parser.add_argument(
'--ref', type=str, required=True, help="The ground truth text.")
parser.add_argument(
'--hyp', type=str, required=True, help="The decoding result.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
ref_dict = {}
sum_errors, sum_ref_len = 0.0, 0
sent_cnt, not_in_ref_cnt = 0, 0
with open(args.ref, "r") as ref_txt:
line = ref_txt.readline()
while line:
del_pos = line.find(" ")
key, sent = line[0:del_pos], line[del_pos + 1:-1].strip()
ref_dict[key] = sent
line = ref_txt.readline()
with open(args.hyp, "r") as hyp_txt:
line = hyp_txt.readline()
while line:
del_pos = line.find(" ")
key, sent = line[0:del_pos], line[del_pos + 1:-1].strip()
sent_cnt += 1
line = hyp_txt.readline()
if key not in ref_dict:
not_in_ref_cnt += 1
continue
if args.error_rate_type == 'cer':
errors, ref_len = char_errors(
ref_dict[key].decode("utf8"),
sent.decode("utf8"),
remove_space=True)
else:
errors, ref_len = word_errors(ref_dict[key].decode("utf8"),
sent.decode("utf8"))
sum_errors += errors
sum_ref_len += ref_len
print("Error rate[%s] = %f (%d/%d)," %
(args.error_rate_type, sum_errors / sum_ref_len, int(sum_errors),
sum_ref_len))
print("total %d sentences in hyp, %d not presented in ref." %
(sent_cnt, not_in_ref_cnt))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册