rng.cpp 22.6 KB
Newer Older
1 2 3 4
/**
 * \file imperative/src/impl/ops/rng.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17
 */

#include "megbrain/imperative/ops/rng.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"

18
#include "../dnn_op_helper.h"
19
#include "../op_trait.h"
20

21
namespace mgb::imperative::rng {
22 23 24 25 26 27

namespace {

template <typename HandleFactory, typename THandle>
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
public:
28
    using DT = CompNode::DeviceType;
29
    using Handle = THandle;
30
    using OpTypeInfo = size_t;
31 32 33 34 35 36 37 38 39 40 41

    template <typename... Args>
    Handle new_handle(Args&&... args) {
        return static_cast<HandleFactory*>(this)->do_new_handle(
                std::forward<Args>(args)...);
    }

    size_t delete_handle(Handle handle) {
        size_t removed = 0;
        if (!is_finalized()) {
            MGB_LOCK_GUARD(m_mtx);
42
            removed = m_handle2ops.erase(handle);
43 44 45 46 47 48
        }
        static_cast<HandleFactory*>(this)->do_delete_handle(handle);
        return removed;
    }

    template <typename DnnOp>
49
    auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
50 51 52 53
        mgb_assert(!is_finalized());
        DnnOpWithMutex* dnn_op_with_mtx;
        {
            MGB_LOCK_GUARD(m_mtx);
54
            dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
55
        }
M
Megvii Engine Team 已提交
56
        auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
57 58
        std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
        bool initialized = false;
59 60
        DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
        if (dnn_op != nullptr) {
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
            mgb_assert(dnn_op->handle() == dnn_handle);
            initialized = true;
        } else {
            auto new_op = dnn_handle->create_operator<DnnOp>();
            dnn_op = new_op.get();
            dnn_op_with_mtx->op = std::move(new_op);
        }
        return std::make_tuple(initialized, dnn_op, std::move(lock));
    }

protected:
    using DnnOpManagerBase = DnnOpManagerT<HandleFactory, Handle>;
    DnnOpManagerT() = default;

private:
    struct DnnOpWithMutex {
        std::mutex mtx;
        std::unique_ptr<megdnn::OperatorBase> op;
M
Megvii Engine Team 已提交
79
        DnnOpWithMutex() : op{nullptr} {}
80 81 82 83
    };

    std::shared_ptr<void> on_comp_node_finalize() override {
        MGB_LOCK_GUARD(m_mtx);
84
        m_handle2ops.clear();
85 86 87
        return {};
    }

M
Megvii Engine Team 已提交
88 89
    std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex>>
            m_handle2ops;
90 91 92
    std::mutex m_mtx;
};

M
Megvii Engine Team 已提交
93
class RNGDnnOpManager final : public DnnOpManagerT<RNGDnnOpManager, Handle> {
94
public:
95 96 97 98 99
    Handle new_handle(CompNode comp_node, uint64_t seed) {
        MGB_LOCK_GUARD(sm_mtx);
        return DnnOpManagerBase::new_handle(comp_node, seed);
    }

100
    size_t delete_handle(Handle handle) {
101 102
        MGB_LOCK_GUARD(sm_mtx);
        return DnnOpManagerBase::delete_handle(handle);
103 104 105 106 107 108 109 110 111 112 113 114
    }

    Handle do_new_handle(CompNode comp_node, uint64_t seed) {
        auto handle = m_handle_pool.alloc(comp_node, seed);
        return reinterpret_cast<Handle>(handle);
    }

    void do_delete_handle(Handle handle) {
        m_handle_pool.free(reinterpret_cast<HandleData*>(handle));
    }

    static uint64_t get_seed(Handle handle) {
M
Megvii Engine Team 已提交
115 116 117
        if (!handle) {
            return glob_default_seed;
        }
118 119 120 121
        return reinterpret_cast<HandleData*>(handle)->seed;
    }

    static CompNode get_comp_node(Handle handle) {
122
        mgb_assert(handle, "invalid handle");
123 124 125 126
        return reinterpret_cast<HandleData*>(handle)->comp_node;
    }

    static Handle get_default_handle(CompNode comp_node) {
127 128 129 130 131
        mgb_assert(comp_node.valid());
        MGB_LOCK_GUARD(sm_mtx);
        auto&& glob_handle = glob_default_handles[comp_node];
        if (!glob_handle) {
            glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
132
        }
133
        mgb_assert(get_seed(glob_handle) == glob_default_seed);
134
        return glob_handle;
135 136 137 138 139 140 141 142
    }

    static RNGDnnOpManager& inst() {
        static RNGDnnOpManager mgr;
        return mgr;
    }

    static void set_glob_default_seed(uint64_t seed) {
143
        MGB_LOCK_GUARD(sm_mtx);
M
Megvii Engine Team 已提交
144
        for (auto&& elem : glob_default_handles) {
145
            mgb_assert(elem.first.valid());
M
Megvii Engine Team 已提交
146
            if (elem.second) {
147 148 149 150
                inst().DnnOpManagerBase::delete_handle(elem.second);
            }
            elem.second = inst().do_new_handle(elem.first, seed);
        }
151 152 153
        glob_default_seed = seed;
    }

154 155 156 157 158
    static uint64_t get_glob_default_seed() {
        MGB_LOCK_GUARD(sm_mtx);
        return glob_default_seed;
    }

159 160 161 162 163 164 165 166 167 168
private:
    struct HandleData {
        CompNode comp_node;
        uint64_t seed;
        HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {}
    };

    MemPool<HandleData> m_handle_pool;

    static std::mutex sm_mtx;
169
    static CompNode::UnorderedMap<Handle> glob_default_handles;
170 171 172 173 174
    static uint64_t glob_default_seed;
};

uint64_t RNGDnnOpManager::glob_default_seed = 0;
std::mutex RNGDnnOpManager::sm_mtx;
175
CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;
176 177 178 179 180 181 182 183 184 185

template <typename Op>
struct OpMeth;

template <>
struct OpMeth<UniformRNG> {
    using DnnOp = megdnn::UniformRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::UniformRNG;
    static Param make_param(const UniformRNG& rng) {
186
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
187 188 189 190
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
191 192 193 194 195 196 197 198 199 200
        return {handle_seed, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<PoissonRNG> {
    using DnnOp = megdnn::PoissonRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::PoissonRNG;
    static Param make_param(const PoissonRNG& rng) {
201
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
202 203 204 205
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
206
        return {handle_seed};
207 208 209 210 211 212 213 214 215
    }
};

template <>
struct OpMeth<GaussianRNG> {
    using DnnOp = megdnn::GaussianRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::GaussianRNG;
    static Param make_param(const GaussianRNG& rng) {
216
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
217 218 219 220
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
221 222 223 224 225 226 227 228 229 230 231
        return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<GammaRNG> {
    using DnnOp = megdnn::GammaRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::GammaRNG;
    static Param make_param(const GammaRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
232 233 234 235
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
236
        return {handle_seed};
237 238 239
    }
};

240 241 242 243 244 245 246
template <>
struct OpMeth<PermutationRNG> {
    using DnnOp = megdnn::PermutationRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::PermutationRNG;
    static Param make_param(const PermutationRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
247 248 249 250
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
251 252 253 254 255 256 257 258 259 260 261
        return {handle_seed, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<BetaRNG> {
    using DnnOp = megdnn::BetaRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::BetaRNG;
    static Param make_param(const BetaRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
262 263 264 265
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
266 267 268 269
        return {handle_seed};
    }
};

270 271 272 273 274 275 276
template <>
struct OpMeth<ShuffleRNG> {
    using DnnOp = megdnn::ShuffleRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::ShuffleRNG;
    static Param make_param(const ShuffleRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
277 278 279 280
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
281 282 283 284
        return {handle_seed};
    }
};

285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
template <>
struct OpMeth<Dropout> {
    using DnnOp = megdnn::Dropout;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::Dropout;
    static Param make_param(const Dropout& opdef) {
        auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle);
        mgb_assert(
                handle_seed == opdef.seed,
                "inconsistent dropout seed: dropout op: %lu handle: %lu", handle_seed,
                opdef.seed);
        return {opdef.drop_prob, handle_seed};
    }
};

300 301 302 303 304 305
template <bool>
struct _InferLayout;

template <int nr_in>
struct _RNGOprMaker;

306
template <int nr_in, int nr_out>
307 308
struct _RNGOprInvoker;

M
Megvii Engine Team 已提交
309 310 311 312
template <>
struct _InferLayout<true> {
    template <typename Op>
    static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
313 314 315 316 317 318
        TensorShape tshape;
        auto hv = inp->get_value().proxy_to_default_cpu();
        cg::copy_tensor_value_to_shape(tshape, hv);
        return TensorLayout(tshape, rng.dtype);
    }

M
Megvii Engine Team 已提交
319 320
    template <typename Op>
    static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
321 322 323 324 325 326 327 328 329
        TensorLayout out_layout = inp.layout;
        out_layout.dtype = rng.dtype;
        if (inp.layout.ndim == 0 || inp.value.empty()) {
            out_layout.ndim = 0;
            return out_layout;
        }
        mgb_assert(
                inp.layout.ndim == 1,
                "target shape of %s expects ndim=1; got ndim=%lu actually",
M
Megvii Engine Team 已提交
330
                rng.dyn_typeinfo()->name, inp.layout.ndim);
331 332 333 334 335 336
        size_t target_ndim = inp.layout.shape[0];
        out_layout.ndim = target_ndim;
        auto* ptr = inp.value.ptr<dt_int32>();
        for (size_t i = 0; i < target_ndim; ++i) {
            out_layout.shape[i] = ptr[i];
        }
337
        out_layout.init_contiguous_stride();
338 339 340 341
        return out_layout;
    }
};

M
Megvii Engine Team 已提交
342 343 344 345
template <>
struct _InferLayout<false> {
    template <typename Op>
    static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
346 347 348
        return inp->layout();
    }

M
Megvii Engine Team 已提交
349 350
    template <typename Op>
    static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
351
        mgb_assert(inp.layout.ndim);
352 353 354 355
        return inp.layout;
    }
};

M
Megvii Engine Team 已提交
356
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS)                  \
357
    template <>                                                            \
M
Megvii Engine Team 已提交
358 359 360 361 362 363 364 365 366 367 368 369 370 371
    struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> {                 \
        template <typename Opr>                                            \
        static void exec(                                                  \
                Opr* dnn_op, const SmallVector<TensorPtr>& inputs,         \
                const SmallVector<TensorPtr>& outputs) {                   \
            size_t wk_size = 0;                                            \
            wk_size = dnn_op->get_workspace_in_bytes(                      \
                    _FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout()));   \
            auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \
            megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
            dnn_op->exec(                                                  \
                    _FOR_EACH_IN(->dev_tensor().as_megdnn())               \
                            _FOR_EACH_OUT(->dev_tensor().as_megdnn()),     \
                    dnn_wk);                                               \
372 373
        }                                                                  \
    };
374

M
Megvii Engine Team 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
#define _INST_RNG_MAKER(MGB_NR_INPUTS)                                                \
    template <>                                                                       \
    struct _RNGOprMaker<MGB_NR_INPUTS> {                                              \
        template <typename Op>                                                        \
        static auto make(const VarNodeArray& inputs, const Op& rng) {                 \
            auto param = OpMeth<Op>::make_param(rng);                                 \
            OperatorNodeConfig config;                                                \
            if (rng.handle) {                                                         \
                config = {                                                            \
                        rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \
            } else {                                                                  \
                config = {rng.make_name()};                                           \
            }                                                                         \
            return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config);            \
        }                                                                             \
    };

392 393 394 395
#define _FOR_EACH_IN(subfix)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(0, 1)
#undef _FOR_EACH_OUT
396 397
#undef _FOR_EACH_IN

M
Megvii Engine Team 已提交
398
#define _FOR_EACH_IN(subfix)  inputs[0] subfix,
399 400 401 402 403 404
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(1, 1)
#undef _FOR_EACH_OUT

#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
_INST_RNG_INVOLKER(1, 2)
405
_INST_RNG_MAKER(1)
406
#undef _FOR_EACH_OUT
407 408
#undef _FOR_EACH_IN

M
Megvii Engine Team 已提交
409
#define _FOR_EACH_IN(subfix)  inputs[0] subfix, inputs[1] subfix,
410 411
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(2, 1)
412
_INST_RNG_MAKER(2)
413
#undef _FOR_EACH_OUT
414 415 416 417 418
#undef _FOR_EACH_IN

#undef _INST_RNG_INVOLKER
#undef _INST_RNG_MAKER

419
template <typename Op>
M
Megvii Engine Team 已提交
420 421
void exec(
        const OpDef& op, const SmallVector<TensorPtr>& inputs,
422 423
        const SmallVector<TensorPtr>& outputs,
        const SmallVector<TensorPtr>& workspace) {
424
    auto&& rng = op.cast_final_safe<Op>();
M
Megvii Engine Team 已提交
425

426
    auto dest = outputs[0];
M
Megvii Engine Team 已提交
427 428
    if (dest->layout().is_empty())
        return;
429
    auto cn = dest->comp_node();
430 431 432
    auto handle = rng.handle;
    if (!handle) {
        handle = RNGDnnOpManager::get_default_handle(cn);
433 434 435
    }

    // retrieve dnn_op from glob cache
M
Megvii Engine Team 已提交
436 437 438
    auto dnn_op_thread_safe =
            RNGDnnOpManager::inst().get_dnn_op<typename OpMeth<Op>::DnnOp>(
                    handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
439 440 441 442
    auto initialized = std::get<0>(dnn_op_thread_safe);
    auto dnn_op = std::get<1>(dnn_op_thread_safe);
    if (initialized) {
        auto handle_seed = RNGDnnOpManager::get_seed(handle);
M
Megvii Engine Team 已提交
443 444 445 446
        mgb_assert(
                dnn_op->param().seed == handle_seed,
                "inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed,
                dnn_op->param().seed);
447 448
    }
    dnn_op->param() = OpMeth<Op>::make_param(rng);
M
Megvii Engine Team 已提交
449 450
    _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS, OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(
            dnn_op, inputs, outputs);
451 452 453
}

template <typename Op>
454
SmallVector<LogicalTensorDesc> infer_output_attrs(
455
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
456
    LogicalTensorDesc dest;
457 458
    auto&& rng = op.cast_final_safe<Op>();
    auto handle = rng.handle;
459
    if (handle) {
460
        dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
461
    } else {
462
        dest.comp_node = inputs[0]->comp_node();
463
    }
464
    constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
M
Megvii Engine Team 已提交
465 466 467
    if (!rng_with_shape) {
        for (int i = 0; i < inputs.size(); ++i) {
            mgb_assert(
468
                    inputs[i]->comp_node() == dest.comp_node,
M
Megvii Engine Team 已提交
469 470 471 472
                    "%s expects the device of inputs[%d] to be same as the device of "
                    "handle; "
                    "got %s and %s actually",
                    rng.dyn_typeinfo()->name, i,
473 474
                    inputs[i]->comp_node().to_string().c_str(),
                    dest.comp_node.to_string().c_str());
475 476
        }
    }
477 478
    dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng);
    return {dest};
479 480
}

481
template <>
482
SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>(
483
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
484
    SmallVector<LogicalTensorDesc> dests(2);
485 486 487
    auto&& rng = op.cast_final_safe<ShuffleRNG>();
    auto handle = rng.handle;
    if (handle) {
488 489
        dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle);
        dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle);
490
    } else {
491 492 493 494 495 496 497 498
        dests[0].comp_node = inputs[0]->comp_node();
        dests[1].comp_node = inputs[0]->comp_node();
    }
    dests[0].layout = TensorLayout(inputs[0]->layout());
    dests[0].layout.dtype = inputs[0]->layout().dtype;
    dests[1].layout =
            TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32());
    return dests;
499 500
}

501
template <>
502
SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
503
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
504
    SmallVector<LogicalTensorDesc> dests(2);
505 506
    auto&& cn = inputs[0]->comp_node();

507 508 509 510 511 512 513 514 515 516 517 518
    dests[0].comp_node = cn;
    dests[0].layout = TensorLayout(inputs[0]->layout());
    dests[0].layout.dtype = inputs[0]->layout().dtype;

    auto get_mask_size = [&]() -> size_t {
        auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
        return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
                inputs[0]->layout());
    };
    dests[1].comp_node = cn;
    dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
    return dests;
519 520
}

521 522
template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor(
523 524
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
525
    SmallVector<TensorPtr> outputs;
526 527 528
    SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs);
    for (auto&& i : desc) {
        outputs.push_back(Tensor::make(i.layout, i.comp_node));
529
    }
530
    exec<Op>(def, inputs, outputs, {});
531 532 533
    return outputs;
}

534 535
template <typename Op, typename Output>
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
536
    size_t nr_inp = inputs.size();
537
    constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS;
538
    auto&& rng = def.cast_final_safe<Op>();
M
Megvii Engine Team 已提交
539 540 541 542
    if (dnn_nr_inp == 0) {
        mgb_assert(
                nr_inp == 1, "%s expects 1 inputs; got %lu actually",
                rng.dyn_typeinfo()->name, nr_inp);
543
    }
544 545
    constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp;
    return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng);
546 547
}

M
Megvii Engine Team 已提交
548
template <typename Op>
549 550
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
551 552
    LogicalTensorDesc dest;
    auto&& xxx_rng_def = def.cast_final_safe<Op>();
553
    size_t nr_inp = inputs.size();
554
    constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
M
Megvii Engine Team 已提交
555 556 557 558
    if (rng_with_shape) {
        mgb_assert(
                nr_inp == 1, "%s expects 1 inputs; got %lu actually",
                xxx_rng_def.dyn_typeinfo()->name, nr_inp);
559
    }
560 561 562
    dest.comp_node = inputs[0].comp_node;
    dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
    return {{dest}, true};
563 564
}

565
template <>
M
Megvii Engine Team 已提交
566 567
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
        ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
568 569 570 571 572
    SmallVector<LogicalTensorDesc> dests(2);
    dests[0].comp_node = inputs[0].comp_node;
    dests[0].layout = TensorLayout(inputs[0].layout);
    dests[0].layout.dtype = inputs[0].layout.dtype;
    dests[1].comp_node = inputs[0].comp_node;
M
Megvii Engine Team 已提交
573 574
    dests[1].layout =
            TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
575 576 577
    return {dests, true};
}

578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>(
        const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
    SmallVector<LogicalTensorDesc> dests(2);
    auto cn = inputs[0].comp_node;
    dests[0].comp_node = cn;
    dests[0].layout = TensorLayout(inputs[0].layout);
    dests[0].layout.dtype = inputs[0].layout.dtype;

    auto get_mask_size = [&]() -> size_t {
        auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
        return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
                inputs[0].layout);
    };
    dests[1].comp_node = cn;
    dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
    return {dests, true};
}

597 598 599 600 601 602 603
template <typename Op>
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
    return layout_checker;
}

604
}  // anonymous namespace
605

606
Handle new_handle(CompNode comp_node, uint64_t seed) {
607 608 609
    return RNGDnnOpManager::inst().new_handle(comp_node, seed);
}

610
size_t delete_handle(Handle handle) {
611 612 613
    return RNGDnnOpManager::inst().delete_handle(handle);
}

614
void set_global_rng_seed(uint64_t seed) {
615 616
    RNGDnnOpManager::set_glob_default_seed(seed);
}
617 618 619 620 621

uint64_t get_global_rng_seed() {
    return RNGDnnOpManager::get_glob_default_seed();
}

M
Megvii Engine Team 已提交
622
CompNode get_rng_handle_compnode(Handle handle) {
623 624 625
    return RNGDnnOpManager::get_comp_node(handle);
}

626 627 628 629 630 631
#define REG_RNG_OP(NAME, Output)                                            \
    namespace {                                                             \
    OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode)                          \
            .apply_on_var_node(apply_on_var_node<NAME, Output>)             \
            .apply_on_physical_tensor(apply_on_physical_tensor<NAME>)       \
            .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
632
            .get_input_layout_constraint(get_input_layout_constraint<NAME>) \
633 634 635 636 637 638 639 640 641 642
            .fallback();                                                    \
    }

REG_RNG_OP(UniformRNG, SymbolVar)
REG_RNG_OP(GaussianRNG, SymbolVar)
REG_RNG_OP(GammaRNG, SymbolVar)
REG_RNG_OP(PermutationRNG, SymbolVar)
REG_RNG_OP(PoissonRNG, SymbolVar)
REG_RNG_OP(BetaRNG, SymbolVar)
REG_RNG_OP(ShuffleRNG, SymbolVarArray)
643
REG_RNG_OP(Dropout, SymbolVarArray)
644
#undef REG_RNG_OP
645

646
}  // namespace mgb::imperative::rng
647 648

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}