提交 6bc445f2 编写于 作者: Y Yibing Liu

refine the interface of decoders in swig

上级 a840f854
......@@ -10,8 +10,8 @@ import multiprocessing
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2
from swig_ctc_beam_search_decoder import *
from swig_scorer import Scorer
from deploy.swig_decoders import *
from swig_scorer import LmScorer
from error_rate import wer
import utils
import time
......@@ -85,7 +85,7 @@ parser.add_argument(
help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument(
"--language_model_path",
default="lm/data/en.00.UNKNOWN.klm",
default="lm/data/common_crawl_00.prune01111.trie.klm",
type=str,
help="Path for language model. (default: %(default)s)")
parser.add_argument(
......@@ -164,19 +164,19 @@ def infer():
]
# external scorer
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path)
## decode and print
time_begin = time.time()
wer_sum, wer_counter = 0, 0
for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder(
probs.tolist(),
args.beam_size,
data_generator.vocab_list,
len(data_generator.vocab_list),
args.cutoff_prob,
ext_scorer, )
probs_seq=probs,
beam_size=args.beam_size,
vocabulary=data_generator.vocab_list,
blank_id=len(data_generator.vocab_list),
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
print("\nTarget Transcription:\t%s" % target_transcription[i])
print("Beam %d: %f \t%s" % (0, beam_result[0][0], beam_result[0][1]))
......
......@@ -4,9 +4,9 @@
#include <utility>
#include <cmath>
#include <limits>
#include "ctc_beam_search_decoder.h"
#include "ctc_decoders.h"
typedef float log_prob_type;
typedef double log_prob_type;
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
......@@ -24,8 +24,8 @@ template <typename T>
T log_sum_exp(T x, T y)
{
static T num_min = -std::numeric_limits<T>::max();
if (x <= -num_min) return y;
if (y <= -num_min) return x;
if (x <= num_min) return y;
if (y <= num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}
......@@ -55,17 +55,13 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
}
}
max_idx_vec.push_back(max_idx);
std::cout<<max_idx<<",";
max_prob = 0.0;
max_idx = 0;
}
std::cout<<std::endl;
std::vector<int> idx_vec;
for (int i=0; i<max_idx_vec.size(); i++) {
std::cout<<max_idx_vec[i]<<",";
if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) {
std::cout<<max_idx_vec[i]<<",";
idx_vec.push_back(max_idx_vec[i]);
}
}
......@@ -73,7 +69,7 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::string best_path_result;
for (int i=0; i<idx_vec.size(); i++) {
if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[i];
best_path_result += vocabulary[idx_vec[i]];
}
}
return best_path_result;
......@@ -85,21 +81,21 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string> vocabulary,
int blank_id,
double cutoff_prob,
Scorer *ext_scorer,
LmScorer *ext_scorer,
bool nproc) {
// dimension check
int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) {
std::cout<<"The shape of probs_seq does not match"
<<" with the shape of the vocabulary!"<<std::endl;
std::cout << " The shape of probs_seq does not match"
<< " with the shape of the vocabulary!" << std::endl;
exit(1);
}
}
// blank_id check
if (blank_id > vocabulary.size()) {
std::cout<<"Invalid blank_id!"<<std::endl;
std::cout << " Invalid blank_id! " << std::endl;
exit(1);
}
......@@ -108,7 +104,7 @@ std::vector<std::pair<double, std::string> >
vocabulary.end(), " ");
int space_id = it - vocabulary.begin();
if(space_id >= vocabulary.size()) {
std::cout<<"The character space is not in the vocabulary!"<<std::endl;
std::cout << " The character space is not in the vocabulary!"<<std::endl;
exit(1);
}
......
......@@ -28,10 +28,19 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string> vocabulary,
int blank_id,
double cutoff_prob=1.0,
Scorer *ext_scorer=NULL,
LmScorer *ext_scorer=NULL,
bool nproc=false
);
/* CTC Best Path Decoder
*
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary);
......
%module swig_ctc_beam_search_decoder
%module swig_ctc_decoders
%{
#include "ctc_beam_search_decoder.h"
#include "ctc_decoders.h"
%}
%include "std_vector.i"
......@@ -19,4 +19,4 @@ namespace std{
}
%import scorer.h
%include "ctc_beam_search_decoder.h"
%include "ctc_decoders.h"
......@@ -34,15 +34,13 @@ if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma')
os.system('swig -python -c++ ./ctc_beam_search_decoder.i')
os.system('swig -python -c++ ./ctc_decoders.i')
ctc_beam_search_decoder_module = [
Extension(
name='_swig_ctc_beam_search_decoder',
sources=FILES + [
'scorer.cpp', 'ctc_beam_search_decoder_wrap.cxx',
'ctc_beam_search_decoder.cpp'
],
name='_swig_ctc_decoders',
sources=FILES +
['scorer.cpp', 'ctc_decoders_wrap.cxx', 'ctc_decoders.cpp'],
language='C++',
include_dirs=['.', './kenlm'],
libraries=LIBS,
......@@ -50,8 +48,8 @@ ctc_beam_search_decoder_module = [
]
setup(
name='swig_ctc_beam_search_decoder',
name='swig_ctc_decoders',
version='0.1',
description="""CTC beam search decoder""",
description="""CTC decoders""",
ext_modules=ctc_beam_search_decoder_module,
py_modules=['swig_ctc_beam_search_decoder'], )
py_modules=['swig_ctc_decoders'], )
......@@ -7,7 +7,7 @@
using namespace lm::ngram;
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
LmScorer::LmScorer(float alpha, float beta, std::string lm_model_path) {
this->_alpha = alpha;
this->_beta = beta;
......@@ -18,7 +18,7 @@ Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
this->_language_model = LoadVirtual(lm_model_path.c_str());
}
Scorer::~Scorer(){
LmScorer::~LmScorer(){
delete (lm::base::Model *)this->_language_model;
}
......@@ -57,7 +57,7 @@ inline void strip(std::string &str, char ch=' ') {
}
}
int Scorer::word_count(std::string sentence) {
int LmScorer::word_count(std::string sentence) {
strip(sentence);
int cnt = 1;
for (int i=0; i<sentence.size(); i++) {
......@@ -68,7 +68,7 @@ int Scorer::word_count(std::string sentence) {
return cnt;
}
double Scorer::language_model_score(std::string sentence) {
double LmScorer::language_model_score(std::string sentence) {
lm::base::Model *model = (lm::base::Model *)this->_language_model;
State state, out_state;
lm::FullScoreReturn ret;
......@@ -84,12 +84,12 @@ double Scorer::language_model_score(std::string sentence) {
return log_prob;
}
void Scorer::reset_params(float alpha, float beta) {
void LmScorer::reset_params(float alpha, float beta) {
this->_alpha = alpha;
this->_beta = beta;
}
double Scorer::get_score(std::string sentence, bool log) {
double LmScorer::get_score(std::string sentence, bool log) {
double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence);
......
......@@ -8,10 +8,10 @@
* count and language model scoring.
* Example:
* Scorer ext_scorer(alpha, beta, "path_to_language_model.klm");
* LmScorer ext_scorer(alpha, beta, "path_to_language_model.klm");
* double score = ext_scorer.get_score("sentence_to_score");
*/
class Scorer{
class LmScorer{
private:
float _alpha;
float _beta;
......@@ -23,9 +23,9 @@ private:
double language_model_score(std::string);
public:
Scorer(){}
Scorer(float alpha, float beta, std::string lm_model_path);
~Scorer();
LmScorer(){}
LmScorer(float alpha, float beta, std::string lm_model_path);
~LmScorer();
// reset params alpha & beta
void reset_params(float alpha, float beta);
......
"""Wrapper for various CTC decoders in SWIG."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import swig_ctc_decoders
import multiprocessing
def ctc_best_path_decoder(probs_seq, vocabulary):
"""Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: basestring
"""
return swig_ctc_decoders.ctc_best_path_decoder(probs_seq.tolist(),
vocabulary)
def ctc_beam_search_decoder(
probs_seq,
beam_size,
vocabulary,
blank_id,
cutoff_prob=1.0,
ext_scoring_func=None, ):
"""Wrapper for CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
return swig_ctc_decoders.ctc_beam_search_decoder(
probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob,
ext_scoring_func)
def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
blank_id,
num_processes,
cutoff_prob=1.0,
ext_scoring_func=None):
"""Wrapper for CTC beam search decoder in batch
"""
# TODO: to resolve PicklingError
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
ext_scoring_func)
results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close()
pool.join()
beam_search_results = [result.get() for result in results]
return beam_search_results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册