Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
b7ad2a2d
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b7ad2a2d
编写于
3月 28, 2023
作者:
M
Molly Smith
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fp32
上级
2d7d1749
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
60 addition
and
65 deletion
+60
-65
csrc/transformer/inference/csrc/softmax.cu
csrc/transformer/inference/csrc/softmax.cu
+60
-65
未找到文件。
csrc/transformer/inference/csrc/softmax.cu
浏览文件 @
b7ad2a2d
...
...
@@ -276,11 +276,9 @@ __global__ void attn_softmax_v2(float* vals,
vals
+=
(
iter_offset
*
sequence_length
);
int
batch_idx
=
iter_offset
/
(
num_seq
*
heads
);
int
alibi_offset
=
batch_idx
*
heads
*
mp_size
+
head_offset
;
int
mask_offset
=
batch_idx
*
mask_stride
+
(
iter_offset
%
mask_stride
);
mask_offset
=
mask_offset
*
sequence_length
;
int
seq_id
=
iter_offset
%
num_seq
;
int
seq_id4
=
seq_id
>>
2
;
int
real_seq_id
=
seq_id
+
(
num_seq
==
sequence_length
?
0
:
sequence_length
);
int
window_stride4
=
(
local_attention
&&
(
real_seq_id
>>
2
)
>
(
window_size
>>
2
))
...
...
@@ -292,58 +290,55 @@ __global__ void attn_softmax_v2(float* vals,
float
max_val
=
minus_infinity
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
if
((
!
triangular
||
((
data_id
>>
2
)
<=
seq_id4
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
data
[
i
].
x
=
(
data_id
>
window_stride
?
vals
[
data_id
]
:
minus_infinity
);
data
[
i
].
y
=
((
!
triangular
||
((
data_id
+
1
)
<=
seq_id
))
&&
(
data_id
+
1
)
>
window_stride
)
?
vals
[
data_id
+
1
]
:
minus_infinity
;
data
[
i
].
z
=
((
!
triangular
||
((
data_id
+
2
)
<=
seq_id
))
&&
(
data_id
+
2
)
>
window_stride
)
?
vals
[
data_id
+
2
]
:
minus_infinity
;
data
[
i
].
w
=
((
!
triangular
||
((
data_id
+
3
)
<=
seq_id
))
&&
(
data_id
+
3
)
>
window_stride
)
?
vals
[
data_id
+
3
]
:
minus_infinity
;
if
(
attn_mask
)
{
data
[
i
].
x
+=
attn_mask
[
data_id
+
mask_offset
];
data
[
i
].
y
+=
attn_mask
[
data_id
+
mask_offset
+
1
];
data
[
i
].
z
+=
attn_mask
[
data_id
+
mask_offset
+
2
];
data
[
i
].
w
+=
attn_mask
[
data_id
+
mask_offset
+
3
];
}
}
else
{
data
[
i
].
x
=
data_id
>
window_stride
?
vals
[
data_id
]
:
minus_infinity
;
data
[
i
].
y
=
(((
!
triangular
||
(
data_id
+
1
)
<=
seq_id
))
&&
(
data_id
+
1
)
>
window_stride
&&
(
data_id
+
1
)
<
sequence_length
)
?
(
vals
[
data_id
+
1
])
:
minus_infinity
;
data
[
i
].
z
=
(((
!
triangular
||
(
data_id
+
2
)
<=
seq_id
))
&&
(
data_id
+
2
)
>
window_stride
&&
(
data_id
+
2
)
<
sequence_length
)
?
(
vals
[
data_id
+
2
])
:
minus_infinity
;
data
[
i
].
w
=
minus_infinity
;
if
(
attn_mask
)
{
data
[
i
].
x
+=
attn_mask
[
data_id
+
mask_offset
];
if
((
data_id
+
1
)
<
sequence_length
)
data
[
i
].
y
+=
attn_mask
[
data_id
+
mask_offset
+
1
];
if
((
data_id
+
2
)
<
sequence_length
)
data
[
i
].
z
+=
attn_mask
[
data_id
+
mask_offset
+
2
];
}
}
max_val
=
(
data
[
i
].
x
>
max_val
?
data
[
i
].
x
:
max_val
);
max_val
=
(
data
[
i
].
y
>
max_val
?
data
[
i
].
y
:
max_val
);
max_val
=
(
data
[
i
].
z
>
max_val
?
data
[
i
].
z
:
max_val
);
max_val
=
(
data
[
i
].
w
>
max_val
?
data
[
i
].
w
:
max_val
);
}
else
{
data
[
i
].
x
=
minus_infinity
;
data
[
i
].
y
=
minus_infinity
;
data
[
i
].
z
=
minus_infinity
;
data
[
i
].
w
=
minus_infinity
;
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
);
bool
check1
=
((
!
triangular
||
(
data_id
<=
seq_id
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
);
bool
low_x_check
=
check1
&&
(
data_id
>
window_stride
);
bool
low_y_check
=
check1
&&
((
data_id
+
reduceWidth
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
)
>
window_stride
);
bool
high_x_check
=
check1
&&
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
2
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
2
)
>
window_stride
);
bool
high_y_check
=
check1
&&
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
&&
((
!
triangular
||
((
data_id
+
reduceWidth
*
3
)
<=
seq_id
))
&&
(
data_id
+
reduceWidth
*
3
)
>
window_stride
);
if
(
attn_mask
){
data
[
i
].
x
=
low_x_check
?
vals
[
data_id
]
+
attn_mask
[
data_id
+
mask_offset
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
y
=
low_y_check
?
vals
[
data_id
+
reduceWidth
]
+
attn_mask
[
data_id
+
mask_offset
+
reduceWidth
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
z
=
high_x_check
?
vals
[
data_id
+
reduceWidth
*
2
]
+
attn_mask
[
data_id
+
mask_offset
+
reduceWidth
*
2
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
w
=
high_y_check
?
vals
[
data_id
+
reduceWidth
*
3
]
+
attn_mask
[
data_id
+
mask_offset
+
reduceWidth
*
3
]
:
minus_infinity
;
b
.
sync
();
}
else
{
data
[
i
].
x
=
low_x_check
?
vals
[
data_id
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
y
=
low_y_check
?
vals
[
data_id
+
reduceWidth
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
z
=
high_x_check
?
vals
[
data_id
+
reduceWidth
*
2
]
:
minus_infinity
;
b
.
sync
();
data
[
i
].
w
=
high_y_check
?
vals
[
data_id
+
reduceWidth
*
3
]
:
minus_infinity
;
b
.
sync
();
}
max_val
=
(
data
[
i
].
x
>
max_val
?
data
[
i
].
x
:
max_val
);
max_val
=
(
data
[
i
].
y
>
max_val
?
data
[
i
].
y
:
max_val
);
max_val
=
(
data
[
i
].
z
>
max_val
?
data
[
i
].
z
:
max_val
);
max_val
=
(
data
[
i
].
w
>
max_val
?
data
[
i
].
w
:
max_val
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
...
...
@@ -394,19 +389,19 @@ __global__ void attn_softmax_v2(float* vals,
sum
+=
1e-6
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
);
if
(
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
vals
[
data_id
]
=
data
[
i
].
x
/
sum
;
vals
[
data_id
+
1
]
=
data
[
i
].
y
/
sum
;
vals
[
data_id
+
2
]
=
data
[
i
].
z
/
sum
;
vals
[
data_id
+
3
]
=
data
[
i
].
w
/
sum
;
}
else
{
vals
[
data_id
]
=
data
[
i
].
x
/
sum
;
if
((
data_id
+
1
)
<
sequence_length
)
vals
[
data_id
+
1
]
=
data
[
i
].
y
/
sum
;
if
((
data_id
+
2
)
<
sequence_length
)
vals
[
data_id
+
2
]
=
data
[
i
].
z
/
sum
;
}
vals
[
data_id
]
=
data
[
i
].
x
/
sum
;
b
.
sync
();
if
((
data_id
+
reduceWidth
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
]
=
data
[
i
].
y
/
sum
;
b
.
sync
();
if
((
data_id
+
reduceWidth
*
2
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
2
]
=
data
[
i
].
z
/
sum
;
b
.
sync
();
if
((
data_id
+
reduceWidth
*
3
)
<
sequence_length
)
vals
[
data_id
+
reduceWidth
*
3
]
=
data
[
i
].
w
/
sum
;
b
.
sync
();
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录