Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1f8e930e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1f8e930e
编写于
7月 21, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(cuda): add int4 ptx 128x128 mma kernel
GitOrigin-RevId: 5a8b9c3f8eab59ed8d1daf9bbaf2c81cdc82ca5b
上级
1a2ed8c4
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
2776 addition
and
0 deletion
+2776
-0
dnn/src/cuda/ptx/uint4_int4/base.cu
dnn/src/cuda/ptx/uint4_int4/base.cu
+39
-0
dnn/src/cuda/ptx/uint4_int4/base.cuh
dnn/src/cuda/ptx/uint4_int4/base.cuh
+109
-0
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu
...a/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu
+1096
-0
dnn/src/cuda/ptx/uint4_int4/imma8832_128x128.cuh
dnn/src/cuda/ptx/uint4_int4/imma8832_128x128.cuh
+26
-0
dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu
...src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu
+1089
-0
dnn/src/cuda/ptx/uint4_int4/kern.cuh
dnn/src/cuda/ptx/uint4_int4/kern.cuh
+20
-0
dnn/src/cuda/ptx/uint4_int4/macro.cuh
dnn/src/cuda/ptx/uint4_int4/macro.cuh
+348
-0
dnn/src/cuda/ptx/uint4_int4/tools.cuh
dnn/src/cuda/ptx/uint4_int4/tools.cuh
+49
-0
未找到文件。
dnn/src/cuda/ptx/uint4_int4/base.cu
0 → 100644
浏览文件 @
1f8e930e
#include "./base.cuh"
using
namespace
convolution
;
Uint32Fastdiv
::
Uint32Fastdiv
()
{
memset
(
this
,
0
,
sizeof
(
Uint32Fastdiv
));
}
Uint32Fastdiv
&
Uint32Fastdiv
::
operator
=
(
uint32_t
d
)
{
m_divisor
=
d
;
constexpr
uint32_t
MAX_U32
=
~
0u
;
m_inc_dividend
=
0
;
m_divisor_is_not_1
=
~
0u
;
if
(
!
(
d
&
(
d
-
1
)))
{
// power of 2
m_mul
=
1u
<<
31
;
int
p
=
0
;
while
((
1u
<<
p
)
<
d
)
++
p
;
m_shift
=
p
?
p
-
1
:
0
;
if
(
d
==
1
)
m_divisor_is_not_1
=
0
;
return
*
this
;
}
auto
n_bound
=
uint64_t
(
d
/
2
+
1
)
*
MAX_U32
;
uint32_t
shift
=
32
;
while
((
1ull
<<
shift
)
<
n_bound
)
++
shift
;
uint64_t
mdst
=
1ull
<<
shift
;
int64_t
delta
=
d
-
mdst
%
d
;
m_mul
=
mdst
/
d
+
1
;
if
((
uint64_t
)
delta
>
d
/
2
)
{
delta
-=
d
;
--
m_mul
;
m_inc_dividend
=
1
;
}
m_shift
=
shift
-
32
;
return
*
this
;
}
dnn/src/cuda/ptx/uint4_int4/base.cuh
0 → 100644
浏览文件 @
1f8e930e
#pragma once
#include <cuda_runtime.h>
#include <stdint.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || \
(__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#define SM80_SUPPORTED
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define SM80_ENABLED
#endif
#endif
namespace
convolution
{
class
Uint32Fastdiv
{
uint32_t
m_mul
,
m_divisor
,
m_divisor_is_not_1
,
m_inc_dividend
,
m_shift
;
public:
Uint32Fastdiv
();
Uint32Fastdiv
(
uint32_t
d
)
{
operator
=
(
d
);
}
//! set the divisor to be d
Uint32Fastdiv
&
operator
=
(
uint32_t
d
);
//! caller must ensure that dividend would not exceed this number
static
constexpr
uint32_t
MAX_DIVIDEND
=
~
0u
-
1
;
__device__
__forceinline__
uint32_t
divisor
()
const
{
return
m_divisor
;
}
__device__
__forceinline__
uint32_t
divide
(
uint32_t
dividend
)
const
{
uint32_t
ans_for_one
=
dividend
&
~
m_divisor_is_not_1
,
dfix
=
dividend
+
m_inc_dividend
,
#if __CUDA_ARCH__
hi32
=
__umulhi
(
dfix
,
m_mul
),
#else
hi32
=
((
uint64_t
)
dfix
*
m_mul
)
>>
32
,
#endif
ans
=
hi32
>>
m_shift
;
return
(
ans
&
m_divisor_is_not_1
)
|
ans_for_one
;
}
};
static
__forceinline__
__device__
uint32_t
operator
/
(
uint32_t
a
,
const
Uint32Fastdiv
&
d
)
{
return
d
.
divide
(
a
);
}
static
__forceinline__
__device__
uint32_t
operator
%
(
uint32_t
a
,
const
Uint32Fastdiv
&
d
)
{
return
a
-
d
.
divisor
()
*
d
.
divide
(
a
);
}
struct
Conv2dInt4Param
{
uint32_t
n
,
ic
,
ih
,
iw
,
fh
,
fw
,
sh
,
sw
,
ph
,
pw
,
oc
,
oh
,
ow
;
uint32_t
ibs
,
ics
,
ihs
;
uint32_t
obs
,
ocs
,
ohs
;
uint32_t
icfhfw
;
uint32_t
nhw
;
Uint32Fastdiv
div_ohow
;
Uint32Fastdiv
div_ow
;
Conv2dInt4Param
(
uint32_t
n
,
uint32_t
ic
,
uint32_t
ih
,
uint32_t
iw
,
uint32_t
fh
,
uint32_t
fw
,
uint32_t
sh
,
uint32_t
sw
,
uint32_t
ph
,
uint32_t
pw
,
uint32_t
oc
,
uint32_t
oh
,
uint32_t
ow
,
uint32_t
interleaved
)
:
n
(
n
),
ic
(
ic
),
ih
(
ih
),
iw
(
iw
),
fh
(
fh
),
fw
(
fw
),
sh
(
sh
),
sw
(
sw
),
ph
(
ph
),
pw
(
pw
),
oc
(
oc
),
oh
(
oh
),
ow
(
ow
)
{
constexpr
uint32_t
size_bits
=
4
;
// all stride size in bytes
ibs
=
ic
*
ih
*
iw
*
size_bits
/
8
;
ics
=
ih
*
iw
*
interleaved
*
size_bits
/
8
;
ihs
=
iw
*
interleaved
*
size_bits
/
8
;
obs
=
oc
*
oh
*
ow
*
size_bits
/
8
;
ocs
=
oh
*
ow
*
interleaved
*
size_bits
/
8
;
ohs
=
ow
*
interleaved
*
size_bits
/
8
;
icfhfw
=
ic
*
fh
*
fw
;
nhw
=
n
*
oh
*
ow
;
div_ohow
=
oh
*
ow
;
div_ow
=
ow
;
}
};
struct
Conv2dConstantOffsetParam
{
int32_t
begin
;
int32_t
size
;
int32_t
max
;
int32_t
rewind
;
};
#define CONSTANT_BUFFER_SIZE 848
struct
Conv2dConstantOffset
{
Conv2dConstantOffsetParam
c_offset_param
;
int
c_offset
[
CONSTANT_BUFFER_SIZE
];
};
}
// namespace convolution
dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu
0 → 100644
浏览文件 @
1f8e930e
#include <cuda_runtime.h>
#include <stdio.h>
#include "./imma8832_128x128.cuh"
#include "./kern.cuh"
#include "./macro.cuh"
#include "./tools.cuh"
using
namespace
convolution
;
namespace
{
#ifdef SM80_ENABLED
extern
"C"
__device__
void
g2s_int4
(
const
int4
*
gm
,
int4
*
sm
)
{
unsigned
sm_addr
=
get_smem_pointer
(
sm
);
const
int
SizeInBytes
=
16
;
#if ENABLE_L2_PREFETCH
asm
volatile
(
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;
\n
"
::
"r"
(
sm_addr
),
"l"
(
gm
),
"n"
(
SizeInBytes
));
#else
asm
volatile
(
"cp.async.cg.shared.global [%0], [%1], %2;
\n
"
::
"r"
(
sm_addr
),
"l"
(
gm
),
"n"
(
SizeInBytes
));
#endif
}
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does
/// not block.
#define cp_async_fence() asm volatile("cp.async.commit_group;\n" ::)
/// Blocks until all but <N> previous cp.async.commit_group operations have
/// committed.
#define cp_async_wait(N) asm volatile("cp.async.wait_group %0;\n" ::"n"(N))
#endif
extern
"C"
__global__
void
__launch_bounds__
(
256
)
ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu
(
const
int8_t
*
__restrict__
src
,
const
int8_t
*
__restrict__
filter
,
const
float
*
__restrict__
bias
,
const
int8_t
*
__restrict__
z
,
int8_t
*
__restrict__
dst
,
float
alpha
,
float
beta
,
float
gamma
,
uint32_t
pk_src_zero_point
,
int32_t
z_zero_point
,
float
dst_zero_point
,
uint32_t
relu
,
Conv2dInt4Param
param
,
Conv2dConstantOffset
conv2d_constant
)
{
#ifdef SM80_ENABLED
const
int
stages
=
3
;
const
uint32_t
tid
=
threadIdx
.
x
;
const
uint32_t
bidx
=
blockIdx
.
x
;
const
uint32_t
bidy
=
blockIdx
.
y
;
extern
__shared__
int32_t
smem
[];
// (128 + 128)*128/8*stages
int2
reg_acc
[
reg_m
][
reg_n
];
int4
reg_src
[
2
][
reg_nd4
];
int4
reg_flt
[
2
][
reg_md4
];
// use in other way, maybe use reg_ser/flt
int4
reg_src_cache
[
2
];
int4
reg_filter_cache
[
4
];
uint32_t
tid127
=
(
tid
&
127
);
uint32_t
section
=
(
tid127
>>
1
);
uint32_t
residue
=
((
tid127
<<
5
)
&
63
);
uint32_t
nhw
=
bidx
*
BN
+
section
;
uint32_t
tn
,
hw
,
toh
,
tow
;
int
tih
,
tiw
;
int
h_start
[
2
];
int
h_end
[
2
];
int
w_start
[
2
];
int
w_end
[
2
];
bool
g
[
2
];
const
int8_t
*
__restrict__
g_src_ptr
[
4
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
if
(
i
!=
0
)
{
nhw
+=
64
;
}
tn
=
nhw
/
param
.
div_ohow
;
hw
=
nhw
%
param
.
div_ohow
;
toh
=
hw
/
param
.
div_ow
;
tow
=
hw
%
param
.
div_ow
;
tih
=
toh
*
param
.
sh
-
param
.
ph
;
tiw
=
tow
*
param
.
sw
-
param
.
pw
;
g
[
i
]
=
tn
<
param
.
n
;
h_start
[
i
]
=
-
tih
;
h_end
[
i
]
=
param
.
ih
-
tih
;
w_start
[
i
]
=
-
tiw
;
w_end
[
i
]
=
param
.
iw
-
tiw
;
// param's members have been converted to byte offset and int4 offset need to
// div 2
int
src_offset
=
tn
*
param
.
ibs
+
tih
*
param
.
ihs
+
((
int
)(
tiw
*
packed_channel
+
residue
)
>>
1
);
g_src_ptr
[
i
*
2
]
=
src
+
src_offset
;
g_src_ptr
[
i
*
2
+
1
]
=
g_src_ptr
[
i
*
2
];
}
const
uint32_t
section_section
=
(
section
>>
2
);
const
uint32_t
section_residue
=
(
section
&
3
);
const
uint32_t
section_factor
=
((
section
&
15
)
>>
2
);
const
uint32_t
crosswise_offset
=
((
section_residue
>>
1
)
<<
4
)
+
(((
section_residue
&
1
)
^
(
section_factor
>>
1
))
<<
3
);
const
uint32_t
residue_offset
=
((
residue
>>
5
)
^
(
section_factor
&
1
))
<<
2
;
// next + 64 * BK / 8
int32_t
*
write_src_s
[
2
];
write_src_s
[
0
]
=
smem
+
section_section
*
BK
/
2
+
crosswise_offset
+
residue_offset
;
write_src_s
[
1
]
=
write_src_s
[
0
]
+
32
;
int
iter
=
(
param
.
icfhfw
>>
6
);
uint32_t
tid31
=
(
tid
&
31
);
uint32_t
warp_idx
=
(
tid
>>
5
);
uint32_t
warp_strided
=
(
warp_idx
<<
2
);
uint32_t
htid
=
(
tid31
>>
4
);
const
uint32_t
flt_strided
=
bidy
*
BM
/
8
+
warp_strided
;
bool
guard
=
flt_strided
*
8
<
param
.
oc
&&
iter
>
htid
;
// icfhfw * 8/2 is a stride
const
int8_t
*
__restrict__
g_filter_ptr0
=
filter
+
flt_strided
*
(
param
.
icfhfw
*
4
)
+
(
tid31
<<
4
);
const
int8_t
*
__restrict__
g_filter_ptr1
=
g_filter_ptr0
+
(
param
.
icfhfw
*
4
);
const
int8_t
*
__restrict__
g_filter_ptr2
=
g_filter_ptr0
+
(
param
.
icfhfw
*
8
);
const
int8_t
*
__restrict__
g_filter_ptr3
=
g_filter_ptr0
+
(
param
.
icfhfw
*
12
);
// next + BK * 8 / (INT32/INT4)
uint32_t
q
=
(
tid31
>>
3
);
uint32_t
r
=
(
tid31
&
7
);
int32_t
*
write_flt_s
=
smem
+
BN
*
BK
/
8
+
warp_strided
*
BK
+
((
q
&
1
)
<<
6
)
+
((
q
>>
1
)
<<
5
)
+
(
r
<<
2
);
uint32_t
quad_idx
=
(
tid31
>>
2
);
uint32_t
idx_in_quad
=
(
tid
&
3
);
uint32_t
quad_factor
=
((
tid
&
15
)
>>
2
);
uint32_t
crosswise
=
((
idx_in_quad
>>
1
)
<<
4
)
+
(((
idx_in_quad
&
1
)
^
(
quad_factor
>>
1
))
<<
3
);
uint32_t
warp_x
=
(
warp_idx
>>
1
);
uint32_t
warp_y
=
(
warp_idx
&
1
);
int32_t
*
read_src_s_0
=
smem
+
(
warp_x
*
8
*
BK
)
+
(
quad_idx
*
BK
/
2
)
+
crosswise
+
((
0
^
(
quad_factor
&
1
))
<<
2
);
int32_t
*
read_src_s_1
=
smem
+
(
warp_x
*
8
*
BK
)
+
(
quad_idx
*
BK
/
2
)
+
crosswise
+
((
1
^
(
quad_factor
&
1
))
<<
2
);
int32_t
*
read_flt_s_0
=
smem
+
BN
*
BK
/
8
+
(
warp_y
*
8
*
BK
)
+
(
quad_idx
*
BK
/
2
)
+
crosswise
+
((
0
^
(
quad_factor
&
1
))
<<
2
);
int32_t
*
read_flt_s_1
=
smem
+
BN
*
BK
/
8
+
(
warp_y
*
8
*
BK
)
+
(
quad_idx
*
BK
/
2
)
+
crosswise
+
((
1
^
(
quad_factor
&
1
))
<<
2
);
#pragma unroll
for
(
int
i
=
0
;
i
<
reg_m
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_n
;
j
++
)
{
reg_acc
[
i
][
j
]
=
make_int2
(
0
,
0
);
}
}
const
int
smem_switch
=
4096
;
const
int
smem_switch_back
=
-
smem_switch
*
(
stages
-
1
);
int
stage
=
0
;
uint32_t
offset
[
2
]
=
{
0
,
2
};
// high & low
int
src_step
[
2
],
x
[
2
],
y
[
2
];
// global mem --> shared mem, stage 0
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
src_step
[
i
]
=
conv2d_constant
.
c_offset
[
offset
[
i
]];
uint32_t
spatial
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
(
conv2d_constant
.
c_offset
[
offset
[
i
]
+
1
])));
x
[
i
]
=
(
spatial
&
0xff
);
y
[
i
]
=
((
spatial
>>
8
)
&
0xff
);
if
(
offset
[
i
]
<
conv2d_constant
.
c_offset_param
.
max
)
{
offset
[
i
]
+=
4
;
}
else
{
offset
[
i
]
+=
conv2d_constant
.
c_offset_param
.
rewind
;
}
}
bool
guard0
[
2
],
guard1
[
2
];
guard0
[
0
]
=
g
[
0
]
&&
x
[
0
]
>=
h_start
[
0
]
&&
x
[
0
]
<
h_end
[
0
]
&&
y
[
0
]
>=
w_start
[
0
]
&&
y
[
0
]
<
w_end
[
0
]
&&
iter
>
0
;
guard0
[
1
]
=
g
[
0
]
&&
x
[
1
]
>=
h_start
[
0
]
&&
x
[
1
]
<
h_end
[
0
]
&&
y
[
1
]
>=
w_start
[
0
]
&&
y
[
1
]
<
w_end
[
0
]
&&
iter
>
1
;
guard1
[
0
]
=
g
[
1
]
&&
x
[
0
]
>=
h_start
[
1
]
&&
x
[
0
]
<
h_end
[
1
]
&&
y
[
0
]
>=
w_start
[
1
]
&&
y
[
0
]
<
w_end
[
1
]
&&
iter
>
0
;
guard1
[
1
]
=
g
[
1
]
&&
x
[
1
]
>=
h_start
[
1
]
&&
x
[
1
]
<
h_end
[
1
]
&&
y
[
1
]
>=
w_start
[
1
]
&&
y
[
1
]
<
w_end
[
1
]
&&
iter
>
1
;
g_src_ptr
[
0
]
+=
src_step
[
0
];
g_src_ptr
[
1
]
+=
src_step
[
1
];
g_src_ptr
[
2
]
+=
src_step
[
0
];
g_src_ptr
[
3
]
+=
src_step
[
1
];
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
stage
++
;
cp_async_fence
();
// global mem --> shared mem, stage 1 -> stage n
iter
-=
2
;
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
src_step
[
i
]
=
conv2d_constant
.
c_offset
[
offset
[
i
]];
uint32_t
spatial
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
(
conv2d_constant
.
c_offset
[
offset
[
i
]
+
1
])));
x
[
i
]
=
(
spatial
&
0xff
);
y
[
i
]
=
((
spatial
>>
8
)
&
0xff
);
if
(
offset
[
i
]
<
conv2d_constant
.
c_offset_param
.
max
)
{
offset
[
i
]
+=
4
;
}
else
{
offset
[
i
]
+=
conv2d_constant
.
c_offset_param
.
rewind
;
}
}
guard0
[
0
]
=
g
[
0
]
&&
x
[
0
]
>=
h_start
[
0
]
&&
x
[
0
]
<
h_end
[
0
]
&&
y
[
0
]
>=
w_start
[
0
]
&&
y
[
0
]
<
w_end
[
0
];
guard0
[
1
]
=
g
[
0
]
&&
x
[
1
]
>=
h_start
[
0
]
&&
x
[
1
]
<
h_end
[
0
]
&&
y
[
1
]
>=
w_start
[
0
]
&&
y
[
1
]
<
w_end
[
0
];
guard1
[
0
]
=
g
[
1
]
&&
x
[
0
]
>=
h_start
[
1
]
&&
x
[
0
]
<
h_end
[
1
]
&&
y
[
0
]
>=
w_start
[
1
]
&&
y
[
0
]
<
w_end
[
1
];
guard1
[
1
]
=
g
[
1
]
&&
x
[
1
]
>=
h_start
[
1
]
&&
x
[
1
]
<
h_end
[
1
]
&&
y
[
1
]
>=
w_start
[
1
]
&&
y
[
1
]
<
w_end
[
1
];
g_src_ptr
[
0
]
+=
src_step
[
0
];
g_src_ptr
[
1
]
+=
src_step
[
1
];
g_src_ptr
[
2
]
+=
src_step
[
0
];
g_src_ptr
[
3
]
+=
src_step
[
1
];
g_filter_ptr0
+=
8
*
64
;
g_filter_ptr1
+=
8
*
64
;
g_filter_ptr2
+=
8
*
64
;
g_filter_ptr3
+=
8
*
64
;
write_src_s
[
0
]
+=
smem_switch
;
write_src_s
[
1
]
+=
smem_switch
;
write_flt_s
+=
smem_switch
;
for
(;
iter
>=
2
&&
stage
<
stages
-
1
;
iter
-=
2
)
{
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
stage
++
;
cp_async_fence
();
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
src_step
[
i
]
=
conv2d_constant
.
c_offset
[
offset
[
i
]];
uint32_t
spatial
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
(
conv2d_constant
.
c_offset
[
offset
[
i
]
+
1
])));
x
[
i
]
=
(
spatial
&
0xff
);
y
[
i
]
=
((
spatial
>>
8
)
&
0xff
);
if
(
offset
[
i
]
<
conv2d_constant
.
c_offset_param
.
max
)
{
offset
[
i
]
+=
4
;
}
else
{
offset
[
i
]
+=
conv2d_constant
.
c_offset_param
.
rewind
;
}
}
guard0
[
0
]
=
g
[
0
]
&&
x
[
0
]
>=
h_start
[
0
]
&&
x
[
0
]
<
h_end
[
0
]
&&
y
[
0
]
>=
w_start
[
0
]
&&
y
[
0
]
<
w_end
[
0
];
guard0
[
1
]
=
g
[
0
]
&&
x
[
1
]
>=
h_start
[
0
]
&&
x
[
1
]
<
h_end
[
0
]
&&
y
[
1
]
>=
w_start
[
0
]
&&
y
[
1
]
<
w_end
[
0
];
guard1
[
0
]
=
g
[
1
]
&&
x
[
0
]
>=
h_start
[
1
]
&&
x
[
0
]
<
h_end
[
1
]
&&
y
[
0
]
>=
w_start
[
1
]
&&
y
[
0
]
<
w_end
[
1
];
guard1
[
1
]
=
g
[
1
]
&&
x
[
1
]
>=
h_start
[
1
]
&&
x
[
1
]
<
h_end
[
1
]
&&
y
[
1
]
>=
w_start
[
1
]
&&
y
[
1
]
<
w_end
[
1
];
g_src_ptr
[
0
]
+=
src_step
[
0
];
g_src_ptr
[
1
]
+=
src_step
[
1
];
g_src_ptr
[
2
]
+=
src_step
[
0
];
g_src_ptr
[
3
]
+=
src_step
[
1
];
g_filter_ptr0
+=
8
*
64
;
g_filter_ptr1
+=
8
*
64
;
g_filter_ptr2
+=
8
*
64
;
g_filter_ptr3
+=
8
*
64
;
write_src_s
[
0
]
+=
smem_switch
;
write_src_s
[
1
]
+=
smem_switch
;
write_flt_s
+=
smem_switch
;
}
bool
is_copy
=
false
;
if
(
iter
==
1
&&
stage
!=
stages
-
1
)
{
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
]
&&
iter
>
1
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
]
&&
iter
>
1
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
&&
iter
>
htid
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
stage
++
;
is_copy
=
true
;
cp_async_fence
();
}
bool
only_one_stage
=
(
stage
==
1
)
?
true
:
false
;
if
(
stage
>=
2
)
{
cp_async_wait
(
stages
-
2
);
}
else
{
cp_async_wait
(
0
);
}
__syncthreads
();
// read fuse_z
int2
reg_fuse_z
[
reg_m
]
=
{
make_int2
(
z_zero_point
,
z_zero_point
),
make_int2
(
z_zero_point
,
z_zero_point
),
make_int2
(
z_zero_point
,
z_zero_point
),
make_int2
(
z_zero_point
,
z_zero_point
),
make_int2
(
z_zero_point
,
z_zero_point
),
make_int2
(
z_zero_point
,
z_zero_point
),
make_int2
(
z_zero_point
,
z_zero_point
),
make_int2
(
z_zero_point
,
z_zero_point
)};
int
d_offset
=
(
bidy
*
(
BM
>>
6
)
+
warp_y
)
*
param
.
ocs
+
(
idx_in_quad
<<
3
);
const
int8_t
*
__restrict__
g_z_ptr
=
z
+
d_offset
;
section
=
tid31
>>
2
;
size_t
nhw_post0
=
bidx
*
BN
+
warp_x
*
64
+
section
;
size_t
nhw_post1
=
nhw_post0
+
8
;
size_t
nhw_post2
=
nhw_post0
+
16
;
size_t
nhw_post3
=
nhw_post0
+
24
;
size_t
stg_oc
=
bidy
*
BM
+
(
warp_y
<<
6
);
int
*
g_offset
=
((
int
*
)
&
reg_filter_cache
);
bool
stg_guard
[
8
];
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
LDG_4x1
(
reg_fuse_z
,
g_offset
,
y
)
nhw_post0
+=
32
;
nhw_post1
+=
32
;
nhw_post2
+=
32
;
nhw_post3
+=
32
;
}
for
(;
iter
>=
2
;
iter
-=
2
)
{
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
stage
++
;
cp_async_fence
();
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s_0
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
0
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s_0
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
int
comp
=
(
k_inner
&
0x1
);
int
load
=
1
-
comp
;
if
(
k_inner
<
BKd32
-
1
)
{
int32_t
*
read_src_s
=
(
k_inner
&
1
)
?
read_src_s_0
:
read_src_s_1
;
int32_t
*
read_flt_s
=
(
k_inner
&
1
)
?
read_flt_s_0
:
read_flt_s_1
;
read_src_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
read_flt_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
load
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
load
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
}
int
*
A
=
reinterpret_cast
<
int
*>
(
&
reg_flt
[
comp
][
0
]);
int
*
B
=
reinterpret_cast
<
int
*>
(
&
reg_src
[
comp
][
0
]);
#pragma unroll
for
(
int
x
=
0
;
x
<
reg_n
;
x
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
++
)
{
int
*
D
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
int
*
C
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
asm
volatile
(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};
\n
"
:
"=r"
(
D
[
0
]),
"=r"
(
D
[
1
])
:
"r"
(
B
[
y
]),
"r"
(
A
[
x
]),
"r"
(
C
[
0
]),
"r"
(
C
[
1
]));
}
}
}
if
(
stage
==
stages
)
{
stage
=
0
;
write_src_s
[
0
]
+=
smem_switch_back
;
write_src_s
[
1
]
+=
smem_switch_back
;
write_flt_s
+=
smem_switch_back
;
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
else
if
(
stage
==
stages
-
1
)
{
write_src_s
[
0
]
+=
smem_switch
;
write_src_s
[
1
]
+=
smem_switch
;
write_flt_s
+=
smem_switch
;
read_src_s_0
+=
smem_switch_back
;
read_src_s_1
+=
smem_switch_back
;
read_flt_s_0
+=
smem_switch_back
;
read_flt_s_1
+=
smem_switch_back
;
}
else
{
write_src_s
[
0
]
+=
smem_switch
;
write_src_s
[
1
]
+=
smem_switch
;
write_flt_s
+=
smem_switch
;
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
int
src_step
[
2
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
src_step
[
i
]
=
conv2d_constant
.
c_offset
[
offset
[
i
]];
uint32_t
spatial
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
(
conv2d_constant
.
c_offset
[
offset
[
i
]
+
1
])));
x
[
i
]
=
(
spatial
&
0xff
);
y
[
i
]
=
((
spatial
>>
8
)
&
0xff
);
if
(
offset
[
i
]
<
conv2d_constant
.
c_offset_param
.
max
)
{
offset
[
i
]
+=
4
;
}
else
{
offset
[
i
]
+=
conv2d_constant
.
c_offset_param
.
rewind
;
}
}
guard0
[
0
]
=
g
[
0
]
&&
x
[
0
]
>=
h_start
[
0
]
&&
x
[
0
]
<
h_end
[
0
]
&&
y
[
0
]
>=
w_start
[
0
]
&&
y
[
0
]
<
w_end
[
0
];
guard0
[
1
]
=
g
[
0
]
&&
x
[
1
]
>=
h_start
[
0
]
&&
x
[
1
]
<
h_end
[
0
]
&&
y
[
1
]
>=
w_start
[
0
]
&&
y
[
1
]
<
w_end
[
0
];
guard1
[
0
]
=
g
[
1
]
&&
x
[
0
]
>=
h_start
[
1
]
&&
x
[
0
]
<
h_end
[
1
]
&&
y
[
0
]
>=
w_start
[
1
]
&&
y
[
0
]
<
w_end
[
1
];
guard1
[
1
]
=
g
[
1
]
&&
x
[
1
]
>=
h_start
[
1
]
&&
x
[
1
]
<
h_end
[
1
]
&&
y
[
1
]
>=
w_start
[
1
]
&&
y
[
1
]
<
w_end
[
1
];
g_src_ptr
[
0
]
+=
src_step
[
0
];
g_src_ptr
[
1
]
+=
src_step
[
1
];
g_src_ptr
[
2
]
+=
src_step
[
0
];
g_src_ptr
[
3
]
+=
src_step
[
1
];
g_filter_ptr0
+=
8
*
64
;
g_filter_ptr1
+=
8
*
64
;
g_filter_ptr2
+=
8
*
64
;
g_filter_ptr3
+=
8
*
64
;
cp_async_wait
(
stages
-
2
);
__syncthreads
();
}
if
(
iter
>
0
)
{
if
(
!
is_copy
)
{
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
]
&&
iter
>
1
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
]
&&
iter
>
1
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
&&
iter
>
htid
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
cp_async_fence
();
}
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s_0
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
0
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s_0
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
int
comp
=
(
k_inner
&
0x1
);
int
load
=
1
-
comp
;
if
(
k_inner
<
BKd32
-
1
)
{
int32_t
*
read_src_s
=
(
k_inner
&
1
)
?
read_src_s_0
:
read_src_s_1
;
int32_t
*
read_flt_s
=
(
k_inner
&
1
)
?
read_flt_s_0
:
read_flt_s_1
;
read_src_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
read_flt_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
load
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
load
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
}
int
*
A
=
reinterpret_cast
<
int
*>
(
&
reg_flt
[
comp
][
0
]);
int
*
B
=
reinterpret_cast
<
int
*>
(
&
reg_src
[
comp
][
0
]);
#pragma unroll
for
(
int
x
=
0
;
x
<
reg_n
;
x
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
++
)
{
int
*
D
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
int
*
C
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
asm
volatile
(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};
\n
"
:
"=r"
(
D
[
0
]),
"=r"
(
D
[
1
])
:
"r"
(
B
[
y
]),
"r"
(
A
[
x
]),
"r"
(
C
[
0
]),
"r"
(
C
[
1
]));
}
}
}
stage
++
;
if
(
stage
==
stages
)
{
stage
=
0
;
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
else
if
(
stage
==
stages
-
1
)
{
read_src_s_0
+=
smem_switch_back
;
read_src_s_1
+=
smem_switch_back
;
read_flt_s_0
+=
smem_switch_back
;
read_flt_s_1
+=
smem_switch_back
;
}
else
{
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
cp_async_wait
(
stages
-
2
);
}
if
(
!
only_one_stage
)
{
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s_0
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
0
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s_0
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
int
comp
=
(
k_inner
&
0x1
);
int
load
=
1
-
comp
;
if
(
k_inner
<
BKd32
-
1
)
{
int32_t
*
read_src_s
=
(
k_inner
&
1
)
?
read_src_s_0
:
read_src_s_1
;
int32_t
*
read_flt_s
=
(
k_inner
&
1
)
?
read_flt_s_0
:
read_flt_s_1
;
read_src_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
read_flt_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
load
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
load
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
}
int
*
A
=
reinterpret_cast
<
int
*>
(
&
reg_flt
[
comp
][
0
]);
int
*
B
=
reinterpret_cast
<
int
*>
(
&
reg_src
[
comp
][
0
]);
#pragma unroll
for
(
int
x
=
0
;
x
<
reg_n
;
x
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
++
)
{
int
*
D
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
int
*
C
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
asm
volatile
(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};
\n
"
:
"=r"
(
D
[
0
]),
"=r"
(
D
[
1
])
:
"r"
(
B
[
y
]),
"r"
(
A
[
x
]),
"r"
(
C
[
0
]),
"r"
(
C
[
1
]));
}
}
}
stage
++
;
if
(
stage
==
stages
)
{
stage
=
0
;
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
else
if
(
stage
==
stages
-
1
)
{
read_src_s_0
+=
smem_switch_back
;
read_src_s_1
+=
smem_switch_back
;
read_flt_s_0
+=
smem_switch_back
;
read_flt_s_1
+=
smem_switch_back
;
}
else
{
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
cp_async_wait
(
0
);
}
guard
=
iter
<
0
;
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s_0
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
0
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s_0
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
// compute
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
int
comp
=
(
k_inner
&
0x1
);
int
load
=
1
-
comp
;
if
(
k_inner
<
BKd32
-
1
&&
!
(
k_inner
==
1
&&
guard
))
{
int32_t
*
read_src_s
=
(
k_inner
&
1
)
?
read_src_s_0
:
read_src_s_1
;
int32_t
*
read_flt_s
=
(
k_inner
&
1
)
?
read_flt_s_0
:
read_flt_s_1
;
read_src_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
read_flt_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
load
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
load
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
}
int
*
A
=
reinterpret_cast
<
int
*>
(
&
reg_flt
[
comp
][
0
]);
int
*
B
=
reinterpret_cast
<
int
*>
(
&
reg_src
[
comp
][
0
]);
#pragma unroll
for
(
int
x
=
0
;
x
<
reg_n
;
x
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
++
)
{
int
*
D
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
int
*
C
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
asm
volatile
(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};
\n
"
:
"=r"
(
D
[
0
]),
"=r"
(
D
[
1
])
:
"r"
(
B
[
y
]),
"r"
(
A
[
x
]),
"r"
(
C
[
0
]),
"r"
(
C
[
1
]));
}
}
if
(
k_inner
==
1
&&
guard
)
{
break
;
}
}
__syncthreads
();
/// output
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
mul_v4
(
load_bias3
,
load_bias3
,
beta
);
}
int8_t
*
__restrict__
g_dst_ptr
=
dst
+
d_offset
;
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
I2F_4x8
(
reg_acc
,
y
,
0
);
FMA_4x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
FUSE_Z_4x8
(
reg_acc
,
y
,
0
,
reg_fuse_z
,
gamma
,
z_zero_point
);
PACK_F2I_WITH_RELU_4x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_AFTER_LDG_4x1
(
g_offset
,
reg_acc
,
y
,
0
);
}
#endif
}
}
// namespace
namespace
megdnn
{
namespace
cuda
{
namespace
ptx
{
void
run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu
(
const
dim3
grid
,
const
dim3
block
,
cudaStream_t
stream
,
void
**
params
)
{
#ifdef SM80_SUPPORTED
cudaFuncSetAttribute
(
ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
49152
);
ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu
<<<
grid
,
block
,
49152
,
stream
>>>
(
*
((
int8_t
**
)
params
[
0
]),
*
((
int8_t
**
)
params
[
1
]),
*
((
float
**
)
params
[
2
]),
*
((
int8_t
**
)
params
[
3
]),
*
((
int8_t
**
)
params
[
4
]),
*
((
float
*
)
params
[
5
]),
*
((
float
*
)
params
[
6
]),
*
((
float
*
)
params
[
7
]),
*
((
uint32_t
*
)
params
[
8
]),
*
((
uint32_t
*
)
params
[
9
]),
*
((
float
*
)
params
[
10
]),
*
((
uint32_t
*
)
params
[
11
]),
*
((
Conv2dInt4Param
*
)
params
[
12
]),
*
((
Conv2dConstantOffset
*
)
params
[
13
]));
#endif
}
}
// namespace ptx
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/cuda/ptx/uint4_int4/imma8832_128x128.cuh
0 → 100644
浏览文件 @
1f8e930e
#pragma once
#include "./base.cuh"
#define TX 128
#define TY 1
#define BM 128
#define BN 128
#define BK 128
#define mma_m 16
#define mma_n 8
#define mma_k 64
#define reg_m 8
#define reg_n 8
#define packed_channel 64
#define BKd32 (BK / 32)
#define BKd64 (BK / 64)
#define reg_md4 (reg_m >> 2)
#define WARPS (TX / 32)
#define cache_per_warp 128
#define reg_nd4 (reg_n >> 2)
#define ldg_src (BN * BK / (16 * TX))
#define ldg_filter (BM * BK / (16 * TX))
#define ldg_width 16
// vim: syntax=cpp.doxygen
dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu
0 → 100644
浏览文件 @
1f8e930e
#include <cuda_runtime.h>
#include <stdio.h>
#include "./imma8832_128x128.cuh"
#include "./kern.cuh"
#include "./macro.cuh"
#include "./tools.cuh"
using
namespace
convolution
;
namespace
{
#ifdef SM80_ENABLED
extern
"C"
__device__
void
g2s_int4
(
const
int4
*
gm
,
int4
*
sm
)
{
unsigned
sm_addr
=
get_smem_pointer
(
sm
);
const
int
SizeInBytes
=
16
;
#if ENABLE_L2_PREFETCH
asm
volatile
(
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;
\n
"
::
"r"
(
sm_addr
),
"l"
(
gm
),
"n"
(
SizeInBytes
));
#else
asm
volatile
(
"cp.async.cg.shared.global [%0], [%1], %2;
\n
"
::
"r"
(
sm_addr
),
"l"
(
gm
),
"n"
(
SizeInBytes
));
#endif
}
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does
/// not block.
#define cp_async_fence() asm volatile("cp.async.commit_group;\n" ::)
/// Blocks until all but <N> previous cp.async.commit_group operations have
/// committed.
#define cp_async_wait(N) asm volatile("cp.async.wait_group %0;\n" ::"n"(N))
#endif
extern
"C"
__global__
void
__launch_bounds__
(
256
)
ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu
(
const
int8_t
*
__restrict__
src
,
int8_t
*
__restrict__
filter
,
const
float
*
__restrict__
bias
,
int8_t
*
__restrict__
dst
,
float
alpha
,
float
beta
,
uint32_t
pk_src_zero_point
,
float
dst_zero_point
,
uint32_t
relu
,
Conv2dInt4Param
param
,
Conv2dConstantOffset
conv2d_constant
)
{
#ifdef SM80_ENABLED
const
int
stages
=
3
;
const
uint32_t
tid
=
threadIdx
.
x
;
const
uint32_t
bidx
=
blockIdx
.
x
;
const
uint32_t
bidy
=
blockIdx
.
y
;
extern
__shared__
int32_t
smem
[];
// (128+128)*128/8*stages
int2
reg_acc
[
reg_m
][
reg_n
];
int4
reg_src
[
2
][
reg_nd4
];
int4
reg_flt
[
2
][
reg_md4
];
// use in other way, maybe use reg_ser/flt
int4
reg_src_cache
[
2
];
int4
reg_filter_cache
[
4
];
uint32_t
tid127
=
(
tid
&
127
);
uint32_t
section
=
(
tid127
>>
1
);
uint32_t
residue
=
((
tid127
<<
5
)
&
63
);
uint32_t
nhw
=
bidx
*
BN
+
section
;
uint32_t
tn
,
hw
,
toh
,
tow
;
int
tih
,
tiw
;
int
h_start
[
2
];
int
h_end
[
2
];
int
w_start
[
2
];
int
w_end
[
2
];
bool
g
[
2
];
const
int8_t
*
__restrict__
g_src_ptr
[
4
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
if
(
i
!=
0
)
{
nhw
+=
64
;
}
tn
=
nhw
/
param
.
div_ohow
;
hw
=
nhw
%
param
.
div_ohow
;
toh
=
hw
/
param
.
div_ow
;
tow
=
hw
%
param
.
div_ow
;
tih
=
toh
*
param
.
sh
-
param
.
ph
;
tiw
=
tow
*
param
.
sw
-
param
.
pw
;
g
[
i
]
=
tn
<
param
.
n
;
h_start
[
i
]
=
-
tih
;
h_end
[
i
]
=
param
.
ih
-
tih
;
w_start
[
i
]
=
-
tiw
;
w_end
[
i
]
=
param
.
iw
-
tiw
;
// param's members have been converted to byte offset and int4 offset need to
// div 2
int
src_offset
=
tn
*
param
.
ibs
+
tih
*
param
.
ihs
+
((
int
)(
tiw
*
packed_channel
+
residue
)
>>
1
);
g_src_ptr
[
i
*
2
]
=
src
+
src_offset
;
g_src_ptr
[
i
*
2
+
1
]
=
g_src_ptr
[
i
*
2
];
}
const
uint32_t
section_section
=
(
section
>>
2
);
const
uint32_t
section_residue
=
(
section
&
3
);
const
uint32_t
section_factor
=
((
section
&
15
)
>>
2
);
const
uint32_t
crosswise_offset
=
((
section_residue
>>
1
)
<<
4
)
+
(((
section_residue
&
1
)
^
(
section_factor
>>
1
))
<<
3
);
const
uint32_t
residue_offset
=
((
residue
>>
5
)
^
(
section_factor
&
1
))
<<
2
;
// next + 64 * BK / 8
int32_t
*
write_src_s
[
2
];
write_src_s
[
0
]
=
smem
+
section_section
*
BK
/
2
+
crosswise_offset
+
residue_offset
;
write_src_s
[
1
]
=
write_src_s
[
0
]
+
32
;
int
iter
=
(
param
.
icfhfw
>>
6
);
uint32_t
tid31
=
(
tid
&
31
);
uint32_t
warp_idx
=
(
tid
>>
5
);
uint32_t
warp_strided
=
(
warp_idx
<<
2
);
uint32_t
htid
=
(
tid31
>>
4
);
const
uint32_t
flt_strided
=
bidy
*
BM
/
8
+
warp_strided
;
bool
guard
=
flt_strided
*
8
<
param
.
oc
&&
iter
>
htid
;
// icfhfw * 8/2 is a stride
int8_t
*
__restrict__
g_filter_ptr0
=
filter
+
flt_strided
*
(
param
.
icfhfw
*
4
)
+
(
tid31
<<
4
);
int8_t
*
__restrict__
g_filter_ptr1
=
g_filter_ptr0
+
(
param
.
icfhfw
*
4
);
int8_t
*
__restrict__
g_filter_ptr2
=
g_filter_ptr0
+
(
param
.
icfhfw
*
8
);
int8_t
*
__restrict__
g_filter_ptr3
=
g_filter_ptr0
+
(
param
.
icfhfw
*
12
);
// next + BK * 8 / (INT32/INT4)
uint32_t
q
=
(
tid31
>>
3
);
uint32_t
r
=
(
tid31
&
7
);
int32_t
*
write_flt_s
=
smem
+
BN
*
BK
/
8
+
warp_strided
*
BK
+
((
q
&
1
)
<<
6
)
+
((
q
>>
1
)
<<
5
)
+
(
r
<<
2
);
uint32_t
quad_idx
=
(
tid31
>>
2
);
uint32_t
idx_in_quad
=
(
tid
&
3
);
uint32_t
quad_factor
=
((
tid
&
15
)
>>
2
);
uint32_t
crosswise
=
((
idx_in_quad
>>
1
)
<<
4
)
+
(((
idx_in_quad
&
1
)
^
(
quad_factor
>>
1
))
<<
3
);
uint32_t
warp_x
=
(
warp_idx
>>
1
);
uint32_t
warp_y
=
(
warp_idx
&
1
);
int32_t
*
read_src_s_0
=
smem
+
(
warp_x
*
8
*
BK
)
+
(
quad_idx
*
BK
/
2
)
+
crosswise
+
((
0
^
(
quad_factor
&
1
))
<<
2
);
int32_t
*
read_src_s_1
=
smem
+
(
warp_x
*
8
*
BK
)
+
(
quad_idx
*
BK
/
2
)
+
crosswise
+
((
1
^
(
quad_factor
&
1
))
<<
2
);
int32_t
*
read_flt_s_0
=
smem
+
BN
*
BK
/
8
+
(
warp_y
*
8
*
BK
)
+
(
quad_idx
*
BK
/
2
)
+
crosswise
+
((
0
^
(
quad_factor
&
1
))
<<
2
);
int32_t
*
read_flt_s_1
=
smem
+
BN
*
BK
/
8
+
(
warp_y
*
8
*
BK
)
+
(
quad_idx
*
BK
/
2
)
+
crosswise
+
((
1
^
(
quad_factor
&
1
))
<<
2
);
#pragma unroll
for
(
int
i
=
0
;
i
<
reg_m
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_n
;
j
++
)
{
reg_acc
[
i
][
j
]
=
make_int2
(
0
,
0
);
}
}
const
int
smem_switch
=
4096
;
const
int
smem_switch_back
=
-
smem_switch
*
(
stages
-
1
);
int
stage
=
0
;
uint32_t
offset
[
2
]
=
{
0
,
2
};
// high & low
int
src_step
[
2
],
x
[
2
],
y
[
2
];
// global mem --> shared mem, stage 0
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
src_step
[
i
]
=
conv2d_constant
.
c_offset
[
offset
[
i
]];
uint32_t
spatial
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
(
conv2d_constant
.
c_offset
[
offset
[
i
]
+
1
])));
x
[
i
]
=
(
spatial
&
0xff
);
y
[
i
]
=
((
spatial
>>
8
)
&
0xff
);
if
(
offset
[
i
]
<
conv2d_constant
.
c_offset_param
.
max
)
{
offset
[
i
]
+=
4
;
}
else
{
offset
[
i
]
+=
conv2d_constant
.
c_offset_param
.
rewind
;
}
}
bool
guard0
[
2
],
guard1
[
2
];
guard0
[
0
]
=
g
[
0
]
&&
x
[
0
]
>=
h_start
[
0
]
&&
x
[
0
]
<
h_end
[
0
]
&&
y
[
0
]
>=
w_start
[
0
]
&&
y
[
0
]
<
w_end
[
0
]
&&
iter
>
0
;
guard0
[
1
]
=
g
[
0
]
&&
x
[
1
]
>=
h_start
[
0
]
&&
x
[
1
]
<
h_end
[
0
]
&&
y
[
1
]
>=
w_start
[
0
]
&&
y
[
1
]
<
w_end
[
0
]
&&
iter
>
1
;
guard1
[
0
]
=
g
[
1
]
&&
x
[
0
]
>=
h_start
[
1
]
&&
x
[
0
]
<
h_end
[
1
]
&&
y
[
0
]
>=
w_start
[
1
]
&&
y
[
0
]
<
w_end
[
1
]
&&
iter
>
0
;
guard1
[
1
]
=
g
[
1
]
&&
x
[
1
]
>=
h_start
[
1
]
&&
x
[
1
]
<
h_end
[
1
]
&&
y
[
1
]
>=
w_start
[
1
]
&&
y
[
1
]
<
w_end
[
1
]
&&
iter
>
1
;
g_src_ptr
[
0
]
+=
src_step
[
0
];
g_src_ptr
[
1
]
+=
src_step
[
1
];
g_src_ptr
[
2
]
+=
src_step
[
0
];
g_src_ptr
[
3
]
+=
src_step
[
1
];
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
stage
++
;
cp_async_fence
();
// global mem --> shared mem, stage 1 -> stage n
iter
-=
2
;
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
src_step
[
i
]
=
conv2d_constant
.
c_offset
[
offset
[
i
]];
uint32_t
spatial
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
(
conv2d_constant
.
c_offset
[
offset
[
i
]
+
1
])));
x
[
i
]
=
(
spatial
&
0xff
);
y
[
i
]
=
((
spatial
>>
8
)
&
0xff
);
if
(
offset
[
i
]
<
conv2d_constant
.
c_offset_param
.
max
)
{
offset
[
i
]
+=
4
;
}
else
{
offset
[
i
]
+=
conv2d_constant
.
c_offset_param
.
rewind
;
}
}
guard0
[
0
]
=
g
[
0
]
&&
x
[
0
]
>=
h_start
[
0
]
&&
x
[
0
]
<
h_end
[
0
]
&&
y
[
0
]
>=
w_start
[
0
]
&&
y
[
0
]
<
w_end
[
0
];
guard0
[
1
]
=
g
[
0
]
&&
x
[
1
]
>=
h_start
[
0
]
&&
x
[
1
]
<
h_end
[
0
]
&&
y
[
1
]
>=
w_start
[
0
]
&&
y
[
1
]
<
w_end
[
0
];
guard1
[
0
]
=
g
[
1
]
&&
x
[
0
]
>=
h_start
[
1
]
&&
x
[
0
]
<
h_end
[
1
]
&&
y
[
0
]
>=
w_start
[
1
]
&&
y
[
0
]
<
w_end
[
1
];
guard1
[
1
]
=
g
[
1
]
&&
x
[
1
]
>=
h_start
[
1
]
&&
x
[
1
]
<
h_end
[
1
]
&&
y
[
1
]
>=
w_start
[
1
]
&&
y
[
1
]
<
w_end
[
1
];
g_src_ptr
[
0
]
+=
src_step
[
0
];
g_src_ptr
[
1
]
+=
src_step
[
1
];
g_src_ptr
[
2
]
+=
src_step
[
0
];
g_src_ptr
[
3
]
+=
src_step
[
1
];
g_filter_ptr0
+=
8
*
64
;
g_filter_ptr1
+=
8
*
64
;
g_filter_ptr2
+=
8
*
64
;
g_filter_ptr3
+=
8
*
64
;
write_src_s
[
0
]
+=
smem_switch
;
write_src_s
[
1
]
+=
smem_switch
;
write_flt_s
+=
smem_switch
;
for
(;
iter
>=
2
&&
stage
<
stages
-
1
;
iter
-=
2
)
{
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
stage
++
;
cp_async_fence
();
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
src_step
[
i
]
=
conv2d_constant
.
c_offset
[
offset
[
i
]];
uint32_t
spatial
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
(
conv2d_constant
.
c_offset
[
offset
[
i
]
+
1
])));
x
[
i
]
=
(
spatial
&
0xff
);
y
[
i
]
=
((
spatial
>>
8
)
&
0xff
);
if
(
offset
[
i
]
<
conv2d_constant
.
c_offset_param
.
max
)
{
offset
[
i
]
+=
4
;
}
else
{
offset
[
i
]
+=
conv2d_constant
.
c_offset_param
.
rewind
;
}
}
guard0
[
0
]
=
g
[
0
]
&&
x
[
0
]
>=
h_start
[
0
]
&&
x
[
0
]
<
h_end
[
0
]
&&
y
[
0
]
>=
w_start
[
0
]
&&
y
[
0
]
<
w_end
[
0
];
guard0
[
1
]
=
g
[
0
]
&&
x
[
1
]
>=
h_start
[
0
]
&&
x
[
1
]
<
h_end
[
0
]
&&
y
[
1
]
>=
w_start
[
0
]
&&
y
[
1
]
<
w_end
[
0
];
guard1
[
0
]
=
g
[
1
]
&&
x
[
0
]
>=
h_start
[
1
]
&&
x
[
0
]
<
h_end
[
1
]
&&
y
[
0
]
>=
w_start
[
1
]
&&
y
[
0
]
<
w_end
[
1
];
guard1
[
1
]
=
g
[
1
]
&&
x
[
1
]
>=
h_start
[
1
]
&&
x
[
1
]
<
h_end
[
1
]
&&
y
[
1
]
>=
w_start
[
1
]
&&
y
[
1
]
<
w_end
[
1
];
g_src_ptr
[
0
]
+=
src_step
[
0
];
g_src_ptr
[
1
]
+=
src_step
[
1
];
g_src_ptr
[
2
]
+=
src_step
[
0
];
g_src_ptr
[
3
]
+=
src_step
[
1
];
g_filter_ptr0
+=
8
*
64
;
g_filter_ptr1
+=
8
*
64
;
g_filter_ptr2
+=
8
*
64
;
g_filter_ptr3
+=
8
*
64
;
write_src_s
[
0
]
+=
smem_switch
;
write_src_s
[
1
]
+=
smem_switch
;
write_flt_s
+=
smem_switch
;
}
bool
is_copy
=
false
;
if
(
iter
==
1
&&
stage
!=
stages
-
1
)
{
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
]
&&
iter
>
1
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
]
&&
iter
>
1
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
&&
iter
>
htid
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
stage
++
;
is_copy
=
true
;
cp_async_fence
();
}
// compute offset
int
d_offset
=
(
bidy
*
(
BM
>>
6
)
+
warp_y
)
*
param
.
ocs
+
(
idx_in_quad
<<
3
);
section
=
tid31
>>
2
;
size_t
nhw_post0
=
bidx
*
BN
+
warp_x
*
64
+
section
;
size_t
nhw_post1
=
nhw_post0
+
8
;
size_t
nhw_post2
=
nhw_post0
+
16
;
size_t
nhw_post3
=
nhw_post0
+
24
;
size_t
stg_oc
=
bidy
*
BM
+
(
warp_y
<<
6
);
int
*
g_offset
=
((
int
*
)
&
reg_filter_cache
);
bool
stg_guard
[
8
];
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
COMPUTE_OFFSET_4x1
(
reg_fuse_z
,
g_offset
,
y
)
nhw_post0
+=
32
;
nhw_post1
+=
32
;
nhw_post2
+=
32
;
nhw_post3
+=
32
;
}
bool
only_one_stage
=
(
stage
==
1
)
?
true
:
false
;
if
(
stage
>=
2
)
{
cp_async_wait
(
stages
-
2
);
}
else
{
cp_async_wait
(
0
);
}
__syncthreads
();
for
(;
iter
>=
2
;
iter
-=
2
)
{
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
stage
++
;
cp_async_fence
();
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s_0
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
0
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s_0
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
int
comp
=
(
k_inner
&
0x1
);
int
load
=
1
-
comp
;
if
(
k_inner
<
BKd32
-
1
)
{
int32_t
*
read_src_s
=
(
k_inner
&
1
)
?
read_src_s_0
:
read_src_s_1
;
int32_t
*
read_flt_s
=
(
k_inner
&
1
)
?
read_flt_s_0
:
read_flt_s_1
;
read_src_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
read_flt_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
load
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
load
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
}
int
*
A
=
reinterpret_cast
<
int
*>
(
&
reg_flt
[
comp
][
0
]);
int
*
B
=
reinterpret_cast
<
int
*>
(
&
reg_src
[
comp
][
0
]);
#pragma unroll
for
(
int
x
=
0
;
x
<
reg_n
;
x
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
++
)
{
int
*
D
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
int
*
C
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
asm
volatile
(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};
\n
"
:
"=r"
(
D
[
0
]),
"=r"
(
D
[
1
])
:
"r"
(
B
[
y
]),
"r"
(
A
[
x
]),
"r"
(
C
[
0
]),
"r"
(
C
[
1
]));
}
}
}
if
(
stage
==
stages
)
{
stage
=
0
;
write_src_s
[
0
]
+=
smem_switch_back
;
write_src_s
[
1
]
+=
smem_switch_back
;
write_flt_s
+=
smem_switch_back
;
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
else
if
(
stage
==
stages
-
1
)
{
write_src_s
[
0
]
+=
smem_switch
;
write_src_s
[
1
]
+=
smem_switch
;
write_flt_s
+=
smem_switch
;
read_src_s_0
+=
smem_switch_back
;
read_src_s_1
+=
smem_switch_back
;
read_flt_s_0
+=
smem_switch_back
;
read_flt_s_1
+=
smem_switch_back
;
}
else
{
write_src_s
[
0
]
+=
smem_switch
;
write_src_s
[
1
]
+=
smem_switch
;
write_flt_s
+=
smem_switch
;
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
int
src_step
[
2
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
src_step
[
i
]
=
conv2d_constant
.
c_offset
[
offset
[
i
]];
uint32_t
spatial
=
*
(
reinterpret_cast
<
const
uint32_t
*>
(
&
(
conv2d_constant
.
c_offset
[
offset
[
i
]
+
1
])));
x
[
i
]
=
(
spatial
&
0xff
);
y
[
i
]
=
((
spatial
>>
8
)
&
0xff
);
if
(
offset
[
i
]
<
conv2d_constant
.
c_offset_param
.
max
)
{
offset
[
i
]
+=
4
;
}
else
{
offset
[
i
]
+=
conv2d_constant
.
c_offset_param
.
rewind
;
}
}
guard0
[
0
]
=
g
[
0
]
&&
x
[
0
]
>=
h_start
[
0
]
&&
x
[
0
]
<
h_end
[
0
]
&&
y
[
0
]
>=
w_start
[
0
]
&&
y
[
0
]
<
w_end
[
0
];
guard0
[
1
]
=
g
[
0
]
&&
x
[
1
]
>=
h_start
[
0
]
&&
x
[
1
]
<
h_end
[
0
]
&&
y
[
1
]
>=
w_start
[
0
]
&&
y
[
1
]
<
w_end
[
0
];
guard1
[
0
]
=
g
[
1
]
&&
x
[
0
]
>=
h_start
[
1
]
&&
x
[
0
]
<
h_end
[
1
]
&&
y
[
0
]
>=
w_start
[
1
]
&&
y
[
0
]
<
w_end
[
1
];
guard1
[
1
]
=
g
[
1
]
&&
x
[
1
]
>=
h_start
[
1
]
&&
x
[
1
]
<
h_end
[
1
]
&&
y
[
1
]
>=
w_start
[
1
]
&&
y
[
1
]
<
w_end
[
1
];
g_src_ptr
[
0
]
+=
src_step
[
0
];
g_src_ptr
[
1
]
+=
src_step
[
1
];
g_src_ptr
[
2
]
+=
src_step
[
0
];
g_src_ptr
[
3
]
+=
src_step
[
1
];
g_filter_ptr0
+=
8
*
64
;
g_filter_ptr1
+=
8
*
64
;
g_filter_ptr2
+=
8
*
64
;
g_filter_ptr3
+=
8
*
64
;
cp_async_wait
(
stages
-
2
);
__syncthreads
();
}
if
(
iter
>
0
)
{
if
(
!
is_copy
)
{
if
(
guard0
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
0
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard0
[
1
]
&&
iter
>
1
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
1
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
0
])
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
2
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
0
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard1
[
1
]
&&
iter
>
1
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_src_ptr
[
3
]),
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_src_s
[
1
]
+
8
*
BK
))
=
make_int4
(
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
,
pk_src_zero_point
);
}
if
(
guard
&&
iter
>
htid
)
{
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr0
),
reinterpret_cast
<
int4
*>
(
write_flt_s
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr1
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr2
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
));
g2s_int4
(
reinterpret_cast
<
const
int4
*>
(
g_filter_ptr3
),
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
2
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
*
(
reinterpret_cast
<
int4
*>
(
write_flt_s
+
3
*
BK
))
=
make_int4
(
0
,
0
,
0
,
0
);
}
cp_async_fence
();
}
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s_0
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
0
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s_0
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
int
comp
=
(
k_inner
&
0x1
);
int
load
=
1
-
comp
;
if
(
k_inner
<
BKd32
-
1
)
{
int32_t
*
read_src_s
=
(
k_inner
&
1
)
?
read_src_s_0
:
read_src_s_1
;
int32_t
*
read_flt_s
=
(
k_inner
&
1
)
?
read_flt_s_0
:
read_flt_s_1
;
read_src_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
read_flt_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
load
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
load
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
}
int
*
A
=
reinterpret_cast
<
int
*>
(
&
reg_flt
[
comp
][
0
]);
int
*
B
=
reinterpret_cast
<
int
*>
(
&
reg_src
[
comp
][
0
]);
#pragma unroll
for
(
int
x
=
0
;
x
<
reg_n
;
x
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
++
)
{
int
*
D
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
int
*
C
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
asm
volatile
(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};
\n
"
:
"=r"
(
D
[
0
]),
"=r"
(
D
[
1
])
:
"r"
(
B
[
y
]),
"r"
(
A
[
x
]),
"r"
(
C
[
0
]),
"r"
(
C
[
1
]));
}
}
}
stage
++
;
if
(
stage
==
stages
)
{
stage
=
0
;
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
else
if
(
stage
==
stages
-
1
)
{
read_src_s_0
+=
smem_switch_back
;
read_src_s_1
+=
smem_switch_back
;
read_flt_s_0
+=
smem_switch_back
;
read_flt_s_1
+=
smem_switch_back
;
}
else
{
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
cp_async_wait
(
stages
-
2
);
}
if
(
!
only_one_stage
)
{
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s_0
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
0
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s_0
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
int
comp
=
(
k_inner
&
0x1
);
int
load
=
1
-
comp
;
if
(
k_inner
<
BKd32
-
1
)
{
int32_t
*
read_src_s
=
(
k_inner
&
1
)
?
read_src_s_0
:
read_src_s_1
;
int32_t
*
read_flt_s
=
(
k_inner
&
1
)
?
read_flt_s_0
:
read_flt_s_1
;
read_src_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
read_flt_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
load
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
load
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
}
int
*
A
=
reinterpret_cast
<
int
*>
(
&
reg_flt
[
comp
][
0
]);
int
*
B
=
reinterpret_cast
<
int
*>
(
&
reg_src
[
comp
][
0
]);
#pragma unroll
for
(
int
x
=
0
;
x
<
reg_n
;
x
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
++
)
{
int
*
D
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
int
*
C
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
asm
volatile
(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};
\n
"
:
"=r"
(
D
[
0
]),
"=r"
(
D
[
1
])
:
"r"
(
B
[
y
]),
"r"
(
A
[
x
]),
"r"
(
C
[
0
]),
"r"
(
C
[
1
]));
}
}
}
stage
++
;
if
(
stage
==
stages
)
{
stage
=
0
;
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
else
if
(
stage
==
stages
-
1
)
{
read_src_s_0
+=
smem_switch_back
;
read_src_s_1
+=
smem_switch_back
;
read_flt_s_0
+=
smem_switch_back
;
read_flt_s_1
+=
smem_switch_back
;
}
else
{
read_src_s_0
+=
smem_switch
;
read_src_s_1
+=
smem_switch
;
read_flt_s_0
+=
smem_switch
;
read_flt_s_1
+=
smem_switch
;
}
cp_async_wait
(
0
);
}
guard
=
iter
<
0
;
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s_0
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
0
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s_0
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
0
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
// compute
#pragma unroll
for
(
int
k_inner
=
0
;
k_inner
<
BKd32
;
k_inner
++
)
{
int
comp
=
(
k_inner
&
0x1
);
int
load
=
1
-
comp
;
if
(
k_inner
<
BKd32
-
1
&&
!
(
k_inner
==
1
&&
guard
))
{
int32_t
*
read_src_s
=
(
k_inner
&
1
)
?
read_src_s_0
:
read_src_s_1
;
int32_t
*
read_flt_s
=
(
k_inner
&
1
)
?
read_flt_s_0
:
read_flt_s_1
;
read_src_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
read_flt_s
+=
32
*
((
k_inner
+
1
)
>>
1
);
#pragma unroll // low
for
(
int
i
=
0
;
i
<
reg_nd4
;
++
i
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_src_s
+
i
*
4
*
BK
);
// BK*32/8
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_src
[
load
][
i
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
reg_md4
;
++
j
)
{
int
x
,
y
,
z
,
w
;
unsigned
addr
=
get_smem_pointer
(
read_flt_s
+
4
*
j
*
BK
);
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, "
"%3}, "
"[%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
addr
));
reg_flt
[
load
][
j
]
=
make_int4
(
x
,
y
,
z
,
w
);
}
}
int
*
A
=
reinterpret_cast
<
int
*>
(
&
reg_flt
[
comp
][
0
]);
int
*
B
=
reinterpret_cast
<
int
*>
(
&
reg_src
[
comp
][
0
]);
#pragma unroll
for
(
int
x
=
0
;
x
<
reg_n
;
x
++
)
{
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
++
)
{
int
*
D
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
int
*
C
=
reinterpret_cast
<
int
*>
(
&
reg_acc
[
y
][
x
]);
asm
volatile
(
"mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4."
"s32 "
"{%0,%1}, {%2}, {%3}, "
"{%4,%5};
\n
"
:
"=r"
(
D
[
0
]),
"=r"
(
D
[
1
])
:
"r"
(
B
[
y
]),
"r"
(
A
[
x
]),
"r"
(
C
[
0
]),
"r"
(
C
[
1
]));
}
}
if
(
k_inner
==
1
&&
guard
)
{
break
;
}
}
__syncthreads
();
/// output
size_t
oc
=
bidy
*
BM
+
(
warp_y
<<
6
)
+
16
*
idx_in_quad
;
const
float
*
bias_ptr
=
bias
+
oc
;
int4
load_bias0
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias1
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias2
=
make_int4
(
0
,
0
,
0
,
0
);
int4
load_bias3
=
make_int4
(
0
,
0
,
0
,
0
);
if
(
oc
<
param
.
oc
)
{
load_bias0
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
));
load_bias1
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
4
));
load_bias2
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
8
));
load_bias3
=
*
(
reinterpret_cast
<
const
int4
*>
(
bias_ptr
+
12
));
mul_v4
(
load_bias0
,
load_bias0
,
beta
);
mul_v4
(
load_bias1
,
load_bias1
,
beta
);
mul_v4
(
load_bias2
,
load_bias2
,
beta
);
mul_v4
(
load_bias3
,
load_bias3
,
beta
);
}
int8_t
*
__restrict__
g_dst_ptr
=
dst
+
d_offset
;
#pragma unroll
for
(
int
y
=
0
;
y
<
reg_m
;
y
+=
4
)
{
I2F_4x8
(
reg_acc
,
y
,
0
);
FMA_4x8
(
reg_acc
,
y
,
0
,
alpha
,
load_bias0
,
load_bias1
,
load_bias2
,
load_bias3
);
PACK_F2I_WITH_RELU_4x8
(
reg_acc
,
y
,
0
,
relu
,
dst_zero_point
);
STG_AFTER_LDG_4x1
(
g_offset
,
reg_acc
,
y
,
0
);
nhw_post0
+=
32
;
nhw_post1
+=
32
;
nhw_post2
+=
32
;
nhw_post3
+=
32
;
}
#endif
}
}
// namespace
namespace
megdnn
{
namespace
cuda
{
namespace
ptx
{
void
run_ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu
(
const
dim3
grid
,
const
dim3
block
,
cudaStream_t
stream
,
void
**
params
)
{
#ifdef SM80_SUPPORTED
cudaFuncSetAttribute
(
ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
49152
);
ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu
<<<
grid
,
block
,
49152
,
stream
>>>
(
*
((
int8_t
**
)
params
[
0
]),
*
((
int8_t
**
)
params
[
1
]),
*
((
float
**
)
params
[
2
]),
*
((
int8_t
**
)
params
[
3
]),
*
((
float
*
)
params
[
4
]),
*
((
float
*
)
params
[
5
]),
*
((
uint32_t
*
)
params
[
6
]),
*
((
float
*
)
params
[
7
]),
*
((
uint32_t
*
)
params
[
8
]),
*
((
Conv2dInt4Param
*
)
params
[
9
]),
*
((
Conv2dConstantOffset
*
)
params
[
10
]));
#endif
}
}
// namespace ptx
}
// namespace cuda
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/cuda/ptx/uint4_int4/kern.cuh
0 → 100644
浏览文件 @
1f8e930e
#include <cuda_runtime.h>
namespace
megdnn
{
namespace
cuda
{
namespace
ptx
{
void
run_ampere_conv_bias_uint4_int4_imma8832_ldg16_256x64_relu
(
const
dim3
grid
,
const
dim3
block
,
cudaStream_t
stream
,
void
**
params
);
void
run_ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu
(
const
dim3
grid
,
const
dim3
block
,
cudaStream_t
stream
,
void
**
params
);
void
run_ampere_conv_bias_uint4_int4_imma8832_ldg16_128x256_relu
(
const
dim3
grid
,
const
dim3
block
,
cudaStream_t
stream
,
void
**
params
);
void
run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_256x64_relu
(
const
dim3
grid
,
const
dim3
block
,
cudaStream_t
stream
,
void
**
params
);
void
run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu
(
const
dim3
grid
,
const
dim3
block
,
cudaStream_t
stream
,
void
**
params
);
void
run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x256_relu
(
const
dim3
grid
,
const
dim3
block
,
cudaStream_t
stream
,
void
**
params
);
}
// namespace ptx
}
// namespace cuda
}
// namespace megdnn
dnn/src/cuda/ptx/uint4_int4/macro.cuh
0 → 100644
浏览文件 @
1f8e930e
#pragma once
//! ============= i2f ===============
__device__
__forceinline__
void
i2f
(
int2
&
a
)
{
((
float
*
)
&
a
)[
0
]
=
static_cast
<
float
>
(
a
.
x
);
((
float
*
)
&
a
)[
1
]
=
static_cast
<
float
>
(
a
.
y
);
}
//! ============= mul ===============
template
<
typename
T
>
__device__
__forceinline__
void
mul_v4
(
int4
&
c
,
const
int4
a
,
const
T
alpha
);
template
<
>
__device__
__forceinline__
void
mul_v4
<
float
>
(
int4
&
c
,
const
int4
a
,
const
float
alpha
)
{
((
float
*
)
&
c
)[
0
]
=
((
float
*
)
&
a
)[
0
]
*
alpha
;
((
float
*
)
&
c
)[
1
]
=
((
float
*
)
&
a
)[
1
]
*
alpha
;
((
float
*
)
&
c
)[
2
]
=
((
float
*
)
&
a
)[
2
]
*
alpha
;
((
float
*
)
&
c
)[
3
]
=
((
float
*
)
&
a
)[
3
]
*
alpha
;
}
//! ============= fma ===============
__device__
__forceinline__
void
fma2
(
int2
&
c0
,
const
int2
a0
,
int2
&
c1
,
const
int2
a1
,
const
float
alpha
,
const
int4
b
)
{
asm
(
"fma.rz.f32 %0, %1, %2, %3;"
:
"=f"
(((
float
*
)
&
c0
)[
0
])
:
"f"
(((
float
*
)
&
a0
)[
0
]),
"f"
(
alpha
),
"f"
(((
float
*
)
&
b
)[
0
]));
asm
(
"fma.rz.f32 %0, %1, %2, %3;"
:
"=f"
(((
float
*
)
&
c0
)[
1
])
:
"f"
(((
float
*
)
&
a0
)[
1
]),
"f"
(
alpha
),
"f"
(((
float
*
)
&
b
)[
1
]));
asm
(
"fma.rz.f32 %0, %1, %2, %3;"
:
"=f"
(((
float
*
)
&
c1
)[
0
])
:
"f"
(((
float
*
)
&
a1
)[
0
]),
"f"
(
alpha
),
"f"
(((
float
*
)
&
b
)[
2
]));
asm
(
"fma.rz.f32 %0, %1, %2, %3;"
:
"=f"
(((
float
*
)
&
c1
)[
1
])
:
"f"
(((
float
*
)
&
a1
)[
1
]),
"f"
(
alpha
),
"f"
(((
float
*
)
&
b
)[
3
]));
}
__device__
__forceinline__
void
fuse_z_1x8
(
int4
*
a
,
const
int
&
j
,
const
int4
&
fuse_z
,
const
float
&
gamma
,
const
int32_t
&
zero_point
)
{
const
int2
z
[
2
]
=
{
*
reinterpret_cast
<
const
int2
*>
(
&
fuse_z
),
*
(
reinterpret_cast
<
const
int2
*>
(
&
fuse_z
)
+
1
)};
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
f
=
((
z
[
0
].
x
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
0
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
0
].
x
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
1
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
1
].
x
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
2
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
1
].
x
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
3
]
+=
(
f
-
zero_point
)
*
gamma
;
}
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
f
=
((
z
[
0
].
y
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
0
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
0
].
y
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
1
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
1
].
y
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
2
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
z
[
1
].
y
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
3
]
+=
(
f
-
zero_point
)
*
gamma
;
}
}
__device__
__forceinline__
void
fuse_z_1x8
(
int2
*
a
,
const
int
&
j
,
const
int2
&
fuse_z
,
const
float
&
gamma
,
const
int32_t
&
zero_point
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
f
=
((
fuse_z
.
x
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
0
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
fuse_z
.
x
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
]))[
1
]
+=
(
f
-
zero_point
)
*
gamma
;
}
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
f
=
((
fuse_z
.
y
>>
(
k
*
8
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
0
]
+=
(
f
-
zero_point
)
*
gamma
;
f
=
((
fuse_z
.
y
>>
(
k
*
8
+
4
))
&
15
);
f
=
(
f
<<
28
)
>>
28
;
((
float
*
)
&
(
a
[
j
+
k
+
4
]))[
1
]
+=
(
f
-
zero_point
)
*
gamma
;
}
}
__device__
__forceinline__
void
pack_f2i
(
int
&
d0
,
int
&
d1
,
const
int4
s0
,
const
int4
s1
,
const
int4
s2
,
const
int4
s3
,
const
uint32_t
relu
,
float
&
dst_zero_point
)
{
// uint32_t ix, iy, iz, iw;
uint32_t
x0
,
y0
,
z0
,
w0
;
uint32_t
x1
,
y1
,
z1
,
w1
;
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x0
)
:
"f"
(((
float
*
)
&
s0
)[
0
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
y0
)
:
"f"
(((
float
*
)
&
s0
)[
1
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
z0
)
:
"f"
(((
float
*
)
&
s1
)[
0
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
w0
)
:
"f"
(((
float
*
)
&
s1
)[
1
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x1
)
:
"f"
(((
float
*
)
&
s2
)[
0
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
y1
)
:
"f"
(((
float
*
)
&
s2
)[
1
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
z1
)
:
"f"
(((
float
*
)
&
s3
)[
0
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
w1
)
:
"f"
(((
float
*
)
&
s3
)[
1
]));
asm
volatile
(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
"}"
:
"=r"
(
d0
)
:
"r"
(
x0
),
"r"
(
y0
),
"r"
(
z0
),
"r"
(
w0
),
"r"
(
x1
),
"r"
(
y1
),
"r"
(
z1
),
"r"
(
w1
));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x0
)
:
"f"
(((
float
*
)
&
s0
)[
2
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
y0
)
:
"f"
(((
float
*
)
&
s0
)[
3
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
z0
)
:
"f"
(((
float
*
)
&
s1
)[
2
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
w0
)
:
"f"
(((
float
*
)
&
s1
)[
3
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x1
)
:
"f"
(((
float
*
)
&
s2
)[
2
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
y1
)
:
"f"
(((
float
*
)
&
s2
)[
3
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
z1
)
:
"f"
(((
float
*
)
&
s3
)[
2
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
w1
)
:
"f"
(((
float
*
)
&
s3
)[
3
]));
asm
volatile
(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
"}"
:
"=r"
(
d1
)
:
"r"
(
x0
),
"r"
(
y0
),
"r"
(
z0
),
"r"
(
w0
),
"r"
(
x1
),
"r"
(
y1
),
"r"
(
z1
),
"r"
(
w1
));
}
__device__
__forceinline__
void
pack_f2i_with_relu
(
int
&
d0
,
const
int2
s0
,
const
int2
s1
,
const
int2
s2
,
const
int2
s3
,
const
uint32_t
relu
,
float
&
dst_zero_point
)
{
uint32_t
x
[
8
];
if
(
relu
>
0
)
{
asm
volatile
(
"cvt.rni.u8.f32 %0, %1;"
:
"=r"
(
x
[
0
])
:
"f"
(((
float
*
)
&
s0
)[
0
]));
asm
volatile
(
"cvt.rni.u8.f32 %0, %1;"
:
"=r"
(
x
[
1
])
:
"f"
(((
float
*
)
&
s0
)[
1
]));
asm
volatile
(
"cvt.rni.u8.f32 %0, %1;"
:
"=r"
(
x
[
2
])
:
"f"
(((
float
*
)
&
s1
)[
0
]));
asm
volatile
(
"cvt.rni.u8.f32 %0, %1;"
:
"=r"
(
x
[
3
])
:
"f"
(((
float
*
)
&
s1
)[
1
]));
asm
volatile
(
"cvt.rni.u8.f32 %0, %1;"
:
"=r"
(
x
[
4
])
:
"f"
(((
float
*
)
&
s2
)[
0
]));
asm
volatile
(
"cvt.rni.u8.f32 %0, %1;"
:
"=r"
(
x
[
5
])
:
"f"
(((
float
*
)
&
s2
)[
1
]));
asm
volatile
(
"cvt.rni.u8.f32 %0, %1;"
:
"=r"
(
x
[
6
])
:
"f"
(((
float
*
)
&
s3
)[
0
]));
asm
volatile
(
"cvt.rni.u8.f32 %0, %1;"
:
"=r"
(
x
[
7
])
:
"f"
(((
float
*
)
&
s3
)[
1
]));
x
[
0
]
+=
dst_zero_point
;
x
[
1
]
+=
dst_zero_point
;
x
[
2
]
+=
dst_zero_point
;
x
[
3
]
+=
dst_zero_point
;
x
[
4
]
+=
dst_zero_point
;
x
[
5
]
+=
dst_zero_point
;
x
[
6
]
+=
dst_zero_point
;
x
[
7
]
+=
dst_zero_point
;
}
else
if
(
relu
==
0
)
{
((
float
*
)
&
s0
)[
0
]
+=
dst_zero_point
;
((
float
*
)
&
s0
)[
1
]
+=
dst_zero_point
;
((
float
*
)
&
s1
)[
0
]
+=
dst_zero_point
;
((
float
*
)
&
s1
)[
1
]
+=
dst_zero_point
;
((
float
*
)
&
s2
)[
0
]
+=
dst_zero_point
;
((
float
*
)
&
s2
)[
1
]
+=
dst_zero_point
;
((
float
*
)
&
s3
)[
0
]
+=
dst_zero_point
;
((
float
*
)
&
s3
)[
1
]
+=
dst_zero_point
;
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x
[
0
])
:
"f"
(((
float
*
)
&
s0
)[
0
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x
[
1
])
:
"f"
(((
float
*
)
&
s0
)[
1
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x
[
2
])
:
"f"
(((
float
*
)
&
s1
)[
0
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x
[
3
])
:
"f"
(((
float
*
)
&
s1
)[
1
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x
[
4
])
:
"f"
(((
float
*
)
&
s2
)[
0
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x
[
5
])
:
"f"
(((
float
*
)
&
s2
)[
1
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x
[
6
])
:
"f"
(((
float
*
)
&
s3
)[
0
]));
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
x
[
7
])
:
"f"
(((
float
*
)
&
s3
)[
1
]));
}
if
(
relu
>
1
)
{
int
r1
,
r2
;
r1
=
(
x
[
0
]
>=
relu
);
x
[
0
]
*=
r1
;
r2
=
(
x
[
1
]
>=
relu
);
x
[
1
]
*=
r2
;
r1
=
(
x
[
2
]
>=
relu
);
x
[
2
]
*=
r1
;
r2
=
(
x
[
3
]
>=
relu
);
x
[
3
]
*=
r2
;
r1
=
(
x
[
4
]
>=
relu
);
x
[
4
]
*=
r1
;
r2
=
(
x
[
5
]
>=
relu
);
x
[
5
]
*=
r2
;
r1
=
(
x
[
6
]
>=
relu
);
x
[
6
]
*=
r1
;
r2
=
(
x
[
7
]
>=
relu
);
x
[
7
]
*=
r2
;
}
asm
volatile
(
"{ .reg .u32 r4;"
"cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;"
"cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;"
"cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;"
"cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;"
"}"
:
"=r"
(
d0
)
:
"r"
(
x
[
0
]),
"r"
(
x
[
1
]),
"r"
(
x
[
2
]),
"r"
(
x
[
3
]),
"r"
(
x
[
4
]),
"r"
(
x
[
5
]),
"r"
(
x
[
6
]),
"r"
(
x
[
7
]));
}
#define I2F_1x8(a, i, j) \
i2f(a[i][j]); \
i2f(a[i][j + 1]); \
i2f(a[i][j + 2]); \
i2f(a[i][j + 3]); \
i2f(a[i][j + 4]); \
i2f(a[i][j + 5]); \
i2f(a[i][j + 6]); \
i2f(a[i][j + 7]);
#define I2F_4x8(a, i, j) \
I2F_1x8(a, i, j) I2F_1x8(a, i + 1, j) I2F_1x8(a, i + 2, j) I2F_1x8(a, i + 3, j)
#define FMA_1x8(a, i, j, alpha, bias0, bias1, bias2, bias3) \
fma2(a[i][j], reg_acc[i][j], a[i][j + 1], reg_acc[i][j + 1], alpha, bias0); \
fma2(a[i][j + 2], reg_acc[i][j + 2], a[i][j + 3], reg_acc[i][j + 3], alpha, \
bias1); \
fma2(a[i][j + 4], reg_acc[i][j + 4], a[i][j + 5], reg_acc[i][j + 5], alpha, \
bias2); \
fma2(a[i][j + 6], reg_acc[i][j + 6], a[i][j + 7], reg_acc[i][j + 7], alpha, bias3);
#define FMA_4x8(a, i, j, alpha, bias0, bias1, bias2, bias3) \
FMA_1x8(a, i, j, alpha, bias0, bias1, bias2, bias3) \
FMA_1x8(a, i + 1, j, alpha, bias0, bias1, bias2, bias3) \
FMA_1x8(a, i + 2, j, alpha, bias0, bias1, bias2, bias3) \
FMA_1x8(a, i + 3, j, alpha, bias0, bias1, bias2, bias3)
// pack 1x(8 int2) to int2
#define PACK_F2I_WITH_RELU_1x8(a, i, j, relu, dst_zero_point) \
pack_f2i_with_relu( \
a[i][j].x, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3], relu, \
dst_zero_point); \
pack_f2i_with_relu( \
a[i][j].y, a[i][j + 4], a[i][j + 5], a[i][j + 6], a[i][j + 7], relu, \
dst_zero_point);
// pack 4x8 int2 float to 4 int2
#define PACK_F2I_WITH_RELU_4x8(a, i, j, relu, dst_zero_point) \
PACK_F2I_WITH_RELU_1x8(a, i, j, relu, dst_zero_point) \
PACK_F2I_WITH_RELU_1x8(a, i + 1, j, relu, dst_zero_point) \
PACK_F2I_WITH_RELU_1x8(a, i + 2, j, relu, dst_zero_point) \
PACK_F2I_WITH_RELU_1x8(a, i + 3, j, relu, dst_zero_point)
#define STG(d, s, idx, n_reuse, hw_reuse, g) \
n_reuse = nhw_post##idx / param.div_ohow; \
hw_reuse = nhw_post##idx % param.div_ohow; \
d = g_dst_ptr + n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \
g = nhw_post##idx < param.nhw; \
if (stg_oc < param.oc && g) { \
*(reinterpret_cast<int2*>(d)) = *(reinterpret_cast<int2*>(&s)); \
}
#define STG_4x1(d, a, i, j) \
STG(d[0], a[i][j], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
STG(d[1], a[i + 1][j], 1, reg_src_cache[0].y, reg_src_cache[1].y, \
stg_guard[i + 1]) \
STG(d[2], a[i + 2][j], 2, reg_src_cache[0].z, reg_src_cache[1].z, \
stg_guard[i + 2]) \
STG(d[3], a[i + 3][j], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])
#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \
fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \
fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);
#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \
fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \
fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \
fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \
fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point);
// 1x8 1x(2x8 int2) to 2 int2
#define PACK_F2I_1x8(a, i, j) \
pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \
pack_f2i(a[i][j].y, a[i][j].w, a[i][j + 4], a[i][j + 5], a[i][j + 6], a[i][j + 7]);
// 4x8 int4
#define PACK_F2I_4x8(a, i, j) \
PACK_F2I_1x8(a, i, j) PACK_F2I_1x8(a, i + 1, j) PACK_F2I_1x8(a, i + 2, j) \
PACK_F2I_1x8(a, i + 3, j)
#define LDG(d, s, idx, n_reuse, hw_reuse, g) \
n_reuse = nhw_post##idx / param.div_ohow; \
hw_reuse = nhw_post##idx % param.div_ohow; \
s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \
g = nhw_post##idx < param.nhw; \
if (stg_oc < param.oc && g) { \
*(reinterpret_cast<int2*>(&d)) = \
*(reinterpret_cast<const int2*>(g_z_ptr + s)); \
}
#define LDG_4x1(d, s, i) \
LDG(d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
LDG(d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \
stg_guard[i + 1]) \
LDG(d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \
stg_guard[i + 2]) \
LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3])
#define COMPUTE_OFFSET(d, s, idx, n_reuse, hw_reuse, g) \
n_reuse = nhw_post##idx / param.div_ohow; \
hw_reuse = nhw_post##idx % param.div_ohow; \
s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \
g = nhw_post##idx < param.nhw;
#define COMPUTE_OFFSET_4x1(d, s, i) \
COMPUTE_OFFSET( \
d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \
COMPUTE_OFFSET( \
d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \
stg_guard[i + 1]) \
COMPUTE_OFFSET( \
d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \
stg_guard[i + 2]) \
COMPUTE_OFFSET( \
d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, \
stg_guard[i + 3])
#define STG_AFTER_LDG(d, s, g) \
if (stg_oc < param.oc && g) { \
*(reinterpret_cast<int2*>(g_dst_ptr + d)) = *(reinterpret_cast<int2*>(&s)); \
}
#define STG_AFTER_LDG_4x1(d, a, i, j) \
STG_AFTER_LDG(d[i], a[i][j], stg_guard[i]) \
STG_AFTER_LDG(d[i + 1], a[i + 1][j], stg_guard[i + 1]) \
STG_AFTER_LDG(d[i + 2], a[i + 2][j], stg_guard[i + 2]) \
STG_AFTER_LDG(d[i + 3], a[i + 3][j], stg_guard[i + 3])
// vim: syntax=cpp.doxygen
dnn/src/cuda/ptx/uint4_int4/tools.cuh
0 → 100644
浏览文件 @
1f8e930e
#include <cuda_runtime.h>
#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
extern
"C"
{
//
// This NVVM intrinsic is subject to change in future versions of CUDA.
// Clients should not call it directly. Rather, they should use the
// cutlass::arch::ldsm<>() template.
//
__device__
uint32_t
__nvvm_get_smem_pointer
(
void
*
);
}
#endif
inline
__device__
unsigned
get_smem_pointer
(
void
*
ptr
)
{
#if (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
//
// This NVVM intrinsic converts an address in shared memory to a plain
// unsigned integer. This is necessary to pass to shared memory instructions
// in inline PTX.
//
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only
// available in 10.2].
//
//__device__ size_t __cvta_generic_to_shared(void* ptr);
/// CUTLASS helper to get SMEM pointer
return
static_cast
<
unsigned
>
(
__cvta_generic_to_shared
(
ptr
));
#elif (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && \
__CUDACC_VER_MINOR__ >= 2)
return
__nvvm_get_smem_pointer
(
ptr
);
#elif defined(__CUDA_ARCH__)
uint32_t
smem_ptr
;
asm
(
"{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
"%0, smem_ptr; }
\n
"
:
"=r"
(
smem_ptr
)
:
"l"
(
ptr
));
return
smem_ptr
;
#else
return
0
;
#endif
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录