From e1c83d8d51704a7f70f9a4ca6274fb7e87838d4e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 5 Mar 2021 13:29:24 +0800 Subject: [PATCH] fix(mgb/core): add warning information about const_var_shape when record mode GitOrigin-RevId: a99f9c4e5ddf92c62aaa17fb7b1baf10b68a5411 --- src/core/impl/comp_node/cpu/comp_node.cpp | 26 +++++++++++------------ src/core/impl/comp_node/cpu/comp_node.h | 2 +- src/core/impl/graph/cg_impl_seq.cpp | 17 +++++++++++++++ 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/core/impl/comp_node/cpu/comp_node.cpp b/src/core/impl/comp_node/cpu/comp_node.cpp index ca274fd6f..68cbba901 100644 --- a/src/core/impl/comp_node/cpu/comp_node.cpp +++ b/src/core/impl/comp_node/cpu/comp_node.cpp @@ -243,7 +243,7 @@ public: }; using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl; -using CompNodeNoRecorderImpl = CpuCompNode::CompNodeNoRecorderImpl; +using CompNodeDefaultImpl = CpuCompNode::CompNodeDefaultImpl; using CompNodeRecorderImpl = CpuCompNode::CompNodeRecorderImpl; //! ==================== CompNodeBaseImpl ====================== @@ -466,29 +466,29 @@ public: } }; -//! ==================== CompNodeNoRecorderImpl ====================== +//! ==================== CompNodeDefaultImpl ====================== /** - * \note: CompNodeNoRecorderImpl will use most implements in base including: + * \note: CompNodeDefaultImpl will use most implements in base including: * alloc_device, alloc_host, copy_to_host, copy_to_device, peer_copy_to, * add_callback ... */ -class CpuCompNode::CompNodeNoRecorderImpl final : public CompNodeBaseImpl { +class CpuCompNode::CompNodeDefaultImpl final : public CompNodeBaseImpl { MGB_DYN_TYPE_OBJ_FINAL_DECL; public: //! ptr to default cpu, only used by check_global_finalized - static CompNodeNoRecorderImpl* sm_default_cpu_comp_node_ptr; + static CompNodeDefaultImpl* sm_default_cpu_comp_node_ptr; static void static_free_device(ImplBase* self, void* ptr) { - static_cast(self)->free_device(ptr); + static_cast(self)->free_device(ptr); } static void static_free_host(ImplBase* self, void* ptr) { - static_cast(self)->free_host(ptr); + static_cast(self)->free_host(ptr); } using CpuEventImpl = CpuDispatchableBase::EventImpl; - CompNodeNoRecorderImpl(const Locator& locator, + CompNodeDefaultImpl(const Locator& locator, const Locator& locator_logical) : CompNodeBaseImpl(locator, locator_logical, static_free_device, static_free_host) { @@ -501,7 +501,7 @@ public: sm_default_cpu_comp_node_ptr = this; } - ~CompNodeNoRecorderImpl() { + ~CompNodeDefaultImpl() { m_env.fini(); sm_default_cpu_comp_node_ptr = nullptr; } @@ -551,8 +551,8 @@ public: SeqRecorderImpl* cur_recorder() const override { return nullptr; } }; -MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeNoRecorderImpl); -CompNodeNoRecorderImpl* CompNodeNoRecorderImpl::sm_default_cpu_comp_node_ptr = +MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeDefaultImpl); +CompNodeDefaultImpl* CompNodeDefaultImpl::sm_default_cpu_comp_node_ptr = nullptr; //! ==================== CompNodeRecorderImpl ====================== @@ -746,7 +746,7 @@ public: void peer_copy_to(Impl* dest_impl, void* dest, const void* src, size_t size) override { //! copy to default_cpu - if (dest_impl->same_type()) { + if (dest_impl->same_type()) { CompNodeBaseImpl::peer_copy_to(dest_impl, dest, src, size); return; } @@ -986,7 +986,7 @@ void CpuCompNode::sync_all() { // CpuCompNode::Pool CompNode CompNode::default_cpu() { static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}}; - static CompNodeNoRecorderImpl impl{locator, locator}; + static CompNodeDefaultImpl impl{locator, locator}; return &impl; } diff --git a/src/core/impl/comp_node/cpu/comp_node.h b/src/core/impl/comp_node/cpu/comp_node.h index 0db224c9b..a7d228089 100644 --- a/src/core/impl/comp_node/cpu/comp_node.h +++ b/src/core/impl/comp_node/cpu/comp_node.h @@ -55,7 +55,7 @@ namespace mgb { }; class CompNodeBaseImpl; - class CompNodeNoRecorderImpl; + class CompNodeDefaultImpl; class CompNodeRecorderImpl; static void foreach(thin_function callback); diff --git a/src/core/impl/graph/cg_impl_seq.cpp b/src/core/impl/graph/cg_impl_seq.cpp index 540a3f27b..1ccead74b 100644 --- a/src/core/impl/graph/cg_impl_seq.cpp +++ b/src/core/impl/graph/cg_impl_seq.cpp @@ -11,6 +11,7 @@ #include "./cg_impl_seq.h" #include "megbrain/graph/exc_extra_info.h" +#include "megbrain/opr/tensor_manip.h" using namespace mgb; using namespace cg; @@ -255,6 +256,22 @@ ComputingGraphImpl::ComputingSequence::check_enable_comp_node_seq_recorder() { } } } + auto check_const_shape = [&]() { + for (auto i : *m_opr_seq) { + for (auto j : i->output()) { + if (j->shape().ndim && !is_const_var_shape(j)) { + mgb_log_warn( + "Non-const var shape detected. Make sure all " + "shapes are constant. Check whether " + "'const_var_shape' is set " + "in GraphLoadConfig under record mode"); + return; + } + } + } + }; + check_const_shape(); + auto cn = *m_used_comp_node.begin(); auto rec = cn.create_seq_recorder(m_owner_graph); if (!rec) { -- GitLab