Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
efa8aded
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
efa8aded
编写于
9月 12, 2022
作者:
A
Arash Bakhtiari
提交者:
GitHub
9月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the residual add mp scaling for GPTNeoX (#2310)
上级
a691ec60
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
8 addition
and
8 deletion
+8
-8
csrc/transformer/inference/csrc/gelu.cu
csrc/transformer/inference/csrc/gelu.cu
+8
-8
未找到文件。
csrc/transformer/inference/csrc/gelu.cu
浏览文件 @
efa8aded
...
...
@@ -321,10 +321,10 @@ __global__ void gptj_residual_add(float* input,
if
(
attnbias
)
{
float4
attn_bias
=
attnbias_cast
[
offset
%
intermediate_size
];
data
.
x
+=
attn_bias
.
x
*
mp_scale
;
data
.
y
+=
attn_bias
.
y
*
mp_scale
;
data
.
z
+=
attn_bias
.
z
*
mp_scale
;
data
.
w
+=
attn_bias
.
w
*
mp_scale
;
data
.
x
+=
attn_bias
.
x
;
data
.
y
+=
attn_bias
.
y
;
data
.
z
+=
attn_bias
.
z
;
data
.
w
+=
attn_bias
.
w
;
}
data
.
x
=
out
.
x
+
res_vec
.
x
+
(
data
.
x
+
bias_data
.
x
)
*
mp_scale
;
data
.
y
=
out
.
y
+
res_vec
.
y
+
(
data
.
y
+
bias_data
.
y
)
*
mp_scale
;
...
...
@@ -383,10 +383,10 @@ __global__ void gptj_residual_add(__half* input,
__half2
*
attnbias_half
=
reinterpret_cast
<
__half2
*>
(
&
attn_bias_vec
);
float2
attn_low_bias
=
__half22float2
(
attnbias_half
[
0
]);
float2
attn_high_bias
=
__half22float2
(
attnbias_half
[
1
]);
low_data
.
x
+=
attn_low_bias
.
x
*
mp_scale
;
low_data
.
y
+=
attn_low_bias
.
y
*
mp_scale
;
high_data
.
x
+=
attn_high_bias
.
x
*
mp_scale
;
high_data
.
y
+=
attn_high_bias
.
y
*
mp_scale
;
low_data
.
x
+=
attn_low_bias
.
x
;
low_data
.
y
+=
attn_low_bias
.
y
;
high_data
.
x
+=
attn_high_bias
.
x
;
high_data
.
y
+=
attn_high_bias
.
y
;
}
low_data
.
x
=
low_res
.
x
+
low_out
.
x
+
(
low_data
.
x
+
low_bias
.
x
)
*
mp_scale
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录