nms_opr.cpp 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
#include "megbrain/opr/standalone/nms_opr.h"

#if MGB_CUDA
#include "./nms_kern.cuh"
#endif
#include "./nms_cpu.h"

#include "megbrain/comp_node_env.h"
#include "megbrain/serialization/sereg.h"
#include "megbrain/utils/arith_helper.h"  // for get_aligned_power2

#if MGB_ENABLE_FBS_SERIALIZATION
#include "megbrain/serialization/internal/mgb_cpp_opr_generated.h"
#include "megbrain/serialization/internal/schema_generated.h"
#endif

using namespace mgb::opr::standalone;

MGB_DYN_TYPE_OBJ_FINAL_IMPL(NMSKeep);

class NMSKeep::Kern {
public:
    virtual ~Kern() = default;

    //! get workspace size in bytes
M
Megvii Engine Team 已提交
26 27 28 29 30
    virtual size_t get_workspace_size(const NMSKeep* opr, const TensorShape& boxes) = 0;
    virtual void exec(
            const NMSKeep* opr, const DeviceTensorND& inp,
            const DeviceTensorND& out_idx, const DeviceTensorND& out_size,
            const DeviceTensorND& workspace) = 0;
31 32 33 34 35 36 37 38 39 40 41
};

// f{{{ cuda kernel begins
#if MGB_CUDA
class NMSKeep::CUDAKern final : public Kern {
    size_t m_workspace_overlap_mask_bytes, m_workspace_overlap_mask_bytes_align,
            m_workspace_rm_mask_bytes;

    void init(const NMSKeep* opr, const TensorShape& boxes) {
        auto align = opr->comp_node().get_mem_addr_alignment();
        size_t nr_boxes = boxes[1];
42 43 44 45 46 47 48 49 50 51 52
        if (nr_boxes == 0) {
            m_workspace_overlap_mask_bytes = 0;
            m_workspace_overlap_mask_bytes_align = 0;
            m_workspace_rm_mask_bytes = 0;
        } else {
            m_workspace_overlap_mask_bytes =
                    nr_boxes * DIVUP(nr_boxes, 64) * sizeof(uint64_t);
            m_workspace_overlap_mask_bytes_align =
                    get_aligned_power2(m_workspace_overlap_mask_bytes, align);
            m_workspace_rm_mask_bytes = DIVUP(nr_boxes, 64) * sizeof(uint64_t);
        }
53 54 55
    }

public:
M
Megvii Engine Team 已提交
56
    size_t get_workspace_size(const NMSKeep* opr, const TensorShape& boxes) override {
57 58 59 60
        init(opr, boxes);
        return m_workspace_overlap_mask_bytes_align + m_workspace_rm_mask_bytes;
    }

M
Megvii Engine Team 已提交
61 62 63 64
    void exec(
            const NMSKeep* opr, const DeviceTensorND& inp,
            const DeviceTensorND& out_idx, const DeviceTensorND& out_size,
            const DeviceTensorND& workspace) override;
65 66
};

M
Megvii Engine Team 已提交
67 68 69
void NMSKeep::CUDAKern::exec(
        const NMSKeep* opr, const DeviceTensorND& inp, const DeviceTensorND& out_idx,
        const DeviceTensorND& out_size, const DeviceTensorND& workspace) {
70 71 72 73 74 75 76 77 78 79 80
    // NOTE: input comp node might be different from output comp node (for
    // example, CUDA stream may be modified to overlap computations); a
    // SingleCNOperatorNodeBase is expected to execute on a single comp node,
    // and the comp node is defined as the output comp node
    CompNode comp_node = out_idx.comp_node();

    // comp ndoe is also accessible from SingleCNOperatorNode
    mgb_assert(comp_node == opr->comp_node());

    // CompNodeEnv contains platform-specific properties of a CompNode
    auto&& cuda_env = CompNodeEnv::from_comp_node(comp_node).cuda_env();
M
Megvii Engine Team 已提交
81 82 83
    mgb_assert(
            cuda_env.device_prop.warpSize == 32, "invalid warp size: %d",
            cuda_env.device_prop.warpSize);
84 85 86 87 88
    auto stream = cuda_env.stream;

    init(opr, inp.shape());

    auto inp_ptr = inp.ptr<float>();
89 90
    void* workspace_ptr = workspace.raw_ptr();
    auto dev_overlap_mask = reinterpret_cast<uint64_t*>(workspace_ptr),
M
Megvii Engine Team 已提交
91 92
         dev_rm_mask =
                 (uint64_t*)(workspace.raw_ptr() + m_workspace_overlap_mask_bytes_align);
93 94 95
    auto out_idx_ptr = reinterpret_cast<uint32_t*>(out_idx.ptr<int32_t>()),
         out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>());
    size_t batch = inp.shape(0), nr_boxes = inp.shape(1);
96
    if (nr_boxes == 0) {
M
Megvii Engine Team 已提交
97 98
        MGB_CUDA_CHECK(
                cudaMemsetAsync(out_size_ptr, 0, batch * sizeof(uint32_t), stream));
99 100
        return;
    }
M
Megvii Engine Team 已提交
101 102
    MGB_CUDA_CHECK(cudaMemsetAsync(
            dev_overlap_mask, 0, m_workspace_overlap_mask_bytes, stream));
103 104 105 106

    auto max_output = opr->param().max_output;

    for (size_t i = 0; i < batch; ++i) {
M
Megvii Engine Team 已提交
107 108 109 110 111 112 113 114 115
        nms::launch_gen_mask(
                nr_boxes, opr->param().iou_thresh, inp_ptr + i * nr_boxes * 4,
                DIVUP(nr_boxes, 64), dev_overlap_mask, stream);

        MGB_CUDA_CHECK(
                cudaMemsetAsync(dev_rm_mask, 0, m_workspace_rm_mask_bytes, stream));
        nms::launch_gen_indices(
                nr_boxes, max_output, DIVUP(nr_boxes, 64), dev_overlap_mask,
                dev_rm_mask, out_idx_ptr + i * max_output, out_size_ptr + i, stream);
116 117 118 119 120 121 122 123 124 125 126
    }
}

#endif  // MGB_CUDA for CUDAKern
// f}}} cuda kernel ends

// f{{{ cpu kernel begins
class NMSKeep::CPUKern final : public Kern {
public:
    ~CPUKern() = default;

M
Megvii Engine Team 已提交
127
    size_t get_workspace_size(const NMSKeep*, const TensorShape& boxes) override {
128 129 130
        return nms::cpu_kern_workspace(boxes.shape[1]);
    }

M
Megvii Engine Team 已提交
131 132 133 134
    void exec(
            const NMSKeep* opr, const DeviceTensorND& inp,
            const DeviceTensorND& out_idx, const DeviceTensorND& out_size,
            const DeviceTensorND& workspace) override;
135
};
M
Megvii Engine Team 已提交
136 137 138
void NMSKeep::CPUKern::exec(
        const NMSKeep* opr, const DeviceTensorND& inp, const DeviceTensorND& out_idx,
        const DeviceTensorND& out_size, const DeviceTensorND& workspace) {
139 140 141 142
    // See CUDAKern::exec for more explanation on output comp nodes.
    CompNode comp_node = out_idx.comp_node();

    size_t batch = inp.shape(0), nr_boxes = inp.shape(1);
143
    if (nr_boxes == 0) {
144
        auto out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>());
145 146 147 148 149
        for (size_t i = 0; i < batch; ++i) {
            *(out_size_ptr + i) = 0;
        }
        return;
    }
150 151 152 153 154 155 156 157
    auto param = opr->param();

    auto workspace_ptr = workspace.raw_ptr();

    // NOTE: we must copy all the params into the kernel closure since it would
    // be dispatched on a different thread
    auto kern = [=]() {
        for (size_t i = 0; i < batch; ++i) {
158 159 160 161 162
            auto inp_ptr = inp.as_megdnn().ptr<float>();
            auto out_idx_ptr =
                    reinterpret_cast<uint32_t*>(out_idx.as_megdnn().ptr<int32_t>());
            auto out_size_ptr =
                    reinterpret_cast<uint32_t*>(out_size.as_megdnn().ptr<int32_t>());
M
Megvii Engine Team 已提交
163 164 165 166
            nms::cpu_kern(
                    nr_boxes, param.max_output, param.iou_thresh,
                    inp_ptr + i * nr_boxes * 4, out_idx_ptr + i * param.max_output,
                    out_size_ptr + i, workspace_ptr);
167 168 169 170 171 172 173 174 175
        }
    };

    // The kernel should not be invoked
    CompNodeEnv::from_comp_node(comp_node).cpu_env().dispatch(kern);
}

// f}}} cpu kernel ends

M
Megvii Engine Team 已提交
176 177 178
NMSKeep::NMSKeep(
        VarNode* boxes, const Param& param,
        const OperatorNodeConfig& config)
179 180
        : Super(boxes->owner_graph(),  // owner graph
                config,                // OperatorNodeConfig
M
Megvii Engine Team 已提交
181 182
                "nms_keep",            // opr type name (used for generating opr name)
                {boxes}                // input vars for generating opr name
183 184
                ),
          m_param{param} {
M
Megvii Engine Team 已提交
185 186 187
    mgb_assert(
            boxes->dtype() == dtype::Float32(), "input should be float32; got %s",
            boxes->dtype().name());
188 189 190 191 192 193 194 195 196 197 198
    // setup m_kern according to device type
    switch (boxes->comp_node().device_type()) {
#if MGB_CUDA
        case CompNode::DeviceType::CUDA:
            m_kern = std::make_unique<CUDAKern>();
            break;
#endif
        case CompNode::DeviceType::CPU:
            m_kern = std::make_unique<CPUKern>();
            break;
        default:
M
Megvii Engine Team 已提交
199 200 201
            mgb_throw(
                    MegBrainError, "NMSKeep: unsupported device type: %s",
                    boxes->comp_node().to_string().c_str());
202 203 204
    }

    add_input({boxes});
M
Megvii Engine Team 已提交
205 206 207
    add_output("indices")
            ->dtype(dtype::Int32())
            .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
208 209 210 211 212 213 214 215 216 217 218
    add_output("sizes")->dtype(dtype::Int32());
    cg::add_workspace_output(this);  // workspace is also an output var

    // make the graph deduplication system consider m_param (so two oprs with
    // same input vars but different param values would not be deduplicated)
    add_equivalence_component<PODHash<Param>>(&m_param);
}

// impl dtor after Kern is defined
NMSKeep::~NMSKeep() noexcept = default;

M
Megvii Engine Team 已提交
219 220
mgb::SymbolVar NMSKeep::make(
        SymbolVar boxes, const Param& param, const OperatorNodeConfig& config) {
221 222 223 224 225 226 227
    // SymbolVar is just a wrapper of VarNode*, with overloaded methods such as
    // operator+()
    auto bvar = boxes.node();
    // insert opr into the owner graph of boxes
    return boxes.insert_single_output_opr<NMSKeep>(bvar, param, config);
}

M
Megvii Engine Team 已提交
228 229
void NMSKeep::get_output_var_shape(
        const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
230
    auto boxes = inp_shape.at(0);
M
Megvii Engine Team 已提交
231 232 233
    mgb_assert(
            boxes.ndim == 3 && boxes.shape[2] == 4, "invalid box shape: %s",
            boxes.to_string().c_str());
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249

    // out_shape should match the outputs added in the constructor
    mgb_assert(out_shape.size() == 3);

    auto batch = boxes[0];
    out_shape[0] = {batch, m_param.max_output};                // indices
    out_shape[1] = {batch};                                    // sizes
    out_shape[2] = {m_kern->get_workspace_size(this, boxes)};  // workspace
}

void NMSKeep::add_input_layout_constraint() {
    input(0)->add_layout_constraint_contiguous();
}

void NMSKeep::scn_do_execute() {
    DeviceTensorND empty_workspace;
M
Megvii Engine Team 已提交
250 251 252 253 254 255
    m_kern->exec(
            this, input(0)->dev_tensor(), output(0)->dev_tensor(),
            output(1)->dev_tensor(),
            // if workspace size is 0, output(2) would be invalid and its
            // dev_tensor() can not be accessed
            output(2)->dev_tensor_valid() ? output(2)->dev_tensor() : empty_workspace);
256 257
}

258 259
NMSKeep::NodeProp* NMSKeep::do_make_node_prop() const {
    auto ret = Super::do_make_node_prop();
M
Megvii Engine Team 已提交
260
    ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
261 262 263
    return ret;
}

264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
#if MGB_ENABLE_FBS_SERIALIZATION

namespace mgb {
namespace serialization {
namespace fbs {

template <>
struct ParamConverter<opr::standalone::NMSKeep::Param> {
    using FlatBufferType = param::NMSKeep;
    static opr::standalone::NMSKeep::Param to_param(const FlatBufferType* fb) {
        return {fb->iou_thresh(), fb->max_output()};
    }
    static flatbuffers::Offset<FlatBufferType> to_flatbuffer(
            flatbuffers::FlatBufferBuilder& builder,
            const opr::standalone::NMSKeep::Param& p) {
        return param::CreateNMSKeep(builder, p.iou_thresh, p.max_output);
    }
};

}  // namespace fbs
}  // namespace serialization
}  // namespace mgb

#endif

namespace mgb {

void _hack_pull_in_nms_opr_object() {}

}  // namespace mgb

// register serialization: the default implementation uses Opr::Param; it
// requires Param::TAG, Opr::param() and Opr::make(..., param) to exist
// Note: the second param 1 here means that this operator has one input
using NMSKeepMGB = NMSKeep;
MGB_SEREG_OPR(NMSKeepMGB, 1);

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}