未验证 提交 240ea97b 编写于 作者: J Jeff Rasley 提交者: GitHub

only add 1bit adam reqs if mpi is installed, update cond build for cpu-adam (#400)

上级 b29229bf
......@@ -10,8 +10,10 @@ The wheel will be located at: dist/*.whl
import os
import torch
import shutil
import subprocess
import warnings
import cpufeature
from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension
......@@ -27,47 +29,53 @@ install_requires = fetch_requirements('requirements/requirements.txt')
dev_requires = fetch_requirements('requirements/requirements-dev.txt')
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')
onebit_adam_requires = fetch_requirements('requirements/requirements-1bit-adam.txt')
# If MPI is available add 1bit-adam requirements
if torch.cuda.is_available():
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}")
install_requires += onebit_adam_requires
if shutil.which('ompi_info') or shutil.which('mpiname'):
onebit_adam_requires = fetch_requirements(
'requirements/requirements-1bit-adam.txt')
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}")
install_requires += onebit_adam_requires
# Constants for each op
LAMB = "lamb"
TRANSFORMER = "transformer"
SPARSE_ATTN = "sparse-attn"
ADAM = "cpu-adam"
CPU_ADAM = "cpu-adam"
# Build environment variables for custom builds
DS_BUILD_LAMB_MASK = 1
DS_BUILD_TRANSFORMER_MASK = 10
DS_BUILD_SPARSE_ATTN_MASK = 100
DS_BUILD_ADAM_MASK = 1000
DS_BUILD_CPU_ADAM_MASK = 1000
DS_BUILD_AVX512_MASK = 10000
# Allow for build_cuda to turn on or off all ops
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_ADAM_MASK | DS_BUILD_AVX512_MASK
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK | DS_BUILD_AVX512_MASK
DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS
# Set default of each op based on if build_cuda is set
OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS
DS_BUILD_ADAM = int(os.environ.get('DS_BUILD_ADAM', OP_DEFAULT)) * DS_BUILD_ADAM_MASK
DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM',
OP_DEFAULT)) * DS_BUILD_CPU_ADAM_MASK
DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK
DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER',
OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK
DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN',
0)) * DS_BUILD_SPARSE_ATTN_MASK
DS_BUILD_AVX512 = int(os.environ.get('DS_BUILD_AVX512', 0)) * DS_BUILD_AVX512_MASK
OP_DEFAULT)) * DS_BUILD_SPARSE_ATTN_MASK
DS_BUILD_AVX512 = int(os.environ.get(
'DS_BUILD_AVX512',
cpufeature.CPUFeature['AVX512f'])) * DS_BUILD_AVX512_MASK
# Final effective mask is the bitwise OR of each op
BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN
| DS_BUILD_ADAM)
| DS_BUILD_CPU_ADAM)
install_ops = dict.fromkeys([LAMB, TRANSFORMER, SPARSE_ATTN, ADAM], False)
install_ops = dict.fromkeys([LAMB, TRANSFORMER, SPARSE_ATTN, CPU_ADAM], False)
if BUILD_MASK & DS_BUILD_LAMB:
install_ops[LAMB] = True
if BUILD_MASK & DS_BUILD_ADAM:
install_ops[ADAM] = True
if BUILD_MASK & DS_BUILD_CPU_ADAM:
install_ops[CPU_ADAM] = True
if BUILD_MASK & DS_BUILD_TRANSFORMER:
install_ops[TRANSFORMER] = True
if BUILD_MASK & DS_BUILD_SPARSE_ATTN:
......@@ -103,9 +111,7 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
import cpufeature
cpu_info = cpufeature.CPUFeature
SIMD_WIDTH = ''
if cpu_info['AVX512f'] and DS_BUILD_AVX512:
SIMD_WIDTH = '-D__AVX512__'
......@@ -133,7 +139,7 @@ if BUILD_MASK & DS_BUILD_LAMB:
}))
## Adam ##
if BUILD_MASK & DS_BUILD_ADAM:
if BUILD_MASK & DS_BUILD_CPU_ADAM:
ext_modules.append(
CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op',
sources=[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册