rng.cpp 22.7 KB
Newer Older
1 2 3 4 5
#include "megbrain/imperative/ops/rng.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"

6
#include "../dnn_op_helper.h"
7
#include "../op_trait.h"
8

9
namespace mgb::imperative::rng {
10 11 12 13 14 15

namespace {

template <typename HandleFactory, typename THandle>
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
public:
16
    using DT = CompNode::DeviceType;
17
    using Handle = THandle;
18
    using OpTypeInfo = size_t;
19 20 21 22 23 24 25 26 27 28 29

    template <typename... Args>
    Handle new_handle(Args&&... args) {
        return static_cast<HandleFactory*>(this)->do_new_handle(
                std::forward<Args>(args)...);
    }

    size_t delete_handle(Handle handle) {
        size_t removed = 0;
        if (!is_finalized()) {
            MGB_LOCK_GUARD(m_mtx);
30
            removed = m_handle2ops.erase(handle);
31 32 33 34 35 36
        }
        static_cast<HandleFactory*>(this)->do_delete_handle(handle);
        return removed;
    }

    template <typename DnnOp>
37
    auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
38 39 40 41
        mgb_assert(!is_finalized());
        DnnOpWithMutex* dnn_op_with_mtx;
        {
            MGB_LOCK_GUARD(m_mtx);
42
            dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
43
        }
M
Megvii Engine Team 已提交
44
        auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
45 46
        std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
        bool initialized = false;
47 48
        DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
        if (dnn_op != nullptr) {
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
            mgb_assert(dnn_op->handle() == dnn_handle);
            initialized = true;
        } else {
            auto new_op = dnn_handle->create_operator<DnnOp>();
            dnn_op = new_op.get();
            dnn_op_with_mtx->op = std::move(new_op);
        }
        return std::make_tuple(initialized, dnn_op, std::move(lock));
    }

protected:
    using DnnOpManagerBase = DnnOpManagerT<HandleFactory, Handle>;
    DnnOpManagerT() = default;

private:
    struct DnnOpWithMutex {
        std::mutex mtx;
        std::unique_ptr<megdnn::OperatorBase> op;
M
Megvii Engine Team 已提交
67
        DnnOpWithMutex() : op{nullptr} {}
68 69 70 71
    };

    std::shared_ptr<void> on_comp_node_finalize() override {
        MGB_LOCK_GUARD(m_mtx);
72
        m_handle2ops.clear();
73 74 75
        return {};
    }

M
Megvii Engine Team 已提交
76 77
    std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex>>
            m_handle2ops;
78 79 80
    std::mutex m_mtx;
};

M
Megvii Engine Team 已提交
81
class RNGDnnOpManager final : public DnnOpManagerT<RNGDnnOpManager, Handle> {
82
public:
83 84 85 86 87
    Handle new_handle(CompNode comp_node, uint64_t seed) {
        MGB_LOCK_GUARD(sm_mtx);
        return DnnOpManagerBase::new_handle(comp_node, seed);
    }

88
    size_t delete_handle(Handle handle) {
89 90
        MGB_LOCK_GUARD(sm_mtx);
        return DnnOpManagerBase::delete_handle(handle);
91 92 93 94 95 96 97 98 99 100 101 102
    }

    Handle do_new_handle(CompNode comp_node, uint64_t seed) {
        auto handle = m_handle_pool.alloc(comp_node, seed);
        return reinterpret_cast<Handle>(handle);
    }

    void do_delete_handle(Handle handle) {
        m_handle_pool.free(reinterpret_cast<HandleData*>(handle));
    }

    static uint64_t get_seed(Handle handle) {
M
Megvii Engine Team 已提交
103 104 105
        if (!handle) {
            return glob_default_seed;
        }
106 107 108 109
        return reinterpret_cast<HandleData*>(handle)->seed;
    }

    static CompNode get_comp_node(Handle handle) {
110
        mgb_assert(handle, "invalid handle");
111 112 113 114
        return reinterpret_cast<HandleData*>(handle)->comp_node;
    }

    static Handle get_default_handle(CompNode comp_node) {
115 116 117 118 119
        mgb_assert(comp_node.valid());
        MGB_LOCK_GUARD(sm_mtx);
        auto&& glob_handle = glob_default_handles[comp_node];
        if (!glob_handle) {
            glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
120
        }
121
        mgb_assert(get_seed(glob_handle) == glob_default_seed);
122
        return glob_handle;
123 124 125 126 127 128 129 130
    }

    static RNGDnnOpManager& inst() {
        static RNGDnnOpManager mgr;
        return mgr;
    }

    static void set_glob_default_seed(uint64_t seed) {
131
        MGB_LOCK_GUARD(sm_mtx);
M
Megvii Engine Team 已提交
132
        for (auto&& elem : glob_default_handles) {
133
            mgb_assert(elem.first.valid());
M
Megvii Engine Team 已提交
134
            if (elem.second) {
135 136 137 138
                inst().DnnOpManagerBase::delete_handle(elem.second);
            }
            elem.second = inst().do_new_handle(elem.first, seed);
        }
139 140 141
        glob_default_seed = seed;
    }

142 143 144 145 146
    static uint64_t get_glob_default_seed() {
        MGB_LOCK_GUARD(sm_mtx);
        return glob_default_seed;
    }

147 148 149 150 151 152 153 154 155 156
private:
    struct HandleData {
        CompNode comp_node;
        uint64_t seed;
        HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {}
    };

    MemPool<HandleData> m_handle_pool;

    static std::mutex sm_mtx;
157
    static CompNode::UnorderedMap<Handle> glob_default_handles;
158 159 160 161 162
    static uint64_t glob_default_seed;
};

uint64_t RNGDnnOpManager::glob_default_seed = 0;
std::mutex RNGDnnOpManager::sm_mtx;
163
CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;
164 165 166 167 168 169 170 171 172 173

template <typename Op>
struct OpMeth;

template <>
struct OpMeth<UniformRNG> {
    using DnnOp = megdnn::UniformRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::UniformRNG;
    static Param make_param(const UniformRNG& rng) {
174
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
175 176 177 178
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
179 180 181 182 183 184 185 186 187 188
        return {handle_seed, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<PoissonRNG> {
    using DnnOp = megdnn::PoissonRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::PoissonRNG;
    static Param make_param(const PoissonRNG& rng) {
189
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
190 191 192 193
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
194
        return {handle_seed};
195 196 197 198 199 200 201 202 203
    }
};

template <>
struct OpMeth<GaussianRNG> {
    using DnnOp = megdnn::GaussianRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::GaussianRNG;
    static Param make_param(const GaussianRNG& rng) {
204
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
205 206 207 208
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
209 210 211 212 213 214 215 216 217 218 219
        return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<GammaRNG> {
    using DnnOp = megdnn::GammaRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::GammaRNG;
    static Param make_param(const GammaRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
220 221 222 223
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
224
        return {handle_seed};
225 226 227
    }
};

228 229 230 231 232 233 234
template <>
struct OpMeth<PermutationRNG> {
    using DnnOp = megdnn::PermutationRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::PermutationRNG;
    static Param make_param(const PermutationRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
235 236 237 238
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
239 240 241 242 243 244 245 246 247 248 249
        return {handle_seed, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<BetaRNG> {
    using DnnOp = megdnn::BetaRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::BetaRNG;
    static Param make_param(const BetaRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
250 251 252 253
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
254 255 256 257
        return {handle_seed};
    }
};

258 259 260 261 262 263 264
template <>
struct OpMeth<ShuffleRNG> {
    using DnnOp = megdnn::ShuffleRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::ShuffleRNG;
    static Param make_param(const ShuffleRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
M
Megvii Engine Team 已提交
265 266 267 268
        mgb_assert(
                handle_seed == rng.seed,
                "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
                rng.seed);
269 270 271 272
        return {handle_seed};
    }
};

273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
template <>
struct OpMeth<Dropout> {
    using DnnOp = megdnn::Dropout;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::Dropout;
    static Param make_param(const Dropout& opdef) {
        auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle);
        mgb_assert(
                handle_seed == opdef.seed,
                "inconsistent dropout seed: dropout op: %lu handle: %lu", handle_seed,
                opdef.seed);
        return {opdef.drop_prob, handle_seed};
    }
};

288 289 290 291 292 293
template <bool>
struct _InferLayout;

template <int nr_in>
struct _RNGOprMaker;

294
template <int nr_in, int nr_out>
295 296
struct _RNGOprInvoker;

M
Megvii Engine Team 已提交
297 298 299 300
template <>
struct _InferLayout<true> {
    template <typename Op>
    static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
301 302 303 304 305 306
        TensorShape tshape;
        auto hv = inp->get_value().proxy_to_default_cpu();
        cg::copy_tensor_value_to_shape(tshape, hv);
        return TensorLayout(tshape, rng.dtype);
    }

M
Megvii Engine Team 已提交
307 308
    template <typename Op>
    static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
309 310 311 312 313 314 315 316 317
        TensorLayout out_layout = inp.layout;
        out_layout.dtype = rng.dtype;
        if (inp.layout.ndim == 0 || inp.value.empty()) {
            out_layout.ndim = 0;
            return out_layout;
        }
        mgb_assert(
                inp.layout.ndim == 1,
                "target shape of %s expects ndim=1; got ndim=%lu actually",
M
Megvii Engine Team 已提交
318
                rng.dyn_typeinfo()->name, inp.layout.ndim);
319 320 321 322 323 324
        size_t target_ndim = inp.layout.shape[0];
        out_layout.ndim = target_ndim;
        auto* ptr = inp.value.ptr<dt_int32>();
        for (size_t i = 0; i < target_ndim; ++i) {
            out_layout.shape[i] = ptr[i];
        }
325
        out_layout.init_contiguous_stride();
326 327 328 329
        return out_layout;
    }
};

M
Megvii Engine Team 已提交
330 331 332 333
template <>
struct _InferLayout<false> {
    template <typename Op>
    static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
334 335 336
        return inp->layout();
    }

M
Megvii Engine Team 已提交
337 338
    template <typename Op>
    static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
339
        mgb_assert(inp.layout.ndim);
340 341 342 343
        return inp.layout;
    }
};

M
Megvii Engine Team 已提交
344
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS)                  \
345
    template <>                                                            \
M
Megvii Engine Team 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359
    struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> {                 \
        template <typename Opr>                                            \
        static void exec(                                                  \
                Opr* dnn_op, const SmallVector<TensorPtr>& inputs,         \
                const SmallVector<TensorPtr>& outputs) {                   \
            size_t wk_size = 0;                                            \
            wk_size = dnn_op->get_workspace_in_bytes(                      \
                    _FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout()));   \
            auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \
            megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
            dnn_op->exec(                                                  \
                    _FOR_EACH_IN(->dev_tensor().as_megdnn())               \
                            _FOR_EACH_OUT(->dev_tensor().as_megdnn()),     \
                    dnn_wk);                                               \
360 361
        }                                                                  \
    };
362

M
Megvii Engine Team 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
#define _INST_RNG_MAKER(MGB_NR_INPUTS)                                                \
    template <>                                                                       \
    struct _RNGOprMaker<MGB_NR_INPUTS> {                                              \
        template <typename Op>                                                        \
        static auto make(const VarNodeArray& inputs, const Op& rng) {                 \
            auto param = OpMeth<Op>::make_param(rng);                                 \
            OperatorNodeConfig config;                                                \
            if (rng.handle) {                                                         \
                config = {                                                            \
                        rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \
            } else {                                                                  \
                config = {rng.make_name()};                                           \
            }                                                                         \
            return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config);            \
        }                                                                             \
    };

380 381 382 383
#define _FOR_EACH_IN(subfix)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(0, 1)
#undef _FOR_EACH_OUT
384 385
#undef _FOR_EACH_IN

M
Megvii Engine Team 已提交
386
#define _FOR_EACH_IN(subfix)  inputs[0] subfix,
387 388 389 390 391 392
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(1, 1)
#undef _FOR_EACH_OUT

#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
_INST_RNG_INVOLKER(1, 2)
393
_INST_RNG_MAKER(1)
394
#undef _FOR_EACH_OUT
395 396
#undef _FOR_EACH_IN

M
Megvii Engine Team 已提交
397
#define _FOR_EACH_IN(subfix)  inputs[0] subfix, inputs[1] subfix,
398 399
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(2, 1)
400
_INST_RNG_MAKER(2)
401
#undef _FOR_EACH_OUT
402 403 404 405 406
#undef _FOR_EACH_IN

#undef _INST_RNG_INVOLKER
#undef _INST_RNG_MAKER

407
template <typename Op>
M
Megvii Engine Team 已提交
408 409
void exec(
        const OpDef& op, const SmallVector<TensorPtr>& inputs,
410 411
        const SmallVector<TensorPtr>& outputs,
        const SmallVector<TensorPtr>& workspace) {
412
    auto&& rng = op.cast_final_safe<Op>();
M
Megvii Engine Team 已提交
413

414
    auto dest = outputs[0];
M
Megvii Engine Team 已提交
415 416
    if (dest->layout().is_empty())
        return;
417
    auto cn = dest->comp_node();
418 419 420
    auto handle = rng.handle;
    if (!handle) {
        handle = RNGDnnOpManager::get_default_handle(cn);
421 422 423
    }

    // retrieve dnn_op from glob cache
M
Megvii Engine Team 已提交
424 425 426
    auto dnn_op_thread_safe =
            RNGDnnOpManager::inst().get_dnn_op<typename OpMeth<Op>::DnnOp>(
                    handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
427 428 429 430
    auto initialized = std::get<0>(dnn_op_thread_safe);
    auto dnn_op = std::get<1>(dnn_op_thread_safe);
    if (initialized) {
        auto handle_seed = RNGDnnOpManager::get_seed(handle);
M
Megvii Engine Team 已提交
431 432 433 434
        mgb_assert(
                dnn_op->param().seed == handle_seed,
                "inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed,
                dnn_op->param().seed);
435 436
    }
    dnn_op->param() = OpMeth<Op>::make_param(rng);
M
Megvii Engine Team 已提交
437 438
    _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS, OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(
            dnn_op, inputs, outputs);
439 440 441
}

template <typename Op>
442
SmallVector<LogicalTensorDesc> infer_output_attrs(
443
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
444
    LogicalTensorDesc dest;
445 446
    auto&& rng = op.cast_final_safe<Op>();
    auto handle = rng.handle;
447
    if (handle) {
448
        dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
449
    } else {
450
        dest.comp_node = inputs[0]->comp_node();
451
    }
452
    constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
M
Megvii Engine Team 已提交
453 454 455
    if (!rng_with_shape) {
        for (int i = 0; i < inputs.size(); ++i) {
            mgb_assert(
456
                    inputs[i]->comp_node() == dest.comp_node,
M
Megvii Engine Team 已提交
457 458 459 460
                    "%s expects the device of inputs[%d] to be same as the device of "
                    "handle; "
                    "got %s and %s actually",
                    rng.dyn_typeinfo()->name, i,
461 462
                    inputs[i]->comp_node().to_string().c_str(),
                    dest.comp_node.to_string().c_str());
463 464
        }
    }
465 466
    dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng);
    return {dest};
467 468
}

469
template <>
470
SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>(
471
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
472
    SmallVector<LogicalTensorDesc> dests(2);
473 474 475
    auto&& rng = op.cast_final_safe<ShuffleRNG>();
    auto handle = rng.handle;
    if (handle) {
476 477
        dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle);
        dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle);
478
    } else {
479 480 481 482 483 484 485 486
        dests[0].comp_node = inputs[0]->comp_node();
        dests[1].comp_node = inputs[0]->comp_node();
    }
    dests[0].layout = TensorLayout(inputs[0]->layout());
    dests[0].layout.dtype = inputs[0]->layout().dtype;
    dests[1].layout =
            TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32());
    return dests;
487 488
}

489
template <>
490
SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
491
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
492
    SmallVector<LogicalTensorDesc> dests(2);
493 494
    auto&& cn = inputs[0]->comp_node();

495 496 497 498 499 500 501 502 503 504 505 506
    dests[0].comp_node = cn;
    dests[0].layout = TensorLayout(inputs[0]->layout());
    dests[0].layout.dtype = inputs[0]->layout().dtype;

    auto get_mask_size = [&]() -> size_t {
        auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
        return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
                inputs[0]->layout());
    };
    dests[1].comp_node = cn;
    dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
    return dests;
507 508
}

509 510
template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor(
511 512
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
513
    SmallVector<TensorPtr> outputs;
514 515 516
    SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs);
    for (auto&& i : desc) {
        outputs.push_back(Tensor::make(i.layout, i.comp_node));
517
    }
518
    exec<Op>(def, inputs, outputs, {});
519 520 521
    return outputs;
}

522 523
template <typename Op, typename Output>
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
524
    size_t nr_inp = inputs.size();
525
    constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS;
526
    auto&& rng = def.cast_final_safe<Op>();
M
Megvii Engine Team 已提交
527 528 529 530
    if (dnn_nr_inp == 0) {
        mgb_assert(
                nr_inp == 1, "%s expects 1 inputs; got %lu actually",
                rng.dyn_typeinfo()->name, nr_inp);
531
    }
532 533
    constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp;
    return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng);
534 535
}

M
Megvii Engine Team 已提交
536
template <typename Op>
537 538
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
539
    bool success = inputs[0].layout.ndim != 0;
540 541
    LogicalTensorDesc dest;
    auto&& xxx_rng_def = def.cast_final_safe<Op>();
542
    size_t nr_inp = inputs.size();
543
    constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
M
Megvii Engine Team 已提交
544 545 546 547
    if (rng_with_shape) {
        mgb_assert(
                nr_inp == 1, "%s expects 1 inputs; got %lu actually",
                xxx_rng_def.dyn_typeinfo()->name, nr_inp);
548
    }
549
    dest.comp_node = inputs[0].comp_node;
550 551 552 553 554
    if (success) {
        dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
    } else {
        dest.layout = TensorLayout(inputs[0].layout.dtype);
    }
555
    return {{dest}, inputs[0].layout.ndim != 0};
556 557
}

558
template <>
M
Megvii Engine Team 已提交
559 560
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
        ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
561 562
    bool success = inputs[0].layout.ndim != 0;

563 564 565 566 567
    SmallVector<LogicalTensorDesc> dests(2);
    dests[0].comp_node = inputs[0].comp_node;
    dests[0].layout = TensorLayout(inputs[0].layout);
    dests[0].layout.dtype = inputs[0].layout.dtype;
    dests[1].comp_node = inputs[0].comp_node;
568 569 570 571 572 573 574
    if (success) {
        dests[1].layout =
                TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
    } else {
        dests[1].layout = TensorLayout(dtype::Int32());
    }
    return {dests, success};
575 576
}

577 578 579
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>(
        const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
580 581
    bool success = inputs[0].layout.ndim != 0;

582 583 584 585 586 587 588 589 590 591 592 593
    SmallVector<LogicalTensorDesc> dests(2);
    auto cn = inputs[0].comp_node;
    dests[0].comp_node = cn;
    dests[0].layout = TensorLayout(inputs[0].layout);
    dests[0].layout.dtype = inputs[0].layout.dtype;

    auto get_mask_size = [&]() -> size_t {
        auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
        return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
                inputs[0].layout);
    };
    dests[1].comp_node = cn;
594 595 596 597 598 599 600
    if (success) {
        dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
    } else {
        dests[1].layout = TensorLayout(dtype::Byte());
    }

    return {dests, success};
601 602
}

603 604 605 606 607 608 609
template <typename Op>
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
    return layout_checker;
}

610
}  // anonymous namespace
611

612
Handle new_handle(CompNode comp_node, uint64_t seed) {
613 614 615
    return RNGDnnOpManager::inst().new_handle(comp_node, seed);
}

616
size_t delete_handle(Handle handle) {
617 618 619
    return RNGDnnOpManager::inst().delete_handle(handle);
}

620
void set_global_rng_seed(uint64_t seed) {
621 622
    RNGDnnOpManager::set_glob_default_seed(seed);
}
623 624 625 626 627

uint64_t get_global_rng_seed() {
    return RNGDnnOpManager::get_glob_default_seed();
}

M
Megvii Engine Team 已提交
628
CompNode get_rng_handle_compnode(Handle handle) {
629 630 631
    return RNGDnnOpManager::get_comp_node(handle);
}

632 633 634 635 636 637
#define REG_RNG_OP(NAME, Output)                                            \
    namespace {                                                             \
    OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode)                          \
            .apply_on_var_node(apply_on_var_node<NAME, Output>)             \
            .apply_on_physical_tensor(apply_on_physical_tensor<NAME>)       \
            .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
638
            .get_input_layout_constraint(get_input_layout_constraint<NAME>) \
639 640 641 642 643 644 645 646 647 648
            .fallback();                                                    \
    }

REG_RNG_OP(UniformRNG, SymbolVar)
REG_RNG_OP(GaussianRNG, SymbolVar)
REG_RNG_OP(GammaRNG, SymbolVar)
REG_RNG_OP(PermutationRNG, SymbolVar)
REG_RNG_OP(PoissonRNG, SymbolVar)
REG_RNG_OP(BetaRNG, SymbolVar)
REG_RNG_OP(ShuffleRNG, SymbolVarArray)
649
REG_RNG_OP(Dropout, SymbolVarArray)
650
#undef REG_RNG_OP
651

652
}  // namespace mgb::imperative::rng
653 654

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}