#include "megbrain/imperative/ops/rng.h" #include "megbrain/comp_node_env.h" #include "megbrain/graph/helper.h" #include "megbrain/opr/rand.h" #include "../dnn_op_helper.h" #include "../op_trait.h" namespace mgb::imperative::rng { namespace { template class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { public: using DT = CompNode::DeviceType; using Handle = THandle; using OpTypeInfo = size_t; template Handle new_handle(Args&&... args) { return static_cast(this)->do_new_handle( std::forward(args)...); } size_t delete_handle(Handle handle) { size_t removed = 0; if (!is_finalized()) { MGB_LOCK_GUARD(m_mtx); removed = m_handle2ops.erase(handle); } static_cast(this)->do_delete_handle(handle); return removed; } template auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) { mgb_assert(!is_finalized()); DnnOpWithMutex* dnn_op_with_mtx; { MGB_LOCK_GUARD(m_mtx); dnn_op_with_mtx = &m_handle2ops[handle][tpinfo]; } auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); std::unique_lock lock(dnn_op_with_mtx->mtx); bool initialized = false; DnnOp* dnn_op = static_cast(dnn_op_with_mtx->op.get()); if (dnn_op != nullptr) { mgb_assert(dnn_op->handle() == dnn_handle); initialized = true; } else { auto new_op = dnn_handle->create_operator(); dnn_op = new_op.get(); dnn_op_with_mtx->op = std::move(new_op); } return std::make_tuple(initialized, dnn_op, std::move(lock)); } protected: using DnnOpManagerBase = DnnOpManagerT; DnnOpManagerT() = default; private: struct DnnOpWithMutex { std::mutex mtx; std::unique_ptr op; DnnOpWithMutex() : op{nullptr} {} }; std::shared_ptr on_comp_node_finalize() override { MGB_LOCK_GUARD(m_mtx); m_handle2ops.clear(); return {}; } std::unordered_map> m_handle2ops; std::mutex m_mtx; }; class RNGDnnOpManager final : public DnnOpManagerT { public: Handle new_handle(CompNode comp_node, uint64_t seed) { MGB_LOCK_GUARD(sm_mtx); return DnnOpManagerBase::new_handle(comp_node, seed); } size_t delete_handle(Handle handle) { MGB_LOCK_GUARD(sm_mtx); return DnnOpManagerBase::delete_handle(handle); } Handle do_new_handle(CompNode comp_node, uint64_t seed) { auto handle = m_handle_pool.alloc(comp_node, seed); return reinterpret_cast(handle); } void do_delete_handle(Handle handle) { m_handle_pool.free(reinterpret_cast(handle)); } static uint64_t get_seed(Handle handle) { if (!handle) { return glob_default_seed; } return reinterpret_cast(handle)->seed; } static CompNode get_comp_node(Handle handle) { mgb_assert(handle, "invalid handle"); return reinterpret_cast(handle)->comp_node; } static Handle get_default_handle(CompNode comp_node) { mgb_assert(comp_node.valid()); MGB_LOCK_GUARD(sm_mtx); auto&& glob_handle = glob_default_handles[comp_node]; if (!glob_handle) { glob_handle = inst().do_new_handle(comp_node, glob_default_seed); } mgb_assert(get_seed(glob_handle) == glob_default_seed); return glob_handle; } static RNGDnnOpManager& inst() { static RNGDnnOpManager mgr; return mgr; } static void set_glob_default_seed(uint64_t seed) { MGB_LOCK_GUARD(sm_mtx); for (auto&& elem : glob_default_handles) { mgb_assert(elem.first.valid()); if (elem.second) { inst().DnnOpManagerBase::delete_handle(elem.second); } elem.second = inst().do_new_handle(elem.first, seed); } glob_default_seed = seed; } static uint64_t get_glob_default_seed() { MGB_LOCK_GUARD(sm_mtx); return glob_default_seed; } private: struct HandleData { CompNode comp_node; uint64_t seed; HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {} }; MemPool m_handle_pool; static std::mutex sm_mtx; static CompNode::UnorderedMap glob_default_handles; static uint64_t glob_default_seed; }; uint64_t RNGDnnOpManager::glob_default_seed = 0; std::mutex RNGDnnOpManager::sm_mtx; CompNode::UnorderedMap RNGDnnOpManager::glob_default_handles; template struct OpMeth; template <> struct OpMeth { using DnnOp = megdnn::UniformRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::UniformRNG; static Param make_param(const UniformRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert( handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed, rng.dtype.enumv()}; } }; template <> struct OpMeth { using DnnOp = megdnn::PoissonRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::PoissonRNG; static Param make_param(const PoissonRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert( handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed}; } }; template <> struct OpMeth { using DnnOp = megdnn::GaussianRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::GaussianRNG; static Param make_param(const GaussianRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert( handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()}; } }; template <> struct OpMeth { using DnnOp = megdnn::GammaRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::GammaRNG; static Param make_param(const GammaRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert( handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed}; } }; template <> struct OpMeth { using DnnOp = megdnn::PermutationRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::PermutationRNG; static Param make_param(const PermutationRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert( handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed, rng.dtype.enumv()}; } }; template <> struct OpMeth { using DnnOp = megdnn::BetaRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::BetaRNG; static Param make_param(const BetaRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert( handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed}; } }; template <> struct OpMeth { using DnnOp = megdnn::ShuffleRNG; using Param = DnnOp::Param; using OpNode = mgb::opr::ShuffleRNG; static Param make_param(const ShuffleRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert( handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); return {handle_seed}; } }; template <> struct OpMeth { using DnnOp = megdnn::Dropout; using Param = DnnOp::Param; using OpNode = mgb::opr::Dropout; static Param make_param(const Dropout& opdef) { auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle); mgb_assert( handle_seed == opdef.seed, "inconsistent dropout seed: dropout op: %lu handle: %lu", handle_seed, opdef.seed); return {opdef.drop_prob, handle_seed}; } }; template <> struct OpMeth { using DnnOp = megdnn::MultiHeadAttn; using Param = DnnOp::Param; using OpNode = mgb::opr::MultiHeadAttn; static Param make_param(const MultiHeadAttn& opdef) { auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle); mgb_assert( handle_seed == opdef.seed, "inconsistent multiheadattn seed: dropout op: %lu handle: %lu", handle_seed, opdef.seed); return {opdef.num_heads, opdef.sm_scaler, opdef.input_order, opdef.reslink, opdef.training, opdef.bias, opdef.attn_mask, opdef.enable_qproj, opdef.enable_kproj, opdef.enable_vproj, opdef.enable_oproj, handle_seed, opdef.attn_prob, opdef.out_prob}; } }; template struct _InferLayout; template struct _RNGOprMaker; template struct _RNGOprInvoker; template <> struct _InferLayout { template static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) { TensorShape tshape; auto hv = inp->get_value().proxy_to_default_cpu(); cg::copy_tensor_value_to_shape(tshape, hv); return TensorLayout(tshape, rng.dtype); } template static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) { TensorLayout out_layout = inp.layout; out_layout.dtype = rng.dtype; if (inp.layout.ndim == 0 || inp.value.empty()) { out_layout.ndim = 0; return out_layout; } mgb_assert( inp.layout.ndim == 1, "target shape of %s expects ndim=1; got ndim=%lu actually", rng.dyn_typeinfo()->name, inp.layout.ndim); size_t target_ndim = inp.layout.shape[0]; out_layout.ndim = target_ndim; auto* ptr = inp.value.ptr(); for (size_t i = 0; i < target_ndim; ++i) { out_layout.shape[i] = ptr[i]; } out_layout.init_contiguous_stride(); return out_layout; } }; template <> struct _InferLayout { template static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) { return inp->layout(); } template static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) { mgb_assert(inp.layout.ndim); return inp.layout; } }; #define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \ template <> \ struct _RNGOprInvoker { \ template \ static void exec( \ Opr* dnn_op, const SmallVector& inputs, \ const SmallVector& outputs) { \ size_t wk_size = 0; \ wk_size = dnn_op->get_workspace_in_bytes( \ _FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \ auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \ megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ dnn_op->exec( \ _FOR_EACH_IN(->dev_tensor().as_megdnn()) \ _FOR_EACH_OUT(->dev_tensor().as_megdnn()), \ dnn_wk); \ } \ }; #define _INST_RNG_MAKER(MGB_NR_INPUTS) \ template <> \ struct _RNGOprMaker { \ template \ static auto make(const VarNodeArray& inputs, const Op& rng) { \ auto param = OpMeth::make_param(rng); \ OperatorNodeConfig config; \ if (rng.handle) { \ config = { \ rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \ } else { \ config = {rng.make_name()}; \ } \ return OpMeth::OpNode::make(_FOR_EACH_IN() param, config); \ } \ }; #define _FOR_EACH_IN(subfix) #define _FOR_EACH_OUT(subfix) outputs[0] subfix _INST_RNG_INVOLKER(0, 1) #undef _FOR_EACH_OUT #undef _FOR_EACH_IN #define _FOR_EACH_IN(subfix) inputs[0] subfix, #define _FOR_EACH_OUT(subfix) outputs[0] subfix _INST_RNG_INVOLKER(1, 1) #undef _FOR_EACH_OUT #define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix _INST_RNG_INVOLKER(1, 2) _INST_RNG_MAKER(1) #undef _FOR_EACH_OUT #undef _FOR_EACH_IN #define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, #define _FOR_EACH_OUT(subfix) outputs[0] subfix _INST_RNG_INVOLKER(2, 1) _INST_RNG_MAKER(2) #undef _FOR_EACH_OUT #undef _FOR_EACH_IN #define _FOR_EACH_IN(subfix) \ inputs[0] subfix, inputs[1] subfix, inputs[2] subfix, inputs[3] subfix, #define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix _INST_RNG_INVOLKER(4, 2) _INST_RNG_MAKER(4) #undef _FOR_EACH_OUT #undef _FOR_EACH_IN #undef _INST_RNG_INVOLKER #undef _INST_RNG_MAKER template void exec( const OpDef& op, const SmallVector& inputs, const SmallVector& outputs, const SmallVector& workspace) { auto&& rng = op.cast_final_safe(); auto dest = outputs[0]; if (dest->layout().is_empty()) return; auto cn = dest->comp_node(); auto handle = rng.handle; if (!handle) { handle = RNGDnnOpManager::get_default_handle(cn); } // retrieve dnn_op from glob cache auto dnn_op_thread_safe = RNGDnnOpManager::inst().get_dnn_op::DnnOp>( handle, reinterpret_cast(op.dyn_typeinfo()), cn); auto initialized = std::get<0>(dnn_op_thread_safe); auto dnn_op = std::get<1>(dnn_op_thread_safe); if (initialized) { auto handle_seed = RNGDnnOpManager::get_seed(handle); mgb_assert( dnn_op->param().seed == handle_seed, "inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed, dnn_op->param().seed); } dnn_op->param() = OpMeth::make_param(rng); _RNGOprInvoker::DnnOp::NR_INPUTS, OpMeth::DnnOp::NR_OUTPUTS>::exec( dnn_op, inputs, outputs); } template SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { LogicalTensorDesc dest; auto&& rng = op.cast_final_safe(); auto handle = rng.handle; if (handle) { dest.comp_node = RNGDnnOpManager::get_comp_node(handle); } else { dest.comp_node = inputs[0]->comp_node(); } constexpr bool rng_with_shape = OpMeth::DnnOp::NR_INPUTS == 0; if (!rng_with_shape) { for (int i = 0; i < inputs.size(); ++i) { mgb_assert( inputs[i]->comp_node() == dest.comp_node, "%s expects the device of inputs[%d] to be same as the device of " "handle; " "got %s and %s actually", rng.dyn_typeinfo()->name, i, inputs[i]->comp_node().to_string().c_str(), dest.comp_node.to_string().c_str()); } } dest.layout = _InferLayout::do_infer(inputs[0], rng); return {dest}; } template <> SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { SmallVector dests(2); auto&& rng = op.cast_final_safe(); auto handle = rng.handle; if (handle) { dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle); dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle); } else { dests[0].comp_node = inputs[0]->comp_node(); dests[1].comp_node = inputs[0]->comp_node(); } dests[0].layout = TensorLayout(inputs[0]->layout()); dests[0].layout.dtype = inputs[0]->layout().dtype; dests[1].layout = TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32()); return dests; } template <> SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { SmallVector dests(2); auto&& cn = inputs[0]->comp_node(); dests[0].comp_node = cn; dests[0].layout = TensorLayout(inputs[0]->layout()); dests[0].layout.dtype = inputs[0]->layout().dtype; auto get_mask_size = [&]() -> size_t { auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); return dnn_handle->create_operator()->get_mask_size_in_bytes( inputs[0]->layout()); }; dests[1].comp_node = cn; dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); return dests; } template <> SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { SmallVector dests(2); auto&& cn = inputs[0]->comp_node(); dests[0].comp_node = cn; dests[0].layout = TensorLayout(inputs[0]->layout()); dests[0].layout.dtype = inputs[0]->layout().dtype; auto get_reservespace_in_bytes = [&]() -> size_t { // retrieve dnn_op from glob cache auto&& rng = op.cast_final_safe(); auto handle = rng.handle; if (!handle) { handle = RNGDnnOpManager::get_default_handle(cn); } auto dnn_op_thread_safe = RNGDnnOpManager::inst().get_dnn_op( handle, reinterpret_cast(op.dyn_typeinfo()), cn); auto dnn_op = std::get<1>(dnn_op_thread_safe); dnn_op->param() = OpMeth::make_param(rng); return dnn_op->get_reservespace_in_bytes( inputs[0]->layout(), inputs[1]->layout(), inputs[2]->layout(), inputs[3]->layout(), {}, {}); }; dests[1].comp_node = cn; dests[1].layout = TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte()); return dests; } template SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { SmallVector outputs; SmallVector desc = infer_output_attrs(def, inputs); for (auto&& i : desc) { outputs.push_back(Tensor::make(i.layout, i.comp_node)); } exec(def, inputs, outputs, {}); return outputs; } template Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { size_t nr_inp = inputs.size(); constexpr size_t dnn_nr_inp = OpMeth::DnnOp::NR_INPUTS; auto&& rng = def.cast_final_safe(); if (dnn_nr_inp == 0) { mgb_assert( nr_inp == 1, "%s expects 1 inputs; got %lu actually", rng.dyn_typeinfo()->name, nr_inp); } constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp; return _RNGOprMaker::make(inputs, rng); } template std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { bool success = inputs[0].layout.ndim != 0; LogicalTensorDesc dest; auto&& xxx_rng_def = def.cast_final_safe(); size_t nr_inp = inputs.size(); constexpr bool rng_with_shape = OpMeth::DnnOp::NR_INPUTS == 0; if (rng_with_shape) { mgb_assert( nr_inp == 1, "%s expects 1 inputs; got %lu actually", xxx_rng_def.dyn_typeinfo()->name, nr_inp); } dest.comp_node = inputs[0].comp_node; if (success) { dest.layout = _InferLayout::do_infer(inputs[0], xxx_rng_def); } else { dest.layout = TensorLayout(inputs[0].layout.dtype); } return {{dest}, inputs[0].layout.ndim != 0}; } template <> std::tuple, bool> infer_output_attrs_fallible< ShuffleRNG>(const OpDef& def, const SmallVector& inputs) { bool success = inputs[0].layout.ndim != 0; SmallVector dests(2); dests[0].comp_node = inputs[0].comp_node; dests[0].layout = TensorLayout(inputs[0].layout); dests[0].layout.dtype = inputs[0].layout.dtype; dests[1].comp_node = inputs[0].comp_node; if (success) { dests[1].layout = TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32()); } else { dests[1].layout = TensorLayout(dtype::Int32()); } return {dests, success}; } template <> std::tuple, bool> infer_output_attrs_fallible( const OpDef& op, const SmallVector& inputs) { bool success = inputs[0].layout.ndim != 0; SmallVector dests(2); auto cn = inputs[0].comp_node; dests[0].comp_node = cn; dests[0].layout = TensorLayout(inputs[0].layout); dests[0].layout.dtype = inputs[0].layout.dtype; auto get_mask_size = [&]() -> size_t { auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); return dnn_handle->create_operator()->get_mask_size_in_bytes( inputs[0].layout); }; dests[1].comp_node = cn; if (success) { dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); } else { dests[1].layout = TensorLayout(dtype::Byte()); } return {dests, success}; } template <> std::tuple, bool> infer_output_attrs_fallible< MultiHeadAttn>(const OpDef& op, const SmallVector& inputs) { bool success = inputs[0].layout.ndim != 0; SmallVector dests(2); auto cn = inputs[0].comp_node; dests[0].comp_node = cn; dests[0].layout = TensorLayout(inputs[0].layout); dests[0].layout.dtype = inputs[0].layout.dtype; auto get_reservespace_in_bytes = [&]() -> size_t { auto&& rng = op.cast_final_safe(); auto handle = rng.handle; if (!handle) { handle = RNGDnnOpManager::get_default_handle(cn); } auto dnn_op_thread_safe = RNGDnnOpManager::inst().get_dnn_op( handle, reinterpret_cast(op.dyn_typeinfo()), cn); auto dnn_op = std::get<1>(dnn_op_thread_safe); dnn_op->param() = OpMeth::make_param(rng); return dnn_op->get_reservespace_in_bytes( inputs[0].layout, inputs[1].layout, inputs[2].layout, inputs[3].layout, {}, {}); }; dests[1].comp_node = cn; if (success) { dests[1].layout = TensorLayout(TensorShape({get_reservespace_in_bytes()}), dtype::Byte()); } else { dests[1].layout = TensorLayout(dtype::Byte()); } return {dests, success}; } template SmallVector get_input_layout_constraint( const OpDef& def, const SmallVector& inputs) { SmallVector layout_checker(inputs.size()); return layout_checker; } } // anonymous namespace Handle new_handle(CompNode comp_node, uint64_t seed) { return RNGDnnOpManager::inst().new_handle(comp_node, seed); } size_t delete_handle(Handle handle) { return RNGDnnOpManager::inst().delete_handle(handle); } void set_global_rng_seed(uint64_t seed) { RNGDnnOpManager::set_glob_default_seed(seed); } uint64_t get_global_rng_seed() { return RNGDnnOpManager::get_glob_default_seed(); } CompNode get_rng_handle_compnode(Handle handle) { return RNGDnnOpManager::get_comp_node(handle); } #define REG_RNG_OP(NAME, Output) \ namespace { \ OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ .apply_on_var_node(apply_on_var_node) \ .apply_on_physical_tensor(apply_on_physical_tensor) \ .infer_output_attrs_fallible(infer_output_attrs_fallible) \ .get_input_layout_constraint(get_input_layout_constraint) \ .fallback(); \ } REG_RNG_OP(UniformRNG, SymbolVar) REG_RNG_OP(GaussianRNG, SymbolVar) REG_RNG_OP(GammaRNG, SymbolVar) REG_RNG_OP(PermutationRNG, SymbolVar) REG_RNG_OP(PoissonRNG, SymbolVar) REG_RNG_OP(BetaRNG, SymbolVar) REG_RNG_OP(ShuffleRNG, SymbolVarArray) REG_RNG_OP(Dropout, SymbolVarArray) REG_RNG_OP(MultiHeadAttn, SymbolVarArray) #undef REG_RNG_OP } // namespace mgb::imperative::rng // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}