提交 67e4e834 编写于 作者: M Megvii Engine Team 提交者: XindaH

fix(lite): fix the force_output_use_user_specified_memory when out var not supported

GitOrigin-RevId: ffaf4c14164e6fff1ce6e9bc8bfe05ef1bd7673a
上级 0d37bfb0
......@@ -26,6 +26,7 @@
#include "megbrain/graph.h"
#include "megbrain/graph/cg.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/tensor.h"
#if MGB_OPENCL
......@@ -340,6 +341,31 @@ void NetworkImplDft::cross_compnode_model_detect() {
m_nr_device_type = nr_used_device_type.size();
}
void NetworkImplDft::adapt_option_valid() {
auto&& options = m_load_config.comp_graph->options();
if (m_user_config->options.force_output_use_user_specified_memory) {
for (auto&& out : m_load_result.output_var_list) {
auto opr = out.node()->owner_opr();
//! all the dest operator inherit from ReadonlyFwdHelper can't
//! support force_output_use_user_specified_memory options
if (opr->try_cast_final<mgb::opr::Reshape>() ||
opr->try_cast_final<mgb::opr::Broadcast>() ||
opr->try_cast_final<mgb::opr::Subtensor>() ||
opr->try_cast_final<mgb::opr::AxisAddRemove>() ||
opr->try_cast_final<mgb::opr::Dimshuffle>()) {
m_user_config->options.force_output_use_user_specified_memory = false;
options.force_output_use_user_specified_memory = false;
LITE_WARN(
"detect the unsupported dest operator %s when config "
"force_output_use_user_specified_memory, set "
"force_output_use_user_specified_memory to false\n",
opr->cname());
break;
}
}
}
}
void NetworkImplDft::load_model(
std::shared_ptr<void> model_mem, size_t size,
std::unordered_map<std::string, LiteAny> separate_config_map) {
......@@ -378,6 +404,8 @@ void NetworkImplDft::load_model(
m_load_result = m_loader->load(m_load_config, true);
adapt_option_valid();
cross_compnode_model_detect();
//! update the IO of the network
......
......@@ -214,6 +214,9 @@ private:
//! optimized output tensor copy
void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor);
//! adapt option valid, it should call after update_io
void adapt_option_valid();
private:
bool m_async = false;
bool m_is_cpu_inplace_mode = false;
......
......@@ -250,14 +250,10 @@ std::unique_ptr<CompNodeSeqRecorder> ComputingGraphImpl::ComputingSequence::
"graph.");
return {};
}
auto is_graph_dest_varnode = [&](VarNode* var) {
return ComputingGraphImpl::downcast(owner_graph())->var_receiver(var).size() ==
0;
};
for (auto i : *m_opr_seq) {
for (auto j : i->output()) {
if (!is_static_var_storage(j) && !is_graph_dest_varnode(j)) {
if (!is_static_var_storage(j) && !j->is_graph_dest_varnode()) {
mgb_log_error(
"can not enable CompNodeSeqRecorder because var "
"storage not static: %s",
......
......@@ -504,6 +504,10 @@ public:
*/
MGE_WIN_DECLSPEC_FUC bool capable_value_infer();
//! whether the var is graph output, if it is output, the Flag of
//! NO_SYS_MEM_ALLOC can be modified.
MGE_WIN_DECLSPEC_FUC bool is_graph_dest_varnode();
private:
//! whether its memory should be allocated by mgb system during graph
//! execution; initialized in VarNodeMemManager::reset_opr_seq()
......@@ -552,10 +556,6 @@ private:
MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag);
//! whether the var is graph output, if it is output, the Flag of
//! NO_SYS_MEM_ALLOC can be modified.
bool is_graph_dest_varnode();
MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor(
const DeviceTensorND& value);
......
......@@ -919,7 +919,7 @@ Split::Options Split::Options::make_callback(
int axis, size_t nr_part, callback_t callback) {
mgb_assert(nr_part);
Options rst;
rst.method = Method::CALLBACK;
rst.method = Method::CALL_BACK;
rst.axis = axis;
rst.callback = callback;
rst.nr_part = nr_part;
......@@ -955,7 +955,7 @@ Split::Split(VarNode* inp, const Options& opt, const OperatorNodeConfig& config)
// disable dedup
add_equivalence_component<ScalarHash<void*>>(this);
mgb_assert(m_opt.method == Options::Method::CALLBACK);
mgb_assert(m_opt.method == Options::Method::CALL_BACK);
mgb_assert(m_opt.nr_part);
}
......
......@@ -172,7 +172,7 @@ cg::OperatorNodeBase* opr_shallow_copy_split(
auto option = opr.options();
using Meth = Split::Options::Method;
switch (option.method) {
case Meth::CALLBACK:
case Meth::CALL_BACK:
mgb_assert(inputs.size() == 1);
break;
case Meth::SPECIFY:
......
......@@ -408,8 +408,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Split, intl::OutshapeBySymvarOprBase) // {
public:
struct Options {
enum class Method {
SPECIFY, //!< specify output sizes
CALLBACK //!< output sizes obtained from callback
SPECIFY, //!< specify output sizes
CALL_BACK //!< output sizes obtained from callback
};
Method method;
size_t nr_part = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册