diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 47e3204540fb6a05b6faff3f1a8656f67678c3d3..c7db5f4cead64c7a8e76cadbe23f8d0f2425b4de 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -26,6 +26,7 @@ #include "megbrain/graph.h" #include "megbrain/graph/cg.h" #include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" #include "megbrain/tensor.h" #if MGB_OPENCL @@ -340,6 +341,31 @@ void NetworkImplDft::cross_compnode_model_detect() { m_nr_device_type = nr_used_device_type.size(); } +void NetworkImplDft::adapt_option_valid() { + auto&& options = m_load_config.comp_graph->options(); + if (m_user_config->options.force_output_use_user_specified_memory) { + for (auto&& out : m_load_result.output_var_list) { + auto opr = out.node()->owner_opr(); + //! all the dest operator inherit from ReadonlyFwdHelper can't + //! support force_output_use_user_specified_memory options + if (opr->try_cast_final() || + opr->try_cast_final() || + opr->try_cast_final() || + opr->try_cast_final() || + opr->try_cast_final()) { + m_user_config->options.force_output_use_user_specified_memory = false; + options.force_output_use_user_specified_memory = false; + LITE_WARN( + "detect the unsupported dest operator %s when config " + "force_output_use_user_specified_memory, set " + "force_output_use_user_specified_memory to false\n", + opr->cname()); + break; + } + } + } +} + void NetworkImplDft::load_model( std::shared_ptr model_mem, size_t size, std::unordered_map separate_config_map) { @@ -378,6 +404,8 @@ void NetworkImplDft::load_model( m_load_result = m_loader->load(m_load_config, true); + adapt_option_valid(); + cross_compnode_model_detect(); //! update the IO of the network diff --git a/lite/src/mge/network_impl.h b/lite/src/mge/network_impl.h index 903f92f02141921403a8677493e6da347cfde52d..ec9e61c2337c5dc096db4dffa181566b0fc5b495 100644 --- a/lite/src/mge/network_impl.h +++ b/lite/src/mge/network_impl.h @@ -214,6 +214,9 @@ private: //! optimized output tensor copy void output_tensor_copy_optimize(Var var, std::shared_ptr tensor); + //! adapt option valid, it should call after update_io + void adapt_option_valid(); + private: bool m_async = false; bool m_is_cpu_inplace_mode = false; diff --git a/src/core/impl/graph/cg_impl_seq.cpp b/src/core/impl/graph/cg_impl_seq.cpp index b146969bcb8606b11050a2734613fcb20da4d7ab..40a6ef01c75cb45b66bcaa27382642e80a6525d2 100644 --- a/src/core/impl/graph/cg_impl_seq.cpp +++ b/src/core/impl/graph/cg_impl_seq.cpp @@ -250,14 +250,10 @@ std::unique_ptr ComputingGraphImpl::ComputingSequence:: "graph."); return {}; } - auto is_graph_dest_varnode = [&](VarNode* var) { - return ComputingGraphImpl::downcast(owner_graph())->var_receiver(var).size() == - 0; - }; for (auto i : *m_opr_seq) { for (auto j : i->output()) { - if (!is_static_var_storage(j) && !is_graph_dest_varnode(j)) { + if (!is_static_var_storage(j) && !j->is_graph_dest_varnode()) { mgb_log_error( "can not enable CompNodeSeqRecorder because var " "storage not static: %s", diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index b09b4157dd64a411cffdbad67ce9730d5063bfae..7a9ff05512ef58ac760b69153769f705668f74c8 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -504,6 +504,10 @@ public: */ MGE_WIN_DECLSPEC_FUC bool capable_value_infer(); + //! whether the var is graph output, if it is output, the Flag of + //! NO_SYS_MEM_ALLOC can be modified. + MGE_WIN_DECLSPEC_FUC bool is_graph_dest_varnode(); + private: //! whether its memory should be allocated by mgb system during graph //! execution; initialized in VarNodeMemManager::reset_opr_seq() @@ -552,10 +556,6 @@ private: MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag); - //! whether the var is graph output, if it is output, the Flag of - //! NO_SYS_MEM_ALLOC can be modified. - bool is_graph_dest_varnode(); - MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor( const DeviceTensorND& value); diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index cfac8d64d3586b1fe089833b47548d0ac69c4f34..cc4302dec1efdf33ab7e96fe07e35468af5a3e3c 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -919,7 +919,7 @@ Split::Options Split::Options::make_callback( int axis, size_t nr_part, callback_t callback) { mgb_assert(nr_part); Options rst; - rst.method = Method::CALLBACK; + rst.method = Method::CALL_BACK; rst.axis = axis; rst.callback = callback; rst.nr_part = nr_part; @@ -955,7 +955,7 @@ Split::Split(VarNode* inp, const Options& opt, const OperatorNodeConfig& config) // disable dedup add_equivalence_component>(this); - mgb_assert(m_opt.method == Options::Method::CALLBACK); + mgb_assert(m_opt.method == Options::Method::CALL_BACK); mgb_assert(m_opt.nr_part); } diff --git a/src/opr/impl/tensor_manip.sereg.h b/src/opr/impl/tensor_manip.sereg.h index 6f178d9b9f350e3d2c7141ff7d9834cb8395abb8..07915e757eaecc6f410a3a840c9b925de495ddd9 100644 --- a/src/opr/impl/tensor_manip.sereg.h +++ b/src/opr/impl/tensor_manip.sereg.h @@ -172,7 +172,7 @@ cg::OperatorNodeBase* opr_shallow_copy_split( auto option = opr.options(); using Meth = Split::Options::Method; switch (option.method) { - case Meth::CALLBACK: + case Meth::CALL_BACK: mgb_assert(inputs.size() == 1); break; case Meth::SPECIFY: diff --git a/src/opr/include/megbrain/opr/tensor_manip.h b/src/opr/include/megbrain/opr/tensor_manip.h index baee4baa34d297dc6fbc1e0112eebca58ac8c63d..01b2fa99793c1735b85893012d8b9adf61cd24a1 100644 --- a/src/opr/include/megbrain/opr/tensor_manip.h +++ b/src/opr/include/megbrain/opr/tensor_manip.h @@ -408,8 +408,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Split, intl::OutshapeBySymvarOprBase) // { public: struct Options { enum class Method { - SPECIFY, //!< specify output sizes - CALLBACK //!< output sizes obtained from callback + SPECIFY, //!< specify output sizes + CALL_BACK //!< output sizes obtained from callback }; Method method; size_t nr_part = 0;