diff --git a/deploy.py b/deploy.py index 091d82892bf2efd12b2b1c85b42fbd96eca3cc12..2d29973fbbdd41fb61a4775d300c499c030e9ae8 100644 --- a/deploy.py +++ b/deploy.py @@ -10,8 +10,7 @@ import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 -from deploy.swig_decoders import * -from swig_scorer import Scorer +from deploy.swig_decoders_wrapper import * from error_rate import wer import utils import time @@ -164,7 +163,8 @@ def infer(): ] # external scorer - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) + ext_scorer = Scorer( + alpha=args.alpha, beta=args.beta, model_path=args.language_model_path) ## decode and print time_begin = time.time() diff --git a/deploy/README.md b/deploy/README.md index c8dbd1c125ca67b7b46e1a6780c8b19d28377671..cf0c04391cf202ee16102e1f168df8590734613c 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -1,19 +1,16 @@ ### Installation -The setup of the decoder for deployment depends on the source code of [kenlm](https://github.com/kpu/kenlm/), first clone it to current directory (i.e., `deep_speech_2/deploy`) +The setup of the decoder for deployment depends on the source code of [kenlm](https://github.com/kpu/kenlm/) and [openfst](http://www.openfst.org/twiki/bin/view/FST/WebHome), first clone kenlm and download openfst to current directory (i.e., `deep_speech_2/deploy`) ```shell git clone https://github.com/kpu/kenlm.git +wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz +tar -xzvf openfst-1.6.3.tar.gz ``` Then run the setup ```shell -sh setup.sh -``` - -After the installation succeeds, go back to the parent directory - -``` +python setup.py install cd .. ``` diff --git a/deploy/ctc_decoders.i b/deploy/decoders.i similarity index 91% rename from deploy/ctc_decoders.i rename to deploy/decoders.i index 8c9dd1643d994dfd0e3a713acd96fe27301491aa..04736e09e8047749cfd09d1f91f85701226dfda9 100644 --- a/deploy/ctc_decoders.i +++ b/deploy/decoders.i @@ -1,5 +1,6 @@ -%module swig_ctc_decoders +%module swig_decoders %{ +#include "scorer.h" #include "ctc_decoders.h" %} @@ -18,6 +19,6 @@ namespace std{ %template(PairDoubleStringVector) std::vector >; } -%import scorer.h %import decoder_utils.h +%include "scorer.h" %include "ctc_decoders.h" diff --git a/deploy/scorer.i b/deploy/scorer.i deleted file mode 100644 index 8380e15a609ef23e1e31de2b7fa3b949ff0929a7..0000000000000000000000000000000000000000 --- a/deploy/scorer.i +++ /dev/null @@ -1,8 +0,0 @@ -%module swig_scorer -%{ -#include "scorer.h" -%} - -%include "std_string.i" - -%include "scorer.h" diff --git a/deploy/scorer_setup.py b/deploy/scorer_setup.py deleted file mode 100644 index 3bb582724a23822856ceb50a2f736556216af53f..0000000000000000000000000000000000000000 --- a/deploy/scorer_setup.py +++ /dev/null @@ -1,54 +0,0 @@ -from setuptools import setup, Extension -import glob -import platform -import os - - -def compile_test(header, library): - dummy_path = os.path.join(os.path.dirname(__file__), "dummy") - command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\"" - return os.system(command) == 0 - - -FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob( - 'kenlm/util/double-conversion/*.cc') -FILES = [ - fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')) -] - -LIBS = ['stdc++'] -if platform.system() != 'Darwin': - LIBS.append('rt') - -ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6'] - -if compile_test('zlib.h', 'z'): - ARGS.append('-DHAVE_ZLIB') - LIBS.append('z') - -if compile_test('bzlib.h', 'bz2'): - ARGS.append('-DHAVE_BZLIB') - LIBS.append('bz2') - -if compile_test('lzma.h', 'lzma'): - ARGS.append('-DHAVE_XZLIB') - LIBS.append('lzma') - -os.system('swig -python -c++ ./scorer.i') - -ext_modules = [ - Extension( - name='_swig_scorer', - sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'], - language='C++', - include_dirs=['.', './kenlm'], - libraries=LIBS, - extra_compile_args=ARGS) -] - -setup( - name='swig_scorer', - version='0.1', - ext_modules=ext_modules, - include_package_data=True, - py_modules=['swig_scorer'], ) diff --git a/deploy/decoder_setup.py b/deploy/setup.py similarity index 75% rename from deploy/decoder_setup.py rename to deploy/setup.py index 146538f557f727ae798ab8b9322bac2b646531ca..077cabd0867d2b9245fb42c4667fc6cdd2f8f958 100644 --- a/deploy/decoder_setup.py +++ b/deploy/setup.py @@ -20,7 +20,7 @@ LIBS = ['stdc++'] if platform.system() != 'Darwin': LIBS.append('rt') -ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6'] +ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11'] if compile_test('zlib.h', 'z'): ARGS.append('-DHAVE_ZLIB') @@ -34,24 +34,21 @@ if compile_test('lzma.h', 'lzma'): ARGS.append('-DHAVE_XZLIB') LIBS.append('lzma') -os.system('swig -python -c++ ./ctc_decoders.i') +os.system('swig -python -c++ ./decoders.i') ctc_beam_search_decoder_module = [ Extension( - name='_swig_ctc_decoders', - sources=FILES + [ - 'scorer.cpp', 'ctc_decoders_wrap.cxx', 'ctc_decoders.cpp', - 'decoder_utils.cpp' - ], + name='_swig_decoders', + sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'), language='C++', - include_dirs=['.', './kenlm'], + include_dirs=['.', './kenlm', './openfst-1.6.3/src/include'], libraries=LIBS, extra_compile_args=ARGS) ] setup( - name='swig_ctc_decoders', + name='swig_decoders', version='0.1', description="""CTC decoders""", ext_modules=ctc_beam_search_decoder_module, - py_modules=['swig_ctc_decoders'], ) + py_modules=['swig_decoders'], ) diff --git a/deploy/setup.sh b/deploy/setup.sh deleted file mode 100644 index 423f5b8922c89e785231bd84557de0b59a089f53..0000000000000000000000000000000000000000 --- a/deploy/setup.sh +++ /dev/null @@ -1,11 +0,0 @@ -echo "Run decoder setup ..." - -python decoder_setup.py install -rm -r ./build - -echo "Run scorer setup ..." - -python scorer_setup.py install -rm -r ./build - -echo "Finish the installation of decoder and scorer." diff --git a/deploy/swig_decoders.py b/deploy/swig_decoders_wrapper.py similarity index 68% rename from deploy/swig_decoders.py rename to deploy/swig_decoders_wrapper.py index 0247c0c9ea512089c2293f71745b66432d8d6007..54c430147538caf51508dcf297f68d4242e50371 100644 --- a/deploy/swig_decoders.py +++ b/deploy/swig_decoders_wrapper.py @@ -3,9 +3,25 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import swig_ctc_decoders -#import multiprocessing -from pathos.multiprocessing import Pool +import swig_decoders +import multiprocessing + + +class Scorer(swig_decoders.Scorer): + """Wrapper for Scorer. + + :param alpha: Parameter associated with language model. Don't use + language model when alpha = 0. + :type alpha: float + :param beta: Parameter associated with word count. Don't use word + count when beta = 0. + :type beta: float + :model_path: Path to load language model. + :type model_path: basestring + """ + + def __init__(self, alpha, beta, model_path): + swig_decoders.Scorer.__init__(self, alpha, beta, model_path) def ctc_best_path_decoder(probs_seq, vocabulary): @@ -20,8 +36,7 @@ def ctc_best_path_decoder(probs_seq, vocabulary): :return: Decoding result string. :rtype: basestring """ - return swig_ctc_decoders.ctc_best_path_decoder(probs_seq.tolist(), - vocabulary) + return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary) def ctc_beam_search_decoder( @@ -54,9 +69,9 @@ def ctc_beam_search_decoder( results, in descending order of the probability. :rtype: list """ - return swig_ctc_decoders.ctc_beam_search_decoder( - probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob, - ext_scoring_func) + return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size, + vocabulary, blank_id, + cutoff_prob, ext_scoring_func) def ctc_beam_search_decoder_batch(probs_split, @@ -86,25 +101,4 @@ def ctc_beam_search_decoder_batch(probs_split, pool.close() pool.join() beam_search_results = [result.get() for result in results] - """ - len_args = len(probs_split) - beam_search_results = pool.map(ctc_beam_search_decoder, - probs_split, - [beam_size for i in xrange(len_args)], - [vocabulary for i in xrange(len_args)], - [blank_id for i in xrange(len_args)], - [cutoff_prob for i in xrange(len_args)], - [ext_scoring_func for i in xrange(len_args)] - ) - """ - ''' - processes = [mp.Process(target=ctc_beam_search_decoder, - args=(probs_list, beam_size, vocabulary, blank_id, cutoff_prob, - ext_scoring_func) for probs_list in probs_split] - for p in processes: - p.start() - for p in processes: - p.join() - beam_search_results = [] - ''' return beam_search_results