Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
19919384
MegEngine
项目概览
MegEngine 天元
/
MegEngine
10 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
19919384
编写于
5月 17, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add cuda uint warp perspective
GitOrigin-RevId: 2aec72010f81ad92b726924fcdc4b65069ec0cab
上级
01354337
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
334 addition
and
85 deletion
+334
-85
dnn/src/common/rounding_converter.cuh
dnn/src/common/rounding_converter.cuh
+10
-0
dnn/src/common/warp_perspective.cpp
dnn/src/common/warp_perspective.cpp
+6
-4
dnn/src/cuda/warp_perspective/forward.cpp
dnn/src/cuda/warp_perspective/forward.cpp
+36
-5
dnn/src/cuda/warp_perspective/forward.cu
dnn/src/cuda/warp_perspective/forward.cu
+129
-59
dnn/src/naive/warp_perspective/opr_impl.cpp
dnn/src/naive/warp_perspective/opr_impl.cpp
+32
-13
dnn/src/naive/warp_perspective/opr_impl.h
dnn/src/naive/warp_perspective/opr_impl.h
+2
-1
dnn/test/cuda/warp_perspective.cpp
dnn/test/cuda/warp_perspective.cpp
+92
-1
dnn/test/naive/warp_perspective.cpp
dnn/test/naive/warp_perspective.cpp
+27
-2
未找到文件。
dnn/src/common/rounding_converter.cuh
浏览文件 @
19919384
...
...
@@ -86,6 +86,16 @@ struct RoundingConverter<dt_qint4> {
}
};
template
<
>
struct
RoundingConverter
<
dt_quint4
>
{
__host__
__device__
__forceinline__
dt_quint4
operator
()(
float
x
)
const
{
#if MEGDNN_CC_HOST
using
std
::
round
;
#endif
return
static_cast
<
dt_quint4
>
(
round
(
x
));
}
};
}
// namespace rounding
}
// namespace megdnn
...
...
dnn/src/common/warp_perspective.cpp
浏览文件 @
19919384
...
...
@@ -73,9 +73,10 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
||
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
,
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
,
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8"
DNN_FLOAT16_SELECT
(
"Float32/Int8/Uint8/QInt8/QUint8
/QInt4/QUInt4
"
DNN_FLOAT16_SELECT
(
"/Float16/BFloat16"
,
""
)
"."
);
megdnn_assert
(
(
src
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
&&
...
...
@@ -118,8 +119,9 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
ISOLATED
);
}
else
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW64
)
{
megdnn_assert
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
,
"src expected QuantizedS4, but got %s"
,
megdnn_assert
((
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
),
"src expected QuantizedS4/Quantized4Asymm, but got %s"
,
src
.
dtype
.
name
());
megdnn_assert
(
mat
.
dtype
==
dtype
::
Float32
(),
"matrix dtype expected float, got %s"
,
...
...
dnn/src/cuda/warp_perspective/forward.cpp
浏览文件 @
19919384
...
...
@@ -44,8 +44,9 @@ void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
TensorLayout
&
inner_src
,
TensorLayout
&
inner_dst
,
Handle
*
handle
,
WarpPerspectiveForwardImpl
::
Param
::
Format
format
)
{
if
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
dst
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
if
((
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
dst
.
dtype
.
enumv
()
==
src
.
dtype
.
enumv
()
&&
format
==
param
::
WarpPerspective
::
Format
::
NCHW
)
{
auto
relayout_opr
=
handle
->
create_operator
<
RelayoutFormat
>
();
deduce_reformat_layout
(
relayout_opr
,
src
,
inner_src
,
...
...
@@ -130,7 +131,8 @@ WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
TensorLayout
fsrc
=
src
;
TensorLayout
fmat
=
mat
;
TensorLayout
fdst
=
dst
;
if
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
if
((
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
)
{
get_inner_layout
(
src
,
dst
,
fsrc
,
fdst
,
handle
(),
param
().
format
);
sizes
.
push_back
(
fsrc
.
span
().
dist_byte
());
...
...
@@ -177,7 +179,8 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
ctypecvt
.
src_to_comp_type
(
ssrc
,
src
)
.
src_to_comp_type
(
smat
,
mat
)
.
src_to_comp_type
(
sdst
,
dst
);
}
else
if
(
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
}
else
if
((
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
ssrc
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
param
().
format
==
Param
::
Format
::
NCHW
)
{
auto
handle_ptr
=
handle
();
get_inner_layout
(
ssrc
.
layout
,
sdst
.
layout
,
src
.
layout
,
dst
.
layout
,
...
...
@@ -330,7 +333,7 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
param
().
format
==
Param
::
Format
::
NCHW64
||
param
().
format
==
Param
::
Format
::
NCHW
,
"WarpPerspective on CUDA supports NCHW64 or NCHW+ "
"QuantizedS4
only
"
);
"QuantizedS4"
);
bval
=
roundf
(
bval
);
bval
=
fmin
(
fmax
(
-
8.
f
,
bval
),
7.
f
);
warp_perspective
::
forward_proxy_nchw64
<
dt_qint4
>
(
...
...
@@ -352,6 +355,34 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
relayout_opr
->
param
()
=
trans_param
;
relayout_opr
->
exec
(
dst
,
sdst
,
{});
}
}
else
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NCHW64
||
param
().
format
==
Param
::
Format
::
NCHW
,
"WarpPerspective on CUDA supports NCHW64 or NCHW+ "
"Quantized4Asymm"
);
bval
=
roundf
(
bval
);
bval
=
fmin
(
fmax
(
0
,
bval
),
15
);
warp_perspective
::
forward_proxy_nchw64
<
dt_quint4
>
(
src
.
compatible_ptr
<
dt_quint4
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
compatible_ptr
<
dt_quint4
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
static_cast
<
dt_quint4
>
(
bval
),
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
if
(
param
().
format
==
Param
::
Format
::
NCHW
)
{
auto
relayout_opr
=
handle
()
->
create_operator
<
RelayoutFormat
>
();
RelayoutFormat
::
Param
trans_param
;
trans_param
.
mode
=
RelayoutFormat
::
Param
::
Mode
::
NCHW64_NCHW
;
trans_param
.
oc
=
sdst
.
layout
[
1
];
relayout_opr
->
param
()
=
trans_param
;
relayout_opr
->
exec
(
dst
,
sdst
,
{});
}
}
}
else
if
((
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
...
...
dnn/src/cuda/warp_perspective/forward.cu
浏览文件 @
19919384
...
...
@@ -144,25 +144,68 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat,
}
}
#define warp_perspective_transform(idx) \
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
int
transform_int8_to_bit4x8
(
int
s0
,
int
s1
,
int
s2
,
int
s3
,
int
s4
,
int
s5
,
int
s6
,
int
s7
);
template
<
>
MEGDNN_DEVICE
__forceinline__
int
transform_int8_to_bit4x8
<
true
>
(
int
s0
,
int
s1
,
int
s2
,
int
s3
,
int
s4
,
int
s5
,
int
s6
,
int
s7
)
{
return
transform_int8_to_int4x8
(
s0
,
s1
,
s2
,
s3
,
s4
,
s5
,
s6
,
s7
);
}
template
<
>
MEGDNN_DEVICE
__forceinline__
int
transform_int8_to_bit4x8
<
false
>
(
int
s0
,
int
s1
,
int
s2
,
int
s3
,
int
s4
,
int
s5
,
int
s6
,
int
s7
)
{
return
transform_int8_to_uint4x8
(
s0
,
s1
,
s2
,
s3
,
s4
,
s5
,
s6
,
s7
);
}
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
void
transform_bit4x8_to_int8
(
int
(
&
result
)[
8
],
const
int
&
source
);
template
<
>
MEGDNN_DEVICE
__forceinline__
void
transform_bit4x8_to_int8
<
true
>
(
int
(
&
result
)[
8
],
const
int
&
source
){
transform_int4x8_to_int8
(
result
,
source
);
}
template
<
>
MEGDNN_DEVICE
__forceinline__
void
transform_bit4x8_to_int8
<
false
>
(
int
(
&
result
)[
8
],
const
int
&
source
){
transform_uint4x8_to_int8
(
result
,
source
);
}
template
<
bool
signedness
,
typename
OutputConverter
>
MEGDNN_DEVICE
__forceinline__
int
pack_output_func
(
OutputConverter
&
output_converter
,
int
(
&
s00
)[
8
],
int
(
&
s01
)[
8
],
int
(
&
s10
)[
8
],
int
(
&
s11
)[
8
],
float
palpha
,
float
pbeta
,
float
nalpha
,
float
nbeta
)
{
#define warp_perspective_transform(idx) \
static_cast<int>(output_converter(s00[idx] * nalpha * nbeta + \
s01[idx] * nalpha * pbeta + \
s10[idx] * palpha * nbeta + \
s11[idx] * palpha * pbeta) \
.as_int8())
#define pack_output \
transform_int8_to_int4x8( \
warp_perspective_transform(0), warp_perspective_transform(1), \
warp_perspective_transform(2), warp_perspective_transform(3), \
warp_perspective_transform(4), warp_perspective_transform(5), \
warp_perspective_transform(6), warp_perspective_transform(7))
.as_storage())
return
transform_int8_to_bit4x8
<
signedness
>
(
warp_perspective_transform
(
0
),
warp_perspective_transform
(
1
),
warp_perspective_transform
(
2
),
warp_perspective_transform
(
3
),
warp_perspective_transform
(
4
),
warp_perspective_transform
(
5
),
warp_perspective_transform
(
6
),
warp_perspective_transform
(
7
));
#undef warp_perspective_transform
}
template
<
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_general_nchw64
(
SrcVisitor
src
,
const
float
*
__restrict
mat
,
ctype
*
__restrict
dst
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
)
{
constexpr
bool
signedness
=
std
::
is_same
<
ctype
,
dt_qint4
>::
value
;
Getter
getter
;
OutputConverter
output_converter
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -199,29 +242,37 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat,
s
[
2
]
=
__ldg
(
sptr_int4
+
i_coor_10
+
c1
);
s
[
3
]
=
__ldg
(
sptr_int4
+
i_coor_11
+
c1
);
transform_int4x8_to_int8
(
s00
,
s
[
0
].
x
);
transform_int4x8_to_int8
(
s01
,
s
[
1
].
x
);
transform_int4x8_to_int8
(
s10
,
s
[
2
].
x
);
transform_int4x8_to_int8
(
s11
,
s
[
3
].
x
);
d
.
x
=
pack_output
;
transform_int4x8_to_int8
(
s00
,
s
[
0
].
y
);
transform_int4x8_to_int8
(
s01
,
s
[
1
].
y
);
transform_int4x8_to_int8
(
s10
,
s
[
2
].
y
);
transform_int4x8_to_int8
(
s11
,
s
[
3
].
y
);
d
.
y
=
pack_output
;
transform_int4x8_to_int8
(
s00
,
s
[
0
].
z
);
transform_int4x8_to_int8
(
s01
,
s
[
1
].
z
);
transform_int4x8_to_int8
(
s10
,
s
[
2
].
z
);
transform_int4x8_to_int8
(
s11
,
s
[
3
].
z
);
d
.
z
=
pack_output
;
transform_int4x8_to_int8
(
s00
,
s
[
0
].
w
);
transform_int4x8_to_int8
(
s01
,
s
[
1
].
w
);
transform_int4x8_to_int8
(
s10
,
s
[
2
].
w
);
transform_int4x8_to_int8
(
s11
,
s
[
3
].
w
);
d
.
w
=
pack_output
;
transform_bit4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
x
);
transform_bit4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
x
);
transform_bit4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
x
);
transform_bit4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
x
);
d
.
x
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
palpha
,
pbeta
,
nalpha
,
nbeta
);
transform_bit4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
y
);
transform_bit4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
y
);
transform_bit4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
y
);
transform_bit4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
y
);
d
.
y
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
palpha
,
pbeta
,
nalpha
,
nbeta
);
transform_bit4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
z
);
transform_bit4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
z
);
transform_bit4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
z
);
transform_bit4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
z
);
d
.
z
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
palpha
,
pbeta
,
nalpha
,
nbeta
);
transform_bit4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
w
);
transform_bit4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
w
);
transform_bit4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
w
);
transform_bit4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
w
);
d
.
w
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
palpha
,
pbeta
,
nalpha
,
nbeta
);
dst_int4
[
o_coor
+
c1
]
=
d
;
sptr_int4
+=
IH
*
IW
*
2
;
...
...
@@ -320,15 +371,25 @@ __global__ void kern_const_border_nchw4(SrcVisitor src,
}
}
}
template
<
bool
signedness
>
MEGDNN_DEVICE
__forceinline__
static
void
transform_bit4x8_to_int8
(
int
(
&
result
)[
8
],
const
int
&
source
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
result
[
i
]
=
unpack_integer_4bits
<
signedness
>
(
reinterpret_cast
<
unsigned
const
&>
(
source
),
(
i
<<
2
));
}
}
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_const_border_nchw64
(
SrcVisitor
src
,
const
float
*
__restrict
mat
,
ctype
*
__restrict
dst
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
ctype
bval
)
{
constexpr
bool
signedness
=
std
::
is_same
<
ctype
,
dt_qint4
>::
value
;
OutputConverter
output_converter
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
c1
=
ow
%
2
;
int
c1
=
ow
%
2
;
ow
=
ow
/
2
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
ctype
*
__restrict
sptr
=
src
.
get
(
blockIdx
.
z
,
C
*
IH
*
IW
/
2
);
...
...
@@ -359,9 +420,9 @@ __global__ void kern_const_border_nchw64(SrcVisitor src,
int
i_coor_11
=
(
ih1
*
IW
+
iw1
)
<<
1
;
bool
flag00
=
okh0
&&
okw0
,
flag01
=
okh0
&&
okw1
,
flag10
=
okh1
&&
okw0
,
flag11
=
okh1
&&
okw1
;
int8_t
bval_4
=
bval
.
as_
int8
()
&
0xF
;
int
bval_8
=
transform_int8_to_
int4x8
(
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
);
int8_t
bval_4
=
bval
.
as_
storage
()
&
0xF
;
int
bval_8
=
transform_int8_to_
bit4x8
<
signedness
>
(
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
,
bval_4
);
int4
bval_int4
;
bval_int4
.
x
=
bval_8
;
bval_int4
.
y
=
bval_8
;
...
...
@@ -391,29 +452,37 @@ __global__ void kern_const_border_nchw64(SrcVisitor src,
s
[
3
]
=
bval_int4
;
}
transform_int4x8_to_int8
(
s00
,
s
[
0
].
x
);
transform_int4x8_to_int8
(
s01
,
s
[
1
].
x
);
transform_int4x8_to_int8
(
s10
,
s
[
2
].
x
);
transform_int4x8_to_int8
(
s11
,
s
[
3
].
x
);
d
.
x
=
pack_output
;
transform_int4x8_to_int8
(
s00
,
s
[
0
].
y
);
transform_int4x8_to_int8
(
s01
,
s
[
1
].
y
);
transform_int4x8_to_int8
(
s10
,
s
[
2
].
y
);
transform_int4x8_to_int8
(
s11
,
s
[
3
].
y
);
d
.
y
=
pack_output
;
transform_int4x8_to_int8
(
s00
,
s
[
0
].
z
);
transform_int4x8_to_int8
(
s01
,
s
[
1
].
z
);
transform_int4x8_to_int8
(
s10
,
s
[
2
].
z
);
transform_int4x8_to_int8
(
s11
,
s
[
3
].
z
);
d
.
z
=
pack_output
;
transform_int4x8_to_int8
(
s00
,
s
[
0
].
w
);
transform_int4x8_to_int8
(
s01
,
s
[
1
].
w
);
transform_int4x8_to_int8
(
s10
,
s
[
2
].
w
);
transform_int4x8_to_int8
(
s11
,
s
[
3
].
w
);
d
.
w
=
pack_output
;
transform_bit4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
x
);
transform_bit4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
x
);
transform_bit4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
x
);
transform_bit4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
x
);
d
.
x
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
palpha
,
pbeta
,
nalpha
,
nbeta
);
transform_bit4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
y
);
transform_bit4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
y
);
transform_bit4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
y
);
transform_bit4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
y
);
d
.
y
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
palpha
,
pbeta
,
nalpha
,
nbeta
);
transform_bit4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
z
);
transform_bit4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
z
);
transform_bit4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
z
);
transform_bit4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
z
);
d
.
z
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
palpha
,
pbeta
,
nalpha
,
nbeta
);
transform_bit4x8_to_int8
<
signedness
>
(
s00
,
s
[
0
].
w
);
transform_bit4x8_to_int8
<
signedness
>
(
s01
,
s
[
1
].
w
);
transform_bit4x8_to_int8
<
signedness
>
(
s10
,
s
[
2
].
w
);
transform_bit4x8_to_int8
<
signedness
>
(
s11
,
s
[
3
].
w
);
d
.
w
=
pack_output_func
<
signedness
>
(
output_converter
,
s00
,
s01
,
s10
,
s11
,
palpha
,
pbeta
,
nalpha
,
nbeta
);
dst_int4
[
o_coor
+
c1
]
=
d
;
sptr_int4
+=
IH
*
IW
*
2
;
...
...
@@ -1448,6 +1517,7 @@ INST(int8_t)
void*, cudaStream_t);
INST
(
dt_qint4
)
INST
(
dt_quint4
)
#undef INST
template
<
typename
src_dtype
,
typename
src_ctype
,
typename
dst_ctype
>
...
...
dnn/src/naive/warp_perspective/opr_impl.cpp
浏览文件 @
19919384
...
...
@@ -249,6 +249,7 @@ void WarpPerspectiveForwardImpl::kern_naive_nhwcd4(
MIDOUT_END
();
}
template
<
typename
ctype
,
typename
mtype
>
void
WarpPerspectiveForwardImpl
::
kern_naive_int4
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
)
{
...
...
@@ -257,6 +258,7 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM
(
kern_param
);
MEGDNN_MARK_USED_VAR
(
N_MAT
);
uint8_t
c_shift
,
c_mask
,
iw_shift
=
0
,
ow_shift
=
0
;
constexpr
bool
signedness
=
std
::
is_same
<
ctype
,
dt_qint4
>::
value
;
switch
(
param
().
format
)
{
case
Format
::
NCHW
:
c_shift
=
0
;
...
...
@@ -282,8 +284,13 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
<<
c_shift
)
+
(
c
&
c_mask
);
uint8_t
result
=
(
sptr
[
index
/
2
].
as_int8
()
>>
(
4
*
(
index
%
2
)))
&
0xF
;
return
result
&
uint8_t
(
1
<<
3
)
?
result
|
~
mask
:
result
;
(
sptr
[
index
/
2
].
as_storage
()
>>
(
4
*
(
index
%
2
)))
&
0xF
;
if
(
signedness
)
{
return
result
&
uint8_t
(
1
<<
3
)
?
result
|
~
mask
:
result
;
}
else
{
megdnn_assert
((
std
::
is_same
<
ctype
,
dt_quint4
>::
value
));
return
result
;
}
};
auto
visit_src_bd
=
[
&
sptr
,
sstrd
,
border_val
,
c_shift
,
c_mask
](
size_t
c
,
int
h
,
int
w
)
->
float
{
...
...
@@ -292,8 +299,14 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
<<
c_shift
)
+
(
c
&
c_mask
);
uint8_t
result
=
(
sptr
[
index
/
2
].
as_int8
()
>>
(
4
*
(
index
%
2
)))
&
0xF
;
return
result
&
uint8_t
(
1
<<
3
)
?
result
|
~
mask
:
result
;
(
sptr
[
index
/
2
].
as_storage
()
>>
(
4
*
(
index
%
2
)))
&
0xF
;
if
(
signedness
)
{
return
result
&
uint8_t
(
1
<<
3
)
?
result
|
~
mask
:
result
;
}
else
{
megdnn_assert
((
std
::
is_same
<
ctype
,
dt_quint4
>::
value
));
return
result
;;
}
}
else
return
border_val
;
};
...
...
@@ -302,9 +315,9 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
size_t
index
=
((
dstrd
[
0
]
*
(
c
>>
c_shift
)
+
dstrd
[
1
]
*
h
+
w
)
<<
c_shift
)
+
(
c
&
c_mask
);
dptr
[
index
/
2
]
=
(
dptr
[
index
/
2
].
as_int8
()
&
(
0xF0
>>
(
4
*
(
index
%
2
))))
|
(
v
.
as_int8
()
<<
(
4
*
(
index
%
2
)));
dptr
[
index
/
2
]
=
(
dptr
[
index
/
2
].
as_storage
()
&
(
0xF0
>>
(
4
*
(
index
%
2
))))
|
(
v
.
as_storage
()
<<
(
4
*
(
index
%
2
)));
};
rounding
::
RoundingConverter
<
ctype
>
output_converter
;
...
...
@@ -334,21 +347,20 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
int
iw1
=
get_real_coord
(
std
::
floor
(
alphaw
)
+
1
,
IW
);
int
ih0
=
get_real_coord
(
std
::
floor
(
alphah
)
+
0
,
IH
);
int
ih1
=
get_real_coord
(
std
::
floor
(
alphah
)
+
1
,
IH
);
alphaw
-=
floor
(
alphaw
);
alphah
-=
floor
(
alphah
);
if
(
bmode
!=
BorderMode
::
CONSTANT
)
{
rep
(
c
,
C
)
{
set_visit_dst
(
c
,
oh
,
ow
,
output_converter
(
visit_src
(
c
,
ih0
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
auto
val
=
visit_src
(
c
,
ih0
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
(
1.0
f
-
alphah
)
+
visit_src
(
c
,
ih0
,
iw1
)
*
alphaw
*
(
1.0
f
-
alphah
)
+
visit_src
(
c
,
ih1
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
alphah
+
visit_src
(
c
,
ih1
,
iw1
)
*
alphaw
*
alphah
));
visit_src
(
c
,
ih1
,
iw1
)
*
alphaw
*
alphah
;
set_visit_dst
(
c
,
oh
,
ow
,
output_converter
(
val
));
}
}
else
{
rep
(
c
,
C
)
{
...
...
@@ -613,6 +625,13 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
"WarpPerspective: %s"
,
src
.
layout
.
dtype
.
name
())
.
c_str
());
}
else
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
Quantized4Asymm
>::
enumv
)
{
DISPATCH_ST
(
dtype
::
Quantized4Asymm
,
dt_quint4
,
float
,
KERN_INT4
);
megdnn_throw
(
ssprintf
(
"Unsupported input DType in "
"WarpPerspective: %s"
,
src
.
layout
.
dtype
.
name
())
.
c_str
());
}
bool
is_fusion_dtype
=
src
.
layout
.
dtype
.
enumv
()
!=
dst
.
layout
.
dtype
.
enumv
();
...
...
dnn/src/naive/warp_perspective/opr_impl.h
浏览文件 @
19919384
...
...
@@ -107,7 +107,8 @@ protected:
ret
.
mptr
=
mat
.
ptr
<
mtype
>
();
ret
.
dptr
=
dst
.
compatible_ptr
<
ctype
>
();
}
else
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
ret
.
sptr
=
src
.
compatible_ptr
<
ctype
>
();
ret
.
mptr
=
mat
.
ptr
<
mtype
>
();
ret
.
dptr
=
dst
.
compatible_ptr
<
ctype
>
();
...
...
dnn/test/cuda/warp_perspective.cpp
浏览文件 @
19919384
...
...
@@ -647,6 +647,31 @@ TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_QINT4) {
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_FORWARD_QUINT4
)
{
using
Param
=
WarpPerspective
::
Param
;
Checker
<
WarpPerspectiveForward
>
checker
(
handle_cuda
());
WarpPerspectiveMatRNG
rng
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
(
1.25
f
,
0
))
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Quantized4Asymm
(
1.25
f
,
0
));
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
WarpPerspective
::
Param
param
;
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
execs
({{
1
,
64
,
11
,
11
},
{
1
,
3
,
3
},
{
1
,
64
,
11
,
11
}});
checker
.
execs
({{
20
,
640
,
11
,
12
},
{
20
,
3
,
3
},
{
20
,
640
,
11
,
12
}});
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_BACKWARD_DATA_BFLOAT16
)
{
Checker
<
WarpPerspectiveBackwardData
>
checker
(
handle_cuda
());
WarpPerspectiveMatRNG
rng
;
...
...
@@ -701,7 +726,7 @@ TEST_F(CUDA, WARP_PERSPECTIVE_MAT_IDX) {
warp_perspective
::
run_mat_idx_test
(
handle_cuda
());
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_NCHW64
)
{
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_NCHW64
_QINT4
)
{
using
Param
=
WarpPerspective
::
Param
;
WarpPerspective
::
Param
param
;
Checker
<
WarpPerspectiveForward
>
checker
(
handle_cuda
());
...
...
@@ -767,6 +792,72 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64) {
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_NCHW64_QUINT4
)
{
using
Param
=
WarpPerspective
::
Param
;
WarpPerspective
::
Param
param
;
Checker
<
WarpPerspectiveForward
>
checker
(
handle_cuda
());
WarpPerspectiveMatRNG_V2
rng
;
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
(
0.1
f
,
3
));
checker
.
set_dtype
(
2
,
dtype
::
Quantized4Asymm
(
0.1
f
,
3
));
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW64
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
rng
.
set_hw
(
10
,
11
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
2
,
1
,
10
,
11
,
64
},
{
2
,
3
,
3
},
{
2
,
1
,
11
,
12
,
64
}});
checker
.
execs
(
{{
20
,
300
,
10
,
11
,
64
},
{
20
,
3
,
3
},
{
20
,
300
,
11
,
12
,
64
}});
checker
.
execs
(
{{
2200
,
3
,
10
,
11
,
64
},
{
2200
,
3
,
3
},
{
2200
,
3
,
11
,
12
,
64
}});
rng
.
set_hw
(
25
,
25
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
1
,
25
,
25
,
25
,
64
},
{
1
,
3
,
3
},
{
1
,
25
,
25
,
51
,
64
}});
rng
.
set_hw
(
25
,
510
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
1
,
1
,
25
,
510
,
64
},
{
1
,
3
,
3
},
{
1
,
1
,
25
,
25
,
64
}});
rng
.
set_hw
(
25
,
25
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
1
,
1
,
25
,
25
,
64
},
{
1
,
3
,
3
},
{
1
,
1
,
51
,
51
,
64
}});
rng
.
set_hw
(
51
,
51
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
1
,
1
,
51
,
51
,
64
},
{
1
,
3
,
3
},
{
1
,
1
,
25
,
25
,
64
}});
}
{
Checker
<
WarpPerspective
,
WarpPerspectiveMatIdxProxy
>
checker
(
handle_cuda
());
constexpr
int
N_SRC
=
5
;
UniformIntRNG
mat_idx_rng
{
0
,
N_SRC
-
1
};
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
(
0.1
f
,
3
));
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
2
,
dtype
::
Int32
());
checker
.
set_rng
(
2
,
&
mat_idx_rng
);
checker
.
set_dtype
(
3
,
dtype
::
Quantized4Asymm
(
0.1
f
,
3
));
param
.
bmode
=
WarpPerspective
::
Param
::
BorderMode
::
REFLECT
;
param
.
imode
=
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
rng
.
set_hw
(
10
,
11
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
(
{{
N_SRC
,
3
,
10
,
11
,
64
},
{
2
,
3
,
3
},
{
2
},
{
2
,
3
,
11
,
12
,
64
}});
rng
.
set_hw
(
17
,
13
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
N_SRC
,
14
,
17
,
13
,
64
},
{
123
,
3
,
3
},
{
123
},
{
123
,
14
,
16
,
15
,
64
}});
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
CUDA
,
BENCHMARK_WARP_PERSPECTIVE_NCHW4
)
{
...
...
dnn/test/naive/warp_perspective.cpp
浏览文件 @
19919384
...
...
@@ -196,8 +196,8 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW_QINT4) {
param
.
imode
=
WarpPerspective
::
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
WarpPerspective
::
Param
::
Format
::
NCHW
;
std
::
vector
<
int
>
input_values
=
{
1
,
3
,
2
,
2
,
0
,
0
,
0
,
0
,
2
},
output_values
=
{
1
,
2
,
2
,
2
};
std
::
vector
<
int
>
input_values
=
{
-
1
,
-
3
,
-
2
,
-
2
,
0
,
0
,
0
,
0
,
-
2
},
output_values
=
{
-
1
,
-
2
,
-
2
,
-
2
};
checker
.
set_param
(
param
).
exect
(
Testcase
{
TensorValueLowbit4
({
1
,
1
,
3
,
3
},
dtype
::
QuantizedS4
(
0.1
),
...
...
@@ -212,6 +212,31 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW_QINT4) {
output_values
)});
}
TEST_F
(
NAIVE
,
WARP_PERSPECTIVE_NCHW_QUINT4
)
{
Checker
<
WarpPerspective
>
checker
(
handle
(),
false
);
WarpPerspective
::
Param
param
;
param
.
bmode
=
WarpPerspective
::
Param
::
BorderMode
::
BORDER_REFLECT
;
param
.
imode
=
WarpPerspective
::
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
WarpPerspective
::
Param
::
Format
::
NCHW
;
std
::
vector
<
int
>
input_values
=
{
4
,
13
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
output_values
=
{
6
,
8
,
8
,
9
};
checker
.
set_param
(
param
).
exect
(
Testcase
{
TensorValueLowbit4
({
1
,
1
,
3
,
3
},
dtype
::
Quantized4Asymm
(
0.1
,
3
),
input_values
),
TensorValue
({
1
,
3
,
3
},
dtype
::
Float32
{},
{
1.2
f
,
1.2
f
,
0.6
f
,
-
1.05
f
,
-
2.0
f
,
-
0.7
f
,
1.3
f
,
1.5
f
,
3.0
f
}),
{}},
Testcase
{{},
{},
TensorValueLowbit4
({
1
,
1
,
2
,
2
},
dtype
::
Quantized4Asymm
(
0.1
,
3
),
output_values
)});
}
TEST_F
(
NAIVE_MULTI_THREADS
,
WARP_PERSPECTIVE_NCHW4
)
{
using
Param
=
WarpPerspective
::
Param
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录