提交 32047c72 编写于 作者: Y Yibing Liu

refine wrapper for swig and simplify setup

上级 9ff48b05
...@@ -10,8 +10,7 @@ import multiprocessing ...@@ -10,8 +10,7 @@ import multiprocessing
import paddle.v2 as paddle import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from deploy.swig_decoders import * from deploy.swig_decoders_wrapper import *
from swig_scorer import Scorer
from error_rate import wer from error_rate import wer
import utils import utils
import time import time
...@@ -164,7 +163,8 @@ def infer(): ...@@ -164,7 +163,8 @@ def infer():
] ]
# external scorer # external scorer
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) ext_scorer = Scorer(
alpha=args.alpha, beta=args.beta, model_path=args.language_model_path)
## decode and print ## decode and print
time_begin = time.time() time_begin = time.time()
......
### Installation ### Installation
The setup of the decoder for deployment depends on the source code of [kenlm](https://github.com/kpu/kenlm/), first clone it to current directory (i.e., `deep_speech_2/deploy`) The setup of the decoder for deployment depends on the source code of [kenlm](https://github.com/kpu/kenlm/) and [openfst](http://www.openfst.org/twiki/bin/view/FST/WebHome), first clone kenlm and download openfst to current directory (i.e., `deep_speech_2/deploy`)
```shell ```shell
git clone https://github.com/kpu/kenlm.git git clone https://github.com/kpu/kenlm.git
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
``` ```
Then run the setup Then run the setup
```shell ```shell
sh setup.sh python setup.py install
```
After the installation succeeds, go back to the parent directory
```
cd .. cd ..
``` ```
......
%module swig_ctc_decoders %module swig_decoders
%{ %{
#include "scorer.h"
#include "ctc_decoders.h" #include "ctc_decoders.h"
%} %}
...@@ -18,6 +19,6 @@ namespace std{ ...@@ -18,6 +19,6 @@ namespace std{
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >; %template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
} }
%import scorer.h
%import decoder_utils.h %import decoder_utils.h
%include "scorer.h"
%include "ctc_decoders.h" %include "ctc_decoders.h"
%module swig_scorer
%{
#include "scorer.h"
%}
%include "std_string.i"
%include "scorer.h"
from setuptools import setup, Extension
import glob
import platform
import os
def compile_test(header, library):
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
return os.system(command) == 0
FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob(
'kenlm/util/double-conversion/*.cc')
FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
]
LIBS = ['stdc++']
if platform.system() != 'Darwin':
LIBS.append('rt')
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6']
if compile_test('zlib.h', 'z'):
ARGS.append('-DHAVE_ZLIB')
LIBS.append('z')
if compile_test('bzlib.h', 'bz2'):
ARGS.append('-DHAVE_BZLIB')
LIBS.append('bz2')
if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma')
os.system('swig -python -c++ ./scorer.i')
ext_modules = [
Extension(
name='_swig_scorer',
sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'],
language='C++',
include_dirs=['.', './kenlm'],
libraries=LIBS,
extra_compile_args=ARGS)
]
setup(
name='swig_scorer',
version='0.1',
ext_modules=ext_modules,
include_package_data=True,
py_modules=['swig_scorer'], )
...@@ -20,7 +20,7 @@ LIBS = ['stdc++'] ...@@ -20,7 +20,7 @@ LIBS = ['stdc++']
if platform.system() != 'Darwin': if platform.system() != 'Darwin':
LIBS.append('rt') LIBS.append('rt')
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6'] ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11']
if compile_test('zlib.h', 'z'): if compile_test('zlib.h', 'z'):
ARGS.append('-DHAVE_ZLIB') ARGS.append('-DHAVE_ZLIB')
...@@ -34,24 +34,21 @@ if compile_test('lzma.h', 'lzma'): ...@@ -34,24 +34,21 @@ if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB') ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma') LIBS.append('lzma')
os.system('swig -python -c++ ./ctc_decoders.i') os.system('swig -python -c++ ./decoders.i')
ctc_beam_search_decoder_module = [ ctc_beam_search_decoder_module = [
Extension( Extension(
name='_swig_ctc_decoders', name='_swig_decoders',
sources=FILES + [ sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
'scorer.cpp', 'ctc_decoders_wrap.cxx', 'ctc_decoders.cpp',
'decoder_utils.cpp'
],
language='C++', language='C++',
include_dirs=['.', './kenlm'], include_dirs=['.', './kenlm', './openfst-1.6.3/src/include'],
libraries=LIBS, libraries=LIBS,
extra_compile_args=ARGS) extra_compile_args=ARGS)
] ]
setup( setup(
name='swig_ctc_decoders', name='swig_decoders',
version='0.1', version='0.1',
description="""CTC decoders""", description="""CTC decoders""",
ext_modules=ctc_beam_search_decoder_module, ext_modules=ctc_beam_search_decoder_module,
py_modules=['swig_ctc_decoders'], ) py_modules=['swig_decoders'], )
echo "Run decoder setup ..."
python decoder_setup.py install
rm -r ./build
echo "Run scorer setup ..."
python scorer_setup.py install
rm -r ./build
echo "Finish the installation of decoder and scorer."
...@@ -3,9 +3,25 @@ from __future__ import absolute_import ...@@ -3,9 +3,25 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import swig_ctc_decoders import swig_decoders
#import multiprocessing import multiprocessing
from pathos.multiprocessing import Pool
class Scorer(swig_decoders.Scorer):
"""Wrapper for Scorer.
:param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float
:model_path: Path to load language model.
:type model_path: basestring
"""
def __init__(self, alpha, beta, model_path):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path)
def ctc_best_path_decoder(probs_seq, vocabulary): def ctc_best_path_decoder(probs_seq, vocabulary):
...@@ -20,8 +36,7 @@ def ctc_best_path_decoder(probs_seq, vocabulary): ...@@ -20,8 +36,7 @@ def ctc_best_path_decoder(probs_seq, vocabulary):
:return: Decoding result string. :return: Decoding result string.
:rtype: basestring :rtype: basestring
""" """
return swig_ctc_decoders.ctc_best_path_decoder(probs_seq.tolist(), return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary)
vocabulary)
def ctc_beam_search_decoder( def ctc_beam_search_decoder(
...@@ -54,9 +69,9 @@ def ctc_beam_search_decoder( ...@@ -54,9 +69,9 @@ def ctc_beam_search_decoder(
results, in descending order of the probability. results, in descending order of the probability.
:rtype: list :rtype: list
""" """
return swig_ctc_decoders.ctc_beam_search_decoder( return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size,
probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob, vocabulary, blank_id,
ext_scoring_func) cutoff_prob, ext_scoring_func)
def ctc_beam_search_decoder_batch(probs_split, def ctc_beam_search_decoder_batch(probs_split,
...@@ -86,25 +101,4 @@ def ctc_beam_search_decoder_batch(probs_split, ...@@ -86,25 +101,4 @@ def ctc_beam_search_decoder_batch(probs_split,
pool.close() pool.close()
pool.join() pool.join()
beam_search_results = [result.get() for result in results] beam_search_results = [result.get() for result in results]
"""
len_args = len(probs_split)
beam_search_results = pool.map(ctc_beam_search_decoder,
probs_split,
[beam_size for i in xrange(len_args)],
[vocabulary for i in xrange(len_args)],
[blank_id for i in xrange(len_args)],
[cutoff_prob for i in xrange(len_args)],
[ext_scoring_func for i in xrange(len_args)]
)
"""
'''
processes = [mp.Process(target=ctc_beam_search_decoder,
args=(probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
ext_scoring_func) for probs_list in probs_split]
for p in processes:
p.start()
for p in processes:
p.join()
beam_search_results = []
'''
return beam_search_results return beam_search_results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册