未验证 提交 b8416282 编写于 作者: C Connor Holmes 提交者: GitHub

Drop Maxwell Support (#2574)

* Officially drop Maxwell support

* Formatting

* Comparison mismatch fix
上级 06938835
...@@ -67,7 +67,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, ...@@ -67,7 +67,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
unsigned total_count, unsigned total_count,
int max_out_tokens) int max_out_tokens)
{ {
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -102,7 +101,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, ...@@ -102,7 +101,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
lane += WARP_SIZE; lane += WARP_SIZE;
} }
} }
#endif
} }
__global__ void apply_rotary_pos_emb1(float* mixed_query, __global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer, float* key_layer,
...@@ -159,7 +157,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, ...@@ -159,7 +157,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
unsigned total_count, unsigned total_count,
int max_out_tokens) int max_out_tokens)
{ {
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -205,7 +202,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, ...@@ -205,7 +202,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
lane += WARP_SIZE; lane += WARP_SIZE;
} }
} }
#endif
} }
template <typename T> template <typename T>
......
...@@ -50,8 +50,6 @@ __global__ void dequantize_kernel(__half* output, ...@@ -50,8 +50,6 @@ __global__ void dequantize_kernel(__half* output,
unsigned groups, unsigned groups,
unsigned merge_count) unsigned merge_count)
{ {
#ifdef HALF_PRECISION_AVAILABLE
unsigned merge_hidden = hidden_dim >> merge_count; unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups; unsigned quantization_stride = (merge_hidden * output_size) / groups;
...@@ -75,7 +73,6 @@ __global__ void dequantize_kernel(__half* output, ...@@ -75,7 +73,6 @@ __global__ void dequantize_kernel(__half* output,
output[q_index] = __float2half(scale_data * (float)q); output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x; tid += blockDim.x;
} }
#endif
} }
template <typename T> template <typename T>
......
...@@ -17,6 +17,9 @@ inline __device__ float gelu(const float x) ...@@ -17,6 +17,9 @@ inline __device__ float gelu(const float x)
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
} }
/*
In-place gelu(biasAdd(x)) for channels last
*/
template <typename T> template <typename T>
__global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int intermediate_size) __global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int intermediate_size)
{ {
...@@ -64,63 +67,51 @@ void launch_bias_gelu(T* input, ...@@ -64,63 +67,51 @@ void launch_bias_gelu(T* input,
template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t); template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t);
// Not called directly from DeepSpeed, but used in ds_qkv_gemm_int8, ds_linear_layer, etc. /*
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size) In-place channels-last bias add
{ */
constexpr int granularity = 16; template <typename T>
constexpr int vals_per_access = granularity / sizeof(float); __global__ void fused_bias_add(T* input, const T* bias, int total_count, int intermediate_size)
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;
if (offset < total_count) {
float data[vals_per_access];
float bias_data[vals_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % hidden_size));
#pragma unroll
for (int i = 0; i < vals_per_access; i++) { data[i] += bias_data[i]; }
mem_access::store_global<granularity>(input + offset, data);
}
}
__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{ {
#ifdef HALF_PRECISION_AVAILABLE // Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16; constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half); constexpr int values_per_access = granularity / sizeof(T);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access;
if (offset < total_count) { if (offset < total_count) {
__half2 data[vals_per_access / 2]; T data[values_per_access];
__half2 bias_data[vals_per_access / 2]; T data_bias[values_per_access];
mem_access::load_global<granularity>(data, input + offset); mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % hidden_size)); mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));
#pragma unroll #pragma unroll
for (int i = 0; i < vals_per_access / 2; i++) { for (int i = 0; i < values_per_access; i++) {
float2 data_f = __half22float2(data[i]); float data_f = conversion::to<float>(data[i]);
float2 bias_f = __half22float2(bias_data[i]); float bias_f = conversion::to<float>(data_bias[i]);
data[i] = __floats2half2_rn(data_f.x + bias_f.x, data_f.y + bias_f.y); data[i] = conversion::to<T>(data_f + bias_f);
} }
mem_access::store_global<granularity>(input + offset, data); mem_access::store_global<granularity>(input + offset, data);
} }
#endif
} }
template <typename T> template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream) void launch_bias_add(T* input,
const T* bias,
int intermediate_size,
int batch_size,
cudaStream_t stream)
{ {
constexpr int threads = 1024; constexpr int threads = 1024;
constexpr int granularity = 16; constexpr int granularity = 16;
const int total_count = batch_size * hidden_size; const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T)); const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads); dim3 block_dims(threads);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block); dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);
fused_bias_add<<<grid_dims, block_dims, 0, stream>>>(input, bias, total_count, hidden_size); fused_bias_add<<<grid_dims, block_dims, 0, stream>>>(
input, bias, total_count, intermediate_size);
} }
template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t); template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t);
...@@ -181,8 +172,6 @@ __global__ void fused_bias_residual(__half* residual, ...@@ -181,8 +172,6 @@ __global__ void fused_bias_residual(__half* residual,
const float mp_scale, const float mp_scale,
const bool preln) const bool preln)
{ {
#ifdef HALF_PRECISION_AVAILABLE
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual); float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state); const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn); const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
...@@ -241,7 +230,6 @@ __global__ void fused_bias_residual(__half* residual, ...@@ -241,7 +230,6 @@ __global__ void fused_bias_residual(__half* residual,
res_fl2_ptr[offset] = res_fl2; res_fl2_ptr[offset] = res_fl2;
} }
#endif
} }
template <typename T> template <typename T>
...@@ -325,8 +313,6 @@ __global__ void gptj_residual_add(__half* residual, ...@@ -325,8 +313,6 @@ __global__ void gptj_residual_add(__half* residual,
const int intermediate_size, const int intermediate_size,
const float mp_scale) const float mp_scale)
{ {
#ifdef HALF_PRECISION_AVAILABLE
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual); float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state); const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn); const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
...@@ -379,7 +365,6 @@ __global__ void gptj_residual_add(__half* residual, ...@@ -379,7 +365,6 @@ __global__ void gptj_residual_add(__half* residual,
res_fl2_ptr[offset] = res_fl2; res_fl2_ptr[offset] = res_fl2;
} }
#endif
} }
template <typename T> template <typename T>
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Copyright 2022 The Microsoft DeepSpeed Team Copyright 2022 The Microsoft DeepSpeed Team
*/ */
#include "conversion_utils.h"
#include "inference_cuda_layers.h" #include "inference_cuda_layers.h"
#include "memory_access_utils.h" #include "memory_access_utils.h"
...@@ -11,58 +12,32 @@ namespace cg = cooperative_groups; ...@@ -11,58 +12,32 @@ namespace cg = cooperative_groups;
inline __device__ float relu(const float x) { return x < 0 ? 0 : x; } inline __device__ float relu(const float x) { return x < 0 ? 0 : x; }
__global__ void fused_bias_relu(float* input, /*
const float* bias, In-place relu(biasAdd(x)) for channels last
int total_count, */
int intermediate_size) template <typename T>
__global__ void fused_bias_relu(T* input, const T* bias, int total_count, int intermediate_size)
{ {
// Input restriction: intermediate_size % vals_per_access == 0 // Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16; constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(float); constexpr int values_per_access = granularity / sizeof(T);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access;
if (offset < total_count) { if (offset < total_count) {
float data[vals_per_access]; T data[values_per_access];
float data_bias[vals_per_access]; T data_bias[values_per_access];
mem_access::load_global<granularity>(data, input + offset); mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size)); mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));
#pragma unroll #pragma unroll
for (int i = 0; i < vals_per_access; i++) { data[i] = relu(data[i] + data_bias[i]); } for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
mem_access::store_global<granularity>(input + offset, data); float bias_f = conversion::to<float>(data_bias[i]);
} data[i] = conversion::to<T>(relu(data_f + bias_f));
}
__global__ void fused_bias_relu(__half* input,
const __half* bias,
int total_count,
int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
// This kernel doubles the per-thread ALU workload as compared to the float implementation
#ifdef HALF_PRECISION_AVAILABLE
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half);
int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;
if (offset < total_count) {
// Divide by 2 since we store two values per __half2
__half2 data[vals_per_access / 2];
__half2 bias_data[vals_per_access / 2];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % intermediate_size));
#pragma unroll
for (int i = 0; i < vals_per_access / 2; i++) {
float2 data_f = __half22float2(data[i]);
float2 bias_f = __half22float2(bias_data[i]);
data[i] = __floats2half2_rn(relu(data_f.x + bias_f.x), relu(data_f.y + bias_f.y));
} }
mem_access::store_global<granularity>(input + offset, data); mem_access::store_global<granularity>(input + offset, data);
} }
#endif
} }
template <typename T> template <typename T>
......
...@@ -48,8 +48,6 @@ __global__ void attn_softmax_v2(__half* vals, ...@@ -48,8 +48,6 @@ __global__ void attn_softmax_v2(__half* vals,
int iterations, int iterations,
int reduceWidth) int reduceWidth)
{ {
#ifdef HALF_PRECISION_AVAILABLE
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -232,7 +230,6 @@ __global__ void attn_softmax_v2(__half* vals, ...@@ -232,7 +230,6 @@ __global__ void attn_softmax_v2(__half* vals,
} }
} }
} }
#endif
} }
__global__ void attn_softmax_v2(float* vals, __global__ void attn_softmax_v2(float* vals,
......
...@@ -90,8 +90,6 @@ __global__ void bias_add_transform_0213(__half* output, // q ...@@ -90,8 +90,6 @@ __global__ void bias_add_transform_0213(__half* output, // q
int head_ext, int head_ext,
int max_out_tokens) int max_out_tokens)
{ {
#if __CUDA_ARCH__ >= 700
unsigned half_dim = (rotary_dim << 3) >> 1; unsigned half_dim = (rotary_dim << 3) >> 1;
int d0_stride = hidden_dim * seq_length; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim; int d1_stride = hidden_dim;
...@@ -146,8 +144,6 @@ __global__ void bias_add_transform_0213(__half* output, // q ...@@ -146,8 +144,6 @@ __global__ void bias_add_transform_0213(__half* output, // q
output_vec[d3] = q; output_vec[d3] = q;
} else } else
output_vec[d3] = vals_vec[d3]; output_vec[d3] = vals_vec[d3];
#endif
} }
// [B S C*H] - > C * [B A S N] // [B S C*H] - > C * [B A S N]
...@@ -269,7 +265,6 @@ __global__ void pad_add_transform_0213(__half* output, ...@@ -269,7 +265,6 @@ __global__ void pad_add_transform_0213(__half* output,
int heads, int heads,
int padded_head_size) int padded_head_size)
{ {
#if __CUDA_ARCH__ >= 700
float4 ZERO; float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f); const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
...@@ -303,8 +298,6 @@ __global__ void pad_add_transform_0213(__half* output, ...@@ -303,8 +298,6 @@ __global__ void pad_add_transform_0213(__half* output,
output_vec[d3] = vals_vec[d3]; output_vec[d3] = vals_vec[d3];
else else
output_vec[d3] = ZERO; output_vec[d3] = ZERO;
#endif
} }
template <typename T> template <typename T>
...@@ -409,8 +402,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output, ...@@ -409,8 +402,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
int heads, int heads,
int head_ext) int head_ext)
{ {
#ifdef HALF_PRECISION_AVAILABLE
int d0_stride = hidden_dim * seq_length; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads; int d2_stride = hidden_dim / heads;
...@@ -455,8 +446,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output, ...@@ -455,8 +446,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
output_half[2] = vals_half[2] + bias_half[2]; output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3]; output_half[3] = vals_half[3] + bias_half[3];
output_vec[d3] = output_arr; output_vec[d3] = output_arr;
#endif
} }
__global__ void bias_add_transform_0213_v2(__half* output, __global__ void bias_add_transform_0213_v2(__half* output,
...@@ -466,7 +455,6 @@ __global__ void bias_add_transform_0213_v2(__half* output, ...@@ -466,7 +455,6 @@ __global__ void bias_add_transform_0213_v2(__half* output,
int seq_length, int seq_length,
int heads) int heads)
{ {
#ifdef HALF_PRECISION_AVAILABLE
__shared__ float4 in_data[3072]; __shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length; int d0_stride = hidden_dim * seq_length;
...@@ -528,7 +516,6 @@ __global__ void bias_add_transform_0213_v2(__half* output, ...@@ -528,7 +516,6 @@ __global__ void bias_add_transform_0213_v2(__half* output,
output_vec[out_index + iter_offset] = output_vec[out_index + iter_offset] =
in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)]; in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
} }
#endif
} }
template <typename T> template <typename T>
...@@ -580,8 +567,6 @@ __global__ void transform4d_0213<__half>(__half* out, ...@@ -580,8 +567,6 @@ __global__ void transform4d_0213<__half>(__half* out,
int hidden_dim, int hidden_dim,
int head_ext) int head_ext)
{ {
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * (seq_length / head_ext); int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads; int d2_stride = hidden_dim / heads;
...@@ -606,8 +591,6 @@ __global__ void transform4d_0213<__half>(__half* out, ...@@ -606,8 +591,6 @@ __global__ void transform4d_0213<__half>(__half* out,
out_vec += (d2 * d1_stride * gridDim.y); out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3]; out_vec[d3] = in_vec[d3];
#endif
} }
__global__ void transform4d_0213_v2(__half* out, __global__ void transform4d_0213_v2(__half* out,
...@@ -616,7 +599,6 @@ __global__ void transform4d_0213_v2(__half* out, ...@@ -616,7 +599,6 @@ __global__ void transform4d_0213_v2(__half* out,
int seq_length, int seq_length,
int hidden_dim) int hidden_dim)
{ {
#if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072]; __shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length; int d0_stride = hidden_dim * seq_length;
...@@ -657,7 +639,6 @@ __global__ void transform4d_0213_v2(__half* out, ...@@ -657,7 +639,6 @@ __global__ void transform4d_0213_v2(__half* out,
int iter_id = iter * iteration_stride + iter_index; int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id]; out_vec[output_offset + iter_id] = in_data[iter_id];
} }
#endif
} }
// 3 * [B A S N] - > [B S C*H] // 3 * [B A S N] - > [B S C*H]
......
...@@ -15,6 +15,7 @@ import distutils.log ...@@ -15,6 +15,7 @@ import distutils.log
import distutils.sysconfig import distutils.sysconfig
from distutils.errors import CompileError, LinkError from distutils.errors import CompileError, LinkError
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List
YELLOW = '\033[93m' YELLOW = '\033[93m'
END = '\033[0m' END = '\033[0m'
...@@ -524,7 +525,7 @@ class CUDAOpBuilder(OpBuilder): ...@@ -524,7 +525,7 @@ class CUDAOpBuilder(OpBuilder):
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ... TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
- `cross_compile_archs` uses ; separator. - `cross_compile_archs` uses ; separator.
...@@ -554,6 +555,12 @@ class CUDAOpBuilder(OpBuilder): ...@@ -554,6 +555,12 @@ class CUDAOpBuilder(OpBuilder):
cross_compile_archs = get_default_compute_capabilities() cross_compile_archs = get_default_compute_capabilities()
ccs = cross_compile_archs.split(';') ccs = cross_compile_archs.split(';')
ccs = self.filter_ccs(ccs)
if len(ccs) == 0:
raise RuntimeError(
f"Unable to load {self.name} op due to no compute capabilities remaining after filtering"
)
args = [] args = []
for cc in ccs: for cc in ccs:
num = cc[0] + cc[2] num = cc[0] + cc[2]
...@@ -563,6 +570,13 @@ class CUDAOpBuilder(OpBuilder): ...@@ -563,6 +570,13 @@ class CUDAOpBuilder(OpBuilder):
return args return args
def filter_ccs(self, ccs: List[str]):
"""
Prune any compute capabilities that are not compatible with the builder. Should log
which CCs have been pruned.
"""
return ccs
def version_dependent_macros(self): def version_dependent_macros(self):
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
version_ge_1_1 = [] version_ge_1_1 = []
......
...@@ -25,6 +25,11 @@ class InferenceBuilder(CUDAOpBuilder): ...@@ -25,6 +25,11 @@ class InferenceBuilder(CUDAOpBuilder):
sys_cuda_major, _ = installed_cuda_version() sys_cuda_major, _ = installed_cuda_version()
torch_cuda_major = int(torch.version.cuda.split('.')[0]) torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major cuda_capability = torch.cuda.get_device_properties(0).major
if cuda_capability < 6:
self.warning(
"NVIDIA Inference is only supported on Pascal and newer architectures"
)
cuda_okay = False
if cuda_capability >= 8: if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11: if torch_cuda_major < 11 or sys_cuda_major < 11:
self.warning( self.warning(
...@@ -32,6 +37,18 @@ class InferenceBuilder(CUDAOpBuilder): ...@@ -32,6 +37,18 @@ class InferenceBuilder(CUDAOpBuilder):
cuda_okay = False cuda_okay = False
return super().is_compatible(verbose) and cuda_okay return super().is_compatible(verbose) and cuda_okay
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
if int(cc[0]) >= 6:
ccs_retained.append(cc)
else:
ccs_pruned.append(cc)
if len(ccs_pruned) > 0:
self.warning(f"Filtered compute capabilities {ccs_pruned}")
return ccs_retained
def sources(self): def sources(self):
return [ return [
'csrc/transformer/inference/csrc/pt_binding.cpp', 'csrc/transformer/inference/csrc/pt_binding.cpp',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册