From ce363d0e06500f1b449b47cb80b648913bca8716 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 7 Dec 2020 12:08:41 -0800 Subject: [PATCH] [build] make builder smarter and configurable wrt compute capabilities + docs (#578) --- docs/_tutorials/advanced-install.md | 23 ++++++++++++ op_builder/builder.py | 58 +++++++++++++++++++++-------- op_builder/fused_adam.py | 5 ++- op_builder/fused_lamb.py | 5 ++- 4 files changed, 74 insertions(+), 17 deletions(-) diff --git a/docs/_tutorials/advanced-install.md b/docs/_tutorials/advanced-install.md index ccb38e33..5dd95a67 100644 --- a/docs/_tutorials/advanced-install.md +++ b/docs/_tutorials/advanced-install.md @@ -84,6 +84,29 @@ the nodes listed in your hostfile (either given via --hostfile, or defaults to /job/hostfile). +## Building for the correct architectures + +If you're getting the following error: + +```python +RuntimeError: CUDA error: no kernel image is available for execution on the device +``` +when running deepspeed that means that the cuda extensions weren't built for the card you're trying to use it for. + +When building from source deepspeed will try to support a wide range of architectures, but under jit-mode it'll only support the archs visible at the time of building. + +You can build specifically for a desired range of architectures by setting a `TORCH_CUDA_ARCH_LIST` env variable, like so: + +```bash +TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... +``` + +It will also make the build faster when you only build for a few architectures. + +This is also recommended to do to ensure your exact architecture is used. Due to a variety of technical reasons a distributed pytorch binary isn't built to fully support all architectures, skipping binary compatible ones, at a potential cost of underutilizing your full card's compute capabilities. To see which archs get included during the deepspeed build from source - save the log and grep for `-gencode` arguments. + +The full list of nvidia gpus and their compute capabilities can be found [here](https://developer.nvidia.com/cuda-gpus). + ## Feature specific dependencies Some DeepSpeed features require specific dependencies outside of the general diff --git a/op_builder/builder.py b/op_builder/builder.py index ccbcd9aa..ac45c480 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -216,25 +216,53 @@ class OpBuilder(ABC): class CUDAOpBuilder(OpBuilder): def compute_capability_args(self, cross_compile_archs=None): - if cross_compile_archs is None: - cross_compile_archs = get_default_compute_capatabilities() + """ + Returns nvcc compute capability compile flags. - args = [] + 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. + 2. If neither is set default compute capabilities will be used + 3. Under `jit_mode` compute capabilities of all visible cards will be used. + + Format: + + - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: + + TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... + TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ... + + - `cross_compile_archs` uses ; separator. + + """ + + ccs = [] if self.jit_mode: - # Compile for underlying architecture since we know it at runtime - CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability() - compute_capability = f"{CC_MAJOR}{CC_MINOR}" - args.append('-gencode') - args.append( - f'arch=compute_{compute_capability},code=compute_{compute_capability}') + # Compile for underlying architectures since we know those at runtime + for i in range(torch.cuda.device_count()): + CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) + cc = f"{CC_MAJOR}.{CC_MINOR}" + if cc not in ccs: + ccs.append(cc) + ccs = sorted(ccs) else: # Cross-compile mode, compile for various architectures - for compute_capability in cross_compile_archs.split(';'): - compute_capability = compute_capability.replace('.', '') - args.append('-gencode') - args.append( - f'arch=compute_{compute_capability},code=compute_{compute_capability}' - ) + # env override takes priority + cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None) + if cross_compile_archs_env is not None: + if cross_compile_archs is not None: + print( + f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`" + ) + cross_compile_archs = cross_compile_archs_env.replace(' ', ';') + else: + if cross_compile_archs is None: + cross_compile_archs = get_default_compute_capatabilities() + ccs = cross_compile_archs.split(';') + + args = [] + for cc in ccs: + cc = cc.replace('.', '') + args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}') + return args def version_dependent_macros(self): diff --git a/op_builder/fused_adam.py b/op_builder/fused_adam.py index 4b43ff7f..e9dd71a5 100644 --- a/op_builder/fused_adam.py +++ b/op_builder/fused_adam.py @@ -22,4 +22,7 @@ class FusedAdamBuilder(CUDAOpBuilder): return ['-O3'] + self.version_dependent_macros() def nvcc_args(self): - return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros() + return ['-lineinfo', + '-O3', + '--use_fast_math' + ] + self.version_dependent_macros() + self.compute_capability_args() diff --git a/op_builder/fused_lamb.py b/op_builder/fused_lamb.py index 272a9772..33a98387 100644 --- a/op_builder/fused_lamb.py +++ b/op_builder/fused_lamb.py @@ -22,4 +22,7 @@ class FusedLambBuilder(CUDAOpBuilder): return ['-O3'] + self.version_dependent_macros() def nvcc_args(self): - return ['-lineinfo', '-O3', '--use_fast_math'] + self.version_dependent_macros() + return ['-lineinfo', + '-O3', + '--use_fast_math' + ] + self.version_dependent_macros() + self.compute_capability_args() -- GitLab