rng.cpp 18.3 KB
Newer Older
1 2 3 4
/**
 * \file imperative/src/impl/ops/rng.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/imperative/ops/rng.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"

#include "../op_trait.h"
18
#include "../dnn_op_helper.h"
19

20
namespace mgb::imperative::rng {
21 22 23 24 25 26

namespace {

template <typename HandleFactory, typename THandle>
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
public:
27
    using DT = CompNode::DeviceType;
28
    using Handle = THandle;
29
    using OpTypeInfo = size_t;
30 31 32 33 34 35 36 37 38 39 40

    template <typename... Args>
    Handle new_handle(Args&&... args) {
        return static_cast<HandleFactory*>(this)->do_new_handle(
                std::forward<Args>(args)...);
    }

    size_t delete_handle(Handle handle) {
        size_t removed = 0;
        if (!is_finalized()) {
            MGB_LOCK_GUARD(m_mtx);
41
            removed = m_handle2ops.erase(handle);
42 43 44 45 46 47
        }
        static_cast<HandleFactory*>(this)->do_delete_handle(handle);
        return removed;
    }

    template <typename DnnOp>
48
    auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
49 50 51 52
        mgb_assert(!is_finalized());
        DnnOpWithMutex* dnn_op_with_mtx;
        {
            MGB_LOCK_GUARD(m_mtx);
53
            dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
54 55 56 57 58
        }
        auto dnn_handle =
                MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
        std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
        bool initialized = false;
59 60
        DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
        if (dnn_op != nullptr) {
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
            mgb_assert(dnn_op->handle() == dnn_handle);
            initialized = true;
        } else {
            auto new_op = dnn_handle->create_operator<DnnOp>();
            dnn_op = new_op.get();
            dnn_op_with_mtx->op = std::move(new_op);
        }
        return std::make_tuple(initialized, dnn_op, std::move(lock));
    }

protected:
    using DnnOpManagerBase = DnnOpManagerT<HandleFactory, Handle>;
    DnnOpManagerT() = default;

private:
    struct DnnOpWithMutex {
        std::mutex mtx;
        std::unique_ptr<megdnn::OperatorBase> op;
79
        DnnOpWithMutex(): op{nullptr} {}
80 81 82 83
    };

    std::shared_ptr<void> on_comp_node_finalize() override {
        MGB_LOCK_GUARD(m_mtx);
84
        m_handle2ops.clear();
85 86 87
        return {};
    }

88
    std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex> > m_handle2ops;
89 90 91 92
    std::mutex m_mtx;
};

class RNGDnnOpManager final
93
        : public DnnOpManagerT<RNGDnnOpManager, Handle> {
94
public:
95 96 97 98 99
    Handle new_handle(CompNode comp_node, uint64_t seed) {
        MGB_LOCK_GUARD(sm_mtx);
        return DnnOpManagerBase::new_handle(comp_node, seed);
    }

100
    size_t delete_handle(Handle handle) {
101 102
        MGB_LOCK_GUARD(sm_mtx);
        return DnnOpManagerBase::delete_handle(handle);
103 104 105 106 107 108 109 110 111 112 113 114
    }

    Handle do_new_handle(CompNode comp_node, uint64_t seed) {
        auto handle = m_handle_pool.alloc(comp_node, seed);
        return reinterpret_cast<Handle>(handle);
    }

    void do_delete_handle(Handle handle) {
        m_handle_pool.free(reinterpret_cast<HandleData*>(handle));
    }

    static uint64_t get_seed(Handle handle) {
115
        if (!handle) { return glob_default_seed; }
116 117 118 119
        return reinterpret_cast<HandleData*>(handle)->seed;
    }

    static CompNode get_comp_node(Handle handle) {
120
        mgb_assert(handle, "invalid handle");
121 122 123 124
        return reinterpret_cast<HandleData*>(handle)->comp_node;
    }

    static Handle get_default_handle(CompNode comp_node) {
125 126 127 128 129
        mgb_assert(comp_node.valid());
        MGB_LOCK_GUARD(sm_mtx);
        auto&& glob_handle = glob_default_handles[comp_node];
        if (!glob_handle) {
            glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
130
        }
131
        mgb_assert(get_seed(glob_handle) == glob_default_seed);
132
        return glob_handle;
133 134 135 136 137 138 139 140
    }

    static RNGDnnOpManager& inst() {
        static RNGDnnOpManager mgr;
        return mgr;
    }

    static void set_glob_default_seed(uint64_t seed) {
141
        MGB_LOCK_GUARD(sm_mtx);
142 143 144 145 146 147 148
        for(auto && elem : glob_default_handles){
            mgb_assert(elem.first.valid());
            if(elem.second){
                inst().DnnOpManagerBase::delete_handle(elem.second);
            }
            elem.second = inst().do_new_handle(elem.first, seed);
        }
149 150 151
        glob_default_seed = seed;
    }

152 153 154 155 156
    static uint64_t get_glob_default_seed() {
        MGB_LOCK_GUARD(sm_mtx);
        return glob_default_seed;
    }

157 158 159 160 161 162 163 164 165 166
private:
    struct HandleData {
        CompNode comp_node;
        uint64_t seed;
        HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {}
    };

    MemPool<HandleData> m_handle_pool;

    static std::mutex sm_mtx;
167
    static CompNode::UnorderedMap<Handle> glob_default_handles;
168 169 170 171 172
    static uint64_t glob_default_seed;
};

uint64_t RNGDnnOpManager::glob_default_seed = 0;
std::mutex RNGDnnOpManager::sm_mtx;
173
CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;
174 175 176 177 178 179 180 181 182 183

template <typename Op>
struct OpMeth;

template <>
struct OpMeth<UniformRNG> {
    using DnnOp = megdnn::UniformRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::UniformRNG;
    static Param make_param(const UniformRNG& rng) {
184 185 186 187 188 189 190 191 192 193 194 195 196 197
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
        mgb_assert(handle_seed == rng.seed,
            "inconsistent rng seed: rng op: %lu handle: %lu",
            handle_seed, rng.seed);
        return {handle_seed, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<PoissonRNG> {
    using DnnOp = megdnn::PoissonRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::PoissonRNG;
    static Param make_param(const PoissonRNG& rng) {
198 199 200 201 202
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
        mgb_assert(handle_seed == rng.seed,
            "inconsistent rng seed: rng op: %lu handle: %lu",
            handle_seed, rng.seed);
        return {handle_seed};
203 204 205 206 207 208 209 210 211
    }
};

template <>
struct OpMeth<GaussianRNG> {
    using DnnOp = megdnn::GaussianRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::GaussianRNG;
    static Param make_param(const GaussianRNG& rng) {
212 213 214 215
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
        mgb_assert(handle_seed == rng.seed,
            "inconsistent rng seed: rng op: %lu handle: %lu",
            handle_seed, rng.seed);
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
        return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<GammaRNG> {
    using DnnOp = megdnn::GammaRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::GammaRNG;
    static Param make_param(const GammaRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
        mgb_assert(handle_seed == rng.seed,
            "inconsistent rng seed: rng op: %lu handle: %lu",
            handle_seed, rng.seed);
        return {handle_seed};
231 232 233
    }
};

234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
template <>
struct OpMeth<PermutationRNG> {
    using DnnOp = megdnn::PermutationRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::PermutationRNG;
    static Param make_param(const PermutationRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
        mgb_assert(handle_seed == rng.seed,
            "inconsistent rng seed: rng op: %lu handle: %lu",
            handle_seed, rng.seed);
        return {handle_seed, rng.dtype.enumv()};
    }
};

template <>
struct OpMeth<BetaRNG> {
    using DnnOp = megdnn::BetaRNG;
    using Param = DnnOp::Param;
    using OpNode = mgb::opr::BetaRNG;
    static Param make_param(const BetaRNG& rng) {
        auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
        mgb_assert(handle_seed == rng.seed,
            "inconsistent rng seed: rng op: %lu handle: %lu",
            handle_seed, rng.seed);
        return {handle_seed};
    }
};

template <bool>
struct _InferLayout;

template <int nr_in>
struct _RNGOprMaker;

template <int nr_in>
struct _RNGOprInvoker;

template<>
struct _InferLayout<true>
{
    template<typename Op>
    static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){
        TensorShape tshape;
        auto hv = inp->get_value().proxy_to_default_cpu();
        cg::copy_tensor_value_to_shape(tshape, hv);
        return TensorLayout(tshape, rng.dtype);
    }

    template<typename Op>
    static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){
        TensorLayout out_layout = inp.layout;
        out_layout.dtype = rng.dtype;
        if (inp.layout.ndim == 0 || inp.value.empty()) {
            out_layout.ndim = 0;
            return out_layout;
        }
        mgb_assert(
                inp.layout.ndim == 1,
                "target shape of %s expects ndim=1; got ndim=%lu actually",
                rng.dyn_typeinfo()->name,
                inp.layout.ndim);
        size_t target_ndim = inp.layout.shape[0];
        out_layout.ndim = target_ndim;
        auto* ptr = inp.value.ptr<dt_int32>();
        for (size_t i = 0; i < target_ndim; ++i) {
            out_layout.shape[i] = ptr[i];
        }
        return out_layout;
    }
};

template<>
struct _InferLayout<false>
{
    template<typename Op>
    static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){
        return inp->layout();
    }

    template<typename Op>
    static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){
315
        mgb_assert(inp.layout.ndim);
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
        return inp.layout;
    }
};
                                   
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS)                                                              \
template<>                                                                                             \
struct _RNGOprInvoker<DNN_NR_INPUTS> {                                                                 \
    template<typename Opr>                                                                             \
    static void exec(Opr *dnn_op, const SmallVector<TensorPtr>& inputs,const TensorPtr& dest){         \
        size_t wk_size = 0;                                                                            \
        wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout());              \
        auto workspace = Blob::make(dest->comp_node(), wk_size);                                       \
        megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size);                                 \
        dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn())                                          \
                                 dest->dev_tensor().as_megdnn(), dnn_wk);                              \
    }                                                                                                  \
};

334

335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
#define _INST_RNG_MAKER(MGB_NR_INPUTS)                                                                 \
template<>                                                                                             \
struct _RNGOprMaker<MGB_NR_INPUTS> {                                                                   \
    template<typename Op>                                                                              \
    static SymbolVar make(const VarNodeArray& inputs, const Op& rng){                                  \
        auto param = OpMeth<Op>::make_param(rng);                                                      \
        OperatorNodeConfig config;                                                                     \
        if (rng.handle) {                                                                              \
            config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)};                    \
        } else {                                                                                       \
            config = {rng.make_name()};                                                                \
        }                                                                                              \
        return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config);                                 \
    }                                                                                                  \
};

#define _FOR_EACH_IN(subfix)   
_INST_RNG_INVOLKER(0)
#undef _FOR_EACH_IN

#define _FOR_EACH_IN(subfix) inputs[0] subfix,
_INST_RNG_INVOLKER(1)
_INST_RNG_MAKER(1)
#undef _FOR_EACH_IN

#define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix,
_INST_RNG_INVOLKER(2)
_INST_RNG_MAKER(2)
#undef _FOR_EACH_IN

#undef _INST_RNG_INVOLKER
#undef _INST_RNG_MAKER

368 369
template <typename Op>
void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
370
          const SmallVector<TensorPtr>& outputs, const SmallVector<TensorPtr>& workspace) {
371
    auto&& rng = op.cast_final_safe<Op>();
372
 
373
    auto dest = outputs[0];
374
    if (dest->layout().is_empty()) return;
375
    auto cn = dest->comp_node();
376 377 378
    auto handle = rng.handle;
    if (!handle) {
        handle = RNGDnnOpManager::get_default_handle(cn);
379 380 381 382
    }

    // retrieve dnn_op from glob cache
    auto dnn_op_thread_safe = RNGDnnOpManager::inst()
383 384 385
            .get_dnn_op<typename OpMeth<Op>::DnnOp>(
                handle, reinterpret_cast<size_t>(op.dyn_typeinfo()),
                cn);
386 387 388 389 390
    auto initialized = std::get<0>(dnn_op_thread_safe);
    auto dnn_op = std::get<1>(dnn_op_thread_safe);
    if (initialized) {
        auto handle_seed = RNGDnnOpManager::get_seed(handle);
        mgb_assert(dnn_op->param().seed == handle_seed,
391
            "inconsistent rng seed: handle: %lu, dnn_op: %lu",
392 393 394
            handle_seed, dnn_op->param().seed);
    }
    dnn_op->param() = OpMeth<Op>::make_param(rng);
395
    _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest);
396 397 398 399 400 401
}

template <typename Op>
SmallVector<LogicalTensorDesc> infer_output_attrs(
        const OpDef& op, const SmallVector<TensorPtr>& inputs) {
    LogicalTensorDesc dest;
402 403
    auto&& rng = op.cast_final_safe<Op>();
    auto handle = rng.handle;
404 405 406
    if (handle) {
        dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
    } else {
407
        dest.comp_node = inputs[0]->comp_node();
408
    }
409 410 411 412 413 414 415 416 417 418 419
    constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
    if(!rng_with_shape){
        for(int i = 0; i < inputs.size(); ++i){
            mgb_assert(inputs[i]->comp_node() == dest.comp_node, 
                    "%s expects the device of inputs[%d] to be same as the device of handle; "
                    "got %s and %s actually", rng.dyn_typeinfo()->name, i,
                    inputs[i]->comp_node().to_string().c_str(),
                    dest.comp_node.to_string().c_str());
        }
    }
    dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng);
420 421 422
    return {dest};
}

423 424 425 426 427 428 429 430 431 432 433 434
template <typename Op>
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
        const OpDef& def,
        const SmallVector<TensorPtr>& inputs_tensors,
        const SmallVector<MemoryDesc>& inputs_mems) {
    auto &&dest = infer_output_attrs<Op>(def, inputs_tensors);
    SmallVector<MemoryDesc> outputs = {{dest[0].layout, 0, dest[0].comp_node, StorageIdentifier::make(1)}};
    
    return {outputs, {}};    
}


435 436 437 438
template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    SmallVector<TensorPtr> outputs;
439 440
    SmallVector<LogicalTensorDesc> desc; 
    desc = infer_output_attrs<Op>(def, inputs);
441 442 443
    for (auto&& i : desc) {
        outputs.push_back(Tensor::make(i.layout, i.comp_node));
    }
444
    exec<Op>(def, inputs, outputs, {});
445 446 447
    return outputs;
}

448 449 450 451 452 453 454 455 456
template <typename Op>
void execute(
        const OpDef& def,
        SmallVector<TensorPtr> inputs,
        SmallVector<TensorPtr> outputs,
        SmallVector<TensorPtr> workspace) {
    exec<Op>(def, inputs, outputs, {});
}

457
template<typename Op>
458 459 460
SymbolVar apply_on_var_node(
        const OpDef& def,
        const VarNodeArray& inputs) {
461
    size_t nr_inp = inputs.size();
462
    constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS;
463
    auto&& rng = def.cast_final_safe<Op>();
464 465 466 467
    if(dnn_nr_inp == 0){
        mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually",
                rng.dyn_typeinfo()->name,
                nr_inp);
468
    }
469 470
    constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp;
    return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng);
471 472
}

473
template<typename Op>
474 475
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
476 477
    LogicalTensorDesc dest;
    auto&& xxx_rng_def = def.cast_final_safe<Op>();
478
    size_t nr_inp = inputs.size();
479 480 481 482 483
    constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
    if (rng_with_shape){
        mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually",
                xxx_rng_def.dyn_typeinfo()->name,
                nr_inp);
484
    }
485 486 487
    dest.comp_node = inputs[0].comp_node;
    dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
    return {{dest}, true};
488 489 490 491
}

} // anonymous namespace

492
Handle new_handle(CompNode comp_node, uint64_t seed) {
493 494 495
    return RNGDnnOpManager::inst().new_handle(comp_node, seed);
}

496
size_t delete_handle(Handle handle) {
497 498 499
    return RNGDnnOpManager::inst().delete_handle(handle);
}

500
void set_global_rng_seed(uint64_t seed) {
501 502
    RNGDnnOpManager::set_glob_default_seed(seed);
}
503 504 505 506 507

uint64_t get_global_rng_seed() {
    return RNGDnnOpManager::get_glob_default_seed();
}

508 509 510 511
CompNode get_rng_handle_compnode(Handle handle){
    return RNGDnnOpManager::get_comp_node(handle);
}

512 513 514 515 516 517
#define REG_RNG_OP(NAME)\
namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
    .apply_on_var_node(apply_on_var_node<NAME>) \
    .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
    .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
518 519
    .infer_output_mem_desc(infer_output_mem_desc<NAME>) \
    .execute(execute<NAME>) \
520 521 522 523 524
    .fallback(); \
} \

REG_RNG_OP(UniformRNG)
REG_RNG_OP(GaussianRNG)
525 526 527 528 529
REG_RNG_OP(GammaRNG)
REG_RNG_OP(PermutationRNG)
REG_RNG_OP(PoissonRNG)
REG_RNG_OP(BetaRNG)
#undef REG_RNG_OP
530

531
}  // namespace mgb::imperative::rng
532 533

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}