#include "../mgb_cg_impl.h" #include "./mini_graph.h" #include "megbrain/opr/io.h" using LayoutConstraintLevel = mgb::cg::VarNodeMemManager::LayoutConstraintLevel; using LayoutConstraintCallback = mgb::VarNode::LayoutConstraintCallback; namespace mgb::imperative::proxy_graph { MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder); thread_local std::unique_ptr ProxyGraphTypeI::sm_instance = {}; } // namespace mgb::imperative::proxy_graph namespace mgb::imperative::proxy_graph_detail { std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto ret = proxy_graph::ProxyGraphTypeI::inst().infer_output_attrs_fallible( def, inputs); return ret; } SmallVector apply_on_physical_tensor( const OpDef& def, SmallVector inputs, SmallVector& output_descs, const bool& validated) { auto ret = proxy_graph::ProxyGraphTypeI::inst().apply_on_physical_tensor( def, inputs, output_descs, validated); return ret; } std::unordered_map> input_layout_constraints_cache; SmallVector get_input_layout_constraint( const OpDef& def, const SmallVector& inputs) { auto get_input_layout_constraint_hash_key = [](const OpDef& def, const SmallVector& inputs) { XXHash state; size_t length = 0, data[1 + inputs.size()]; data[length++] = def.hash(); for (auto&& i : inputs) { data[length++] = mgb::hash(i->comp_node()); } state.update(data, length * sizeof(size_t)); return state.digest(); }; auto hash_key = get_input_layout_constraint_hash_key(def, inputs); auto&& iter = input_layout_constraints_cache.find(hash_key); if (iter != input_layout_constraints_cache.end()) { return iter->second; } static cg::ComputingGraphImpl* graph = imperative::ResourceManager::create_global(); VarNodeArray vinputs(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { OperatorNodeConfig config; auto layout = inputs[i]->layout(); layout.init_contiguous_stride(); vinputs[i] = graph->insert_opr(std::make_unique( *graph, std::make_shared( inputs[i]->comp_node(), layout), false, config)) ->output(0); } auto&& opr = OpDef::apply_on_var_node(def, vinputs)[0]->owner_opr(); opr->add_input_layout_constraint(); SmallVector res(inputs.size()); auto& mem_mgr = graph->var_node_mem_manager(); for (size_t i = 0; i < vinputs.size(); ++i) { auto& trait = mem_mgr.get_var_node_mem_trait(vinputs[i]); switch (trait.layout_constraint.level) { case LayoutConstraintLevel::CONTIG: res[i] = [](const TensorLayout& layout) { return layout.is_contiguous(); }; break; case LayoutConstraintLevel::MONOTONE: res[i] = [&trait](const TensorLayout& layout) { if (!layout.is_abs_monotonous_allow_brdcst()) { return false; } for (auto&& i : trait.layout_constraint.custom) if (!i(layout)) return false; return true; }; break; case LayoutConstraintLevel::NONE: if (!trait.layout_constraint.custom.empty()) { res[i] = [&trait](const TensorLayout& layout) { for (auto&& i : trait.layout_constraint.custom) if (!i(layout)) return false; return true; }; } break; default: mgb_throw(InternalError, "invalid layout_constraint_level"); } } input_layout_constraints_cache.emplace(hash_key, res); return res; } } // namespace mgb::imperative::proxy_graph_detail