Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f96429c0
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f96429c0
编写于
5月 05, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): support empty tensor in roi_align
GitOrigin-RevId: aeb2770401e8dc6b0eea1469a54bb977dd1521db
上级
2f829aaa
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
202 addition
and
45 deletion
+202
-45
dnn/src/common/roi_align.cpp
dnn/src/common/roi_align.cpp
+4
-2
imperative/python/megengine/functional/vision.py
imperative/python/megengine/functional/vision.py
+11
-11
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+31
-0
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-31
imperative/src/impl/ops/vision.cpp
imperative/src/impl/ops/vision.cpp
+118
-1
src/opr/impl/dnn/roi_align.cpp
src/opr/impl/dnn/roi_align.cpp
+31
-0
src/opr/include/megbrain/opr/dnn/roi_align.h
src/opr/include/megbrain/opr/dnn/roi_align.h
+7
-0
未找到文件。
dnn/src/common/roi_align.cpp
浏览文件 @
f96429c0
...
...
@@ -7,8 +7,10 @@ namespace megdnn {
void
ROIAlignBase
::
deduce_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
rois
,
TensorLayout
&
dst
,
TensorLayout
&
index
)
{
megdnn_assert_contiguous
(
src
);
megdnn_assert_contiguous
(
rois
);
if
(
!
src
.
is_empty
())
megdnn_assert_contiguous
(
src
);
if
(
!
rois
.
is_empty
())
megdnn_assert_contiguous
(
rois
);
megdnn_assert_contiguous
(
dst
);
megdnn_assert_contiguous
(
index
);
auto
errmsg
=
[
&
]()
{
...
...
imperative/python/megengine/functional/vision.py
浏览文件 @
f96429c0
...
...
@@ -16,14 +16,14 @@ from .tensor import broadcast_to, concat, expand_dims, reshape, transpose
__all__
=
[
"correlation"
,
"cvt_color"
,
"roi_pooling"
,
"roi_align"
,
"interpolate"
,
"nms"
,
"nvof"
,
"remap"
,
"roi_align"
,
"roi_pooling"
,
"warp_affine"
,
"warp_perspective"
,
"interpolate"
,
"nvof"
,
]
...
...
@@ -95,9 +95,9 @@ def roi_pooling(
Args:
inp: tensor that represents the input feature, `(N, C, H, W)` images.
rois: K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy.
output_shape: height, width)` of output rois feature.
mode: max" or "average", use max/average align just like max/average pooling. Default: "max"
rois:
`(
K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy.
output_shape:
`(
height, width)` of output rois feature.
mode:
"
max" or "average", use max/average align just like max/average pooling. Default: "max"
scale: scale the input boxes by this number. Default: 1.0
Returns:
...
...
@@ -176,9 +176,9 @@ def roi_align(
Args:
inp: tensor that represents the input feature, shape is `(N, C, H, W)`.
rois: N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``.
output_shape: height, width)` shape of output rois feature.
mode: max" or "average", use max/average align just like max/average pooling. Default: "average"
rois:
`(
N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``.
output_shape:
`(
height, width)` shape of output rois feature.
mode:
"
max" or "average", use max/average align just like max/average pooling. Default: "average"
spatial_scale: scale the input boxes by this number. Default: 1.0
sample_points: number of inputs samples to take for each output sample.
0 to take samples densely. Default: 2
...
...
@@ -345,7 +345,7 @@ def warp_affine(
Args:
inp: input image.
mat: batch, 2, 3)` transformation matrix.
mat:
`(
batch, 2, 3)` transformation matrix.
out_shape: output tensor shape.
border_mode: pixel extrapolation method.
Default: "wrap". Currently "constant", "reflect",
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
f96429c0
...
...
@@ -289,6 +289,37 @@ def test_roi_align():
assert
make_shape_tuple
(
inp_feat
.
grad
.
shape
)
==
make_shape_tuple
(
inp_feat
.
shape
)
@
pytest
.
mark
.
parametrize
(
"shapes"
,
[((
2
,
0
,
26
,
26
),
(
4
,
5
)),
((
2
,
3
,
26
,
26
),
(
0
,
5
))])
@
pytest
.
mark
.
parametrize
(
"is_tracing"
,
[
False
,
True
])
def
test_roi_align_empty
(
shapes
,
is_tracing
):
inp_feat
=
tensor
(
np
.
random
.
randn
(
*
(
shapes
[
0
])))
rois
=
tensor
(
np
.
random
.
random
(
shapes
[
1
]))
output_shape
=
(
7
,
7
)
def
func
(
inp
,
rois
):
out_feat
=
F
.
vision
.
roi_align
(
inp_feat
,
rois
,
output_shape
=
output_shape
,
mode
=
"average"
,
spatial_scale
=
1.0
/
4
,
sample_points
=
2
,
aligned
=
True
,
)
return
out_feat
if
is_tracing
:
func
=
jit
.
trace
(
func
)
for
_
in
range
(
3
):
out_feat
=
func
(
inp_feat
,
rois
)
assert
make_shape_tuple
(
out_feat
.
shape
)
==
(
rois
.
shape
[
0
],
inp_feat
.
shape
[
1
],
*
output_shape
,
)
def
_gen_correlation
(
random
=
True
,
constant
=
1
,
image_shape
=
(
2
,
1
,
160
,
160
)):
if
random
:
inp_feat1
=
np
.
random
.
randn
(
...
...
imperative/src/impl/ops/specializations.cpp
浏览文件 @
f96429c0
...
...
@@ -441,21 +441,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fall
}
// namespace assert_equal
}
// namespace
namespace
{
namespace
roi_align
{
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
ROIAlign
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
op
.
make_name
()};
auto
*
opr
=
opr
::
ROIAlign
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
)
.
node
()
->
owner_opr
();
return
{
opr
->
output
(
0
),
opr
->
output
(
1
)};
}
OP_TRAIT_REG
(
ROIAlign
,
ROIAlign
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace roi_align
}
// namespace
namespace
{
namespace
correlation
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
@@ -522,22 +507,6 @@ OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback();
}
// namespace diag
}
// namespace
namespace
{
namespace
roi_pooling
{
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
ROIPooling
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
3
);
OperatorNodeConfig
config
{
op
.
make_name
()};
auto
*
opr
=
opr
::
ROIPooling
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
(),
config
)
.
node
()
->
owner_opr
();
return
{
opr
->
output
(
0
),
opr
->
output
(
1
)};
}
OP_TRAIT_REG
(
ROIPooling
,
ROIPooling
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace roi_pooling
}
// namespace
namespace
{
namespace
remap
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
imperative/src/impl/ops/vision.cpp
浏览文件 @
f96429c0
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/imgproc.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace
mgb
{
namespace
imperative
{
...
...
@@ -15,5 +18,119 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
}
OP_TRAIT_REG
(
CvtColor
,
CvtColor
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace
namespace
{
namespace
roi_align
{
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
ROIAlign
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
op
.
make_name
()};
auto
*
opr
=
opr
::
ROIAlign
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
)
.
node
()
->
owner_opr
();
return
{
opr
->
output
(
0
),
opr
->
output
(
1
)};
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
op
=
static_cast
<
const
ROIAlign
&>
(
def
);
if
(
inputs
[
0
].
layout
.
is_empty
()
||
inputs
[
1
].
layout
.
is_empty
())
{
return
{{{
TensorLayout
(
inputs
[
0
].
layout
.
dtype
),
inputs
[
0
].
comp_node
},
{
TensorLayout
(
dtype
::
Int32
()),
inputs
[
1
].
comp_node
}},
false
};
}
SmallVector
<
LogicalTensorDesc
>
descs
(
2u
);
size_t
n
=
inputs
[
1
].
layout
[
0
];
size_t
c
=
inputs
[
0
].
layout
[
1
];
descs
[
0
].
layout
=
TensorLayout
(
{
n
,
c
,
op
.
pooled_height
,
op
.
pooled_width
},
inputs
[
0
].
layout
.
dtype
);
descs
[
0
].
layout
.
init_contiguous_stride
();
descs
[
0
].
comp_node
=
inputs
[
0
].
comp_node
;
descs
[
1
].
layout
=
TensorLayout
({
n
,
c
,
op
.
pooled_height
,
op
.
pooled_width
},
dtype
::
Int32
());
descs
[
1
].
layout
.
init_contiguous_stride
();
descs
[
1
].
comp_node
=
descs
[
0
].
comp_node
;
return
{
descs
,
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op
=
static_cast
<
const
ROIAlign
&>
(
def
);
CompNode
cn
=
inputs
[
0
]
->
comp_node
();
TensorLayout
out_layout
=
output_descs
[
0
].
layout
;
TensorLayout
ind_layout
=
output_descs
[
1
].
layout
;
if
(
!
validated
)
{
size_t
n
=
inputs
[
1
]
->
layout
()[
0
];
size_t
c
=
inputs
[
0
]
->
layout
()[
1
];
out_layout
=
TensorLayout
(
{
n
,
c
,
op
.
pooled_height
,
op
.
pooled_width
},
inputs
[
0
]
->
layout
().
dtype
);
out_layout
.
init_contiguous_stride
();
ind_layout
=
TensorLayout
({
n
,
c
,
op
.
pooled_height
,
op
.
pooled_width
},
dtype
::
Int32
());
ind_layout
.
init_contiguous_stride
();
}
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
out_layout
);
DeviceTensorND
inds
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
ind_layout
);
if
(
out_layout
.
is_empty
()
||
ind_layout
.
is_empty
())
{
return
{
Tensor
::
make
(
out
),
Tensor
::
make
(
inds
)};
}
DnnOprCaller
<
megdnn
::
ROIAlign
>
dnn_opr
(
cn
);
dnn_opr
.
op
->
param
()
=
op
.
param
();
size_t
sz
=
dnn_opr
.
op
->
get_workspace_in_bytes
(
inputs
[
0
]
->
layout
(),
inputs
[
1
]
->
layout
(),
out_layout
,
ind_layout
);
TensorLayout
w_layout
({
sz
},
dtype
::
Byte
());
auto
dnn_wk
=
dnn_opr
.
create_workspace
(
w_layout
);
dnn_opr
.
op
->
exec
(
inputs
[
0
]
->
dnn_tensor
(),
inputs
[
1
]
->
dnn_tensor
(),
out
.
as_megdnn
(),
inds
.
as_megdnn
(),
dnn_wk
);
return
{
Tensor
::
make
(
out
),
Tensor
::
make
(
inds
)};
}
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
get_input_layout_constraint
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
layout_checker
(
inputs
.
size
());
layout_checker
[
0
]
=
layout_checker
[
1
]
=
[](
const
TensorLayout
&
layout
)
{
return
layout
.
is_contiguous
();
};
return
layout_checker
;
}
OP_TRAIT_REG
(
ROIAlign
,
ROIAlign
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
get_input_layout_constraint
(
get_input_layout_constraint
)
.
fallback
();
}
// namespace roi_align
}
// namespace
namespace
{
namespace
roi_pooling
{
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
ROIPooling
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
3
);
OperatorNodeConfig
config
{
op
.
make_name
()};
auto
*
opr
=
opr
::
ROIPooling
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
(),
config
)
.
node
()
->
owner_opr
();
return
{
opr
->
output
(
0
),
opr
->
output
(
1
)};
}
OP_TRAIT_REG
(
ROIPooling
,
ROIPooling
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace roi_pooling
}
// namespace
}
// namespace imperative
}
// namespace mgb
src/opr/impl/dnn/roi_align.cpp
浏览文件 @
f96429c0
...
...
@@ -20,6 +20,8 @@ ROIAlignForward::ROIAlignForward(
add_input
({
src
,
rois
});
output
(
0
)
->
dtype
(
dtype
::
Float32
());
output
(
1
)
->
dtype
(
dtype
::
Int32
());
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
output
(
1
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
}
SymbolVar
ROIAlignForward
::
make
(
...
...
@@ -29,6 +31,35 @@ SymbolVar ROIAlignForward::make(
src
.
node
(),
rois
.
node
(),
param
,
config
);
}
ROIAlignForward
::
NodeProp
*
ROIAlignForward
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
ret
->
add_dep_type_existing_var
(
input
(
1
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
void
ROIAlignForward
::
scn_do_execute
()
{
auto
src
=
input
(
0
)
->
dev_tensor
().
as_megdnn
(),
rois
=
input
(
1
)
->
dev_tensor
().
as_megdnn
(),
dst
=
output
(
0
)
->
dev_tensor
().
as_megdnn
(),
index
=
output
(
1
)
->
dev_tensor
().
as_megdnn
();
if
((
src
.
layout
.
is_empty
()
||
rois
.
layout
.
is_empty
()))
{
return
;
}
megdnn_opr
()
->
exec
(
src
,
rois
,
dst
,
index
,
intl
::
get_megdnn_workspace_from_var
(
output
(
2
)));
}
size_t
ROIAlignForward
::
get_workspace_size_bytes
(
const
TensorShapeArray
&
inp_shapes
,
const
TensorShapeArray
&
out_shapes
)
const
{
TensorLayout
inp
{
inp_shapes
[
0
],
input
(
0
)
->
dtype
(),
input
(
0
)
->
format
()},
rois
{
inp_shapes
[
1
],
input
(
1
)
->
dtype
(),
input
(
1
)
->
format
()},
out
{
out_shapes
[
0
],
output
(
0
)
->
dtype
(),
output
(
0
)
->
format
()},
index
{
out_shapes
[
1
],
output
(
1
)
->
dtype
(),
output
(
1
)
->
format
()};
return
megdnn_opr
()
->
get_workspace_in_bytes
(
inp
,
rois
,
index
,
out
);
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ROIAlignForward
)
{
if
(
wrt_idx
==
0
)
{
...
...
src/opr/include/megbrain/opr/dnn/roi_align.h
浏览文件 @
f96429c0
...
...
@@ -16,6 +16,13 @@ public:
MGE_WIN_DECLSPEC_FUC
static
SymbolVar
make
(
SymbolVar
src
,
SymbolVar
rois
,
const
Param
&
param
=
{},
const
OperatorNodeConfig
&
config
=
{});
private:
void
scn_do_execute
()
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
override
;
};
using
ROIAlign
=
ROIAlignForward
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录