/** * \file imperative/src/impl/ops/rng.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/ops/rng.h" #include #include "megbrain/comp_node_env.h" #include "megbrain/graph/helper.h" #include "megbrain/opr/rand.h" //#include "megbrain/common.h" #include "../op_trait.h" namespace mgb { namespace imperative { namespace { template class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { public: using Handle = THandle; template Handle new_handle(Args&&... args) { return static_cast(this)->do_new_handle( std::forward(args)...); } size_t delete_handle(Handle handle) { size_t removed = 0; if (!is_finalized()) { MGB_LOCK_GUARD(m_mtx); removed = m_handle2op.erase(handle); } static_cast(this)->do_delete_handle(handle); return removed; } template auto get_dnn_op(Handle handle, CompNode cn) { mgb_assert(!is_finalized()); DnnOpWithMutex* dnn_op_with_mtx; { MGB_LOCK_GUARD(m_mtx); dnn_op_with_mtx = &m_handle2op[handle]; } auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); DnnOp* dnn_op; std::unique_lock lock(dnn_op_with_mtx->mtx); bool initialized = false; if ((dnn_op = dynamic_cast(dnn_op_with_mtx->op.get())) != nullptr) { mgb_assert(dnn_op->handle() == dnn_handle); initialized = true; } else { auto new_op = dnn_handle->create_operator(); dnn_op = new_op.get(); dnn_op_with_mtx->op = std::move(new_op); } return std::make_tuple(initialized, dnn_op, std::move(lock)); } protected: using DnnOpManagerBase = DnnOpManagerT; DnnOpManagerT() = default; private: struct DnnOpWithMutex { std::mutex mtx; std::unique_ptr op; }; std::shared_ptr on_comp_node_finalize() override { MGB_LOCK_GUARD(m_mtx); m_handle2op.clear(); return {}; } std::unordered_map m_handle2op; std::mutex m_mtx; }; class RNGDnnOpManager final : public DnnOpManagerT { public: size_t delete_handle(Handle handle) { size_t ret = 0; { MGB_LOCK_GUARD(sm_mtx); auto iter = sm_partial2full.find(handle); if (iter != sm_partial2full.end()) { for (auto&& h : iter->second) { ret += DnnOpManagerBase::delete_handle(h.second); } sm_partial2full.erase(iter); } } ret += DnnOpManagerBase::delete_handle(handle); return ret; } Handle do_new_handle(CompNode comp_node, uint64_t seed) { auto handle = m_handle_pool.alloc(comp_node, seed); return reinterpret_cast(handle); } void do_delete_handle(Handle handle) { m_handle_pool.free(reinterpret_cast(handle)); } static uint64_t get_seed(Handle handle) { return reinterpret_cast(handle)->seed; } static CompNode get_comp_node(Handle handle) { return reinterpret_cast(handle)->comp_node; } static Handle get_full_handle(Handle handle, CompNode comp_node) { if (get_comp_node(handle).valid()) { return handle; } MGB_LOCK_GUARD(sm_mtx); auto&& full = sm_partial2full[handle][comp_node]; if (!full) { full = inst().new_handle(comp_node, get_seed(handle)); } return full; } static Handle get_default_handle(CompNode comp_node) { static Handle glob_partial_handle = inst().new_handle(CompNode{}, glob_default_seed); if (!comp_node.valid()) { return glob_partial_handle; } return get_full_handle(glob_partial_handle, comp_node); } static RNGDnnOpManager& inst() { static RNGDnnOpManager mgr; return mgr; } static void set_glob_default_seed(uint64_t seed) { glob_default_seed = seed; } private: struct HandleData { CompNode comp_node; uint64_t seed; HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {} }; MemPool m_handle_pool; static std::mutex sm_mtx; static std::unordered_map> sm_partial2full; static uint64_t glob_default_seed; }; uint64_t RNGDnnOpManager::glob_default_seed = 0; std::mutex RNGDnnOpManager::sm_mtx; std::unordered_map> RNGDnnOpManager::sm_partial2full; template struct OpMeth; template <> struct OpMeth { using DnnOp = megdnn::UniformRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::UniformRNG; static Param make_param(const UniformRNG& rng) { return {RNGDnnOpManager::get_seed(rng.handle())}; } }; template <> struct OpMeth { using DnnOp = megdnn::GaussianRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::GaussianRNG; static Param make_param(const GaussianRNG& rng) { return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std}; } }; template void exec(const OpDef& op, const SmallVector& inputs, const SmallVector& outputs) { auto&& rng = op.cast_final_safe(); auto dest = outputs[0]; auto cn = dest->comp_node(); auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn); { auto handle_cn = RNGDnnOpManager::get_comp_node(handle); mgb_assert(cn == handle_cn, "inconsistent comp_node: handle: %s, output: %s", cn.to_string().c_str(), handle_cn.to_string().c_str()); } // retrieve dnn_op from glob cache auto dnn_op_thread_safe = RNGDnnOpManager::inst() .get_dnn_op::DnnOp>(handle, cn); auto initialized = std::get<0>(dnn_op_thread_safe); auto dnn_op = std::get<1>(dnn_op_thread_safe); if (initialized) { auto handle_seed = RNGDnnOpManager::get_seed(handle); mgb_assert(dnn_op->param().seed == handle_seed, "inconsistent rng seed: handle: %zu, dnn_op: %zu", handle_seed, dnn_op->param().seed); } dnn_op->param() = OpMeth::make_param(rng); // allocate workspace size_t wk_size = dnn_op->get_workspace_in_bytes(dest->layout()); auto workspace = Blob::make(cn, wk_size); megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); dnn_op->exec(dest->dev_tensor().as_megdnn(), dnn_wk); } template SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { LogicalTensorDesc dest; dest.comp_node = op.cast_final_safe().comp_node(); if (!dest.comp_node.valid()) dest.comp_node = inputs[0]->comp_node(); auto hv = inputs[0]->get_value().proxy_to_default_cpu(); TensorShape tshape; cg::copy_tensor_value_to_shape(tshape, hv); dest.layout = TensorLayout(tshape, dtype::Float32()); return {dest}; } template SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { auto desc = infer_output_attrs(def, inputs); SmallVector outputs; for (auto&& i : desc) { outputs.push_back(Tensor::make(i.layout, i.comp_node)); } exec(def, inputs, outputs); return outputs; } template cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { size_t nr_inp = inputs.size(); mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually", nr_inp); auto&& rng = def.cast_final_safe(); auto param = OpMeth::make_param(rng); return OpMeth::OpNode::make( inputs[0], param, {rng.comp_node()}).node()->owner_opr(); } template std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& xxx_rng_def = def.cast_final_safe(); size_t nr_inp = inputs.size(); mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", xxx_rng_def.dyn_typeinfo()->name, nr_inp); auto&& tshp = inputs[0]; TensorLayout out_layout = tshp.layout; out_layout.dtype = dtype::Float32(); if (tshp.layout.ndim == 0 || tshp.value.empty()) { out_layout.ndim = 0; return {{{out_layout, tshp.comp_node}}, true}; } mgb_assert( tshp.layout.ndim == 1, "target shape of %s expects ndim=1; got ndim=%lu actually", xxx_rng_def.dyn_typeinfo()->name, tshp.layout.ndim); size_t target_ndim = tshp.layout.shape[0]; out_layout.ndim = target_ndim; auto* ptr = tshp.value.ptr(); for (size_t i = 0; i < target_ndim; ++i) { out_layout.shape[i] = ptr[i]; } return {{{out_layout, tshp.comp_node}}, true}; } } // anonymous namespace RNGMixin::RNGMixin(CompNode cn): m_handle(RNGDnnOpManager::get_default_handle(cn)) {} uint64_t RNGMixin::seed() const { return RNGDnnOpManager::get_seed(m_handle); } CompNode RNGMixin::comp_node() const { return RNGDnnOpManager::get_comp_node(m_handle); } RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) { return RNGDnnOpManager::inst().new_handle(comp_node, seed); } size_t RNGMixin::delete_handle(Handle handle) { return RNGDnnOpManager::inst().delete_handle(handle); } void set_rng_seed(uint64_t seed) { RNGDnnOpManager::set_glob_default_seed(seed); } #define REG_RNG_OP(NAME)\ namespace { \ OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ .apply_on_var_node(apply_on_var_node) \ .apply_on_physical_tensor(apply_on_physical_tensor) \ .infer_output_attrs_fallible(infer_output_attrs_fallible) \ .fallback(); \ } \ MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME); REG_RNG_OP(UniformRNG) REG_RNG_OP(GaussianRNG) } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}