未验证 提交 2665c8b1 编写于 作者: J Jeff Rasley 提交者: GitHub

Fix 1bit extra issue (#1542)

上级 bd3ebddf
......@@ -91,8 +91,8 @@ jobs:
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
git clone https://github.com/huggingface/transformers
git rev-parse --short HEAD
cd transformers
git rev-parse --short HEAD
# scipy/sklearn required for tests, using the 'dev' extra forces torch re-install
pip install .[testing]
# find reqs used in ds integration tests
......
......@@ -50,16 +50,17 @@ def fetch_requirements(path):
install_requires = fetch_requirements('requirements/requirements.txt')
extras_require = {
'1bit_adam': fetch_requirements('requirements/requirements-1bit-adam.txt'),
'1bit_mpi' : fetch_requirements('requirements/requirements-1bit-mpi.txt'),
'1bit': [], # Will add proper cupy version below
'readthedocs': fetch_requirements('requirements/requirements-readthedocs.txt'),
'dev': fetch_requirements('requirements/requirements-dev.txt'),
}
# If MPI is available add 1bit-adam requirements
# Add specific cupy version to both onebit extension variants
if torch_available and torch.cuda.is_available():
if shutil.which('ompi_info') or shutil.which('mpiname'):
cupy = f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}"
extras_require['1bit_adam'].append(cupy)
cupy = f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}"
extras_require['1bit_mpi'].append(cupy)
extras_require['1bit'].append(cupy)
# Make an [all] extra that installs all needed dependencies
all_extras = set()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册