Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5ee00943
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5ee00943
编写于
8月 17, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/cuda): fix ptx mma algo compute bugs
GitOrigin-RevId: 19628d0c94e93ff1072db2eb04547e6f8db5f809
上级
d4bf57d6
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
157 addition
and
195 deletion
+157
-195
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_128x256_relu.cu
...cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_128x256_relu.cu
+14
-12
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_256x64_relu.cu
.../cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_256x64_relu.cu
+32
-25
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu
...a/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu
+32
-25
dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_128x256_relu.cu
dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_128x256_relu.cu
+14
-11
dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_256x64_relu.cu
dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_256x64_relu.cu
+14
-11
dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu
...src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu
+32
-30
dnn/src/cuda/ptx/uint4_int4/macro.cuh
dnn/src/cuda/ptx/uint4_int4/macro.cuh
+19
-81
未找到文件。
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_128x256_relu.cu
浏览文件 @
5ee00943
...
...
@@ -476,6 +476,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads
();
}
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
}
// read fuse_z
int2
reg_fuse_z
[
reg_m
]
=
{
make_int2
(
z_zero_point
,
z_zero_point
),
make_int2
(
z_zero_point
,
z_zero_point
),
...
...
@@ -595,18 +609,7 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads
();
/// output
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
...
...
@@ -617,7 +620,6 @@ extern "C" __global__ void __launch_bounds__(256)
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
I2F_4x8
(
reg_acc
,
y
,
0
);
FMA_4x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
FUSE_Z_4x8
(
reg_acc
,
y
,
0
,
reg_fuse_z
,
gamma
,
z_zero_point
);
PACK_F2I_WITH_RELU_4x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
...
...
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_256x64_relu.cu
浏览文件 @
5ee00943
...
...
@@ -657,6 +657,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads
();
}
size_t
oc
=
bidy
*
BM
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
}
// read fuse_z
int2
reg_fuse_z
[
reg_m
]
=
{
make_int2
(
z_zero_point
,
z_zero_point
),
make_int2
(
z_zero_point
,
z_zero_point
),
...
...
@@ -712,6 +726,14 @@ extern "C" __global__ void __launch_bounds__(256)
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
/// output
if
(
oc
<
param
.
oc
)
{
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
mul_v4
(
load_bias3
,
load_bias3
,
beta
);
}
// compute
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
...
...
@@ -773,35 +795,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads
();
/// output
size_t
oc
=
bidy
*
BM
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
mul_v4
(
load_bias3
,
load_bias3
,
beta
);
}
int8_t
*
__restrict__
g_dst_ptr
=
dst
+
d_offset
;
FMA_1x8
(
reg_acc
,
0
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
fuse_z_1x8
(
reg_acc
[
0
],
0
,
reg_fuse_z
[
0
],
gamma
,
z_zero_point
);
PACK_F2I_WITH_RELU_1x8
(
reg_acc
,
0
,
0
,
relu
,
dst_zero_point
);
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
I2F_4x8
(
reg_acc
,
y
,
0
);
FMA_4x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
FUSE_Z_4x8
(
reg_acc
,
y
,
0
,
reg_fuse_z
,
gamma
,
z_zero_point
);
PACK_F2I_WITH_RELU_4x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_AFTER_LDG_4x1
(
g_offset
,
reg_acc
,
y
,
0
);
for
(
int
y
=
1
;
y
<
reg_m
;
y
+=
1
)
{
FMA_1x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
fuse_z_1x8
(
reg_acc
[
y
],
0
,
reg_fuse_z
[
y
],
gamma
,
z_zero_point
);
PACK_F2I_WITH_RELU_1x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_AFTER_LDG
(
g_offset
[
y
-
1
],
reg_acc
[
y
-
1
][
0
],
stg_guard
[
y
-
1
]);
}
STG_AFTER_LDG
(
g_offset
[
7
],
reg_acc
[
7
][
0
],
stg_guard
[
7
]);
#endif
}
}
// namespace
...
...
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu
浏览文件 @
5ee00943
...
...
@@ -437,7 +437,7 @@ extern "C" __global__ void __launch_bounds__(256)
cp_async_fence
();
}
bool
only_one_stage
=
(
stage
==
1
)
?
true
:
false
;
bool
only_one_stage
=
(
stage
==
1
);
if
(
stage
>=
2
)
{
cp_async_wait
(
stages
-
2
);
}
else
{
...
...
@@ -844,6 +844,20 @@ extern "C" __global__ void __launch_bounds__(256)
cp_async_wait
(
stages
-
2
);
}
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
}
if
(
!
only_one_stage
)
{
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
...
...
@@ -975,6 +989,13 @@ extern "C" __global__ void __launch_bounds__(256)
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
if
(
oc
<
param
.
oc
)
{
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
mul_v4
(
load_bias3
,
load_bias3
,
beta
);
}
// compute
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
...
...
@@ -1038,34 +1059,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads
();
/// output
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
mul_v4
(
load_bias3
,
load_bias3
,
beta
);
}
int8_t
*
__restrict__
g_dst_ptr
=
dst
+
d_offset
;
FMA_1x8
(
reg_acc
,
0
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
fuse_z_1x8
(
reg_acc
[
0
],
0
,
reg_fuse_z
[
0
],
gamma
,
z_zero_point
);
PACK_F2I_WITH_RELU_1x8
(
reg_acc
,
0
,
0
,
relu
,
dst_zero_point
);
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
I2F_4x8
(
reg_acc
,
y
,
0
);
FMA_4x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
FUSE_Z_4x8
(
reg_acc
,
y
,
0
,
reg_fuse_z
,
gamma
,
z_zero_point
);
PACK_F2I_WITH_RELU_4x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_AFTER_LDG_4x1
(
g_offset
,
reg_acc
,
y
,
0
);
for
(
int
y
=
1
;
y
<
reg_m
;
y
+=
1
)
{
FMA_1x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
fuse_z_1x8
(
reg_acc
[
y
],
0
,
reg_fuse_z
[
y
],
gamma
,
z_zero_point
);
PACK_F2I_WITH_RELU_1x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_AFTER_LDG
(
g_offset
[
y
-
1
],
reg_acc
[
y
-
1
][
0
],
stg_guard
[
y
-
1
]);
}
STG_AFTER_LDG
(
g_offset
[
7
],
reg_acc
[
7
][
0
],
stg_guard
[
7
]);
#endif
}
}
// namespace
...
...
dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_128x256_relu.cu
浏览文件 @
5ee00943
...
...
@@ -475,6 +475,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads
();
}
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
}
guard
=
iter
<
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
...
...
@@ -574,18 +588,8 @@ extern "C" __global__ void __launch_bounds__(256)
size_t
nhw_post3
=
nhw_post0
+
24
;
size_t
stg_oc
=
bidy
*
BM
+
(
warp_y
<<
6
);
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
...
...
@@ -599,7 +603,6 @@ extern "C" __global__ void __launch_bounds__(256)
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
I2F_4x8
(
reg_acc
,
y
,
0
);
FMA_4x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
PACK_F2I_WITH_RELU_4x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_4x1
(
stg_ptr
,
reg_acc
,
y
,
0
);
...
...
dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_256x64_relu.cu
浏览文件 @
5ee00943
...
...
@@ -659,6 +659,20 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads
();
}
size_t
oc
=
bidy
*
BM
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
}
guard
=
iter
<
0
;
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
...
...
@@ -755,18 +769,8 @@ extern "C" __global__ void __launch_bounds__(256)
size_t
nhw_post3
=
nhw_post0
+
24
;
size_t
stg_oc
=
bidy
*
BM
;
size_t
oc
=
bidy
*
BM
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
...
...
@@ -779,7 +783,6 @@ extern "C" __global__ void __launch_bounds__(256)
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
I2F_4x8
(
reg_acc
,
y
,
0
);
FMA_4x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
PACK_F2I_WITH_RELU_4x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_4x1
(
stg_ptr
,
reg_acc
,
y
,
0
);
...
...
dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu
浏览文件 @
5ee00943
...
...
@@ -449,15 +449,15 @@ extern "C" __global__ void __launch_bounds__(256)
bool
stg_guard
[
8
];
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
COMPUTE_OFFSET_4x1
(
reg_fuse_z
,
g_offset
,
y
)
COMPUTE_OFFSET_4x1
(
g_offset
,
y
);
nhw_post0
+=
32
;
nhw_post0
+=
32
;
nhw_post1
+=
32
;
nhw_post2
+=
32
;
nhw_post3
+=
32
;
}
bool
only_one_stage
=
(
stage
==
1
)
?
true
:
false
;
bool
only_one_stage
=
(
stage
==
1
);
if
(
stage
>=
2
)
{
cp_async_wait
(
stages
-
2
);
}
else
{
...
...
@@ -835,6 +835,20 @@ extern "C" __global__ void __launch_bounds__(256)
cp_async_wait
(
stages
-
2
);
}
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
}
if
(
!
only_one_stage
)
{
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
...
...
@@ -965,6 +979,13 @@ extern "C" __global__ void __launch_bounds__(256)
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
if
(
oc
<
param
.
oc
)
{
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
mul_v4
(
load_bias3
,
load_bias3
,
beta
);
}
// compute
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
...
...
@@ -1028,38 +1049,19 @@ extern "C" __global__ void __launch_bounds__(256)
__syncthreads
();
/// output
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
mul_v4
(
load_bias3
,
load_bias3
,
beta
);
}
int8_t
*
__restrict__
g_dst_ptr
=
dst
+
d_offset
;
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
I2F_4x8
(
reg_acc
,
y
,
0
);
FMA_4x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
PACK_F2I_WITH_RELU_4x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_AFTER_LDG_4x1
(
g_offset
,
reg_acc
,
y
,
0
);
FMA_1x8
(
reg_acc
,
0
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
PACK_F2I_WITH_RELU_1x8
(
reg_acc
,
0
,
0
,
relu
,
dst_zero_point
);
nhw_post0
+=
32
;
nhw_post1
+=
32
;
nhw_post2
+=
32
;
nhw_post3
+=
32
;
#pragma unroll
for
(
int
y
=
1
;
y
<
reg_m
;
y
+=
1
)
{
FMA_1x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
PACK_F2I_WITH_RELU_1x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_AFTER_LDG
(
g_offset
[
y
-
1
],
reg_acc
[
y
-
1
][
0
],
stg_guard
[
y
-
1
]);
}
STG_AFTER_LDG
(
g_offset
[
7
],
reg_acc
[
7
][
0
],
stg_guard
[
7
]);
#endif
}
}
// namespace
...
...
dnn/src/cuda/ptx/uint4_int4/macro.cuh
浏览文件 @
5ee00943
...
...
@@ -23,78 +23,26 @@ __device__ __forceinline__ void mul_v4<float>(
__device__
__forceinline__
void
fma2
(
int2
&
c0
,
const
int2
a0
,
int2
&
c1
,
const
int2
a1
,
const
float
alpha
,
const
int4
b
)
{
asm
(
"fma.rz.f32 %0, %1, %2, %3;"
:
"=f"
(((
float
*
)
&
c0
)[
0
])
:
"f"
(((
float
*
)
&
a0
)[
0
]),
"f"
(
alpha
),
"f"
(((
float
*
)
&
b
)[
0
]));
asm
(
"fma.rz.f32 %0, %1, %2, %3;"
:
"=f"
(((
float
*
)
&
c0
)[
1
])
:
"f"
(((
float
*
)
&
a0
)[
1
]),
"f"
(
alpha
),
"f"
(((
float
*
)
&
b
)[
1
]));
asm
(
"fma.rz.f32 %0, %1, %2, %3;"
:
"=f"
(((
float
*
)
&
c1
)[
0
])
:
"f"
(((
float
*
)
&
a1
)[
0
]),
"f"
(
alpha
),
"f"
(((
float
*
)
&
b
)[
2
]));
asm
(
"fma.rz.f32 %0, %1, %2, %3;"
:
"=f"
(((
float
*
)
&
c1
)[
1
])
:
"f"
(((
float
*
)
&
a1
)[
1
]),
"f"
(
alpha
),
"f"
(((
float
*
)
&
b
)[
3
]));
}
__device__
__forceinline__
void
fuse_z_1x8
(
int4
*
a
,
const
int
&
j
,
const
int4
&
fuse_z
,
const
float
&
gamma
,
const
int32_t
&
zero_point
)
{
const
int2
z
[
2
]
=
{
*
reinterpret_cast
<
const
int2
*>
(
&
fuse_z
),
*
(
reinterpret_cast
<
const
int2
*>
(
&
fuse_z
)
+
1
)};
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
f
=
((
z
[
0
].
x
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
0
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
0
].
x
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
1
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
1
].
x
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
2
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
1
].
x
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
3
]
+=
(
f
-
zero_point
)
*
gamma
;
}
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
f
=
((
z
[
0
].
y
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
0
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
0
].
y
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
1
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
1
].
y
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
2
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
1
].
y
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
3
]
+=
(
f
-
zero_point
)
*
gamma
;
}
((
float
*
)
&
c0
)[
0
]
=
a0
.
x
*
alpha
+
((
float
*
)
&
b
)[
0
];
((
float
*
)
&
c0
)[
1
]
=
a0
.
y
*
alpha
+
((
float
*
)
&
b
)[
1
];
((
float
*
)
&
c1
)[
0
]
=
a1
.
x
*
alpha
+
((
float
*
)
&
b
)[
2
];
((
float
*
)
&
c1
)[
1
]
=
a1
.
y
*
alpha
+
((
float
*
)
&
b
)[
3
];
}
__device__
__forceinline__
void
fuse_z_1x8
(
int2
*
a
,
const
int
&
j
,
const
int2
&
fuse_z
,
const
float
&
gamma
,
const
int32_t
&
zero_point
)
{
float
x
=
zero_point
*
gamma
;
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
f
=
((
fuse_z
.
x
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
0
]
+=
(
f
-
zero_point
)
*
gamma
;
((
float
*
)
&
(
a
[
j
+
k
]))[
0
]
+=
f
*
gamma
-
x
;
f
=
((
fuse_z
.
x
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
1
]
+=
(
f
-
zero_point
)
*
gamma
;
}
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
f
=
((
fuse_z
.
y
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
0
]
+=
(
f
-
zero_point
)
*
gamma
;
((
float
*
)
&
(
a
[
j
+
k
]))[
1
]
+=
f
*
gamma
-
x
;
f
=
((
fuse_z
.
y
>>
(
k
*
8
))
&
15
);
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
0
]
+=
f
*
gamma
-
x
;
f
=
((
fuse_z
.
y
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
1
]
+=
(
f
-
zero_point
)
*
gamma
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
1
]
+=
f
*
gamma
-
x
;
}
}
...
...
@@ -282,12 +230,6 @@ __device__ __forceinline__ void pack_f2i_with_relu(
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);
#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \
fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \
fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);
// 1x8 1x(2x8 int2) to 2 int2
#define PACK_F2I_1x8(a, i, j) \
pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \
...
...
@@ -316,24 +258,20 @@ __device__ __forceinline__ void pack_f2i_with_relu(
stg_guard[i + 2]) \
LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])
#define COMPUTE_OFFSET(
d, s, idx, n_reuse, hw_reuse, g)
\
#define COMPUTE_OFFSET(
s, idx, n_reuse, hw_reuse, g)
\
n_reuse = nhw_post##idx / param.div_ohow; \
hw_reuse = nhw_post##idx % param.div_ohow; \
s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \
g = nhw_post##idx < param.nhw;
#define COMPUTE_OFFSET_4x1(d, s, i) \
COMPUTE_OFFSET( \
d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
COMPUTE_OFFSET( \
d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \
stg_guard[i + 1]) \
COMPUTE_OFFSET( \
d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \
stg_guard[i + 2]) \
COMPUTE_OFFSET( \
d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, \
stg_guard[i + 3])
#define COMPUTE_OFFSET_4x1(s, i) \
COMPUTE_OFFSET(s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
COMPUTE_OFFSET( \
s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, stg_guard[i + 1]) \
COMPUTE_OFFSET( \
s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, stg_guard[i + 2]) \
COMPUTE_OFFSET( \
s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])
#define STG_AFTER_LDG(d, s, g) \
if (stg_oc < param.oc && g) { \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录