/** * \file src/opr/impl/internal/megdnn_opr_wrapper.inl * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megbrain/opr/internal/megdnn_opr_wrapper.h" namespace mgb { namespace opr { namespace intl { /*! * \brief template that can be specialized so inputs of an operator could be * modified in-place * * Invoked by MEGDNN_OPR_INIT* macros * * \tparam Opr an megbrain opr final class */ template struct MegDNNOprInitInputsModifier { static inline void apply(const typename Opr::Param ¶m, std::initializer_list inputs) { MGB_MARK_USED_VAR(param); MGB_MARK_USED_VAR(inputs); } }; /*! * \brief template that can be specialized to be called in opr constructor * * Invoked by MEGDNN_OPR_INIT* macros */ template struct MegDNNOprInitPostCtor { static inline void apply(cg::OperatorNodeBase &opr) { MGB_MARK_USED_VAR(opr); } }; //! get megdnn Workspace object from a workspace var megdnn::Workspace get_megdnn_workspace_from_var(VarNode *var); /*! * \brief A UserData object associated with the computing graph to get * maximal usable workspace. * * It works by first limit workspace to 0 and alloc to get free memory, and * assume workspace can use all free memory. * It would produce a var node, which should be taken as a value dep for * workspace static infer functors so memory manager can re-allocate. */ class WorkspaceLimitGetter { class Impl; static Impl* get_impl(ComputingGraph *graph); public: /*! * \brief get usable workspace size in bytes for a comp node * * Can only be called after is_prealloc_run() returns false * * \param old_limit workspace limit set by user, which would be an * upper bound for the return value */ static size_t get_workspace_limit( ComputingGraph *graph, CompNode cn, size_t old_limit); //! return whether current is pre-allocation so workspace should //! return 0 static bool is_prealloc_run(ComputingGraph *graph); /*! * \brief register WorkspaceLimitGetter in a graph * \return an var to be added as extra value dep for workspace * infer; it would be null if WorkspaceLimitGetter is disabled * at compile time */ static VarNode* register_to_graph(ComputingGraph *graph); }; /*! * a template that can be specialized to indicate whether * WorkspaceLimitGetter is needed for an operator class * * \tparam MegDNNOpr a megdnn opr class */ template struct AutoAddWorkspaceNeedLimitGetter { static constexpr bool val = false; }; /*! * \brief implement megdnn::DynOutMallocPolicy using memory management * system in megbrain */ class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { cg::OperatorNodeBase *m_opr; CompNode m_cn; public: MegDNNDynOutMallocImpl(cg::OperatorNodeBase *opr, CompNode cn): m_opr{opr}, m_cn{cn} {} megdnn::TensorND alloc_output( size_t id, DType dtype, const TensorShape &shape, void *user_data) override; void* alloc_workspace(size_t sz, void *user_data) override; void free_workspace(void *ptr, void *user_data) override; }; /* ======================= MegDNNOprMethInvoker ======================= */ namespace { template struct _MegDNNOprMethInvoker; template using MegDNNOprMethInvoker = _MegDNNOprMethInvoker; #define _NR_INPUTS 1 #define _NR_OUTPUTS 1 #define _FOREACH_IO(_i, _o) _i(0), _o(0) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 1 #define _NR_OUTPUTS 2 #define _FOREACH_IO(_i, _o) _i(0), _o(0), _o(1) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 1 #define _NR_OUTPUTS 3 #define _FOREACH_IO(_i, _o) _i(0), _o(0), _o(1), _o(2) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 2 #define _NR_OUTPUTS 1 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _o(0) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 2 #define _NR_OUTPUTS 2 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _o(0), _o(1) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 3 #define _NR_OUTPUTS 1 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 3 #define _NR_OUTPUTS 2 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 3 #define _NR_OUTPUTS 3 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1), _o(2) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 4 #define _NR_OUTPUTS 1 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 5 #define _NR_OUTPUTS 2 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 5 #define _NR_OUTPUTS 3 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #define _NR_INPUTS 6 #define _NR_OUTPUTS 3 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" } // anonymous namespace /* ======================= MegDNNOprWrapperFwd ======================= */ template void MegDNNOprWrapperFwd::init_output_static_infer_desc() { Super::set_nr_managed_outputs(this->output().size() - 1); Super::init_output_static_infer_desc(); this->init_output_static_infer_desc_workspace( AutoAddWorkspaceNeedLimitGetter::val); } template void MegDNNOprWrapperFwd::scn_do_execute() { MegDNNOprMethInvoker::exec(this->megdnn_opr(), this); } template size_t MegDNNOprWrapperFwd::get_workspace_size_bytes( const TensorShapeArray &input_shapes, const TensorShapeArray &output_shapes) const { return this->mixin_get_workspace_size_bytes_by_megdnn( *this, input_shapes, output_shapes); } template void MegDNNOprWrapperFwd::get_output_var_shape( const TensorShapeArray &inp_shape, TensorShapeArray &out_shape) const { MegDNNOprMethInvoker::deduce_layout( this->megdnn_opr(), this, inp_shape, out_shape); } /* ======================= MegDNNOprWrapperBwd ======================= */ template void MegDNNOprWrapperBwd::init_output_static_infer_desc() { this->mixin_init_output_static_infer_desc_bwd(*this); this->init_output_static_infer_desc_workspace( AutoAddWorkspaceNeedLimitGetter::val); } template void MegDNNOprWrapperBwd::scn_do_execute() { MegDNNOprMethInvoker::exec(this->megdnn_opr(), this); } template size_t MegDNNOprWrapperBwd::get_workspace_size_bytes( const TensorShapeArray &input_shapes, const TensorShapeArray &output_shapes) const { return this->mixin_get_workspace_size_bytes_by_megdnn( *this, input_shapes, output_shapes); } template typename MegDNNOprWrapperBwd::Super::NodeProp* MegDNNOprWrapperBwd::do_make_node_prop() const { auto prop = Super::do_make_node_prop(); this->mixin_update_node_prop(*this, prop); return prop; } } // nmamespace intl namespace mixin { /* ======================= MegDNNOprHolderImpl ======================= */ template size_t MegDNNOprHolderImpl:: mixin_get_workspace_size_bytes_by_megdnn( const OperatorNodeBase &opr, const TensorShapeArray &input_shapes, const TensorShapeArray &output_shapes) const { static_assert(add_workspace, "must add_workspace"); return intl::MegDNNOprMethInvoker::get_workspace_in_bytes( this->megdnn_opr(), &opr, input_shapes, output_shapes); } } } // namespace opr } // namespace mgb //! generate opr constructor, with 1 arg #define MEGDNN_OPR_CTOR_INIT1(_name, _node_name, ...) \ _name::_name(VarNode *i0, \ const Param ¶m, const OperatorNodeConfig &config): \ Super(OperatorNodeBaseCtorParam{ \ i0->owner_graph(), config, _node_name, {i0}} ,##__VA_ARGS__) \ { \ init_megdnn_opr(*this, param); \ add_input({i0}); \ intl::MegDNNOprInitPostCtor<_name>::apply(*this); \ } //! generate opr constructor and ::make, with 1 arg #define MEGDNN_OPR_INIT1(_name, _node_name, ...) \ MEGDNN_OPR_CTOR_INIT1(_name, _node_name ,##__VA_ARGS__) \ SymbolVar _name::make(SymbolVar i0, \ const Param ¶m, const OperatorNodeConfig &config) { \ intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0}); \ return i0.insert_single_output_opr<_name>( \ i0.node(), param, config); \ } //! generate opr constructor, with 2 args #define MEGDNN_OPR_CTOR_INIT2(_name, _node_name, ...) \ _name::_name(VarNode *i0, VarNode *i1, \ const Param ¶m, const OperatorNodeConfig &config): \ Super(OperatorNodeBaseCtorParam{ \ i0->owner_graph(), config, _node_name, {i0}} ,##__VA_ARGS__) \ { \ init_megdnn_opr(*this, param); \ add_input({i0, i1}); \ intl::MegDNNOprInitPostCtor<_name>::apply(*this); \ } //! generate opr constructor and ::make, with 2 args #define MEGDNN_OPR_INIT2(_name, _node_name, ...) \ MEGDNN_OPR_CTOR_INIT2(_name, _node_name ,##__VA_ARGS__) \ SymbolVar _name::make(SymbolVar i0, SymbolVar i1, \ const Param ¶m, const OperatorNodeConfig &config) { \ intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1}); \ return i0.insert_single_output_opr<_name>( \ i0.node(), i1.node(), param, config); \ } //! generate opr constructor, with 3 args #define MEGDNN_OPR_CTOR_INIT3(_name, _node_name, ...) \ _name::_name(VarNode *i0, VarNode *i1, VarNode *i2, \ const Param ¶m, const OperatorNodeConfig &config): \ Super(OperatorNodeBaseCtorParam{ \ i0->owner_graph(), config, _node_name, {i0}} ,##__VA_ARGS__) \ { \ init_megdnn_opr(*this, param); \ add_input({i0, i1, i2}); \ intl::MegDNNOprInitPostCtor<_name>::apply(*this); \ } //! generate opr constructor and ::make, with 3 args #define MEGDNN_OPR_INIT3(_name, _node_name, ...) \ MEGDNN_OPR_CTOR_INIT3(_name, _node_name ,##__VA_ARGS__) \ SymbolVar _name::make(SymbolVar i0, SymbolVar i1, SymbolVar i2, \ const Param ¶m, const OperatorNodeConfig &config) { \ intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1, &i2}); \ return i0.insert_single_output_opr<_name>( \ i0.node(), i1.node(), i2.node(), param, config); \ } //! generate opr constructor, with 4 args #define MEGDNN_OPR_CTOR_INIT4(_name, _node_name, ...) \ _name::_name(VarNode *i0, VarNode *i1, VarNode *i2, VarNode *i3, \ const Param ¶m, const OperatorNodeConfig &config): \ Super(OperatorNodeBaseCtorParam{ \ i0->owner_graph(), config, _node_name, {i0}} ,##__VA_ARGS__) \ { \ init_megdnn_opr(*this, param); \ add_input({i0, i1, i2, i3}); \ intl::MegDNNOprInitPostCtor<_name>::apply(*this); \ } //! generate opr constructor and ::make, with 4 args #define MEGDNN_OPR_INIT4(_name, _node_name, ...) \ MEGDNN_OPR_CTOR_INIT4(_name, _node_name ,##__VA_ARGS__) \ SymbolVar _name::make(SymbolVar i0, SymbolVar i1, SymbolVar i2, SymbolVar i3, \ const Param ¶m, const OperatorNodeConfig &config) { \ intl::MegDNNOprInitInputsModifier<_name>::apply( \ param, {&i0, &i1, &i2, &i3}); \ return i0.insert_single_output_opr<_name>( \ i0.node(), i1.node(), i2.node(), i3.node(), param, config); \ } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}