Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2398df07
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2398df07
编写于
3月 31, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add cuda int4 pooling
GitOrigin-RevId: 14ed4e6f0095231ca87cace41842b32df18d0818
上级
2a2a7f45
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
909 addition
and
82 deletion
+909
-82
dnn/src/common/pooling.cpp
dnn/src/common/pooling.cpp
+7
-1
dnn/src/cuda/pooling/opr_impl.cpp
dnn/src/cuda/pooling/opr_impl.cpp
+98
-10
dnn/src/cuda/pooling/pooling2d_qint.cu
dnn/src/cuda/pooling/pooling2d_qint.cu
+540
-0
dnn/src/cuda/pooling/pooling2d_qint.cuh
dnn/src/cuda/pooling/pooling2d_qint.cuh
+6
-1
dnn/src/naive/pooling/opr_impl.cpp
dnn/src/naive/pooling/opr_impl.cpp
+135
-46
dnn/src/naive/pooling/opr_impl.h
dnn/src/naive/pooling/opr_impl.h
+6
-4
dnn/test/common/checker.h
dnn/test/common/checker.h
+20
-20
dnn/test/cuda/pooling.cpp
dnn/test/cuda/pooling.cpp
+38
-0
dnn/test/naive/pooling.cpp
dnn/test/naive/pooling.cpp
+59
-0
未找到文件。
dnn/src/common/pooling.cpp
浏览文件 @
2398df07
...
...
@@ -47,7 +47,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW32
)
{
param
().
format
==
Param
::
Format
::
NCHW32
||
param
().
format
==
Param
::
Format
::
NCHW64
)
{
megdnn_assert
(
src
.
ndim
==
5
_z
,
"%s"
,
errmsg_c
);
spatial_pos
=
2
;
...
...
@@ -82,6 +83,9 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
if
(
param
().
format
==
Param
::
Format
::
NCHW32
)
{
c
*=
32
;
}
if
(
param
().
format
==
Param
::
Format
::
NCHW64
)
{
c
*=
64
;
}
size_t
oh
,
ow
;
size_t
fh
=
this
->
param
().
window_h
;
size_t
fw
=
this
->
param
().
window_w
;
...
...
@@ -109,6 +113,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
dst
=
TensorLayout
{{
n
,
c
/
8
,
oh
,
ow
,
8
},
src
.
dtype
,
src
.
format
};
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW32
)
{
dst
=
TensorLayout
{{
n
,
c
/
32
,
oh
,
ow
,
32
},
src
.
dtype
,
src
.
format
};
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW64
)
{
dst
=
TensorLayout
{{
n
,
c
/
64
,
oh
,
ow
,
64
},
src
.
dtype
,
src
.
format
};
}
else
if
(
param
().
format
==
Param
::
Format
::
CHWN4
)
{
dst
=
TensorLayout
{{
c
/
4
,
oh
,
ow
,
n
,
4
},
src
.
dtype
,
src
.
format
};
}
else
{
...
...
dnn/src/cuda/pooling/opr_impl.cpp
浏览文件 @
2398df07
...
...
@@ -9,13 +9,50 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/pooling/opr_impl.h"
#include "src/cuda/relayout_format/opr_impl.h"
#include "./pooling2d_
int8
.cuh"
#include "./pooling2d_
qint
.cuh"
#include "src/cuda/utils.h"
namespace
megdnn
{
namespace
cuda
{
namespace
{
inline
void
deduce_reformat_layout
(
std
::
unique_ptr
<
RelayoutFormat
>&
relayout
,
const
TensorLayout
&
src_layout
,
TensorLayout
&
dst_layout
,
RelayoutFormat
::
Param
::
Mode
mode
,
const
int
oc
=
0
,
const
int
group
=
1
)
{
if
(
src_layout
.
ndim
>
0
)
{
RelayoutFormat
::
Param
trans_param
;
trans_param
.
mode
=
mode
;
trans_param
.
oc
=
oc
;
trans_param
.
group
=
group
;
relayout
->
param
()
=
trans_param
;
relayout
->
deduce_layout
(
src_layout
,
dst_layout
);
}
else
{
dst_layout
=
src_layout
;
}
}
void
get_inner_layout
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
,
TensorLayout
&
inner_src
,
TensorLayout
&
inner_dst
,
Handle
*
handle
,
PoolingForwardImpl
::
Param
::
Format
format
)
{
bool
is_nchw
=
format
==
PoolingForwardImpl
::
Param
::
Format
::
NCHW
;
if
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
dst
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
is_nchw
)
{
auto
relayout_opr
=
handle
->
create_operator
<
RelayoutFormat
>
();
deduce_reformat_layout
(
relayout_opr
,
src
,
inner_src
,
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW64
,
0
,
1
);
deduce_reformat_layout
(
relayout_opr
,
dst
,
inner_dst
,
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW64
,
0
,
1
);
}
else
{
megdnn_assert
(
0
,
"not support"
);
}
}
}
// namespace
void
PoolingForwardImpl
::
setup_descs
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
src_desc
.
set
(
src
,
param
().
format
);
...
...
@@ -28,14 +65,22 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
SmallVector
<
size_t
>
sizes
;
TensorLayout
fsrc
=
src
;
TensorLayout
fdst
=
dst
;
auto
get_workspace
=
[
&
sizes
](
TensorLayout
&
layout
)
{
if
(
layout
.
dtype
==
dtype
::
BFloat16
())
{
layout
.
dtype
=
dtype
::
Float32
();
sizes
.
push_back
(
layout
.
span
().
dist_byte
());
}
};
get_workspace
(
fsrc
);
get_workspace
(
fdst
);
bool
is_nchw
=
param
().
format
==
Param
::
Format
::
NCHW
;
if
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
dst
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
is_nchw
)
{
get_inner_layout
(
src
,
dst
,
fsrc
,
fdst
,
handle
(),
param
().
format
);
sizes
.
push_back
(
fsrc
.
span
().
dist_byte
());
sizes
.
push_back
(
fdst
.
span
().
dist_byte
());
}
else
{
auto
get_workspace
=
[
&
sizes
](
TensorLayout
&
layout
)
{
if
(
layout
.
dtype
==
dtype
::
BFloat16
())
{
layout
.
dtype
=
dtype
::
Float32
();
sizes
.
push_back
(
layout
.
span
().
dist_byte
());
}
};
get_workspace
(
fsrc
);
get_workspace
(
fdst
);
}
return
{
ptr
,
std
::
move
(
sizes
)};
}
...
...
@@ -44,12 +89,27 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
check_exec
(
ssrc
.
layout
,
sdst
.
layout
,
sworkspace
.
size
);
TensorND
src
=
ssrc
;
TensorND
dst
=
sdst
;
Param
::
Format
inner_format
=
param
().
format
;
auto
wsb
=
get_workspace_bundle
(
sworkspace
.
raw_ptr
,
ssrc
.
layout
,
sdst
.
layout
);
auto
ctypecvt
=
CompTypeCvter
<
dtype
::
BFloat16
,
dtype
::
Float32
>
(
concrete_handle
(
this
->
handle
()),
&
wsb
);
bool
is_nchw
=
param
().
format
==
Param
::
Format
::
NCHW
;
if
(
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
BFloat16
>::
enumv
)
{
ctypecvt
.
src_to_comp_type
(
ssrc
,
src
).
src_to_comp_type
(
sdst
,
dst
);
}
else
if
(
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
sdst
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
is_nchw
)
{
auto
handle_ptr
=
handle
();
get_inner_layout
(
ssrc
.
layout
,
sdst
.
layout
,
src
.
layout
,
dst
.
layout
,
handle_ptr
,
param
().
format
);
src
.
raw_ptr
=
wsb
.
get
(
0
);
dst
.
raw_ptr
=
wsb
.
get
(
1
);
auto
relayout_opr
=
handle_ptr
->
create_operator
<
RelayoutFormat
>
();
RelayoutFormat
::
Param
trans_param
;
trans_param
.
mode
=
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW64
;
relayout_opr
->
param
()
=
trans_param
;
relayout_opr
->
exec
(
ssrc
,
src
,
{});
inner_format
=
Param
::
Format
::
NCHW64
;
}
{
using
Format
=
param
::
Pooling
::
Format
;
...
...
@@ -104,6 +164,34 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
return
pooling2d
::
do_pooling2d_int8_ncdiv32hw32
(
src
.
compatible_ptr
<
int8_t
>
(),
dst
.
compatible_ptr
<
int8_t
>
(),
kern_param
,
stream
,
static_cast
<
uint32_t
>
(
param
().
mode
));
}
else
if
(
param
().
format
==
Format
::
NCHW64
||
inner_format
==
Format
::
NCHW64
)
{
megdnn_assert
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
,
"but %s"
,
src
.
layout
.
dtype
.
name
());
pooling2d
::
Param
kern_param
;
size_t
n
=
src
.
layout
[
0
],
hi
=
src
.
layout
[
2
],
wi
=
src
.
layout
[
3
],
c
=
src
.
layout
[
1
],
ho
=
dst
.
layout
[
2
],
wo
=
dst
.
layout
[
3
];
c
=
c
*
64
;
size_t
ph
=
param
().
pad_h
,
pw
=
param
().
pad_w
;
size_t
window_h
=
param
().
window_h
,
window_w
=
param
().
window_w
;
size_t
sh
=
param
().
stride_h
,
sw
=
param
().
stride_w
;
kern_param
.
n
=
n
,
kern_param
.
c
=
c
,
kern_param
.
hi
=
hi
,
kern_param
.
wi
=
wi
,
kern_param
.
ho
=
ho
,
kern_param
.
wo
=
wo
,
kern_param
.
ph
=
ph
,
kern_param
.
pw
=
pw
,
kern_param
.
window_h
=
window_h
,
kern_param
.
window_w
=
window_w
,
kern_param
.
sh
=
sh
,
kern_param
.
sw
=
sw
;
auto
&&
stream
=
cuda_stream
(
handle
());
pooling2d
::
do_pooling2d_int4_ncdiv64hw64
(
(
int8_t
*
)
src
.
raw_ptr
,
(
int8_t
*
)
dst
.
raw_ptr
,
kern_param
,
stream
,
static_cast
<
uint32_t
>
(
param
().
mode
));
if
(
sdst
.
layout
.
ndim
==
4
)
{
auto
relayout_opr
=
handle
()
->
create_operator
<
RelayoutFormat
>
();
RelayoutFormat
::
Param
trans_param
;
trans_param
.
mode
=
RelayoutFormat
::
Param
::
Mode
::
NCHW64_NCHW
;
relayout_opr
->
param
()
=
trans_param
;
relayout_opr
->
exec
(
dst
,
sdst
,{});
}
return
;
}
auto
handle
=
cudnn_handle
(
this
->
handle
());
setup_descs
(
src
.
layout
,
dst
.
layout
);
...
...
@@ -114,7 +202,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
}
if
(
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
BFloat16
>::
enumv
)
{
ctypecvt
.
comp_to_dst_type
(
dst
,
sdst
);
}
}
}
void
PoolingBackwardImpl
::
setup_descs
(
const
TensorLayout
&
src
,
...
...
dnn/src/cuda/pooling/pooling2d_
int8
.cu
→
dnn/src/cuda/pooling/pooling2d_
qint
.cu
浏览文件 @
2398df07
/**
* \file dnn/src/cuda/pooling/pooling2d_
int8_cdiv4hwn4
.cu
* \file dnn/src/cuda/pooling/pooling2d_
qint
.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./pooling2d_
int8
.cuh"
#include "./pooling2d_
qint
.cuh"
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/cuda/query_blocksize.cuh"
...
...
@@ -17,27 +18,6 @@ using namespace cuda;
using
namespace
pooling2d
;
namespace
{
// common macros
#define FEED1 Base::feed(x, 0);
#define FEED2 \
Base::feed(x.x, 0); \
Base::feed(x.y, 4);
#define FEED4 \
FEED2; \
Base::feed(x.z, 8); \
Base::feed(x.w, 12);
#define ANS1(cb) cb(Base::res[0], Base::res[1], Base::res[2], Base::res[3], i1);
#define ANS2(cb) \
ANS1(cb); \
cb(Base::res[4], Base::res[5], Base::res[6], Base::res[7], i2);
#define ANS4(cb) \
ANS2(cb); \
cb(Base::res[8], Base::res[9], Base::res[10], Base::res[11], i3); \
cb(Base::res[12], Base::res[13], Base::res[14], Base::res[15], i4);
__device__
__forceinline__
int
pack_int8_to_int8x4
(
int8_t
x
,
int8_t
y
,
int8_t
z
,
int8_t
w
)
{
int
ix
=
static_cast
<
int
>
(
x
),
iy
=
static_cast
<
int
>
(
y
),
...
...
@@ -49,184 +29,188 @@ __device__ __forceinline__ int pack_int8_to_int8x4(int8_t x, int8_t y, int8_t z,
return
ix
;
}
template
<
typename
src_type
,
typename
feed_type
>
struct
MaxPoolerBase
;
template
<
typename
feed_type
>
struct
MaxPoolerBase
<
int8_t
,
feed_type
>
{
static
constexpr
int
nr_results
=
sizeof
(
feed_type
)
/
sizeof
(
int8_t
);
int8_t
res
[
nr_results
];
__device__
MaxPoolerBase
(
int
)
{}
__device__
__forceinline__
void
init
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
nr_results
;
++
i
)
{
res
[
i
]
=
-
128
;
}
}
__device__
__forceinline__
void
feed
(
int32_t
x
,
int
idx
)
{
int8_t
ix
=
(
x
&
0xff
);
int8_t
iy
=
((
x
>>
8
)
&
0xff
);
int8_t
iz
=
((
x
>>
16
)
&
0xff
);
int8_t
iw
=
((
x
>>
24
)
&
0xff
);
res
[
idx
]
=
res
[
idx
]
>
ix
?
res
[
idx
]
:
ix
;
res
[
idx
+
1
]
=
res
[
idx
+
1
]
>
iy
?
res
[
idx
+
1
]
:
iy
;
res
[
idx
+
2
]
=
res
[
idx
+
2
]
>
iz
?
res
[
idx
+
2
]
:
iz
;
res
[
idx
+
3
]
=
res
[
idx
+
3
]
>
iw
?
res
[
idx
+
3
]
:
iw
;
}
};
template
<
int
regs
,
typename
Dtype
,
typename
OutDtype
>
__device__
__forceinline__
OutDtype
pack_int8
(
int8_t
(
&
x
)[
regs
]);
template
<
typename
src_type
,
typename
feed_type
>
struct
MaxPooler
;
template
<
>
__device__
__forceinline__
int
pack_int8
<
4
,
int8_t
,
int
>
(
int8_t
(
&
x
)[
4
])
{
return
pack_int8_to_int8x4
(
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]);
}
#define SPEC_WITH_FEED_TYPE(_feed_type) \
template <> \
struct MaxPooler<int8_t, _feed_type> : MaxPoolerBase<int8_t, _feed_type>
template
<
>
__device__
__forceinline__
int2
pack_int8
<
8
,
int8_t
,
int2
>
(
int8_t
(
&
x
)[
8
])
{
int8_t
x0
[
4
]{
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]};
int8_t
x1
[
4
]{
x
[
4
],
x
[
5
],
x
[
6
],
x
[
7
]};
return
::
make_int2
(
pack_int8
<
4
,
int8_t
,
int
>
(
x0
),
pack_int8
<
4
,
int8_t
,
int
>
(
x1
));
}
#define COMMON_DEFS(_feed_type) \
using feed_type = _feed_type; \
using Base = MaxPoolerBase<int8_t, _feed_type>; \
using MaxPoolerBase<int8_t, _feed_type>::MaxPoolerBase;
template
<
>
__device__
__forceinline__
int4
pack_int8
<
16
,
int8_t
,
int4
>
(
int8_t
(
&
x
)[
16
])
{
int8_t
x0
[
4
]{
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
]};
int8_t
x1
[
4
]{
x
[
4
],
x
[
5
],
x
[
6
],
x
[
7
]};
int8_t
x2
[
4
]{
x
[
8
],
x
[
9
],
x
[
10
],
x
[
11
]};
int8_t
x3
[
4
]{
x
[
12
],
x
[
13
],
x
[
14
],
x
[
15
]};
return
::
make_int4
(
pack_int8
<
4
,
int8_t
,
int
>
(
x0
),
pack_int8
<
4
,
int8_t
,
int
>
(
x1
),
pack_int8
<
4
,
int8_t
,
int
>
(
x2
),
pack_int8
<
4
,
int8_t
,
int
>
(
x3
));
}
#define cb(_x, _y, _z, _w, _ret) \
{ _ret = pack_int8_to_int8x4(_x, _y, _z, _w); }
__device__
__forceinline__
int8_t
pack_int8_to_int4x2
(
int8_t
x0
,
int8_t
x1
)
{
return
(
x0
&
0xf
)
|
(
x1
<<
4
);
}
template
<
>
__device__
__forceinline__
int
pack_int8
<
8
,
dt_qint4
,
int
>
(
int8_t
(
&
x
)[
8
])
{
int8_t
x0
=
pack_int8_to_int4x2
(
x
[
0
],
x
[
1
]);
int8_t
x1
=
pack_int8_to_int4x2
(
x
[
2
],
x
[
3
]);
int8_t
x2
=
pack_int8_to_int4x2
(
x
[
4
],
x
[
5
]);
int8_t
x3
=
pack_int8_to_int4x2
(
x
[
6
],
x
[
7
]);
return
pack_int8_to_int8x4
(
x0
,
x1
,
x2
,
x3
);
}
template
<
>
__device__
__forceinline__
int4
pack_int8
<
32
,
dt_qint4
,
int4
>
(
int8_t
(
&
x
)[
32
])
{
int8_t
x0
[
8
]{
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
],
x
[
4
],
x
[
5
],
x
[
6
],
x
[
7
]};
int8_t
x1
[
8
]{
x
[
8
],
x
[
9
],
x
[
10
],
x
[
11
],
x
[
12
],
x
[
13
],
x
[
14
],
x
[
15
]};
int8_t
x2
[
8
]{
x
[
16
],
x
[
17
],
x
[
18
],
x
[
19
],
x
[
20
],
x
[
21
],
x
[
22
],
x
[
23
]};
int8_t
x3
[
8
]{
x
[
24
],
x
[
25
],
x
[
26
],
x
[
27
],
x
[
28
],
x
[
29
],
x
[
30
],
x
[
31
]};
return
::
make_int4
(
pack_int8
<
8
,
dt_qint4
,
int
>
(
x0
),
pack_int8
<
8
,
dt_qint4
,
int
>
(
x1
),
pack_int8
<
8
,
dt_qint4
,
int
>
(
x2
),
pack_int8
<
8
,
dt_qint4
,
int
>
(
x3
));
}
SPEC_WITH_FEED_TYPE
(
int32_t
)
{
COMMON_DEFS
(
int32_t
);
__device__
__forceinline__
void
feed
(
int32_t
x
)
{
FEED1
;
}
template
<
typename
Dtype
>
struct
TypeTrait
;
__device__
__forceinline__
int
get_ans
()
{
int
i1
;
ANS1
(
cb
);
return
i1
;
}
template
<
>
struct
TypeTrait
<
int8_t
>
{
static
constexpr
int
bit_width
=
8
;
static
constexpr
int
mask
=
0xff
;
static
constexpr
int8_t
min
=
-
128
;
static
constexpr
int
elem_per_32bit
=
32
/
bit_width
;
static
constexpr
int
shift_fix_sign
=
0
;
};
SPEC_WITH_FEED_TYPE
(
int2
)
{
COMMON_DEFS
(
int2
);
__device__
__forceinline__
void
feed
(
int2
x
)
{
FEED2
;
}
__device__
__forceinline__
int2
get_ans
()
{
int
i1
,
i2
;
ANS2
(
cb
);
return
::
make_int2
(
i1
,
i2
);
}
template
<
>
struct
TypeTrait
<
dt_qint4
>
{
static
constexpr
int
bit_width
=
4
;
static
constexpr
int
mask
=
0xf
;
static
constexpr
int8_t
min
=
-
8
;
static
constexpr
int
elem_per_32bit
=
32
/
bit_width
;
static
constexpr
int
shift_fix_sign
=
4
;
};
SPEC_WITH_FEED_TYPE
(
int4
)
{
COMMON_DEFS
(
int4
);
__device__
__forceinline__
void
feed
(
int4
x
)
{
FEED4
;
}
template
<
typename
src_type
,
typename
_feed_type
>
struct
MaxPooler
{
using
feed_type
=
_feed_type
;
static
constexpr
int
bit_width
=
TypeTrait
<
src_type
>::
bit_width
;
static
constexpr
int
nr_results
=
sizeof
(
feed_type
)
*
8
/
bit_width
;
static
constexpr
int
elem_per_32bit
=
TypeTrait
<
src_type
>::
elem_per_32bit
;
static
constexpr
int
shift_fix_sign
=
TypeTrait
<
src_type
>::
shift_fix_sign
;
int8_t
res
[
nr_results
];
__device__
__forceinline__
int4
get_ans
()
{
int
i1
,
i2
,
i3
,
i4
;
ANS4
(
cb
);
return
::
make_int4
(
i1
,
i2
,
i3
,
i4
);
__device__
MaxPooler
(
int
)
{}
__device__
__forceinline__
void
init
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
nr_results
;
++
i
)
{
res
[
i
]
=
TypeTrait
<
src_type
>::
min
;
}
}
__device__
__forceinline__
void
feed
(
int
x
,
int
idx
=
0
)
{
constexpr
int
unroll_n
=
sizeof
(
int
)
*
8
/
bit_width
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_n
;
i
++
)
{
int8_t
temp
=
((
x
>>
(
i
*
bit_width
))
&
TypeTrait
<
src_type
>::
mask
)
<<
shift_fix_sign
;
temp
=
temp
>>
shift_fix_sign
;
res
[
idx
+
i
]
=
res
[
idx
+
i
]
>
temp
?
res
[
idx
+
i
]
:
temp
;
}
}
__device__
__forceinline__
void
feed
(
int2
x
)
{
feed
(
x
.
x
,
0
*
elem_per_32bit
);
feed
(
x
.
y
,
1
*
elem_per_32bit
);
}
__device__
__forceinline__
void
feed
(
int4
x
)
{
feed
(
x
.
x
,
0
*
elem_per_32bit
);
feed
(
x
.
y
,
1
*
elem_per_32bit
);
feed
(
x
.
z
,
2
*
elem_per_32bit
);
feed
(
x
.
w
,
3
*
elem_per_32bit
);
}
__device__
__forceinline__
feed_type
get_ans
()
{
feed_type
ans
;
ans
=
pack_int8
<
nr_results
,
src_type
,
feed_type
>
(
res
);
return
ans
;
}
};
#undef cb
#undef COMMON_DEFS
#undef SPEC_WITH_FEED_TYPE
template
<
typename
src_type
,
typename
_feed_type
,
typename
inter_type
>
struct
MeanIncludeRoundedPooler
{
using
feed_type
=
_feed_type
;
static
constexpr
int
bit_width
=
TypeTrait
<
src_type
>::
bit_width
;
static
constexpr
int
nr_results
=
sizeof
(
feed_type
)
*
8
/
bit_width
;
static
constexpr
int
elem_per_32bit
=
TypeTrait
<
src_type
>::
elem_per_32bit
;
static
constexpr
int
shift_fix_sign
=
TypeTrait
<
src_type
>::
shift_fix_sign
;
template
<
typename
src_type
,
typename
feed_type
,
typename
inter_type
>
struct
MeanIncludeRoundedPoolerBase
;
template
<
typename
feed_type
>
struct
MeanIncludeRoundedPoolerBase
<
int8_t
,
feed_type
,
int32_t
>
{
static
constexpr
int
nr_results
=
sizeof
(
feed_type
)
/
sizeof
(
int8_t
);
int32_t
res
[
nr_results
];
const
int
count
;
const
float
fi_count
;
__device__
MeanIncludeRoundedPooler
Base
(
int
count
)
__device__
MeanIncludeRoundedPooler
(
int
count
)
:
count
{
count
},
fi_count
{
1.
f
/
count
}
{}
__device__
__forceinline__
void
init
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
nr_results
;
++
i
)
{
res
[
i
]
=
0
;
}
}
__device__
__forceinline__
void
feed
(
int32_t
x
,
int
idx
)
{
int8_t
ix
=
(
x
&
0xff
);
int8_t
iy
=
((
x
>>
8
)
&
0xff
);
int8_t
iz
=
((
x
>>
16
)
&
0xff
);
int8_t
iw
=
((
x
>>
24
)
&
0xff
);
res
[
idx
]
+=
static_cast
<
int32_t
>
(
ix
);
res
[
idx
+
1
]
+=
static_cast
<
int32_t
>
(
iy
);
res
[
idx
+
2
]
+=
static_cast
<
int32_t
>
(
iz
);
res
[
idx
+
3
]
+=
static_cast
<
int32_t
>
(
iw
);
}
};
template
<
typename
src_type
,
typename
feed_type
,
typename
inter_type
>
struct
MeanIncludeRoundedPooler
;
#define SPEC_WITH_FEED_TYPE(_feed_type) \
template <> \
struct MeanIncludeRoundedPooler<int8_t, _feed_type, int32_t> \
: MeanIncludeRoundedPoolerBase<int8_t, _feed_type, int32_t>
#define COMMON_DEFS(_feed_type) \
using feed_type = _feed_type; \
using Base = MeanIncludeRoundedPoolerBase<int8_t, _feed_type, int32_t>; \
using MeanIncludeRoundedPoolerBase<int8_t, _feed_type, \
int32_t>::MeanIncludeRoundedPoolerBase;
#define cb(_x, _y, _z, _w, _ret) \
{ \
float fx = roundf(static_cast<float>(_x) * Base::fi_count); \
float fy = roundf(static_cast<float>(_y) * Base::fi_count); \
float fz = roundf(static_cast<float>(_z) * Base::fi_count); \
float fw = roundf(static_cast<float>(_w) * Base::fi_count); \
_ret = transform_float4_to_int8x4(::make_float4(fx, fy, fz, fw)); \
__device__
__forceinline__
void
feed
(
int
x
,
int
idx
=
0
)
{
constexpr
int
unroll_n
=
sizeof
(
int
)
*
8
/
bit_width
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_n
;
i
++
)
{
int8_t
temp
=
((
x
>>
(
i
*
bit_width
))
&
TypeTrait
<
src_type
>::
mask
)
<<
shift_fix_sign
;
temp
=
temp
>>
shift_fix_sign
;
res
[
idx
+
i
]
+=
static_cast
<
int32_t
>
(
temp
);
}
}
SPEC_WITH_FEED_TYPE
(
int32_t
)
{
COMMON_DEFS
(
int32_t
);
__device__
__forceinline__
void
feed
(
int32_t
x
)
{
FEED1
;
}
__device__
__forceinline__
int
get_ans
()
{
int
i1
;
ANS1
(
cb
);
return
i1
;
__device__
__forceinline__
void
feed
(
int2
x
)
{
feed
(
x
.
x
,
0
*
elem_per_32bit
);
feed
(
x
.
y
,
1
*
elem_per_32bit
);
}
};
SPEC_WITH_FEED_TYPE
(
int2
)
{
COMMON_DEFS
(
int2
);
__device__
__forceinline__
void
feed
(
int2
x
)
{
FEED2
;
}
__device__
__forceinline__
int2
get_ans
()
{
int
i1
,
i2
;
ANS2
(
cb
);
return
::
make_int2
(
i1
,
i2
);
__device__
__forceinline__
void
feed
(
int4
x
)
{
feed
(
x
.
x
,
0
*
elem_per_32bit
);
feed
(
x
.
y
,
1
*
elem_per_32bit
);
feed
(
x
.
z
,
2
*
elem_per_32bit
);
feed
(
x
.
w
,
3
*
elem_per_32bit
);
}
};
SPEC_WITH_FEED_TYPE
(
int4
)
{
COMMON_DEFS
(
int4
);
__device__
__forceinline__
void
feed
(
int4
x
)
{
FEED4
;
}
__device__
__forceinline__
int4
get_ans
()
{
int
i1
,
i2
,
i3
,
i4
;
ANS4
(
cb
);
return
::
make_int4
(
i1
,
i2
,
i3
,
i4
);
__device__
__forceinline__
feed_type
get_ans
()
{
feed_type
ans
;
int8_t
out_res
[
nr_results
];
#pragma unroll
for
(
int
i
=
0
;
i
<
nr_results
;
i
++
)
{
float
f32_res
=
roundf
(
static_cast
<
float
>
(
res
[
i
])
*
fi_count
);
int
i8_res
;
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
i8_res
)
:
"f"
(
f32_res
));
out_res
[
i
]
=
i8_res
;
}
ans
=
pack_int8
<
nr_results
,
src_type
,
feed_type
>
(
out_res
);
return
ans
;
}
};
#undef cb
#undef COMMON_DEFS
#undef SPEC_WITH_FEED_TYPE
template
<
typename
src_type
,
typename
feed_type
,
typename
inter_type
>
struct
MeanExcludeRoundedPoolerBase
;
template
<
typename
feed_type
>
struct
MeanExcludeRoundedPoolerBase
<
int8_t
,
feed_type
,
int32_t
>
{
static
const
int
nr_results
=
sizeof
(
feed_type
)
/
sizeof
(
int8_t
);
template
<
typename
src_type
,
typename
_feed_type
,
typename
inter_type
>
struct
MeanExcludeRoundedPooler
{
using
feed_type
=
_feed_type
;
static
constexpr
int
bit_width
=
TypeTrait
<
src_type
>::
bit_width
;
static
constexpr
int
nr_results
=
sizeof
(
feed_type
)
*
8
/
bit_width
;
static
constexpr
int
elem_per_32bit
=
TypeTrait
<
src_type
>::
elem_per_32bit
;
static
constexpr
int
shift_fix_sign
=
TypeTrait
<
src_type
>::
shift_fix_sign
;
int32_t
res
[
nr_results
];
int
count
;
__device__
MeanExcludeRoundedPooler
(
int
)
{}
__device__
MeanExcludeRoundedPoolerBase
(
int
/* count */
)
{}
__device__
__forceinline__
void
init
()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
nr_results
;
++
i
)
{
...
...
@@ -234,87 +218,50 @@ struct MeanExcludeRoundedPoolerBase<int8_t, feed_type, int32_t> {
}
count
=
0
;
}
__device__
__forceinline__
void
feed
(
int32_t
x
,
int
idx
)
{
int8_t
ix
=
(
x
&
0xff
);
int8_t
iy
=
((
x
>>
8
)
&
0xff
);
int8_t
iz
=
((
x
>>
16
)
&
0xff
);
int8_t
iw
=
((
x
>>
24
)
&
0xff
);
res
[
idx
]
+=
static_cast
<
int32_t
>
(
ix
);
res
[
idx
+
1
]
+=
static_cast
<
int32_t
>
(
iy
);
res
[
idx
+
2
]
+=
static_cast
<
int32_t
>
(
iz
);
res
[
idx
+
3
]
+=
static_cast
<
int32_t
>
(
iw
);
}
};
template
<
typename
src_type
,
typename
feed_type
,
typename
inter_type
>
struct
MeanExcludeRoundedPooler
;
#define SPEC_WITH_FEED_TYPE(_feed_type) \
template <> \
struct MeanExcludeRoundedPooler<int8_t, _feed_type, int32_t> \
: MeanExcludeRoundedPoolerBase<int8_t, _feed_type, int32_t>
#define COMMON_DEFS(_feed_type) \
using feed_type = _feed_type; \
using Base = MeanExcludeRoundedPoolerBase<int8_t, _feed_type, int32_t>; \
using MeanExcludeRoundedPoolerBase<int8_t, _feed_type, \
int32_t>::MeanExcludeRoundedPoolerBase;
#define cb(_x, _y, _z, _w, _ret) \
{ \
float fx = roundf(static_cast<float>(_x) / Base::count); \
float fy = roundf(static_cast<float>(_y) / Base::count); \
float fz = roundf(static_cast<float>(_z) / Base::count); \
float fw = roundf(static_cast<float>(_w) / Base::count); \
_ret = transform_float4_to_int8x4(::make_float4(fx, fy, fz, fw)); \
__device__
__forceinline__
void
feed
(
int
x
,
int
idx
)
{
constexpr
int
unroll_n
=
sizeof
(
int
)
*
8
/
bit_width
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_n
;
i
++
)
{
int8_t
temp
=
((
x
>>
(
i
*
bit_width
))
&
TypeTrait
<
src_type
>::
mask
)
<<
shift_fix_sign
;
temp
=
temp
>>
shift_fix_sign
;
res
[
idx
+
i
]
+=
static_cast
<
int32_t
>
(
temp
);
}
}
SPEC_WITH_FEED_TYPE
(
int32_t
)
{
COMMON_DEFS
(
int32_t
);
__device__
__forceinline__
void
feed
(
int32_t
x
)
{
FEED1
;
__device__
__forceinline__
void
feed
(
int
x
)
{
feed
(
x
,
0
);
count
++
;
}
__device__
__forceinline__
int
get_ans
()
{
int
i1
;
ANS1
(
cb
);
return
i1
;
}
};
SPEC_WITH_FEED_TYPE
(
int2
)
{
COMMON_DEFS
(
int2
);
__device__
__forceinline__
void
feed
(
int2
x
)
{
FEED2
;
feed
(
x
.
x
,
0
*
elem_per_32bit
);
feed
(
x
.
y
,
1
*
elem_per_32bit
);
count
++
;
}
__device__
__forceinline__
int2
get_ans
()
{
int
i1
,
i2
;
ANS2
(
cb
);
return
::
make_int2
(
i1
,
i2
);
}
};
SPEC_WITH_FEED_TYPE
(
int4
)
{
COMMON_DEFS
(
int4
);
__device__
__forceinline__
void
feed
(
int4
x
)
{
FEED4
;
feed
(
x
.
x
,
0
*
elem_per_32bit
);
feed
(
x
.
y
,
1
*
elem_per_32bit
);
feed
(
x
.
z
,
2
*
elem_per_32bit
);
feed
(
x
.
w
,
3
*
elem_per_32bit
);
count
++
;
}
__device__
__forceinline__
int4
get_ans
()
{
int
i1
,
i2
,
i3
,
i4
;
ANS4
(
cb
);
return
::
make_int4
(
i1
,
i2
,
i3
,
i4
);
__device__
__forceinline__
feed_type
get_ans
()
{
feed_type
ans
;
int8_t
out_res
[
nr_results
];
#pragma unroll
for
(
int
i
=
0
;
i
<
nr_results
;
i
++
)
{
float
f32_res
=
roundf
(
static_cast
<
float
>
(
res
[
i
])
/
count
);
int
i8_res
;
asm
volatile
(
"cvt.rni.s8.f32 %0, %1;"
:
"=r"
(
i8_res
)
:
"f"
(
f32_res
));
out_res
[
i
]
=
i8_res
;
}
ans
=
pack_int8
<
nr_results
,
src_type
,
feed_type
>
(
out_res
);
return
ans
;
}
};
#undef cb
#undef COMMON_DEFS
#undef SPEC_WITH_FEED_TYPE
template
<
typename
Pooler
>
__global__
void
pooling2d_device_template_int8_cdiv4hwn4
(
const
int8_t
*
__restrict__
src
,
int8_t
*
__restrict__
dst
,
Param
param
)
{
...
...
@@ -362,70 +309,19 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4(
*
(
reinterpret_cast
<
ldg_type
*>
(
g_dst_ptr
))
=
res
;
}
template
<
typename
Pooler
>
__global__
void
pooling2d_device_template_int8_ncdiv4hw4
(
const
int8_t
*
__restrict__
src
,
int8_t
*
__restrict__
dst
,
Param
param
)
{
const
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
using
ldg_type
=
typename
Pooler
::
feed_type
;
static
int
constexpr
pack_size
=
4
;
static
int
constexpr
ldg_width
=
sizeof
(
ldg_type
)
/
sizeof
(
int32_t
);
MEGDNN_STATIC_ASSERT
(
ldg_width
==
1
,
"pooling2d (NCHW4) kernel must use 32bit width ldg instruction"
);
const
int
wo_ldg
=
param
.
wo
/
ldg_width
;
const
int
c_packed
=
param
.
c
/
pack_size
;
const
int
batch
=
tid
/
(
param
.
ho
*
wo_ldg
*
c_packed
);
const
int
chw
=
tid
-
batch
*
param
.
ho
*
wo_ldg
*
c_packed
;
const
int
oc_packed
=
chw
/
(
param
.
ho
*
wo_ldg
);
const
int
hw
=
chw
-
oc_packed
*
param
.
ho
*
wo_ldg
;
const
int
oh
=
hw
/
wo_ldg
;
const
int
ow
=
(
hw
-
wo_ldg
*
oh
)
*
ldg_width
;
if
(
batch
>=
param
.
n
||
oc_packed
>=
c_packed
||
oh
>=
param
.
ho
||
ow
>=
param
.
wo
)
return
;
const
int
in_batch_stride
=
param
.
hi
*
param
.
wi
*
param
.
c
;
const
int
out_batch_stride
=
param
.
ho
*
param
.
wo
*
param
.
c
;
const
int
in_channel_stride
=
param
.
hi
*
param
.
wi
*
pack_size
;
const
int
out_channel_stride
=
param
.
ho
*
param
.
wo
*
pack_size
;
const
int8_t
*
__restrict__
g_src_ptr
=
src
+
batch
*
in_batch_stride
+
oc_packed
*
in_channel_stride
;
int8_t
*
__restrict__
g_dst_ptr
=
dst
+
batch
*
out_batch_stride
+
oc_packed
*
out_channel_stride
+
(
oh
*
param
.
wo
+
ow
)
*
pack_size
;
Pooler
pooler
(
param
.
window_h
*
param
.
window_w
);
pooler
.
init
();
for
(
int
fh
=
0
;
fh
<
param
.
window_h
;
fh
++
)
{
uint32_t
ih
=
oh
*
param
.
sh
+
fh
-
param
.
ph
;
for
(
int
fw
=
0
;
fw
<
param
.
window_w
;
fw
++
)
{
uint32_t
iw
=
ow
*
param
.
sw
+
fw
-
param
.
pw
;
if
(
ih
<
param
.
hi
&&
iw
<
param
.
wi
)
{
const
int8_t
*
__restrict__
cur_src_ptr
=
g_src_ptr
+
(
ih
*
param
.
wi
+
iw
)
*
pack_size
;
ldg_type
sval
=
__ldg
(
reinterpret_cast
<
const
ldg_type
*>
(
cur_src_ptr
));
pooler
.
feed
(
sval
);
}
}
}
ldg_type
res
=
pooler
.
get_ans
();
*
(
reinterpret_cast
<
ldg_type
*>
(
g_dst_ptr
))
=
res
;
}
template
<
typename
Pooler
>
__global__
void
pooling2d_device_template_int8_ncdiv32hw32
(
const
int8_t
*
__restrict__
src
,
int8_t
*
__restrict__
dst
,
Param
param
)
{
template
<
typename
Pooler
,
int
pack_size
,
int
pack_byte
,
int
ldg_width_assert
=
4
>
__global__
void
pooling2d_device_template_nchwc
(
const
int8_t
*
__restrict__
src
,
int8_t
*
__restrict__
dst
,
Param
param
)
{
const
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
using
ldg_type
=
typename
Pooler
::
feed_type
;
static
int
constexpr
pack_size
=
32
;
static
int
constexpr
ldg_width
=
sizeof
(
ldg_type
)
/
sizeof
(
int32_t
);
static
int
constexpr
ldg_width_bytes
=
sizeof
(
ldg_type
);
static
int
constexpr
section
=
pack_
siz
e
/
sizeof
(
ldg_type
);
static
int
constexpr
section
=
pack_
byt
e
/
sizeof
(
ldg_type
);
MEGDNN_STATIC_ASSERT
(
ldg_width
==
4
,
"pooling2d (NCHW
32
) kernel must use 128bit width ldg instruction"
);
ldg_width
==
ldg_width_assert
,
"pooling2d (NCHW
64
) kernel must use 128bit width ldg instruction"
);
const
int
c_packed
=
param
.
c
/
pack_size
;
const
int
batch
=
tid
/
(
param
.
ho
*
param
.
wo
*
c_packed
*
section
);
const
int
batch_residual
=
...
...
@@ -439,16 +335,18 @@ __global__ void pooling2d_device_template_int8_ncdiv32hw32(
if
(
batch
>=
param
.
n
||
oc
>=
c_packed
||
oh
>=
param
.
ho
||
ow
>=
param
.
wo
)
return
;
const
int
in_batch_stride
=
param
.
hi
*
param
.
wi
*
param
.
c
;
const
int
out_batch_stride
=
param
.
ho
*
param
.
wo
*
param
.
c
;
const
int
in_channel_stride
=
param
.
hi
*
param
.
wi
*
pack_size
;
const
int
out_channel_stride
=
param
.
ho
*
param
.
wo
*
pack_size
;
const
int8_t
*
__restrict__
g_src_ptr
=
src
+
batch
*
in_batch_stride
+
oc
*
in_channel_stride
+
sec
*
ldg_width_bytes
;
const
int
in_batch_stride
=
param
.
hi
*
param
.
wi
*
param
.
c
*
pack_byte
/
pack_size
;
const
int
out_batch_stride
=
param
.
ho
*
param
.
wo
*
param
.
c
*
pack_byte
/
pack_size
;
const
int
in_channel_stride
=
param
.
hi
*
param
.
wi
*
pack_byte
;
const
int
out_channel_stride
=
param
.
ho
*
param
.
wo
*
pack_byte
;
const
int8_t
*
__restrict__
g_src_ptr
=
src
+
(
batch
*
in_batch_stride
+
oc
*
in_channel_stride
+
sec
*
ldg_width_bytes
);
int8_t
*
__restrict__
g_dst_ptr
=
dst
+
batch
*
out_batch_stride
+
oc
*
out_channel_stride
+
(
oh
*
param
.
wo
+
ow
)
*
pack_size
+
sec
*
ldg_width_bytes
;
dst
+
(
batch
*
out_batch_stride
+
oc
*
out_channel_stride
+
(
oh
*
param
.
wo
+
ow
)
*
pack_byte
+
sec
*
ldg_width_bytes
)
;
Pooler
pooler
(
param
.
window_h
*
param
.
window_w
);
pooler
.
init
();
...
...
@@ -458,7 +356,7 @@ __global__ void pooling2d_device_template_int8_ncdiv32hw32(
uint32_t
iw
=
ow
*
param
.
sw
+
fw
-
param
.
pw
;
if
(
ih
<
param
.
hi
&&
iw
<
param
.
wi
)
{
const
int8_t
*
__restrict__
cur_src_ptr
=
g_src_ptr
+
(
ih
*
param
.
wi
+
iw
)
*
pack_
siz
e
;
g_src_ptr
+
(
ih
*
param
.
wi
+
iw
)
*
pack_
byt
e
;
ldg_type
sval
=
__ldg
(
reinterpret_cast
<
const
ldg_type
*>
(
cur_src_ptr
));
pooler
.
feed
(
sval
);
...
...
@@ -527,19 +425,29 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src,
uint32_t
mode
)
{
using
Mode
=
megdnn
::
param_enumv
::
Pooling
::
Mode
;
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
);
uint32_t
vthreads
=
param
.
n
*
param
.
c
*
param
.
ho
*
param
.
wo
/
4
;
constexpr
int
ldg_byte
=
4
;
constexpr
int
elem_per_byte
=
1
;
constexpr
int
pack_size
=
4
;
constexpr
int
pack_byte
=
pack_size
/
elem_per_byte
;
constexpr
int
elem_per_thread
=
ldg_byte
*
elem_per_byte
;
constexpr
int
ldg_assert_width
=
ldg_byte
/
sizeof
(
int32_t
);
uint32_t
vthreads
=
param
.
n
*
param
.
c
*
param
.
ho
*
param
.
wo
/
elem_per_thread
;
switch
(
mode
)
{
case
Mode
::
MAX
:
kern
=
pooling2d_device_template_int8_ncdiv4hw4
<
MaxPooler
<
int8_t
,
int32_t
>>
;
kern
=
pooling2d_device_template_nchwc
<
MaxPooler
<
int8_t
,
int32_t
>
,
pack_size
,
pack_byte
,
ldg_assert_width
>
;
break
;
case
Mode
::
AVERAGE
:
kern
=
pooling2d_device_template_int8_ncdiv4hw4
<
MeanIncludeRoundedPooler
<
int8_t
,
int32_t
,
int32_t
>>
;
kern
=
pooling2d_device_template_nchwc
<
MeanIncludeRoundedPooler
<
int8_t
,
int32_t
,
int32_t
>
,
pack_size
,
pack_byte
,
ldg_assert_width
>
;
break
;
case
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
:
kern
=
pooling2d_device_template_int8_ncdiv4hw4
<
MeanExcludeRoundedPooler
<
int8_t
,
int32_t
,
int32_t
>>
;
kern
=
pooling2d_device_template_nchwc
<
MeanExcludeRoundedPooler
<
int8_t
,
int32_t
,
int32_t
>
,
pack_size
,
pack_byte
,
ldg_assert_width
>
;
break
;
default:
megdnn_assert
(
false
,
"invalid pooling mode"
);
...
...
@@ -558,19 +466,27 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src,
uint32_t
mode
)
{
using
Mode
=
megdnn
::
param_enumv
::
Pooling
::
Mode
;
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
);
uint32_t
vthreads
=
param
.
n
*
param
.
c
*
param
.
ho
*
param
.
wo
/
16
;
constexpr
int
ldg_byte
=
16
;
constexpr
int
elem_per_byte
=
1
;
constexpr
int
pack_size
=
32
;
constexpr
int
pack_byte
=
pack_size
/
elem_per_byte
;
constexpr
int
elem_per_thread
=
ldg_byte
*
elem_per_byte
;
uint32_t
vthreads
=
param
.
n
*
param
.
c
*
param
.
ho
*
param
.
wo
/
elem_per_thread
;
switch
(
mode
)
{
case
Mode
::
MAX
:
kern
=
pooling2d_device_template_
int8_ncdiv32hw32
<
MaxPooler
<
int8_t
,
int4
>
>
;
kern
=
pooling2d_device_template_
nchwc
<
MaxPooler
<
int8_t
,
int4
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE
:
kern
=
pooling2d_device_template_int8_ncdiv32hw32
<
MeanIncludeRoundedPooler
<
int8_t
,
int4
,
int32_t
>>
;
kern
=
pooling2d_device_template_nchwc
<
MeanIncludeRoundedPooler
<
int8_t
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
:
kern
=
pooling2d_device_template_int8_ncdiv32hw32
<
MeanExcludeRoundedPooler
<
int8_t
,
int4
,
int32_t
>>
;
kern
=
pooling2d_device_template_nchwc
<
MeanExcludeRoundedPooler
<
int8_t
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
default:
megdnn_assert
(
false
,
"invalid pooling mode"
);
...
...
@@ -582,11 +498,43 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src,
after_kernel_launch
();
}
#undef FEED1
#undef FEED2
#undef FEED3
#undef ANS1
#undef ANS2
#undef ANS4
void
megdnn
::
cuda
::
pooling2d
::
do_pooling2d_int4_ncdiv64hw64
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
)
{
using
Mode
=
megdnn
::
param_enumv
::
Pooling
::
Mode
;
void
(
*
kern
)(
const
int8_t
*
__restrict__
,
int8_t
*
__restrict__
,
Param
param
);
constexpr
int
ldg_byte
=
16
;
constexpr
int
elem_per_byte
=
2
;
constexpr
int
pack_size
=
64
;
constexpr
int
pack_byte
=
pack_size
/
elem_per_byte
;
constexpr
int
elem_per_thread
=
ldg_byte
*
elem_per_byte
;
uint32_t
vthreads
=
param
.
n
*
param
.
c
*
param
.
ho
*
param
.
wo
/
elem_per_thread
;
switch
(
mode
)
{
case
Mode
::
MAX
:
kern
=
pooling2d_device_template_nchwc
<
MaxPooler
<
dt_qint4
,
int4
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE
:
kern
=
pooling2d_device_template_nchwc
<
MeanIncludeRoundedPooler
<
dt_qint4
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
case
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
:
kern
=
pooling2d_device_template_nchwc
<
MeanExcludeRoundedPooler
<
dt_qint4
,
int4
,
int32_t
>
,
pack_size
,
pack_byte
>
;
break
;
default:
megdnn_assert
(
false
,
"invalid pooling mode"
);
}
uint32_t
nr_threads
=
query_blocksize_for_kernel
(
kern
);
nr_threads
=
std
::
min
(
nr_threads
,
vthreads
);
uint32_t
nr_blocks
=
DIVUP
(
vthreads
,
nr_threads
);
kern
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
d_src
,
d_dst
,
param
);
after_kernel_launch
();
}
// vim: syntax=cuda.doxygen
dnn/src/cuda/pooling/pooling2d_
int8
.cuh
→
dnn/src/cuda/pooling/pooling2d_
qint
.cuh
浏览文件 @
2398df07
/**
* \file dnn/src/cuda/pooling/pooling2d_
int8
.cuh
* \file dnn/src/cuda/pooling/pooling2d_
qint
.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...
...
@@ -32,6 +32,11 @@ void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst,
void
do_pooling2d_int8_ncdiv32hw32
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
);
void
do_pooling2d_int4_ncdiv64hw64
(
const
int8_t
*
d_src
,
int8_t
*
d_dst
,
const
Param
&
param
,
cudaStream_t
stream
,
uint32_t
mode
);
}
// namespace pooling2d
}
// namespace cuda
}
// namespace megdnn
...
...
dnn/src/naive/pooling/opr_impl.cpp
浏览文件 @
2398df07
...
...
@@ -15,6 +15,7 @@
#include "megdnn/dtype.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "src/naive/lowbit_utils.h"
#include "midout.h"
MIDOUT_DECL
(
megdnn_naive_pooling
)
...
...
@@ -190,6 +191,12 @@ struct NCHW32IdxGetter {
return
(((
n
*
(
C
>>
5
)
+
(
c
>>
5
))
*
H
+
h
)
*
W
+
w
)
*
32
+
(
c
&
0x1f
);
}
};
struct
NCHW64IdxGetter
{
static
size_t
get_idx
(
size_t
n
,
size_t
c
,
size_t
h
,
size_t
w
,
size_t
,
size_t
C
,
size_t
H
,
size_t
W
)
{
return
(((
n
*
(
C
>>
6
)
+
(
c
>>
6
))
*
H
+
h
)
*
W
+
w
)
*
64
+
(
c
&
0x3f
);
}
};
/*!
* Pooler for AVERAGE_COUNT_EXCLUDE_PADDING mode
*/
...
...
@@ -375,15 +382,81 @@ void pooling_backward_max_impl(const ctype* __restrict src,
namespace
megdnn
{
namespace
naive
{
WorkspaceBundle
PoolingForwardImpl
::
get_workspace_bundle
(
void
*
ptr
,
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
const
{
SmallVector
<
size_t
>
sizes
;
TensorLayout
fsrc
=
src
;
TensorLayout
fdst
=
dst
;
auto
get_workspace
=
[
&
sizes
](
TensorLayout
&
layout
)
{
if
(
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
||
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
layout
.
dtype
=
dtype
::
Int8
();
layout
.
format
=
TensorLayout
::
Format
(
layout
.
dtype
);
sizes
.
push_back
(
layout
.
span
().
dist_byte
());
}
};
get_workspace
(
fsrc
);
get_workspace
(
fdst
);
return
{
ptr
,
std
::
move
(
sizes
)};
};
size_t
PoolingForwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
return
get_workspace_bundle
(
nullptr
,
src
,
dst
).
total_size_in_bytes
();
}
namespace
{
void
post_process
(
const
TensorND
&
dst
,
TensorND
&
comp_dst
,
Handle
*
handle
,
WorkspaceBundle
&
workspace_bundle
)
{
if
(
dst
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
int8_to_int4
(
comp_dst
,
dst
);
}
else
if
(
dst
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
uint8_to_uint4
(
comp_dst
,
dst
);
}
}
}
// namespace
void
PoolingForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
dst
.
layout
,
workspace
.
size
);
TensorND
comp_src
=
src
;
TensorND
comp_dst
=
dst
;
auto
wsb
=
get_workspace_bundle
(
workspace
.
raw_ptr
,
src
.
layout
,
dst
.
layout
);
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
float
scale
=
src
.
layout
.
dtype
.
param
<
dtype
::
QuantizedS4
>
().
scale
;
comp_src
.
layout
.
dtype
=
dtype
::
QuantizedS8
(
scale
);
comp_src
.
layout
.
init_contiguous_stride
();
comp_src
.
layout
.
format
=
TensorLayout
::
Format
(
comp_src
.
layout
.
dtype
);
comp_src
.
raw_ptr
=
wsb
.
get
(
0
);
comp_dst
.
layout
.
dtype
=
dtype
::
QuantizedS8
(
scale
);
comp_dst
.
layout
.
format
=
TensorLayout
::
Format
(
comp_dst
.
layout
.
dtype
);
comp_dst
.
layout
.
init_contiguous_stride
();
comp_dst
.
raw_ptr
=
wsb
.
get
(
1
);
int4_to_int8
(
src
,
comp_src
);
}
else
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
float
scale
=
src
.
layout
.
dtype
.
param
<
dtype
::
Quantized4Asymm
>
().
scale
;
uint8_t
zero_point
=
src
.
layout
.
dtype
.
param
<
dtype
::
Quantized4Asymm
>
().
zero_point
;
comp_src
.
layout
.
dtype
=
dtype
::
Quantized8Asymm
(
scale
,
zero_point
);
comp_src
.
layout
.
format
=
TensorLayout
::
Format
(
comp_src
.
layout
.
dtype
);
comp_src
.
layout
.
init_contiguous_stride
();
comp_src
.
raw_ptr
=
wsb
.
get
(
0
);
comp_dst
.
layout
.
dtype
=
dtype
::
Quantized8Asymm
(
scale
,
zero_point
);
comp_dst
.
layout
.
format
=
TensorLayout
::
Format
(
comp_dst
.
layout
.
dtype
);
comp_dst
.
layout
.
init_contiguous_stride
();
comp_dst
.
raw_ptr
=
wsb
.
get
(
1
);
uint4_to_uint8
(
src
,
comp_src
);
}
size_t
c_pos
,
spatial_pos
,
batch_pos
=
0
;
if
(
param
().
format
==
Param
::
Format
::
NCHW
||
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW32
)
{
param
().
format
==
Param
::
Format
::
NCHW32
||
param
().
format
==
Param
::
Format
::
NCHW64
)
{
c_pos
=
1
;
spatial_pos
=
2
;
}
else
if
(
param
().
format
==
Param
::
Format
::
NHWC
)
{
...
...
@@ -398,27 +471,35 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
c_pos
=
2
;
spatial_pos
=
1
;
}
size_t
N
=
src
.
layout
.
shape
[
batch_pos
],
C
=
src
.
layout
.
shape
[
c_pos
],
IH
=
src
.
layout
.
shape
[
spatial_pos
+
0
],
IW
=
src
.
layout
.
shape
[
spatial_pos
+
1
];
size_t
OH
=
dst
.
layout
.
shape
[
spatial_pos
+
0
],
OW
=
dst
.
layout
.
shape
[
spatial_pos
+
1
];
if
(
param
().
format
==
Param
::
Format
::
NHWCD4
)
{
C
*=
4
;
IW
=
src
.
layout
.
shape
[
spatial_pos
+
2
];
OW
=
dst
.
layout
.
shape
[
spatial_pos
+
2
];
}
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
CHWN4
)
{
C
*=
4
;
}
if
(
param
().
format
==
Param
::
Format
::
NCHW88
)
{
C
*=
8
;
}
if
(
param
().
format
==
Param
::
Format
::
NCHW32
)
{
C
*=
32
;
size_t
N
=
comp_src
.
layout
.
shape
[
batch_pos
],
C
=
comp_src
.
layout
.
shape
[
c_pos
],
IH
=
comp_src
.
layout
.
shape
[
spatial_pos
+
0
],
IW
=
comp_src
.
layout
.
shape
[
spatial_pos
+
1
];
size_t
OH
=
comp_dst
.
layout
.
shape
[
spatial_pos
+
0
],
OW
=
comp_dst
.
layout
.
shape
[
spatial_pos
+
1
];
switch
(
param
().
format
)
{
case
Param
::
Format
::
NHWCD4
:
C
*=
4
;
IW
=
comp_src
.
layout
.
shape
[
spatial_pos
+
2
];
OW
=
comp_dst
.
layout
.
shape
[
spatial_pos
+
2
];
break
;
case
Param
::
Format
::
NCHW4
:
case
Param
::
Format
::
NCHW44
:
case
Param
::
Format
::
CHWN4
:
C
*=
4
;
break
;
case
Param
::
Format
::
NCHW88
:
C
*=
8
;
break
;
case
Param
::
Format
::
NCHW32
:
C
*=
32
;
break
;
case
Param
::
Format
::
NCHW64
:
C
*=
64
;
break
;
default:
;
}
size_t
PH
=
param
().
pad_h
,
PW
=
param
().
pad_w
;
size_t
FH
=
param
().
window_h
,
FW
=
param
().
window_w
;
size_t
SH
=
param
().
stride_h
,
SW
=
param
().
stride_w
;
...
...
@@ -427,8 +508,8 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle()), \
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \
sptr, dptr,
src.layout.dtype, N, C, IH, IW, OH, OW,
\
PH, PW, SH, SW, FH, FW));
\
sptr, dptr,
comp_src.layout.dtype, N, C, IH, IW, OH,
\
OW, PH, PW, SH, SW, FH, FW));
\
} \
MIDOUT_END();
...
...
@@ -455,6 +536,9 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
case Param::Format::NCHW32: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW32IdxGetter); \
break; \
case Param::Format::NCHW64: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW64IdxGetter); \
break; \
case Param::Format::CHWN4: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, CHWN4IdxGetter); \
break; \
...
...
@@ -462,30 +546,35 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
megdnn_throw("invalid pooling format"); \
}
#define cb(DType) \
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
switch (param().mode) { \
case Mode::MAX: { \
auto sptr = src.ptr<ctype>(); \
auto dptr = dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MaxPooler<ctype>); \
return; \
} \
case Mode::AVERAGE: { \
auto sptr = src.ptr<ctype>(); \
auto dptr = dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MeanIncludePooler<ctype>); \
return; \
} \
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: { \
auto sptr = src.ptr<ctype>(); \
auto dptr = dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MeanExcludePooler<ctype>); \
return; \
} \
} \
#define cb(DType) \
if (comp_src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
switch (param().mode) { \
case Mode::MAX: { \
auto sptr = comp_src.ptr<ctype>(); \
auto dptr = comp_dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MaxPooler<ctype>); \
break; \
} \
case Mode::AVERAGE: { \
auto sptr = comp_src.ptr<ctype>(); \
auto dptr = comp_dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MeanIncludePooler<ctype>); \
break; \
} \
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: { \
auto sptr = comp_src.ptr<ctype>(); \
auto dptr = comp_dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MeanExcludePooler<ctype>); \
break; \
} \
default: \
megdnn_assert(0, "not support mode"); \
} \
post_process(dst, comp_dst, handle(), wsb); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
#undef cb
...
...
dnn/src/naive/pooling/opr_impl.h
浏览文件 @
2398df07
...
...
@@ -20,10 +20,12 @@ class PoolingForwardImpl: public PoolingForward {
using
PoolingForward
::
PoolingForward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
)
override
;
private:
WorkspaceBundle
get_workspace_bundle
(
void
*
ptr
,
const
TensorLayout
&
,
const
TensorLayout
&
)
const
;
};
class
PoolingBackwardImpl
:
public
PoolingBackward
{
...
...
dnn/test/common/checker.h
浏览文件 @
2398df07
...
...
@@ -414,34 +414,34 @@ TensorND TensorValue(const TensorShape& shape, T dtype,
template
<
typename
T
,
typename
U
>
TensorND
TensorValueLowbit4
(
const
TensorShape
&
shape
,
T
dtype
,
std
::
vector
<
U
>
values
)
{
std
::
vector
<
U
>
values
)
{
TensorND
tensor
;
tensor
.
layout
=
{
shape
,
dtype
};
tensor
.
raw_ptr
=
static_cast
<
dt_byte
*>
(
malloc
(
tensor
.
layout
.
span
().
dist_byte
()));
megdnn_assert
(
values
.
size
()
==
tensor
.
layout
.
total_nr_elems
());
auto
ptr
=
tensor
.
ptr
<
typename
DTypeTrait
<
T
>::
ctype
>
();
size_t
i
;
for
(
i
=
0
;
i
+
1
<
values
.
size
();
i
+=
2
)
{
U
val0
=
values
[
i
],
val1
=
values
[
i
+
1
];
megdnn_assert
(
val0
>=
DTypeTrait
<
T
>::
min
());
megdnn_assert
(
val1
<=
DTypeTrait
<
T
>::
max
());
ptr
[
i
/
2
]
=
typename
DTypeTrait
<
T
>::
ctype
((
val0
&
0xF
)
|
(
val1
<<
4
));
}
if
(
i
<
values
.
size
())
{
U
val0
=
values
[
i
];
megdnn_assert
(
val0
>=
DTypeTrait
<
T
>::
min
()
&&
val0
<=
DTypeTrait
<
T
>::
max
());
if
(
i
+
1
<
values
.
size
())
{
U
val1
=
values
[
i
+
1
];
megdnn_assert
(
val1
>=
DTypeTrait
<
T
>::
min
()
&&
val1
<=
DTypeTrait
<
T
>::
max
());
ptr
[
i
/
2
]
=
typename
DTypeTrait
<
T
>::
ctype
((
val0
&
0xF
)
|
(
val1
<<
4
));
}
else
{
ptr
[
i
/
2
]
=
typename
DTypeTrait
<
T
>::
ctype
(
val0
&
0xF
);
auto
layout
=
tensor
.
layout
;
auto
dim_in
=
shape
[
layout
.
ndim
-
1
];
auto
elems
=
tensor
.
layout
.
total_nr_elems
();
auto
dim_out
=
elems
/
dim_in
;
auto
stride_out
=
div_ceil
(
dim_in
,
2
_z
);
size_t
in_offset
=
0
;
for
(
size_t
i
=
0
;
i
<
dim_out
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
dim_in
;
j
+=
2
)
{
U
a
=
values
[
in_offset
+
j
];
U
b
=
0
;
if
(
j
+
1
<
dim_in
)
b
=
values
[
in_offset
+
j
+
1
];
megdnn_assert
(
a
>=
DTypeTrait
<
T
>::
min
());
megdnn_assert
(
a
<=
DTypeTrait
<
T
>::
max
());
megdnn_assert
(
b
>=
DTypeTrait
<
T
>::
min
());
megdnn_assert
(
b
<=
DTypeTrait
<
T
>::
max
());
ptr
[
j
/
2
]
=
(
a
&
0xF
)
|
(
b
<<
4
);
}
in_offset
+=
dim_in
;
ptr
+=
stride_out
;
}
return
tensor
;
}
...
...
dnn/test/cuda/pooling.cpp
浏览文件 @
2398df07
...
...
@@ -242,6 +242,20 @@ TEST_F(CUDA, POOLING_BACKWARD)
.
exec
(
TensorShapeArray
{
ilayout
,
olayout
,
olayout
,
ilayout
});
}
}
TEST_F
(
CUDA
,
POOLING_FORWARD_NCHW_Q4
)
{
require_compute_capability
(
7
,
5
);
using
Param
=
param
::
Pooling
;
Checker
<
Pooling
>
checker
(
handle_cuda
());
Param
param
{
Param
::
Mode
::
MAX
,
0
,
0
,
2
,
2
,
2
,
2
};
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
0.1
f
));
param
.
format
=
Param
::
Format
::
NCHW
;
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
set_param
(
param
).
exec
({{
20
,
64
,
22
,
33
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE
;
checker
.
set_param
(
param
).
exec
({{
20
,
64
,
22
,
33
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
;
checker
.
set_param
(
param
).
exec
({{
20
,
64
,
22
,
33
},
{}});
}
TEST_F
(
CUDA
,
POOLING_FORWARD_NCHW4
)
{
require_compute_capability
(
7
,
5
);
...
...
@@ -252,6 +266,10 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW4) {
param
.
format
=
Param
::
Format
::
NCHW4
;
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
set_param
(
param
).
exec
({{
20
,
3
,
50
,
50
,
4
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE
;
checker
.
set_param
(
param
).
exec
({{
20
,
3
,
50
,
50
,
4
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
;
checker
.
set_param
(
param
).
exec
({{
20
,
3
,
50
,
50
,
4
},
{}});
}
#if CUDNN_VERSION >= 7500
...
...
@@ -267,9 +285,29 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW32) {
param
.
format
=
Param
::
Format
::
NCHW32
;
checker
.
set_epsilon
(
1e-3
).
set_rng
(
0
,
&
int_rng
);
checker
.
set_param
(
param
).
exec
({{
64
,
8
,
28
,
28
,
32
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE
;
checker
.
set_param
(
param
).
exec
({{
64
,
8
,
28
,
28
,
64
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
;
checker
.
set_param
(
param
).
exec
({{
64
,
8
,
28
,
28
,
64
},
{}});
}
#endif
TEST_F
(
CUDA
,
POOLING_FORWARD_NCHW64
)
{
require_compute_capability
(
7
,
5
);
using
Param
=
param
::
Pooling
;
Checker
<
Pooling
>
checker
(
handle_cuda
());
Param
param
{
Param
::
Mode
::
MAX
,
0
,
0
,
2
,
2
,
2
,
2
};
UniformIntRNG
int_rng
{
-
8
,
7
};
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
1.
f
));
param
.
format
=
Param
::
Format
::
NCHW64
;
checker
.
set_epsilon
(
1e-3
).
set_rng
(
0
,
&
int_rng
);
checker
.
set_param
(
param
).
exec
({{
64
,
8
,
28
,
28
,
64
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE
;
checker
.
set_param
(
param
).
exec
({{
64
,
8
,
28
,
28
,
64
},
{}});
param
.
mode
=
Param
::
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
;
checker
.
set_param
(
param
).
exec
({{
64
,
8
,
28
,
28
,
64
},
{}});
}
TEST_F
(
CUDA
,
POOLING_FORWARD_CHWN4
)
{
require_compute_capability
(
6
,
1
);
using
Param
=
param
::
Pooling
;
...
...
dnn/test/naive/pooling.cpp
浏览文件 @
2398df07
...
...
@@ -50,4 +50,63 @@ TEST_F(NAIVE, POOLING_QUANTIZED) {
12306
,
23333
})});
}
TEST_F
(
NAIVE
,
POOLING_QUANTIZED_Q4
)
{
using
Mode
=
Pooling
::
Param
::
Mode
;
Checker
<
Pooling
>
checker
(
handle
(),
/* check_dispatch */
false
);
{
auto
q4_dt
=
dtype
::
QuantizedS4
(
1.
f
);
std
::
vector
<
int
>
i8_src_vec
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
-
1
,
-
2
};
std
::
vector
<
int
>
i8_max_dst_vec
{
1
,
3
,
7
,
6
};
std
::
vector
<
int
>
i8_avg_dst_vec
{
0
,
1
,
3
,
2
};
std
::
vector
<
int
>
i8_avg_exclu_dst_vec
{
1
,
3
,
6
,
2
};
Pooling
::
Param
param
{
Mode
::
MAX
,
1
,
1
,
2
,
2
,
2
,
2
};
Testcase
input
{
TensorValueLowbit4
({
1
,
1
,
3
,
3
},
q4_dt
,
i8_src_vec
),
{}};
checker
.
set_param
(
param
).
exect
(
input
,
Testcase
{{},
TensorValueLowbit4
({
1
,
1
,
2
,
2
},
q4_dt
,
i8_max_dst_vec
)});
param
=
{
Mode
::
AVERAGE
,
1
,
1
,
2
,
2
,
2
,
2
};
checker
.
set_param
(
param
).
exect
(
input
,
Testcase
{{},
TensorValueLowbit4
({
1
,
1
,
2
,
2
},
q4_dt
,
i8_avg_dst_vec
)});
param
=
{
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
,
1
,
1
,
2
,
2
,
2
,
2
};
checker
.
set_param
(
param
).
exect
(
input
,
Testcase
{{},
TensorValueLowbit4
({
1
,
1
,
2
,
2
},
q4_dt
,
i8_avg_exclu_dst_vec
)});
}
{
auto
u4_dt
=
dtype
::
Quantized4Asymm
(
1.
f
,
0
);
std
::
vector
<
int
>
u8_src_vec
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
std
::
vector
<
int
>
u8_max_dst_vec
{
1
,
3
,
7
,
9
};
std
::
vector
<
int
>
u8_avg_dst_vec
{
0
,
1
,
3
,
7
};
std
::
vector
<
int
>
u8_avg_exclu_dst_vec
{
1
,
3
,
6
,
7
};
Pooling
::
Param
param
{
Mode
::
MAX
,
1
,
1
,
2
,
2
,
2
,
2
};
Testcase
input
{
TensorValueLowbit4
({
1
,
1
,
3
,
3
},
u4_dt
,
u8_src_vec
),
{}};
checker
.
set_param
(
param
).
exect
(
input
,
Testcase
{{},
TensorValueLowbit4
({
1
,
1
,
2
,
2
},
u4_dt
,
u8_max_dst_vec
)});
param
=
{
Mode
::
AVERAGE
,
1
,
1
,
2
,
2
,
2
,
2
};
checker
.
set_param
(
param
).
exect
(
input
,
Testcase
{{},
TensorValueLowbit4
({
1
,
1
,
2
,
2
},
u4_dt
,
u8_avg_dst_vec
)});
param
=
{
Mode
::
AVERAGE_COUNT_EXCLUDE_PADDING
,
1
,
1
,
2
,
2
,
2
,
2
};
checker
.
set_param
(
param
).
exect
(
input
,
Testcase
{{},
TensorValueLowbit4
({
1
,
1
,
2
,
2
},
u4_dt
,
u8_avg_exclu_dst_vec
)});
}
}
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录