Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f31e52d5
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f31e52d5
编写于
8月 01, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb): warpperspective support multi src input
GitOrigin-RevId: 0887656864ea43c7f564882c312d0a8a29a90295
上级
669816e2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
518 addition
and
77 deletion
+518
-77
dnn/include/megdnn/oprs/imgproc.h
dnn/include/megdnn/oprs/imgproc.h
+39
-4
dnn/src/common/warp_perspective.cpp
dnn/src/common/warp_perspective.cpp
+1
-5
dnn/test/fallback/warp_perspective.cpp
dnn/test/fallback/warp_perspective.cpp
+2
-2
src/opr/impl/imgproc.cpp
src/opr/impl/imgproc.cpp
+178
-62
src/opr/impl/imgproc.sereg.h
src/opr/impl/imgproc.sereg.h
+28
-4
src/opr/include/megbrain/opr/imgproc.h
src/opr/include/megbrain/opr/imgproc.h
+26
-0
src/opr/test/imgproc.cpp
src/opr/test/imgproc.cpp
+244
-0
未找到文件。
dnn/include/megdnn/oprs/imgproc.h
浏览文件 @
f31e52d5
...
@@ -56,7 +56,22 @@ public:
...
@@ -56,7 +56,22 @@ public:
_megdnn_workspace
workspace
)
{
_megdnn_workspace
workspace
)
{
exec
(
src
,
mat
,
{},
dst
,
workspace
);
exec
(
src
,
mat
,
{},
dst
,
workspace
);
}
}
/**
* \param[in] srcs consists of n TensorNDs, each TensorND has shape (1, channel,
* in_height, in_width) \param[in] mat (n, 3, 3) \param[out] dst (n, channel,
* out_height, out_width)
*
* \note
* srcs and dst can have different shapes, as long as their c agree and the size of
* srcs is equal to n. every element of srcs, mat and dst should be contiguous.
*
* equivalent to:
* TensorND src{nullptr, TensorLayout({n, channel, in_height, in_width},
* srcs[0].layout.dtype)}; auto concat = handle()->create_operator<Concat>();
* concat->exec(srcs, src);
* auto warp = handle()->create_operator<WarpPerspectiveForward>();
* warp->exec(src, mat, dst, workspace);
*/
void
exec
(
void
exec
(
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
...
@@ -75,11 +90,25 @@ public:
...
@@ -75,11 +90,25 @@ public:
virtual
void
exec
(
virtual
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
/**
* \p srcs should have m elements, and \p mat and \p mat_idx should
* both have batch size n. Each item in \p mat_idx must be in the range
* of [0, m-1].
*
* \param mat_idx the indices of input image that each matrix in \p mat
* should act on. It can also be empty and in such case \p mat batch size
* should be the same as the number of elements in \p srcs .
*/
virtual
void
exec
(
virtual
void
exec
(
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_in
const
TensorNDArray
&
srcs
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_tensor_in
mat_idx
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
_megdnn_workspace
workspace
)
{
static_cast
<
void
>
(
srcs
);
static_cast
<
void
>
(
mat
);
static_cast
<
void
>
(
mat_idx
);
static_cast
<
void
>
(
dst
);
static_cast
<
void
>
(
workspace
);
}
size_t
get_workspace_in_bytes
(
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
dst
)
{
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
dst
)
{
...
@@ -98,7 +127,13 @@ public:
...
@@ -98,7 +127,13 @@ public:
virtual
size_t
get_workspace_in_bytes
(
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayoutArray
&
srcs
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
=
0
;
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
{
static_cast
<
void
>
(
srcs
);
static_cast
<
void
>
(
mat
);
static_cast
<
void
>
(
mat_idx
);
static_cast
<
void
>
(
dst
);
return
0
;
}
protected:
protected:
void
check_exec
(
void
check_exec
(
...
...
dnn/src/common/warp_perspective.cpp
浏览文件 @
f31e52d5
...
@@ -10,12 +10,8 @@ void WarpPerspectiveBase::check_layout_fwd(
...
@@ -10,12 +10,8 @@ void WarpPerspectiveBase::check_layout_fwd(
auto
s
=
srcs
.
front
();
auto
s
=
srcs
.
front
();
for
(
auto
&&
src
:
srcs
)
{
for
(
auto
&&
src
:
srcs
)
{
megdnn_assert_contiguous
(
src
);
megdnn_assert_contiguous
(
src
);
megdnn_assert
(
src
.
dtype
==
s
.
dtype
);
src
.
eq_layout
(
s
);
megdnn_assert
(
src
.
ndim
==
s
.
ndim
);
megdnn_assert
(
src
.
shape
[
0
]
==
1
);
megdnn_assert
(
src
.
shape
[
0
]
==
1
);
for
(
size_t
i
=
0
;
i
<
s
.
ndim
;
i
++
)
{
megdnn_assert
(
src
.
shape
[
i
]
==
s
.
shape
[
i
]);
}
megdnn_assert
(
src
.
format
==
s
.
format
);
megdnn_assert
(
src
.
format
==
s
.
format
);
}
}
megdnn_assert_contiguous
(
mat
);
megdnn_assert_contiguous
(
mat
);
...
...
dnn/test/fallback/warp_perspective.cpp
浏览文件 @
f31e52d5
...
@@ -289,7 +289,7 @@ TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NCHW) {
...
@@ -289,7 +289,7 @@ TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NCHW) {
shapes
.
emplace_back
(
TensorShape
{{
idx
,
3
,
3
}});
shapes
.
emplace_back
(
TensorShape
{{
idx
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
checker
.
set_rng
(
bs
,
&
rng
);
// mat_idx
// mat_idx
shapes
.
emplace_back
(
TensorShape
{{
idx
}}
);
shapes
.
emplace_back
(
TensorShape
({
idx
})
);
checker
.
set_dtype
(
bs
+
1
,
dtype
::
Int32
());
checker
.
set_dtype
(
bs
+
1
,
dtype
::
Int32
());
idx_rng
=
UniformIntRNG
{
0
,
(
int
)
bs
-
1
};
idx_rng
=
UniformIntRNG
{
0
,
(
int
)
bs
-
1
};
checker
.
set_rng
(
bs
+
1
,
&
idx_rng
);
checker
.
set_rng
(
bs
+
1
,
&
idx_rng
);
...
@@ -338,7 +338,7 @@ TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NHWC) {
...
@@ -338,7 +338,7 @@ TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NHWC) {
shapes
.
emplace_back
(
TensorShape
{{
idx
,
3
,
3
}});
shapes
.
emplace_back
(
TensorShape
{{
idx
,
3
,
3
}});
checker
.
set_rng
(
bs
,
&
rng
);
checker
.
set_rng
(
bs
,
&
rng
);
// mat_idx
// mat_idx
shapes
.
emplace_back
(
TensorShape
{{
idx
}}
);
shapes
.
emplace_back
(
TensorShape
({
idx
})
);
checker
.
set_dtype
(
bs
+
1
,
dtype
::
Int32
());
checker
.
set_dtype
(
bs
+
1
,
dtype
::
Int32
());
idx_rng
=
UniformIntRNG
{
0
,
(
int
)
bs
-
1
};
idx_rng
=
UniformIntRNG
{
0
,
(
int
)
bs
-
1
};
checker
.
set_rng
(
bs
+
1
,
&
idx_rng
);
checker
.
set_rng
(
bs
+
1
,
&
idx_rng
);
...
...
src/opr/impl/imgproc.cpp
浏览文件 @
f31e52d5
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/imgproc.h"
#include "./internal/megdnn_opr_wrapper.inl"
#include "./internal/megdnn_opr_wrapper.inl"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/utility.h"
...
@@ -25,6 +26,26 @@ WarpPerspectiveForward::WarpPerspectiveForward(
...
@@ -25,6 +26,26 @@ WarpPerspectiveForward::WarpPerspectiveForward(
outshape_by_symvar_enable
(
input
().
size
()
-
1
,
input
().
size
()
-
1
);
outshape_by_symvar_enable
(
input
().
size
()
-
1
,
input
().
size
()
-
1
);
}
}
WarpPerspectiveForward
::
WarpPerspectiveForward
(
const
VarNodeArrayView
&
srcs
,
VarNode
*
mat
,
VarNode
*
mat_idx
,
VarNode
*
out_shape
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
)
:
Super
(
OperatorNodeBaseCtorParam
{
srcs
[
0
]
->
owner_graph
(),
config
,
"warp_perspective"
,
{
srcs
[
0
],
mat
}})
{
mgb_assert
(
!
srcs
.
empty
());
m_is_multi_src
=
true
;
m_srcs_size
=
srcs
.
size
();
init_megdnn_opr
(
*
this
,
param
);
for
(
auto
&&
src
:
srcs
)
{
add_input
({
src
});
}
if
(
mat_idx
)
{
add_input
({
mat
,
mat_idx
,
out_shape
});
}
else
{
add_input
({
mat
,
out_shape
});
}
outshape_by_symvar_enable
(
input
().
size
()
-
1
,
input
().
size
()
-
1
);
}
SymbolVar
WarpPerspectiveForward
::
make
(
SymbolVar
WarpPerspectiveForward
::
make
(
SymbolVar
i0
,
SymbolVar
i1
,
SymbolVar
i2
,
SymbolVar
i3
,
const
Param
&
param
,
SymbolVar
i0
,
SymbolVar
i1
,
SymbolVar
i2
,
SymbolVar
i3
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
)
{
const
OperatorNodeConfig
&
config
)
{
...
@@ -32,6 +53,15 @@ SymbolVar WarpPerspectiveForward::make(
...
@@ -32,6 +53,15 @@ SymbolVar WarpPerspectiveForward::make(
i0
.
node
(),
i1
.
node
(),
i2
.
node
(),
i3
.
node
(),
param
,
config
);
i0
.
node
(),
i1
.
node
(),
i2
.
node
(),
i3
.
node
(),
param
,
config
);
}
}
SymbolVar
WarpPerspectiveForward
::
make
(
const
VarNodeArrayView
&
i0
,
SymbolVar
i1
,
SymbolVar
i2
,
SymbolVar
i3
,
const
Param
&
param
,
OperatorNodeConfig
config
)
{
mgb_assert
(
!
i0
.
empty
());
intl
::
BatchedDTypePromotion
dtp
{
i0
};
return
SymbolVar
{
i0
[
0
]}.
insert_single_output_opr
<
WarpPerspectiveForward
>
(
dtp
.
get_vars
(),
i1
.
node
(),
i2
.
node
(),
i3
.
node
(),
param
,
config
);
}
void
WarpPerspectiveForward
::
init_output_dtype
()
{
void
WarpPerspectiveForward
::
init_output_dtype
()
{
if
(
config
().
output_dtype
().
valid
())
{
if
(
config
().
output_dtype
().
valid
())
{
output
(
0
)
->
dtype
(
config
().
output_dtype
());
output
(
0
)
->
dtype
(
config
().
output_dtype
());
...
@@ -48,63 +78,110 @@ void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape(
...
@@ -48,63 +78,110 @@ void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape(
TensorShape
&
dest
,
const
ShapeInferInfo
&
shpinfo
)
{
TensorShape
&
dest
,
const
ShapeInferInfo
&
shpinfo
)
{
TensorShape
oshp2d
;
TensorShape
oshp2d
;
cg
::
copy_tensor_value_to_shape
(
oshp2d
,
*
shpinfo
.
shpval_inp_val
.
at
(
0
));
cg
::
copy_tensor_value_to_shape
(
oshp2d
,
*
shpinfo
.
shpval_inp_val
.
at
(
0
));
auto
imgshp
=
shpinfo
.
shape_inp_shp
.
at
(
0
),
matshp
=
shpinfo
.
shape_inp_shp
.
at
(
1
);
mgb_assert
(
TensorShape
imgshp
,
matshp
,
mat_idx_shp
;
(
imgshp
.
ndim
==
4
||
imgshp
.
ndim
==
5
)
&&
matshp
.
ndim
==
3
&&
TensorShapeArray
imgshps
;
oshp2d
.
ndim
==
2
&&
matshp
.
shape
[
1
]
==
3
&&
matshp
.
shape
[
2
]
==
3
,
if
(
!
m_is_multi_src
)
{
"shape mismatch for WarpPerspectiveForward: img=%s mat=%s "
imgshp
=
shpinfo
.
shape_inp_shp
.
at
(
0
);
"out2d=%s"
,
matshp
=
shpinfo
.
shape_inp_shp
.
at
(
1
);
imgshp
.
to_string
().
c_str
(),
matshp
.
to_string
().
c_str
(),
oshp2d
.
to_string
().
c_str
());
if
(
input
().
size
()
==
3
)
{
mgb_assert
(
imgshp
[
0
]
==
matshp
[
0
],
"batchsize mismatch: img=%zu mat=%zu"
,
imgshp
[
0
],
matshp
[
0
]);
}
else
{
mgb_assert
(
input
().
size
()
==
4
);
auto
mat_idx_shp
=
shpinfo
.
shape_inp_shp
.
at
(
2
);
mgb_assert
(
mgb_assert
(
mat_idx_shp
[
0
]
==
matshp
[
0
]
&&
mat_idx_shp
.
ndim
==
1
,
(
imgshp
.
ndim
==
4
||
imgshp
.
ndim
==
5
)
&&
matshp
.
ndim
==
3
&&
"invalid mat_idx shape: mat=%zu mat_idx=%s"
,
matshp
[
0
],
oshp2d
.
ndim
==
2
&&
matshp
.
shape
[
1
]
==
3
&&
mat_idx_shp
.
to_string
().
c_str
());
matshp
.
shape
[
2
]
==
3
,
}
"shape mismatch for WarpPerspectiveForward: img=%s mat=%s "
"out2d=%s"
,
imgshp
.
to_string
().
c_str
(),
matshp
.
to_string
().
c_str
(),
oshp2d
.
to_string
().
c_str
());
if
(
input
().
size
()
==
3
)
{
mgb_assert
(
imgshp
[
0
]
==
matshp
[
0
],
"batchsize mismatch: img=%zu mat=%zu"
,
imgshp
[
0
],
matshp
[
0
]);
}
else
{
mgb_assert
(
input
().
size
()
==
4
);
mat_idx_shp
=
shpinfo
.
shape_inp_shp
.
at
(
2
);
mgb_assert
(
mat_idx_shp
[
0
]
==
matshp
[
0
]
&&
mat_idx_shp
.
ndim
==
1
,
"invalid mat_idx shape: mat=%zu mat_idx=%s"
,
matshp
[
0
],
mat_idx_shp
.
to_string
().
c_str
());
}
switch
(
param
().
format
)
{
switch
(
param
().
format
)
{
case
Param
::
Format
::
NCHW_NCHW4_IC_SMALL
:
case
Param
::
Format
::
NCHW_NCHW4_IC_SMALL
:
case
Param
::
Format
::
NHWC_NCHW4_IC_SMALL
:
case
Param
::
Format
::
NHWC_NCHW4_IC_SMALL
:
dest
.
ndim
=
5
;
dest
.
ndim
=
5
;
dest
[
0
]
=
matshp
[
0
];
dest
[
0
]
=
matshp
[
0
];
dest
.
shape
[
1
]
=
1
;
dest
.
shape
[
1
]
=
1
;
dest
.
shape
[
2
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
2
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
3
]
=
oshp2d
.
shape
[
1
];
dest
.
shape
[
3
]
=
oshp2d
.
shape
[
1
];
dest
.
shape
[
4
]
=
4
;
dest
.
shape
[
4
]
=
4
;
break
;
break
;
case
Param
::
Format
::
NHWC_NCHW
:
case
Param
::
Format
::
NHWC_NCHW
:
dest
.
ndim
=
4
;
dest
.
ndim
=
4
;
dest
[
0
]
=
matshp
[
0
];
dest
[
0
]
=
matshp
[
0
];
dest
.
shape
[
1
]
=
imgshp
.
shape
[
3
];
dest
.
shape
[
1
]
=
imgshp
.
shape
[
3
];
dest
.
shape
[
2
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
2
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
3
]
=
oshp2d
.
shape
[
1
];
dest
.
shape
[
3
]
=
oshp2d
.
shape
[
1
];
break
;
break
;
default:
default:
size_t
height_idx
=
0
;
size_t
height_idx
=
0
;
if
(
param
().
format
==
Param
::
Format
::
NCHW
||
if
(
param
().
format
==
Param
::
Format
::
NCHW
||
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW64
)
{
param
().
format
==
Param
::
Format
::
NCHW64
)
{
height_idx
=
2
;
height_idx
=
2
;
}
else
{
}
else
{
height_idx
=
1
;
height_idx
=
1
;
}
}
dest
=
imgshp
;
dest
=
imgshp
;
dest
[
0
]
=
matshp
[
0
];
dest
[
0
]
=
matshp
[
0
];
if
(
param
().
format
==
Param
::
Format
::
NHWCD4
)
{
if
(
param
().
format
==
Param
::
Format
::
NHWCD4
)
{
dest
.
shape
[
height_idx
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
height_idx
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
height_idx
+
2
]
=
oshp2d
.
shape
[
1
];
dest
.
shape
[
height_idx
+
2
]
=
oshp2d
.
shape
[
1
];
}
else
{
}
else
{
for
(
int
i
=
0
;
i
<
2
;
++
i
)
for
(
int
i
=
0
;
i
<
2
;
++
i
)
dest
.
shape
[
height_idx
+
i
]
=
oshp2d
.
shape
[
i
];
dest
.
shape
[
height_idx
+
i
]
=
oshp2d
.
shape
[
i
];
}
break
;
}
}
else
{
imgshp
=
shpinfo
.
shape_inp_shp
.
at
(
0
);
matshp
=
shpinfo
.
shape_inp_shp
.
at
(
m_srcs_size
);
for
(
size_t
i
=
0
;
i
<
m_srcs_size
;
i
++
)
{
imgshps
.
emplace_back
(
shpinfo
.
shape_inp_shp
.
at
(
i
));
mgb_assert
(
imgshps
[
i
].
ndim
==
imgshp
.
ndim
);
for
(
size_t
j
=
0
;
j
<
imgshp
.
ndim
;
j
++
)
{
mgb_assert
(
imgshps
[
i
].
shape
[
j
]
==
imgshp
[
j
]);
}
}
break
;
}
mgb_assert
(
imgshp
[
0
]
==
1
&&
imgshp
.
ndim
==
4
&&
matshp
.
ndim
==
3
&&
oshp2d
.
ndim
==
2
&&
matshp
.
shape
[
1
]
==
3
&&
matshp
.
shape
[
2
]
==
3
,
"shape mismatch for WarpPerspectiveForward: img=%s mat=%s "
"out2d=%s"
,
imgshp
.
to_string
().
c_str
(),
matshp
.
to_string
().
c_str
(),
oshp2d
.
to_string
().
c_str
());
if
(
input
().
size
()
-
m_srcs_size
==
2
)
{
mgb_assert
(
m_srcs_size
==
matshp
[
0
],
"batchsize mismatch: img=%zu mat=%zu"
,
m_srcs_size
,
matshp
[
0
]);
}
else
{
mgb_assert
(
input
().
size
()
-
m_srcs_size
==
3
);
mat_idx_shp
=
shpinfo
.
shape_inp_shp
.
at
(
m_srcs_size
+
1
);
mgb_assert
(
mat_idx_shp
[
0
]
==
matshp
[
0
]
&&
mat_idx_shp
.
ndim
==
1
,
"invalid mat_idx shape: mat=%zu mat_idx=%s"
,
matshp
[
0
],
mat_idx_shp
.
to_string
().
c_str
());
}
size_t
height_idx
=
0
;
if
(
param
().
format
==
Param
::
Format
::
NCHW
)
{
height_idx
=
2
;
}
else
{
height_idx
=
1
;
}
dest
=
imgshp
;
dest
[
0
]
=
matshp
[
0
];
for
(
int
i
=
0
;
i
<
2
;
++
i
)
dest
.
shape
[
height_idx
+
i
]
=
oshp2d
.
shape
[
i
];
}
}
}
}
...
@@ -114,22 +191,61 @@ void WarpPerspectiveForward::init_output_static_infer_desc() {
...
@@ -114,22 +191,61 @@ void WarpPerspectiveForward::init_output_static_infer_desc() {
}
}
void
WarpPerspectiveForward
::
scn_do_execute
()
{
void
WarpPerspectiveForward
::
scn_do_execute
()
{
if
(
input
().
size
()
==
3
)
{
if
(
!
m_is_multi_src
)
{
intl
::
_MegDNNOprMethInvoker
<
2
,
1
>::
exec
(
megdnn_opr
(),
this
);
if
(
input
().
size
()
==
3
)
{
intl
::
_MegDNNOprMethInvoker
<
2
,
1
>::
exec
(
megdnn_opr
(),
this
);
}
else
{
intl
::
_MegDNNOprMethInvoker
<
3
,
1
>::
exec
(
megdnn_opr
(),
this
);
}
}
else
{
}
else
{
intl
::
_MegDNNOprMethInvoker
<
3
,
1
>::
exec
(
megdnn_opr
(),
this
);
megdnn
::
TensorNDArray
srcs
;
for
(
size_t
i
=
0
;
i
<
m_srcs_size
;
i
++
)
{
srcs
.
push_back
(
input
(
i
)
->
dev_tensor
().
as_megdnn
());
}
if
(
input
().
size
()
-
m_srcs_size
==
2
)
{
megdnn_opr
()
->
exec
(
srcs
,
input
(
m_srcs_size
)
->
dev_tensor
().
as_megdnn
(),
output
(
0
)
->
dev_tensor
().
as_megdnn
(),
intl
::
get_megdnn_workspace_from_var
(
output
().
back
()));
}
else
{
megdnn_opr
()
->
exec
(
srcs
,
input
(
m_srcs_size
)
->
dev_tensor
().
as_megdnn
(),
input
(
m_srcs_size
+
1
)
->
dev_tensor
().
as_megdnn
(),
output
(
0
)
->
dev_tensor
().
as_megdnn
(),
intl
::
get_megdnn_workspace_from_var
(
output
().
back
()));
}
}
}
}
}
size_t
WarpPerspectiveForward
::
get_workspace_size_bytes
(
size_t
WarpPerspectiveForward
::
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
{
const
TensorShapeArray
&
output_shapes
)
const
{
if
(
input
().
size
()
==
3
)
{
if
(
!
m_is_multi_src
)
{
return
intl
::
_MegDNNOprMethInvoker
<
2
,
1
>::
get_workspace_in_bytes
(
if
(
input
().
size
()
==
3
)
{
megdnn_opr
(),
this
,
input_shapes
,
output_shapes
);
return
intl
::
_MegDNNOprMethInvoker
<
2
,
1
>::
get_workspace_in_bytes
(
megdnn_opr
(),
this
,
input_shapes
,
output_shapes
);
}
else
{
return
intl
::
_MegDNNOprMethInvoker
<
3
,
1
>::
get_workspace_in_bytes
(
megdnn_opr
(),
this
,
input_shapes
,
output_shapes
);
}
}
else
{
}
else
{
return
intl
::
_MegDNNOprMethInvoker
<
3
,
1
>::
get_workspace_in_bytes
(
TensorLayoutArray
srcs
;
megdnn_opr
(),
this
,
input_shapes
,
output_shapes
);
for
(
size_t
i
=
0
;
i
<
m_srcs_size
;
i
++
)
{
srcs
.
push_back
(
TensorLayout
{
input_shapes
[
i
],
input
(
i
)
->
dtype
(),
input
(
i
)
->
format
()});
}
TensorLayout
mat
{
input_shapes
[
m_srcs_size
],
input
(
m_srcs_size
)
->
dtype
(),
input
(
m_srcs_size
)
->
format
()};
TensorLayout
dst
{
output_shapes
[
0
],
output
(
0
)
->
dtype
(),
output
(
0
)
->
format
()};
if
(
input
().
size
()
-
m_srcs_size
==
2
)
{
return
megdnn_opr
()
->
get_workspace_in_bytes
(
srcs
,
mat
,
dst
);
}
else
{
TensorLayout
mat_idx
{
input_shapes
[
m_srcs_size
+
1
],
input
(
m_srcs_size
+
1
)
->
dtype
(),
input
(
m_srcs_size
+
1
)
->
format
()};
return
megdnn_opr
()
->
get_workspace_in_bytes
(
srcs
,
mat
,
mat_idx
,
dst
);
}
}
}
}
}
...
...
src/opr/impl/imgproc.sereg.h
浏览文件 @
f31e52d5
...
@@ -19,10 +19,34 @@ struct OprMaker<opr::WarpPerspective, 0> {
...
@@ -19,10 +19,34 @@ struct OprMaker<opr::WarpPerspective, 0> {
.
node
()
.
node
()
->
owner_opr
();
->
owner_opr
();
}
else
{
}
else
{
mgb_assert
(
inputs
.
size
()
==
4
);
bool
with_mat_idx
=
false
;
return
Opr
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
param
,
config
)
VarNodeArray
inps
=
inputs
;
.
node
()
VarNode
*
mat
,
*
mat_idx
,
*
outshp
;
->
owner_opr
();
outshp
=
inps
.
back
();
inps
.
pop_back
();
if
(
inps
.
back
()
->
shape
().
ndim
==
3
)
{
mat
=
inps
.
back
();
}
else
{
mat_idx
=
inps
.
back
();
inps
.
pop_back
();
mat
=
inps
.
back
();
with_mat_idx
=
true
;
}
inps
.
pop_back
();
if
(
inps
.
size
()
==
1
)
{
mgb_assert
(
with_mat_idx
);
return
Opr
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
param
,
config
)
.
node
()
->
owner_opr
();
}
else
if
(
with_mat_idx
)
{
return
Opr
::
make
(
inps
,
mat
,
mat_idx
,
outshp
,
param
,
config
)
.
node
()
->
owner_opr
();
}
else
{
return
Opr
::
make
(
inps
,
mat
,
outshp
,
param
,
config
).
node
()
->
owner_opr
();
}
}
}
}
}
};
};
...
...
src/opr/include/megbrain/opr/imgproc.h
浏览文件 @
f31e52d5
...
@@ -31,6 +31,10 @@ public:
...
@@ -31,6 +31,10 @@ public:
VarNode
*
in_tensor
,
VarNode
*
mat
,
VarNode
*
mat_idx
,
VarNode
*
out_shape
,
VarNode
*
in_tensor
,
VarNode
*
mat
,
VarNode
*
mat_idx
,
VarNode
*
out_shape
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
);
const
Param
&
param
,
const
OperatorNodeConfig
&
config
);
WarpPerspectiveForward
(
const
VarNodeArrayView
&
in_tensor
,
VarNode
*
mat
,
VarNode
*
mat_idx
,
VarNode
*
out_shape
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
);
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
SymbolVar
in_tensor
,
SymbolVar
mat
,
SymbolVar
mat_idx
,
SymbolVar
out_shape
,
SymbolVar
in_tensor
,
SymbolVar
mat
,
SymbolVar
mat_idx
,
SymbolVar
out_shape
,
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{});
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{});
...
@@ -49,6 +53,26 @@ public:
...
@@ -49,6 +53,26 @@ public:
config
);
config
);
}
}
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
const
VarNodeArrayView
&
in_tensor
,
SymbolVar
mat
,
SymbolVar
mat_idx
,
SymbolVar
out_shape
,
const
Param
&
param
=
{},
OperatorNodeConfig
config
=
{});
static
SymbolVar
make
(
const
VarNodeArrayView
&
in_tensor
,
SymbolVar
mat
,
SymbolVar
out_shape
,
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{})
{
return
make
(
in_tensor
,
mat
,
SymbolVar
{},
out_shape
,
param
,
config
);
}
static
SymbolVar
make
(
const
VarNodeArrayView
&
in_tensor
,
SymbolVar
mat
,
const
TensorShape
&
out_shape
,
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{})
{
return
make
(
in_tensor
,
mat
,
cg
::
var_from_tensor_shape
(
in_tensor
[
0
],
out_shape
),
param
,
config
);
}
private:
private:
void
init_output_dtype
()
override
;
void
init_output_dtype
()
override
;
void
add_input_layout_constraint
()
override
;
void
add_input_layout_constraint
()
override
;
...
@@ -62,6 +86,8 @@ private:
...
@@ -62,6 +86,8 @@ private:
const
TensorShapeArray
&
output_shapes
)
const
override
;
const
TensorShapeArray
&
output_shapes
)
const
override
;
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
;
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
;
bool
m_is_multi_src
=
false
;
size_t
m_srcs_size
=
0
;
};
};
using
WarpPerspective
=
WarpPerspectiveForward
;
using
WarpPerspective
=
WarpPerspectiveForward
;
...
...
src/opr/test/imgproc.cpp
浏览文件 @
f31e52d5
...
@@ -89,6 +89,250 @@ TEST(TestOprImgproc, WarpPerspective) {
...
@@ -89,6 +89,250 @@ TEST(TestOprImgproc, WarpPerspective) {
.
run
({
TensorShape
{
N
,
C
,
10
,
9
},
{
N
,
3
,
3
}},
opt
);
.
run
({
TensorShape
{
N
,
C
,
10
,
9
},
{
N
,
3
,
3
}},
opt
);
}
}
TEST
(
TestOprImgproc
,
WarpPerspective_MultiSrc
)
{
set_rand_seed
(
20220801
);
// a seed that can pass the test
constexpr
size_t
INP_H
=
6
,
INP_W
=
4
,
N
=
3
,
C
=
3
;
using
Checker
=
AutoOprChecker
<
4
,
1
>
;
TensorShape
out_shp
{
N
,
C
,
9
,
10
};
auto
make_graph
=
[
&
](
const
Checker
::
SymInpArray
&
inputs
)
->
Checker
::
SymOutArray
{
SymbolVarArray
srcs
;
for
(
size_t
i
=
0
;
i
<
N
;
i
++
)
{
srcs
.
push_back
(
inputs
[
i
]);
}
return
{
opr
::
WarpPerspective
::
make
(
srcs
,
inputs
[
N
],
TensorShape
{
out_shp
.
shape
[
2
],
out_shp
.
shape
[
3
]})};
};
auto
fwd
=
[
&
](
Checker
::
NumOutArray
&
dest
,
Checker
::
NumInpArray
inp
)
{
auto
opr
=
megdnn_naive_handle
()
->
create_operator
<
megdnn
::
WarpPerspective
>
();
dest
[
0
].
resize
(
out_shp
);
megdnn
::
TensorNDArray
srcs
;
for
(
size_t
i
=
0
;
i
<
N
;
i
++
)
{
srcs
.
push_back
(
inp
[
i
]
->
as_megdnn
());
}
opr
->
exec
(
srcs
,
inp
[
N
]
->
as_megdnn
(),
dest
[
0
].
as_megdnn
(),
{});
};
auto
dump_mat
=
[
&
](
const
Checker
::
NumInpArray
&
inp
)
->
std
::
string
{
std
::
ostringstream
ostr
;
ostr
<<
std
::
setprecision
(
3
);
auto
&&
mat
=
*
inp
[
N
];
mgb_assert
(
mat
.
shape
().
ndim
==
3
);
auto
ptr
=
mat
.
ptr
<
float
>
();
for
(
size_t
n
=
0
;
n
<
mat
.
shape
().
shape
[
0
];
++
n
)
{
ostr
<<
"mat "
<<
n
<<
":
\n
"
;
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
3
;
++
j
)
{
ostr
<<
std
::
setw
(
10
)
<<
*
(
ptr
++
);
}
ostr
<<
'\n'
;
}
}
return
ostr
.
str
();
};
Checker
::
RunOptions
opt
;
opt
.
numdiff_eps_single_inp
[
1
]
=
1e-5
;
opt
.
numdiff_max_err_single_inp
[
1
]
=
0.5
;
Checker
(
make_graph
,
fwd
)
.
set_input_generator
(
N
,
warp_perspective_mat_gen
(
N
,
INP_H
,
INP_W
))
.
set_input_dump_on_error
(
dump_mat
)
.
disable_grad_check
()
.
run
({
TensorShape
{
1
,
C
,
10
,
9
},
{
1
,
C
,
10
,
9
},
{
1
,
C
,
10
,
9
},
{
N
,
3
,
3
}},
opt
)
.
run
({
TensorShape
{
1
,
C
,
4
,
5
},
{
1
,
C
,
4
,
5
},
{
1
,
C
,
4
,
5
},
{
N
,
3
,
3
}},
opt
)
.
run
({
TensorShape
{
1
,
C
,
6
,
5
},
{
1
,
C
,
6
,
5
},
{
1
,
C
,
6
,
5
},
{
N
,
3
,
3
}},
opt
);
}
TEST
(
TestOprImgproc
,
WarpPerspective_MultiSrc_NHWC
)
{
set_rand_seed
(
20220801
);
// a seed that can pass the test
opr
::
WarpPerspective
::
Param
param
;
param
.
format
=
opr
::
WarpPerspective
::
Param
::
Format
::
NHWC
;
constexpr
size_t
INP_H
=
6
,
INP_W
=
4
,
N
=
3
,
C
=
3
;
using
Checker
=
AutoOprChecker
<
4
,
1
>
;
TensorShape
out_shp
{
N
,
9
,
10
,
C
};
auto
make_graph
=
[
&
](
const
Checker
::
SymInpArray
&
inputs
)
->
Checker
::
SymOutArray
{
SymbolVarArray
srcs
;
for
(
size_t
i
=
0
;
i
<
N
;
i
++
)
{
srcs
.
push_back
(
inputs
[
i
]);
}
return
{
opr
::
WarpPerspective
::
make
(
srcs
,
inputs
[
N
],
TensorShape
{
out_shp
.
shape
[
1
],
out_shp
.
shape
[
2
]},
param
)};
};
auto
fwd
=
[
&
](
Checker
::
NumOutArray
&
dest
,
Checker
::
NumInpArray
inp
)
{
auto
opr
=
megdnn_naive_handle
()
->
create_operator
<
megdnn
::
WarpPerspective
>
();
opr
->
param
()
=
param
;
dest
[
0
].
resize
(
out_shp
);
megdnn
::
TensorNDArray
srcs
;
for
(
size_t
i
=
0
;
i
<
N
;
i
++
)
{
srcs
.
push_back
(
inp
[
i
]
->
as_megdnn
());
}
opr
->
exec
(
srcs
,
inp
[
N
]
->
as_megdnn
(),
dest
[
0
].
as_megdnn
(),
{});
};
Checker
::
RunOptions
opt
;
opt
.
numdiff_eps_single_inp
[
1
]
=
1e-5
;
opt
.
numdiff_max_err_single_inp
[
1
]
=
0.5
;
Checker
(
make_graph
,
fwd
)
.
set_input_generator
(
N
,
warp_perspective_mat_gen
(
N
,
INP_H
,
INP_W
))
.
disable_grad_check
()
.
run
({
TensorShape
{
1
,
10
,
9
,
C
},
{
1
,
10
,
9
,
C
},
{
1
,
10
,
9
,
C
},
{
N
,
3
,
3
}},
opt
)
.
run
({
TensorShape
{
1
,
4
,
5
,
C
},
{
1
,
4
,
5
,
C
},
{
1
,
4
,
5
,
C
},
{
N
,
3
,
3
}},
opt
)
.
run
({
TensorShape
{
1
,
6
,
5
,
C
},
{
1
,
6
,
5
,
C
},
{
1
,
6
,
5
,
C
},
{
N
,
3
,
3
}},
opt
);
}
TEST
(
TestOprImgproc
,
WarpPerspectiveWithMatIdx_MultiSrc
)
{
constexpr
size_t
INP_H
=
13
,
INP_W
=
9
,
N_MAT
=
23
,
N_SRC
=
3
,
C
=
3
;
std
::
mt19937
rng
(
next_rand_seed
());
auto
rand_real
=
[
&
](
double
lo
,
double
hi
)
{
return
rng
()
/
(
std
::
mt19937
::
max
()
+
1.0
)
*
(
hi
-
lo
)
+
lo
;
};
auto
rand_real2
=
[
&
](
double
range
)
{
return
rand_real
(
-
range
,
range
);
};
using
Checker
=
AutoOprChecker
<
5
,
1
>
;
TensorShape
out_shp
{
N_MAT
,
C
,
9
,
10
};
auto
make_graph
=
[
&
](
const
Checker
::
SymInpArray
&
inputs
)
->
Checker
::
SymOutArray
{
SymbolVarArray
srcs
;
for
(
size_t
i
=
0
;
i
<
N_SRC
;
i
++
)
{
srcs
.
push_back
(
inputs
[
i
]);
}
return
{
opr
::
WarpPerspective
::
make
(
srcs
,
inputs
[
N_SRC
],
inputs
[
N_SRC
+
1
],
cg
::
var_from_tensor_shape
(
srcs
[
0
],
{
out_shp
.
shape
[
2
],
out_shp
.
shape
[
3
]}))};
};
auto
fwd
=
[
&
](
Checker
::
NumOutArray
&
dest
,
Checker
::
NumInpArray
inp
)
{
auto
opr
=
megdnn_naive_handle
()
->
create_operator
<
megdnn
::
WarpPerspective
>
();
dest
[
0
].
resize
(
out_shp
);
megdnn
::
TensorNDArray
srcs
;
for
(
size_t
i
=
0
;
i
<
N_SRC
;
i
++
)
{
srcs
.
push_back
(
inp
[
i
]
->
as_megdnn
());
}
opr
->
exec
(
srcs
,
inp
[
N_SRC
]
->
as_megdnn
(),
inp
[
N_SRC
+
1
]
->
as_megdnn
(),
dest
[
0
].
as_megdnn
(),
{});
};
auto
gen_mat
=
[
&
](
HostTensorND
&
mat
)
{
auto
ptr
=
mat
.
ptr
<
float
>
();
for
(
size_t
i
=
0
;
i
<
N_MAT
;
++
i
)
{
auto
rot
=
rand_real
(
0
,
M_PI
*
2
),
scale
=
rand_real
(
0.8
,
1.2
),
sheer
=
rand_real
(
0.9
,
1.1
),
dy
=
rand_real2
(
INP_H
*
0.5
),
dx
=
rand_real2
(
INP_W
*
0.5
),
ky
=
rand_real2
(
0.1
/
INP_H
),
kx
=
rand_real2
(
0.1
/
INP_W
),
kb
=
rand_real2
(
0.1
)
+
1
;
ptr
[
0
]
=
ptr
[
4
]
=
cos
(
rot
)
*
scale
;
ptr
[
1
]
=
-
(
ptr
[
3
]
=
sin
(
rot
)
*
scale
);
ptr
[
3
]
*=
sheer
;
ptr
[
4
]
*=
sheer
;
ptr
[
2
]
=
dx
;
ptr
[
5
]
=
dy
;
ptr
[
6
]
=
kx
;
ptr
[
7
]
=
ky
;
ptr
[
8
]
=
kb
;
ptr
+=
9
;
}
mgb_assert
(
ptr
==
mat
.
ptr
<
float
>
()
+
mat
.
shape
().
total_nr_elems
());
};
HostTensorGenerator
<
dtype
::
Int32
>
gen_mat_idx_rng
{
0
,
N_SRC
};
auto
gen_mat_idx
=
[
&
](
HostTensorND
&
mat
)
{
mat
=
*
gen_mat_idx_rng
(
mat
.
shape
());
};
Checker
(
make_graph
,
fwd
)
.
set_input_generator
(
N_SRC
,
gen_mat
)
.
set_input_generator
(
N_SRC
+
1
,
gen_mat_idx
)
.
set_input_dtype
(
N_SRC
+
1
,
dtype
::
Int32
{})
.
disable_grad_check
()
.
run
({
TensorShape
{
1
,
C
,
4
,
5
},
{
1
,
C
,
4
,
5
},
{
1
,
C
,
4
,
5
},
{
N_MAT
,
3
,
3
},
{
N_MAT
}})
.
run
({
TensorShape
{
1
,
C
,
6
,
5
},
{
1
,
C
,
6
,
5
},
{
1
,
C
,
6
,
5
},
{
N_MAT
,
3
,
3
},
{
N_MAT
}})
.
run
({
TensorShape
{
1
,
C
,
22
,
19
},
{
1
,
C
,
22
,
19
},
{
1
,
C
,
22
,
19
},
{
N_MAT
,
3
,
3
},
{
N_MAT
}});
}
TEST
(
TestOprImgproc
,
WarpPerspectiveWithMatIdx_MultiSrc_NHWC
)
{
constexpr
size_t
INP_H
=
13
,
INP_W
=
9
,
N_MAT
=
23
,
N_SRC
=
3
,
C
=
3
;
opr
::
WarpPerspective
::
Param
param
;
param
.
format
=
opr
::
WarpPerspective
::
Param
::
Format
::
NHWC
;
std
::
mt19937
rng
(
next_rand_seed
());
auto
rand_real
=
[
&
](
double
lo
,
double
hi
)
{
return
rng
()
/
(
std
::
mt19937
::
max
()
+
1.0
)
*
(
hi
-
lo
)
+
lo
;
};
auto
rand_real2
=
[
&
](
double
range
)
{
return
rand_real
(
-
range
,
range
);
};
using
Checker
=
AutoOprChecker
<
5
,
1
>
;
TensorShape
out_shp
{
N_MAT
,
9
,
10
,
C
};
auto
make_graph
=
[
&
](
const
Checker
::
SymInpArray
&
inputs
)
->
Checker
::
SymOutArray
{
SymbolVarArray
srcs
;
for
(
size_t
i
=
0
;
i
<
N_SRC
;
i
++
)
{
srcs
.
push_back
(
inputs
[
i
]);
}
return
{
opr
::
WarpPerspective
::
make
(
srcs
,
inputs
[
N_SRC
],
inputs
[
N_SRC
+
1
],
cg
::
var_from_tensor_shape
(
srcs
[
0
],
{
out_shp
.
shape
[
1
],
out_shp
.
shape
[
2
]}),
param
)};
};
auto
fwd
=
[
&
](
Checker
::
NumOutArray
&
dest
,
Checker
::
NumInpArray
inp
)
{
auto
opr
=
megdnn_naive_handle
()
->
create_operator
<
megdnn
::
WarpPerspective
>
();
opr
->
param
()
=
param
;
dest
[
0
].
resize
(
out_shp
);
megdnn
::
TensorNDArray
srcs
;
for
(
size_t
i
=
0
;
i
<
N_SRC
;
i
++
)
{
srcs
.
push_back
(
inp
[
i
]
->
as_megdnn
());
}
opr
->
exec
(
srcs
,
inp
[
N_SRC
]
->
as_megdnn
(),
inp
[
N_SRC
+
1
]
->
as_megdnn
(),
dest
[
0
].
as_megdnn
(),
{});
};
auto
gen_mat
=
[
&
](
HostTensorND
&
mat
)
{
auto
ptr
=
mat
.
ptr
<
float
>
();
for
(
size_t
i
=
0
;
i
<
N_MAT
;
++
i
)
{
auto
rot
=
rand_real
(
0
,
M_PI
*
2
),
scale
=
rand_real
(
0.8
,
1.2
),
sheer
=
rand_real
(
0.9
,
1.1
),
dy
=
rand_real2
(
INP_H
*
0.5
),
dx
=
rand_real2
(
INP_W
*
0.5
),
ky
=
rand_real2
(
0.1
/
INP_H
),
kx
=
rand_real2
(
0.1
/
INP_W
),
kb
=
rand_real2
(
0.1
)
+
1
;
ptr
[
0
]
=
ptr
[
4
]
=
cos
(
rot
)
*
scale
;
ptr
[
1
]
=
-
(
ptr
[
3
]
=
sin
(
rot
)
*
scale
);
ptr
[
3
]
*=
sheer
;
ptr
[
4
]
*=
sheer
;
ptr
[
2
]
=
dx
;
ptr
[
5
]
=
dy
;
ptr
[
6
]
=
kx
;
ptr
[
7
]
=
ky
;
ptr
[
8
]
=
kb
;
ptr
+=
9
;
}
mgb_assert
(
ptr
==
mat
.
ptr
<
float
>
()
+
mat
.
shape
().
total_nr_elems
());
};
HostTensorGenerator
<
dtype
::
Int32
>
gen_mat_idx_rng
{
0
,
N_SRC
};
auto
gen_mat_idx
=
[
&
](
HostTensorND
&
mat
)
{
mat
=
*
gen_mat_idx_rng
(
mat
.
shape
());
};
Checker
(
make_graph
,
fwd
)
.
set_input_generator
(
N_SRC
,
gen_mat
)
.
set_input_generator
(
N_SRC
+
1
,
gen_mat_idx
)
.
set_input_dtype
(
N_SRC
+
1
,
dtype
::
Int32
{})
.
disable_grad_check
()
.
run
({
TensorShape
{
1
,
4
,
5
,
C
},
{
1
,
4
,
5
,
C
},
{
1
,
4
,
5
,
C
},
{
N_MAT
,
3
,
3
},
{
N_MAT
}})
.
run
({
TensorShape
{
1
,
6
,
5
,
C
},
{
1
,
6
,
5
,
C
},
{
1
,
6
,
5
,
C
},
{
N_MAT
,
3
,
3
},
{
N_MAT
}})
.
run
({
TensorShape
{
1
,
22
,
19
,
C
},
{
1
,
22
,
19
,
C
},
{
1
,
22
,
19
,
C
},
{
N_MAT
,
3
,
3
},
{
N_MAT
}});
}
TEST
(
TestOprImgproc
,
WarpPerspective_NCHW4
)
{
TEST
(
TestOprImgproc
,
WarpPerspective_NCHW4
)
{
set_rand_seed
(
19931102
);
set_rand_seed
(
19931102
);
constexpr
size_t
INP_H
=
6
,
INP_W
=
4
,
N
=
2
,
C
=
12
;
constexpr
size_t
INP_H
=
6
,
INP_W
=
4
,
N
=
2
,
C
=
12
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录