Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
b8416282
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
b8416282
编写于
12月 06, 2022
作者:
C
Connor Holmes
提交者:
GitHub
12月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Drop Maxwell Support (#2574)
* Officially drop Maxwell support * Formatting * Comparison mismatch fix
上级
06938835
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
72 addition
and
110 deletion
+72
-110
csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
+0
-4
csrc/transformer/inference/csrc/dequantize.cu
csrc/transformer/inference/csrc/dequantize.cu
+0
-3
csrc/transformer/inference/csrc/gelu.cu
csrc/transformer/inference/csrc/gelu.cu
+26
-41
csrc/transformer/inference/csrc/relu.cu
csrc/transformer/inference/csrc/relu.cu
+14
-39
csrc/transformer/inference/csrc/softmax.cu
csrc/transformer/inference/csrc/softmax.cu
+0
-3
csrc/transformer/inference/csrc/transform.cu
csrc/transformer/inference/csrc/transform.cu
+0
-19
op_builder/builder.py
op_builder/builder.py
+15
-1
op_builder/transformer_inference.py
op_builder/transformer_inference.py
+17
-0
未找到文件。
csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
浏览文件 @
b8416282
...
@@ -67,7 +67,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
...
@@ -67,7 +67,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
unsigned
total_count
,
unsigned
total_count
,
int
max_out_tokens
)
int
max_out_tokens
)
{
{
#if __CUDA_ARCH__ >= 700
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
...
@@ -102,7 +101,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
...
@@ -102,7 +101,6 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
lane
+=
WARP_SIZE
;
lane
+=
WARP_SIZE
;
}
}
}
}
#endif
}
}
__global__
void
apply_rotary_pos_emb1
(
float
*
mixed_query
,
__global__
void
apply_rotary_pos_emb1
(
float
*
mixed_query
,
float
*
key_layer
,
float
*
key_layer
,
...
@@ -159,7 +157,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
...
@@ -159,7 +157,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
unsigned
total_count
,
unsigned
total_count
,
int
max_out_tokens
)
int
max_out_tokens
)
{
{
#if __CUDA_ARCH__ >= 700
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
...
@@ -205,7 +202,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
...
@@ -205,7 +202,6 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
lane
+=
WARP_SIZE
;
lane
+=
WARP_SIZE
;
}
}
}
}
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
csrc/transformer/inference/csrc/dequantize.cu
浏览文件 @
b8416282
...
@@ -50,8 +50,6 @@ __global__ void dequantize_kernel(__half* output,
...
@@ -50,8 +50,6 @@ __global__ void dequantize_kernel(__half* output,
unsigned
groups
,
unsigned
groups
,
unsigned
merge_count
)
unsigned
merge_count
)
{
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned
merge_hidden
=
hidden_dim
>>
merge_count
;
unsigned
merge_hidden
=
hidden_dim
>>
merge_count
;
unsigned
quantization_stride
=
(
merge_hidden
*
output_size
)
/
groups
;
unsigned
quantization_stride
=
(
merge_hidden
*
output_size
)
/
groups
;
...
@@ -75,7 +73,6 @@ __global__ void dequantize_kernel(__half* output,
...
@@ -75,7 +73,6 @@ __global__ void dequantize_kernel(__half* output,
output
[
q_index
]
=
__float2half
(
scale_data
*
(
float
)
q
);
output
[
q_index
]
=
__float2half
(
scale_data
*
(
float
)
q
);
tid
+=
blockDim
.
x
;
tid
+=
blockDim
.
x
;
}
}
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
csrc/transformer/inference/csrc/gelu.cu
浏览文件 @
b8416282
...
@@ -17,6 +17,9 @@ inline __device__ float gelu(const float x)
...
@@ -17,6 +17,9 @@ inline __device__ float gelu(const float x)
return
x
*
0.5
f
*
(
1.0
f
+
tanhf
(
sqrt_param
*
(
x
+
mul_param
*
x
*
x
*
x
)));
return
x
*
0.5
f
*
(
1.0
f
+
tanhf
(
sqrt_param
*
(
x
+
mul_param
*
x
*
x
*
x
)));
}
}
/*
In-place gelu(biasAdd(x)) for channels last
*/
template
<
typename
T
>
template
<
typename
T
>
__global__
void
fused_bias_gelu
(
T
*
input
,
const
T
*
bias
,
int
total_count
,
int
intermediate_size
)
__global__
void
fused_bias_gelu
(
T
*
input
,
const
T
*
bias
,
int
total_count
,
int
intermediate_size
)
{
{
...
@@ -64,63 +67,51 @@ void launch_bias_gelu(T* input,
...
@@ -64,63 +67,51 @@ void launch_bias_gelu(T* input,
template
void
launch_bias_gelu
<
float
>(
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
float
>(
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
__half
>(
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
__half
>(
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
// Not called directly from DeepSpeed, but used in ds_qkv_gemm_int8, ds_linear_layer, etc.
/*
__global__
void
fused_bias_add
(
float
*
input
,
const
float
*
bias
,
int
total_count
,
int
hidden_size
)
In-place channels-last bias add
{
*/
constexpr
int
granularity
=
16
;
template
<
typename
T
>
constexpr
int
vals_per_access
=
granularity
/
sizeof
(
float
);
__global__
void
fused_bias_add
(
T
*
input
,
const
T
*
bias
,
int
total_count
,
int
intermediate_size
)
const
int
offset
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
vals_per_access
;
if
(
offset
<
total_count
)
{
float
data
[
vals_per_access
];
float
bias_data
[
vals_per_access
];
mem_access
::
load_global
<
granularity
>
(
data
,
input
+
offset
);
mem_access
::
load_global
<
granularity
>
(
bias_data
,
bias
+
(
offset
%
hidden_size
));
#pragma unroll
for
(
int
i
=
0
;
i
<
vals_per_access
;
i
++
)
{
data
[
i
]
+=
bias_data
[
i
];
}
mem_access
::
store_global
<
granularity
>
(
input
+
offset
,
data
);
}
}
__global__
void
fused_bias_add
(
__half
*
input
,
const
__half
*
bias
,
int
total_count
,
int
hidden_size
)
{
{
#ifdef HALF_PRECISION_AVAILABLE
// Input restriction: intermediate_size % vals_per_access == 0
constexpr
int
granularity
=
16
;
constexpr
int
granularity
=
16
;
constexpr
int
val
s_per_access
=
granularity
/
sizeof
(
__half
);
constexpr
int
val
ues_per_access
=
granularity
/
sizeof
(
T
);
const
int
offset
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
vals_per_access
;
const
int
offset
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
val
ue
s_per_access
;
if
(
offset
<
total_count
)
{
if
(
offset
<
total_count
)
{
__half2
data
[
vals_per_access
/
2
];
T
data
[
values_per_access
];
__half2
bias_data
[
vals_per_access
/
2
];
T
data_bias
[
values_per_access
];
mem_access
::
load_global
<
granularity
>
(
data
,
input
+
offset
);
mem_access
::
load_global
<
granularity
>
(
data
,
input
+
offset
);
mem_access
::
load_global
<
granularity
>
(
bias_data
,
bias
+
(
offset
%
hidden
_size
));
mem_access
::
load_global
<
granularity
>
(
data_bias
,
bias
+
(
offset
%
intermediate
_size
));
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
val
s_per_access
/
2
;
i
++
)
{
for
(
int
i
=
0
;
i
<
val
ues_per_access
;
i
++
)
{
float
2
data_f
=
__half22float2
(
data
[
i
]);
float
data_f
=
conversion
::
to
<
float
>
(
data
[
i
]);
float
2
bias_f
=
__half22float2
(
bias_data
[
i
]);
float
bias_f
=
conversion
::
to
<
float
>
(
data_bias
[
i
]);
data
[
i
]
=
__floats2half2_rn
(
data_f
.
x
+
bias_f
.
x
,
data_f
.
y
+
bias_f
.
y
);
data
[
i
]
=
conversion
::
to
<
T
>
(
data_f
+
bias_f
);
}
}
mem_access
::
store_global
<
granularity
>
(
input
+
offset
,
data
);
mem_access
::
store_global
<
granularity
>
(
input
+
offset
,
data
);
}
}
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
void
launch_bias_add
(
T
*
input
,
const
T
*
bias
,
int
hidden_size
,
int
batch_size
,
cudaStream_t
stream
)
void
launch_bias_add
(
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
{
constexpr
int
threads
=
1024
;
constexpr
int
threads
=
1024
;
constexpr
int
granularity
=
16
;
constexpr
int
granularity
=
16
;
const
int
total_count
=
batch_size
*
hidden
_size
;
const
int
total_count
=
batch_size
*
intermediate
_size
;
const
int
elems_per_block
=
threads
*
(
granularity
/
sizeof
(
T
));
const
int
elems_per_block
=
threads
*
(
granularity
/
sizeof
(
T
));
dim3
block_dims
(
threads
);
dim3
block_dims
(
threads
);
dim3
grid_dims
((
total_count
+
elems_per_block
-
1
)
/
elems_per_block
);
dim3
grid_dims
((
total_count
+
elems_per_block
-
1
)
/
elems_per_block
);
fused_bias_add
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
total_count
,
hidden_size
);
fused_bias_add
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
total_count
,
intermediate_size
);
}
}
template
void
launch_bias_add
<
float
>(
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_add
<
float
>(
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
...
@@ -181,8 +172,6 @@ __global__ void fused_bias_residual(__half* residual,
...
@@ -181,8 +172,6 @@ __global__ void fused_bias_residual(__half* residual,
const
float
mp_scale
,
const
float
mp_scale
,
const
bool
preln
)
const
bool
preln
)
{
{
#ifdef HALF_PRECISION_AVAILABLE
float2
*
res_fl2_ptr
=
reinterpret_cast
<
float2
*>
(
residual
);
float2
*
res_fl2_ptr
=
reinterpret_cast
<
float2
*>
(
residual
);
const
float2
*
hs_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
hidden_state
);
const
float2
*
hs_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
hidden_state
);
const
float2
*
attn_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
attn
);
const
float2
*
attn_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
attn
);
...
@@ -241,7 +230,6 @@ __global__ void fused_bias_residual(__half* residual,
...
@@ -241,7 +230,6 @@ __global__ void fused_bias_residual(__half* residual,
res_fl2_ptr
[
offset
]
=
res_fl2
;
res_fl2_ptr
[
offset
]
=
res_fl2
;
}
}
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -325,8 +313,6 @@ __global__ void gptj_residual_add(__half* residual,
...
@@ -325,8 +313,6 @@ __global__ void gptj_residual_add(__half* residual,
const
int
intermediate_size
,
const
int
intermediate_size
,
const
float
mp_scale
)
const
float
mp_scale
)
{
{
#ifdef HALF_PRECISION_AVAILABLE
float2
*
res_fl2_ptr
=
reinterpret_cast
<
float2
*>
(
residual
);
float2
*
res_fl2_ptr
=
reinterpret_cast
<
float2
*>
(
residual
);
const
float2
*
hs_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
hidden_state
);
const
float2
*
hs_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
hidden_state
);
const
float2
*
attn_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
attn
);
const
float2
*
attn_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
attn
);
...
@@ -379,7 +365,6 @@ __global__ void gptj_residual_add(__half* residual,
...
@@ -379,7 +365,6 @@ __global__ void gptj_residual_add(__half* residual,
res_fl2_ptr
[
offset
]
=
res_fl2
;
res_fl2_ptr
[
offset
]
=
res_fl2
;
}
}
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
csrc/transformer/inference/csrc/relu.cu
浏览文件 @
b8416282
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
Copyright 2022 The Microsoft DeepSpeed Team
Copyright 2022 The Microsoft DeepSpeed Team
*/
*/
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
#include "inference_cuda_layers.h"
#include "memory_access_utils.h"
#include "memory_access_utils.h"
...
@@ -11,58 +12,32 @@ namespace cg = cooperative_groups;
...
@@ -11,58 +12,32 @@ namespace cg = cooperative_groups;
inline
__device__
float
relu
(
const
float
x
)
{
return
x
<
0
?
0
:
x
;
}
inline
__device__
float
relu
(
const
float
x
)
{
return
x
<
0
?
0
:
x
;
}
__global__
void
fused_bias_relu
(
float
*
input
,
/*
const
float
*
bias
,
In-place relu(biasAdd(x)) for channels last
int
total_count
,
*/
int
intermediate_size
)
template
<
typename
T
>
__global__
void
fused_bias_relu
(
T
*
input
,
const
T
*
bias
,
int
total_count
,
int
intermediate_size
)
{
{
// Input restriction: intermediate_size % vals_per_access == 0
// Input restriction: intermediate_size % vals_per_access == 0
constexpr
int
granularity
=
16
;
constexpr
int
granularity
=
16
;
constexpr
int
val
s_per_access
=
granularity
/
sizeof
(
float
);
constexpr
int
val
ues_per_access
=
granularity
/
sizeof
(
T
);
const
int
offset
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
vals_per_access
;
const
int
offset
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
val
ue
s_per_access
;
if
(
offset
<
total_count
)
{
if
(
offset
<
total_count
)
{
float
data
[
val
s_per_access
];
T
data
[
value
s_per_access
];
float
data_bias
[
val
s_per_access
];
T
data_bias
[
value
s_per_access
];
mem_access
::
load_global
<
granularity
>
(
data
,
input
+
offset
);
mem_access
::
load_global
<
granularity
>
(
data
,
input
+
offset
);
mem_access
::
load_global
<
granularity
>
(
data_bias
,
bias
+
(
offset
%
intermediate_size
));
mem_access
::
load_global
<
granularity
>
(
data_bias
,
bias
+
(
offset
%
intermediate_size
));
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
vals_per_access
;
i
++
)
{
data
[
i
]
=
relu
(
data
[
i
]
+
data_bias
[
i
]);
}
for
(
int
i
=
0
;
i
<
values_per_access
;
i
++
)
{
float
data_f
=
conversion
::
to
<
float
>
(
data
[
i
]);
mem_access
::
store_global
<
granularity
>
(
input
+
offset
,
data
);
float
bias_f
=
conversion
::
to
<
float
>
(
data_bias
[
i
]);
}
data
[
i
]
=
conversion
::
to
<
T
>
(
relu
(
data_f
+
bias_f
));
}
__global__
void
fused_bias_relu
(
__half
*
input
,
const
__half
*
bias
,
int
total_count
,
int
intermediate_size
)
{
// Input restriction: intermediate_size % vals_per_access == 0
// This kernel doubles the per-thread ALU workload as compared to the float implementation
#ifdef HALF_PRECISION_AVAILABLE
constexpr
int
granularity
=
16
;
constexpr
int
vals_per_access
=
granularity
/
sizeof
(
__half
);
int
offset
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
vals_per_access
;
if
(
offset
<
total_count
)
{
// Divide by 2 since we store two values per __half2
__half2
data
[
vals_per_access
/
2
];
__half2
bias_data
[
vals_per_access
/
2
];
mem_access
::
load_global
<
granularity
>
(
data
,
input
+
offset
);
mem_access
::
load_global
<
granularity
>
(
bias_data
,
bias
+
(
offset
%
intermediate_size
));
#pragma unroll
for
(
int
i
=
0
;
i
<
vals_per_access
/
2
;
i
++
)
{
float2
data_f
=
__half22float2
(
data
[
i
]);
float2
bias_f
=
__half22float2
(
bias_data
[
i
]);
data
[
i
]
=
__floats2half2_rn
(
relu
(
data_f
.
x
+
bias_f
.
x
),
relu
(
data_f
.
y
+
bias_f
.
y
));
}
}
mem_access
::
store_global
<
granularity
>
(
input
+
offset
,
data
);
mem_access
::
store_global
<
granularity
>
(
input
+
offset
,
data
);
}
}
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
csrc/transformer/inference/csrc/softmax.cu
浏览文件 @
b8416282
...
@@ -48,8 +48,6 @@ __global__ void attn_softmax_v2(__half* vals,
...
@@ -48,8 +48,6 @@ __global__ void attn_softmax_v2(__half* vals,
int
iterations
,
int
iterations
,
int
reduceWidth
)
int
reduceWidth
)
{
{
#ifdef HALF_PRECISION_AVAILABLE
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
...
@@ -232,7 +230,6 @@ __global__ void attn_softmax_v2(__half* vals,
...
@@ -232,7 +230,6 @@ __global__ void attn_softmax_v2(__half* vals,
}
}
}
}
}
}
#endif
}
}
__global__
void
attn_softmax_v2
(
float
*
vals
,
__global__
void
attn_softmax_v2
(
float
*
vals
,
...
...
csrc/transformer/inference/csrc/transform.cu
浏览文件 @
b8416282
...
@@ -90,8 +90,6 @@ __global__ void bias_add_transform_0213(__half* output, // q
...
@@ -90,8 +90,6 @@ __global__ void bias_add_transform_0213(__half* output, // q
int
head_ext
,
int
head_ext
,
int
max_out_tokens
)
int
max_out_tokens
)
{
{
#if __CUDA_ARCH__ >= 700
unsigned
half_dim
=
(
rotary_dim
<<
3
)
>>
1
;
unsigned
half_dim
=
(
rotary_dim
<<
3
)
>>
1
;
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d1_stride
=
hidden_dim
;
...
@@ -146,8 +144,6 @@ __global__ void bias_add_transform_0213(__half* output, // q
...
@@ -146,8 +144,6 @@ __global__ void bias_add_transform_0213(__half* output, // q
output_vec
[
d3
]
=
q
;
output_vec
[
d3
]
=
q
;
}
else
}
else
output_vec
[
d3
]
=
vals_vec
[
d3
];
output_vec
[
d3
]
=
vals_vec
[
d3
];
#endif
}
}
// [B S C*H] - > C * [B A S N]
// [B S C*H] - > C * [B A S N]
...
@@ -269,7 +265,6 @@ __global__ void pad_add_transform_0213(__half* output,
...
@@ -269,7 +265,6 @@ __global__ void pad_add_transform_0213(__half* output,
int
heads
,
int
heads
,
int
padded_head_size
)
int
padded_head_size
)
{
{
#if __CUDA_ARCH__ >= 700
float4
ZERO
;
float4
ZERO
;
const
__half2
zero_h
=
__float2half2_rn
(
0.
f
);
const
__half2
zero_h
=
__float2half2_rn
(
0.
f
);
__half2
*
ZERO_h
=
reinterpret_cast
<
__half2
*>
(
&
ZERO
);
__half2
*
ZERO_h
=
reinterpret_cast
<
__half2
*>
(
&
ZERO
);
...
@@ -303,8 +298,6 @@ __global__ void pad_add_transform_0213(__half* output,
...
@@ -303,8 +298,6 @@ __global__ void pad_add_transform_0213(__half* output,
output_vec
[
d3
]
=
vals_vec
[
d3
];
output_vec
[
d3
]
=
vals_vec
[
d3
];
else
else
output_vec
[
d3
]
=
ZERO
;
output_vec
[
d3
]
=
ZERO
;
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -409,8 +402,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
...
@@ -409,8 +402,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
int
heads
,
int
heads
,
int
head_ext
)
int
head_ext
)
{
{
#ifdef HALF_PRECISION_AVAILABLE
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d2_stride
=
hidden_dim
/
heads
;
...
@@ -455,8 +446,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
...
@@ -455,8 +446,6 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
output_half
[
2
]
=
vals_half
[
2
]
+
bias_half
[
2
];
output_half
[
2
]
=
vals_half
[
2
]
+
bias_half
[
2
];
output_half
[
3
]
=
vals_half
[
3
]
+
bias_half
[
3
];
output_half
[
3
]
=
vals_half
[
3
]
+
bias_half
[
3
];
output_vec
[
d3
]
=
output_arr
;
output_vec
[
d3
]
=
output_arr
;
#endif
}
}
__global__
void
bias_add_transform_0213_v2
(
__half
*
output
,
__global__
void
bias_add_transform_0213_v2
(
__half
*
output
,
...
@@ -466,7 +455,6 @@ __global__ void bias_add_transform_0213_v2(__half* output,
...
@@ -466,7 +455,6 @@ __global__ void bias_add_transform_0213_v2(__half* output,
int
seq_length
,
int
seq_length
,
int
heads
)
int
heads
)
{
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__
float4
in_data
[
3072
];
__shared__
float4
in_data
[
3072
];
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d0_stride
=
hidden_dim
*
seq_length
;
...
@@ -528,7 +516,6 @@ __global__ void bias_add_transform_0213_v2(__half* output,
...
@@ -528,7 +516,6 @@ __global__ void bias_add_transform_0213_v2(__half* output,
output_vec
[
out_index
+
iter_offset
]
=
output_vec
[
out_index
+
iter_offset
]
=
in_data
[
iter_row
*
d2_stride
+
d3
+
(
d2
%
2
)
*
(
d1_stride
*
blockDim
.
z
)];
in_data
[
iter_row
*
d2_stride
+
d3
+
(
d2
%
2
)
*
(
d1_stride
*
blockDim
.
z
)];
}
}
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -580,8 +567,6 @@ __global__ void transform4d_0213<__half>(__half* out,
...
@@ -580,8 +567,6 @@ __global__ void transform4d_0213<__half>(__half* out,
int
hidden_dim
,
int
hidden_dim
,
int
head_ext
)
int
head_ext
)
{
{
#if __CUDA_ARCH__ >= 700
int
d0_stride
=
hidden_dim
*
(
seq_length
/
head_ext
);
int
d0_stride
=
hidden_dim
*
(
seq_length
/
head_ext
);
int
d1_stride
=
hidden_dim
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d2_stride
=
hidden_dim
/
heads
;
...
@@ -606,8 +591,6 @@ __global__ void transform4d_0213<__half>(__half* out,
...
@@ -606,8 +591,6 @@ __global__ void transform4d_0213<__half>(__half* out,
out_vec
+=
(
d2
*
d1_stride
*
gridDim
.
y
);
out_vec
+=
(
d2
*
d1_stride
*
gridDim
.
y
);
out_vec
[
d3
]
=
in_vec
[
d3
];
out_vec
[
d3
]
=
in_vec
[
d3
];
#endif
}
}
__global__
void
transform4d_0213_v2
(
__half
*
out
,
__global__
void
transform4d_0213_v2
(
__half
*
out
,
...
@@ -616,7 +599,6 @@ __global__ void transform4d_0213_v2(__half* out,
...
@@ -616,7 +599,6 @@ __global__ void transform4d_0213_v2(__half* out,
int
seq_length
,
int
seq_length
,
int
hidden_dim
)
int
hidden_dim
)
{
{
#if __CUDA_ARCH__ >= 700
__shared__
float4
in_data
[
3072
];
__shared__
float4
in_data
[
3072
];
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d0_stride
=
hidden_dim
*
seq_length
;
...
@@ -657,7 +639,6 @@ __global__ void transform4d_0213_v2(__half* out,
...
@@ -657,7 +639,6 @@ __global__ void transform4d_0213_v2(__half* out,
int
iter_id
=
iter
*
iteration_stride
+
iter_index
;
int
iter_id
=
iter
*
iteration_stride
+
iter_index
;
out_vec
[
output_offset
+
iter_id
]
=
in_data
[
iter_id
];
out_vec
[
output_offset
+
iter_id
]
=
in_data
[
iter_id
];
}
}
#endif
}
}
// 3 * [B A S N] - > [B S C*H]
// 3 * [B A S N] - > [B S C*H]
...
...
op_builder/builder.py
浏览文件 @
b8416282
...
@@ -15,6 +15,7 @@ import distutils.log
...
@@ -15,6 +15,7 @@ import distutils.log
import
distutils.sysconfig
import
distutils.sysconfig
from
distutils.errors
import
CompileError
,
LinkError
from
distutils.errors
import
CompileError
,
LinkError
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
YELLOW
=
'
\033
[93m'
YELLOW
=
'
\033
[93m'
END
=
'
\033
[0m'
END
=
'
\033
[0m'
...
@@ -524,7 +525,7 @@ class CUDAOpBuilder(OpBuilder):
...
@@ -524,7 +525,7 @@ class CUDAOpBuilder(OpBuilder):
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
TORCH_CUDA_ARCH_LIST="
5.2
6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
- `cross_compile_archs` uses ; separator.
- `cross_compile_archs` uses ; separator.
...
@@ -554,6 +555,12 @@ class CUDAOpBuilder(OpBuilder):
...
@@ -554,6 +555,12 @@ class CUDAOpBuilder(OpBuilder):
cross_compile_archs
=
get_default_compute_capabilities
()
cross_compile_archs
=
get_default_compute_capabilities
()
ccs
=
cross_compile_archs
.
split
(
';'
)
ccs
=
cross_compile_archs
.
split
(
';'
)
ccs
=
self
.
filter_ccs
(
ccs
)
if
len
(
ccs
)
==
0
:
raise
RuntimeError
(
f
"Unable to load
{
self
.
name
}
op due to no compute capabilities remaining after filtering"
)
args
=
[]
args
=
[]
for
cc
in
ccs
:
for
cc
in
ccs
:
num
=
cc
[
0
]
+
cc
[
2
]
num
=
cc
[
0
]
+
cc
[
2
]
...
@@ -563,6 +570,13 @@ class CUDAOpBuilder(OpBuilder):
...
@@ -563,6 +570,13 @@ class CUDAOpBuilder(OpBuilder):
return
args
return
args
def
filter_ccs
(
self
,
ccs
:
List
[
str
]):
"""
Prune any compute capabilities that are not compatible with the builder. Should log
which CCs have been pruned.
"""
return
ccs
def
version_dependent_macros
(
self
):
def
version_dependent_macros
(
self
):
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
version_ge_1_1
=
[]
version_ge_1_1
=
[]
...
...
op_builder/transformer_inference.py
浏览文件 @
b8416282
...
@@ -25,6 +25,11 @@ class InferenceBuilder(CUDAOpBuilder):
...
@@ -25,6 +25,11 @@ class InferenceBuilder(CUDAOpBuilder):
sys_cuda_major
,
_
=
installed_cuda_version
()
sys_cuda_major
,
_
=
installed_cuda_version
()
torch_cuda_major
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
0
])
torch_cuda_major
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
0
])
cuda_capability
=
torch
.
cuda
.
get_device_properties
(
0
).
major
cuda_capability
=
torch
.
cuda
.
get_device_properties
(
0
).
major
if
cuda_capability
<
6
:
self
.
warning
(
"NVIDIA Inference is only supported on Pascal and newer architectures"
)
cuda_okay
=
False
if
cuda_capability
>=
8
:
if
cuda_capability
>=
8
:
if
torch_cuda_major
<
11
or
sys_cuda_major
<
11
:
if
torch_cuda_major
<
11
or
sys_cuda_major
<
11
:
self
.
warning
(
self
.
warning
(
...
@@ -32,6 +37,18 @@ class InferenceBuilder(CUDAOpBuilder):
...
@@ -32,6 +37,18 @@ class InferenceBuilder(CUDAOpBuilder):
cuda_okay
=
False
cuda_okay
=
False
return
super
().
is_compatible
(
verbose
)
and
cuda_okay
return
super
().
is_compatible
(
verbose
)
and
cuda_okay
def
filter_ccs
(
self
,
ccs
):
ccs_retained
=
[]
ccs_pruned
=
[]
for
cc
in
ccs
:
if
int
(
cc
[
0
])
>=
6
:
ccs_retained
.
append
(
cc
)
else
:
ccs_pruned
.
append
(
cc
)
if
len
(
ccs_pruned
)
>
0
:
self
.
warning
(
f
"Filtered compute capabilities
{
ccs_pruned
}
"
)
return
ccs_retained
def
sources
(
self
):
def
sources
(
self
):
return
[
return
[
'csrc/transformer/inference/csrc/pt_binding.cpp'
,
'csrc/transformer/inference/csrc/pt_binding.cpp'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录