Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
150eb7d9
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
150eb7d9
编写于
3月 30, 2023
作者:
M
Molly Smith
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
naming. precommit
上级
bc450d48
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
175 addition
and
155 deletion
+175
-155
csrc/transformer/general_kernels.cu
csrc/transformer/general_kernels.cu
+2
-2
csrc/transformer/inference/csrc/softmax.cu
csrc/transformer/inference/csrc/softmax.cu
+173
-153
未找到文件。
csrc/transformer/general_kernels.cu
浏览文件 @
150eb7d9
...
...
@@ -161,7 +161,7 @@ void launch_fused_add2<float>(float* out,
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
...
...
@@ -178,7 +178,7 @@ void launch_fused_add2<__half>(__half* out,
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
...
...
csrc/transformer/inference/csrc/softmax.cu
浏览文件 @
150eb7d9
...
...
@@ -86,92 +86,110 @@ __global__ void attn_softmax_v2(__half* vals,
// if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
);
bool
check1
=
((
!
triangular
||
(
data_id
<=
seq_id
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
);
bool
low_x_check
=
check1
&&
(
data_id
>
window_stride
);
bool
low_y_check
=
check1
&&
((
data_id
+
reduceWidth
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
)
>
window_stride
);
bool
high_x_check
=
check1
&&
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
2
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
2
)
>
window_stride
);
bool
high_y_check
=
check1
&&
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
3
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
3
)
>
window_stride
);
if
(
mask
&&
alibi
){
low_data
[
i
].
x
=
low_x_check
?
__half2float
(
vals
[
data_id
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
]))
+
(
__half2float
(
mask
[
data_id
+
mask_offset
]))
:
minus_infinity
;
b
.
sync
();
low_data
[
i
].
y
=
low_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
]))
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
x
=
high_x_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
2
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
*
2
]))
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
*
2
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
y
=
high_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
3
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
*
3
]))
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
*
3
]))
:
minus_infinity
;
b
.
sync
();
}
else
if
(
mask
){
low_data
[
i
].
x
=
low_x_check
?
__half2float
(
vals
[
data_id
])
*
layer_scale
+
(
__half2float
(
mask
[
data_id
+
mask_offset
]))
:
minus_infinity
;
b
.
sync
();
low_data
[
i
].
y
=
low_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
])
*
layer_scale
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
x
=
high_x_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
2
])
*
layer_scale
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
*
2
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
y
=
high_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
3
])
*
layer_scale
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
*
3
]))
:
minus_infinity
;
b
.
sync
();
}
else
if
(
alibi
){
low_data
[
i
].
x
=
low_x_check
?
__half2float
(
vals
[
data_id
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
]))
bool
check1
=
((
!
triangular
||
(
data_id
<=
seq_id
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
);
bool
low_x_check
=
check1
&&
(
data_id
>
window_stride
);
bool
low_y_check
=
check1
&&
((
data_id
+
reduceWidth
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
)
>
window_stride
);
bool
high_x_check
=
check1
&&
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
2
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
2
)
>
window_stride
);
bool
high_y_check
=
check1
&&
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
3
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
3
)
>
window_stride
);
if
(
mask
&&
alibi
)
{
low_data
[
i
].
x
=
low_x_check
?
__half2float
(
vals
[
data_id
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
]))
+
(
__half2float
(
mask
[
data_id
+
mask_offset
]))
:
minus_infinity
;
b
.
sync
();
low_data
[
i
].
y
=
low_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
]))
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
]))
:
minus_infinity
;
b
.
sync
();
low_data
[
i
].
y
=
low_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
x
=
high_x_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
2
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
*
2
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
y
=
high_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
3
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
*
3
]))
:
minus_infinity
;
b
.
sync
();
}
else
{
low_data
[
i
].
x
=
low_x_check
?
__half2float
(
vals
[
data_id
])
*
layer_scale
b
.
sync
();
high_data
[
i
].
x
=
high_x_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
2
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
*
2
]))
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
*
2
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
y
=
high_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
3
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
*
3
]))
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
*
3
]))
:
minus_infinity
;
b
.
sync
();
}
else
if
(
mask
)
{
low_data
[
i
].
x
=
low_x_check
?
__half2float
(
vals
[
data_id
])
*
layer_scale
+
(
__half2float
(
mask
[
data_id
+
mask_offset
]))
:
minus_infinity
;
b
.
sync
();
low_data
[
i
].
y
=
low_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
])
*
layer_scale
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
x
=
high_x_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
2
])
*
layer_scale
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
*
2
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
y
=
high_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
3
])
*
layer_scale
+
(
__half2float
(
mask
[
data_id
+
mask_offset
+
reduceWidth
*
3
]))
:
minus_infinity
;
b
.
sync
();
}
else
if
(
alibi
)
{
low_data
[
i
].
x
=
low_x_check
?
__half2float
(
vals
[
data_id
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
]))
:
minus_infinity
;
b
.
sync
();
low_data
[
i
].
y
=
low_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
]))
:
minus_infinity
;
b
.
sync
();
low_data
[
i
].
y
=
low_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
])
*
layer_scale
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
x
=
high_x_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
2
])
*
layer_scale
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
y
=
high_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
3
])
*
layer_scale
:
minus_infinity
;
b
.
sync
();
b
.
sync
();
high_data
[
i
].
x
=
high_x_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
2
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
*
2
]))
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
y
=
high_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
3
])
*
layer_scale
+
(
__half2float
(
alibi
[
data_id
+
alibi_offset
+
reduceWidth
*
3
]))
:
minus_infinity
;
b
.
sync
();
}
else
{
low_data
[
i
].
x
=
low_x_check
?
__half2float
(
vals
[
data_id
])
*
layer_scale
:
minus_infinity
;
b
.
sync
();
low_data
[
i
].
y
=
low_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
])
*
layer_scale
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
x
=
high_x_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
2
])
*
layer_scale
:
minus_infinity
;
b
.
sync
();
high_data
[
i
].
y
=
high_y_check
?
__half2float
(
vals
[
data_id
+
reduceWidth
*
3
])
*
layer_scale
:
minus_infinity
;
b
.
sync
();
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val
=
(
low_data
[
i
].
x
>
max_val
?
low_data
[
i
].
x
:
max_val
);
max_val
=
(
low_data
[
i
].
y
>
max_val
?
low_data
[
i
].
y
:
max_val
);
max_val
=
(
high_data
[
i
].
x
>
max_val
?
high_data
[
i
].
x
:
max_val
);
max_val
=
(
high_data
[
i
].
y
>
max_val
?
high_data
[
i
].
y
:
max_val
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
...
...
@@ -227,13 +245,13 @@ __global__ void attn_softmax_v2(__half* vals,
b
.
sync
();
if
((
data_id
+
reduceWidth
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
]
=
__float2half
(
low_data
[
i
].
y
/
sum
);
b
.
sync
();
if
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
2
]
=
__float2half
(
high_data
[
i
].
x
/
sum
);
b
.
sync
();
if
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
3
]
=
__float2half
(
high_data
[
i
].
y
/
sum
);
b
.
sync
();
b
.
sync
();
if
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
2
]
=
__float2half
(
high_data
[
i
].
x
/
sum
);
b
.
sync
();
if
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
3
]
=
__float2half
(
high_data
[
i
].
y
/
sum
);
b
.
sync
();
}
}
}
...
...
@@ -291,54 +309,50 @@ __global__ void attn_softmax_v2(float* vals,
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
);
bool
check1
=
((
!
triangular
||
(
data_id
<=
seq_id
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
);
bool
low_x_check
=
check1
&&
(
data_id
>
window_stride
);
bool
low_y_check
=
check1
&&
((
data_id
+
reduceWidth
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
)
>
window_stride
);
bool
high_x_check
=
check1
&&
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
2
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
2
)
>
window_stride
);
bool
high_y_check
=
check1
&&
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
3
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
3
)
>
window_stride
);
if
(
attn_mask
){
data
[
i
].
x
=
low_x_check
?
vals
[
data_id
]
+
attn_mask
[
data_id
+
mask_offset
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
y
=
low_y_check
?
vals
[
data_id
+
reduceWidth
]
+
attn_mask
[
data_id
+
mask_offset
+
reduceWidth
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
z
=
high_x_check
?
vals
[
data_id
+
reduceWidth
*
2
]
+
attn_mask
[
data_id
+
mask_offset
+
reduceWidth
*
2
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
w
=
high_y_check
?
vals
[
data_id
+
reduceWidth
*
3
]
+
attn_mask
[
data_id
+
mask_offset
+
reduceWidth
*
3
]
:
minus_infinity
;
b
.
sync
();
}
else
{
data
[
i
].
x
=
low_x_check
?
vals
[
data_id
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
y
=
low_y_check
?
vals
[
data_id
+
reduceWidth
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
z
=
high_x_check
?
vals
[
data_id
+
reduceWidth
*
2
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
w
=
high_y_check
?
vals
[
data_id
+
reduceWidth
*
3
]
:
minus_infinity
;
b
.
sync
();
bool
check1
=
((
!
triangular
||
(
data_id
<=
seq_id
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
);
bool
x_check
=
check1
&&
(
data_id
>
window_stride
);
bool
y_check
=
check1
&&
((
data_id
+
reduceWidth
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
)
>
window_stride
);
bool
z_check
=
check1
&&
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
2
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
2
)
>
window_stride
);
bool
w_check
=
check1
&&
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
3
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
3
)
>
window_stride
);
if
(
attn_mask
)
{
data
[
i
].
x
=
x_check
?
vals
[
data_id
]
+
attn_mask
[
data_id
+
mask_offset
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
y
=
y_check
?
vals
[
data_id
+
reduceWidth
]
+
attn_mask
[
data_id
+
mask_offset
+
reduceWidth
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
z
=
z_check
?
vals
[
data_id
+
reduceWidth
*
2
]
+
attn_mask
[
data_id
+
mask_offset
+
reduceWidth
*
2
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
w
=
w_check
?
vals
[
data_id
+
reduceWidth
*
3
]
+
attn_mask
[
data_id
+
mask_offset
+
reduceWidth
*
3
]
:
minus_infinity
;
b
.
sync
();
}
else
{
data
[
i
].
x
=
x_check
?
vals
[
data_id
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
y
=
y_check
?
vals
[
data_id
+
reduceWidth
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
z
=
z_check
?
vals
[
data_id
+
reduceWidth
*
2
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
w
=
w_check
?
vals
[
data_id
+
reduceWidth
*
3
]
:
minus_infinity
;
b
.
sync
();
}
max_val
=
(
data
[
i
].
x
>
max_val
?
data
[
i
].
x
:
max_val
);
max_val
=
(
data
[
i
].
y
>
max_val
?
data
[
i
].
y
:
max_val
);
max_val
=
(
data
[
i
].
z
>
max_val
?
data
[
i
].
z
:
max_val
);
max_val
=
(
data
[
i
].
w
>
max_val
?
data
[
i
].
w
:
max_val
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
...
...
@@ -395,22 +409,35 @@ __global__ void attn_softmax_v2(float* vals,
b
.
sync
();
if
((
data_id
+
reduceWidth
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
]
=
data
[
i
].
y
/
sum
;
b
.
sync
();
if
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
2
]
=
data
[
i
].
z
/
sum
;
b
.
sync
();
if
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
3
]
=
data
[
i
].
w
/
sum
;
b
.
sync
();
b
.
sync
();
if
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
2
]
=
data
[
i
].
z
/
sum
;
b
.
sync
();
if
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
3
]
=
data
[
i
].
w
/
sum
;
b
.
sync
();
}
}
}
}
#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \
attn_softmax_v2<iterations><<<grid, block, 0, stream>>> \
(vals,mask,alibi,layer_scale,triangular,recompute,local_attention,window_size,total_count,heads, \
sequence_length,num_seq,head_offset,mask_stride,mp_size,reduce_width);
#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \
attn_softmax_v2<iterations><<<grid, block, 0, stream>>>(vals, \
mask, \
alibi, \
layer_scale, \
triangular, \
recompute, \
local_attention, \
window_size, \
total_count, \
heads, \
sequence_length, \
num_seq, \
head_offset, \
mask_stride, \
mp_size, \
reduce_width);
template
<
typename
T
>
void
launch_attn_softmax_v2
(
T
*
vals
,
...
...
@@ -457,30 +484,23 @@ void launch_attn_softmax_v2(T* vals,
dim3
grid
((
total_count
+
partitions
-
1
)
/
partitions
);
dim3
block
(
attn_threads
);
if
(
sequence_length
<=
32768
){
if
(
iterations
==
1
){
if
(
sequence_length
<=
32768
)
{
if
(
iterations
==
1
)
{
LAUNCH_ATTN_SOFTMAX_V2
(
1
);
}
else
if
(
iterations
==
2
){
}
else
if
(
iterations
==
2
)
{
LAUNCH_ATTN_SOFTMAX_V2
(
2
);
}
else
if
(
iterations
==
4
){
}
else
if
(
iterations
==
4
)
{
LAUNCH_ATTN_SOFTMAX_V2
(
4
);
}
else
if
(
iterations
==
8
){
}
else
if
(
iterations
==
8
)
{
LAUNCH_ATTN_SOFTMAX_V2
(
8
);
}
else
if
(
iterations
==
16
){
}
else
if
(
iterations
==
16
)
{
LAUNCH_ATTN_SOFTMAX_V2
(
16
);
}
else
if
(
iterations
==
32
){
}
else
if
(
iterations
==
32
)
{
LAUNCH_ATTN_SOFTMAX_V2
(
32
);
}
else
if
(
iterations
==
64
){
}
else
if
(
iterations
==
64
)
{
LAUNCH_ATTN_SOFTMAX_V2
(
64
);
}
}
else
}
else
throw
std
::
runtime_error
(
"Unsupport Seq_Length!"
);
}
...
...
GitCode官方
@csdn_codechina
mentioned in commit
e73de8ce
·
4月 06, 2023
mentioned in commit
e73de8ce
mentioned in commit e73de8cee89ca1b2296c8eead2ae3d904e24271c
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录