提交 89c4a962 编写于 作者: Y yangyaming

Make setup.py to support parallel processing.

上级 34411488
......@@ -25,7 +25,7 @@ git clone https://github.com/progschj/ThreadPool.git
Then run the setup
```shell
python setup.py install
python setup.py install --num_processes 4
cd ..
```
......
#include <iostream>
#include <unistd.h>
#include "lm/config.hh"
#include "lm/state.hh"
#include "lm/model.hh"
#include "scorer.h"
#include "decoder_utils.h"
......@@ -24,7 +27,7 @@ void Scorer::load_LM(const char* filename) {
exit(1);
}
RetriveStrEnumerateVocab enumerate;
Config config;
lm::ngram::Config config;
config.enumerate_vocab = &enumerate;
_language_model = lm::ngram::LoadVirtual(filename, config);
_max_order = static_cast<lm::base::Model*>(_language_model)->Order();
......@@ -43,7 +46,7 @@ void Scorer::load_LM(const char* filename) {
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(_language_model);
double cond_prob;
State state, tmp_state, out_state;
lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin
model->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) {
......
from setuptools import setup, Extension
"""Script to build and install decoder package."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from setuptools import setup, Extension, distutils
import glob
import platform
import os
import os, sys
import multiprocessing.pool
import argparse
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num_processes",
default=1,
type=int,
help="Number of cpu processes to build package. (default: %(default)d)")
args = parser.parse_known_args()
# reconstruct sys.argv to pass to setup below
sys.argv = [sys.argv[0]] + args[1]
# monkey-patch for parallel compilation
# See: https://stackoverflow.com/a/13176803
def parallelCCompile(self,
sources,
output_dir=None,
macros=None,
include_dirs=None,
debug=0,
extra_preargs=None,
extra_postargs=None,
depends=None):
# those lines are copied from distutils.ccompiler.CCompiler directly
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
output_dir, macros, include_dirs, sources, depends, extra_postargs)
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
# parallel code
def _single_compile(obj):
try:
src, ext = build[obj]
except KeyError:
return
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
# convert to list, imap is evaluated on-demand
thread_pool = multiprocessing.pool.ThreadPool(args[0].num_processes)
list(thread_pool.imap(_single_compile, objects))
return objects
def compile_test(header, library):
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
command = "bash -c \"g++ -include " + header \
+ " -l" + library + " -x c++ - <<<'int main() {}' -o " \
+ dummy_path + " >/dev/null 2>/dev/null && rm " \
+ dummy_path + " 2>/dev/null\""
return os.system(command) == 0
FILES = glob.glob('kenlm/util/*.cc') + glob.glob('kenlm/lm/*.cc') + glob.glob(
'kenlm/util/double-conversion/*.cc')
# hack compile to support parallel compiling
distutils.ccompiler.CCompiler.compile = parallelCCompile
FILES = glob.glob('kenlm/util/*.cc') \
+ glob.glob('kenlm/lm/*.cc') \
+ glob.glob('kenlm/util/double-conversion/*.cc')
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
FILES = [
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))
]
......@@ -40,7 +98,7 @@ decoders_module = [
Extension(
name='_swig_decoders',
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
language='C++',
language='c++',
include_dirs=['.', 'kenlm', 'openfst-1.6.3/src/include', 'ThreadPool'],
libraries=LIBS,
extra_compile_args=ARGS)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部