提交 67e4e834 编写于 作者: M Megvii Engine Team 提交者: XindaH

fix(lite): fix the force_output_use_user_specified_memory when out var not supported

GitOrigin-RevId: ffaf4c14164e6fff1ce6e9bc8bfe05ef1bd7673a
上级 0d37bfb0
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/graph/cg.h" #include "megbrain/graph/cg.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
#if MGB_OPENCL #if MGB_OPENCL
...@@ -340,6 +341,31 @@ void NetworkImplDft::cross_compnode_model_detect() { ...@@ -340,6 +341,31 @@ void NetworkImplDft::cross_compnode_model_detect() {
m_nr_device_type = nr_used_device_type.size(); m_nr_device_type = nr_used_device_type.size();
} }
void NetworkImplDft::adapt_option_valid() {
auto&& options = m_load_config.comp_graph->options();
if (m_user_config->options.force_output_use_user_specified_memory) {
for (auto&& out : m_load_result.output_var_list) {
auto opr = out.node()->owner_opr();
//! all the dest operator inherit from ReadonlyFwdHelper can't
//! support force_output_use_user_specified_memory options
if (opr->try_cast_final<mgb::opr::Reshape>() ||
opr->try_cast_final<mgb::opr::Broadcast>() ||
opr->try_cast_final<mgb::opr::Subtensor>() ||
opr->try_cast_final<mgb::opr::AxisAddRemove>() ||
opr->try_cast_final<mgb::opr::Dimshuffle>()) {
m_user_config->options.force_output_use_user_specified_memory = false;
options.force_output_use_user_specified_memory = false;
LITE_WARN(
"detect the unsupported dest operator %s when config "
"force_output_use_user_specified_memory, set "
"force_output_use_user_specified_memory to false\n",
opr->cname());
break;
}
}
}
}
void NetworkImplDft::load_model( void NetworkImplDft::load_model(
std::shared_ptr<void> model_mem, size_t size, std::shared_ptr<void> model_mem, size_t size,
std::unordered_map<std::string, LiteAny> separate_config_map) { std::unordered_map<std::string, LiteAny> separate_config_map) {
...@@ -378,6 +404,8 @@ void NetworkImplDft::load_model( ...@@ -378,6 +404,8 @@ void NetworkImplDft::load_model(
m_load_result = m_loader->load(m_load_config, true); m_load_result = m_loader->load(m_load_config, true);
adapt_option_valid();
cross_compnode_model_detect(); cross_compnode_model_detect();
//! update the IO of the network //! update the IO of the network
......
...@@ -214,6 +214,9 @@ private: ...@@ -214,6 +214,9 @@ private:
//! optimized output tensor copy //! optimized output tensor copy
void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor); void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor);
//! adapt option valid, it should call after update_io
void adapt_option_valid();
private: private:
bool m_async = false; bool m_async = false;
bool m_is_cpu_inplace_mode = false; bool m_is_cpu_inplace_mode = false;
......
...@@ -250,14 +250,10 @@ std::unique_ptr<CompNodeSeqRecorder> ComputingGraphImpl::ComputingSequence:: ...@@ -250,14 +250,10 @@ std::unique_ptr<CompNodeSeqRecorder> ComputingGraphImpl::ComputingSequence::
"graph."); "graph.");
return {}; return {};
} }
auto is_graph_dest_varnode = [&](VarNode* var) {
return ComputingGraphImpl::downcast(owner_graph())->var_receiver(var).size() ==
0;
};
for (auto i : *m_opr_seq) { for (auto i : *m_opr_seq) {
for (auto j : i->output()) { for (auto j : i->output()) {
if (!is_static_var_storage(j) && !is_graph_dest_varnode(j)) { if (!is_static_var_storage(j) && !j->is_graph_dest_varnode()) {
mgb_log_error( mgb_log_error(
"can not enable CompNodeSeqRecorder because var " "can not enable CompNodeSeqRecorder because var "
"storage not static: %s", "storage not static: %s",
......
...@@ -504,6 +504,10 @@ public: ...@@ -504,6 +504,10 @@ public:
*/ */
MGE_WIN_DECLSPEC_FUC bool capable_value_infer(); MGE_WIN_DECLSPEC_FUC bool capable_value_infer();
//! whether the var is graph output, if it is output, the Flag of
//! NO_SYS_MEM_ALLOC can be modified.
MGE_WIN_DECLSPEC_FUC bool is_graph_dest_varnode();
private: private:
//! whether its memory should be allocated by mgb system during graph //! whether its memory should be allocated by mgb system during graph
//! execution; initialized in VarNodeMemManager::reset_opr_seq() //! execution; initialized in VarNodeMemManager::reset_opr_seq()
...@@ -552,10 +556,6 @@ private: ...@@ -552,10 +556,6 @@ private:
MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag); MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag);
//! whether the var is graph output, if it is output, the Flag of
//! NO_SYS_MEM_ALLOC can be modified.
bool is_graph_dest_varnode();
MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor( MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor(
const DeviceTensorND& value); const DeviceTensorND& value);
......
...@@ -919,7 +919,7 @@ Split::Options Split::Options::make_callback( ...@@ -919,7 +919,7 @@ Split::Options Split::Options::make_callback(
int axis, size_t nr_part, callback_t callback) { int axis, size_t nr_part, callback_t callback) {
mgb_assert(nr_part); mgb_assert(nr_part);
Options rst; Options rst;
rst.method = Method::CALLBACK; rst.method = Method::CALL_BACK;
rst.axis = axis; rst.axis = axis;
rst.callback = callback; rst.callback = callback;
rst.nr_part = nr_part; rst.nr_part = nr_part;
...@@ -955,7 +955,7 @@ Split::Split(VarNode* inp, const Options& opt, const OperatorNodeConfig& config) ...@@ -955,7 +955,7 @@ Split::Split(VarNode* inp, const Options& opt, const OperatorNodeConfig& config)
// disable dedup // disable dedup
add_equivalence_component<ScalarHash<void*>>(this); add_equivalence_component<ScalarHash<void*>>(this);
mgb_assert(m_opt.method == Options::Method::CALLBACK); mgb_assert(m_opt.method == Options::Method::CALL_BACK);
mgb_assert(m_opt.nr_part); mgb_assert(m_opt.nr_part);
} }
......
...@@ -172,7 +172,7 @@ cg::OperatorNodeBase* opr_shallow_copy_split( ...@@ -172,7 +172,7 @@ cg::OperatorNodeBase* opr_shallow_copy_split(
auto option = opr.options(); auto option = opr.options();
using Meth = Split::Options::Method; using Meth = Split::Options::Method;
switch (option.method) { switch (option.method) {
case Meth::CALLBACK: case Meth::CALL_BACK:
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
break; break;
case Meth::SPECIFY: case Meth::SPECIFY:
......
...@@ -408,8 +408,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Split, intl::OutshapeBySymvarOprBase) // { ...@@ -408,8 +408,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Split, intl::OutshapeBySymvarOprBase) // {
public: public:
struct Options { struct Options {
enum class Method { enum class Method {
SPECIFY, //!< specify output sizes SPECIFY, //!< specify output sizes
CALLBACK //!< output sizes obtained from callback CALL_BACK //!< output sizes obtained from callback
}; };
Method method; Method method;
size_t nr_part = 0; size_t nr_part = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册