/** * \file imperative/src/impl/ops/rng.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/ops/rng.h" #include "megbrain/comp_node_env.h" #include "megbrain/graph/helper.h" #include "megbrain/opr/rand.h" #include "../op_trait.h" #include "../dnn_op_helper.h" namespace mgb::imperative::rng { namespace { template class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { public: using DT = CompNode::DeviceType; using Handle = THandle; using OpTypeInfo = size_t; template Handle new_handle(Args&&... args) { return static_cast(this)->do_new_handle( std::forward(args)...); } size_t delete_handle(Handle handle) { size_t removed = 0; if (!is_finalized()) { MGB_LOCK_GUARD(m_mtx); removed = m_handle2ops.erase(handle); } static_cast(this)->do_delete_handle(handle); return removed; } template auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) { mgb_assert(!is_finalized()); DnnOpWithMutex* dnn_op_with_mtx; { MGB_LOCK_GUARD(m_mtx); dnn_op_with_mtx = &m_handle2ops[handle][tpinfo]; } auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); std::unique_lock lock(dnn_op_with_mtx->mtx); bool initialized = false; DnnOp* dnn_op = static_cast(dnn_op_with_mtx->op.get()); if (dnn_op != nullptr) { mgb_assert(dnn_op->handle() == dnn_handle); initialized = true; } else { auto new_op = dnn_handle->create_operator(); dnn_op = new_op.get(); dnn_op_with_mtx->op = std::move(new_op); } return std::make_tuple(initialized, dnn_op, std::move(lock)); } protected: using DnnOpManagerBase = DnnOpManagerT; DnnOpManagerT() = default; private: struct DnnOpWithMutex { std::mutex mtx; std::unique_ptr op; DnnOpWithMutex(): op{nullptr} {} }; std::shared_ptr on_comp_node_finalize() override { MGB_LOCK_GUARD(m_mtx); m_handle2ops.clear(); return {}; } std::unordered_map > m_handle2ops; std::mutex m_mtx; }; class RNGDnnOpManager final : public DnnOpManagerT { public: Handle new_handle(CompNode comp_node, uint64_t seed) { MGB_LOCK_GUARD(sm_mtx); return DnnOpManagerBase::new_handle(comp_node, seed); } size_t delete_handle(Handle handle) { MGB_LOCK_GUARD(sm_mtx); return DnnOpManagerBase::delete_handle(handle); } Handle do_new_handle(CompNode comp_node, uint64_t seed) { auto handle = m_handle_pool.alloc(comp_node, seed); return reinterpret_cast(handle); } void do_delete_handle(Handle handle) { m_handle_pool.free(reinterpret_cast(handle)); } static uint64_t get_seed(Handle handle) { if (!handle) { return glob_default_seed; } return reinterpret_cast(handle)->seed; } static CompNode get_comp_node(Handle handle) { mgb_assert(handle, "invalid handle"); return reinterpret_cast(handle)->comp_node; } static Handle get_default_handle(CompNode comp_node) { mgb_assert(comp_node.valid()); MGB_LOCK_GUARD(sm_mtx); auto&& glob_handle = glob_default_handles[comp_node]; if (!glob_handle) { glob_handle = inst().do_new_handle(comp_node, glob_default_seed); } mgb_assert(get_seed(glob_handle) == glob_default_seed); return glob_handle; } static RNGDnnOpManager& inst() { static RNGDnnOpManager mgr; return mgr; } static void set_glob_default_seed(uint64_t seed) { MGB_LOCK_GUARD(sm_mtx); for(auto && elem : glob_default_handles){ mgb_assert(elem.first.valid()); if(elem.second){ inst().DnnOpManagerBase::delete_handle(elem.second); } elem.second = inst().do_new_handle(elem.first, seed); } glob_default_seed = seed; } static uint64_t get_glob_default_seed() { MGB_LOCK_GUARD(sm_mtx); return glob_default_seed; } private: struct HandleData { CompNode comp_node; uint64_t seed; HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {} }; MemPool m_handle_pool; static std::mutex sm_mtx; static CompNode::UnorderedMap glob_default_handles; static uint64_t glob_default_seed; }; uint64_t RNGDnnOpManager::glob_default_seed = 0; std::mutex RNGDnnOpManager::sm_mtx; CompNode::UnorderedMap RNGDnnOpManager::glob_default_handles; template struct OpMeth; template <> struct OpMeth { using DnnOp = megdnn::UniformRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::UniformRNG; static Param make_param(const UniformRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert(handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed, rng.dtype.enumv()}; } }; template <> struct OpMeth { using DnnOp = megdnn::PoissonRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::PoissonRNG; static Param make_param(const PoissonRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert(handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed}; } }; template <> struct OpMeth { using DnnOp = megdnn::GaussianRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::GaussianRNG; static Param make_param(const GaussianRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert(handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()}; } }; template <> struct OpMeth { using DnnOp = megdnn::GammaRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::GammaRNG; static Param make_param(const GammaRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert(handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed}; } }; template <> struct OpMeth { using DnnOp = megdnn::PermutationRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::PermutationRNG; static Param make_param(const PermutationRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert(handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed, rng.dtype.enumv()}; } }; template <> struct OpMeth { using DnnOp = megdnn::BetaRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::BetaRNG; static Param make_param(const BetaRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert(handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed}; } }; template struct _InferLayout; template struct _RNGOprMaker; template struct _RNGOprInvoker; template<> struct _InferLayout { template static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){ TensorShape tshape; auto hv = inp->get_value().proxy_to_default_cpu(); cg::copy_tensor_value_to_shape(tshape, hv); return TensorLayout(tshape, rng.dtype); } template static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ TensorLayout out_layout = inp.layout; out_layout.dtype = rng.dtype; if (inp.layout.ndim == 0 || inp.value.empty()) { out_layout.ndim = 0; return out_layout; } mgb_assert( inp.layout.ndim == 1, "target shape of %s expects ndim=1; got ndim=%lu actually", rng.dyn_typeinfo()->name, inp.layout.ndim); size_t target_ndim = inp.layout.shape[0]; out_layout.ndim = target_ndim; auto* ptr = inp.value.ptr(); for (size_t i = 0; i < target_ndim; ++i) { out_layout.shape[i] = ptr[i]; } return out_layout; } }; template<> struct _InferLayout { template static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){ return inp->layout(); } template static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ mgb_assert(inp.layout.ndim); return inp.layout; } }; #define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \ template<> \ struct _RNGOprInvoker { \ template \ static void exec(Opr *dnn_op, const SmallVector& inputs,const TensorPtr& dest){ \ size_t wk_size = 0; \ wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \ auto workspace = Blob::make(dest->comp_node(), wk_size); \ megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ dest->dev_tensor().as_megdnn(), dnn_wk); \ } \ }; #define _INST_RNG_MAKER(MGB_NR_INPUTS) \ template<> \ struct _RNGOprMaker { \ template \ static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \ auto param = OpMeth::make_param(rng); \ OperatorNodeConfig config; \ if (rng.handle) { \ config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \ } else { \ config = {rng.make_name()}; \ } \ return OpMeth::OpNode::make(_FOR_EACH_IN() param, config); \ } \ }; #define _FOR_EACH_IN(subfix) _INST_RNG_INVOLKER(0) #undef _FOR_EACH_IN #define _FOR_EACH_IN(subfix) inputs[0] subfix, _INST_RNG_INVOLKER(1) _INST_RNG_MAKER(1) #undef _FOR_EACH_IN #define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, _INST_RNG_INVOLKER(2) _INST_RNG_MAKER(2) #undef _FOR_EACH_IN #undef _INST_RNG_INVOLKER #undef _INST_RNG_MAKER template void exec(const OpDef& op, const SmallVector& inputs, const SmallVector& outputs, const SmallVector& workspace) { auto&& rng = op.cast_final_safe(); auto dest = outputs[0]; if (dest->layout().is_empty()) return; auto cn = dest->comp_node(); auto handle = rng.handle; if (!handle) { handle = RNGDnnOpManager::get_default_handle(cn); } // retrieve dnn_op from glob cache auto dnn_op_thread_safe = RNGDnnOpManager::inst() .get_dnn_op::DnnOp>( handle, reinterpret_cast(op.dyn_typeinfo()), cn); auto initialized = std::get<0>(dnn_op_thread_safe); auto dnn_op = std::get<1>(dnn_op_thread_safe); if (initialized) { auto handle_seed = RNGDnnOpManager::get_seed(handle); mgb_assert(dnn_op->param().seed == handle_seed, "inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed, dnn_op->param().seed); } dnn_op->param() = OpMeth::make_param(rng); _RNGOprInvoker::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest); } template SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { LogicalTensorDesc dest; auto&& rng = op.cast_final_safe(); auto handle = rng.handle; if (handle) { dest.comp_node = RNGDnnOpManager::get_comp_node(handle); } else { dest.comp_node = inputs[0]->comp_node(); } constexpr bool rng_with_shape = OpMeth::DnnOp::NR_INPUTS == 0; if(!rng_with_shape){ for(int i = 0; i < inputs.size(); ++i){ mgb_assert(inputs[i]->comp_node() == dest.comp_node, "%s expects the device of inputs[%d] to be same as the device of handle; " "got %s and %s actually", rng.dyn_typeinfo()->name, i, inputs[i]->comp_node().to_string().c_str(), dest.comp_node.to_string().c_str()); } } dest.layout = _InferLayout::do_infer(inputs[0], rng); return {dest}; } template std::tuple, SmallVector> infer_output_mem_desc( const OpDef& def, const SmallVector& inputs_tensors, const SmallVector& inputs_mems) { auto &&dest = infer_output_attrs(def, inputs_tensors); SmallVector outputs = {{dest[0].layout, 0, dest[0].comp_node, StorageIdentifier::make(1)}}; return {outputs, {}}; } template SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { SmallVector outputs; SmallVector desc; desc = infer_output_attrs(def, inputs); for (auto&& i : desc) { outputs.push_back(Tensor::make(i.layout, i.comp_node)); } exec(def, inputs, outputs, {}); return outputs; } template void execute( const OpDef& def, SmallVector inputs, SmallVector outputs, SmallVector workspace) { exec(def, inputs, outputs, {}); } template SymbolVar apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { size_t nr_inp = inputs.size(); constexpr size_t dnn_nr_inp = OpMeth::DnnOp::NR_INPUTS; auto&& rng = def.cast_final_safe(); if(dnn_nr_inp == 0){ mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", rng.dyn_typeinfo()->name, nr_inp); } constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp; return _RNGOprMaker::make(inputs, rng); } template std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { LogicalTensorDesc dest; auto&& xxx_rng_def = def.cast_final_safe(); size_t nr_inp = inputs.size(); constexpr bool rng_with_shape = OpMeth::DnnOp::NR_INPUTS == 0; if (rng_with_shape){ mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", xxx_rng_def.dyn_typeinfo()->name, nr_inp); } dest.comp_node = inputs[0].comp_node; dest.layout = _InferLayout::do_infer(inputs[0], xxx_rng_def); return {{dest}, true}; } } // anonymous namespace Handle new_handle(CompNode comp_node, uint64_t seed) { return RNGDnnOpManager::inst().new_handle(comp_node, seed); } size_t delete_handle(Handle handle) { return RNGDnnOpManager::inst().delete_handle(handle); } void set_global_rng_seed(uint64_t seed) { RNGDnnOpManager::set_glob_default_seed(seed); } uint64_t get_global_rng_seed() { return RNGDnnOpManager::get_glob_default_seed(); } CompNode get_rng_handle_compnode(Handle handle){ return RNGDnnOpManager::get_comp_node(handle); } #define REG_RNG_OP(NAME)\ namespace { \ OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ .apply_on_var_node(apply_on_var_node) \ .apply_on_physical_tensor(apply_on_physical_tensor) \ .infer_output_attrs_fallible(infer_output_attrs_fallible) \ .infer_output_mem_desc(infer_output_mem_desc) \ .execute(execute) \ .fallback(); \ } \ REG_RNG_OP(UniformRNG) REG_RNG_OP(GaussianRNG) REG_RNG_OP(GammaRNG) REG_RNG_OP(PermutationRNG) REG_RNG_OP(PoissonRNG) REG_RNG_OP(BetaRNG) #undef REG_RNG_OP } // namespace mgb::imperative::rng // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}