rng.cpp 36.5 KB
Newer Older
1 2 3 4 5
#include "megbrain/imperative/ops/rng.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"

6
#include "../dnn_op_helper.h"
7
#include "../op_trait.h"
8

9
namespace mgb::imperative::rng {
10 11 12 13 14 15

namespace {

template <typename HandleFactory, typename THandle>
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
public:
16
    using DT = CompNode::DeviceType;
17
    using Handle = THandle;
18
    using OpTypeInfo = size_t;
19 20 21 22 23 24 25 26 27 28 29

    template <typename... Args>
    Handle new_handle(Args&&... args) {
        return static_cast<HandleFactory*>(this)->do_new_handle(
                std::forward<Args>(args)...);
    }

    size_t delete_handle(Handle handle) {
        size_t removed = 0;
        if (!is_finalized()) {
            MGB_LOCK_GUARD(m_mtx);
30
            removed = m_handle2ops.erase(handle);
31 32 33 34 35 36
        }
        static_cast<HandleFactory*>(this)->do_delete_handle(handle);
        return removed;
    }

    template <typename DnnOp>
37
    auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
38 39 40 41
        mgb_assert(!is_finalized());
        DnnOpWithMutex* dnn_op_with_mtx;
        {
            MGB_LOCK_GUARD(m_mtx);
42
            dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
43
        }
M
Megvii Engine Team 已提交
44
        auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
45 46
        std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
        bool initialized = false;
47 48
        DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
        if (dnn_op != nullptr) {
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
            mgb_assert(dnn_op->handle() == dnn_handle);
            initialized = true;
        } else {
            auto new_op = dnn_handle->create_operator<DnnOp>();
            dnn_op = new_op.get();
            dnn_op_with_mtx->op = std::move(new_op);
        }
        return std::make_tuple(initialized, dnn_op, std::move(lock));
    }

protected:
    using DnnOpManagerBase = DnnOpManagerT<HandleFactory, Handle>;
    DnnOpManagerT() = default;

private:
    struct DnnOpWithMutex {
        std::mutex mtx;
        std::unique_ptr<megdnn::OperatorBase> op;
M
Megvii Engine Team 已提交
67
        DnnOpWithMutex() : op{nullptr} {}
68 69 70 71
    };

    std::shared_ptr<void> on_comp_node_finalize() override {
        MGB_LOCK_GUARD(m_mtx);
72
        m_handle2ops.clear();
73 74 75
        return {};
    }

M
Megvii Engine Team 已提交
76 77
    std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex>>
            m_handle2ops;
78 79 80
    std::mutex m_mtx;
};

M
Megvii Engine Team 已提交
81
class RNGDnnOpManager final : public DnnOpManagerT<RNGDnnOpManager, Handle> {
82
public:
83 84 85 86 87
    Handle new_handle(CompNode comp_node, uint64_t seed) {
        MGB_LOCK_GUARD(sm_mtx);
        return DnnOpManagerBase::new_handle(comp_node, seed);
    }

88
    size_t delete_handle(Handle handle) {
89 90
        MGB_LOCK_GUARD(sm_mtx);
        return DnnOpManagerBase::delete_handle(handle);
91 92 93 94 95 96 97 98 99 100 101 102
    }

    Handle do_new_handle(CompNode comp_node, uint64_t seed) {
        auto handle = m_handle_pool.alloc(comp_node, seed);
        return reinterpret_cast<Handle>(handle);
    }

    void do_delete_handle(Handle handle) {
        m_handle_pool.free(reinterpret_cast<HandleData*>(handle));
    }

    static uint64_t get_seed(Handle handle) {
M
Megvii Engine Team 已提交
103 104 105
        if (!handle) {
            return glob_default_seed;
        }
106 107 108 109
        return reinterpret_cast<HandleData*>(handle)->seed;
    }

    static CompNode get_comp_node(Handle handle) {
110
        mgb_assert(handle, "invalid handle");
111 112 113 114
        return reinterpret_cast<HandleData*>(handle)->comp_node;
    }

    static Handle get_default_handle(CompNode comp_node) {
115 116 117 118 119
        mgb_assert(comp_node.valid());
        MGB_LOCK_GUARD(sm_mtx);
        auto&& glob_handle = glob_default_handles[comp_node];
        if (!glob_handle) {
            glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
120
        }
121
        mgb_assert(get_seed(glob_handle) == glob_default_seed);
122
        return glob_handle;
123 124 125 126 127 128 129 130
    }

    static RNGDnnOpManager& inst() {
        static RNGDnnOpManager mgr;
        return mgr;
    }

    static void set_glob_default_seed(uint64_t seed) {
131
        MGB_LOCK_GUARD(sm_mtx);
M
Megvii Engine Team 已提交
132
        for (auto&& elem : glob_default_handles) {
133
            mgb_assert(elem.first.valid());
M
Megvii Engine Team 已提交
134
            if (elem.second) {
135 136 137 138
                inst().DnnOpManagerBase::delete_handle(elem.second);
            }
            elem.second = inst().do_new_handle(elem.first, seed);
        }
139 140 141
        glob_default_seed = seed;
    }

142 143 144 145 146
    static uint64_t get_glob_default_seed() {
        MGB_LOCK_GUARD(sm_mtx);
        return glob_default_seed;
    }

147 148 149 150 151 152 153 154 155 156
private:
    struct HandleData {
        CompNode comp_node;
        uint64_t seed;
        HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {}
    };

    MemPool<HandleData> m_handle_pool;

    static std::mutex sm_mtx;
157
    static CompNode::UnorderedMap<Handle> glob_default_handles;
158 159 160 161 162
    static uint64_t glob_default_seed;
};

uint64_t RNGDnnOpManager::glob_default_seed = 0;
std::mutex RNGDnnOpManager::sm_mtx;
163
CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;
164 165 166 167 168 169 170 171 172 173

template <typename Op>
struct OpMeth;

template <>
struct OpMeth<UniformRNG> {
    using DnnOp = megdnn::UniformRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::UniformRNG;
    static Param make_param(const UniformRNG& rng) {
174
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
175 176 177 178
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
179 180 181 182 183 184 185 186 187 188
        return {handle_seed, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<PoissonRNG> {
    using DnnOp = megdnn::PoissonRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::PoissonRNG;
    static Param make_param(const PoissonRNG& rng) {
189
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
190 191 192 193
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
194
        return {handle_seed};
195 196 197 198 199 200 201 202 203
    }
};

template <>
struct OpMeth<GaussianRNG> {
    using DnnOp = megdnn::GaussianRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::GaussianRNG;
    static Param make_param(const GaussianRNG& rng) {
204
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
205 206 207 208
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
209 210 211 212 213 214 215 216 217 218 219
        return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<GammaRNG> {
    using DnnOp = megdnn::GammaRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::GammaRNG;
    static Param make_param(const GammaRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
220 221 222 223
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
224
        return {handle_seed};
225 226 227
    }
};

228 229 230 231 232 233 234
template <>
struct OpMeth<PermutationRNG> {
    using DnnOp = megdnn::PermutationRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::PermutationRNG;
    static Param make_param(const PermutationRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
235 236 237 238
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
239 240 241 242 243 244 245 246 247 248 249
        return {handle_seed, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<BetaRNG> {
    using DnnOp = megdnn::BetaRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::BetaRNG;
    static Param make_param(const BetaRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
250 251 252 253
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
254 255 256 257
        return {handle_seed};
    }
};

258 259 260 261 262 263 264
template <>
struct OpMeth<ShuffleRNG> {
    using DnnOp = megdnn::ShuffleRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::ShuffleRNG;
    static Param make_param(const ShuffleRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
265 266 267 268
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
269 270 271 272
        return {handle_seed};
    }
};

273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
template <>
struct OpMeth<Dropout> {
    using DnnOp = megdnn::Dropout;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::Dropout;
    static Param make_param(const Dropout& opdef) {
        auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle);
        mgb_assert(
                handle_seed == opdef.seed,
                "inconsistent dropout seed: dropout op: %lu handle: %lu", handle_seed,
                opdef.seed);
        return {opdef.drop_prob, handle_seed};
    }
};

288 289 290 291 292 293 294 295 296 297 298
template <>
struct OpMeth<MultiHeadAttn> {
    using DnnOp = megdnn::MultiHeadAttn;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::MultiHeadAttn;
    static Param make_param(const MultiHeadAttn& opdef) {
        auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle);
        mgb_assert(
                handle_seed == opdef.seed,
                "inconsistent multiheadattn seed: dropout op: %lu handle: %lu",
                handle_seed, opdef.seed);
299 300 301 302 303 304 305 306 307 308 309 310 311

        return {opdef.num_heads,      opdef.embeding_size,
                opdef.k_size,         opdef.v_size,
                opdef.qproj_size,     opdef.kproj_size,
                opdef.vproj_size,     opdef.oproj_size,
                opdef.qbias,          opdef.kbias,
                opdef.vbias,          opdef.obias,
                opdef.sm_scaler,      opdef.input_order,
                opdef.attn_mask_type, opdef.tensor_combination_type,
                opdef.add_zero_attn,  opdef.need_weights,
                opdef.reslink,        opdef.training,
                handle_seed,          opdef.attn_prob,
                opdef.out_prob};
312 313 314
    }
};

315 316 317 318 319 320
template <bool>
struct _InferLayout;

template <int nr_in>
struct _RNGOprMaker;

321
template <int nr_in, int nr_out>
322 323
struct _RNGOprInvoker;

M
Megvii Engine Team 已提交
324 325 326 327
template <>
struct _InferLayout<true> {
    template <typename Op>
    static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
328 329 330 331 332 333
        TensorShape tshape;
        auto hv = inp->get_value().proxy_to_default_cpu();
        cg::copy_tensor_value_to_shape(tshape, hv);
        return TensorLayout(tshape, rng.dtype);
    }

M
Megvii Engine Team 已提交
334 335
    template <typename Op>
    static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
336 337 338 339 340 341 342 343 344
        TensorLayout out_layout = inp.layout;
        out_layout.dtype = rng.dtype;
        if (inp.layout.ndim == 0 || inp.value.empty()) {
            out_layout.ndim = 0;
            return out_layout;
        }
        mgb_assert(
                inp.layout.ndim == 1,
                "target shape of %s expects ndim=1; got ndim=%lu actually",
M
Megvii Engine Team 已提交
345
                rng.dyn_typeinfo()->name, inp.layout.ndim);
346 347 348 349 350 351
        size_t target_ndim = inp.layout.shape[0];
        out_layout.ndim = target_ndim;
        auto* ptr = inp.value.ptr<dt_int32>();
        for (size_t i = 0; i < target_ndim; ++i) {
            out_layout.shape[i] = ptr[i];
        }
352
        out_layout.init_contiguous_stride();
353 354 355 356
        return out_layout;
    }
};

M
Megvii Engine Team 已提交
357 358 359 360
template <>
struct _InferLayout<false> {
    template <typename Op>
    static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
361 362 363
        return inp->layout();
    }

M
Megvii Engine Team 已提交
364 365
    template <typename Op>
    static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
366
        mgb_assert(inp.layout.ndim);
367 368 369 370
        return inp.layout;
    }
};

M
Megvii Engine Team 已提交
371
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS)                  \
372
    template <>                                                            \
M
Megvii Engine Team 已提交
373 374 375 376 377 378 379 380 381 382 383 384 385 386
    struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> {                 \
        template <typename Opr>                                            \
        static void exec(                                                  \
                Opr* dnn_op, const SmallVector<TensorPtr>& inputs,         \
                const SmallVector<TensorPtr>& outputs) {                   \
            size_t wk_size = 0;                                            \
            wk_size = dnn_op->get_workspace_in_bytes(                      \
                    _FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout()));   \
            auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \
            megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
            dnn_op->exec(                                                  \
                    _FOR_EACH_IN(->dev_tensor().as_megdnn())               \
                            _FOR_EACH_OUT(->dev_tensor().as_megdnn()),     \
                    dnn_wk);                                               \
387 388
        }                                                                  \
    };
389

M
Megvii Engine Team 已提交
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
#define _INST_RNG_MAKER(MGB_NR_INPUTS)                                                \
    template <>                                                                       \
    struct _RNGOprMaker<MGB_NR_INPUTS> {                                              \
        template <typename Op>                                                        \
        static auto make(const VarNodeArray& inputs, const Op& rng) {                 \
            auto param = OpMeth<Op>::make_param(rng);                                 \
            OperatorNodeConfig config;                                                \
            if (rng.handle) {                                                         \
                config = {                                                            \
                        rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \
            } else {                                                                  \
                config = {rng.make_name()};                                           \
            }                                                                         \
            return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config);            \
        }                                                                             \
    };

407 408 409 410
#define _FOR_EACH_IN(subfix)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(0, 1)
#undef _FOR_EACH_OUT
411 412
#undef _FOR_EACH_IN

M
Megvii Engine Team 已提交
413
#define _FOR_EACH_IN(subfix)  inputs[0] subfix,
414 415 416 417 418 419
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(1, 1)
#undef _FOR_EACH_OUT

#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
_INST_RNG_INVOLKER(1, 2)
420
_INST_RNG_MAKER(1)
421
#undef _FOR_EACH_OUT
422 423
#undef _FOR_EACH_IN

M
Megvii Engine Team 已提交
424
#define _FOR_EACH_IN(subfix)  inputs[0] subfix, inputs[1] subfix,
425 426
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(2, 1)
427
_INST_RNG_MAKER(2)
428
#undef _FOR_EACH_OUT
429 430
#undef _FOR_EACH_IN

431 432 433 434 435 436 437 438
#define _FOR_EACH_IN(subfix) \
    inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix,
#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
_INST_RNG_INVOLKER(4, 2)
_INST_RNG_MAKER(4)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN

439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
#define _FOR_EACH_IN(subfix)                                                \
    inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, \
            inputs[4] subfix,
_INST_RNG_MAKER(5)
#undef _FOR_EACH_IN

#define _FOR_EACH_IN(subfix)                                                \
    inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, \
            inputs[4] subfix, inputs[5] subfix,
_INST_RNG_MAKER(6)
#undef _FOR_EACH_IN

#define _FOR_EACH_IN(subfix)                                                \
    inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, \
            inputs[4] subfix, inputs[5] subfix, inputs[6] subfix,
#define _FOR_EACH_OUT(subfix) \
    outputs[0] subfix, outputs[1] subfix, outputs[2] subfix, outputs[3] subfix
_INST_RNG_INVOLKER(7, 4)
_INST_RNG_MAKER(7)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN

461 462 463
#undef _INST_RNG_INVOLKER
#undef _INST_RNG_MAKER

464
template <typename Op>
M
Megvii Engine Team 已提交
465 466
void exec(
        const OpDef& op, const SmallVector<TensorPtr>& inputs,
467 468
        const SmallVector<TensorPtr>& outputs,
        const SmallVector<TensorPtr>& workspace) {
469
    auto&& rng = op.cast_final_safe<Op>();
M
Megvii Engine Team 已提交
470

471
    auto dest = outputs[0];
M
Megvii Engine Team 已提交
472 473
    if (dest->layout().is_empty())
        return;
474
    auto cn = dest->comp_node();
475 476 477
    auto handle = rng.handle;
    if (!handle) {
        handle = RNGDnnOpManager::get_default_handle(cn);
478 479 480
    }

    // retrieve dnn_op from glob cache
M
Megvii Engine Team 已提交
481 482 483
    auto dnn_op_thread_safe =
            RNGDnnOpManager::inst().get_dnn_op<typename OpMeth<Op>::DnnOp>(
                    handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
484 485 486 487
    auto initialized = std::get<0>(dnn_op_thread_safe);
    auto dnn_op = std::get<1>(dnn_op_thread_safe);
    if (initialized) {
        auto handle_seed = RNGDnnOpManager::get_seed(handle);
M
Megvii Engine Team 已提交
488 489 490 491
        mgb_assert(
                dnn_op->param().seed == handle_seed,
                "inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed,
                dnn_op->param().seed);
492 493
    }
    dnn_op->param() = OpMeth<Op>::make_param(rng);
M
Megvii Engine Team 已提交
494 495
    _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS, OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(
            dnn_op, inputs, outputs);
496 497 498
}

template <typename Op>
499
SmallVector<LogicalTensorDesc> infer_output_attrs(
500
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
501
    LogicalTensorDesc dest;
502 503
    auto&& rng = op.cast_final_safe<Op>();
    auto handle = rng.handle;
504
    if (handle) {
505
        dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
506
    } else {
507
        dest.comp_node = inputs[0]->comp_node();
508
    }
509
    constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
M
Megvii Engine Team 已提交
510 511 512
    if (!rng_with_shape) {
        for (int i = 0; i < inputs.size(); ++i) {
            mgb_assert(
513
                    inputs[i]->comp_node() == dest.comp_node,
M
Megvii Engine Team 已提交
514 515 516 517
                    "%s expects the device of inputs[%d] to be same as the device of "
                    "handle; "
                    "got %s and %s actually",
                    rng.dyn_typeinfo()->name, i,
518 519
                    inputs[i]->comp_node().to_string().c_str(),
                    dest.comp_node.to_string().c_str());
520 521
        }
    }
522 523
    dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng);
    return {dest};
524 525
}

526
template <>
527
SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>(
528
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
529
    SmallVector<LogicalTensorDesc> dests(2);
530 531 532
    auto&& rng = op.cast_final_safe<ShuffleRNG>();
    auto handle = rng.handle;
    if (handle) {
533 534
        dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle);
        dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle);
535
    } else {
536 537 538 539 540 541 542 543
        dests[0].comp_node = inputs[0]->comp_node();
        dests[1].comp_node = inputs[0]->comp_node();
    }
    dests[0].layout = TensorLayout(inputs[0]->layout());
    dests[0].layout.dtype = inputs[0]->layout().dtype;
    dests[1].layout =
            TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32());
    return dests;
544 545
}

546
template <>
547
SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
548
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
549
    SmallVector<LogicalTensorDesc> dests(2);
550 551
    auto&& cn = inputs[0]->comp_node();

552 553 554 555 556 557 558 559 560 561 562 563
    dests[0].comp_node = cn;
    dests[0].layout = TensorLayout(inputs[0]->layout());
    dests[0].layout.dtype = inputs[0]->layout().dtype;

    auto get_mask_size = [&]() -> size_t {
        auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
        return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
                inputs[0]->layout());
    };
    dests[1].comp_node = cn;
    dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
    return dests;
564 565
}

566 567 568 569
template <typename Op>
std::tuple<SmallVector<LogicalTensorDesc>, bool> _infer_output_attrs(
        const OpDef& op, const SmallVector<TensorLayout>& inputs, const CompNode cn){};

570
template <>
571 572 573
std::tuple<SmallVector<LogicalTensorDesc>, bool> _infer_output_attrs<MultiHeadAttn>(
        const OpDef& op, const SmallVector<TensorLayout>& inputs, const CompNode cn) {
    bool success = inputs[0].ndim != 0;
574

575
    SmallVector<LogicalTensorDesc> dests(4);
576

577 578 579 580 581 582 583 584 585 586
    // retrieve dnn_op from glob cache
    auto&& rng = op.cast_final_safe<MultiHeadAttn>();
    auto handle = rng.handle;
    if (!handle) {
        handle = RNGDnnOpManager::get_default_handle(cn);
    }
    auto dnn_op_thread_safe = RNGDnnOpManager::inst().get_dnn_op<megdnn::MultiHeadAttn>(
            handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
    auto dnn_op = std::get<1>(dnn_op_thread_safe);
    dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng);
587

588 589 590 591 592 593 594 595
    TensorLayout out, attn_weight, mask_layout, othr_layout;
    dnn_op->deduce_layout(
            inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], inputs[6],
            out, attn_weight, mask_layout, othr_layout);

    dests[0].comp_node = cn;
    dests[0].layout = out;
    dests[0].layout.dtype = inputs[0].dtype;
596
    dests[1].comp_node = cn;
597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
    dests[1].layout = attn_weight;
    if (success) {
        dests[2].comp_node = cn;
        dests[2].layout = mask_layout;
        dests[3].comp_node = cn;
        dests[3].layout = othr_layout;
    } else {
        dests[2].comp_node = cn;
        dests[2].layout = TensorLayout(dtype::Byte());
        dests[3].comp_node = cn;
        dests[3].layout = TensorLayout(inputs[0].dtype);
    }

    return {dests, success};
}

template <>
SmallVector<LogicalTensorDesc> infer_output_attrs<MultiHeadAttn>(
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
616
    using InputType = opr::MultiHeadAttn::Param::TensorCombinationType;
617 618 619 620 621
    auto&& cn = inputs[0]->comp_node();
    auto input_type = op.cast_final_safe<MultiHeadAttn>().tensor_combination_type;

    std::tuple<SmallVector<LogicalTensorDesc>, bool> ret;
    TensorLayout empty_layout;
622
    if (input_type == InputType::NONE)
623 624 625 626 627
        ret = _infer_output_attrs<MultiHeadAttn>(
                op,
                {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
                 inputs[3]->layout(), empty_layout, empty_layout, empty_layout},
                cn);
628
    else if (input_type == InputType::ONLY_MASK)
629 630 631 632 633
        ret = _infer_output_attrs<MultiHeadAttn>(
                op,
                {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
                 inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout},
                cn);
634
    else if (input_type == InputType::ONLY_BIASKV)
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
        ret = _infer_output_attrs<MultiHeadAttn>(
                op,
                {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
                 inputs[3]->layout(), empty_layout, inputs[4]->layout(),
                 inputs[5]->layout()},
                cn);
    else
        ret = _infer_output_attrs<MultiHeadAttn>(
                op,
                {inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
                 inputs[3]->layout(), inputs[4]->layout(), inputs[5]->layout(),
                 inputs[6]->layout()},
                cn);

    return std::get<0>(ret);
650 651
}

652 653
template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor(
654 655
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
656
    SmallVector<TensorPtr> outputs;
657 658 659
    SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs);
    for (auto&& i : desc) {
        outputs.push_back(Tensor::make(i.layout, i.comp_node));
660
    }
661
    exec<Op>(def, inputs, outputs, {});
662 663 664
    return outputs;
}

665 666 667 668
template <>
SmallVector<TensorPtr> apply_on_physical_tensor<MultiHeadAttn>(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
669
    using InputType = opr::MultiHeadAttn::Param::TensorCombinationType;
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
    SmallVector<TensorPtr> outputs;
    SmallVector<LogicalTensorDesc> desc =
            infer_output_attrs<MultiHeadAttn>(def, inputs);
    for (auto&& i : desc) {
        outputs.push_back(Tensor::make(i.layout, i.comp_node));
    }

    auto&& rng = def.cast_final_safe<MultiHeadAttn>();
    auto dest = outputs[0];
    if (dest->layout().is_empty())
        return outputs;
    auto cn = dest->comp_node();
    auto handle = rng.handle;
    if (!handle) {
        handle = RNGDnnOpManager::get_default_handle(cn);
    }

    // retrieve dnn_op from glob cache
    auto dnn_op_thread_safe =
            RNGDnnOpManager::inst().get_dnn_op<typename OpMeth<MultiHeadAttn>::DnnOp>(
                    handle, reinterpret_cast<size_t>(def.dyn_typeinfo()), cn);
    auto initialized = std::get<0>(dnn_op_thread_safe);
    auto dnn_op = std::get<1>(dnn_op_thread_safe);
    if (initialized) {
        auto handle_seed = RNGDnnOpManager::get_seed(handle);
        mgb_assert(
                dnn_op->param().seed == handle_seed,
                "inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed,
                dnn_op->param().seed);
    }
    dnn_op->param() = OpMeth<MultiHeadAttn>::make_param(rng);

    auto input_type = rng.tensor_combination_type;
    std::shared_ptr<Tensor> empty_dnn(nullptr);
    size_t wk_size = 0;
    TensorLayout empty_layout;
    megdnn::TensorND empty_tensor;

708
    if (input_type == InputType::ALL) {
709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727
        wk_size = dnn_op->get_workspace_in_bytes(
                inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
                inputs[3]->layout(), inputs[4]->layout(), inputs[5]->layout(),
                inputs[6]->layout(), outputs[0]->layout(), outputs[1]->layout(),
                outputs[2]->layout(), outputs[3]->layout());
        auto workspace = Blob::make(outputs[0]->comp_node(), wk_size);
        megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);
        dnn_op->exec(
                inputs[0]->dev_tensor().as_megdnn(),
                inputs[1]->dev_tensor().as_megdnn(),
                inputs[2]->dev_tensor().as_megdnn(),
                inputs[3]->dev_tensor().as_megdnn(),
                inputs[4]->dev_tensor().as_megdnn(),
                inputs[5]->dev_tensor().as_megdnn(),
                inputs[6]->dev_tensor().as_megdnn(),
                outputs[0]->dev_tensor().as_megdnn(),
                outputs[1]->dev_tensor().as_megdnn(),
                outputs[2]->dev_tensor().as_megdnn(),
                outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
728
    } else if (input_type == InputType::ONLY_MASK) {
729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
        wk_size = dnn_op->get_workspace_in_bytes(
                inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
                inputs[3]->layout(), inputs[4]->layout(), empty_layout, empty_layout,
                outputs[0]->layout(), outputs[1]->layout(), outputs[2]->layout(),
                outputs[3]->layout());
        auto workspace = Blob::make(outputs[0]->comp_node(), wk_size);
        megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);
        dnn_op->exec(
                inputs[0]->dev_tensor().as_megdnn(),
                inputs[1]->dev_tensor().as_megdnn(),
                inputs[2]->dev_tensor().as_megdnn(),
                inputs[3]->dev_tensor().as_megdnn(),
                inputs[4]->dev_tensor().as_megdnn(), empty_tensor, empty_tensor,
                outputs[0]->dev_tensor().as_megdnn(),
                outputs[1]->dev_tensor().as_megdnn(),
                outputs[2]->dev_tensor().as_megdnn(),
                outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
746
    } else if (input_type == InputType::ONLY_BIASKV) {
747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785
        wk_size = dnn_op->get_workspace_in_bytes(
                inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
                inputs[3]->layout(), empty_layout, inputs[4]->layout(),
                inputs[5]->layout(), outputs[0]->layout(), outputs[1]->layout(),
                outputs[2]->layout(), outputs[3]->layout());
        auto workspace = Blob::make(outputs[0]->comp_node(), wk_size);
        megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);
        dnn_op->exec(
                inputs[0]->dev_tensor().as_megdnn(),
                inputs[1]->dev_tensor().as_megdnn(),
                inputs[2]->dev_tensor().as_megdnn(),
                inputs[3]->dev_tensor().as_megdnn(), empty_tensor,
                inputs[5]->dev_tensor().as_megdnn(),
                inputs[6]->dev_tensor().as_megdnn(),
                outputs[0]->dev_tensor().as_megdnn(),
                outputs[1]->dev_tensor().as_megdnn(),
                outputs[2]->dev_tensor().as_megdnn(),
                outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
    } else {
        wk_size = dnn_op->get_workspace_in_bytes(
                inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(),
                inputs[3]->layout(), empty_layout, empty_layout, empty_layout,
                outputs[0]->layout(), outputs[1]->layout(), outputs[2]->layout(),
                outputs[3]->layout());
        auto workspace = Blob::make(outputs[0]->comp_node(), wk_size);
        megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);
        dnn_op->exec(
                inputs[0]->dev_tensor().as_megdnn(),
                inputs[1]->dev_tensor().as_megdnn(),
                inputs[2]->dev_tensor().as_megdnn(),
                inputs[3]->dev_tensor().as_megdnn(), empty_tensor, empty_tensor,
                empty_tensor, outputs[0]->dev_tensor().as_megdnn(),
                outputs[1]->dev_tensor().as_megdnn(),
                outputs[2]->dev_tensor().as_megdnn(),
                outputs[3]->dev_tensor().as_megdnn(), dnn_wk);
    }
    return outputs;
}

786 787
template <typename Op, typename Output>
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
788
    size_t nr_inp = inputs.size();
789
    constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS;
790
    auto&& rng = def.cast_final_safe<Op>();
M
Megvii Engine Team 已提交
791 792 793 794
    if (dnn_nr_inp == 0) {
        mgb_assert(
                nr_inp == 1, "%s expects 1 inputs; got %lu actually",
                rng.dyn_typeinfo()->name, nr_inp);
795
    }
796 797
    constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp;
    return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng);
798 799
}

800 801 802 803
template <>
SymbolVarArray apply_on_var_node<MultiHeadAttn, SymbolVarArray>(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& rng = def.cast_final_safe<MultiHeadAttn>();
804
    using InputType = opr::MultiHeadAttn::Param::TensorCombinationType;
805
    auto input_type = rng.tensor_combination_type;
806
    if (input_type == InputType::ALL) {
807
        return _RNGOprMaker<7>::make(inputs, rng);
808
    } else if (input_type == InputType::ONLY_BIASKV) {
809
        return _RNGOprMaker<6>::make(inputs, rng);
810
    } else if (input_type == InputType::ONLY_MASK) {
811 812 813 814 815 816
        return _RNGOprMaker<5>::make(inputs, rng);
    } else {
        return _RNGOprMaker<4>::make(inputs, rng);
    }
}

M
Megvii Engine Team 已提交
817
template <typename Op>
818 819
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
820
    bool success = inputs[0].layout.ndim != 0;
821 822
    LogicalTensorDesc dest;
    auto&& xxx_rng_def = def.cast_final_safe<Op>();
823
    size_t nr_inp = inputs.size();
824
    constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
M
Megvii Engine Team 已提交
825 826 827 828
    if (rng_with_shape) {
        mgb_assert(
                nr_inp == 1, "%s expects 1 inputs; got %lu actually",
                xxx_rng_def.dyn_typeinfo()->name, nr_inp);
829
    }
830
    dest.comp_node = inputs[0].comp_node;
831 832 833 834 835
    if (success) {
        dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
    } else {
        dest.layout = TensorLayout(inputs[0].layout.dtype);
    }
836
    return {{dest}, inputs[0].layout.ndim != 0};
837 838
}

839
template <>
M
Megvii Engine Team 已提交
840 841
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
        ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
842 843
    bool success = inputs[0].layout.ndim != 0;

844 845 846 847 848
    SmallVector<LogicalTensorDesc> dests(2);
    dests[0].comp_node = inputs[0].comp_node;
    dests[0].layout = TensorLayout(inputs[0].layout);
    dests[0].layout.dtype = inputs[0].layout.dtype;
    dests[1].comp_node = inputs[0].comp_node;
849 850 851 852 853 854 855
    if (success) {
        dests[1].layout =
                TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
    } else {
        dests[1].layout = TensorLayout(dtype::Int32());
    }
    return {dests, success};
856 857
}

858 859 860
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>(
        const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
861 862
    bool success = inputs[0].layout.ndim != 0;

863 864 865 866 867 868 869 870 871 872 873 874
    SmallVector<LogicalTensorDesc> dests(2);
    auto cn = inputs[0].comp_node;
    dests[0].comp_node = cn;
    dests[0].layout = TensorLayout(inputs[0].layout);
    dests[0].layout.dtype = inputs[0].layout.dtype;

    auto get_mask_size = [&]() -> size_t {
        auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
        return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
                inputs[0].layout);
    };
    dests[1].comp_node = cn;
875 876 877 878 879 880 881
    if (success) {
        dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
    } else {
        dests[1].layout = TensorLayout(dtype::Byte());
    }

    return {dests, success};
882 883
}

884 885 886
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
        MultiHeadAttn>(const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
887
    using InputType = opr::MultiHeadAttn::Param::TensorCombinationType;
888 889 890 891 892
    auto&& cn = inputs[0].comp_node;
    auto input_type = op.cast_final_safe<MultiHeadAttn>().tensor_combination_type;

    std::tuple<SmallVector<LogicalTensorDesc>, bool> ret;
    TensorLayout empty_layout;
893
    if (input_type == InputType::NONE)
894 895 896 897 898
        ret = _infer_output_attrs<MultiHeadAttn>(
                op,
                {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
                 empty_layout, empty_layout, empty_layout},
                cn);
899
    else if (input_type == InputType::ONLY_MASK)
900 901 902 903 904
        ret = _infer_output_attrs<MultiHeadAttn>(
                op,
                {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
                 inputs[4].layout, empty_layout, empty_layout},
                cn);
905
    else if (input_type == InputType::ONLY_BIASKV)
906 907 908 909 910 911 912 913 914 915 916 917 918
        ret = _infer_output_attrs<MultiHeadAttn>(
                op,
                {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
                 empty_layout, inputs[4].layout, inputs[5].layout},
                cn);
    else
        ret = _infer_output_attrs<MultiHeadAttn>(
                op,
                {inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout,
                 inputs[4].layout, inputs[5].layout, inputs[6].layout},
                cn);

    return ret;
919 920
}

921 922 923 924 925 926 927
template <typename Op>
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
    return layout_checker;
}

928
}  // anonymous namespace
929

930
Handle new_handle(CompNode comp_node, uint64_t seed) {
931 932 933
    return RNGDnnOpManager::inst().new_handle(comp_node, seed);
}

934
size_t delete_handle(Handle handle) {
935 936 937
    return RNGDnnOpManager::inst().delete_handle(handle);
}

938
void set_global_rng_seed(uint64_t seed) {
939 940
    RNGDnnOpManager::set_glob_default_seed(seed);
}
941 942 943 944 945

uint64_t get_global_rng_seed() {
    return RNGDnnOpManager::get_glob_default_seed();
}

M
Megvii Engine Team 已提交
946
CompNode get_rng_handle_compnode(Handle handle) {
947 948 949
    return RNGDnnOpManager::get_comp_node(handle);
}

950 951 952 953 954 955
#define REG_RNG_OP(NAME, Output)                                            \
    namespace {                                                             \
    OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode)                          \
            .apply_on_var_node(apply_on_var_node<NAME, Output>)             \
            .apply_on_physical_tensor(apply_on_physical_tensor<NAME>)       \
            .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
956
            .get_input_layout_constraint(get_input_layout_constraint<NAME>) \
957 958 959 960 961 962 963 964 965 966
            .fallback();                                                    \
    }

REG_RNG_OP(UniformRNG, SymbolVar)
REG_RNG_OP(GaussianRNG, SymbolVar)
REG_RNG_OP(GammaRNG, SymbolVar)
REG_RNG_OP(PermutationRNG, SymbolVar)
REG_RNG_OP(PoissonRNG, SymbolVar)
REG_RNG_OP(BetaRNG, SymbolVar)
REG_RNG_OP(ShuffleRNG, SymbolVarArray)
967
REG_RNG_OP(Dropout, SymbolVarArray)
968
REG_RNG_OP(MultiHeadAttn, SymbolVarArray)
969
#undef REG_RNG_OP
970

971
}  // namespace mgb::imperative::rng
972 973

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}