Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
606540be
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
411
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
606540be
编写于
5月 24, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add nhwc 4bit warp perspective
GitOrigin-RevId: fbec4a4a1f7a5dd184b4e76d60b26a6cac8a24ef
上级
1e601943
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
611 addition
and
97 deletion
+611
-97
dnn/src/common/warp_perspective.cpp
dnn/src/common/warp_perspective.cpp
+1
-1
dnn/src/cuda/warp_perspective/common.cuh
dnn/src/cuda/warp_perspective/common.cuh
+8
-0
dnn/src/cuda/warp_perspective/common.h
dnn/src/cuda/warp_perspective/common.h
+8
-0
dnn/src/cuda/warp_perspective/forward.cpp
dnn/src/cuda/warp_perspective/forward.cpp
+67
-13
dnn/src/cuda/warp_perspective/forward.cu
dnn/src/cuda/warp_perspective/forward.cu
+313
-67
dnn/src/naive/warp_perspective/opr_impl.cpp
dnn/src/naive/warp_perspective/opr_impl.cpp
+20
-12
dnn/test/cuda/warp_perspective.cpp
dnn/test/cuda/warp_perspective.cpp
+90
-2
dnn/test/naive/warp_perspective.cpp
dnn/test/naive/warp_perspective.cpp
+104
-2
未找到文件。
dnn/src/common/warp_perspective.cpp
浏览文件 @
606540be
...
...
@@ -226,7 +226,7 @@ std::string WarpPerspectiveBase::param_msg() const {
res
.
append
(
"LANCZOS4"
);
break
;
}
res
.
append
(
"bmode="
);
res
.
append
(
"
,
bmode="
);
switch
(
param
().
bmode
)
{
case
BorderMode
::
WRAP
:
res
.
append
(
"WRAP"
);
...
...
dnn/src/cuda/warp_perspective/common.cuh
浏览文件 @
606540be
...
...
@@ -63,6 +63,14 @@ class WrapGetter {
}
};
class
ConstGetter
{
public:
__device__
int
operator
()(
int
i
,
int
n
)
{
return
i
;
}
};
}
// namespace warp_perspective
}
// namespace cuda
}
// namespace megdnn
...
...
dnn/src/cuda/warp_perspective/common.h
浏览文件 @
606540be
...
...
@@ -28,6 +28,14 @@ void forward_proxy(bool is_nhwc, const ctype* src, const float* mat,
megcore
::
AsyncErrorInfo
*
error_info
,
void
*
error_tracker
,
cudaStream_t
stream
);
template
<
typename
ctype
,
int
pack_c
>
void
forward_proxy_nhwc_bit4
(
const
ctype
*
src
,
const
float
*
mat
,
const
int
*
mat_idx
,
ctype
*
dst
,
int
N_SRC
,
int
N_MAT
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
ctype
bval
,
BorderMode
bmode
,
megcore
::
AsyncErrorInfo
*
error_info
,
void
*
error_tracker
,
cudaStream_t
stream
);
template
<
typename
ctype
>
void
forward_proxy_nchw4
(
const
ctype
*
src
,
const
float
*
mat
,
const
int
*
mat_idx
,
ctype
*
dst
,
int
N_SRC
,
int
N_MAT
,
int
C
,
int
IH
,
...
...
dnn/src/cuda/warp_perspective/forward.cpp
浏览文件 @
606540be
...
...
@@ -328,12 +328,10 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NCHW64
||
param
().
format
==
Param
::
Format
::
NCHW
,
"WarpPerspective on CUDA supports NCHW64 or NCHW+ "
"QuantizedS4"
);
}
else
if
((
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
&&
(
param
().
format
==
Param
::
Format
::
NCHW64
||
param
().
format
==
Param
::
Format
::
NCHW
))
{
bval
=
roundf
(
bval
);
bval
=
fmin
(
fmax
(
-
8.
f
,
bval
),
7.
f
);
warp_perspective
::
forward_proxy_nchw64
<
dt_qint4
>
(
...
...
@@ -355,13 +353,10 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
relayout_opr
->
param
()
=
trans_param
;
relayout_opr
->
exec
(
dst
,
sdst
,
{});
}
}
else
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NCHW64
||
param
().
format
==
Param
::
Format
::
NCHW
,
"WarpPerspective on CUDA supports NCHW64 or NCHW+ "
"Quantized4Asymm"
);
}
else
if
((
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
(
param
().
format
==
Param
::
Format
::
NCHW64
||
param
().
format
==
Param
::
Format
::
NCHW
))
{
bval
=
roundf
(
bval
);
bval
=
fmin
(
fmax
(
0
,
bval
),
15
);
warp_perspective
::
forward_proxy_nchw64
<
dt_quint4
>
(
...
...
@@ -383,6 +378,65 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
relayout_opr
->
param
()
=
trans_param
;
relayout_opr
->
exec
(
dst
,
sdst
,
{});
}
}
else
if
((
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
(
param
().
format
==
Param
::
Format
::
NHWC
))
{
constexpr
int
pack_c
=
8
;
megdnn_assert
(
C
%
pack_c
==
0
);
bval
=
roundf
(
bval
);
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
bval
=
fmin
(
fmax
(
-
8.
f
,
bval
),
7.
f
);
if
(
C
%
16
==
0
)
{
warp_perspective
::
forward_proxy_nhwc_bit4
<
dt_qint4
,
16
>
(
src
.
ptr
<
dt_qint4
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_qint4
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
static_cast
<
dt_qint4
>
(
bval
),
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
{
warp_perspective
::
forward_proxy_nhwc_bit4
<
dt_qint4
,
pack_c
>
(
src
.
ptr
<
dt_qint4
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_qint4
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
static_cast
<
dt_qint4
>
(
bval
),
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
}
else
{
bval
=
fmin
(
fmax
(
0.
f
,
bval
),
15.
f
);
if
(
C
%
16
==
0
)
{
warp_perspective
::
forward_proxy_nhwc_bit4
<
dt_quint4
,
16
>
(
src
.
ptr
<
dt_quint4
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_quint4
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
static_cast
<
dt_quint4
>
(
bval
),
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
{
warp_perspective
::
forward_proxy_nhwc_bit4
<
dt_quint4
,
pack_c
>
(
src
.
ptr
<
dt_quint4
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_quint4
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
static_cast
<
dt_quint4
>
(
bval
),
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
}
}
}
else
if
((
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
...
...
dnn/src/cuda/warp_perspective/forward.cu
浏览文件 @
606540be
此差异已折叠。
点击以展开。
dnn/src/naive/warp_perspective/opr_impl.cpp
浏览文件 @
606540be
...
...
@@ -257,7 +257,7 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
MIDOUT_BEGIN
(
megdnn_naive_warpperspective
,
ctype
,
mtype
,
midout_iv
(
0
))
{
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM
(
kern_param
);
MEGDNN_MARK_USED_VAR
(
N_MAT
);
uint
8
_t
c_shift
,
c_mask
,
iw_shift
=
0
,
ow_shift
=
0
;
uint
32
_t
c_shift
,
c_mask
,
iw_shift
=
0
,
ow_shift
=
0
;
constexpr
bool
signedness
=
std
::
is_same
<
ctype
,
dt_qint4
>::
value
;
switch
(
param
().
format
)
{
case
Format
::
NCHW
:
...
...
@@ -270,19 +270,29 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
c_shift
=
6
;
c_mask
=
0x3F
;
break
;
case
Format
::
NHWC
:
megdnn_assert
(
C
%
2
==
0
);
c_shift
=
0
;
c_mask
=
0
;
break
;
default:
megdnn_throw
(
"bad format"
);
break
;
}
//! strides of C, H, W on src and dst
size_t
sstrd
[
2
]
=
{
IH
*
(
IW
+
iw_shift
),
IW
+
iw_shift
},
dstrd
[
2
]
=
{
OH
*
(
OW
+
ow_shift
),
OW
+
ow_shift
};
std
::
vector
<
size_t
>
sstrd
=
{
IH
*
((
IW
+
iw_shift
)
<<
c_shift
),
(
IW
+
iw_shift
)
<<
c_shift
,
1
};
std
::
vector
<
size_t
>
dstrd
=
{
OH
*
((
OW
+
ow_shift
)
<<
c_shift
),
(
OW
+
ow_shift
)
<<
c_shift
,
1
};
if
(
param
().
format
==
Format
::
NHWC
)
{
sstrd
=
{
1
,
IW
*
C
,
C
};
dstrd
=
{
1
,
OW
*
C
,
C
};
}
static
constexpr
uint8_t
mask
=
(
uint8_t
)((
1
<<
4
)
-
1
);
auto
visit_src
=
[
&
sptr
,
sstrd
,
c_shift
,
c_mask
](
size_t
c
,
int
h
,
int
w
)
->
float
{
size_t
index
=
((
sstrd
[
0
]
*
(
c
>>
c_shift
)
+
sstrd
[
1
]
*
h
+
w
)
<<
c_shift
)
+
(
c
&
c_mask
);
size_t
index
=
(
c
>>
c_shift
)
*
sstrd
[
0
]
+
h
*
sstrd
[
1
]
+
(
w
<<
c_shift
)
*
sstrd
[
2
]
+
(
c
&
c_mask
);
uint8_t
result
=
(
sptr
[
index
/
2
].
as_storage
()
>>
(
4
*
(
index
%
2
)))
&
0xF
;
if
(
signedness
)
{
...
...
@@ -295,9 +305,8 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
auto
visit_src_bd
=
[
&
sptr
,
sstrd
,
border_val
,
c_shift
,
c_mask
](
size_t
c
,
int
h
,
int
w
)
->
float
{
if
(
h
!=
-
1
&&
w
!=
-
1
)
{
size_t
index
=
((
sstrd
[
0
]
*
(
c
>>
c_shift
)
+
sstrd
[
1
]
*
h
+
w
)
<<
c_shift
)
+
(
c
&
c_mask
);
size_t
index
=
(
c
>>
c_shift
)
*
sstrd
[
0
]
+
h
*
sstrd
[
1
]
+
(
w
<<
c_shift
)
*
sstrd
[
2
]
+
(
c
&
c_mask
);
uint8_t
result
=
(
sptr
[
index
/
2
].
as_storage
()
>>
(
4
*
(
index
%
2
)))
&
0xF
;
...
...
@@ -312,9 +321,8 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
};
auto
set_visit_dst
=
[
&
dptr
,
dstrd
,
c_shift
,
c_mask
](
size_t
c
,
int
h
,
int
w
,
ctype
v
)
{
size_t
index
=
((
dstrd
[
0
]
*
(
c
>>
c_shift
)
+
dstrd
[
1
]
*
h
+
w
)
<<
c_shift
)
+
(
c
&
c_mask
);
size_t
index
=
(
c
>>
c_shift
)
*
dstrd
[
0
]
+
h
*
dstrd
[
1
]
+
(
w
<<
c_shift
)
*
dstrd
[
2
]
+
(
c
&
c_mask
);
dptr
[
index
/
2
]
=
(
dptr
[
index
/
2
].
as_storage
()
&
(
0xF0
>>
(
4
*
(
index
%
2
))))
|
(
v
.
as_storage
()
<<
(
4
*
(
index
%
2
)));
...
...
dnn/test/cuda/warp_perspective.cpp
浏览文件 @
606540be
...
...
@@ -176,10 +176,12 @@ TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD) {
Checker
<
WarpPerspectiveForward
>
checker
(
handle_cuda
());
WarpPerspectiveMatRNG
rng
;
checker
.
set_rng
(
1
,
&
rng
);
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
WarpPerspective
::
Param
param
;
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
...
...
@@ -215,6 +217,84 @@ TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD) {
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_FORWARD_NHWC
)
{
using
Param
=
WarpPerspective
::
Param
;
Checker
<
WarpPerspectiveForward
>
checker
(
handle_cuda
());
WarpPerspectiveMatRNG_V2
rng
;
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
0.1
f
));
checker
.
set_dtype
(
2
,
dtype
::
QuantizedS4
(
0.1
f
));
checker
.
set_rng
(
1
,
&
rng
);
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
WarpPerspective
::
Param
param
;
param
.
border_val
=
1.2
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
rng
.
set_hw
(
10
,
11
);
checker
.
execs
({{
23
,
10
,
11
,
16
},
{
23
,
3
,
3
},
{
23
,
11
,
12
,
16
}});
checker
.
execs
({{
20
,
10
,
11
,
32
},
{
20
,
3
,
3
},
{
20
,
11
,
12
,
32
}});
checker
.
execs
({{
20
,
10
,
11
,
32
},
{
20
,
3
,
3
},
{
20
,
11
,
12
,
32
}});
rng
.
set_hw
(
55
,
66
);
checker
.
execs
({{
20
,
55
,
66
,
32
},
{
20
,
3
,
3
},
{
20
,
44
,
34
,
32
}});
}
{
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
(
0.1
f
,
3
));
checker
.
set_dtype
(
2
,
dtype
::
Quantized4Asymm
(
0.1
f
,
3
));
checker
.
set_rng
(
1
,
&
rng
);
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
WarpPerspective
::
Param
param
;
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
rng
.
set_hw
(
10
,
11
);
checker
.
execs
({{
23
,
10
,
11
,
16
},
{
23
,
3
,
3
},
{
23
,
11
,
12
,
16
}});
checker
.
execs
({{
20
,
10
,
11
,
32
},
{
20
,
3
,
3
},
{
20
,
11
,
12
,
32
}});
checker
.
execs
({{
20
,
10
,
11
,
32
},
{
20
,
3
,
3
},
{
20
,
11
,
12
,
32
}});
rng
.
set_hw
(
55
,
66
);
checker
.
execs
({{
20
,
55
,
66
,
32
},
{
20
,
3
,
3
},
{
20
,
44
,
34
,
32
}});
}
}
{
Checker
<
WarpPerspective
,
WarpPerspectiveMatIdxProxy
>
checker
(
handle_cuda
());
constexpr
int
N_SRC
=
5
;
UniformIntRNG
mat_idx_rng
{
0
,
N_SRC
-
1
};
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
0.1
f
));
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
2
,
dtype
::
Int32
());
checker
.
set_rng
(
2
,
&
mat_idx_rng
);
checker
.
set_dtype
(
3
,
dtype
::
QuantizedS4
(
0.1
f
));
WarpPerspective
::
Param
param
;
param
.
border_val
=
0.3
f
;
param
.
format
=
Param
::
Format
::
NHWC
;
param
.
bmode
=
WarpPerspective
::
Param
::
BorderMode
::
REFLECT
;
param
.
imode
=
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
rng
.
set_hw
(
10
,
11
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
({{
N_SRC
,
10
,
11
,
48
},
{
2
,
3
,
3
},
{
2
},
{
2
,
11
,
12
,
48
}});
rng
.
set_hw
(
17
,
13
);
checker
.
set_rng
(
1
,
&
rng
);
checker
.
execs
(
{{
N_SRC
,
17
,
13
,
64
},
{
123
,
3
,
3
},
{
123
},
{
123
,
16
,
15
,
64
}});
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_FORWARD_INTMAX
)
{
require_compute_capability
(
6
,
0
);
using
Param
=
WarpPerspective
::
Param
;
...
...
@@ -895,6 +975,14 @@ TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) {
run
({
TensorShape
{
1
,
25
,
256
,
5120
,
4
},
{
1
,
3
,
3
},
{
1
,
25
,
256
,
256
,
4
}});
run
({
TensorShape
{
1
,
25
,
256
,
256
,
4
},
{
1
,
3
,
3
},
{
1
,
25
,
512
,
512
,
4
}});
run
({
TensorShape
{
1
,
25
,
512
,
512
,
4
},
{
1
,
3
,
3
},
{
1
,
25
,
256
,
256
,
4
}});
param
.
format
=
Param
::
Format
::
NHWC
;
benchmarker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
1.
f
));
benchmarker
.
set_dtype
(
2
,
dtype
::
QuantizedS4
(
1.
f
));
run
({
TensorShape
{
1
,
256
,
256
,
4
*
24
},
{
1
,
3
,
3
},
{
1
,
256
,
5120
,
4
*
24
}});
run
({
TensorShape
{
1
,
256
,
5120
,
4
*
24
},
{
1
,
3
,
3
},
{
1
,
256
,
256
,
4
*
24
}});
run
({
TensorShape
{
1
,
256
,
256
,
4
*
24
},
{
1
,
3
,
3
},
{
1
,
512
,
512
,
4
*
24
}});
run
({
TensorShape
{
1
,
512
,
512
,
4
*
24
},
{
1
,
3
,
3
},
{
1
,
256
,
256
,
4
*
24
}});
}
#endif
...
...
dnn/test/naive/warp_perspective.cpp
浏览文件 @
606540be
...
...
@@ -642,12 +642,114 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW64) {
param
.
format
=
Param
::
Format
::
NCHW64
;
checker
.
set_param
(
param
);
checker
.
execs
({{
2
,
1
,
10
,
10
,
64
},
{
2
,
3
,
3
},
{
2
,
1
,
10
,
12
,
64
}});
checker
.
execs
(
{{
20
,
3
,
10
,
12
,
64
},
{
20
,
3
,
3
},
{
20
,
3
,
11
,
12
,
64
}});
checker
.
execs
({{
20
,
3
,
10
,
12
,
64
},
{
20
,
3
,
3
},
{
20
,
3
,
11
,
12
,
64
}});
checker
.
execs
({{
1
,
3
,
25
,
24
,
64
},
{
1
,
3
,
3
},
{
1
,
3
,
25
,
51
,
64
}});
checker
.
execs
({{
1
,
3
,
25
,
51
,
64
},
{
1
,
3
,
3
},
{
1
,
3
,
25
,
24
,
64
}});
checker
.
execs
({{
1
,
3
,
25
,
24
,
64
},
{
1
,
3
,
3
},
{
1
,
3
,
51
,
50
,
64
}});
checker
.
execs
({{
1
,
3
,
51
,
50
,
64
},
{
1
,
3
,
3
},
{
1
,
3
,
25
,
24
,
64
}});
}
}
TEST_F
(
NAIVE
,
WARP_PERSPECTIVE_NHWC
)
{
using
Param
=
WarpPerspective
::
Param
;
auto
convert_true_format
=
[](
const
TensorLayout
&
layout
)
{
if
(
layout
.
ndim
==
4
)
{
TensorLayout
ret
{{
layout
[
0
],
layout
[
2
],
layout
[
3
],
layout
[
1
]},
layout
.
dtype
};
return
ret
.
dimshuffle
({
0
,
3
,
1
,
2
});
}
else
return
layout
;
};
WarpPerspective
::
Param
param
;
auto
extra_impl
=
[
&
param
,
this
,
convert_true_format
](
const
TensorNDArray
&
tensors
)
{
auto
warp_perspective
=
handle
()
->
create_operator
<
WarpPerspective
>
();
warp_perspective
->
param
()
=
param
;
warp_perspective
->
param
().
format
=
Param
::
Format
::
NCHW
;
TensorNDArray
nchw_tensors
;
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
TensorLayout
ly
;
auto
layout
=
tensors
[
i
].
layout
;
if
(
layout
.
ndim
==
4
)
{
ly
=
TensorLayout
{{
layout
[
0
],
layout
[
3
],
layout
[
1
],
layout
[
2
]},
layout
.
dtype
};
}
else
{
ly
=
layout
;
}
nchw_tensors
.
emplace_back
(
malloc
(
ly
.
span
().
dist_byte
()),
ly
);
}
TensorNDArray
nhwc_tensors
;
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
auto
layout
=
convert_true_format
(
nchw_tensors
[
i
].
layout
);
nhwc_tensors
.
emplace_back
(
tensors
[
i
].
raw_ptr
,
std
::
move
(
layout
));
}
auto
workspace_size
=
warp_perspective
->
get_workspace_in_bytes
(
tensors
[
0
].
layout
,
tensors
[
1
].
layout
,
tensors
[
2
].
layout
);
dt_byte
*
workspace_ptr
=
static_cast
<
dt_byte
*>
(
malloc
(
workspace_size
));
Workspace
workspace
{
workspace_ptr
,
workspace_size
};
auto
relayout
=
handle
()
->
create_operator
<
RelayoutForward
>
();
relayout
->
exec
(
nhwc_tensors
[
0
],
nchw_tensors
[
0
]);
relayout
->
exec
(
nhwc_tensors
[
1
],
nchw_tensors
[
1
]);
warp_perspective
->
exec
(
nchw_tensors
[
0
],
nchw_tensors
[
1
],
nchw_tensors
[
2
],
workspace
);
relayout
->
exec
(
nchw_tensors
[
2
],
nhwc_tensors
[
2
]);
free
(
workspace_ptr
);
for
(
auto
&&
tensor
:
nchw_tensors
)
{
free
(
tensor
.
raw_ptr
);
}
};
{
Checker
<
WarpPerspectiveForward
>
checker
(
handle
());
WarpPerspectiveMatRNG
rng
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
(
0.1
f
));
checker
.
set_dtype
(
2
,
dtype
::
QuantizedS4
(
0.1
f
));
checker
.
set_extra_opr_impl
(
extra_impl
);
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
checker
.
set_param
(
param
);
checker
.
execs
({{
1
,
2
,
2
,
4
},
{
1
,
3
,
3
},
{
1
,
2
,
2
,
4
}});
checker
.
execs
({{
2
,
10
,
10
,
4
},
{
2
,
3
,
3
},
{
2
,
10
,
12
,
4
}});
checker
.
execs
({{
3
,
25
,
24
,
8
},
{
3
,
3
,
3
},
{
3
,
12
,
10
,
8
}});
checker
.
execs
({{
4
,
33
,
22
,
16
},
{
4
,
3
,
3
},
{
4
,
9
,
12
,
16
}});
}
}
{
Checker
<
WarpPerspectiveForward
>
checker
(
handle
());
WarpPerspectiveMatRNG
rng
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
Quantized4Asymm
(
0.1
f
,
3
));
checker
.
set_dtype
(
2
,
dtype
::
Quantized4Asymm
(
0.1
f
,
3
));
checker
.
set_extra_opr_impl
(
extra_impl
);
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
checker
.
set_param
(
param
);
checker
.
execs
({{
1
,
2
,
2
,
4
},
{
1
,
3
,
3
},
{
1
,
2
,
2
,
4
}});
checker
.
execs
({{
2
,
10
,
10
,
4
},
{
2
,
3
,
3
},
{
2
,
10
,
12
,
4
}});
checker
.
execs
({{
3
,
25
,
24
,
8
},
{
3
,
3
,
3
},
{
3
,
12
,
10
,
8
}});
checker
.
execs
({{
4
,
33
,
22
,
16
},
{
4
,
3
,
3
},
{
4
,
9
,
12
,
16
}});
}
}
}
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录