Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
e14d40e5
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
e14d40e5
编写于
9月 27, 2022
作者:
A
Arash Bakhtiari
提交者:
GitHub
9月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor fused_bias_residual kernels for better readability (#2356)
Co-authored-by:
N
Olatunji Ruwase
<
olruwase@microsoft.com
>
上级
79692af1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
100 addition
and
86 deletion
+100
-86
csrc/transformer/inference/csrc/gelu.cu
csrc/transformer/inference/csrc/gelu.cu
+100
-86
未找到文件。
csrc/transformer/inference/csrc/gelu.cu
浏览文件 @
e14d40e5
...
...
@@ -126,120 +126,127 @@ void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, c
template
void
launch_bias_add
<
float
>(
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_add
<
__half
>(
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
__global__
void
fused_bias_residual
(
float
*
input
,
float
*
output
,
float
*
attn
,
float
*
bias
,
float
*
attn
bias
,
int
total_count
,
int
intermediate_size
,
float
mp_scale
,
bool
preln
)
__global__
void
fused_bias_residual
(
float
*
residual
,
const
float
*
hidden_state
,
const
float
*
attn
,
const
float
*
bias
,
const
float
*
attn_
bias
,
const
int
total_count
,
const
int
intermediate_size
,
const
float
mp_scale
,
const
bool
preln
)
{
float4
*
input_cast
=
reinterpret_cast
<
float4
*>
(
input
);
float4
*
output_cast
=
reinterpret_cast
<
float4
*>
(
output
);
float4
*
attn_cast
=
reinterpret_cast
<
float4
*>
(
attn
);
float4
*
bias_cast
=
reinterpret_cast
<
float4
*>
(
bias
);
float4
*
attnbias_cast
=
reinterpret_cast
<
float4
*>
(
attn
bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float4
*
res_fl4_ptr
=
reinterpret_cast
<
float4
*>
(
residual
);
const
float4
*
hs_fl4_ptr
=
reinterpret_cast
<
const
float4
*>
(
hidden_state
);
const
float4
*
attn_fl4_ptr
=
reinterpret_cast
<
const
float4
*>
(
attn
);
const
float4
*
bias_fl4_ptr
=
reinterpret_cast
<
const
float4
*>
(
bias
);
const
float4
*
attn_bias_fl4_ptr
=
reinterpret_cast
<
const
float4
*>
(
attn_
bias
);
const
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float4
data
=
input_cast
[
offset
];
float4
out
=
output_cast
[
offset
];
float4
res_vec
=
attn_cast
[
offset
];
float4
bias_data
=
bias_cast
[
offset
%
intermediate_size
];
float4
attn_bias
=
attnbias_cast
[
offset
%
intermediate_size
];
float4
res_fl4
=
res_fl4_ptr
[
offset
];
const
float4
hs_fl4
=
hs_fl4_ptr
[
offset
];
const
float4
attn_fl4
=
attn_fl4_ptr
[
offset
];
const
float4
bias_fl4
=
bias_fl4_ptr
[
offset
%
intermediate_size
];
const
float4
attn_bias_fl4
=
attn_bias_fl4_ptr
[
offset
%
intermediate_size
];
if
(
preln
)
{
data
.
x
=
(
data
.
x
+
res_vec
.
x
+
bias_data
.
x
+
attn_bias
.
x
)
*
mp_scale
+
(
out
.
x
);
data
.
y
=
(
data
.
y
+
res_vec
.
y
+
bias_data
.
y
+
attn_bias
.
y
)
*
mp_scale
+
(
out
.
y
);
data
.
z
=
(
data
.
z
+
res_vec
.
z
+
bias_data
.
z
+
attn_bias
.
z
)
*
mp_scale
+
(
out
.
z
);
data
.
w
=
(
data
.
w
+
res_vec
.
w
+
bias_data
.
w
+
attn_bias
.
w
)
*
mp_scale
+
(
out
.
w
);
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_fl4
.
x
=
(
res_fl4
.
x
+
attn_fl4
.
x
+
bias_fl4
.
x
+
attn_bias_fl4
.
x
)
*
mp_scale
+
(
hs_fl4
.
x
);
res_fl4
.
y
=
(
res_fl4
.
y
+
attn_fl4
.
y
+
bias_fl4
.
y
+
attn_bias_fl4
.
y
)
*
mp_scale
+
(
hs_fl4
.
y
);
res_fl4
.
z
=
(
res_fl4
.
z
+
attn_fl4
.
z
+
bias_fl4
.
z
+
attn_bias_fl4
.
z
)
*
mp_scale
+
(
hs_fl4
.
z
);
res_fl4
.
w
=
(
res_fl4
.
w
+
attn_fl4
.
w
+
bias_fl4
.
w
+
attn_bias_fl4
.
w
)
*
mp_scale
+
(
hs_fl4
.
w
);
}
else
{
data
.
x
=
data
.
x
+
out
.
x
+
bias_data
.
x
;
data
.
y
=
data
.
y
+
out
.
y
+
bias_data
.
y
;
data
.
z
=
data
.
z
+
out
.
z
+
bias_data
.
z
;
data
.
w
=
data
.
w
+
out
.
w
+
bias_data
.
w
;
// residual += hidden_state + bias
res_fl4
.
x
=
res_fl4
.
x
+
hs_fl4
.
x
+
bias_fl4
.
x
;
res_fl4
.
y
=
res_fl4
.
y
+
hs_fl4
.
y
+
bias_fl4
.
y
;
res_fl4
.
z
=
res_fl4
.
z
+
hs_fl4
.
z
+
bias_fl4
.
z
;
res_fl4
.
w
=
res_fl4
.
w
+
hs_fl4
.
w
+
bias_fl4
.
w
;
}
input_cast
[
offset
]
=
data
;
res_fl4_ptr
[
offset
]
=
res_fl4
;
}
}
__global__
void
fused_bias_residual
(
__half
*
input
,
__half
*
output
,
__half
*
attn
,
__half
*
bias
,
__half
*
attn_bias
,
int
total_count
,
int
intermediate_size
,
float
mp_scale
,
bool
preln
)
__global__
void
fused_bias_residual
(
__half
*
residual
,
const
__half
*
hidden_state
,
const
__half
*
attn
,
const
__half
*
bias
,
const
__half
*
attn_bias
,
const
int
total_count
,
const
int
intermediate_size
,
const
float
mp_scale
,
const
bool
preln
)
{
#ifdef HALF_PRECISION_AVAILABLE
float2
*
input_cast
=
reinterpret_cast
<
float2
*>
(
input
);
float2
*
output_cast
=
reinterpret_cast
<
float2
*>
(
output
);
float2
*
attn_cast
=
reinterpret_cast
<
float2
*>
(
attn
);
float2
*
bias_cast
=
reinterpret_cast
<
float2
*>
(
bias
);
float2
*
attnbias_cast
=
reinterpret_cast
<
float2
*>
(
attn_bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float2
*
res_fl2_ptr
=
reinterpret_cast
<
float2
*>
(
residual
);
const
float2
*
hs_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
hidden_state
);
const
float2
*
attn_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
attn
);
const
float2
*
bias_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
bias
);
const
float2
*
attn_bias_fl2_ptr
=
reinterpret_cast
<
const
float2
*>
(
attn_bias
);
const
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float2
vals_vec
=
input_cast
[
offset
];
float2
out_vec
=
output_cast
[
offset
];
float2
res_vec
=
attn_cast
[
offset
];
float2
bias_vec
=
bias_cast
[
offset
%
intermediate_size
];
float2
attn_bias_vec
=
attnbias_cast
[
offset
%
intermediate_size
];
float2
res_fl2
=
res_fl2_ptr
[
offset
];
const
float2
hs_fl2
=
hs_fl2_ptr
[
offset
];
const
float2
attn_fl2
=
attn_fl2_ptr
[
offset
];
const
float2
bias_fl2
=
bias_fl2_ptr
[
offset
%
intermediate_size
];
const
float2
attn_bias_fl2
=
attn_bias_fl2_ptr
[
offset
%
intermediate_size
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
out_half
=
reinterpret_cast
<
__half2
*>
(
&
out_vec
);
__half2
*
res_half
=
reinterpret_cast
<
__half2
*>
(
&
res_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
__half2
*
attnbias_half
=
reinterpret_cast
<
__half2
*>
(
&
attn_bias_vec
);
__half2
*
res_half2
=
reinterpret_cast
<
__half2
*>
(
&
res_fl2
);
const
__half2
*
hs_half2
=
reinterpret_cast
<
const
__half2
*>
(
&
hs_fl2
);
const
__half2
*
attn_half2
=
reinterpret_cast
<
const
__half2
*>
(
&
attn_fl2
);
const
__half2
*
bias_half2
=
reinterpret_cast
<
const
__half2
*>
(
&
bias_fl2
);
const
__half2
*
attn_bias_half2
=
reinterpret_cast
<
const
__half2
*>
(
&
attn_bias_fl2
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
res_low
=
__half22float2
(
res_half2
[
0
]);
float2
res_high
=
__half22float2
(
res_half2
[
1
]);
float2
low_out
=
__half22float2
(
out_half
[
0
]);
float2
high_out
=
__half22float2
(
out_half
[
1
]);
const
float2
hs_low
=
__half22float2
(
hs_half2
[
0
]);
const
float2
hs_high
=
__half22float2
(
hs_half2
[
1
]);
float2
low_res
=
__half22float2
(
res_half
[
0
]);
float2
high_res
=
__half22float2
(
res_half
[
1
]);
const
float2
attn_low
=
__half22float2
(
attn_half2
[
0
]);
const
float2
attn_high
=
__half22float2
(
attn_half2
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
const
float2
bias_low
=
__half22float2
(
bias_half2
[
0
]);
const
float2
bias_high
=
__half22float2
(
bias_half2
[
1
]);
float2
attn_low_bias
=
__half22float2
(
attnbias_half
[
0
]);
float2
attn_high_bias
=
__half22float2
(
attnbias_half
[
1
]);
const
float2
attn_bias_low
=
__half22float2
(
attn_bias_half2
[
0
]);
const
float2
attn_bias_high
=
__half22float2
(
attn_bias_half2
[
1
]);
if
(
preln
)
{
low_data
.
x
=
(
low_data
.
x
+
low_res
.
x
+
(
low_bias
.
x
+
attn_low_bias
.
x
))
*
mp_scale
+
low_out
.
x
;
low_data
.
y
=
(
low_data
.
y
+
low_res
.
y
+
(
low_bias
.
y
+
attn_low_bias
.
y
))
*
mp_scale
+
low_out
.
y
;
high_data
.
x
=
(
high_data
.
x
+
high_res
.
x
+
(
high_bias
.
x
+
attn_high_bias
.
x
))
*
mp_scale
+
high_out
.
x
;
high_data
.
y
=
(
high_data
.
y
+
high_res
.
y
+
(
high_bias
.
y
+
attn_high_bias
.
y
))
*
mp_scale
+
high_out
.
y
;
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_low
.
x
=
(
res_low
.
x
+
attn_low
.
x
+
bias_low
.
x
+
attn_bias_low
.
x
)
*
mp_scale
+
hs_low
.
x
;
res_low
.
y
=
(
res_low
.
y
+
attn_low
.
y
+
bias_low
.
y
+
attn_bias_low
.
y
)
*
mp_scale
+
hs_low
.
y
;
res_high
.
x
=
(
res_high
.
x
+
attn_high
.
x
+
bias_high
.
x
+
attn_bias_high
.
x
)
*
mp_scale
+
hs_high
.
x
;
res_high
.
y
=
(
res_high
.
y
+
attn_high
.
y
+
bias_high
.
y
+
attn_bias_high
.
y
)
*
mp_scale
+
hs_high
.
y
;
}
else
{
low_data
.
x
=
(
low_data
.
x
+
low_out
.
x
+
low_bias
.
x
);
low_data
.
y
=
(
low_data
.
y
+
low_out
.
y
+
low_bias
.
y
);
high_data
.
x
=
(
high_data
.
x
+
high_out
.
x
+
high_bias
.
x
);
high_data
.
y
=
(
high_data
.
y
+
high_out
.
y
+
high_bias
.
y
);
// residual += hidden_state + bias
res_low
.
x
=
(
res_low
.
x
+
hs_low
.
x
+
bias_low
.
x
);
res_low
.
y
=
(
res_low
.
y
+
hs_low
.
y
+
bias_low
.
y
);
res_high
.
x
=
(
res_high
.
x
+
hs_high
.
x
+
bias_high
.
x
);
res_high
.
y
=
(
res_high
.
y
+
hs_high
.
y
+
bias_high
.
y
);
}
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
res_half2
[
0
]
=
__float22half2_rn
(
res_low
);
res_half2
[
1
]
=
__float22half2_rn
(
res_high
);
input_cast
[
offset
]
=
vals_vec
;
res_fl2_ptr
[
offset
]
=
res_fl2
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_residual
(
T
*
input
,
T
*
output
,
void
launch_bias_residual
(
T
*
residual
,
T
*
hidden_state
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
...
...
@@ -253,8 +260,15 @@ void launch_bias_residual(T* input,
dim3
block_dims
(
1024
);
dim3
grid_dims
((
total_count
-
1
)
/
1024
+
1
);
// (batch_size);
fused_bias_residual
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
attn
,
bias
,
attn_bias
,
total_count
,
hidden_dim
/
4
,
1.0
/
mp_size
,
preln
);
fused_bias_residual
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
residual
,
hidden_state
,
attn
,
bias
,
attn_bias
,
total_count
,
hidden_dim
/
4
,
1.0
/
mp_size
,
preln
);
}
template
void
launch_bias_residual
<
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录