setup.py 3.9 KB
Newer Older
H
Hui Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16
"""Script to build and install decoder package."""

from setuptools import setup, Extension, distutils
Y
Yibing Liu 已提交
17 18
import glob
import platform
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
import os, sys
import multiprocessing.pool
import argparse

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
    "--num_processes",
    default=1,
    type=int,
    help="Number of cpu processes to build package. (default: %(default)d)")
args = parser.parse_known_args()

# reconstruct sys.argv to pass to setup below
sys.argv = [sys.argv[0]] + args[1]


# monkey-patch for parallel compilation
# See: https://stackoverflow.com/a/13176803
def parallelCCompile(self,
                     sources,
                     output_dir=None,
                     macros=None,
                     include_dirs=None,
                     debug=0,
                     extra_preargs=None,
                     extra_postargs=None,
                     depends=None):
    # those lines are copied from distutils.ccompiler.CCompiler directly
    macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
        output_dir, macros, include_dirs, sources, depends, extra_postargs)
    cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)

    # parallel code
    def _single_compile(obj):
        try:
            src, ext = build[obj]
        except KeyError:
            return
        self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)

    # convert to list, imap is evaluated on-demand
    thread_pool = multiprocessing.pool.ThreadPool(args[0].num_processes)
    list(thread_pool.imap(_single_compile, objects))
    return objects
Y
Yibing Liu 已提交
63 64 65 66


def compile_test(header, library):
    dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
67 68 69 70
    command = "bash -c \"g++ -include " + header \
                + " -l" + library + " -x c++ - <<<'int main() {}' -o " \
                + dummy_path + " >/dev/null 2>/dev/null && rm " \
                + dummy_path + " 2>/dev/null\""
Y
Yibing Liu 已提交
71 72 73
    return os.system(command) == 0


74 75 76 77 78 79 80 81 82
# hack compile to support parallel compiling
distutils.ccompiler.CCompiler.compile = parallelCCompile

FILES = glob.glob('kenlm/util/*.cc') \
        + glob.glob('kenlm/lm/*.cc') \
        + glob.glob('kenlm/util/double-conversion/*.cc')

FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')

Y
Yibing Liu 已提交
83
FILES = [
84 85 86
    fn for fn in FILES
    if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
        'unittest.cc'))
Y
Yibing Liu 已提交
87 88 89 90 91 92
]

LIBS = ['stdc++']
if platform.system() != 'Darwin':
    LIBS.append('rt')

93
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11']
Y
Yibing Liu 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106

if compile_test('zlib.h', 'z'):
    ARGS.append('-DHAVE_ZLIB')
    LIBS.append('z')

if compile_test('bzlib.h', 'bz2'):
    ARGS.append('-DHAVE_BZLIB')
    LIBS.append('bz2')

if compile_test('lzma.h', 'lzma'):
    ARGS.append('-DHAVE_XZLIB')
    LIBS.append('lzma')

107
os.system('swig -python -c++ ./decoders.i')
Y
Yibing Liu 已提交
108

109
decoders_module = [
Y
Yibing Liu 已提交
110
    Extension(
111 112
        name='_swig_decoders',
        sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
113
        language='c++',
114 115 116 117 118 119
        include_dirs=[
            '.',
            'kenlm',
            'openfst-1.6.3/src/include',
            'ThreadPool',
        ],
Y
Yibing Liu 已提交
120 121 122 123 124
        libraries=LIBS,
        extra_compile_args=ARGS)
]

setup(
125
    name='swig_decoders',
126
    version='1.1',
127
    description="""CTC decoders""",
128
    ext_modules=decoders_module,
129
    py_modules=['swig_decoders'], )