diff --git a/src/core/impl/graph/cg_impl_seq.cpp b/src/core/impl/graph/cg_impl_seq.cpp index 89a30a165588a7fc42a4edc6adf7da3e817e6e79..8d2cb9b3d41abb44fb3904c187469cf3758b4925 100644 --- a/src/core/impl/graph/cg_impl_seq.cpp +++ b/src/core/impl/graph/cg_impl_seq.cpp @@ -537,14 +537,28 @@ std::shared_ptr ComputingGraphImpl::ComputingSequence::to_json() comp_seq->add(json::String::make(i->id_str())); } - // expand opr and var nodes that do not appear in comp seq + // expand opr and var nodes that do not appear in comp seq, + // also expand var nodes which are only used in static infer { VarNodeArray new_var_node; + auto&& mgr = m_owner_graph->static_infer_manager_impl(); auto check_opr_input = [&](OperatorNodeBase* opr) { + auto update = [&](VarNode* var) { + if (!(all_var_node.count(var))) { + all_var_node.insert(var); + new_var_node.push_back(var); + } + }; for (auto i : opr->input()) { - if (!(all_var_node.count(i))) { - all_var_node.insert(i); - new_var_node.push_back(i); + update(i); + } + for (auto &&out : opr->output()) { + using DepType = static_infer::DepType; + for (auto&& i : mgr.get_deps({out, DepType::SHAPE})) { + update(i.dest); + } + for (auto&& i : mgr.get_deps({out, DepType::VALUE})) { + update(i.dest); } } }; diff --git a/src/core/impl/graph/static_infer_impl.cpp b/src/core/impl/graph/static_infer_impl.cpp index ed3ffbc6cb54704629a46460d1b9a3cf737c09d0..15808a018ab2bb2388881d0a5c7e110a0dbb3981 100644 --- a/src/core/impl/graph/static_infer_impl.cpp +++ b/src/core/impl/graph/static_infer_impl.cpp @@ -245,6 +245,9 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase, return m_infer_withoutexc_ret; } + //! original deps given in the InferDesc by the caller + virtual const DepVal& raw_deps() = 0; + protected: //! current infer result, to be used by dependents InpElement m_inp_element; @@ -300,9 +303,6 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase, //! all missing inputs SharedSet m_missing_input; - //! original deps given in the InferDesc by the caller - virtual const DepVal& raw_deps() = 0; - //! recursively set m_inp_element_synced of this and all receivers to //! false void reset_inp_element_synced(); @@ -1027,6 +1027,14 @@ void StaticInferManagerImpl::update_mutable_src_shape(Tag dest) { MGB_CATCH(MegBrainError & exc, { update_rethrow_exc(dest, exc); }) } +DepVal StaticInferManagerImpl::get_deps(const DepElement &elem) { + auto trait_base = get_tag_trait_container(elem.dest).select(elem.type); + if (!trait_base || trait_base->is_const()) + return {}; + + return trait_base->as_mutable_safe()->raw_deps(); +} + /* ===================== CompSeqManager ===================== */ class CompSeqManager::VersionedTagTrait { diff --git a/src/core/impl/graph/static_infer_impl.h b/src/core/impl/graph/static_infer_impl.h index ca052846d072dc9f779ae81c30a19a3758be35ea..e8adc4aef8ac47b8a0bfb00f6d87dbcc97cff29a 100644 --- a/src/core/impl/graph/static_infer_impl.h +++ b/src/core/impl/graph/static_infer_impl.h @@ -99,6 +99,17 @@ class StaticInferManagerImpl final: public StaticInferManager { */ void update_mutable_src_shape(Tag tag); + + /*! + * \brief get original deps given in the InferDesc which is registered + * by register_shape_infer or register_value_infer + * + * Note: the \p elem with DepType::SHAPE and InferType::CONST shows no + * deps since the StaticInferManagerImpl folds the infererence chain of + * the const var shape + */ + DepVal get_deps(const DepElement &elem); + private: friend class CompSeqManager; diff --git a/src/core/impl/graph/var_node.cpp b/src/core/impl/graph/var_node.cpp index 1a136d9bb6923cf8638f642374148c04dd14347b..6cfae871ae6b5c9bf2c3b71b3c6e9f0d440c7808 100644 --- a/src/core/impl/graph/var_node.cpp +++ b/src/core/impl/graph/var_node.cpp @@ -396,6 +396,108 @@ VarNode& VarNode::comp_node(const CompNode &cn) { } #if MGB_ENABLE_JSON +std::shared_ptr +VarNode::dump_static_infer_info_to_json() const { + using namespace cg::static_infer; + auto&& mgr = static_cast( + owner_graph())->static_infer_manager_impl(); + auto get_dep_type = [](const DepType& type) -> std::string { + switch (type) { +#define cb(name) \ +case DepType::name: \ + return #name; + cb(SHAPE) + cb(VALUE) +#undef cb + default: + mgb_throw(MegBrainError, "unknown dep type"); + } + }; + auto get_infer_type = [](const InferType::Flag& type) { + switch (type) { +#define cb(name) \ +case InferType::Flag::name: \ + return json::String::make(#name); + cb(NO_DESC) + cb(CONST) + cb(RT_STATIC) + cb(MISSING_INP) +#undef cb + default: + mgb_throw(MegBrainError, "unknown infer type"); + } + }; + auto make_tag = [&](const DepType& type) { + VarNode* self = const_cast(this); + auto c_deps = mgr.get_deps({self, type}); + auto deps = json::Array::make(); + for (auto&& i : c_deps) { + mgb_assert(i.dest); + deps->add(json::Object::make({ + {"var", json::String::make(i.dest->id_str())}, + {"dep_type", json::String::make(get_dep_type(i.type))} + })); + } + auto infer_type_handle = mgr.get_infer_type(self); + auto inferred_result = json::Null::make(); + auto infer_type = type == DepType::SHAPE ? infer_type_handle.shape + : infer_type_handle.value; + if (infer_type != InferType::Flag::NO_DESC) { + if (type == DepType::SHAPE) { + if (auto shape = mgr.infer_shape_fallible(self)) { + auto inferred_shape = json::Array::make(); + for (size_t i = 0; i < shape->ndim; ++ i) { + inferred_shape->add(json::Number::make((*shape)[i])); + } + inferred_result = inferred_shape; + } + } else { + if (auto p = mgr.infer_value_fallible(self)) { + auto&& dev = *p; + if (dev.shape().ndim == 1 && + dev.shape(0) < TensorShape::MAX_NDIM && + mgb_likely(dev.comp_node() == CompNode::default_cpu())) { + MGB_TRY { + size_t nr_elems = dev.shape(0); + auto&& dtype = dev.dtype(); + void* vptr = dev.raw_ptr(); + double data[nr_elems]; + HostTensorND contig; + if (!dev.layout().is_contiguous()) { + // both src and dst are placed on default cpu, + // no need for sync + contig.copy_from(dev); + mgb_assert(contig.layout().is_contiguous()); + vptr = contig.raw_ptr(); + } + static_cast_dtype(data, dtype, vptr, nr_elems); + auto inferred_value = json::Array::make(); + for (size_t i = 0; i < nr_elems; ++ i) { + inferred_value->add(json::Number::make(data[i])); + } + inferred_result = inferred_value; + } + MGB_CATCH(ConversionError&, {}); + } else { + inferred_result = json::String::make("Large Array"); + } + } + } + } + return json::Object::make({ + {"node_type", json::String::make("static_infer_tag")}, + {"infer_type", get_infer_type(infer_type)}, + {"inferred_result", inferred_result}, + {"deps", deps} + }); + }; + return json::Object::make({ +#define TAG(type) {get_dep_type(type), make_tag(type)} + TAG(DepType::SHAPE), TAG(DepType::VALUE) +#undef TAG + }); +} + std::shared_ptr VarNode::to_json() const { auto get_var = [](VarNode *p) -> std::shared_ptr { if(p) @@ -443,8 +545,10 @@ std::shared_ptr VarNode::to_json() const { {"dev_ptr", json::Null::make()}, {"prev_dev_ptr", json::NumberInt::make(reinterpret_cast( m_prev_dev_ptr))}, - {"flag", flag} + {"flag", flag}, + {"static_infer_tags", dump_static_infer_info_to_json()} }); + if (m_prev_dev_ptr) { (*rst)["prev_dev_ptr_end"] = json::NumberInt::make( reinterpret_cast(m_prev_dev_ptr) + diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index 93182edc6da13ca96911140c2a19dde61201298e..f9a5d4d262fb47f2039b587fdf6457e2d5d169dd 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -575,6 +575,10 @@ class VarNode final: public GraphNodeBase { void assign_dev_tensor_from_tensor(const DeviceTensorND &value); +#if MGB_ENABLE_JSON + std::shared_ptr dump_static_infer_info_to_json() const; +#endif + friend class static_infer::StaticInferManagerImpl; friend class VarNodeMemManager; friend class VarDevMemDefragmenter;