/** * \file src/opr/impl/rand.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/rand.h" #include "megbrain/opr/utility.h" #include "megbrain/graph/grad_impl.h" #include "./internal/megdnn_opr_wrapper.inl" using namespace mgb; using namespace opr; using namespace intl; template RNGOprBase::RNGOprBase(const OperatorNodeBaseCtorParam &opr, const Param ¶m): Super(opr),m_param(param) { } template UniqPtrWithCN RNGOprBase::create_megdnn_opr() { auto opr = intl::create_megdnn_opr(comp_node()); opr->param() = param(); return opr; } template void RNGOprBase::ensure_megdnn_opr() { if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node()) { // activate comp_node for curandCreateGenerator in create_megdnn_opr comp_node().activate(); m_dnn_opr = create_megdnn_opr(); } } /* ================= RNG with shape ================= */ #define _INST_RNG_OPR_WITH_SHAPE(RNGOpr, name) \ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr); \ cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { \ auto prop = Super::do_make_node_prop(); \ prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \ prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); \ for (auto i: input()) { \ prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); \ } \ return prop; \ } \ RNGOpr::RNGOpr(VarNode *shape, const Param ¶m, \ const OperatorNodeConfig &config): \ Super({shape->owner_graph(), config, (name), {shape}}, param) \ { \ DType dtype = DType::from_enum(param.dtype); \ add_input({shape}); \ add_output(None)->dtype(dtype).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); \ cg::add_workspace_output(this); \ add_equivalence_component>(this); \ } \ SymbolVar RNGOpr::make(SymbolVar shape, const Param ¶m, \ const OperatorNodeConfig &config){ \ return shape.insert_single_output_opr(shape.node(), param, config); \ } \ void RNGOpr::init_output_static_infer_desc() { \ using namespace cg::static_infer; \ auto &&mgr = owner_graph()->static_infer_manager(); \ auto infer_out = [](TensorShape &dest, const InpVal &inp) { \ cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); \ return true; \ }; \ auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { \ ensure_megdnn_opr(); \ dest.ndim = 1; \ dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( \ {inp.val.at(0).shape(), output(0)->dtype()}); \ return true; \ }; \ mgr.register_shape_infer(output(0), \ {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_out}); \ mgr.register_shape_infer(output(1), \ {SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); \ } \ void RNGOpr::scn_do_execute() { \ auto&& ret = output(0); \ if (ret->layout().is_empty()) { \ mgb_assert(ret->dev_tensor().empty()); \ return; \ } \ m_dnn_opr->exec(ret->dev_tensor().as_megdnn(), \ get_megdnn_workspace_from_var(output(1))); \ } _INST_RNG_OPR_WITH_SHAPE(UniformRNG,"uniform_rng") _INST_RNG_OPR_WITH_SHAPE(GaussianRNG,"gaussian_rng") _INST_RNG_OPR_WITH_SHAPE(PermutationRNG,"permutation_rng") #undef _INST_RNG_OPR_WITH_SHAPE /* ================= RNG with input ================= */ #define _AS_MEGDNN(idx) input((idx))->dev_tensor().as_megdnn() #define _INFER_WK_DEPS(idx) {input((idx)), DepType::SHAPE} #define _INFER_WK_ARGS(idx) {inp.val.at((idx)).shape(), input((idx))->dtype()} #define _INST_RNG_OPR_WITH_INPUT(RNGOpr, name) \ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr); \ RNGOpr::RNGOpr(_INPUTS(VarNode*,), const Param ¶m, \ const OperatorNodeConfig &config): \ Super({i0->owner_graph(), config, (name), {_INPUTS(,)}}, param) \ { \ add_input({_INPUTS(,)}); \ add_output(None)->dtype(i0->dtype()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); \ cg::add_workspace_output(this); \ add_equivalence_component>(this); \ } \ SymbolVar RNGOpr::make(_INPUTS(SymbolVar,), const Param ¶m, \ const OperatorNodeConfig &config){ \ return i0.insert_single_output_opr(_INPUTS(,.node()), param, config); \ } \ void RNGOpr::init_output_static_infer_desc() { \ using namespace cg::static_infer; \ auto &&mgr = owner_graph()->static_infer_manager(); \ auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { \ ensure_megdnn_opr(); \ dest.ndim = 1; \ dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( \ _FOR_EACH(_INFER_WK_ARGS), \ {output(0)->shape(), output(0)->dtype()}); \ return true; \ }; \ mgr.register_shape_infer(output(0),ShapeInferDesc::make_identity(input(0))); \ mgr.register_shape_infer(output(1),{SourceType::DEP, {_FOR_EACH(_INFER_WK_DEPS)}, \ infer_wk}); \ } \ void RNGOpr::add_input_layout_constraint(){ \ for (auto i : input()) i->add_layout_constraint_contiguous(); \ }; \ void RNGOpr::scn_do_execute() { \ auto&& ret = output(0); \ if (ret->layout().is_empty()) { \ mgb_assert(ret->dev_tensor().empty()); \ return; \ } \ m_dnn_opr->exec(_FOR_EACH(_AS_MEGDNN),output(0)->dev_tensor().as_megdnn(), \ get_megdnn_workspace_from_var(output(1))); \ } \ cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { \ auto prop = Super::do_make_node_prop(); \ prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \ for (auto i: input()) { \ prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); \ } \ return prop; \ } /* ================= 1 input ================= */ #define _INPUTS(prefix, subfix) prefix i0 subfix #define _FOR_EACH(cb) cb(0) _INST_RNG_OPR_WITH_INPUT(PoissonRNG,"poisson_rng") #undef _INPUTS #undef _FOR_EACH /* ================= 2 input ================= */ #define _INPUTS(prefix,subfix) prefix i0 subfix, prefix i1 subfix #define _FOR_EACH(cb) cb(0), cb(1) _INST_RNG_OPR_WITH_INPUT(BetaRNG,"beta_rng") _INST_RNG_OPR_WITH_INPUT(GammaRNG,"gamma_rng") #undef _INPUTS #undef _FOR_EACH #undef _AS_MEGDNN #undef _INFER_WK_DEPS #undef _INFER_WK_ARGS #undef _INST_RNG_OPR_WITH_INPUT #define IMPL(_cls) \ MGB_IMPL_OPR_GRAD(_cls) { \ MGB_MARK_USED_VAR(out_grad); \ return InvalidGrad::make(opr, wrt_idx); \ } namespace mgb { namespace opr { namespace intl { template class RNGOprBase<::megdnn::GaussianRNG>; template class RNGOprBase<::megdnn::UniformRNG>; template class RNGOprBase<::megdnn::GammaRNG>; template class RNGOprBase<::megdnn::PermutationRNG>; template class RNGOprBase<::megdnn::BetaRNG>; template class RNGOprBase<::megdnn::PoissonRNG>; template class RNGOprBase<::megdnn::ShuffleRNGForward>; template class RNGOprBase<::megdnn::ShuffleRNGBackward>; #if MGB_ENABLE_GRAD IMPL(GaussianRNG); IMPL(UniformRNG); IMPL(GammaRNG); IMPL(PoissonRNG); IMPL(PermutationRNG); IMPL(BetaRNG); #endif } // namespace intl } // namespace opr } // namespace mgb /* ================= ShuffleRNGForward ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGForward); ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param, const OperatorNodeConfig& config) : Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) { add_input({data}); add_output(None) ->dtype(data->dtype()) .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_output(None) ->dtype(dtype::Int32{}) .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); cg::add_workspace_output(this); add_equivalence_component>(this); } SymbolVarArray ShuffleRNGForward::make(SymbolVar in_tensor, const Param& param, const OperatorNodeConfig& config) { auto node = in_tensor.node()->owner_graph()->insert_opr( std::make_unique(in_tensor.node(), param, config)); mgb_assert(node->output().size() == 3); return {node->output(0), node->output(1)}; } void ShuffleRNGForward::init_output_static_infer_desc() { using namespace cg::static_infer; auto&& mgr = owner_graph()->static_infer_manager(); mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); auto infer_oshp1 = [this](TensorShape& dest, const InpVal& iv) { TensorLayout o0, o1; m_dnn_opr->deduce_layout({iv.val[0].shape(), input(0)->dtype()}, o0, o1); dest = o1; return true; }; mgr.register_shape_infer( output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_oshp1}); auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { ensure_megdnn_opr(); dest.ndim = 1; dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( {inp.val[0].shape(), input(0)->dtype()}, {output(0)->shape(), output(0)->dtype()}, {output(1)->shape(), output(1)->dtype()}); return true; }; mgr.register_shape_infer( output(2), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_wk}); } void ShuffleRNGForward::add_input_layout_constraint() { input(0)->add_layout_constraint_contiguous(); }; void ShuffleRNGForward::scn_do_execute() { auto&& ret = output(0); if (ret->layout().is_empty()) { mgb_assert(ret->dev_tensor().empty()); return; } m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), output(1)->dev_tensor().as_megdnn(), get_megdnn_workspace_from_var(output(2))); } cg::OperatorNodeBase::NodeProp* ShuffleRNGForward::do_make_node_prop() const { auto prop = Super::do_make_node_prop(); prop->add_flag(NodeProp::Flag::IMPURE_FUNC); for (auto i : input()) { prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); } return prop; } #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); if (!out_grad[0]) return nullptr; return ShuffleRNGBackward::make(out_grad[0], opr.output(1), opr.input(0)).node(); } #endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGBackward); MEGDNN_OPR_INIT3(ShuffleRNGBackward, "shuffle_rng_bwd", 2, true) // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}