提交 618faf64 编写于 作者: M Megvii Engine Team

feat(mgb/profiler): dump static infer info

GitOrigin-RevId: bb9150eb8320169236a786168290d368ac03cd5b
上级 e05c795b
...@@ -537,14 +537,28 @@ std::shared_ptr<json::Value> ComputingGraphImpl::ComputingSequence::to_json() ...@@ -537,14 +537,28 @@ std::shared_ptr<json::Value> ComputingGraphImpl::ComputingSequence::to_json()
comp_seq->add(json::String::make(i->id_str())); comp_seq->add(json::String::make(i->id_str()));
} }
// expand opr and var nodes that do not appear in comp seq // expand opr and var nodes that do not appear in comp seq,
// also expand var nodes which are only used in static infer
{ {
VarNodeArray new_var_node; VarNodeArray new_var_node;
auto&& mgr = m_owner_graph->static_infer_manager_impl();
auto check_opr_input = [&](OperatorNodeBase* opr) { auto check_opr_input = [&](OperatorNodeBase* opr) {
auto update = [&](VarNode* var) {
if (!(all_var_node.count(var))) {
all_var_node.insert(var);
new_var_node.push_back(var);
}
};
for (auto i : opr->input()) { for (auto i : opr->input()) {
if (!(all_var_node.count(i))) { update(i);
all_var_node.insert(i); }
new_var_node.push_back(i); for (auto &&out : opr->output()) {
using DepType = static_infer::DepType;
for (auto&& i : mgr.get_deps({out, DepType::SHAPE})) {
update(i.dest);
}
for (auto&& i : mgr.get_deps({out, DepType::VALUE})) {
update(i.dest);
} }
} }
}; };
......
...@@ -245,6 +245,9 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase, ...@@ -245,6 +245,9 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase,
return m_infer_withoutexc_ret; return m_infer_withoutexc_ret;
} }
//! original deps given in the InferDesc by the caller
virtual const DepVal& raw_deps() = 0;
protected: protected:
//! current infer result, to be used by dependents //! current infer result, to be used by dependents
InpElement m_inp_element; InpElement m_inp_element;
...@@ -300,9 +303,6 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase, ...@@ -300,9 +303,6 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase,
//! all missing inputs //! all missing inputs
SharedSet<TagHandler*, TagHandlerSet> m_missing_input; SharedSet<TagHandler*, TagHandlerSet> m_missing_input;
//! original deps given in the InferDesc by the caller
virtual const DepVal& raw_deps() = 0;
//! recursively set m_inp_element_synced of this and all receivers to //! recursively set m_inp_element_synced of this and all receivers to
//! false //! false
void reset_inp_element_synced(); void reset_inp_element_synced();
...@@ -1027,6 +1027,14 @@ void StaticInferManagerImpl::update_mutable_src_shape(Tag dest) { ...@@ -1027,6 +1027,14 @@ void StaticInferManagerImpl::update_mutable_src_shape(Tag dest) {
MGB_CATCH(MegBrainError & exc, { update_rethrow_exc(dest, exc); }) MGB_CATCH(MegBrainError & exc, { update_rethrow_exc(dest, exc); })
} }
DepVal StaticInferManagerImpl::get_deps(const DepElement &elem) {
auto trait_base = get_tag_trait_container(elem.dest).select(elem.type);
if (!trait_base || trait_base->is_const())
return {};
return trait_base->as_mutable_safe()->raw_deps();
}
/* ===================== CompSeqManager ===================== */ /* ===================== CompSeqManager ===================== */
class CompSeqManager::VersionedTagTrait { class CompSeqManager::VersionedTagTrait {
......
...@@ -99,6 +99,17 @@ class StaticInferManagerImpl final: public StaticInferManager { ...@@ -99,6 +99,17 @@ class StaticInferManagerImpl final: public StaticInferManager {
*/ */
void update_mutable_src_shape(Tag tag); void update_mutable_src_shape(Tag tag);
/*!
* \brief get original deps given in the InferDesc which is registered
* by register_shape_infer or register_value_infer
*
* Note: the \p elem with DepType::SHAPE and InferType::CONST shows no
* deps since the StaticInferManagerImpl folds the infererence chain of
* the const var shape
*/
DepVal get_deps(const DepElement &elem);
private: private:
friend class CompSeqManager; friend class CompSeqManager;
......
...@@ -396,6 +396,108 @@ VarNode& VarNode::comp_node(const CompNode &cn) { ...@@ -396,6 +396,108 @@ VarNode& VarNode::comp_node(const CompNode &cn) {
} }
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
std::shared_ptr<json::Value>
VarNode::dump_static_infer_info_to_json() const {
using namespace cg::static_infer;
auto&& mgr = static_cast<cg::ComputingGraphImpl*>(
owner_graph())->static_infer_manager_impl();
auto get_dep_type = [](const DepType& type) -> std::string {
switch (type) {
#define cb(name) \
case DepType::name: \
return #name;
cb(SHAPE)
cb(VALUE)
#undef cb
default:
mgb_throw(MegBrainError, "unknown dep type");
}
};
auto get_infer_type = [](const InferType::Flag& type) {
switch (type) {
#define cb(name) \
case InferType::Flag::name: \
return json::String::make(#name);
cb(NO_DESC)
cb(CONST)
cb(RT_STATIC)
cb(MISSING_INP)
#undef cb
default:
mgb_throw(MegBrainError, "unknown infer type");
}
};
auto make_tag = [&](const DepType& type) {
VarNode* self = const_cast<VarNode*>(this);
auto c_deps = mgr.get_deps({self, type});
auto deps = json::Array::make();
for (auto&& i : c_deps) {
mgb_assert(i.dest);
deps->add(json::Object::make({
{"var", json::String::make(i.dest->id_str())},
{"dep_type", json::String::make(get_dep_type(i.type))}
}));
}
auto infer_type_handle = mgr.get_infer_type(self);
auto inferred_result = json::Null::make();
auto infer_type = type == DepType::SHAPE ? infer_type_handle.shape
: infer_type_handle.value;
if (infer_type != InferType::Flag::NO_DESC) {
if (type == DepType::SHAPE) {
if (auto shape = mgr.infer_shape_fallible(self)) {
auto inferred_shape = json::Array::make();
for (size_t i = 0; i < shape->ndim; ++ i) {
inferred_shape->add(json::Number::make((*shape)[i]));
}
inferred_result = inferred_shape;
}
} else {
if (auto p = mgr.infer_value_fallible(self)) {
auto&& dev = *p;
if (dev.shape().ndim == 1 &&
dev.shape(0) < TensorShape::MAX_NDIM &&
mgb_likely(dev.comp_node() == CompNode::default_cpu())) {
MGB_TRY {
size_t nr_elems = dev.shape(0);
auto&& dtype = dev.dtype();
void* vptr = dev.raw_ptr();
double data[nr_elems];
HostTensorND contig;
if (!dev.layout().is_contiguous()) {
// both src and dst are placed on default cpu,
// no need for sync
contig.copy_from(dev);
mgb_assert(contig.layout().is_contiguous());
vptr = contig.raw_ptr();
}
static_cast_dtype(data, dtype, vptr, nr_elems);
auto inferred_value = json::Array::make();
for (size_t i = 0; i < nr_elems; ++ i) {
inferred_value->add(json::Number::make(data[i]));
}
inferred_result = inferred_value;
}
MGB_CATCH(ConversionError&, {});
} else {
inferred_result = json::String::make("Large Array");
}
}
}
}
return json::Object::make({
{"node_type", json::String::make("static_infer_tag")},
{"infer_type", get_infer_type(infer_type)},
{"inferred_result", inferred_result},
{"deps", deps}
});
};
return json::Object::make({
#define TAG(type) {get_dep_type(type), make_tag(type)}
TAG(DepType::SHAPE), TAG(DepType::VALUE)
#undef TAG
});
}
std::shared_ptr<json::Value> VarNode::to_json() const { std::shared_ptr<json::Value> VarNode::to_json() const {
auto get_var = [](VarNode *p) -> std::shared_ptr<json::Value> { auto get_var = [](VarNode *p) -> std::shared_ptr<json::Value> {
if(p) if(p)
...@@ -443,8 +545,10 @@ std::shared_ptr<json::Value> VarNode::to_json() const { ...@@ -443,8 +545,10 @@ std::shared_ptr<json::Value> VarNode::to_json() const {
{"dev_ptr", json::Null::make()}, {"dev_ptr", json::Null::make()},
{"prev_dev_ptr", json::NumberInt::make(reinterpret_cast<size_t>( {"prev_dev_ptr", json::NumberInt::make(reinterpret_cast<size_t>(
m_prev_dev_ptr))}, m_prev_dev_ptr))},
{"flag", flag} {"flag", flag},
{"static_infer_tags", dump_static_infer_info_to_json()}
}); });
if (m_prev_dev_ptr) { if (m_prev_dev_ptr) {
(*rst)["prev_dev_ptr_end"] = json::NumberInt::make( (*rst)["prev_dev_ptr_end"] = json::NumberInt::make(
reinterpret_cast<size_t>(m_prev_dev_ptr) + reinterpret_cast<size_t>(m_prev_dev_ptr) +
......
...@@ -575,6 +575,10 @@ class VarNode final: public GraphNodeBase { ...@@ -575,6 +575,10 @@ class VarNode final: public GraphNodeBase {
void assign_dev_tensor_from_tensor(const DeviceTensorND &value); void assign_dev_tensor_from_tensor(const DeviceTensorND &value);
#if MGB_ENABLE_JSON
std::shared_ptr<json::Value> dump_static_infer_info_to_json() const;
#endif
friend class static_infer::StaticInferManagerImpl; friend class static_infer::StaticInferManagerImpl;
friend class VarNodeMemManager; friend class VarNodeMemManager;
friend class VarDevMemDefragmenter; friend class VarDevMemDefragmenter;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册