提交 6bc445f2 编写于 作者: Y Yibing Liu

refine the interface of decoders in swig

上级 a840f854
...@@ -10,8 +10,8 @@ import multiprocessing ...@@ -10,8 +10,8 @@ import multiprocessing
import paddle.v2 as paddle import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from swig_ctc_beam_search_decoder import * from deploy.swig_decoders import *
from swig_scorer import Scorer from swig_scorer import LmScorer
from error_rate import wer from error_rate import wer
import utils import utils
import time import time
...@@ -85,7 +85,7 @@ parser.add_argument( ...@@ -85,7 +85,7 @@ parser.add_argument(
help="Number of output per sample in beam search. (default: %(default)d)") help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="lm/data/en.00.UNKNOWN.klm", default="lm/data/common_crawl_00.prune01111.trie.klm",
type=str, type=str,
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -164,19 +164,19 @@ def infer(): ...@@ -164,19 +164,19 @@ def infer():
] ]
# external scorer # external scorer
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path)
## decode and print ## decode and print
time_begin = time.time() time_begin = time.time()
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
beam_result = ctc_beam_search_decoder( beam_result = ctc_beam_search_decoder(
probs.tolist(), probs_seq=probs,
args.beam_size, beam_size=args.beam_size,
data_generator.vocab_list, vocabulary=data_generator.vocab_list,
len(data_generator.vocab_list), blank_id=len(data_generator.vocab_list),
args.cutoff_prob, cutoff_prob=args.cutoff_prob,
ext_scorer, ) ext_scoring_func=ext_scorer, )
print("\nTarget Transcription:\t%s" % target_transcription[i]) print("\nTarget Transcription:\t%s" % target_transcription[i])
print("Beam %d: %f \t%s" % (0, beam_result[0][0], beam_result[0][1])) print("Beam %d: %f \t%s" % (0, beam_result[0][0], beam_result[0][1]))
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
#include <utility> #include <utility>
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include "ctc_beam_search_decoder.h" #include "ctc_decoders.h"
typedef float log_prob_type; typedef double log_prob_type;
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
...@@ -24,8 +24,8 @@ template <typename T> ...@@ -24,8 +24,8 @@ template <typename T>
T log_sum_exp(T x, T y) T log_sum_exp(T x, T y)
{ {
static T num_min = -std::numeric_limits<T>::max(); static T num_min = -std::numeric_limits<T>::max();
if (x <= -num_min) return y; if (x <= num_min) return y;
if (y <= -num_min) return x; if (y <= num_min) return x;
T xmax = std::max(x, y); T xmax = std::max(x, y);
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
} }
...@@ -55,17 +55,13 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, ...@@ -55,17 +55,13 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
} }
} }
max_idx_vec.push_back(max_idx); max_idx_vec.push_back(max_idx);
std::cout<<max_idx<<",";
max_prob = 0.0; max_prob = 0.0;
max_idx = 0; max_idx = 0;
} }
std::cout<<std::endl;
std::vector<int> idx_vec; std::vector<int> idx_vec;
for (int i=0; i<max_idx_vec.size(); i++) { for (int i=0; i<max_idx_vec.size(); i++) {
std::cout<<max_idx_vec[i]<<",";
if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) { if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) {
std::cout<<max_idx_vec[i]<<",";
idx_vec.push_back(max_idx_vec[i]); idx_vec.push_back(max_idx_vec[i]);
} }
} }
...@@ -73,7 +69,7 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, ...@@ -73,7 +69,7 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::string best_path_result; std::string best_path_result;
for (int i=0; i<idx_vec.size(); i++) { for (int i=0; i<idx_vec.size(); i++) {
if (idx_vec[i] != blank_id) { if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[i]; best_path_result += vocabulary[idx_vec[i]];
} }
} }
return best_path_result; return best_path_result;
...@@ -85,21 +81,21 @@ std::vector<std::pair<double, std::string> > ...@@ -85,21 +81,21 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
double cutoff_prob, double cutoff_prob,
Scorer *ext_scorer, LmScorer *ext_scorer,
bool nproc) { bool nproc) {
// dimension check // dimension check
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) { for (int i=0; i<num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) { if (probs_seq[i].size() != vocabulary.size()+1) {
std::cout<<"The shape of probs_seq does not match" std::cout << " The shape of probs_seq does not match"
<<" with the shape of the vocabulary!"<<std::endl; << " with the shape of the vocabulary!" << std::endl;
exit(1); exit(1);
} }
} }
// blank_id check // blank_id check
if (blank_id > vocabulary.size()) { if (blank_id > vocabulary.size()) {
std::cout<<"Invalid blank_id!"<<std::endl; std::cout << " Invalid blank_id! " << std::endl;
exit(1); exit(1);
} }
...@@ -108,7 +104,7 @@ std::vector<std::pair<double, std::string> > ...@@ -108,7 +104,7 @@ std::vector<std::pair<double, std::string> >
vocabulary.end(), " "); vocabulary.end(), " ");
int space_id = it - vocabulary.begin(); int space_id = it - vocabulary.begin();
if(space_id >= vocabulary.size()) { if(space_id >= vocabulary.size()) {
std::cout<<"The character space is not in the vocabulary!"<<std::endl; std::cout << " The character space is not in the vocabulary!"<<std::endl;
exit(1); exit(1);
} }
......
...@@ -28,10 +28,19 @@ std::vector<std::pair<double, std::string> > ...@@ -28,10 +28,19 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
double cutoff_prob=1.0, double cutoff_prob=1.0,
Scorer *ext_scorer=NULL, LmScorer *ext_scorer=NULL,
bool nproc=false bool nproc=false
); );
/* CTC Best Path Decoder /* CTC Best Path Decoder
*
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/ */
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary); std::vector<std::string> vocabulary);
......
%module swig_ctc_beam_search_decoder %module swig_ctc_decoders
%{ %{
#include "ctc_beam_search_decoder.h" #include "ctc_decoders.h"
%} %}
%include "std_vector.i" %include "std_vector.i"
...@@ -19,4 +19,4 @@ namespace std{ ...@@ -19,4 +19,4 @@ namespace std{
} }
%import scorer.h %import scorer.h
%include "ctc_beam_search_decoder.h" %include "ctc_decoders.h"
...@@ -34,15 +34,13 @@ if compile_test('lzma.h', 'lzma'): ...@@ -34,15 +34,13 @@ if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB') ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma') LIBS.append('lzma')
os.system('swig -python -c++ ./ctc_beam_search_decoder.i') os.system('swig -python -c++ ./ctc_decoders.i')
ctc_beam_search_decoder_module = [ ctc_beam_search_decoder_module = [
Extension( Extension(
name='_swig_ctc_beam_search_decoder', name='_swig_ctc_decoders',
sources=FILES + [ sources=FILES +
'scorer.cpp', 'ctc_beam_search_decoder_wrap.cxx', ['scorer.cpp', 'ctc_decoders_wrap.cxx', 'ctc_decoders.cpp'],
'ctc_beam_search_decoder.cpp'
],
language='C++', language='C++',
include_dirs=['.', './kenlm'], include_dirs=['.', './kenlm'],
libraries=LIBS, libraries=LIBS,
...@@ -50,8 +48,8 @@ ctc_beam_search_decoder_module = [ ...@@ -50,8 +48,8 @@ ctc_beam_search_decoder_module = [
] ]
setup( setup(
name='swig_ctc_beam_search_decoder', name='swig_ctc_decoders',
version='0.1', version='0.1',
description="""CTC beam search decoder""", description="""CTC decoders""",
ext_modules=ctc_beam_search_decoder_module, ext_modules=ctc_beam_search_decoder_module,
py_modules=['swig_ctc_beam_search_decoder'], ) py_modules=['swig_ctc_decoders'], )
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
using namespace lm::ngram; using namespace lm::ngram;
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) { LmScorer::LmScorer(float alpha, float beta, std::string lm_model_path) {
this->_alpha = alpha; this->_alpha = alpha;
this->_beta = beta; this->_beta = beta;
...@@ -18,7 +18,7 @@ Scorer::Scorer(float alpha, float beta, std::string lm_model_path) { ...@@ -18,7 +18,7 @@ Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
this->_language_model = LoadVirtual(lm_model_path.c_str()); this->_language_model = LoadVirtual(lm_model_path.c_str());
} }
Scorer::~Scorer(){ LmScorer::~LmScorer(){
delete (lm::base::Model *)this->_language_model; delete (lm::base::Model *)this->_language_model;
} }
...@@ -57,7 +57,7 @@ inline void strip(std::string &str, char ch=' ') { ...@@ -57,7 +57,7 @@ inline void strip(std::string &str, char ch=' ') {
} }
} }
int Scorer::word_count(std::string sentence) { int LmScorer::word_count(std::string sentence) {
strip(sentence); strip(sentence);
int cnt = 1; int cnt = 1;
for (int i=0; i<sentence.size(); i++) { for (int i=0; i<sentence.size(); i++) {
...@@ -68,7 +68,7 @@ int Scorer::word_count(std::string sentence) { ...@@ -68,7 +68,7 @@ int Scorer::word_count(std::string sentence) {
return cnt; return cnt;
} }
double Scorer::language_model_score(std::string sentence) { double LmScorer::language_model_score(std::string sentence) {
lm::base::Model *model = (lm::base::Model *)this->_language_model; lm::base::Model *model = (lm::base::Model *)this->_language_model;
State state, out_state; State state, out_state;
lm::FullScoreReturn ret; lm::FullScoreReturn ret;
...@@ -84,12 +84,12 @@ double Scorer::language_model_score(std::string sentence) { ...@@ -84,12 +84,12 @@ double Scorer::language_model_score(std::string sentence) {
return log_prob; return log_prob;
} }
void Scorer::reset_params(float alpha, float beta) { void LmScorer::reset_params(float alpha, float beta) {
this->_alpha = alpha; this->_alpha = alpha;
this->_beta = beta; this->_beta = beta;
} }
double Scorer::get_score(std::string sentence, bool log) { double LmScorer::get_score(std::string sentence, bool log) {
double lm_score = language_model_score(sentence); double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence); int word_cnt = word_count(sentence);
......
...@@ -8,10 +8,10 @@ ...@@ -8,10 +8,10 @@
* count and language model scoring. * count and language model scoring.
* Example: * Example:
* Scorer ext_scorer(alpha, beta, "path_to_language_model.klm"); * LmScorer ext_scorer(alpha, beta, "path_to_language_model.klm");
* double score = ext_scorer.get_score("sentence_to_score"); * double score = ext_scorer.get_score("sentence_to_score");
*/ */
class Scorer{ class LmScorer{
private: private:
float _alpha; float _alpha;
float _beta; float _beta;
...@@ -23,9 +23,9 @@ private: ...@@ -23,9 +23,9 @@ private:
double language_model_score(std::string); double language_model_score(std::string);
public: public:
Scorer(){} LmScorer(){}
Scorer(float alpha, float beta, std::string lm_model_path); LmScorer(float alpha, float beta, std::string lm_model_path);
~Scorer(); ~LmScorer();
// reset params alpha & beta // reset params alpha & beta
void reset_params(float alpha, float beta); void reset_params(float alpha, float beta);
......
"""Wrapper for various CTC decoders in SWIG."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import swig_ctc_decoders
import multiprocessing
def ctc_best_path_decoder(probs_seq, vocabulary):
"""Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: basestring
"""
return swig_ctc_decoders.ctc_best_path_decoder(probs_seq.tolist(),
vocabulary)
def ctc_beam_search_decoder(
probs_seq,
beam_size,
vocabulary,
blank_id,
cutoff_prob=1.0,
ext_scoring_func=None, ):
"""Wrapper for CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
return swig_ctc_decoders.ctc_beam_search_decoder(
probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob,
ext_scoring_func)
def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
blank_id,
num_processes,
cutoff_prob=1.0,
ext_scoring_func=None):
"""Wrapper for CTC beam search decoder in batch
"""
# TODO: to resolve PicklingError
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
ext_scoring_func)
results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close()
pool.join()
beam_search_results = [result.get() for result in results]
return beam_search_results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册