提交 32047c72 编写于 作者: Y Yibing Liu

refine wrapper for swig and simplify setup

上级 9ff48b05
......@@ -10,8 +10,7 @@ import multiprocessing
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2
from deploy.swig_decoders import *
from swig_scorer import Scorer
from deploy.swig_decoders_wrapper import *
from error_rate import wer
import utils
import time
......@@ -164,7 +163,8 @@ def infer():
]
# external scorer
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
ext_scorer = Scorer(
alpha=args.alpha, beta=args.beta, model_path=args.language_model_path)
## decode and print
time_begin = time.time()
......
### Installation
The setup of the decoder for deployment depends on the source code of [kenlm](https://github.com/kpu/kenlm/), first clone it to current directory (i.e., `deep_speech_2/deploy`)
The setup of the decoder for deployment depends on the source code of [kenlm](https://github.com/kpu/kenlm/) and [openfst](http://www.openfst.org/twiki/bin/view/FST/WebHome), first clone kenlm and download openfst to current directory (i.e., `deep_speech_2/deploy`)
```shell
git clone https://github.com/kpu/kenlm.git
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
```
Then run the setup
```shell
sh setup.sh
```
After the installation succeeds, go back to the parent directory
```
python setup.py install
cd ..
```
......
%module swig_ctc_decoders
%module swig_decoders
%{
#include "scorer.h"
#include "ctc_decoders.h"
%}
......@@ -18,6 +19,6 @@ namespace std{
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
}
%import scorer.h
%import decoder_utils.h
%include "scorer.h"
%include "ctc_decoders.h"
%module swig_scorer
%{
#include "scorer.h"
%}
%include "std_string.i"
%include "scorer.h"
from setuptools import setup, Extension
import glob
import platform
import os
def compile_test(header, library):
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
return os.system(command) == 0
FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob(
'kenlm/util/double-conversion/*.cc')
FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
]
LIBS = ['stdc++']
if platform.system() != 'Darwin':
LIBS.append('rt')
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6']
if compile_test('zlib.h', 'z'):
ARGS.append('-DHAVE_ZLIB')
LIBS.append('z')
if compile_test('bzlib.h', 'bz2'):
ARGS.append('-DHAVE_BZLIB')
LIBS.append('bz2')
if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma')
os.system('swig -python -c++ ./scorer.i')
ext_modules = [
Extension(
name='_swig_scorer',
sources=FILES + ['scorer_wrap.cxx', 'scorer.cpp'],
language='C++',
include_dirs=['.', './kenlm'],
libraries=LIBS,
extra_compile_args=ARGS)
]
setup(
name='swig_scorer',
version='0.1',
ext_modules=ext_modules,
include_package_data=True,
py_modules=['swig_scorer'], )
......@@ -20,7 +20,7 @@ LIBS = ['stdc++']
if platform.system() != 'Darwin':
LIBS.append('rt')
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6']
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11']
if compile_test('zlib.h', 'z'):
ARGS.append('-DHAVE_ZLIB')
......@@ -34,24 +34,21 @@ if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma')
os.system('swig -python -c++ ./ctc_decoders.i')
os.system('swig -python -c++ ./decoders.i')
ctc_beam_search_decoder_module = [
Extension(
name='_swig_ctc_decoders',
sources=FILES + [
'scorer.cpp', 'ctc_decoders_wrap.cxx', 'ctc_decoders.cpp',
'decoder_utils.cpp'
],
name='_swig_decoders',
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
language='C++',
include_dirs=['.', './kenlm'],
include_dirs=['.', './kenlm', './openfst-1.6.3/src/include'],
libraries=LIBS,
extra_compile_args=ARGS)
]
setup(
name='swig_ctc_decoders',
name='swig_decoders',
version='0.1',
description="""CTC decoders""",
ext_modules=ctc_beam_search_decoder_module,
py_modules=['swig_ctc_decoders'], )
py_modules=['swig_decoders'], )
echo "Run decoder setup ..."
python decoder_setup.py install
rm -r ./build
echo "Run scorer setup ..."
python scorer_setup.py install
rm -r ./build
echo "Finish the installation of decoder and scorer."
......@@ -3,9 +3,25 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import swig_ctc_decoders
#import multiprocessing
from pathos.multiprocessing import Pool
import swig_decoders
import multiprocessing
class Scorer(swig_decoders.Scorer):
"""Wrapper for Scorer.
:param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float
:model_path: Path to load language model.
:type model_path: basestring
"""
def __init__(self, alpha, beta, model_path):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path)
def ctc_best_path_decoder(probs_seq, vocabulary):
......@@ -20,8 +36,7 @@ def ctc_best_path_decoder(probs_seq, vocabulary):
:return: Decoding result string.
:rtype: basestring
"""
return swig_ctc_decoders.ctc_best_path_decoder(probs_seq.tolist(),
vocabulary)
return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary)
def ctc_beam_search_decoder(
......@@ -54,9 +69,9 @@ def ctc_beam_search_decoder(
results, in descending order of the probability.
:rtype: list
"""
return swig_ctc_decoders.ctc_beam_search_decoder(
probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob,
ext_scoring_func)
return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size,
vocabulary, blank_id,
cutoff_prob, ext_scoring_func)
def ctc_beam_search_decoder_batch(probs_split,
......@@ -86,25 +101,4 @@ def ctc_beam_search_decoder_batch(probs_split,
pool.close()
pool.join()
beam_search_results = [result.get() for result in results]
"""
len_args = len(probs_split)
beam_search_results = pool.map(ctc_beam_search_decoder,
probs_split,
[beam_size for i in xrange(len_args)],
[vocabulary for i in xrange(len_args)],
[blank_id for i in xrange(len_args)],
[cutoff_prob for i in xrange(len_args)],
[ext_scoring_func for i in xrange(len_args)]
)
"""
'''
processes = [mp.Process(target=ctc_beam_search_decoder,
args=(probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
ext_scoring_func) for probs_list in probs_split]
for p in processes:
p.start()
for p in processes:
p.join()
beam_search_results = []
'''
return beam_search_results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册