Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
669816e2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
669816e2
编写于
7月 31, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): warpperspective support multi src input
GitOrigin-RevId: 8a4789852e6df47b5b44ac3e2fa2999d4e0ab5d6
上级
33b27be8
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
1693 addition
and
0 deletion
+1693
-0
dnn/include/megdnn/oprs/imgproc.h
dnn/include/megdnn/oprs/imgproc.h
+33
-0
dnn/src/common/warp_common.cpp
dnn/src/common/warp_common.cpp
+7
-0
dnn/src/common/warp_common.h
dnn/src/common/warp_common.h
+4
-0
dnn/src/common/warp_perspective.cpp
dnn/src/common/warp_perspective.cpp
+103
-0
dnn/src/cuda/warp_perspective/common.h
dnn/src/cuda/warp_perspective/common.h
+7
-0
dnn/src/cuda/warp_perspective/forward.cpp
dnn/src/cuda/warp_perspective/forward.cpp
+146
-0
dnn/src/cuda/warp_perspective/forward.cu
dnn/src/cuda/warp_perspective/forward.cu
+298
-0
dnn/src/cuda/warp_perspective/opr_impl.h
dnn/src/cuda/warp_perspective/opr_impl.h
+13
-0
dnn/src/fallback/warp_perspective/opr_impl.cpp
dnn/src/fallback/warp_perspective/opr_impl.cpp
+113
-0
dnn/src/fallback/warp_perspective/opr_impl.h
dnn/src/fallback/warp_perspective/opr_impl.h
+10
-0
dnn/src/naive/warp_perspective/opr_impl.cpp
dnn/src/naive/warp_perspective/opr_impl.cpp
+178
-0
dnn/src/naive/warp_perspective/opr_impl.h
dnn/src/naive/warp_perspective/opr_impl.h
+79
-0
dnn/test/common/warp_perspective.cpp
dnn/test/common/warp_perspective.cpp
+48
-0
dnn/test/common/warp_perspective.h
dnn/test/common/warp_perspective.h
+6
-0
dnn/test/cuda/warp_perspective.cpp
dnn/test/cuda/warp_perspective.cpp
+188
-0
dnn/test/fallback/warp_perspective.cpp
dnn/test/fallback/warp_perspective.cpp
+184
-0
dnn/test/naive/warp_perspective.cpp
dnn/test/naive/warp_perspective.cpp
+276
-0
未找到文件。
dnn/include/megdnn/oprs/imgproc.h
浏览文件 @
669816e2
...
@@ -16,10 +16,18 @@ protected:
...
@@ -16,10 +16,18 @@ protected:
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
dst
)
{
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
dst
)
{
check_layout_fwd
(
src
,
mat
,
{},
dst
);
check_layout_fwd
(
src
,
mat
,
{},
dst
);
}
}
void
check_layout_fwd
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
dst
)
{
check_layout_fwd
(
srcs
,
mat
,
{},
dst
);
}
void
check_layout_fwd
(
void
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
);
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
);
void
check_layout_fwd
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
);
std
::
string
param_msg
()
const
;
std
::
string
param_msg
()
const
;
int
get_real_coord
(
int
p
,
int
len
);
int
get_real_coord
(
int
p
,
int
len
);
};
};
...
@@ -49,6 +57,12 @@ public:
...
@@ -49,6 +57,12 @@ public:
exec
(
src
,
mat
,
{},
dst
,
workspace
);
exec
(
src
,
mat
,
{},
dst
,
workspace
);
}
}
void
exec
(
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
exec
(
srcs
,
mat
,
{},
dst
,
workspace
);
}
/**
/**
* \p src should have batch size m, and \p mat and \p mat_idx should
* \p src should have batch size m, and \p mat and \p mat_idx should
* both have batch size n. Each item in \p mat_idx must be in the range
* both have batch size n. Each item in \p mat_idx must be in the range
...
@@ -62,15 +76,30 @@ public:
...
@@ -62,15 +76,30 @@ public:
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
virtual
void
exec
(
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
size_t
get_workspace_in_bytes
(
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
dst
)
{
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
dst
)
{
return
get_workspace_in_bytes
(
src
,
mat
,
{},
dst
);
return
get_workspace_in_bytes
(
src
,
mat
,
{},
dst
);
}
}
size_t
get_workspace_in_bytes
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
dst
)
{
return
get_workspace_in_bytes
(
srcs
,
mat
,
{},
dst
);
}
virtual
size_t
get_workspace_in_bytes
(
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
=
0
;
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
=
0
;
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
=
0
;
protected:
protected:
void
check_exec
(
void
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
...
@@ -81,6 +110,10 @@ protected:
...
@@ -81,6 +110,10 @@ protected:
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
);
size_t
workspace_in_bytes
);
void
check_exec_allow_nhwc_mat_idx
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
);
};
};
using
WarpPerspective
=
WarpPerspectiveForward
;
using
WarpPerspective
=
WarpPerspectiveForward
;
...
...
dnn/src/common/warp_common.cpp
浏览文件 @
669816e2
...
@@ -22,4 +22,11 @@ bool warp::is_dnn_available(
...
@@ -22,4 +22,11 @@ bool warp::is_dnn_available(
return
imode
==
param
::
WarpAffine
::
InterpolationMode
::
LINEAR
;
return
imode
==
param
::
WarpAffine
::
InterpolationMode
::
LINEAR
;
}
}
bool
warp
::
is_dnn_available
(
const
TensorLayoutArray
&
/*src*/
,
const
TensorLayout
&
/*mat*/
,
const
TensorLayout
&
/*dst*/
,
param
::
WarpAffine
::
InterpolationMode
imode
,
param
::
WarpAffine
::
Format
/*format*/
)
{
return
imode
==
param
::
WarpAffine
::
InterpolationMode
::
LINEAR
;
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/common/warp_common.h
浏览文件 @
669816e2
...
@@ -90,6 +90,10 @@ bool is_dnn_available(
...
@@ -90,6 +90,10 @@ bool is_dnn_available(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
param
::
WarpAffine
::
InterpolationMode
imode
,
param
::
WarpAffine
::
Format
format
);
param
::
WarpAffine
::
InterpolationMode
imode
,
param
::
WarpAffine
::
Format
format
);
bool
is_dnn_available
(
const
TensorLayoutArray
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
param
::
WarpAffine
::
InterpolationMode
imode
,
param
::
WarpAffine
::
Format
format
);
using
namespace
megcv
;
using
namespace
megcv
;
using
IMode
=
InterpolationMode
;
using
IMode
=
InterpolationMode
;
using
BMode
=
BorderMode
;
using
BMode
=
BorderMode
;
...
...
dnn/src/common/warp_perspective.cpp
浏览文件 @
669816e2
...
@@ -3,7 +3,97 @@
...
@@ -3,7 +3,97 @@
#include "src/common/utils.h"
#include "src/common/utils.h"
namespace
megdnn
{
namespace
megdnn
{
void
WarpPerspectiveBase
::
check_layout_fwd
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
{
megdnn_assert
(
srcs
.
size
()
>
0
);
auto
s
=
srcs
.
front
();
for
(
auto
&&
src
:
srcs
)
{
megdnn_assert_contiguous
(
src
);
megdnn_assert
(
src
.
dtype
==
s
.
dtype
);
megdnn_assert
(
src
.
ndim
==
s
.
ndim
);
megdnn_assert
(
src
.
shape
[
0
]
==
1
);
for
(
size_t
i
=
0
;
i
<
s
.
ndim
;
i
++
)
{
megdnn_assert
(
src
.
shape
[
i
]
==
s
.
shape
[
i
]);
}
megdnn_assert
(
src
.
format
==
s
.
format
);
}
megdnn_assert_contiguous
(
mat
);
megdnn_assert_contiguous
(
dst
);
auto
errmsg
=
[
&
]()
{
std
::
string
msg
=
"{"
;
for
(
auto
&&
src
:
srcs
)
{
msg
.
append
(
megdnn_layout_msg
(
src
)
+
", "
);
}
return
msg
+
"} "
+
megdnn_layout_msg
(
mat
)
+
", "
+
megdnn_layout_msg
(
mat_idx
)
+
", "
+
megdnn_layout_msg
(
dst
)
+
", "
+
param_msg
();
};
MEGDNN_MARK_USED_VAR
(
errmsg
);
megdnn_assert
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC
||
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
);
megdnn_assert
(
s
.
ndim
==
4
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
dst
.
ndim
==
4
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
mat
.
ndim
==
3
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
dst
.
shape
[
0
]
==
mat
.
shape
[
0
],
"%s"
,
errmsg
().
c_str
());
if
(
mat_idx
.
ndim
)
{
megdnn_assert
(
mat_idx
.
dtype
==
dtype
::
Int32
()
&&
mat_idx
.
ndim
==
1
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
mat
.
shape
[
0
]
==
mat_idx
.
shape
[
0
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert_contiguous
(
mat_idx
);
}
else
{
megdnn_assert
(
s
.
shape
[
0
]
*
srcs
.
size
()
==
dst
.
shape
[
0
],
"%s"
,
errmsg
().
c_str
());
}
megdnn_assert
(
mat
.
shape
[
1
]
==
3
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
mat
.
shape
[
2
]
==
3
_z
,
"%s"
,
errmsg
().
c_str
());
if
(
s
.
format
==
dst
.
format
&&
dst
.
dtype
==
s
.
dtype
)
{
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
)
{
megdnn_assert
(
s
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
||
DNN_FLOAT16_SELECT
(
(
s
.
dtype
.
enumv
()
==
DTypeEnum
::
Float16
||
s
.
dtype
.
enumv
()
==
DTypeEnum
::
BFloat16
),
false
),
"WarpPerspective multi src NCHW input dtype should be "
"Float32"
DNN_FLOAT16_SELECT
(
"/Float16/BFloat16"
,
""
)
"."
);
megdnn_assert
(
(
s
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
&&
(
s
.
dtype
==
mat
.
dtype
||
mat
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
)),
"The input to WarpPerspective multi src is in NCHW format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, otherwise, it should be in Float32, %s given."
,
mat
.
dtype
.
name
());
megdnn_assert
(
s
.
shape
[
1
]
==
dst
.
shape
[
1
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
param
().
imode
==
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
TRANSPARENT
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
ISOLATED
);
}
else
{
megdnn_assert
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC
);
megdnn_assert
(
s
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
||
DNN_FLOAT16_SELECT
(
(
s
.
dtype
.
enumv
()
==
DTypeEnum
::
Float16
||
s
.
dtype
.
enumv
()
==
DTypeEnum
::
BFloat16
),
false
),
"WarpPerspective multi src NHWC input dtype should be "
"Float32"
DNN_FLOAT16_SELECT
(
"/Float16/BFloat16"
,
""
)
"."
);
megdnn_assert
(
s
.
shape
[
3
]
==
dst
.
shape
[
3
],
"%s"
,
errmsg
().
c_str
());
}
}
else
{
megdnn_assert
(
0
,
"WarpPerspective multi src only support format NHWC/NCHW, dtype "
"Float32"
DNN_FLOAT16_SELECT
(
"/Float16/BFloat16"
,
""
)
"."
);
}
}
void
WarpPerspectiveBase
::
check_layout_fwd
(
void
WarpPerspectiveBase
::
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
{
const
TensorLayout
&
dst
)
{
...
@@ -295,6 +385,19 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
...
@@ -295,6 +385,19 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
}
}
}
}
void
WarpPerspectiveForward
::
check_exec_allow_nhwc_mat_idx
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
)
{
check_layout_fwd
(
srcs
,
mat
,
mat_idx
,
dst
);
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
srcs
,
mat
,
mat_idx
,
dst
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
if
(
param
().
format
!=
Param
::
Format
::
NHWC
&&
param
().
format
!=
Param
::
Format
::
NCHW
)
{
megdnn_assert
(
!
mat_idx
.
ndim
,
"mat_idx not supported for current format"
);
}
}
void
WarpPerspectiveBackwardData
::
check_exec
(
void
WarpPerspectiveBackwardData
::
check_exec
(
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
diff
,
const
TensorLayout
&
grad
,
size_t
workspace_in_bytes
)
{
const
TensorLayout
&
grad
,
size_t
workspace_in_bytes
)
{
...
...
dnn/src/cuda/warp_perspective/common.h
浏览文件 @
669816e2
...
@@ -17,6 +17,13 @@ void forward_proxy(
...
@@ -17,6 +17,13 @@ void forward_proxy(
ctype
bval
,
BorderMode
bmode
,
megcore
::
AsyncErrorInfo
*
error_info
,
ctype
bval
,
BorderMode
bmode
,
megcore
::
AsyncErrorInfo
*
error_info
,
void
*
error_tracker
,
cudaStream_t
stream
);
void
*
error_tracker
,
cudaStream_t
stream
);
template
<
typename
ctype
>
void
forward_proxy_multi_src
(
bool
is_nhwc
,
const
ctype
**
srcs
,
const
float
*
mat
,
const
int
*
mat_idx
,
ctype
*
dst
,
int
N_SRC
,
int
N_MAT
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
ctype
bval
,
BorderMode
bmode
,
megcore
::
AsyncErrorInfo
*
error_info
,
void
*
error_tracker
,
cudaStream_t
stream
);
template
<
typename
ctype
,
int
pack_c
>
template
<
typename
ctype
,
int
pack_c
>
void
forward_proxy_nhwc_bit4
(
void
forward_proxy_nhwc_bit4
(
const
ctype
*
src
,
const
float
*
mat
,
const
int
*
mat_idx
,
ctype
*
dst
,
int
N_SRC
,
const
ctype
*
src
,
const
float
*
mat
,
const
int
*
mat_idx
,
ctype
*
dst
,
int
N_SRC
,
...
...
dnn/src/cuda/warp_perspective/forward.cpp
浏览文件 @
669816e2
...
@@ -143,6 +143,34 @@ WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
...
@@ -143,6 +143,34 @@ WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
return
{
ptr
,
std
::
move
(
sizes
)};
return
{
ptr
,
std
::
move
(
sizes
)};
}
}
WorkspaceBundle
WarpPerspectiveForwardImpl
::
get_workspace_bundle
(
void
*
ptr
,
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
const
{
MEGDNN_MARK_USED_VAR
(
mat_idx
);
SmallVector
<
size_t
>
sizes
;
TensorLayoutArray
fsrcs
=
srcs
;
TensorLayout
fmat
=
mat
;
TensorLayout
fdst
=
dst
;
auto
get_workspace
=
[
&
sizes
](
TensorLayout
&
layout
)
{
if
(
layout
.
dtype
==
dtype
::
BFloat16
())
{
layout
.
dtype
=
dtype
::
Float32
();
sizes
.
push_back
(
layout
.
span
().
dist_byte
());
}
};
for
(
auto
&&
fsrc
:
fsrcs
)
{
get_workspace
(
fsrc
);
}
get_workspace
(
fmat
);
get_workspace
(
fdst
);
sizes
.
push_back
(
sizeof
(
dt_float32
*
)
*
srcs
.
size
());
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC
)
{
//! use double for the workspace dtype as float may cause
//! accuracy problems
sizes
.
push_back
(
mat
.
total_nr_elems
()
*
sizeof
(
double
));
}
return
{
ptr
,
std
::
move
(
sizes
)};
}
void
WarpPerspectiveForwardImpl
::
exec
(
void
WarpPerspectiveForwardImpl
::
exec
(
_megdnn_tensor_in
ssrc
,
_megdnn_tensor_in
smat
,
_megdnn_tensor_in
smat_idx
,
_megdnn_tensor_in
ssrc
,
_megdnn_tensor_in
smat
,
_megdnn_tensor_in
smat_idx
,
_megdnn_tensor_out
sdst
,
_megdnn_workspace
sworkspace
)
{
_megdnn_tensor_out
sdst
,
_megdnn_workspace
sworkspace
)
{
...
@@ -453,6 +481,124 @@ void WarpPerspectiveForwardImpl::exec(
...
@@ -453,6 +481,124 @@ void WarpPerspectiveForwardImpl::exec(
}
}
}
}
void
WarpPerspectiveForwardImpl
::
exec
(
_megdnn_in
const
TensorNDArray
&
ssrcs
,
_megdnn_tensor_in
smat
,
_megdnn_tensor_in
smat_idx
,
_megdnn_tensor_out
sdst
,
_megdnn_workspace
sworkspace
)
{
TensorLayoutArray
ssrcs_layout
;
for
(
auto
&&
s
:
ssrcs
)
{
ssrcs_layout
.
push_back
(
s
.
layout
);
}
check_exec_allow_nhwc_mat_idx
(
ssrcs_layout
,
smat
.
layout
,
smat_idx
.
layout
,
sdst
.
layout
,
sworkspace
.
size
);
TensorNDArray
srcs
=
ssrcs
;
TensorND
mat
=
smat
;
TensorND
mat_idx
=
smat_idx
;
TensorND
dst
=
sdst
;
Param
::
Format
inner_format
=
param
().
format
;
auto
bundle
=
get_workspace_bundle
(
sworkspace
.
raw_ptr
,
ssrcs_layout
,
smat
.
layout
,
smat_idx
.
layout
,
sdst
.
layout
);
auto
ctypecvt
=
CompTypeCvter
<
dtype
::
BFloat16
,
dtype
::
Float32
>
(
concrete_handle
(
this
->
handle
()),
&
bundle
);
if
(
ssrcs
.
front
().
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
BFloat16
>::
enumv
)
{
for
(
size_t
i
=
0
;
i
<
ssrcs
.
size
();
i
++
)
{
ctypecvt
.
src_to_comp_type
(
ssrcs
[
i
],
srcs
[
i
]);
}
ctypecvt
.
src_to_comp_type
(
smat
,
mat
).
src_to_comp_type
(
sdst
,
dst
);
}
{
auto
stream
=
cuda_stream
(
this
->
handle
());
bool
is_nhwc
=
inner_format
==
param
::
WarpPerspective
::
Format
::
NHWC
;
TensorND
src
=
srcs
.
front
();
megdnn_assert
(
warp
::
is_dnn_available
(
ssrcs_layout
,
mat
.
layout
,
dst
.
layout
,
param
().
imode
,
inner_format
));
size_t
C
,
IH
,
IW
,
OH
,
OW
;
if
(
is_nhwc
)
{
C
=
src
.
layout
.
shape
[
3
];
IH
=
src
.
layout
.
shape
[
1
];
IW
=
src
.
layout
.
shape
[
2
];
OH
=
dst
.
layout
.
shape
[
1
];
OW
=
dst
.
layout
.
shape
[
2
];
}
else
{
megdnn_assert
(
inner_format
==
param
::
WarpPerspective
::
Format
::
NCHW
,
"invalid warp_perspective format"
);
C
=
src
.
layout
.
shape
[
1
];
IH
=
src
.
layout
.
shape
[
2
];
IW
=
src
.
layout
.
shape
[
3
];
OH
=
dst
.
layout
.
shape
[
2
];
OW
=
dst
.
layout
.
shape
[
3
];
}
megdnn_assert
(
param
().
imode
==
Param
::
InterpolationMode
::
LINEAR
,
"unsupported interpolation mode form NCHW format"
);
auto
bval
=
param
().
border_val
;
auto
bmode
=
warp_perspective
::
get_bmode
(
param
().
bmode
);
if
(
src
.
layout
.
dtype
==
dst
.
layout
.
dtype
)
{
if
(
src
.
layout
.
dtype
==
dtype
::
Float32
{})
{
SmallVector
<
size_t
>
workspace_sizes
{
sizeof
(
dt_float32
*
)
*
srcs
.
size
()};
WorkspaceBundle
workspace_cpu
(
nullptr
,
workspace_sizes
);
auto
total_workspace_size
=
workspace_cpu
.
total_size_in_bytes
();
void
*
workspace_cpu_raw
=
malloc
(
total_workspace_size
);
workspace_cpu
=
WorkspaceBundle
(
workspace_cpu_raw
,
workspace_sizes
);
auto
srcs_cpu
=
static_cast
<
const
dt_float32
**>
(
workspace_cpu
.
get
(
0
));
size_t
i
=
is_nhwc
?
bundle
.
nr_workspace
()
-
2
:
bundle
.
nr_workspace
()
-
1
;
auto
srcs_gpu
=
static_cast
<
const
dt_float32
**>
(
bundle
.
get
(
i
));
for
(
size_t
i
=
0
;
i
<
srcs
.
size
();
++
i
)
{
srcs_cpu
[
i
]
=
srcs
[
i
].
ptr
<
dt_float32
>
();
}
cuda_check
(
cudaMemcpyAsync
(
bundle
.
get
(
i
),
workspace_cpu
.
get
(
0
),
workspace_cpu
.
get_size
(
0
),
cudaMemcpyHostToDevice
,
stream
));
cuda_check
(
cudaStreamAddCallback
(
stream
,
callback_free
,
static_cast
<
void
*>
(
workspace_cpu_raw
),
0
));
warp_perspective
::
forward_proxy_multi_src
(
is_nhwc
,
srcs_gpu
,
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
()
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_float32
>
(),
srcs
.
size
(),
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
if
(
DNN_FLOAT16_SELECT
(
src
.
layout
.
dtype
==
dtype
::
Float16
(),
false
))
{
#ifndef MEGDNN_DISABLE_FLOAT16
SmallVector
<
size_t
>
workspace_sizes
{
sizeof
(
dt_float16
*
)
*
srcs
.
size
()};
WorkspaceBundle
workspace_cpu
(
nullptr
,
workspace_sizes
);
auto
total_workspace_size
=
workspace_cpu
.
total_size_in_bytes
();
void
*
workspace_cpu_raw
=
malloc
(
total_workspace_size
);
workspace_cpu
=
WorkspaceBundle
(
workspace_cpu_raw
,
workspace_sizes
);
auto
srcs_cpu
=
static_cast
<
const
dt_float16
**>
(
workspace_cpu
.
get
(
0
));
auto
srcs_gpu
=
static_cast
<
const
dt_float16
**>
(
bundle
.
get
(
0
));
for
(
size_t
i
=
0
;
i
<
srcs
.
size
();
++
i
)
{
srcs_cpu
[
i
]
=
srcs
[
i
].
ptr
<
dt_float16
>
();
}
cuda_check
(
cudaMemcpyAsync
(
bundle
.
get
(
0
),
workspace_cpu
.
get
(
0
),
workspace_cpu
.
get_size
(
0
),
cudaMemcpyHostToDevice
,
stream
));
cuda_check
(
cudaStreamAddCallback
(
stream
,
callback_free
,
static_cast
<
void
*>
(
workspace_cpu_raw
),
0
));
warp_perspective
::
forward_proxy_multi_src
(
is_nhwc
,
srcs_gpu
,
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
()
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_float16
>
(),
srcs
.
size
(),
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
static_cast
<
dt_float16
>
(
bval
),
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
#endif
}
}
else
{
megdnn_throw
(
ssprintf
(
"unsupported dtype: %s"
,
src
.
layout
.
dtype
.
name
()));
}
}
if
(
ssrcs
.
front
().
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
BFloat16
>::
enumv
)
{
ctypecvt
.
comp_to_dst_type
(
dst
,
sdst
);
}
}
}
// namespace cuda
}
// namespace cuda
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/cuda/warp_perspective/forward.cu
浏览文件 @
669816e2
...
@@ -47,11 +47,16 @@ struct CtypeHelper<dt_quint4> {
...
@@ -47,11 +47,16 @@ struct CtypeHelper<dt_quint4> {
template
<
typename
ctype
>
template
<
typename
ctype
>
struct
DirectSrcVisitor
{
struct
DirectSrcVisitor
{
const
void
*
ptr
;
const
void
*
ptr
;
const
void
**
ptrs
;
__device__
__forceinline__
const
ctype
*
get
(
int
batch
,
int
im_size
)
{
__device__
__forceinline__
const
ctype
*
get
(
int
batch
,
int
im_size
)
{
return
(
ctype
*
)((
char
*
)
ptr
+
static_cast
<
int64_t
>
(
batch
)
*
static_cast
<
int64_t
>
(
im_size
)
*
CtypeHelper
<
ctype
>::
bit_width
/
8
);
return
(
ctype
*
)((
char
*
)
ptr
+
static_cast
<
int64_t
>
(
batch
)
*
static_cast
<
int64_t
>
(
im_size
)
*
CtypeHelper
<
ctype
>::
bit_width
/
8
);
}
}
__device__
__forceinline__
const
ctype
*
get
(
int
batch
)
{
return
(
ctype
*
)(
ptrs
[
batch
]);
}
void
move_batch
(
size_t
batch
,
size_t
im_size
)
{
void
move_batch
(
size_t
batch
,
size_t
im_size
)
{
ptr
=
(
char
*
)
ptr
+
batch
*
im_size
*
CtypeHelper
<
ctype
>::
bit_width
/
8
;
ptr
=
(
char
*
)
ptr
+
batch
*
im_size
*
CtypeHelper
<
ctype
>::
bit_width
/
8
;
}
}
...
@@ -60,6 +65,7 @@ struct DirectSrcVisitor {
...
@@ -60,6 +65,7 @@ struct DirectSrcVisitor {
template
<
typename
ctype
>
template
<
typename
ctype
>
struct
IndexedSrcVisitor
{
struct
IndexedSrcVisitor
{
const
void
*
ptr
;
const
void
*
ptr
;
const
void
**
ptrs
;
const
int
*
idx
;
const
int
*
idx
;
int
N_SRC
;
int
N_SRC
;
...
@@ -79,9 +85,58 @@ struct IndexedSrcVisitor {
...
@@ -79,9 +85,58 @@ struct IndexedSrcVisitor {
return
(
ctype
*
)((
char
*
)
ptr
+
static_cast
<
int64_t
>
(
batch
)
*
static_cast
<
int64_t
>
(
im_size
)
*
CtypeHelper
<
ctype
>::
bit_width
/
8
);
return
(
ctype
*
)((
char
*
)
ptr
+
static_cast
<
int64_t
>
(
batch
)
*
static_cast
<
int64_t
>
(
im_size
)
*
CtypeHelper
<
ctype
>::
bit_width
/
8
);
}
}
__device__
__forceinline__
const
ctype
*
get
(
int
batch
)
{
int
orig_batch
=
batch
;
batch
=
idx
[
batch
];
if
(
batch
<
0
||
batch
>=
N_SRC
)
{
set_async_error_info
(
error_info
,
error_tracker
,
"mat_idx out of bound: mat_idx[%d]=%d src_batch=%d"
,
orig_batch
,
batch
,
N_SRC
);
batch
=
0
;
}
return
(
ctype
*
)(
ptrs
[
batch
]);
}
void
move_batch
(
size_t
batch
,
size_t
)
{
idx
+=
batch
;
}
void
move_batch
(
size_t
batch
,
size_t
)
{
idx
+=
batch
;
}
};
};
template
<
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_general_multi_src
(
SrcVisitor
srcs
,
const
float
*
__restrict
mat
,
ctype
*
__restrict
dst
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
)
{
Getter
getter
;
OutputConverter
output_converter
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
ctype
*
__restrict
sptr
=
srcs
.
get
(
blockIdx
.
z
);
dst
+=
blockIdx
.
z
*
C
*
OH
*
OW
;
mat
+=
blockIdx
.
z
*
3
*
3
;
if
(
ow
<
OW
&&
oh
<
OH
)
{
float
denominator
=
mat
[
6
]
*
ow
+
mat
[
7
]
*
oh
+
mat
[
8
];
float
iw
=
(
mat
[
0
]
*
ow
+
mat
[
1
]
*
oh
+
mat
[
2
])
/
denominator
;
float
ih
=
(
mat
[
3
]
*
ow
+
mat
[
4
]
*
oh
+
mat
[
5
])
/
denominator
;
int
iw0
=
getter
(
floor
(
iw
)
+
0
,
IW
);
int
iw1
=
getter
(
floor
(
iw
)
+
1
,
IW
);
int
ih0
=
getter
(
floor
(
ih
)
+
0
,
IH
);
int
ih1
=
getter
(
floor
(
ih
)
+
1
,
IH
);
float
palpha
=
ih
-
floor
(
ih
);
float
pbeta
=
iw
-
floor
(
iw
);
float
nalpha
=
1.0
f
-
palpha
;
float
nbeta
=
1.0
f
-
pbeta
;
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
dst
[
oh
*
OW
+
ow
]
=
output_converter
(
sptr
[
ih0
*
IW
+
iw0
]
*
nalpha
*
nbeta
+
sptr
[
ih0
*
IW
+
iw1
]
*
nalpha
*
pbeta
+
sptr
[
ih1
*
IW
+
iw0
]
*
palpha
*
nbeta
+
sptr
[
ih1
*
IW
+
iw1
]
*
palpha
*
pbeta
);
sptr
+=
IH
*
IW
;
dst
+=
OH
*
OW
;
}
}
}
template
<
template
<
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
>
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_general
(
__global__
void
kern_general
(
...
@@ -261,6 +316,47 @@ __global__ void kern_general_nchw64(
...
@@ -261,6 +316,47 @@ __global__ void kern_general_nchw64(
}
}
}
}
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_const_border_multi_src
(
SrcVisitor
srcs
,
const
float
*
__restrict
mat
,
ctype
*
__restrict
dst
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
ctype
bval
)
{
OutputConverter
output_converter
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
ctype
*
__restrict
sptr
=
srcs
.
get
(
blockIdx
.
z
);
dst
+=
blockIdx
.
z
*
C
*
OH
*
OW
;
mat
+=
blockIdx
.
z
*
3
*
3
;
if
(
ow
<
OW
&&
oh
<
OH
)
{
float
denominator
=
mat
[
6
]
*
ow
+
mat
[
7
]
*
oh
+
mat
[
8
];
float
iw
=
(
mat
[
0
]
*
ow
+
mat
[
1
]
*
oh
+
mat
[
2
])
/
denominator
;
float
ih
=
(
mat
[
3
]
*
ow
+
mat
[
4
]
*
oh
+
mat
[
5
])
/
denominator
;
int
iw0
=
floor
(
iw
)
+
0
;
int
iw1
=
floor
(
iw
)
+
1
;
int
ih0
=
floor
(
ih
)
+
0
;
int
ih1
=
floor
(
ih
)
+
1
;
bool
okw0
=
(
iw0
>=
0
&&
iw0
<
IW
);
bool
okw1
=
(
iw1
>=
0
&&
iw1
<
IW
);
bool
okh0
=
(
ih0
>=
0
&&
ih0
<
IH
);
bool
okh1
=
(
ih1
>=
0
&&
ih1
<
IH
);
float
palpha
=
ih
-
floor
(
ih
);
float
pbeta
=
iw
-
floor
(
iw
);
float
nalpha
=
1.0
f
-
palpha
;
float
nbeta
=
1.0
f
-
pbeta
;
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
ctype
v00
=
(
okh0
&&
okw0
?
sptr
[
ih0
*
IW
+
iw0
]
:
bval
);
ctype
v01
=
(
okh0
&&
okw1
?
sptr
[
ih0
*
IW
+
iw1
]
:
bval
);
ctype
v10
=
(
okh1
&&
okw0
?
sptr
[
ih1
*
IW
+
iw0
]
:
bval
);
ctype
v11
=
(
okh1
&&
okw1
?
sptr
[
ih1
*
IW
+
iw1
]
:
bval
);
ctype
val
=
output_converter
(
v00
*
nalpha
*
nbeta
+
v01
*
nalpha
*
pbeta
+
v10
*
palpha
*
nbeta
+
v11
*
palpha
*
pbeta
);
dst
[
oh
*
OW
+
ow
]
=
val
;
sptr
+=
IH
*
IW
;
dst
+=
OH
*
OW
;
}
}
}
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
template
<
typename
ctype
,
typename
SrcVisitor
,
typename
OutputConverter
>
__global__
void
kern_const_border
(
__global__
void
kern_const_border
(
SrcVisitor
src
,
const
float
*
__restrict
mat
,
ctype
*
__restrict
dst
,
int
C
,
SrcVisitor
src
,
const
float
*
__restrict
mat
,
ctype
*
__restrict
dst
,
int
C
,
...
@@ -553,6 +649,51 @@ struct KernCoreNHWC<ctype, OutputConverter, 16> {
...
@@ -553,6 +649,51 @@ struct KernCoreNHWC<ctype, OutputConverter, 16> {
}
}
};
};
template
<
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
,
int
pack_c
>
__global__
void
kern_general_nhwc_multi_src
(
SrcVisitor
srcs
,
const
float
*
__restrict
mat
,
ctype
*
__restrict
dst
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
)
{
Getter
getter
;
OutputConverter
output_converter
;
constexpr
int
bit_width
=
CtypeHelper
<
ctype
>::
bit_width
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
ctype
*
__restrict
sptr
=
srcs
.
get
(
blockIdx
.
z
);
dst
=
(
ctype
*
)((
char
*
)
dst
+
blockIdx
.
z
*
C
*
OH
*
OW
*
bit_width
/
8
);
mat
+=
blockIdx
.
z
*
3
*
3
;
if
(
ow
<
OW
&&
oh
<
OH
)
{
float
denominator
=
mat
[
6
]
*
ow
+
mat
[
7
]
*
oh
+
mat
[
8
];
float
iw
=
(
mat
[
0
]
*
ow
+
mat
[
1
]
*
oh
+
mat
[
2
])
/
denominator
;
float
ih
=
(
mat
[
3
]
*
ow
+
mat
[
4
]
*
oh
+
mat
[
5
])
/
denominator
;
int
iw0
=
getter
(
floor
(
iw
)
+
0
,
IW
);
int
iw1
=
getter
(
floor
(
iw
)
+
1
,
IW
);
int
ih0
=
getter
(
floor
(
ih
)
+
0
,
IH
);
int
ih1
=
getter
(
floor
(
ih
)
+
1
,
IH
);
float
palpha
=
ih
-
floor
(
ih
);
float
pbeta
=
iw
-
floor
(
iw
);
float
nalpha
=
1.0
f
-
palpha
;
float
nbeta
=
1.0
f
-
pbeta
;
float
w00
=
nalpha
*
nbeta
;
float
w01
=
nalpha
*
pbeta
;
float
w10
=
palpha
*
nbeta
;
float
w11
=
palpha
*
pbeta
;
const
char
*
src_ptr0
=
(
char
*
)
sptr
+
(
ih0
*
IW
+
iw0
)
*
C
*
bit_width
/
8
;
const
char
*
src_ptr1
=
(
char
*
)
sptr
+
(
ih0
*
IW
+
iw1
)
*
C
*
bit_width
/
8
;
const
char
*
src_ptr2
=
(
char
*
)
sptr
+
(
ih1
*
IW
+
iw0
)
*
C
*
bit_width
/
8
;
const
char
*
src_ptr3
=
(
char
*
)
sptr
+
(
ih1
*
IW
+
iw1
)
*
C
*
bit_width
/
8
;
char
*
dst_ptr
=
(
char
*
)
dst
+
(
oh
*
OW
+
ow
)
*
C
*
bit_width
/
8
;
for
(
int
c
=
0
;
c
<
C
;
c
+=
pack_c
)
{
KernCoreNHWC
<
ctype
,
OutputConverter
,
pack_c
>::
func
(
dst_ptr
,
src_ptr0
,
src_ptr1
,
src_ptr2
,
src_ptr3
,
c
*
bit_width
/
8
,
w00
,
w01
,
w10
,
w11
,
output_converter
,
true
,
true
,
true
,
true
,
(
ctype
)
0
);
}
}
}
template
<
template
<
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
,
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
,
int
pack_c
>
int
pack_c
>
...
@@ -598,6 +739,58 @@ __global__ void kern_general_nhwc(
...
@@ -598,6 +739,58 @@ __global__ void kern_general_nhwc(
}
}
}
}
template
<
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
,
int
pack_c
>
__global__
void
kern_general_nhwc_const_multi_src
(
SrcVisitor
srcs
,
const
float
*
__restrict
mat
,
ctype
*
__restrict
dst
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
ctype
bval
)
{
Getter
getter
;
OutputConverter
output_converter
;
constexpr
int
bit_width
=
CtypeHelper
<
ctype
>::
bit_width
;
int
ow
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
oh
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
ctype
*
__restrict
sptr
=
srcs
.
get
(
blockIdx
.
z
);
dst
=
(
ctype
*
)((
char
*
)
dst
+
blockIdx
.
z
*
C
*
OH
*
OW
*
bit_width
/
8
);
mat
+=
blockIdx
.
z
*
3
*
3
;
if
(
ow
<
OW
&&
oh
<
OH
)
{
float
denominator
=
mat
[
6
]
*
ow
+
mat
[
7
]
*
oh
+
mat
[
8
];
float
iw
=
(
mat
[
0
]
*
ow
+
mat
[
1
]
*
oh
+
mat
[
2
])
/
denominator
;
float
ih
=
(
mat
[
3
]
*
ow
+
mat
[
4
]
*
oh
+
mat
[
5
])
/
denominator
;
int
iw0
=
getter
(
floor
(
iw
)
+
0
,
IW
);
int
iw1
=
getter
(
floor
(
iw
)
+
1
,
IW
);
int
ih0
=
getter
(
floor
(
ih
)
+
0
,
IH
);
int
ih1
=
getter
(
floor
(
ih
)
+
1
,
IH
);
float
palpha
=
ih
-
floor
(
ih
);
float
pbeta
=
iw
-
floor
(
iw
);
float
nalpha
=
1.0
f
-
palpha
;
float
nbeta
=
1.0
f
-
pbeta
;
float
w00
=
nalpha
*
nbeta
;
float
w01
=
nalpha
*
pbeta
;
float
w10
=
palpha
*
nbeta
;
float
w11
=
palpha
*
pbeta
;
const
char
*
src_ptr0
=
(
char
*
)
sptr
+
(
ih0
*
IW
+
iw0
)
*
C
*
bit_width
/
8
;
const
char
*
src_ptr1
=
(
char
*
)
sptr
+
(
ih0
*
IW
+
iw1
)
*
C
*
bit_width
/
8
;
const
char
*
src_ptr2
=
(
char
*
)
sptr
+
(
ih1
*
IW
+
iw0
)
*
C
*
bit_width
/
8
;
const
char
*
src_ptr3
=
(
char
*
)
sptr
+
(
ih1
*
IW
+
iw1
)
*
C
*
bit_width
/
8
;
char
*
dst_ptr
=
(
char
*
)
dst
+
(
oh
*
OW
+
ow
)
*
C
*
bit_width
/
8
;
bool
okw0
=
(
iw0
>=
0
&&
iw0
<
IW
);
bool
okw1
=
(
iw1
>=
0
&&
iw1
<
IW
);
bool
okh0
=
(
ih0
>=
0
&&
ih0
<
IH
);
bool
okh1
=
(
ih1
>=
0
&&
ih1
<
IH
);
bool
src0_ok
=
okh0
&&
okw0
;
bool
src1_ok
=
okh0
&&
okw1
;
bool
src2_ok
=
okh1
&&
okw0
;
bool
src3_ok
=
okh1
&&
okw1
;
for
(
int
c
=
0
;
c
<
C
;
c
+=
pack_c
)
{
KernCoreNHWC
<
ctype
,
OutputConverter
,
pack_c
>::
func
(
dst_ptr
,
src_ptr0
,
src_ptr1
,
src_ptr2
,
src_ptr3
,
c
*
bit_width
/
8
,
w00
,
w01
,
w10
,
w11
,
output_converter
,
src0_ok
,
src1_ok
,
src2_ok
,
src3_ok
,
bval
);
}
}
}
template
<
template
<
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
,
typename
ctype
,
typename
Getter
,
typename
SrcVisitor
,
typename
OutputConverter
,
int
pack_c
>
int
pack_c
>
...
@@ -650,6 +843,73 @@ __global__ void kern_general_nhwc_const(
...
@@ -650,6 +843,73 @@ __global__ void kern_general_nhwc_const(
}
}
}
}
template
<
typename
ctype
,
typename
SrcVisitor
>
void
dispatch_with_visitor_multi_src
(
bool
is_nhwc
,
SrcVisitor
srcs
,
const
float
*
mat
,
ctype
*
dst
,
int
N
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
ctype
bval
,
BorderMode
bmode
,
cudaStream_t
stream
)
{
constexpr
int
pack_c
=
1
;
const
int
BY
=
16
,
BX
=
32
;
#define DISPATCH(Getter) \
do { \
if (is_nhwc) { \
kern_general_nhwc_multi_src< \
ctype, Getter, SrcVisitor, rounding::RoundingConverter<ctype>, \
pack_c><<<blocks, threads, 0, stream>>>( \
srcs, mat, dst, C, IH, IW, OH, OW); \
} else { \
kern_general_multi_src< \
ctype, Getter, SrcVisitor, rounding::RoundingConverter<ctype>> \
<<<blocks, threads, 0, stream>>>( \
srcs, mat, dst, C, IH, IW, OH, OW); \
} \
} while (0)
const
int
max_batch_size
=
65535
;
while
(
N
)
{
size_t
curr_batch_size
=
N
<
max_batch_size
?
N
:
max_batch_size
;
dim3
threads
(
BX
,
BY
);
dim3
blocks
((
OW
+
BX
-
1
)
/
BX
,
(
OH
+
BY
-
1
)
/
BY
,
curr_batch_size
);
switch
(
bmode
)
{
case
BORDER_REPLICATE
:
DISPATCH
(
ReplicateGetter
);
break
;
case
BORDER_REFLECT
:
DISPATCH
(
ReflectGetter
);
break
;
case
BORDER_REFLECT_101
:
DISPATCH
(
Reflect101Getter
);
break
;
case
BORDER_WRAP
:
DISPATCH
(
WrapGetter
);
break
;
#undef DISPATCH
case
BORDER_CONSTANT
:
if
(
is_nhwc
)
{
kern_general_nhwc_const_multi_src
<
ctype
,
ConstGetter
,
SrcVisitor
,
rounding
::
RoundingConverter
<
ctype
>
,
pack_c
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
srcs
,
mat
,
dst
,
C
,
IH
,
IW
,
OH
,
OW
,
bval
);
}
else
{
kern_const_border_multi_src
<
ctype
,
SrcVisitor
,
rounding
::
RoundingConverter
<
ctype
>>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
srcs
,
mat
,
dst
,
C
,
IH
,
IW
,
OH
,
OW
,
bval
);
}
break
;
default:
break
;
}
N
-=
curr_batch_size
;
srcs
.
move_batch
(
curr_batch_size
,
C
*
IH
*
IW
);
mat
+=
curr_batch_size
*
3
*
3
;
dst
+=
curr_batch_size
*
C
*
OH
*
OW
;
}
}
template
<
typename
ctype
,
typename
SrcVisitor
>
template
<
typename
ctype
,
typename
SrcVisitor
>
void
dispatch_with_visitor
(
void
dispatch_with_visitor
(
bool
is_nhwc
,
SrcVisitor
src
,
const
float
*
mat
,
ctype
*
dst
,
int
N
,
int
C
,
bool
is_nhwc
,
SrcVisitor
src
,
const
float
*
mat
,
ctype
*
dst
,
int
N
,
int
C
,
...
@@ -1534,6 +1794,33 @@ void dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw(
...
@@ -1534,6 +1794,33 @@ void dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw(
namespace
megdnn
{
namespace
megdnn
{
namespace
cuda
{
namespace
cuda
{
namespace
warp_perspective
{
namespace
warp_perspective
{
template
<
typename
ctype
>
void
forward_proxy_multi_src
(
bool
is_nhwc
,
const
ctype
**
srcs
,
const
float
*
mat
,
const
int
*
mat_idx
,
ctype
*
dst
,
int
N_SRC
,
int
N_MAT
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
ctype
bval
,
BorderMode
bmode
,
megcore
::
AsyncErrorInfo
*
error_info
,
void
*
error_tracker
,
cudaStream_t
stream
)
{
if
(
mat_idx
)
{
IndexedSrcVisitor
<
ctype
>
visitor
;
visitor
.
ptrs
=
reinterpret_cast
<
const
void
**>
(
srcs
);
visitor
.
ptr
=
srcs
;
visitor
.
idx
=
mat_idx
;
visitor
.
N_SRC
=
N_SRC
;
visitor
.
error_info
=
error_info
;
visitor
.
error_tracker
=
error_tracker
;
dispatch_with_visitor_multi_src
(
is_nhwc
,
visitor
,
mat
,
dst
,
N_MAT
,
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
stream
);
}
else
{
DirectSrcVisitor
<
ctype
>
visitor
;
visitor
.
ptrs
=
reinterpret_cast
<
const
void
**>
(
srcs
);
visitor
.
ptr
=
srcs
;
dispatch_with_visitor_multi_src
(
is_nhwc
,
visitor
,
mat
,
dst
,
N_MAT
,
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
stream
);
}
after_kernel_launch
();
}
template
<
typename
ctype
>
template
<
typename
ctype
>
void
forward_proxy
(
void
forward_proxy
(
...
@@ -1643,6 +1930,17 @@ INST(dt_float16)
...
@@ -1643,6 +1930,17 @@ INST(dt_float16)
INST
(
int8_t
)
INST
(
int8_t
)
#undef INST
#undef INST
#define INST(ctype) \
template void forward_proxy_multi_src( \
bool, const ctype**, const float*, const int*, ctype*, int, int, int, int, \
int, int, int, ctype, BorderMode, megcore::AsyncErrorInfo*, void*, \
cudaStream_t);
INST
(
float
)
#ifndef MEGDNN_DISABLE_FLOAT16
INST
(
dt_float16
)
#endif
#undef INST
#define INST(ctype) \
#define INST(ctype) \
template void forward_proxy_nchw4( \
template void forward_proxy_nchw4( \
const ctype*, const float*, const int*, ctype*, int, int, int, int, int, \
const ctype*, const float*, const int*, ctype*, int, int, int, int, int, \
...
...
dnn/src/cuda/warp_perspective/opr_impl.h
浏览文件 @
669816e2
...
@@ -15,12 +15,22 @@ public:
...
@@ -15,12 +15,22 @@ public:
void
exec
(
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
void
exec
(
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
override
{
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
override
{
return
get_workspace_bundle
(
nullptr
,
src
,
mat
,
mat_idx
,
dst
)
return
get_workspace_bundle
(
nullptr
,
src
,
mat
,
mat_idx
,
dst
)
.
total_size_in_bytes
();
.
total_size_in_bytes
();
}
}
size_t
get_workspace_in_bytes
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
override
{
return
get_workspace_bundle
(
nullptr
,
srcs
,
mat
,
mat_idx
,
dst
)
.
total_size_in_bytes
();
}
void
set_error_tracker
(
void
*
tracker
)
override
{
m_error_tracker
=
tracker
;
}
void
set_error_tracker
(
void
*
tracker
)
override
{
m_error_tracker
=
tracker
;
}
...
@@ -28,6 +38,9 @@ private:
...
@@ -28,6 +38,9 @@ private:
WorkspaceBundle
get_workspace_bundle
(
WorkspaceBundle
get_workspace_bundle
(
void
*
ptr
,
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
void
*
ptr
,
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
const
;
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
const
;
WorkspaceBundle
get_workspace_bundle
(
void
*
ptr
,
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
const
;
};
};
class
WarpPerspectiveBackwardDataImpl
final
:
public
WarpPerspectiveBackwardData
{
class
WarpPerspectiveBackwardDataImpl
final
:
public
WarpPerspectiveBackwardData
{
...
...
dnn/src/fallback/warp_perspective/opr_impl.cpp
浏览文件 @
669816e2
...
@@ -51,6 +51,56 @@ size_t WarpPerspectiveImpl::get_workspace_in_bytes(
...
@@ -51,6 +51,56 @@ size_t WarpPerspectiveImpl::get_workspace_in_bytes(
}
}
}
}
size_t
WarpPerspectiveImpl
::
get_workspace_in_bytes
(
const
TensorLayoutArray
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
dst
)
{
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
)
{
size_t
OH
=
dst
.
shape
[
2
],
OW
=
dst
.
shape
[
3
];
return
get_bundle
(
OH
,
OW
).
total_size_in_bytes
();
}
else
{
return
0
;
}
}
void
WarpPerspectiveImpl
::
exec
(
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
dst
,
_megdnn_workspace
workspace
)
{
TensorLayoutArray
srcs_layout
;
for
(
auto
&&
src
:
srcs
)
{
srcs_layout
.
push_back
(
src
.
layout
);
}
check_exec_allow_nhwc_mat_idx
(
srcs_layout
,
mat
.
layout
,
mat_idx
.
layout
,
dst
.
layout
,
workspace
.
size
);
size_t
nr_threads
=
static_cast
<
naive
::
HandleImpl
*>
(
handle
())
->
megcore_dispatcher
()
->
nr_threads
();
if
(
param
().
format
==
Format
::
NCHW
&&
nr_threads
==
1
_z
)
{
#define cb(dt, ct, mct) \
case DTypeTrait<dt>::enumv: { \
auto kparam = KernParam<ct, mct>::from_tensors( \
param().format, param().bmode, param().border_val, srcs, mat, mat_idx, \
dst, workspace); \
MIDOUT_BEGIN(megdnn_fallback_warpperspective, midout_iv(0), dt, ct, mct) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback_multi_src(kparam)); \
return; \
} \
MIDOUT_END(); \
}
switch
(
srcs
.
front
().
layout
.
dtype
.
enumv
())
{
cb
(
dtype
::
Float32
,
float
,
float
);
DNN_INC_FLOAT16
(
cb
(
dtype
::
Float16
,
dt_float16
,
float
));
default:
megdnn_throw
(
ssprintf
(
"Unsupported input DType in "
"WarpPerspective: %s"
,
srcs
.
front
().
layout
.
dtype
.
name
())
.
c_str
());
}
#undef cb
}
naive
::
WarpPerspectiveForwardImpl
::
exec
(
srcs
,
mat
,
mat_idx
,
dst
,
workspace
);
}
void
WarpPerspectiveImpl
::
exec
(
void
WarpPerspectiveImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_in
dst
,
_megdnn_workspace
workspace
)
{
...
@@ -95,6 +145,69 @@ void WarpPerspectiveImpl::exec(
...
@@ -95,6 +145,69 @@ void WarpPerspectiveImpl::exec(
naive
::
WarpPerspectiveForwardImpl
::
exec
(
src
,
mat
,
mat_idx
,
dst
,
workspace
);
naive
::
WarpPerspectiveForwardImpl
::
exec
(
src
,
mat
,
mat_idx
,
dst
,
workspace
);
}
}
template
<
typename
ctype
,
typename
mtype
>
void
WarpPerspectiveImpl
::
kern_fallback_multi_src
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
)
{
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM
(
kern_param
);
// cause error if accidentally used
sptr
=
nullptr
;
mptr
=
nullptr
;
dptr
=
nullptr
;
MEGDNN_MARK_USED_VAR
(
sptr
);
MEGDNN_MARK_USED_VAR
(
mptr
);
MEGDNN_MARK_USED_VAR
(
dptr
);
MEGDNN_MARK_USED_VAR
(
border_val
);
MEGDNN_MARK_USED_VAR
(
IH
);
MEGDNN_MARK_USED_VAR
(
IW
);
KernParam
<
ctype
,
mtype
>
sub_param
=
kern_param
;
sub_param
.
n_src
=
1
;
sub_param
.
n_mat
=
1
;
sub_param
.
midx_ptr
=
RefPtr
();
sub_param
.
src_ptr
=
RefPtr
(
kern_param
.
srcs_ptr
.
front
().
get_ptr
());
sub_param
.
mat_ptr
=
RefPtr
(
kern_param
.
mat_ptr
.
get_ptr
());
sub_param
.
dst_ptr
=
RefPtr
(
kern_param
.
dst_ptr
.
get_ptr
());
sub_param
.
srcs_ptr
=
kern_param
.
srcs_ptr
;
rep
(
n
,
N_MAT
)
{
if
(
midx_ptr
)
{
size_t
idx
=
midx_ptr
[
n
];
megdnn_assert
(
idx
<
N_SRC
,
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu"
,
n
,
idx
,
N_SRC
);
sub_param
.
src_ptr
.
reset
(
static_cast
<
ctype
*>
(
kern_param
.
srcs_ptr
[
idx
].
get_ptr
()));
}
else
if
(
n
)
{
sub_param
.
src_ptr
.
reset
(
static_cast
<
ctype
*>
(
kern_param
.
srcs_ptr
[
n
].
get_ptr
()));
}
if
(
is_resize_optimizable
(
static_cast
<
mtype
*>
(
sub_param
.
mat_ptr
.
get_ptr
())))
{
if
(
bmode
==
BorderMode
::
CONSTANT
)
{
MIDOUT_BEGIN
(
megdnn_fallback_warpperspective
,
midout_iv
(
1
),
midout_iv
(
true
),
ctype
,
mtype
)
{
kern_resize
<
true
,
ctype
,
mtype
>
(
sub_param
);
}
MIDOUT_END
();
}
else
{
MIDOUT_BEGIN
(
megdnn_fallback_warpperspective
,
midout_iv
(
1
),
midout_iv
(
false
),
ctype
,
mtype
)
{
kern_resize
<
false
,
ctype
,
mtype
>
(
sub_param
);
}
MIDOUT_END
();
}
}
else
{
MIDOUT_BEGIN
(
megdnn_fallback_warpperspective
,
midout_iv
(
2
),
ctype
,
mtype
)
{
rep
(
oh
,
OH
)
kern_naive
<
ctype
,
mtype
>
(
sub_param
,
oh
);
}
MIDOUT_END
();
}
sub_param
.
mat_ptr
+=
3
*
3
*
sizeof
(
mtype
);
sub_param
.
dst_ptr
+=
C
*
OH
*
OW
*
sizeof
(
ctype
);
}
}
template
<
typename
ctype
,
typename
mtype
>
template
<
typename
ctype
,
typename
mtype
>
void
WarpPerspectiveImpl
::
kern_fallback
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
)
{
void
WarpPerspectiveImpl
::
kern_fallback
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
)
{
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM
(
kern_param
);
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM
(
kern_param
);
...
...
dnn/src/fallback/warp_perspective/opr_impl.h
浏览文件 @
669816e2
...
@@ -9,14 +9,24 @@ protected:
...
@@ -9,14 +9,24 @@ protected:
template
<
typename
ctype
,
typename
mtype
>
template
<
typename
ctype
,
typename
mtype
>
void
kern_fallback
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
);
void
kern_fallback
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
);
template
<
typename
ctype
,
typename
mtype
>
void
kern_fallback_multi_src
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
);
public:
public:
using
naive
::
WarpPerspectiveForwardImpl
::
WarpPerspectiveForwardImpl
;
using
naive
::
WarpPerspectiveForwardImpl
::
WarpPerspectiveForwardImpl
;
size_t
get_workspace_in_bytes
(
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
override
;
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
override
;
void
exec
(
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
void
exec
(
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
private:
private:
template
<
typename
ctype
>
template
<
typename
ctype
>
...
...
dnn/src/naive/warp_perspective/opr_impl.cpp
浏览文件 @
669816e2
...
@@ -14,6 +14,119 @@ MIDOUT_DECL(megdnn_naive_warpperspective)
...
@@ -14,6 +14,119 @@ MIDOUT_DECL(megdnn_naive_warpperspective)
using
namespace
megdnn
;
using
namespace
megdnn
;
using
namespace
naive
;
using
namespace
naive
;
template
<
typename
ctype
,
typename
mtype
>
void
WarpPerspectiveForwardImpl
::
kern_naive_multi_src
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
)
{
MEGDNN_MARK_USED_VAR
(
kern_param
);
MIDOUT_BEGIN
(
megdnn_naive_warpperspective
,
ctype
,
mtype
,
midout_iv
(
0
))
{
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM
(
kern_param
);
MEGDNN_MARK_USED_VAR
(
N_MAT
);
//! strides of C, H, W on src and dst
size_t
sstrd
[
3
],
dstrd
[
3
];
auto
set_sstrd
=
[
&
](
size_t
s0
,
size_t
s1
,
size_t
s2
)
{
sstrd
[
0
]
=
s0
;
sstrd
[
1
]
=
s1
;
sstrd
[
2
]
=
s2
;
};
auto
set_dstrd
=
[
&
](
size_t
s0
,
size_t
s1
,
size_t
s2
)
{
dstrd
[
0
]
=
s0
;
dstrd
[
1
]
=
s1
;
dstrd
[
2
]
=
s2
;
};
switch
(
kern_param
.
format
)
{
case
Format
::
NCHW
:
set_sstrd
(
IH
*
IW
,
IW
,
1
);
set_dstrd
(
OH
*
OW
,
OW
,
1
);
break
;
case
Format
::
NHWC
:
set_sstrd
(
1
,
IW
*
C
,
C
);
set_dstrd
(
1
,
OW
*
C
,
C
);
break
;
default:
megdnn_throw
(
"bad format"
);
}
auto
visit_src
=
[
&
sptr
,
sstrd
](
size_t
c
,
int
h
,
int
w
)
->
float
{
return
sptr
[
sstrd
[
0
]
*
c
+
sstrd
[
1
]
*
h
+
sstrd
[
2
]
*
w
];
};
auto
visit_src_bd
=
[
&
sptr
,
sstrd
,
border_val
](
size_t
c
,
int
h
,
int
w
)
->
float
{
if
(
h
!=
-
1
&&
w
!=
-
1
)
{
return
sptr
[
sstrd
[
0
]
*
c
+
sstrd
[
1
]
*
h
+
sstrd
[
2
]
*
w
];
}
else
return
border_val
;
};
auto
visit_dst
=
[
&
dptr
,
dstrd
](
size_t
c
,
int
h
,
int
w
)
->
ctype
&
{
return
dptr
[
dstrd
[
0
]
*
c
+
dstrd
[
1
]
*
h
+
dstrd
[
2
]
*
w
];
};
rounding
::
RoundingConverter
<
ctype
>
output_converter
;
sptr
=
static_cast
<
const
ctype
*>
(
kern_param
.
srcs_ptr
.
front
().
get_ptr
());
size_t
n
=
task_id
/
OH
;
size_t
oh
=
task_id
%
OH
;
mptr
=
mptr
+
n
*
3
*
3
;
dptr
=
dptr
+
n
*
C
*
OH
*
OW
;
if
(
midx_ptr
)
{
size_t
idx
=
midx_ptr
[
n
];
megdnn_assert
(
idx
<
N_SRC
,
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu"
,
n
,
idx
,
N_SRC
);
sptr
=
sptrs
[
idx
];
}
else
if
(
n
)
{
sptr
=
sptrs
[
n
];
}
rep
(
ow
,
OW
)
{
float
numeratorw
=
mptr
[
0
]
*
ow
+
mptr
[
1
]
*
oh
+
mptr
[
2
];
float
numeratorh
=
mptr
[
3
]
*
ow
+
mptr
[
4
]
*
oh
+
mptr
[
5
];
float
denominator
=
mptr
[
6
]
*
ow
+
mptr
[
7
]
*
oh
+
mptr
[
8
];
float
alphaw
=
numeratorw
/
denominator
;
float
alphah
=
numeratorh
/
denominator
;
int
iw0
=
get_real_coord
(
std
::
floor
(
alphaw
)
+
0
,
IW
);
int
iw1
=
get_real_coord
(
std
::
floor
(
alphaw
)
+
1
,
IW
);
int
ih0
=
get_real_coord
(
std
::
floor
(
alphah
)
+
0
,
IH
);
int
ih1
=
get_real_coord
(
std
::
floor
(
alphah
)
+
1
,
IH
);
alphaw
-=
floor
(
alphaw
);
alphah
-=
floor
(
alphah
);
if
(
bmode
!=
BorderMode
::
CONSTANT
)
{
rep
(
c
,
C
)
{
visit_dst
(
c
,
oh
,
ow
)
=
output_converter
(
visit_src
(
c
,
ih0
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
(
1.0
f
-
alphah
)
+
visit_src
(
c
,
ih0
,
iw1
)
*
alphaw
*
(
1.0
f
-
alphah
)
+
visit_src
(
c
,
ih1
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
alphah
+
visit_src
(
c
,
ih1
,
iw1
)
*
alphaw
*
alphah
);
}
}
else
{
rep
(
c
,
C
)
{
auto
val
=
visit_src_bd
(
c
,
ih0
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
(
1.0
f
-
alphah
)
+
visit_src_bd
(
c
,
ih0
,
iw1
)
*
alphaw
*
(
1.0
f
-
alphah
)
+
visit_src_bd
(
c
,
ih1
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
alphah
+
visit_src_bd
(
c
,
ih1
,
iw1
)
*
alphaw
*
alphah
;
visit_dst
(
c
,
oh
,
ow
)
=
output_converter
(
std
::
isfinite
(
val
)
?
val
:
border_val
);
}
}
}
}
MIDOUT_END
();
}
#define INST(ctype, mtype) \
template void WarpPerspectiveForwardImpl::kern_naive_multi_src<ctype, mtype>( \
const KernParam<ctype, mtype>&, size_t);
INST
(
float
,
float
);
#if !MEGDNN_DISABLE_FLOAT16
INST
(
dt_float16
,
float
);
INST
(
dt_float16
,
dt_float16
);
INST
(
dt_bfloat16
,
float
);
INST
(
dt_bfloat16
,
dt_bfloat16
);
#endif
#undef INST
template
<
typename
ctype
,
typename
mtype
>
template
<
typename
ctype
,
typename
mtype
>
void
WarpPerspectiveForwardImpl
::
kern_naive
(
void
WarpPerspectiveForwardImpl
::
kern_naive
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
)
{
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
)
{
...
@@ -504,6 +617,71 @@ INST(uint8_t, float, float);
...
@@ -504,6 +617,71 @@ INST(uint8_t, float, float);
#undef INST
#undef INST
void
WarpPerspectiveForwardImpl
::
exec
(
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
TensorLayoutArray
srcs_layout
;
for
(
auto
&&
src
:
srcs
)
{
srcs_layout
.
push_back
(
src
.
layout
);
}
check_exec_allow_nhwc_mat_idx
(
srcs_layout
,
mat
.
layout
,
mat_idx
.
layout
,
dst
.
layout
,
workspace
.
size
);
size_t
batch
=
dst
.
layout
[
0
];
#define KERN_NAIVE_MULTI_SRC(ct, mct) \
auto kparam = KernParam<ct, mct>::from_tensors( \
param().format, param().bmode, param().border_val, srcs, mat, mat_idx, \
dst, workspace); \
auto run = [kparam, this](size_t index, size_t) { \
kern_naive_multi_src(kparam, index); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, kparam.oh* batch);
#define DISPATCH_ST_MULTI_SRC(dt, ct, mct, kern) \
if (srcs.front().layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
kern(ct, mct); \
return; \
}
#define DISPATCH_ST_MT_MULTI_SRC(dt, ct, kern) \
if (srcs.front().layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
if (mat.layout.dtype.enumv() == DTypeTrait<dtype::Float32>::enumv) { \
kern(ct, float); \
return; \
} else { \
kern(ct, ct); \
return; \
} \
}
megdnn_assert
(
warp
::
is_dnn_available
(
srcs_layout
,
mat
.
layout
,
dst
.
layout
,
param
().
imode
,
param
().
format
));
/*!
* We currently use floating point for all WarpPerspective
* computation, so even if the input ctype is one of the integer
* type, mtype should always be float32.
*
* \warning It's different with \c WarpAffine, with mtype be float16
* if input type is float16.
*/
DISPATCH_ST_MULTI_SRC
(
dtype
::
Float32
,
float
,
float
,
KERN_NAIVE_MULTI_SRC
);
DNN_INC_FLOAT16
(
DISPATCH_ST_MT_MULTI_SRC
(
dtype
::
Float16
,
dt_float16
,
KERN_NAIVE_MULTI_SRC
));
DNN_INC_FLOAT16
(
DISPATCH_ST_MT_MULTI_SRC
(
dtype
::
BFloat16
,
dt_bfloat16
,
KERN_NAIVE_MULTI_SRC
));
megdnn_throw
(
ssprintf
(
"Unsupported input DType in "
"WarpPerspective: %s"
,
srcs
.
front
().
layout
.
dtype
.
name
())
.
c_str
());
#undef KERN_NAIVE_MULTI_SRC
#undef DISPATCH_ST_MT_MULTI_SRC
#undef DISPATCH_ST_MULTI_SRC
}
void
WarpPerspectiveForwardImpl
::
exec
(
void
WarpPerspectiveForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
...
...
dnn/src/naive/warp_perspective/opr_impl.h
浏览文件 @
669816e2
...
@@ -17,8 +17,70 @@ protected:
...
@@ -17,8 +17,70 @@ protected:
DType
src_dtype
,
dst_dtype
;
DType
src_dtype
,
dst_dtype
;
RefPtr
src_ptr
,
mat_ptr
,
dst_ptr
;
RefPtr
src_ptr
,
mat_ptr
,
dst_ptr
;
RefPtr
midx_ptr
;
//!< can be null
RefPtr
midx_ptr
;
//!< can be null
SmallVector
<
RefPtr
>
srcs_ptr
;
Workspace
workspace
;
Workspace
workspace
;
static
KernParam
from_tensors
(
Format
format
,
BorderMode
bmode
,
float
border_val
,
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
auto
src
=
srcs
.
front
();
KernParam
ret
;
ret
.
format
=
format
;
ret
.
bmode
=
bmode
;
ret
.
border_val
=
border_val
;
ret
.
n_src
=
srcs
.
size
();
ret
.
src_dtype
=
src
.
layout
.
dtype
;
ret
.
dst_dtype
=
dst
.
layout
.
dtype
;
if
(
mat_idx
.
raw_ptr
())
{
megdnn_assert
(
mat_idx
.
layout
.
ndim
==
1
);
ret
.
n_mat
=
mat_idx
.
layout
.
shape
[
0
];
ret
.
midx_ptr
=
mat_idx
.
get_ref_ptr
();
}
else
{
megdnn_assert
(
mat_idx
.
layout
.
ndim
==
0
);
ret
.
n_mat
=
ret
.
n_src
;
ret
.
midx_ptr
=
nullptr
;
}
if
(
format
==
Format
::
NCHW
)
{
ret
.
c
=
src
.
layout
.
shape
[
1
];
ret
.
ih
=
src
.
layout
.
shape
[
2
];
ret
.
iw
=
src
.
layout
.
shape
[
3
];
ret
.
oh
=
dst
.
layout
.
shape
[
2
];
ret
.
ow
=
dst
.
layout
.
shape
[
3
];
}
else
{
megdnn_assert
(
format
==
Format
::
NHWC
);
ret
.
c
=
src
.
layout
.
shape
[
3
];
ret
.
ih
=
src
.
layout
.
shape
[
1
];
ret
.
iw
=
src
.
layout
.
shape
[
2
];
ret
.
oh
=
dst
.
layout
.
shape
[
1
];
ret
.
ow
=
dst
.
layout
.
shape
[
2
];
}
if
((
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
||
DNN_FLOAT16_SELECT
(
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Float16
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
BFloat16
),
false
))
&&
(
src
.
layout
.
dtype
==
dst
.
layout
.
dtype
))
{
for
(
auto
&&
s
:
srcs
)
{
ret
.
srcs_ptr
.
push_back
(
s
.
get_ref_ptr
());
}
ret
.
mat_ptr
=
mat
.
get_ref_ptr
();
ret
.
dst_ptr
=
dst
.
get_ref_ptr
();
}
else
{
for
(
size_t
i
=
0
;
i
<
srcs
.
size
();
i
++
)
{
ret
.
srcs_ptr
.
push_back
(
nullptr
);
}
ret
.
mat_ptr
=
nullptr
;
ret
.
dst_ptr
=
nullptr
;
}
ret
.
src_ptr
=
nullptr
;
ret
.
workspace
=
workspace
;
return
ret
;
}
static
KernParam
from_tensors
(
static
KernParam
from_tensors
(
Format
format
,
BorderMode
bmode
,
float
border_val
,
Format
format
,
BorderMode
bmode
,
float
border_val
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
...
@@ -124,16 +186,29 @@ protected:
...
@@ -124,16 +186,29 @@ protected:
template
<
typename
ctype
,
typename
mtype
>
template
<
typename
ctype
,
typename
mtype
>
void
kern_naive
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
);
void
kern_naive
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
);
template
<
typename
ctype
,
typename
mtype
>
void
kern_naive_multi_src
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
);
public:
public:
using
WarpPerspectiveForward
::
WarpPerspectiveForward
;
using
WarpPerspectiveForward
::
WarpPerspectiveForward
;
void
exec
(
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
void
exec
(
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
const
TensorLayout
&
)
override
{
return
0
;
return
0
;
}
}
size_t
get_workspace_in_bytes
(
const
TensorLayoutArray
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
private:
private:
template
<
typename
ctype
,
typename
mtype
>
template
<
typename
ctype
,
typename
mtype
>
...
@@ -253,6 +328,10 @@ private:
...
@@ -253,6 +328,10 @@ private:
auto mptr = static_cast<const mtype*>(p.mat_ptr.get_ptr()); \
auto mptr = static_cast<const mtype*>(p.mat_ptr.get_ptr()); \
auto dptr = static_cast<ctype*>(p.dst_ptr.get_ptr()); \
auto dptr = static_cast<ctype*>(p.dst_ptr.get_ptr()); \
auto midx_ptr = static_cast<int*>(p.midx_ptr.get_ptr()); \
auto midx_ptr = static_cast<int*>(p.midx_ptr.get_ptr()); \
SmallVector<const ctype*> sptrs; \
for (auto&& s_ptr : p.srcs_ptr) { \
sptrs.push_back(static_cast<const ctype*>(s_ptr.get_ptr())); \
} \
auto bmode = p.bmode; \
auto bmode = p.bmode; \
float border_val = p.border_val
float border_val = p.border_val
...
...
dnn/test/common/warp_perspective.cpp
浏览文件 @
669816e2
...
@@ -50,6 +50,54 @@ void WarpPerspectiveMatIdxProxy::exec(
...
@@ -50,6 +50,54 @@ void WarpPerspectiveMatIdxProxy::exec(
tensors
[
0
],
tensors
[
1
],
tensors
[
2
],
tensors
[
3
],
tensors
[
4
],
W
.
workspace
());
tensors
[
0
],
tensors
[
1
],
tensors
[
2
],
tensors
[
3
],
tensors
[
4
],
W
.
workspace
());
}
}
void
WarpPerspectiveMultiSrcProxy
::
deduce_layout
(
WarpPerspectiveForward
*
,
TensorLayoutArray
&
)
{}
void
WarpPerspectiveMultiSrcProxy
::
exec
(
WarpPerspectiveForward
*
opr
,
const
TensorNDArray
&
tensors
)
{
if
(
!
W
.
valid
())
{
W
=
WorkspaceWrapper
(
opr
->
handle
(),
0
);
}
megdnn_assert
(
tensors
.
size
()
>=
3
);
bool
has_mat_idx
=
false
;
TensorLayout
mat_idx_layout
;
TensorND
mat_idx_tensor
;
TensorLayoutArray
layouts
(
tensors
.
size
());
std
::
transform
(
tensors
.
begin
(),
tensors
.
end
(),
layouts
.
begin
(),
[](
const
TensorND
&
tensor
)
{
return
tensor
.
layout
;
});
auto
srcs_layouts
=
layouts
;
srcs_layouts
.
pop_back
();
// dst
if
(
srcs_layouts
.
back
().
ndim
==
1
)
{
has_mat_idx
=
true
;
mat_idx_layout
=
srcs_layouts
.
back
();
srcs_layouts
.
pop_back
();
// mat_idx;
}
auto
mat_layout
=
srcs_layouts
.
back
();
srcs_layouts
.
pop_back
();
// mat
if
(
has_mat_idx
)
W
.
update
(
opr
->
get_workspace_in_bytes
(
srcs_layouts
,
mat_layout
,
mat_idx_layout
,
layouts
.
back
()));
else
W
.
update
(
opr
->
get_workspace_in_bytes
(
srcs_layouts
,
mat_layout
,
layouts
.
back
()));
auto
srcs_tensors
=
tensors
;
srcs_tensors
.
pop_back
();
// dst
if
(
has_mat_idx
)
{
mat_idx_tensor
=
srcs_tensors
.
back
();
srcs_tensors
.
pop_back
();
// mat_idx;
}
auto
mat_tensor
=
srcs_tensors
.
back
();
srcs_tensors
.
pop_back
();
// mat
if
(
has_mat_idx
)
opr
->
exec
(
srcs_tensors
,
mat_tensor
,
mat_idx_tensor
,
tensors
.
back
(),
W
.
workspace
());
else
opr
->
exec
(
srcs_tensors
,
mat_tensor
,
tensors
.
back
(),
W
.
workspace
());
}
std
::
vector
<
TestArg
>
warp_perspective
::
get_cv_args
()
{
std
::
vector
<
TestArg
>
warp_perspective
::
get_cv_args
()
{
std
::
vector
<
TestArg
>
args
;
std
::
vector
<
TestArg
>
args
;
...
...
dnn/test/common/warp_perspective.h
浏览文件 @
669816e2
...
@@ -19,6 +19,12 @@ struct WarpPerspectiveMatIdxProxy {
...
@@ -19,6 +19,12 @@ struct WarpPerspectiveMatIdxProxy {
void
exec
(
WarpPerspectiveBackwardMat
*
opr
,
const
TensorNDArray
&
tensors
);
void
exec
(
WarpPerspectiveBackwardMat
*
opr
,
const
TensorNDArray
&
tensors
);
};
};
struct
WarpPerspectiveMultiSrcProxy
{
WorkspaceWrapper
W
;
static
void
deduce_layout
(
WarpPerspectiveForward
*
,
TensorLayoutArray
&
);
void
exec
(
WarpPerspectiveForward
*
opr
,
const
TensorNDArray
&
tensors
);
};
class
WarpPerspectiveMatRNG
final
:
public
IIDRNG
{
class
WarpPerspectiveMatRNG
final
:
public
IIDRNG
{
public:
public:
WarpPerspectiveMatRNG
()
:
idx
(
0
)
{}
WarpPerspectiveMatRNG
()
:
idx
(
0
)
{}
...
...
dnn/test/cuda/warp_perspective.cpp
浏览文件 @
669816e2
...
@@ -887,6 +887,194 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64_QUINT4) {
...
@@ -887,6 +887,194 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64_QUINT4) {
}
}
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_MULTI_SRC_NCHW
)
{
using
Param
=
WarpPerspective
::
Param
;
Param
param
;
WarpPerspectiveMatRNG
rng
;
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW
;
auto
run
=
[
&
param
,
&
rng
,
this
](
size_t
bs
,
size_t
ih
,
size_t
iw
,
size_t
c
,
size_t
oh
,
size_t
ow
,
DType
dtype
)
{
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMultiSrcProxy
>
checker
(
handle_cuda
());
checker
.
set_param
(
param
);
TensorShapeArray
shapes
;
// src
for
(
size_t
i
=
0
;
i
<
bs
;
i
++
)
{
shapes
.
emplace_back
(
TensorShape
{{
1
,
c
,
ih
,
iw
}});
checker
.
set_dtype
(
i
,
dtype
);
}
// mat
shapes
.
emplace_back
(
TensorShape
{{
bs
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
// dst
shapes
.
emplace_back
(
TensorShape
{{
bs
,
c
,
oh
,
ow
}});
checker
.
set_dtype
(
bs
+
1
,
dtype
);
checker
.
execs
(
shapes
);
};
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
(
1
,
20
,
18
,
4
,
6
,
6
,
dtype
);
run
(
2
,
100
,
110
,
10
,
50
,
50
,
dtype
);
run
(
20
,
10
,
11
,
123
,
15
,
16
,
dtype
);
run
(
2200
,
10
,
11
,
3
,
11
,
12
,
dtype
);
}
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_MULTI_SRC_NHWC
)
{
using
Param
=
WarpPerspective
::
Param
;
Param
param
;
WarpPerspectiveMatRNG
rng
;
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
auto
run
=
[
&
param
,
&
rng
,
this
](
size_t
bs
,
size_t
ih
,
size_t
iw
,
size_t
c
,
size_t
oh
,
size_t
ow
,
DType
dtype
)
{
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMultiSrcProxy
>
checker
(
handle_cuda
());
checker
.
set_param
(
param
);
TensorShapeArray
shapes
;
// src
for
(
size_t
i
=
0
;
i
<
bs
;
i
++
)
{
shapes
.
emplace_back
(
TensorShape
{{
1
,
ih
,
iw
,
c
}});
checker
.
set_dtype
(
i
,
dtype
);
}
// mat
shapes
.
emplace_back
(
TensorShape
{{
bs
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
// dst
shapes
.
emplace_back
(
TensorShape
{{
bs
,
oh
,
ow
,
c
}});
checker
.
set_dtype
(
bs
+
1
,
dtype
);
checker
.
execs
(
shapes
);
};
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
(
1
,
20
,
18
,
4
,
6
,
6
,
dtype
);
run
(
2
,
100
,
110
,
10
,
50
,
50
,
dtype
);
run
(
20
,
10
,
11
,
123
,
15
,
16
,
dtype
);
run
(
2200
,
10
,
11
,
3
,
11
,
12
,
dtype
);
}
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NCHW
)
{
using
Param
=
WarpPerspective
::
Param
;
Param
param
;
WarpPerspectiveMatRNG
rng
;
UniformIntRNG
idx_rng
{
0
,
0
};
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW
;
auto
run
=
[
&
param
,
&
rng
,
&
idx_rng
,
this
](
size_t
bs
,
size_t
ih
,
size_t
iw
,
size_t
c
,
size_t
oh
,
size_t
ow
,
size_t
idx
,
DType
dtype
)
{
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMultiSrcProxy
>
checker
(
handle_cuda
());
checker
.
set_param
(
param
);
TensorShapeArray
shapes
;
// src
for
(
size_t
i
=
0
;
i
<
bs
;
i
++
)
{
shapes
.
emplace_back
(
TensorShape
{{
1
,
c
,
ih
,
iw
}});
checker
.
set_dtype
(
i
,
dtype
);
}
// mat
shapes
.
emplace_back
(
TensorShape
{{
idx
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
// mat_idx
shapes
.
emplace_back
(
TensorShape
{{
idx
}});
checker
.
set_dtype
(
bs
+
1
,
dtype
::
Int32
());
idx_rng
=
UniformIntRNG
{
0
,
(
int
)
bs
-
1
};
checker
.
set_rng
(
bs
+
1
,
&
idx_rng
);
// dst
shapes
.
emplace_back
(
TensorShape
{{
idx
,
c
,
oh
,
ow
}});
checker
.
set_dtype
(
bs
+
2
,
dtype
);
checker
.
execs
(
shapes
);
};
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
(
1
,
20
,
18
,
4
,
6
,
6
,
1
,
dtype
);
run
(
2
,
100
,
110
,
10
,
50
,
50
,
1
,
dtype
);
run
(
20
,
10
,
11
,
123
,
15
,
16
,
10
,
dtype
);
run
(
2200
,
10
,
11
,
3
,
11
,
12
,
100
,
dtype
);
}
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NHWC
)
{
using
Param
=
WarpPerspective
::
Param
;
Param
param
;
WarpPerspectiveMatRNG
rng
;
UniformIntRNG
idx_rng
{
0
,
0
};
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
auto
run
=
[
&
param
,
&
rng
,
&
idx_rng
,
this
](
size_t
bs
,
size_t
ih
,
size_t
iw
,
size_t
c
,
size_t
oh
,
size_t
ow
,
size_t
idx
,
DType
dtype
)
{
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMultiSrcProxy
>
checker
(
handle_cuda
());
checker
.
set_param
(
param
);
TensorShapeArray
shapes
;
// src
for
(
size_t
i
=
0
;
i
<
bs
;
i
++
)
{
shapes
.
emplace_back
(
TensorShape
{{
1
,
ih
,
iw
,
c
}});
checker
.
set_dtype
(
i
,
dtype
);
}
// mat
shapes
.
emplace_back
(
TensorShape
{{
idx
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
// mat_idx
shapes
.
emplace_back
(
TensorShape
{{
idx
}});
checker
.
set_dtype
(
bs
+
1
,
dtype
::
Int32
());
idx_rng
=
UniformIntRNG
{
0
,
(
int
)
bs
-
1
};
checker
.
set_rng
(
bs
+
1
,
&
idx_rng
);
// dst
shapes
.
emplace_back
(
TensorShape
{{
idx
,
oh
,
ow
,
c
}});
checker
.
set_dtype
(
bs
+
2
,
dtype
);
checker
.
execs
(
shapes
);
};
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
(
1
,
20
,
18
,
4
,
6
,
6
,
1
,
dtype
);
run
(
2
,
100
,
110
,
10
,
50
,
50
,
1
,
dtype
);
run
(
20
,
10
,
11
,
123
,
15
,
16
,
10
,
dtype
);
run
(
2200
,
10
,
11
,
3
,
11
,
12
,
100
,
dtype
);
}
}
}
#if MEGDNN_WITH_BENCHMARK
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
CUDA
,
BENCHMARK_WARP_PERSPECTIVE_NCHW4
)
{
TEST_F
(
CUDA
,
BENCHMARK_WARP_PERSPECTIVE_NCHW4
)
{
...
...
dnn/test/fallback/warp_perspective.cpp
浏览文件 @
669816e2
...
@@ -172,6 +172,190 @@ TEST_F(FALLBACK, WARP_PERSPECTIFVE_NCHW_QUINT8) {
...
@@ -172,6 +172,190 @@ TEST_F(FALLBACK, WARP_PERSPECTIFVE_NCHW_QUINT8) {
warp_perspective
::
run_quint8_test
(
handle
());
warp_perspective
::
run_quint8_test
(
handle
());
}
}
TEST_F
(
FALLBACK
,
WARP_PERSPECTIVE_MULTI_SRC_NCHW
)
{
using
Param
=
WarpPerspective
::
Param
;
Param
param
;
WarpPerspectiveMatRNG
rng
;
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW
;
auto
run
=
[
&
param
,
&
rng
,
this
](
size_t
bs
,
size_t
ih
,
size_t
iw
,
size_t
c
,
size_t
oh
,
size_t
ow
,
DType
dtype
)
{
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMultiSrcProxy
>
checker
(
handle
());
checker
.
set_param
(
param
);
TensorShapeArray
shapes
;
// src
for
(
size_t
i
=
0
;
i
<
bs
;
i
++
)
{
shapes
.
emplace_back
(
TensorShape
{{
1
,
c
,
ih
,
iw
}});
checker
.
set_dtype
(
i
,
dtype
);
}
// mat
shapes
.
emplace_back
(
TensorShape
{{
bs
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
// dst
shapes
.
emplace_back
(
TensorShape
{{
bs
,
c
,
oh
,
ow
}});
checker
.
set_dtype
(
bs
+
1
,
dtype
);
checker
.
execs
(
shapes
);
};
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
(
1
,
20
,
18
,
4
,
6
,
6
,
dtype
);
run
(
20
,
10
,
11
,
123
,
15
,
16
,
dtype
);
run
(
100
,
10
,
11
,
3
,
11
,
12
,
dtype
);
}
}
}
TEST_F
(
FALLBACK
,
WARP_PERSPECTIVE_MULTI_SRC_NHWC
)
{
using
Param
=
WarpPerspective
::
Param
;
Param
param
;
WarpPerspectiveMatRNG
rng
;
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
auto
run
=
[
&
param
,
&
rng
,
this
](
size_t
bs
,
size_t
ih
,
size_t
iw
,
size_t
c
,
size_t
oh
,
size_t
ow
,
DType
dtype
)
{
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMultiSrcProxy
>
checker
(
handle
());
checker
.
set_param
(
param
);
TensorShapeArray
shapes
;
// src
for
(
size_t
i
=
0
;
i
<
bs
;
i
++
)
{
shapes
.
emplace_back
(
TensorShape
{{
1
,
ih
,
iw
,
c
}});
checker
.
set_dtype
(
i
,
dtype
);
}
// mat
shapes
.
emplace_back
(
TensorShape
{{
bs
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
// dst
shapes
.
emplace_back
(
TensorShape
{{
bs
,
oh
,
ow
,
c
}});
checker
.
set_dtype
(
bs
+
1
,
dtype
);
checker
.
execs
(
shapes
);
};
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
(
1
,
20
,
18
,
4
,
6
,
6
,
dtype
);
run
(
20
,
10
,
11
,
123
,
15
,
16
,
dtype
);
run
(
100
,
10
,
11
,
3
,
11
,
12
,
dtype
);
}
}
}
TEST_F
(
FALLBACK
,
WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NCHW
)
{
using
Param
=
WarpPerspective
::
Param
;
Param
param
;
WarpPerspectiveMatRNG
rng
;
UniformIntRNG
idx_rng
{
0
,
0
};
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW
;
auto
run
=
[
&
param
,
&
rng
,
&
idx_rng
,
this
](
size_t
bs
,
size_t
ih
,
size_t
iw
,
size_t
c
,
size_t
oh
,
size_t
ow
,
size_t
idx
,
DType
dtype
)
{
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMultiSrcProxy
>
checker
(
handle
());
checker
.
set_param
(
param
);
TensorShapeArray
shapes
;
// src
for
(
size_t
i
=
0
;
i
<
bs
;
i
++
)
{
shapes
.
emplace_back
(
TensorShape
{{
1
,
c
,
ih
,
iw
}});
checker
.
set_dtype
(
i
,
dtype
);
}
// mat
shapes
.
emplace_back
(
TensorShape
{{
idx
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
// mat_idx
shapes
.
emplace_back
(
TensorShape
{{
idx
}});
checker
.
set_dtype
(
bs
+
1
,
dtype
::
Int32
());
idx_rng
=
UniformIntRNG
{
0
,
(
int
)
bs
-
1
};
checker
.
set_rng
(
bs
+
1
,
&
idx_rng
);
// dst
shapes
.
emplace_back
(
TensorShape
{{
idx
,
c
,
oh
,
ow
}});
checker
.
set_dtype
(
bs
+
2
,
dtype
);
checker
.
execs
(
shapes
);
};
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
(
1
,
20
,
18
,
4
,
6
,
6
,
1
,
dtype
);
run
(
20
,
10
,
11
,
123
,
15
,
16
,
10
,
dtype
);
run
(
100
,
10
,
11
,
3
,
11
,
12
,
100
,
dtype
);
}
}
}
TEST_F
(
FALLBACK
,
WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NHWC
)
{
using
Param
=
WarpPerspective
::
Param
;
Param
param
;
WarpPerspectiveMatRNG
rng
;
UniformIntRNG
idx_rng
{
0
,
0
};
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
auto
run
=
[
&
param
,
&
rng
,
&
idx_rng
,
this
](
size_t
bs
,
size_t
ih
,
size_t
iw
,
size_t
c
,
size_t
oh
,
size_t
ow
,
size_t
idx
,
DType
dtype
)
{
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMultiSrcProxy
>
checker
(
handle
());
checker
.
set_param
(
param
);
TensorShapeArray
shapes
;
// src
for
(
size_t
i
=
0
;
i
<
bs
;
i
++
)
{
shapes
.
emplace_back
(
TensorShape
{{
1
,
ih
,
iw
,
c
}});
checker
.
set_dtype
(
i
,
dtype
);
}
// mat
shapes
.
emplace_back
(
TensorShape
{{
idx
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
// mat_idx
shapes
.
emplace_back
(
TensorShape
{{
idx
}});
checker
.
set_dtype
(
bs
+
1
,
dtype
::
Int32
());
idx_rng
=
UniformIntRNG
{
0
,
(
int
)
bs
-
1
};
checker
.
set_rng
(
bs
+
1
,
&
idx_rng
);
// dst
shapes
.
emplace_back
(
TensorShape
{{
idx
,
oh
,
ow
,
c
}});
checker
.
set_dtype
(
bs
+
2
,
dtype
);
checker
.
execs
(
shapes
);
};
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
run
(
1
,
20
,
18
,
4
,
6
,
6
,
1
,
dtype
);
run
(
20
,
10
,
11
,
123
,
15
,
16
,
10
,
dtype
);
run
(
100
,
10
,
11
,
3
,
11
,
12
,
100
,
dtype
);
}
}
}
}
// namespace test
}
// namespace test
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/test/naive/warp_perspective.cpp
浏览文件 @
669816e2
...
@@ -55,6 +55,282 @@ class NanMatRNG : public RNG {
...
@@ -55,6 +55,282 @@ class NanMatRNG : public RNG {
};
};
}
// namespace
}
// namespace
TEST_F
(
NAIVE
,
WARP_PERSPECTIVE_MULTI_SRC
)
{
using
Param
=
WarpPerspective
::
Param
;
WarpPerspective
::
Param
param
;
auto
extra_impl
=
[
&
param
,
this
](
const
TensorNDArray
&
tensors
)
{
//! split src
TensorND
src
=
tensors
[
0
];
// n h w c
size_t
n
=
src
.
layout
[
0
];
TensorNDArray
srcs
;
// n 个 1 h w c
TensorLayoutArray
srcs_layouts
;
for
(
size_t
i
=
0
;
i
<
n
;
i
++
)
{
TensorLayout
ly
;
ly
=
TensorLayout
{
{
1
,
src
.
layout
[
1
],
src
.
layout
[
2
],
src
.
layout
[
3
]},
src
.
layout
.
dtype
};
srcs
.
emplace_back
(
malloc
(
ly
.
span
().
dist_byte
()),
ly
);
srcs_layouts
.
emplace_back
(
ly
);
}
auto
split
=
handle
()
->
create_operator
<
SplitForward
>
();
split
->
param
().
axis
=
0
;
auto
split_ws_size
=
split
->
get_workspace_in_bytes
(
src
.
layout
,
srcs_layouts
);
dt_byte
*
split_ws_ptr
=
static_cast
<
dt_byte
*>
(
malloc
(
split_ws_size
));
Workspace
split_ws
{
split_ws_ptr
,
split_ws_size
};
split
->
exec
(
src
,
srcs
,
split_ws
);
auto
warp_perspective
=
handle
()
->
create_operator
<
WarpPerspective
>
();
warp_perspective
->
param
()
=
param
;
auto
warp_ws_size
=
warp_perspective
->
get_workspace_in_bytes
(
srcs_layouts
,
tensors
[
1
].
layout
,
tensors
[
2
].
layout
);
dt_byte
*
warp_ws_ptr
=
static_cast
<
dt_byte
*>
(
malloc
(
warp_ws_size
));
Workspace
warp_ws
{
warp_ws_ptr
,
warp_ws_size
};
warp_perspective
->
exec
(
srcs
,
tensors
[
1
],
tensors
[
2
],
warp_ws
);
free
(
split_ws_ptr
);
free
(
warp_ws_ptr
);
for
(
auto
&&
s
:
srcs
)
{
free
(
s
.
raw_ptr
());
}
};
{
// Float32
Checker
<
WarpPerspectiveForward
>
checker
(
handle
());
WarpPerspectiveMatRNG
rng
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_extra_opr_impl
(
extra_impl
);
// NHWC
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
checker
.
set_param
(
param
);
checker
.
execs
({{
1
,
2
,
2
,
4
},
{
1
,
3
,
3
},
{
1
,
2
,
2
,
4
}});
checker
.
execs
({{
2
,
10
,
10
,
4
},
{
2
,
3
,
3
},
{
2
,
10
,
12
,
4
}});
checker
.
execs
({{
3
,
25
,
24
,
8
},
{
3
,
3
,
3
},
{
3
,
12
,
10
,
8
}});
checker
.
execs
({{
4
,
33
,
22
,
16
},
{
4
,
3
,
3
},
{
4
,
9
,
12
,
16
}});
}
// NCHW
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW
;
checker
.
set_param
(
param
);
checker
.
execs
({{
1
,
4
,
2
,
2
},
{
1
,
3
,
3
},
{
1
,
4
,
2
,
2
}});
checker
.
execs
({{
2
,
4
,
10
,
10
},
{
2
,
3
,
3
},
{
2
,
4
,
10
,
12
}});
checker
.
execs
({{
3
,
8
,
25
,
24
},
{
3
,
3
,
3
},
{
3
,
8
,
12
,
10
}});
checker
.
execs
({{
4
,
16
,
33
,
22
},
{
4
,
3
,
3
},
{
4
,
16
,
9
,
12
}});
}
}
{
// Float16
Checker
<
WarpPerspectiveForward
>
checker
(
handle
());
WarpPerspectiveMatRNG
rng
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
Float16
());
checker
.
set_dtype
(
2
,
dtype
::
Float16
());
checker
.
set_extra_opr_impl
(
extra_impl
);
// NHWC
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
checker
.
set_param
(
param
);
checker
.
execs
({{
1
,
2
,
2
,
4
},
{
1
,
3
,
3
},
{
1
,
2
,
2
,
4
}});
checker
.
execs
({{
2
,
10
,
10
,
4
},
{
2
,
3
,
3
},
{
2
,
10
,
12
,
4
}});
checker
.
execs
({{
3
,
25
,
24
,
8
},
{
3
,
3
,
3
},
{
3
,
12
,
10
,
8
}});
checker
.
execs
({{
4
,
33
,
22
,
16
},
{
4
,
3
,
3
},
{
4
,
9
,
12
,
16
}});
}
// NCHW
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW
;
checker
.
set_param
(
param
);
checker
.
execs
({{
1
,
4
,
2
,
2
},
{
1
,
3
,
3
},
{
1
,
4
,
2
,
2
}});
checker
.
execs
({{
2
,
4
,
10
,
10
},
{
2
,
3
,
3
},
{
2
,
4
,
10
,
12
}});
checker
.
execs
({{
3
,
8
,
25
,
24
},
{
3
,
3
,
3
},
{
3
,
8
,
12
,
10
}});
checker
.
execs
({{
4
,
16
,
33
,
22
},
{
4
,
3
,
3
},
{
4
,
16
,
9
,
12
}});
}
}
}
TEST_F
(
NAIVE
,
WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX
)
{
using
Param
=
WarpPerspective
::
Param
;
WarpPerspective
::
Param
param
;
auto
extra_impl
=
[
&
param
,
this
](
const
TensorNDArray
&
tensors
)
{
//! split src
TensorND
src
=
tensors
[
0
];
// n h w c
size_t
n
=
src
.
layout
[
0
];
TensorNDArray
srcs
;
// n 个 1 h w c
TensorLayoutArray
srcs_layouts
;
for
(
size_t
i
=
0
;
i
<
n
;
i
++
)
{
TensorLayout
ly
;
ly
=
TensorLayout
{
{
1
,
src
.
layout
[
1
],
src
.
layout
[
2
],
src
.
layout
[
3
]},
src
.
layout
.
dtype
};
srcs
.
emplace_back
(
malloc
(
ly
.
span
().
dist_byte
()),
ly
);
srcs_layouts
.
emplace_back
(
ly
);
}
auto
split
=
handle
()
->
create_operator
<
SplitForward
>
();
split
->
param
().
axis
=
0
;
auto
split_ws_size
=
split
->
get_workspace_in_bytes
(
src
.
layout
,
srcs_layouts
);
dt_byte
*
split_ws_ptr
=
static_cast
<
dt_byte
*>
(
malloc
(
split_ws_size
));
Workspace
split_ws
{
split_ws_ptr
,
split_ws_size
};
split
->
exec
(
src
,
srcs
,
split_ws
);
auto
warp_perspective
=
handle
()
->
create_operator
<
WarpPerspective
>
();
warp_perspective
->
param
()
=
param
;
auto
warp_ws_size
=
warp_perspective
->
get_workspace_in_bytes
(
srcs_layouts
,
tensors
[
1
].
layout
,
tensors
[
2
].
layout
,
tensors
[
3
].
layout
);
dt_byte
*
warp_ws_ptr
=
static_cast
<
dt_byte
*>
(
malloc
(
warp_ws_size
));
Workspace
warp_ws
{
warp_ws_ptr
,
warp_ws_size
};
warp_perspective
->
exec
(
srcs
,
tensors
[
1
],
tensors
[
2
],
tensors
[
3
],
warp_ws
);
free
(
split_ws_ptr
);
free
(
warp_ws_ptr
);
for
(
auto
&&
s
:
srcs
)
{
free
(
s
.
raw_ptr
());
}
};
{
// Float32
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMatIdxProxy
>
checker
(
handle
());
WarpPerspectiveMatRNG
rng
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
Float32
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
set_dtype
(
2
,
dtype
::
Int32
());
checker
.
set_dtype
(
3
,
dtype
::
Float32
());
checker
.
set_extra_opr_impl
(
extra_impl
);
// NHWC
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
checker
.
set_param
(
param
);
UniformIntRNG
idx_rng
{
0
,
0
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
1
,
2
,
2
,
4
},
{
1
,
3
,
3
},
{
1
},
{
1
,
2
,
2
,
4
}});
idx_rng
=
UniformIntRNG
{
0
,
1
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
2
,
10
,
10
,
4
},
{
1
,
3
,
3
},
{
1
},
{
1
,
10
,
12
,
4
}});
idx_rng
=
UniformIntRNG
{
0
,
2
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
3
,
25
,
24
,
8
},
{
2
,
3
,
3
},
{
2
},
{
2
,
12
,
10
,
8
}});
checker
.
execs
({{
4
,
33
,
22
,
16
},
{
2
,
3
,
3
},
{
2
},
{
2
,
9
,
12
,
16
}});
}
// NCHW
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW
;
checker
.
set_param
(
param
);
UniformIntRNG
idx_rng
{
0
,
0
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
1
,
4
,
2
,
2
},
{
1
,
3
,
3
},
{
1
},
{
1
,
4
,
2
,
2
}});
idx_rng
=
UniformIntRNG
{
0
,
1
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
2
,
4
,
10
,
10
},
{
1
,
3
,
3
},
{
1
},
{
1
,
4
,
10
,
12
}});
idx_rng
=
UniformIntRNG
{
0
,
2
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
3
,
8
,
25
,
24
},
{
2
,
3
,
3
},
{
2
},
{
2
,
8
,
12
,
10
}});
checker
.
execs
({{
4
,
16
,
33
,
22
},
{
2
,
3
,
3
},
{
2
},
{
2
,
16
,
9
,
12
}});
}
}
{
// Float16
Checker
<
WarpPerspectiveForward
,
WarpPerspectiveMatIdxProxy
>
checker
(
handle
());
WarpPerspectiveMatRNG
rng
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
Float16
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
checker
.
set_dtype
(
2
,
dtype
::
Int32
());
checker
.
set_dtype
(
3
,
dtype
::
Float16
());
checker
.
set_extra_opr_impl
(
extra_impl
);
// NHWC
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NHWC
;
checker
.
set_param
(
param
);
UniformIntRNG
idx_rng
{
0
,
0
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
1
,
2
,
2
,
4
},
{
1
,
3
,
3
},
{
1
},
{
1
,
2
,
2
,
4
}});
idx_rng
=
UniformIntRNG
{
0
,
1
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
2
,
10
,
10
,
4
},
{
1
,
3
,
3
},
{
1
},
{
1
,
10
,
12
,
4
}});
idx_rng
=
UniformIntRNG
{
0
,
2
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
3
,
25
,
24
,
8
},
{
2
,
3
,
3
},
{
2
},
{
2
,
12
,
10
,
8
}});
checker
.
execs
({{
4
,
33
,
22
,
16
},
{
2
,
3
,
3
},
{
2
},
{
2
,
9
,
12
,
16
}});
}
// NCHW
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
param
.
format
=
Param
::
Format
::
NCHW
;
checker
.
set_param
(
param
);
UniformIntRNG
idx_rng
{
0
,
0
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
1
,
4
,
2
,
2
},
{
1
,
3
,
3
},
{
1
},
{
1
,
4
,
2
,
2
}});
idx_rng
=
UniformIntRNG
{
0
,
1
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
2
,
4
,
10
,
10
},
{
1
,
3
,
3
},
{
1
},
{
1
,
4
,
10
,
12
}});
idx_rng
=
UniformIntRNG
{
0
,
2
};
checker
.
set_rng
(
2
,
&
idx_rng
);
checker
.
execs
({{
3
,
8
,
25
,
24
},
{
2
,
3
,
3
},
{
2
},
{
2
,
8
,
12
,
10
}});
checker
.
execs
({{
4
,
16
,
33
,
22
},
{
2
,
3
,
3
},
{
2
},
{
2
,
16
,
9
,
12
}});
}
}
}
TEST_F
(
NAIVE
,
WARP_PERSPECTIVE_NCHW4
)
{
TEST_F
(
NAIVE
,
WARP_PERSPECTIVE_NCHW4
)
{
using
Param
=
WarpPerspective
::
Param
;
using
Param
=
WarpPerspective
::
Param
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录