提交 f96429c0 编写于 作者: M Megvii Engine Team

feat(imperative): support empty tensor in roi_align

GitOrigin-RevId: aeb2770401e8dc6b0eea1469a54bb977dd1521db
上级 2f829aaa
......@@ -7,7 +7,9 @@ namespace megdnn {
void ROIAlignBase::deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst,
TensorLayout& index) {
if (!src.is_empty())
megdnn_assert_contiguous(src);
if (!rois.is_empty())
megdnn_assert_contiguous(rois);
megdnn_assert_contiguous(dst);
megdnn_assert_contiguous(index);
......
......@@ -16,14 +16,14 @@ from .tensor import broadcast_to, concat, expand_dims, reshape, transpose
__all__ = [
"correlation",
"cvt_color",
"roi_pooling",
"roi_align",
"interpolate",
"nms",
"nvof",
"remap",
"roi_align",
"roi_pooling",
"warp_affine",
"warp_perspective",
"interpolate",
"nvof",
]
......@@ -95,9 +95,9 @@ def roi_pooling(
Args:
inp: tensor that represents the input feature, `(N, C, H, W)` images.
rois: K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy.
output_shape: height, width)` of output rois feature.
mode: max" or "average", use max/average align just like max/average pooling. Default: "max"
rois: `(K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy.
output_shape: `(height, width)` of output rois feature.
mode: "max" or "average", use max/average align just like max/average pooling. Default: "max"
scale: scale the input boxes by this number. Default: 1.0
Returns:
......@@ -176,9 +176,9 @@ def roi_align(
Args:
inp: tensor that represents the input feature, shape is `(N, C, H, W)`.
rois: N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``.
output_shape: height, width)` shape of output rois feature.
mode: max" or "average", use max/average align just like max/average pooling. Default: "average"
rois: `(N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``.
output_shape: `(height, width)` shape of output rois feature.
mode: "max" or "average", use max/average align just like max/average pooling. Default: "average"
spatial_scale: scale the input boxes by this number. Default: 1.0
sample_points: number of inputs samples to take for each output sample.
0 to take samples densely. Default: 2
......@@ -345,7 +345,7 @@ def warp_affine(
Args:
inp: input image.
mat: batch, 2, 3)` transformation matrix.
mat: `(batch, 2, 3)` transformation matrix.
out_shape: output tensor shape.
border_mode: pixel extrapolation method.
Default: "wrap". Currently "constant", "reflect",
......
......@@ -289,6 +289,37 @@ def test_roi_align():
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
@pytest.mark.parametrize("shapes", [((2, 0, 26, 26), (4, 5)), ((2, 3, 26, 26), (0, 5))])
@pytest.mark.parametrize("is_tracing", [False, True])
def test_roi_align_empty(shapes, is_tracing):
inp_feat = tensor(np.random.randn(*(shapes[0])))
rois = tensor(np.random.random(shapes[1]))
output_shape = (7, 7)
def func(inp, rois):
out_feat = F.vision.roi_align(
inp_feat,
rois,
output_shape=output_shape,
mode="average",
spatial_scale=1.0 / 4,
sample_points=2,
aligned=True,
)
return out_feat
if is_tracing:
func = jit.trace(func)
for _ in range(3):
out_feat = func(inp_feat, rois)
assert make_shape_tuple(out_feat.shape) == (
rois.shape[0],
inp_feat.shape[1],
*output_shape,
)
def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
if random:
inp_feat1 = np.random.randn(
......
......@@ -441,21 +441,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fall
} // namespace assert_equal
} // namespace
namespace {
namespace roi_align {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIAlign&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config)
.node()
->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIAlign, ROIAlign).apply_on_var_node(apply_on_var_node).fallback();
} // namespace roi_align
} // namespace
namespace {
namespace correlation {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......@@ -522,22 +507,6 @@ OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback();
} // namespace diag
} // namespace
namespace {
namespace roi_pooling {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIPooling&>(def);
mgb_assert(inputs.size() == 3);
OperatorNodeConfig config{op.make_name()};
auto* opr =
opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config)
.node()
->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback();
} // namespace roi_pooling
} // namespace
namespace {
namespace remap {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/imgproc.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
......@@ -15,5 +18,119 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
}
OP_TRAIT_REG(CvtColor, CvtColor).apply_on_var_node(apply_on_var_node).fallback();
} // namespace
namespace {
namespace roi_align {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIAlign&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config)
.node()
->owner_opr();
return {opr->output(0), opr->output(1)};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op = static_cast<const ROIAlign&>(def);
if (inputs[0].layout.is_empty() || inputs[1].layout.is_empty()) {
return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node},
{TensorLayout(dtype::Int32()), inputs[1].comp_node}},
false};
}
SmallVector<LogicalTensorDesc> descs(2u);
size_t n = inputs[1].layout[0];
size_t c = inputs[0].layout[1];
descs[0].layout = TensorLayout(
{n, c, op.pooled_height, op.pooled_width}, inputs[0].layout.dtype);
descs[0].layout.init_contiguous_stride();
descs[0].comp_node = inputs[0].comp_node;
descs[1].layout =
TensorLayout({n, c, op.pooled_height, op.pooled_width}, dtype::Int32());
descs[1].layout.init_contiguous_stride();
descs[1].comp_node = descs[0].comp_node;
return {descs, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = static_cast<const ROIAlign&>(def);
CompNode cn = inputs[0]->comp_node();
TensorLayout out_layout = output_descs[0].layout;
TensorLayout ind_layout = output_descs[1].layout;
if (!validated) {
size_t n = inputs[1]->layout()[0];
size_t c = inputs[0]->layout()[1];
out_layout = TensorLayout(
{n, c, op.pooled_height, op.pooled_width}, inputs[0]->layout().dtype);
out_layout.init_contiguous_stride();
ind_layout =
TensorLayout({n, c, op.pooled_height, op.pooled_width}, dtype::Int32());
ind_layout.init_contiguous_stride();
}
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(cn, out_layout);
DeviceTensorND inds =
BlobManager::inst()->alloc_workspace_with_defrag(cn, ind_layout);
if (out_layout.is_empty() || ind_layout.is_empty()) {
return {Tensor::make(out), Tensor::make(inds)};
}
DnnOprCaller<megdnn::ROIAlign> dnn_opr(cn);
dnn_opr.op->param() = op.param();
size_t sz = dnn_opr.op->get_workspace_in_bytes(
inputs[0]->layout(), inputs[1]->layout(), out_layout, ind_layout);
TensorLayout w_layout({sz}, dtype::Byte());
auto dnn_wk = dnn_opr.create_workspace(w_layout);
dnn_opr.op->exec(
inputs[0]->dnn_tensor(), inputs[1]->dnn_tensor(), out.as_megdnn(),
inds.as_megdnn(), dnn_wk);
return {Tensor::make(out), Tensor::make(inds)};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) {
return layout.is_contiguous();
};
return layout_checker;
}
OP_TRAIT_REG(ROIAlign, ROIAlign)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace roi_align
} // namespace
namespace {
namespace roi_pooling {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIPooling&>(def);
mgb_assert(inputs.size() == 3);
OperatorNodeConfig config{op.make_name()};
auto* opr =
opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config)
.node()
->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback();
} // namespace roi_pooling
} // namespace
} // namespace imperative
} // namespace mgb
......@@ -20,6 +20,8 @@ ROIAlignForward::ROIAlignForward(
add_input({src, rois});
output(0)->dtype(dtype::Float32());
output(1)->dtype(dtype::Int32());
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
SymbolVar ROIAlignForward::make(
......@@ -29,6 +31,35 @@ SymbolVar ROIAlignForward::make(
src.node(), rois.node(), param, config);
}
ROIAlignForward::NodeProp* ROIAlignForward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
void ROIAlignForward::scn_do_execute() {
auto src = input(0)->dev_tensor().as_megdnn(),
rois = input(1)->dev_tensor().as_megdnn(),
dst = output(0)->dev_tensor().as_megdnn(),
index = output(1)->dev_tensor().as_megdnn();
if ((src.layout.is_empty() || rois.layout.is_empty())) {
return;
}
megdnn_opr()->exec(
src, rois, dst, index, intl::get_megdnn_workspace_from_var(output(2)));
}
size_t ROIAlignForward::get_workspace_size_bytes(
const TensorShapeArray& inp_shapes, const TensorShapeArray& out_shapes) const {
TensorLayout inp{inp_shapes[0], input(0)->dtype(), input(0)->format()},
rois{inp_shapes[1], input(1)->dtype(), input(1)->format()},
out{out_shapes[0], output(0)->dtype(), output(0)->format()},
index{out_shapes[1], output(1)->dtype(), output(1)->format()};
return megdnn_opr()->get_workspace_in_bytes(inp, rois, index, out);
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ROIAlignForward) {
if (wrt_idx == 0) {
......
......@@ -16,6 +16,13 @@ public:
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar rois, const Param& param = {},
const OperatorNodeConfig& config = {});
private:
void scn_do_execute() override;
NodeProp* do_make_node_prop() const override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
};
using ROIAlign = ROIAlignForward;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册