Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2d4e62ef
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2d4e62ef
编写于
5月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add cuda uint4 pooling
GitOrigin-RevId: a7289772068d08deef1021f71984cb4ecfdc6702
上级
19919384
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
176 addition
and
88 deletion
+176
-88
dnn/src/cuda/pooling/opr_impl.cpp
dnn/src/cuda/pooling/opr_impl.cpp
+26
-15
dnn/src/cuda/pooling/pooling2d_qint.cu
dnn/src/cuda/pooling/pooling2d_qint.cu
+114
-63
dnn/src/cuda/pooling/pooling2d_qint.cuh
dnn/src/cuda/pooling/pooling2d_qint.cuh
+6
-3
dnn/test/cuda/pooling.cpp
dnn/test/cuda/pooling.cpp
+28
-5
dnn/test/naive/pooling.cpp
dnn/test/naive/pooling.cpp
+2
-2
未找到文件。
dnn/src/cuda/pooling/opr_impl.cpp
浏览文件 @
2d4e62ef
...
...
@@ -40,8 +40,7 @@ void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
Handle
*
handle
,
PoolingForwardImpl
::
Param
::
Format
format
)
{
bool
is_nchw
=
format
==
PoolingForwardImpl
::
Param
::
Format
::
NCHW
;
if
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
dst
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
is_nchw
)
{
if
(
is_nchw
)
{
auto
relayout_opr
=
handle
->
create_operator
<
RelayoutFormat
>
();
deduce_reformat_layout
(
relayout_opr
,
src
,
inner_src
,
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW64
,
0
,
1
);
...
...
@@ -66,8 +65,11 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
TensorLayout
fsrc
=
src
;
TensorLayout
fdst
=
dst
;
bool
is_nchw
=
param
().
format
==
Param
::
Format
::
NCHW
;
if
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
dst
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
is_nchw
)
{
if
((
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
(
dst
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
dst
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
is_nchw
)
{
get_inner_layout
(
src
,
dst
,
fsrc
,
fdst
,
handle
(),
param
().
format
);
sizes
.
push_back
(
fsrc
.
span
().
dist_byte
());
sizes
.
push_back
(
fdst
.
span
().
dist_byte
());
...
...
@@ -97,8 +99,11 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
bool
is_nchw
=
param
().
format
==
Param
::
Format
::
NCHW
;
if
(
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
BFloat16
>::
enumv
)
{
ctypecvt
.
src_to_comp_type
(
ssrc
,
src
).
src_to_comp_type
(
sdst
,
dst
);
}
else
if
(
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
sdst
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
is_nchw
)
{
}
else
if
((
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
(
sdst
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
sdst
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
is_nchw
)
{
auto
handle_ptr
=
handle
();
get_inner_layout
(
ssrc
.
layout
,
sdst
.
layout
,
src
.
layout
,
dst
.
layout
,
handle_ptr
,
param
().
format
);
...
...
@@ -166,8 +171,6 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
kern_param
,
stream
,
static_cast
<
uint32_t
>
(
param
().
mode
));
}
else
if
(
param
().
format
==
Format
::
NCHW64
||
inner_format
==
Format
::
NCHW64
)
{
megdnn_assert
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
,
"but %s"
,
src
.
layout
.
dtype
.
name
());
pooling2d
::
Param
kern_param
;
size_t
n
=
src
.
layout
[
0
],
hi
=
src
.
layout
[
2
],
wi
=
src
.
layout
[
3
],
c
=
src
.
layout
[
1
],
ho
=
dst
.
layout
[
2
],
wo
=
dst
.
layout
[
3
];
...
...
@@ -180,16 +183,24 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
kern_param
.
ph
=
ph
,
kern_param
.
pw
=
pw
,
kern_param
.
window_h
=
window_h
,
kern_param
.
window_w
=
window_w
,
kern_param
.
sh
=
sh
,
kern_param
.
sw
=
sw
;
bool
uint_case
=
false
;
int
zero_point
=
0
;
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
uint_case
=
true
;
zero_point
=
src
.
layout
.
dtype
.
param
<
dtype
::
Quantized4Asymm
>
()
.
zero_point
;
}
auto
&&
stream
=
cuda_stream
(
handle
());
pooling2d
::
do_pooling2d_int4_ncdiv64hw64
(
(
int8_t
*
)
src
.
raw_ptr
,
(
int8_t
*
)
dst
.
raw_ptr
,
kern_param
,
stream
,
static_cast
<
uint32_t
>
(
param
().
mode
));
if
(
sdst
.
layout
.
ndim
==
4
)
{
auto
relayout_opr
=
handle
()
->
create_operator
<
RelayoutFormat
>
();
RelayoutFormat
::
Param
trans_param
;
trans_param
.
mode
=
RelayoutFormat
::
Param
::
Mode
::
NCHW64_NCHW
;
relayout_opr
->
param
()
=
trans_param
;
relayout_opr
->
exec
(
dst
,
sdst
,{});
stream
,
static_cast
<
uint32_t
>
(
param
().
mode
),
uint_case
,
zero_point
);
if
(
sdst
.
layout
.
ndim
==
4
)
{
auto
relayout_opr
=
handle
()
->
create_operator
<
RelayoutFormat
>
();
RelayoutFormat
::
Param
trans_param
;
trans_param
.
mode
=
RelayoutFormat
::
Param
::
Mode
::
NCHW64_NCHW
;
relayout_opr
->
param
()
=
trans_param
;
relayout_opr
->
exec
(
dst
,
sdst
,
{});
}
return
;
}
...
...
dnn/src/cuda/pooling/pooling2d_qint.cu
浏览文件 @
2d4e62ef
...
...
@@ -29,53 +29,51 @@ __device__ __forceinline__ int pack_int8_to_int8x4(int8_t x, int8_t y, int8_t z,
return
ix
;
}
template
<
int
regs
,
typename
Dtype
,
typename
OutDtype
>
template
<
int
regs
,
int
dtype_bits
,
typename
OutDtype
>
__device__
__forceinline__
OutDtype
pack_int8
(
int8_t
(
&
x
)[
regs
]);
template
<
>
__device__
__forceinline__
int
pack_int8
<
4
,
int8_t
,
int
>
(
int8_t
(
&
x
)[
4
])
{
__device__
__forceinline__
int
pack_int8
<
4
,
8
,
int
>
(
int8_t
(
&
x
)[
4
])
{
return
pack_int8_to_int8x4
(
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]);
}
template
<
>
__device__
__forceinline__
int2
pack_int8
<
8
,
int8_t
,
int2
>
(
int8_t
(
&
x
)[
8
])
{
__device__
__forceinline__
int2
pack_int8
<
8
,
8
,
int2
>
(
int8_t
(
&
x
)[
8
])
{
int8_t
x0
[
4
]{
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]};
int8_t
x1
[
4
]{
x
[
4
],
x
[
5
],
x
[
6
],
x
[
7
]};
return
::
make_int2
(
pack_int8
<
4
,
int8_t
,
int
>
(
x0
),
pack_int8
<
4
,
int8_t
,
int
>
(
x1
));
return
::
make_int2
(
pack_int8
<
4
,
8
,
int
>
(
x0
),
pack_int8
<
4
,
8
,
int
>
(
x1
));
}
template
<
>
__device__
__forceinline__
int4
pack_int8
<
16
,
int8_t
,
int4
>
(
int8_t
(
&
x
)[
16
])
{
__device__
__forceinline__
int4
pack_int8
<
16
,
8
,
int4
>
(
int8_t
(
&
x
)[
16
])
{
int8_t
x0
[
4
]{
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]};
int8_t
x1
[
4
]{
x
[
4
],
x
[
5
],
x
[
6
],
x
[
7
]};
int8_t
x2
[
4
]{
x
[
8
],
x
[
9
],
x
[
10
],
x
[
11
]};
int8_t
x3
[
4
]{
x
[
12
],
x
[
13
],
x
[
14
],
x
[
15
]};
return
::
make_int4
(
pack_int8
<
4
,
int8_t
,
int
>
(
x0
),
pack_int8
<
4
,
int8_t
,
int
>
(
x1
),
pack_int8
<
4
,
int8_t
,
int
>
(
x2
),
pack_int8
<
4
,
int8_t
,
int
>
(
x3
));
return
::
make_int4
(
pack_int8
<
4
,
8
,
int
>
(
x0
),
pack_int8
<
4
,
8
,
int
>
(
x1
),
pack_int8
<
4
,
8
,
int
>
(
x2
),
pack_int8
<
4
,
8
,
int
>
(
x3
));
}
__device__
__forceinline__
int8_t
pack_int8_to_int4x2
(
int8_t
x0
,
int8_t
x1
)
{
return
(
x0
&
0xf
)
|
(
x1
<<
4
);
}
template
<
>
__device__
__forceinline__
int
pack_int8
<
8
,
dt_qint
4
,
int
>
(
int8_t
(
&
x
)[
8
])
{
__device__
__forceinline__
int
pack_int8
<
8
,
4
,
int
>
(
int8_t
(
&
x
)[
8
])
{
int8_t
x0
=
pack_int8_to_int4x2
(
x
[
0
],
x
[
1
]);
int8_t
x1
=
pack_int8_to_int4x2
(
x
[
2
],
x
[
3
]);
int8_t
x2
=
pack_int8_to_int4x2
(
x
[
4
],
x
[
5
]);
int8_t
x3
=
pack_int8_to_int4x2
(
x
[
6
],
x
[
7
]);
return
pack_int8_to_int8x4
(
x0
,
x1
,
x2
,
x3
);
}
template
<
>
__device__
__forceinline__
int4
pack_int8
<
32
,
dt_qint
4
,
int4
>
(
int8_t
(
&
x
)[
32
])
{
__device__
__forceinline__
int4
pack_int8
<
32
,
4
,
int4
>
(
int8_t
(
&
x
)[
32
])
{
int8_t
x0
[
8
]{
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
],
x
[
4
],
x
[
5
],
x
[
6
],
x
[
7
]};
int8_t
x1
[
8
]{
x
[
8
],
x
[
9
],
x
[
10
],
x
[
11
],
x
[
12
],
x
[
13
],
x
[
14
],
x
[
15
]};
int8_t
x2
[
8
]{
x
[
16
],
x
[
17
],
x
[
18
],
x
[
19
],
x
[
20
],
x
[
21
],
x
[
22
],
x
[
23
]};
int8_t
x3
[
8
]{
x
[
24
],
x
[
25
],
x
[
26
],
x
[
27
],
x
[
28
],
x
[
29
],
x
[
30
],
x
[
31
]};
return
::
make_int4
(
pack_int8
<
8
,
dt_qint4
,
int
>
(
x0
),
pack_int8
<
8
,
dt_qint4
,
int
>
(
x1
),
pack_int8
<
8
,
dt_qint4
,
int
>
(
x2
),
pack_int8
<
8
,
dt_qint4
,
int
>
(
x3
));
return
::
make_int4
(
pack_int8
<
8
,
4
,
int
>
(
x0
),
pack_int8
<
8
,
4
,
int
>
(
x1
),
pack_int8
<
8
,
4
,
int
>
(
x2
),
pack_int8
<
8
,
4
,
int
>
(
x3
));
}
template
<
typename
Dtype
>
...
...
@@ -88,6 +86,7 @@ struct TypeTrait<int8_t> {
static
constexpr
int8_t
min
=
-
128
;
static
constexpr
int
elem_per_32bit
=
32
/
bit_width
;
static
constexpr
int
shift_fix_sign
=
0
;
static
constexpr
bool
need_zero_pad
=
false
;
};
template
<
>
...
...
@@ -97,6 +96,16 @@ struct TypeTrait<dt_qint4> {
static
constexpr
int8_t
min
=
-
8
;
static
constexpr
int
elem_per_32bit
=
32
/
bit_width
;
static
constexpr
int
shift_fix_sign
=
4
;
static
constexpr
bool
need_zero_pad
=
false
;
};
template
<
>
struct
TypeTrait
<
dt_quint4
>
{
static
constexpr
int
bit_width
=
4
;
static
constexpr
int
mask
=
0xf
;
static
constexpr
int8_t
min
=
0
;
static
constexpr
int
elem_per_32bit
=
32
/
bit_width
;
static
constexpr
int
shift_fix_sign
=
0
;
static
constexpr
bool
need_zero_pad
=
true
;
};
template
<
typename
src_type
,
typename
_feed_type
>
...
...
@@ -108,7 +117,7 @@ struct MaxPooler {
static
constexpr
int
shift_fix_sign
=
TypeTrait
<
src_type
>::
shift_fix_sign
;
int8_t
res
[
nr_results
];
__device__
MaxPooler
(
int
)
{}
__device__
MaxPooler
(
int
,
int
)
{}
__device__
__forceinline__
void
init
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
nr_results
;
++
i
)
{
...
...
@@ -137,7 +146,7 @@ struct MaxPooler {
}
__device__
__forceinline__
feed_type
get_ans
()
{
feed_type
ans
;
ans
=
pack_int8
<
nr_results
,
src_type
,
feed_type
>
(
res
);
ans
=
pack_int8
<
nr_results
,
bit_width
,
feed_type
>
(
res
);
return
ans
;
}
};
...
...
@@ -149,21 +158,27 @@ struct MeanIncludeRoundedPooler {
static
constexpr
int
nr_results
=
sizeof
(
feed_type
)
*
8
/
bit_width
;
static
constexpr
int
elem_per_32bit
=
TypeTrait
<
src_type
>::
elem_per_32bit
;
static
constexpr
int
shift_fix_sign
=
TypeTrait
<
src_type
>::
shift_fix_sign
;
static
constexpr
bool
need_zero_pad
=
TypeTrait
<
src_type
>::
need_zero_pad
;
int32_t
res
[
nr_results
];
const
int
count
;
const
float
fi_count
;
int
real_fi_count
;
const
int
zero_pad
;
__device__
MeanIncludeRoundedPooler
(
int
count
)
:
count
{
count
},
fi_count
{
1.
f
/
count
}
{}
__device__
MeanIncludeRoundedPooler
(
int
count
,
int
zero_point
)
:
count
{
count
},
fi_count
{
1.
f
/
count
}
,
zero_pad
{
zero_point
}
{}
__device__
__forceinline__
void
init
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
nr_results
;
++
i
)
{
res
[
i
]
=
0
;
}
if
(
need_zero_pad
)
{
real_fi_count
=
0
;
}
}
__device__
__forceinline__
void
feed
(
int
x
,
int
idx
=
0
)
{
__device__
__forceinline__
void
feed
(
int
x
,
int
idx
)
{
constexpr
int
unroll_n
=
sizeof
(
int
)
*
8
/
bit_width
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_n
;
i
++
)
{
...
...
@@ -173,15 +188,27 @@ struct MeanIncludeRoundedPooler {
res
[
idx
+
i
]
+=
static_cast
<
int32_t
>
(
temp
);
}
}
__device__
__forceinline__
void
feed
(
int
x
)
{
feed
(
x
,
0
);
if
(
need_zero_pad
)
{
real_fi_count
++
;
}
}
__device__
__forceinline__
void
feed
(
int2
x
)
{
feed
(
x
.
x
,
0
*
elem_per_32bit
);
feed
(
x
.
y
,
1
*
elem_per_32bit
);
if
(
need_zero_pad
)
{
real_fi_count
++
;
}
}
__device__
__forceinline__
void
feed
(
int4
x
)
{
feed
(
x
.
x
,
0
*
elem_per_32bit
);
feed
(
x
.
y
,
1
*
elem_per_32bit
);
feed
(
x
.
z
,
2
*
elem_per_32bit
);
feed
(
x
.
w
,
3
*
elem_per_32bit
);
if
(
need_zero_pad
)
{
real_fi_count
++
;
}
}
__device__
__forceinline__
feed_type
get_ans
()
{
feed_type
ans
;
...
...
@@ -189,13 +216,18 @@ struct MeanIncludeRoundedPooler {
#pragma unroll
for
(
int
i
=
0
;
i
<
nr_results
;
i
++
)
{
float
f32_res
=
roundf
(
static_cast
<
float
>
(
res
[
i
])
*
fi_count
);
if
(
need_zero_pad
)
{
f32_res
=
roundf
((
static_cast
<
float
>
(
res
[
i
])
+
(
count
-
real_fi_count
)
*
zero_pad
)
*
fi_count
);
}
int
i8_res
;
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
i8_res
)
:
"f"
(
f32_res
));
out_res
[
i
]
=
i8_res
;
}
ans
=
pack_int8
<
nr_results
,
src_type
,
feed_type
>
(
out_res
);
ans
=
pack_int8
<
nr_results
,
bit_width
,
feed_type
>
(
out_res
);
return
ans
;
}
};
...
...
@@ -209,7 +241,7 @@ struct MeanExcludeRoundedPooler {
static
constexpr
int
shift_fix_sign
=
TypeTrait
<
src_type
>::
shift_fix_sign
;
int32_t
res
[
nr_results
];
int
count
;
__device__
MeanExcludeRoundedPooler
(
int
)
{}
__device__
MeanExcludeRoundedPooler
(
int
,
int
)
{}
__device__
__forceinline__
void
init
()
{
#pragma unroll
...
...
@@ -257,7 +289,7 @@ struct MeanExcludeRoundedPooler {
:
"f"
(
f32_res
));
out_res
[
i
]
=
i8_res
;
}
ans
=
pack_int8
<
nr_results
,
src_type
,
feed_type
>
(
out_res
);
ans
=
pack_int8
<
nr_results
,
bit_width
,
feed_type
>
(
out_res
);
return
ans
;
}
};
...
...
@@ -290,7 +322,7 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4(
packed_ch
*
output_pixels
*
npack
+
(
ho
*
param
.
wo
+
wo
)
*
npack
;
Pooler
pooler
(
param
.
window_h
*
param
.
window_w
);
Pooler
pooler
(
param
.
window_h
*
param
.
window_w
,
0
);
pooler
.
init
();
for
(
int
fh
=
0
;
fh
<
param
.
window_h
;
fh
++
)
{
uint32_t
ih
=
ho
*
param
.
sh
+
fh
-
param
.
ph
;
...
...
@@ -313,7 +345,7 @@ template <typename Pooler, int pack_size, int pack_byte,
int
ldg_width_assert
=
4
>
__global__
void
pooling2d_device_template_nchwc
(
const
int8_t
*
__restrict__
src
,
int8_t
*
__restrict__
dst
,
Param
param
)
{
Param
param
,
int
zero_point
)
{
const
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
using
ldg_type
=
typename
Pooler
::
feed_type
;
static
int
constexpr
ldg_width
=
sizeof
(
ldg_type
)
/
sizeof
(
int32_t
);
...
...
@@ -348,7 +380,7 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src,
dst
+
(
batch
*
out_batch_stride
+
oc
*
out_channel_stride
+
(
oh
*
param
.
wo
+
ow
)
*
pack_byte
+
sec
*
ldg_width_bytes
);
Pooler
pooler
(
param
.
window_h
*
param
.
window_w
);
Pooler
pooler
(
param
.
window_h
*
param
.
window_w
,
zero_point
);
pooler
.
init
();
for
(
int
fh
=
0
;
fh
<
param
.
window_h
;
fh
++
)
{
uint32_t
ih
=
oh
*
param
.
sh
+
fh
-
param
.
ph
;
...
...
@@ -418,13 +450,12 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src,
after_kernel_launch
();
}
void
megdnn
::
cuda
::
pooling2d
::
do_pooling2d_int8_ncdiv4hw4
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
)
{
void
megdnn
::
cuda
::
pooling2d
::
do_pooling2d_int8_ncdiv4hw4
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
,
bool
uint_case
,
int
zero_point
)
{
using
Mode
=
megdnn
::
param_enumv
::
Pooling
::
Mode
;
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
);
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
,
int
zero_point
);
constexpr
int
ldg_byte
=
4
;
constexpr
int
elem_per_byte
=
1
;
constexpr
int
pack_size
=
4
;
...
...
@@ -455,17 +486,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src,
uint32_t
nr_threads
=
query_blocksize_for_kernel
(
kern
);
nr_threads
=
std
::
min
(
nr_threads
,
vthreads
);
uint32_t
nr_blocks
=
DIVUP
(
vthreads
,
nr_threads
);
kern
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
d_src
,
d_dst
,
param
);
kern
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
d_src
,
d_dst
,
param
,
zero_point
);
after_kernel_launch
();
}
void
megdnn
::
cuda
::
pooling2d
::
do_pooling2d_int8_ncdiv32hw32
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
)
{
void
megdnn
::
cuda
::
pooling2d
::
do_pooling2d_int8_ncdiv32hw32
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
,
bool
uint_case
,
int
zero_point
)
{
using
Mode
=
megdnn
::
param_enumv
::
Pooling
::
Mode
;
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
);
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
,
int
zero_point
);
constexpr
int
ldg_byte
=
16
;
constexpr
int
elem_per_byte
=
1
;
constexpr
int
pack_size
=
32
;
...
...
@@ -494,17 +524,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src,
uint32_t
nr_threads
=
query_blocksize_for_kernel
(
kern
);
nr_threads
=
std
::
min
(
nr_threads
,
vthreads
);
uint32_t
nr_blocks
=
DIVUP
(
vthreads
,
nr_threads
);
kern
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
d_src
,
d_dst
,
param
);
kern
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
d_src
,
d_dst
,
param
,
zero_point
);
after_kernel_launch
();
}
void
megdnn
::
cuda
::
pooling2d
::
do_pooling2d_int4_ncdiv64hw64
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
)
{
void
megdnn
::
cuda
::
pooling2d
::
do_pooling2d_int4_ncdiv64hw64
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
,
bool
uint_case
,
int
zero_point
)
{
using
Mode
=
megdnn
::
param_enumv
::
Pooling
::
Mode
;
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
);
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
,
int
zero_point
);
constexpr
int
ldg_byte
=
16
;
constexpr
int
elem_per_byte
=
2
;
constexpr
int
pack_size
=
64
;
...
...
@@ -512,28 +541,50 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src,
constexpr
int
elem_per_thread
=
ldg_byte
*
elem_per_byte
;
uint32_t
vthreads
=
param
.
n
*
param
.
c
*
param
.
ho
*
param
.
wo
/
elem_per_thread
;
switch
(
mode
)
{
case
Mode
::
MAX
:
kern
=
pooling2d_device_template_nchwc
<
MaxPooler
<
dt_qint4
,
int4
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE
:
kern
=
pooling2d_device_template_nchwc
<
MeanIncludeRoundedPooler
<
dt_qint4
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
:
kern
=
pooling2d_device_template_nchwc
<
MeanExcludeRoundedPooler
<
dt_qint4
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
default:
megdnn_assert
(
false
,
"invalid pooling mode"
);
if
(
uint_case
)
{
switch
(
mode
)
{
case
Mode
::
MAX
:
kern
=
pooling2d_device_template_nchwc
<
MaxPooler
<
dt_quint4
,
int4
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE
:
kern
=
pooling2d_device_template_nchwc
<
MeanIncludeRoundedPooler
<
dt_quint4
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
:
kern
=
pooling2d_device_template_nchwc
<
MeanExcludeRoundedPooler
<
dt_quint4
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
default:
megdnn_assert
(
false
,
"invalid pooling mode"
);
}
}
else
{
switch
(
mode
)
{
case
Mode
::
MAX
:
kern
=
pooling2d_device_template_nchwc
<
MaxPooler
<
dt_qint4
,
int4
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE
:
kern
=
pooling2d_device_template_nchwc
<
MeanIncludeRoundedPooler
<
dt_qint4
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
:
kern
=
pooling2d_device_template_nchwc
<
MeanExcludeRoundedPooler
<
dt_qint4
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
default:
megdnn_assert
(
false
,
"invalid pooling mode"
);
}
}
uint32_t
nr_threads
=
query_blocksize_for_kernel
(
kern
);
nr_threads
=
std
::
min
(
nr_threads
,
vthreads
);
uint32_t
nr_blocks
=
DIVUP
(
vthreads
,
nr_threads
);
kern
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
d_src
,
d_dst
,
param
);
kern
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
d_src
,
d_dst
,
param
,
zero_point
);
after_kernel_launch
();
}
...
...
dnn/src/cuda/pooling/pooling2d_qint.cuh
浏览文件 @
2d4e62ef
...
...
@@ -27,15 +27,18 @@ void do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst,
void
do_pooling2d_int8_ncdiv4hw4
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
);
uint32_t
mode
,
bool
uint_case
=
false
,
int
zero_point
=
0
);
void
do_pooling2d_int8_ncdiv32hw32
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
);
uint32_t
mode
,
bool
uint_case
=
false
,
int
zero_point
=
0
);
void
do_pooling2d_int4_ncdiv64hw64
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
);
uint32_t
mode
,
bool
uint_case
=
false
,
int
zero_point
=
0
);
}
// namespace pooling2d
}
// namespace cuda
...
...
dnn/test/cuda/pooling.cpp
浏览文件 @
2d4e62ef
...
...
@@ -254,6 +254,13 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW_Q4) {
checker
.
set_param
(
param
).
exec
({{
20
,
96
,
22
,
33
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
;
checker
.
set_param
(
param
).
exec
({{
20
,
24
,
22
,
33
},
{}});
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
(
3.1415926
f
,
3
));
param
.
format
=
Param
::
Format
::
NCHW
;
checker
.
set_param
(
param
).
exec
({{
20
,
64
,
22
,
33
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE
;
checker
.
set_param
(
param
).
exec
({{
20
,
96
,
22
,
33
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
;
checker
.
set_param
(
param
).
exec
({{
20
,
24
,
22
,
33
},
{}});
}
TEST_F
(
CUDA
,
POOLING_FORWARD_NCHW4
)
{
...
...
@@ -291,20 +298,36 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW32) {
}
#endif
TEST_F
(
CUDA
,
POOLING_FORWARD_NCHW64
)
{
TEST_F
(
CUDA
,
POOLING_FORWARD_NCHW64
_Q4
)
{
require_compute_capability
(
7
,
5
);
using
Param
=
param
::
Pooling
;
Checker
<
Pooling
>
checker
(
handle_cuda
());
Param
param
{
Param
::
Mode
::
MAX
,
0
,
0
,
2
,
2
,
2
,
2
};
Param
param
{
Param
::
Mode
::
MAX
,
1
,
1
,
2
,
2
,
2
,
2
};
UniformIntRNG
int_rng
{
-
8
,
7
};
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
1.
f
));
param
.
format
=
Param
::
Format
::
NCHW64
;
checker
.
set_epsilon
(
1e-3
).
set_rng
(
0
,
&
int_rng
);
checker
.
set_param
(
param
).
exec
({{
6
4
,
8
,
28
,
28
,
64
},
{}});
checker
.
set_param
(
param
).
exec
({{
4
,
8
,
28
,
28
,
64
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE
;
checker
.
set_param
(
param
).
exec
({{
6
4
,
8
,
28
,
28
,
64
},
{}});
checker
.
set_param
(
param
).
exec
({{
4
,
8
,
28
,
28
,
64
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
;
checker
.
set_param
(
param
).
exec
({{
64
,
8
,
28
,
28
,
64
},
{}});
checker
.
set_param
(
param
).
exec
({{
4
,
8
,
28
,
28
,
64
},
{}});
}
TEST_F
(
CUDA
,
POOLING_FORWARD_NCHW64_U4
)
{
require_compute_capability
(
7
,
5
);
using
Param
=
param
::
Pooling
;
Checker
<
Pooling
>
checker
(
handle_cuda
());
Param
param
{
Param
::
Mode
::
MAX
,
1
,
1
,
2
,
2
,
2
,
2
};
UniformIntRNG
int_rng
{
0
,
15
};
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
(
1.
f
,
3
));
param
.
format
=
Param
::
Format
::
NCHW64
;
checker
.
set_epsilon
(
1e-3
).
set_rng
(
0
,
&
int_rng
);
checker
.
set_param
(
param
).
exec
({{
4
,
8
,
28
,
28
,
64
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE
;
checker
.
set_param
(
param
).
exec
({{
4
,
8
,
28
,
28
,
64
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
;
checker
.
set_param
(
param
).
exec
({{
4
,
8
,
28
,
28
,
64
},
{}});
}
TEST_F
(
CUDA
,
POOLING_FORWARD_CHWN4
)
{
...
...
dnn/test/naive/pooling.cpp
浏览文件 @
2d4e62ef
...
...
@@ -84,12 +84,12 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) {
}
{
auto
u4_dt
=
dtype
::
Quantized4Asymm
(
1.
f
,
0
);
auto
u4_dt
=
dtype
::
Quantized4Asymm
(
0.1
f
,
3
);
std
::
vector
<
int
>
u8_src_vec
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
std
::
vector
<
int
>
u8_max_dst_vec
{
1
,
3
,
7
,
9
};
std
::
vector
<
int
>
u8_avg_dst_vec
{
0
,
1
,
3
,
7
};
std
::
vector
<
int
>
u8_avg_dst_vec
{
3
,
3
,
4
,
7
};
std
::
vector
<
int
>
u8_avg_exclu_dst_vec
{
1
,
3
,
6
,
7
};
Pooling
::
Param
param
{
Mode
::
MAX
,
1
,
1
,
2
,
2
,
2
,
2
};
Testcase
input
{
TensorValueLowbit4
({
1
,
1
,
3
,
3
},
u4_dt
,
u8_src_vec
),
{}};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录