diff --git a/dnn/src/common/roi_align.cpp b/dnn/src/common/roi_align.cpp index 4ad829bad2e9b2acc540f2429010e136f5c3ee86..206b4ce4448eac6f4b4d5446c6c50ac67a499069 100644 --- a/dnn/src/common/roi_align.cpp +++ b/dnn/src/common/roi_align.cpp @@ -7,8 +7,10 @@ namespace megdnn { void ROIAlignBase::deduce_layout_fwd( const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, TensorLayout& index) { - megdnn_assert_contiguous(src); - megdnn_assert_contiguous(rois); + if (!src.is_empty()) + megdnn_assert_contiguous(src); + if (!rois.is_empty()) + megdnn_assert_contiguous(rois); megdnn_assert_contiguous(dst); megdnn_assert_contiguous(index); auto errmsg = [&]() { diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index 114d70a8d1144f8beeae3a5eb5aa3ff5fa1a1572..5492683c694d391b0f1a8c1bb113d216f42a3688 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -16,14 +16,14 @@ from .tensor import broadcast_to, concat, expand_dims, reshape, transpose __all__ = [ "correlation", "cvt_color", - "roi_pooling", - "roi_align", + "interpolate", "nms", + "nvof", "remap", + "roi_align", + "roi_pooling", "warp_affine", "warp_perspective", - "interpolate", - "nvof", ] @@ -95,9 +95,9 @@ def roi_pooling( Args: inp: tensor that represents the input feature, `(N, C, H, W)` images. - rois: K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. - output_shape: height, width)` of output rois feature. - mode: max" or "average", use max/average align just like max/average pooling. Default: "max" + rois: `(K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. + output_shape: `(height, width)` of output rois feature. + mode: "max" or "average", use max/average align just like max/average pooling. Default: "max" scale: scale the input boxes by this number. Default: 1.0 Returns: @@ -176,9 +176,9 @@ def roi_align( Args: inp: tensor that represents the input feature, shape is `(N, C, H, W)`. - rois: N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``. - output_shape: height, width)` shape of output rois feature. - mode: max" or "average", use max/average align just like max/average pooling. Default: "average" + rois: `(N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``. + output_shape: `(height, width)` shape of output rois feature. + mode: "max" or "average", use max/average align just like max/average pooling. Default: "average" spatial_scale: scale the input boxes by this number. Default: 1.0 sample_points: number of inputs samples to take for each output sample. 0 to take samples densely. Default: 2 @@ -345,7 +345,7 @@ def warp_affine( Args: inp: input image. - mat: batch, 2, 3)` transformation matrix. + mat: `(batch, 2, 3)` transformation matrix. out_shape: output tensor shape. border_mode: pixel extrapolation method. Default: "wrap". Currently "constant", "reflect", diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 3616b9e75d2e45f9ff2958b53a650c9b98519f82..f7838b3beda90fcbae512faffa9fe0a45ecf8122 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -289,6 +289,37 @@ def test_roi_align(): assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) +@pytest.mark.parametrize("shapes", [((2, 0, 26, 26), (4, 5)), ((2, 3, 26, 26), (0, 5))]) +@pytest.mark.parametrize("is_tracing", [False, True]) +def test_roi_align_empty(shapes, is_tracing): + inp_feat = tensor(np.random.randn(*(shapes[0]))) + rois = tensor(np.random.random(shapes[1])) + output_shape = (7, 7) + + def func(inp, rois): + out_feat = F.vision.roi_align( + inp_feat, + rois, + output_shape=output_shape, + mode="average", + spatial_scale=1.0 / 4, + sample_points=2, + aligned=True, + ) + return out_feat + + if is_tracing: + func = jit.trace(func) + + for _ in range(3): + out_feat = func(inp_feat, rois) + assert make_shape_tuple(out_feat.shape) == ( + rois.shape[0], + inp_feat.shape[1], + *output_shape, + ) + + def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): if random: inp_feat1 = np.random.randn( diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 528d3ae7792ef921ec6f46e2bce05467ac4ac2ab..1708a0887e3dbd75dd61fcf301aa9010736973c0 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -441,21 +441,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fall } // namespace assert_equal } // namespace -namespace { -namespace roi_align { -VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& op = static_cast(def); - mgb_assert(inputs.size() == 2); - OperatorNodeConfig config{op.make_name()}; - auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) - .node() - ->owner_opr(); - return {opr->output(0), opr->output(1)}; -} -OP_TRAIT_REG(ROIAlign, ROIAlign).apply_on_var_node(apply_on_var_node).fallback(); -} // namespace roi_align -} // namespace - namespace { namespace correlation { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { @@ -522,22 +507,6 @@ OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback(); } // namespace diag } // namespace -namespace { -namespace roi_pooling { -VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& op = static_cast(def); - mgb_assert(inputs.size() == 3); - OperatorNodeConfig config{op.make_name()}; - auto* opr = - opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config) - .node() - ->owner_opr(); - return {opr->output(0), opr->output(1)}; -} -OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback(); -} // namespace roi_pooling -} // namespace - namespace { namespace remap { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { diff --git a/imperative/src/impl/ops/vision.cpp b/imperative/src/impl/ops/vision.cpp index 69c1750355857e8a9fe022dc4391d8d5049ea95f..5cfb223bd22e182e6b3eb801759dd95aa59a6c6a 100644 --- a/imperative/src/impl/ops/vision.cpp +++ b/imperative/src/impl/ops/vision.cpp @@ -1,8 +1,11 @@ #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/dnn/roi_align.h" +#include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/imgproc.h" +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" #include "../op_trait.h" - namespace mgb { namespace imperative { @@ -15,5 +18,119 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { } OP_TRAIT_REG(CvtColor, CvtColor).apply_on_var_node(apply_on_var_node).fallback(); } // namespace + +namespace { +namespace roi_align { +VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + OperatorNodeConfig config{op.make_name()}; + auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) + .node() + ->owner_opr(); + return {opr->output(0), opr->output(1)}; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& op = static_cast(def); + if (inputs[0].layout.is_empty() || inputs[1].layout.is_empty()) { + return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}, + {TensorLayout(dtype::Int32()), inputs[1].comp_node}}, + false}; + } + + SmallVector descs(2u); + size_t n = inputs[1].layout[0]; + size_t c = inputs[0].layout[1]; + descs[0].layout = TensorLayout( + {n, c, op.pooled_height, op.pooled_width}, inputs[0].layout.dtype); + descs[0].layout.init_contiguous_stride(); + descs[0].comp_node = inputs[0].comp_node; + + descs[1].layout = + TensorLayout({n, c, op.pooled_height, op.pooled_width}, dtype::Int32()); + descs[1].layout.init_contiguous_stride(); + descs[1].comp_node = descs[0].comp_node; + + return {descs, true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = static_cast(def); + CompNode cn = inputs[0]->comp_node(); + + TensorLayout out_layout = output_descs[0].layout; + TensorLayout ind_layout = output_descs[1].layout; + if (!validated) { + size_t n = inputs[1]->layout()[0]; + size_t c = inputs[0]->layout()[1]; + out_layout = TensorLayout( + {n, c, op.pooled_height, op.pooled_width}, inputs[0]->layout().dtype); + out_layout.init_contiguous_stride(); + ind_layout = + TensorLayout({n, c, op.pooled_height, op.pooled_width}, dtype::Int32()); + ind_layout.init_contiguous_stride(); + } + + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(cn, out_layout); + DeviceTensorND inds = + BlobManager::inst()->alloc_workspace_with_defrag(cn, ind_layout); + + if (out_layout.is_empty() || ind_layout.is_empty()) { + return {Tensor::make(out), Tensor::make(inds)}; + } + + DnnOprCaller dnn_opr(cn); + dnn_opr.op->param() = op.param(); + + size_t sz = dnn_opr.op->get_workspace_in_bytes( + inputs[0]->layout(), inputs[1]->layout(), out_layout, ind_layout); + TensorLayout w_layout({sz}, dtype::Byte()); + auto dnn_wk = dnn_opr.create_workspace(w_layout); + + dnn_opr.op->exec( + inputs[0]->dnn_tensor(), inputs[1]->dnn_tensor(), out.as_megdnn(), + inds.as_megdnn(), dnn_wk); + return {Tensor::make(out), Tensor::make(inds)}; +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { + return layout.is_contiguous(); + }; + return layout_checker; +} + +OP_TRAIT_REG(ROIAlign, ROIAlign) + .apply_on_var_node(apply_on_var_node) + .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); +} // namespace roi_align +} // namespace + +namespace { +namespace roi_pooling { +VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 3); + OperatorNodeConfig config{op.make_name()}; + auto* opr = + opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config) + .node() + ->owner_opr(); + return {opr->output(0), opr->output(1)}; +} +OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace roi_pooling +} // namespace + } // namespace imperative } // namespace mgb diff --git a/src/opr/impl/dnn/roi_align.cpp b/src/opr/impl/dnn/roi_align.cpp index 572403fd3489edf57a5de76cc5b303ac7f011151..5733baabd7113e8fc562d77b465e391c0e5a5394 100644 --- a/src/opr/impl/dnn/roi_align.cpp +++ b/src/opr/impl/dnn/roi_align.cpp @@ -20,6 +20,8 @@ ROIAlignForward::ROIAlignForward( add_input({src, rois}); output(0)->dtype(dtype::Float32()); output(1)->dtype(dtype::Int32()); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } SymbolVar ROIAlignForward::make( @@ -29,6 +31,35 @@ SymbolVar ROIAlignForward::make( src.node(), rois.node(), param, config); } +ROIAlignForward::NodeProp* ROIAlignForward::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY); + ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + +void ROIAlignForward::scn_do_execute() { + auto src = input(0)->dev_tensor().as_megdnn(), + rois = input(1)->dev_tensor().as_megdnn(), + dst = output(0)->dev_tensor().as_megdnn(), + index = output(1)->dev_tensor().as_megdnn(); + + if ((src.layout.is_empty() || rois.layout.is_empty())) { + return; + } + megdnn_opr()->exec( + src, rois, dst, index, intl::get_megdnn_workspace_from_var(output(2))); +} + +size_t ROIAlignForward::get_workspace_size_bytes( + const TensorShapeArray& inp_shapes, const TensorShapeArray& out_shapes) const { + TensorLayout inp{inp_shapes[0], input(0)->dtype(), input(0)->format()}, + rois{inp_shapes[1], input(1)->dtype(), input(1)->format()}, + out{out_shapes[0], output(0)->dtype(), output(0)->format()}, + index{out_shapes[1], output(1)->dtype(), output(1)->format()}; + return megdnn_opr()->get_workspace_in_bytes(inp, rois, index, out); +} + #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ROIAlignForward) { if (wrt_idx == 0) { diff --git a/src/opr/include/megbrain/opr/dnn/roi_align.h b/src/opr/include/megbrain/opr/dnn/roi_align.h index 7cdb97053ad6fd0812e6cfffb29d6320dc411ca0..51f4ee33d4bf8e9631845b99e431f0a658f3d091 100644 --- a/src/opr/include/megbrain/opr/dnn/roi_align.h +++ b/src/opr/include/megbrain/opr/dnn/roi_align.h @@ -16,6 +16,13 @@ public: MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar src, SymbolVar rois, const Param& param = {}, const OperatorNodeConfig& config = {}); + +private: + void scn_do_execute() override; + NodeProp* do_make_node_prop() const override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; }; using ROIAlign = ROIAlignForward;