setup.py 4.5 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9

10 11 12
import os
import re
import pathlib
13
import platform
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
from distutils.file_util import copy_file
from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext as _build_ext

class PrecompiledExtesion(Extension):
    def __init__(self, name):
        super().__init__(name, sources=[])

class build_ext(_build_ext):

    def build_extension(self, ext):
        if not isinstance(ext, PrecompiledExtesion):
            return super().build_extension(ext)

        if not self.inplace:
            fullpath = self.get_ext_fullpath(ext.name)
            extdir = pathlib.Path(fullpath)
            extdir.parent.mkdir(parents=True, exist_ok=True)

            modpath = self.get_ext_fullname(ext.name).split('.')
34 35 36 37
            if platform.system() == 'Windows':
                modpath[-1] += '.pyd'
            else:
                modpath[-1] += '.so'
38 39 40 41 42 43 44 45 46 47 48 49
            modpath = str(pathlib.Path(*modpath).resolve())

            copy_file(modpath, fullpath, verbose=self.verbose, dry_run=self.dry_run)

package_name = 'MegEngine'

v = {}
with open("megengine/version.py") as fp:
    exec(fp.read(), v)
__version__ = v['__version__']

email = 'megengine@megvii.com'
M
Megvii Engine Team 已提交
50 51
# https://www.python.org/dev/peps/pep-0440
# Public version identifiers: [N!]N(.N)*[{a|b|rc}N][.postN][.devN]
52
# Local version identifiers: <public version identifier>[+<local version label>]
M
Megvii Engine Team 已提交
53 54 55 56
# PUBLIC_VERSION_POSTFIX use to handle rc or dev info
public_version_postfix = os.environ.get('PUBLIC_VERSION_POSTFIX')
if public_version_postfix:
    __version__ = '{}{}'.format(__version__, public_version_postfix)
57 58 59

local_version = []
strip_sdk_info = os.environ.get('STRIP_SDK_INFO', 'False').lower()
60
sdk_name = os.environ.get('SDK_NAME', 'cpu')
61 62 63 64 65 66 67 68 69
if 'true' == strip_sdk_info:
    print('wheel version strip sdk info')
else:
    local_version.append(sdk_name)
local_postfix = os.environ.get('LOCAL_VERSION')
if local_postfix:
    local_version.append(local_postfix)
if len(local_version):
    __version__ = '{}+{}'.format(__version__, '.'.join(local_version))
70 71

packages = find_packages(exclude=['test'])
72
megengine_data = [
73 74 75
    str(f.relative_to('megengine'))
    for f in pathlib.Path('megengine', 'core', 'include').glob('**/*')
]
76

77
megengine_data += [
78 79 80
    str(f.relative_to('megengine'))
    for f in pathlib.Path('megengine', 'core', 'lib').glob('**/*')
]
81

82

83 84 85 86 87 88 89
with open('requires.txt') as f:
    requires = f.read().splitlines()
with open('requires-style.txt') as f:
    requires_style = f.read().splitlines()
with open('requires-test.txt') as f:
    requires_test = f.read().splitlines()

90
prebuild_modules=[PrecompiledExtesion('megengine.core._imperative_rt')]
91 92 93 94 95 96 97 98
setup_kwargs = dict(
    name=package_name,
    version=__version__,
    description='Framework for numerical evaluation with '
    'auto-differentiation',
    author='Megvii Engine Team',
    author_email=email,
    packages=packages,
99
    package_data={
100
        'megengine': megengine_data,
101
    },
102
    ext_modules=prebuild_modules,
103 104 105 106 107 108 109 110
    install_requires=requires,
    extras_require={
        'dev': requires_style + requires_test,
        'ci': requires_test,
    },
    cmdclass={'build_ext': build_ext},
)

111

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
setup_kwargs.update(dict(
    classifiers=[
    'Development Status :: 3 - Alpha',
    'Intended Audience :: Developers',
    'Intended Audience :: Education',
    'Intended Audience :: Science/Research',
    'License :: OSI Approved :: Apache Software License',
    'Programming Language :: C++',
    'Programming Language :: Python :: 3',
    'Programming Language :: Python :: 3.5',
    'Programming Language :: Python :: 3.6',
    'Programming Language :: Python :: 3.7',
    'Programming Language :: Python :: 3.8',
    'Topic :: Scientific/Engineering',
    'Topic :: Scientific/Engineering :: Mathematics',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'Topic :: Software Development',
    'Topic :: Software Development :: Libraries',
    'Topic :: Software Development :: Libraries :: Python Modules',
    ],
    license='Apache 2.0',
    keywords='megengine deep learning',
    data_files = [("megengine", [
        "../LICENSE",
        "../ACKNOWLEDGMENTS",
    ])]
))

setup(**setup_kwargs)