misc.cpp 16.8 KB
Newer Older
1 2 3 4
/**
 * \file src/opr/impl/misc.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./internal/megdnn_opr_wrapper.inl"

#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"

using namespace mgb;
using namespace opr;

namespace mgb {
namespace opr {
namespace intl {
M
Megvii Engine Team 已提交
27 28 29 30 31 32
template <>
struct MegDNNOprInitPostCtor<Argmax> {
    static void apply(cg::OperatorNodeBase& opr) {
        opr.output(0)->dtype(dtype::Int32());
    }
};
33

M
Megvii Engine Team 已提交
34 35
template <>
struct MegDNNOprInitPostCtor<Argmin> : public MegDNNOprInitPostCtor<Argmax> {};
36

M
Megvii Engine Team 已提交
37 38 39 40 41 42 43 44 45 46
template <>
struct MegDNNOprInitPostCtor<ArgsortForward> {
    static void apply(cg::OperatorNodeBase& opr) {
        opr.output(0)->dtype(opr.input(0)->dtype());
        opr.output(1)->dtype(dtype::Int32());
    }
};
}  // namespace intl
}  // namespace opr
}  // namespace mgb
47 48 49

/* ================= Argmxx ================= */

50
#if MGB_ENABLE_GRAD
51 52 53 54 55 56
MGB_IMPL_OPR_GRAD(Argmax) {
    MGB_MARK_USED_VAR(out_grad);
    MGB_MARK_USED_VAR(opr);
    mgb_assert(!wrt_idx);
    return nullptr;
}
57
#endif
58 59 60 61

MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax);
MEGDNN_OPR_INIT1(Argmax, "argmax")

62
#if MGB_ENABLE_GRAD
63 64 65 66 67 68
MGB_IMPL_OPR_GRAD(Argmin) {
    MGB_MARK_USED_VAR(out_grad);
    MGB_MARK_USED_VAR(opr);
    mgb_assert(!wrt_idx);
    return nullptr;
}
69
#endif
70 71 72 73 74 75 76

MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmin);
MEGDNN_OPR_INIT1(Argmin, "argmin")

/* ================= ArgsortForward =================  */

MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortForward);
77 78
// MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort")

M
Megvii Engine Team 已提交
79 80 81
ArgsortForward::ArgsortForward(
        VarNode* i0, const Param& param, const OperatorNodeConfig& config)
        : Super(OperatorNodeBaseCtorParam{i0->owner_graph(), config, "argsort", {i0}}) {
82 83 84 85 86 87
    init_megdnn_opr(*this, param);
    add_input({i0});
    output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);  // sorted value
    output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);  // sorted index
    intl::MegDNNOprInitPostCtor<ArgsortForward>::apply(*this);
}
88 89

std::array<SymbolVar, 2> ArgsortForward::make(
M
Megvii Engine Team 已提交
90
        SymbolVar in_tensor, const Param& param, const OperatorNodeConfig& config) {
91 92 93 94 95 96
    auto node = in_tensor.node()->owner_graph()->insert_opr(
            std::make_unique<ArgsortForward>(in_tensor.node(), param, config));
    mgb_assert(node->output().size() == 3);
    return {node->output(0), node->output(1)};
}

97 98
void ArgsortForward::scn_do_execute() {
    if (input(0)->dev_tensor().empty()) {
M
Megvii Engine Team 已提交
99
        mgb_assert(output(0)->dev_tensor().empty() && output(1)->dev_tensor().empty());
100 101
        return;
    }
M
Megvii Engine Team 已提交
102
    mgb_assert(!output(0)->dev_tensor().empty() && !output(1)->dev_tensor().empty());
103 104 105 106
    Super::scn_do_execute();
}

void ArgsortForward::get_output_var_shape(
M
Megvii Engine Team 已提交
107
        const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
108 109 110 111 112 113 114
    mgb_assert(inp_shape.size() == 1 && out_shape.size() == 2);
    out_shape[0] = inp_shape[0];
    out_shape[1] = inp_shape[0];
}

ArgsortForward::NodeProp* ArgsortForward::do_make_node_prop() const {
    auto ret = Super::do_make_node_prop();
M
Megvii Engine Team 已提交
115
    ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
116 117 118
    return ret;
}

119
#if MGB_ENABLE_GRAD
120 121 122 123 124 125
MGB_IMPL_OPR_GRAD(ArgsortForward) {
    mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
    if (!out_grad[0])
        return nullptr;
    return ArgsortBackward::make(out_grad[0], opr.output(1)).node();
}
126
#endif
127 128 129 130 131 132 133 134 135 136

/* ================= ArgsortBackward =================  */

MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortBackward);
MEGDNN_OPR_INIT3(ArgsortBackward, "argsort_bwd", 2, false)

/* ================= Cumsum =================  */

MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum);

M
Megvii Engine Team 已提交
137
Cumsum::Cumsum(VarNode* opr, const Param& param, const OperatorNodeConfig& config)
138 139 140 141 142
        : Super{opr->owner_graph(), config, "Cumsum", {opr}} {
    init_megdnn_opr(*this, param);
    add_input({opr}, AddInputSortType::CUR_ADDED);
}

143
#if MGB_ENABLE_GRAD
144 145 146 147 148 149
MGB_IMPL_OPR_GRAD(Cumsum) {
    mgb_assert(out_grad[0] && !out_grad[1]);
    auto param = opr.param();
    param.reverse = !param.reverse;
    return Cumsum::make(out_grad[0], param).node();
}
150
#endif
151

M
Megvii Engine Team 已提交
152 153
SymbolVar Cumsum::make(
        SymbolVar opr, const Param& param, const OperatorNodeConfig& config) {
154 155 156 157
    return opr.insert_single_output_opr<Cumsum>(opr.node(), param, config);
}

void Cumsum::scn_do_execute() {
M
Megvii Engine Team 已提交
158 159 160
    megdnn_opr()->exec(
            input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
            intl::get_megdnn_workspace_from_var(output().back()));
161 162
}

163 164 165 166
void Cumsum::add_input_layout_constraint() {
    input(0)->add_layout_constraint_contiguous();
}

167 168 169 170 171 172 173 174
void Cumsum::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
        auto ishp = iv.val.at(0).shape();
        dest = ishp;
        return true;
    };
    owner_graph()->static_infer_manager().register_shape_infer(
M
Megvii Engine Team 已提交
175
            output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
    auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) {
        auto dtype = input(0)->dtype();
        auto ishp = iv.val.at(0).shape();
        TensorLayout ily(ishp, dtype);
        Param real_param = param();
        if (real_param.axis < 0)
            real_param.axis += ishp.ndim;
        megdnn_opr()->param() = real_param;
        dest.ndim = 1;
        dest[0] = megdnn_opr()->get_workspace_in_bytes(ily, ily);
        return true;
    };
    owner_graph()->static_infer_manager().register_shape_infer(
            output(1),
            {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace});
}

193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
/* ================= NvOf =================  */

#if MGB_CUDA
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf);

NvOf::NvOf(VarNode* opr, const Param& param, const OperatorNodeConfig& config)
        : Super{opr->owner_graph(), config, "NvOf", {opr}}, m_param{param} {
    mgb_assert(opr->dtype() == dtype::Uint8());
    add_input({opr});
    //! NvOf hava only one output
    add_output(None);
    mgb_log_debug("init nvof engine with precision: %u", m_param.precision);
}

void NvOf::init_output_dtype() {
    output(0)->dtype(dtype::Int16());
}

M
Megvii Engine Team 已提交
211 212
SymbolVar NvOf::make(
        SymbolVar opr, const Param& param, const OperatorNodeConfig& config) {
213 214 215 216
    return opr.insert_single_output_opr<NvOf>(opr.node(), param, config);
}

void NvOf::scn_do_execute() {
217 218 219 220
    auto input_shape = this->input()[0]->shape();
    for (size_t i = 0; i < 5; i++) {
        vshape.push_back(input_shape[i]);
    }
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
    auto c = this->comp_node();
    //! comp_node may init on CUDA or CPU, eg: lar with --cpu
    //! if ON CUDA, need sync, caused by we use different stream
    if (CompNode::DeviceType::CUDA == c.device_type()) {
        c.sync();
    } else {
        mgb_log_warn(
                "NvOf opr on non CUDA comp_node, which will triger H2D and "
                "D2H!!");
    }

    //! create NvOF engine at same device id of comp_node, can not get
    //! comp_node device id, when NvOf:NvOf, so init at scn_do_execute
    std::lock_guard<std::mutex> lock(m_lock);
    if (init_flag == false) {
        //! nvof sdk do not imp p2p copy, so init nvof engine on the same
        //! device with mgb comp_node
        nv_flow_extractor = std::make_shared<NVFlowExtractor>(
                c.locator().device, vshape, m_param.precision, true, true);
        init_flag = true;
    }

    nv_flow_extractor->extract_flow(
M
Megvii Engine Team 已提交
244
            static_cast<unsigned char*>(input(0)->dev_tensor().as_megdnn().raw_ptr),
245
            vshape,
M
Megvii Engine Team 已提交
246
            reinterpret_cast<int16_t*>(output(0)->dev_tensor().as_megdnn().raw_ptr));
247 248 249 250 251
}

void NvOf::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
252
        auto out_grid_size = NV_OF_OUTPUT_VECTOR_GRID_SIZE_4;
253
        auto ishp = iv.val.at(0).shape();
254 255 256 257
        //! nvof input format: nthwc4
        mgb_assert(ishp.ndim == 5);
        //! now only support RGBA format channel data
        mgb_assert(ishp[4] == 4);
258 259 260
        SmallVector<size_t> tv;
        tv.push_back(ishp[0]);
        tv.push_back(ishp[1] - 1);
261 262
        tv.push_back((ishp[2] + out_grid_size - 1) / out_grid_size);
        tv.push_back((ishp[3] + out_grid_size - 1) / out_grid_size);
263 264 265 266 267 268
        tv.push_back(ishp[4] / 2);
        dest = tv;

        return true;
    };
    owner_graph()->static_infer_manager().register_shape_infer(
M
Megvii Engine Team 已提交
269
            output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
270 271
}
#endif
272

273 274 275
/* ================= CondTake =================  */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake);

M
Megvii Engine Team 已提交
276 277 278 279
CondTake::CondTake(
        VarNode* data, VarNode* mask, const Param& param,
        const OperatorNodeConfig& config)
        : Super(data->owner_graph(), config, "cond_take", {data, mask}) {
280 281 282
    init_megdnn_opr(*this, param);
    add_input({data, mask});
    auto dtypes = megdnn_opr()->infer_dtype(data->dtype(), mask->dtype());
M
Megvii Engine Team 已提交
283
    for (int i = 0; i < 2; ++i) {
284
        output(i)
M
Megvii Engine Team 已提交
285 286 287
                ->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
                .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
                .dtype(dtypes[i]);
288 289 290
    }
}

291 292
CondTake::NodeProp* CondTake::do_make_node_prop() const {
    auto ret = Super::do_make_node_prop();
M
Megvii Engine Team 已提交
293 294
    ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
    ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY);
295 296 297
    return ret;
}

298
#if MGB_ENABLE_GRAD
299 300 301 302 303 304 305 306 307 308 309
MGB_IMPL_OPR_GRAD(CondTake) {
    mgb_assert(out_grad.size() == 3 && !out_grad[2]);
    if (wrt_idx == 0 && out_grad[0]) {
        SymbolVar data_sym{opr.input(0)};
        auto inp_set = IndexingIncrMultiAxisVec::make(
                data_sym.flatten().fill_retain_dtype(0), out_grad[0],
                {indexing::AxisIndexer::make_index(0, opr.output(1))});
        return inp_set.reshape(data_sym.symshape()).node();
    }
    return nullptr;
}
310
#endif
311 312

std::array<SymbolVar, 2> CondTake::make(
M
Megvii Engine Team 已提交
313 314
        SymbolVar data, SymbolVar mask, const Param& param,
        const OperatorNodeConfig& config) {
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
    auto ov0 = data.insert_single_output_opr<CondTake>(
            data.node(), mask.node(), param, config);
    return {ov0, ov0.node()->owner_opr()->output(1)};
}

void CondTake::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) {
        auto dtype = input(0)->dtype();
        TensorLayout ily(iv.val[0].shape(), dtype);
        dest.ndim = 1;
        dest.shape[0] = megdnn_opr()->get_workspace_in_bytes(ily);
        return true;
    };
    owner_graph()->static_infer_manager().register_shape_infer(
            output(2),
            {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace});
}

void CondTake::add_input_layout_constraint() {
    mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
}

void CondTake::scn_do_execute() {
339 340
    auto&& data = input(0)->dev_tensor();
    auto&& mask = input(1)->dev_tensor();
341
    intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()};
342
    if (data.layout().is_empty()) {
M
Megvii Engine Team 已提交
343 344
        mgb_assert(
                data.layout().eq_shape(mask.layout()),
345 346 347 348 349 350
                "CondTake shape differs: data=%s mask=%s",
                data.layout().TensorShape::to_string().c_str(),
                mask.layout().TensorShape::to_string().c_str());
        dyn_malloc.alloc_output(0, data.layout().dtype, {0}, nullptr);
        dyn_malloc.alloc_output(1, dtype::Int32(), {0}, nullptr);
    } else {
M
Megvii Engine Team 已提交
351 352 353
        megdnn_opr()->exec(
                data.as_megdnn(), mask.as_megdnn(),
                intl::get_megdnn_workspace_from_var(output().back()), &dyn_malloc);
354
    }
355 356 357 358 359 360
}

/* ================= TopK =================  */

MGB_DYN_TYPE_OBJ_FINAL_IMPL(TopK);

M
Megvii Engine Team 已提交
361 362
TopK::TopK(
        VarNode* data, VarNode* k, const Param& param, const OperatorNodeConfig& config)
363 364 365 366 367 368 369 370 371 372
        : Super(data->owner_graph(), config, "top_k", {data, k}) {
    init_megdnn_opr(*this, param);
    add_input({data, k});
    if (param.mode == Param::Mode::KTH_ONLY) {
        output(1)
                ->add_flag(VarNode::Flag::VOLATILE_CONTENT)
                .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
    }
}

M
Megvii Engine Team 已提交
373 374 375
std::array<SymbolVar, 2> TopK::make(
        SymbolVar data, SymbolVar k, const Param& param,
        const OperatorNodeConfig& config) {
376 377 378 379 380 381 382 383 384 385
    auto opr = data.node()->owner_graph()->insert_opr(
            std::make_unique<TopK>(data.node(), k.node(), param, config));
    auto o1 = opr->output(1);
    if (param.mode == Param::Mode::KTH_ONLY) {
        o1 = nullptr;
    }
    return {opr->output(0), o1};
}

void TopK::init_output_dtype() {
M
Megvii Engine Team 已提交
386 387 388
    mgb_assert(
            input(1)->dtype() == dtype::Int32{}, "k must be int32, got %s",
            input(1)->dtype().name());
389 390 391 392 393 394
    output(0)->dtype(input(0)->dtype());
    output(1)->dtype(dtype::Int32{});
}

void TopK::add_input_layout_constraint() {
    auto check = [](const TensorLayout& layout) {
M
Megvii Engine Team 已提交
395 396 397
        mgb_assert(
                layout.ndim == 2, "top-k input must be two-dim, got %s",
                layout.TensorShape::to_string().c_str());
398 399 400 401 402 403 404 405 406 407 408
        return layout.stride[1] == 1;
    };
    input(0)->add_layout_constraint(check);
}

void TopK::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    auto&& mgr = owner_graph()->static_infer_manager();

    auto infer_oshp0 = [this](TensorShape& dst, const InpVal& iv) {
        auto&& k_tensor = iv.val[1].value();
M
Megvii Engine Team 已提交
409 410 411
        mgb_assert(
                k_tensor.shape().is_scalar(), "k must be scalar, got %s",
                k_tensor.shape().to_string().c_str());
412
        TensorLayout o0, o1;
M
Megvii Engine Team 已提交
413 414
        megdnn_opr()->deduce_layout(
                k_tensor.ptr<int>()[0], {iv.val[0].shape(), input(0)->dtype()}, o0, o1);
415 416 417
        dst = o0;
        return true;
    };
M
Megvii Engine Team 已提交
418 419 420 421
    mgr.register_shape_infer(
            output(0), {SourceType::DEP,
                        {{input(0), DepType::SHAPE}, {input(1), DepType::VALUE}},
                        infer_oshp0});
422 423 424 425

    if (param().mode == Param::Mode::KTH_ONLY) {
        mgr.register_shape_infer(output(1), ShapeInferDesc::make_const({}));
    } else {
M
Megvii Engine Team 已提交
426
        mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(output(0)));
427 428 429
    }

    auto infer_workspace = [this](TensorShape& dst, const InpVal& iv) {
430 431
        // active comp_node for cuda launch kernel in get_workspace_in_bytes
        comp_node().activate();
432 433 434 435 436 437 438 439 440
        auto k = iv.val[3].value().ptr<int>()[0];
        auto size = megdnn_opr()->get_workspace_in_bytes(
                k, {iv.val[0].shape(), input(0)->dtype()},
                {iv.val[1].shape(), output(0)->dtype()},
                {iv.val[2].shape(), output(1)->dtype()});
        dst.ndim = 1;
        dst.shape[0] = size;
        return true;
    };
M
Megvii Engine Team 已提交
441 442 443 444 445 446 447
    mgr.register_shape_infer(
            output(2), {SourceType::DEP,
                        {{input(0), DepType::SHAPE},
                         {output(0), DepType::SHAPE},
                         {output(1), DepType::SHAPE},
                         {input(1), DepType::VALUE}},
                        infer_workspace});
448 449 450 451 452
}

void TopK::scn_do_execute() {
    auto&& mgr = owner_graph()->static_infer_manager();
    auto k = mgr.infer_value(input(1)).ptr<int>()[0];
M
Megvii Engine Team 已提交
453 454 455 456
    megdnn_opr()->exec(
            k, input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
            output(1)->dev_tensor().as_megdnn(),
            intl::get_megdnn_workspace_from_var(output(2)));
457 458 459 460 461 462
}

void TopK::record_execute_deps(ExecDependencyArray& deps) {
    record_megdnn_opr(deps);
}

463
#if MGB_ENABLE_GRAD
464
MGB_IMPL_OPR_GRAD(TopK) {
M
Megvii Engine Team 已提交
465
    // TopK has no gradient on the input k
M
Megvii Engine Team 已提交
466 467
    if (wrt_idx)
        return nullptr;
468 469 470 471 472 473 474 475 476 477 478 479
    if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) {
        mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]);
        auto add_axis = [](SymbolVar x) {
            return opr::AxisAddRemove::make(
                    x, {opr::AxisAddRemove::AxisDesc::make_add(1)});
        };
        SymbolVar mask = opr::eq(add_axis(opr.output(0)), opr.input(0)),
                  og = add_axis(out_grad[0]) / opr::reduce_ax_sum(mask, 1);
        return (og * mask).node();
    }
    if (!out_grad[0])
        return nullptr;
M
Megvii Engine Team 已提交
480
    return ArgsortBackward::make(out_grad[0], opr.output(1), opr.input(0)).node();
481
}
482
#endif
483

484
/* ================= CheckNonFinite =================  */
485 486 487
namespace mgb {
namespace opr {
namespace intl {
M
Megvii Engine Team 已提交
488
template <>
489
struct MegDNNOprInitPostCtor<CheckNonFinite> {
M
Megvii Engine Team 已提交
490
    static void apply(cg::OperatorNodeBase& opr) {
491 492 493
        opr.output(0)->dtype(dtype::Int32());
    }
};
M
Megvii Engine Team 已提交
494 495 496
}  // namespace intl
}  // namespace opr
}  // namespace mgb
497 498
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite);
MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite")
499
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}