提交 d42cf4cd 编写于 作者: M Megvii Engine Team

refactor(mgb): replace static_cast<ComputingGraphImpl*> with a checked version

GitOrigin-RevId: d05b114668f84f9b4045abe55af11bce0ff3bd4a
上级 4d56371e
...@@ -129,7 +129,7 @@ else() ...@@ -129,7 +129,7 @@ else()
if(ANDROID) if(ANDROID)
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -DNDEBUG") set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -DNDEBUG -g") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -DNDEBUG -g")
else() else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG") set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -DNDEBUG -g") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -DNDEBUG -g")
...@@ -224,6 +224,7 @@ endif() ...@@ -224,6 +224,7 @@ endif()
option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON)
option(MGE_BUILD_XXX "Build _xxx.so instead of mgb.so " OFF) option(MGE_BUILD_XXX "Build _xxx.so instead of mgb.so " OFF)
if(MGE_BUILD_XXX) if(MGE_BUILD_XXX)
add_compile_definitions(MGB_ENABLE_IMPERATIVE_RUNTIME)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
endif() endif()
...@@ -662,14 +663,14 @@ endif() ...@@ -662,14 +663,14 @@ endif()
configure_file(cmake/megengine.pc.in configure_file(cmake/megengine.pc.in
${CMAKE_CURRENT_BINARY_DIR}/megengine.pc ${CMAKE_CURRENT_BINARY_DIR}/megengine.pc
@ONLY) @ONLY)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/megengine.pc install(FILES ${CMAKE_CURRENT_BINARY_DIR}/megengine.pc
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
# Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready. # Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready.
if (NOT MGE_WITH_DISTRIBUTED) if (NOT MGE_WITH_DISTRIBUTED)
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
set (MGE_INSTALL_CMAKEDIR ${CMAKE_INSTALL_LIBDIR}/cmake/MegEngine) set (MGE_INSTALL_CMAKEDIR ${CMAKE_INSTALL_LIBDIR}/cmake/MegEngine)
configure_package_config_file(cmake/MegEngineConfig.cmake.in configure_package_config_file(cmake/MegEngineConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/MegEngineConfig.cmake ${CMAKE_CURRENT_BINARY_DIR}/MegEngineConfig.cmake
INSTALL_DESTINATION ${MGE_INSTALL_CMAKEDIR} INSTALL_DESTINATION ${MGE_INSTALL_CMAKEDIR}
) )
......
...@@ -674,7 +674,7 @@ void ComputingGraphImpl::share_device_memory_with(ComputingGraph& other) { ...@@ -674,7 +674,7 @@ void ComputingGraphImpl::share_device_memory_with(ComputingGraph& other) {
mgb_assert( mgb_assert(
!m_current_comp_seq, !m_current_comp_seq,
"share_device_memory_with must be called before compiling graph"); "share_device_memory_with must be called before compiling graph");
auto&& oimpl = static_cast<ComputingGraphImpl&>(other); auto&& oimpl = *ComputingGraphImpl::downcast(&other);
var_node_mem_manager().static_device_memory_manager( var_node_mem_manager().static_device_memory_manager(
oimpl.var_node_mem_manager().static_device_memory_manager()); oimpl.var_node_mem_manager().static_device_memory_manager());
} }
...@@ -707,7 +707,7 @@ size_t ComputingGraphImpl::clear_device_memory() { ...@@ -707,7 +707,7 @@ size_t ComputingGraphImpl::clear_device_memory() {
} }
void ComputingGraphImpl::set_as_subgraph(ComputingGraph& par_graph) { void ComputingGraphImpl::set_as_subgraph(ComputingGraph& par_graph) {
m_parent_graph = static_cast<ComputingGraphImpl*>(&par_graph); m_parent_graph = ComputingGraphImpl::downcast(&par_graph);
m_parent_graph->m_subgraphs.emplace_back(this); m_parent_graph->m_subgraphs.emplace_back(this);
m_node_id_counter = m_parent_graph->m_node_id_counter; m_node_id_counter = m_parent_graph->m_node_id_counter;
options().var_sanity_check_first_run = options().var_sanity_check_first_run =
......
...@@ -122,6 +122,15 @@ public: ...@@ -122,6 +122,15 @@ public:
ComputingGraphImpl(); ComputingGraphImpl();
~ComputingGraphImpl(); ~ComputingGraphImpl();
template<typename T> static ComputingGraphImpl* downcast(T* ptr) = delete;
inline static ComputingGraphImpl* downcast(ComputingGraph* graph) {
#ifdef MGB_ENABLE_IMPERATIVE_RUNTIME
mgb_assert(!graph->options().imperative_proxy_graph);
#endif
return static_cast<ComputingGraphImpl*>(graph);
}
friend struct ComputingGraph::Options; friend struct ComputingGraph::Options;
std::unique_ptr<AsyncExecutable> compile( std::unique_ptr<AsyncExecutable> compile(
......
...@@ -100,7 +100,7 @@ class ComputingGraphImpl::ComputingSequence final : public AsyncExecutable { ...@@ -100,7 +100,7 @@ class ComputingGraphImpl::ComputingSequence final : public AsyncExecutable {
public: public:
ComputingSequence(const std::shared_ptr<ComputingGraph>& graph) ComputingSequence(const std::shared_ptr<ComputingGraph>& graph)
: m_owner_graph_refkeep{graph}, : m_owner_graph_refkeep{graph},
m_owner_graph{static_cast<ComputingGraphImpl*>(graph.get())}, m_owner_graph{ComputingGraphImpl::downcast(graph.get())},
m_have_parent_graph{m_owner_graph->m_parent_graph} {} m_have_parent_graph{m_owner_graph->m_parent_graph} {}
GraphExecutable::ExecEnv& exec_env() { return m_exec_env; } GraphExecutable::ExecEnv& exec_env() { return m_exec_env; }
......
...@@ -76,7 +76,7 @@ EagerEvalManager::~EagerEvalManager() noexcept { ...@@ -76,7 +76,7 @@ EagerEvalManager::~EagerEvalManager() noexcept {
if (m_first_opr_enable_status == 1) { if (m_first_opr_enable_status == 1) {
m_var_sync_mgr_pool.disable_freelist(); m_var_sync_mgr_pool.disable_freelist();
for (auto&& i : for (auto&& i :
static_cast<ComputingGraphImpl*>(m_owner_graph)->all_oprs()) { ComputingGraphImpl::downcast(m_owner_graph)->all_oprs()) {
for (auto var : i->output()) { for (auto var : i->output()) {
auto mgr = VarNodeMemManager::var_node_cn_sync_manager(var); auto mgr = VarNodeMemManager::var_node_cn_sync_manager(var);
if (mgr) { if (mgr) {
...@@ -223,7 +223,7 @@ void EagerEvalManager::prepare_for_exec(OperatorNodeBase* opr) { ...@@ -223,7 +223,7 @@ void EagerEvalManager::prepare_for_exec(OperatorNodeBase* opr) {
} }
void EagerEvalManager::update_static_infer_result(OperatorNodeBase* opr) { void EagerEvalManager::update_static_infer_result(OperatorNodeBase* opr) {
auto&& mgr = static_cast<ComputingGraphImpl*>(m_owner_graph) auto&& mgr = ComputingGraphImpl::downcast(m_owner_graph)
->static_infer_manager_impl(); ->static_infer_manager_impl();
auto sync_missing_trait = auto sync_missing_trait =
[&](static_infer::StaticInferManagerImpl::TagHandler* handler) { [&](static_infer::StaticInferManagerImpl::TagHandler* handler) {
...@@ -260,7 +260,7 @@ void EagerEvalManager::update_static_infer_result(OperatorNodeBase* opr) { ...@@ -260,7 +260,7 @@ void EagerEvalManager::update_static_infer_result(OperatorNodeBase* opr) {
} }
void EagerEvalManager::ensure_input_layout(VarNode* var) { void EagerEvalManager::ensure_input_layout(VarNode* var) {
auto&& mem_mgr = static_cast<ComputingGraphImpl*>(var->owner_graph()) auto&& mem_mgr = ComputingGraphImpl::downcast(var->owner_graph())
->var_node_mem_manager(); ->var_node_mem_manager();
auto trait = mem_mgr.get_var_node_mem_trait_nullable(var); auto trait = mem_mgr.get_var_node_mem_trait_nullable(var);
...@@ -287,7 +287,7 @@ void EagerEvalManager::alloc_output_mem(OperatorNodeBase* opr) { ...@@ -287,7 +287,7 @@ void EagerEvalManager::alloc_output_mem(OperatorNodeBase* opr) {
} }
} }
auto&& mgr = static_cast<ComputingGraphImpl*>(m_owner_graph) auto&& mgr = ComputingGraphImpl::downcast(m_owner_graph)
->var_node_mem_manager(); ->var_node_mem_manager();
OprNodeArray opr_seq{opr}; OprNodeArray opr_seq{opr};
...@@ -348,7 +348,7 @@ void EagerEvalManager::do_on_opr_insert(OperatorNodeBase* opr) { ...@@ -348,7 +348,7 @@ void EagerEvalManager::do_on_opr_insert(OperatorNodeBase* opr) {
if (status) { if (status) {
update_static_infer_result(opr); update_static_infer_result(opr);
alloc_output_mem(opr); alloc_output_mem(opr);
auto&& mgr = static_cast<ComputingGraphImpl*>(m_owner_graph) auto&& mgr = ComputingGraphImpl::downcast(m_owner_graph)
->var_node_mem_manager(); ->var_node_mem_manager();
mgr.on_graph_compile_finished(); mgr.on_graph_compile_finished();
opr->execute(*m_exec_env); opr->execute(*m_exec_env);
......
...@@ -40,7 +40,7 @@ class GradShapeChecker { ...@@ -40,7 +40,7 @@ class GradShapeChecker {
void do_on_var_shape(VarNode *var) { void do_on_var_shape(VarNode *var) {
MGB_MARK_USED_VAR(m_opr); MGB_MARK_USED_VAR(m_opr);
auto graph = static_cast<ComputingGraphImpl*>(var->owner_graph()); auto graph = ComputingGraphImpl::downcast(var->owner_graph());
auto seq = graph->current_comp_seq(); auto seq = graph->current_comp_seq();
if (seq) { if (seq) {
...@@ -90,7 +90,7 @@ class GradShapeChecker { ...@@ -90,7 +90,7 @@ class GradShapeChecker {
} }
static void make(OperatorNodeBase *opr, VarNode *wrt, VarNode *grad) { static void make(OperatorNodeBase *opr, VarNode *wrt, VarNode *grad) {
if (static_cast<ComputingGraphImpl*>(wrt->owner_graph()) if (ComputingGraphImpl::downcast(wrt->owner_graph())
->eager_eval_manager().enabled()) ->eager_eval_manager().enabled())
return; return;
using namespace std::placeholders; using namespace std::placeholders;
...@@ -650,13 +650,13 @@ void GradManager::add_var_virtual_receiver( ...@@ -650,13 +650,13 @@ void GradManager::add_var_virtual_receiver(
} }
void cg::add_grad_transformer(VarNode *var, const GradTransformer &cb) { void cg::add_grad_transformer(VarNode *var, const GradTransformer &cb) {
static_cast<ComputingGraphImpl*>(var->owner_graph())-> ComputingGraphImpl::downcast(var->owner_graph())->
grad_manager(). grad_manager().
add_grad_transformer(var, cb); add_grad_transformer(var, cb);
} }
void cg::add_extra_dep_for_grad(VarNode *inp, VarNode *out) { void cg::add_extra_dep_for_grad(VarNode *inp, VarNode *out) {
static_cast<ComputingGraphImpl*>(inp->owner_graph())->grad_manager(). ComputingGraphImpl::downcast(inp->owner_graph())->grad_manager().
add_extra_dep_for_grad(inp, out); add_extra_dep_for_grad(inp, out);
} }
...@@ -667,7 +667,7 @@ void cg::add_var_virtual_receiver( ...@@ -667,7 +667,7 @@ void cg::add_var_virtual_receiver(
desc->inputs = inputs; desc->inputs = inputs;
desc->outputs = outputs; desc->outputs = outputs;
desc->grad = grad; desc->grad = grad;
static_cast<ComputingGraphImpl*>(inputs.at(0)->owner_graph())-> ComputingGraphImpl::downcast(inputs.at(0)->owner_graph())->
grad_manager(). grad_manager().
add_var_virtual_receiver(desc); add_var_virtual_receiver(desc);
} }
......
...@@ -99,8 +99,8 @@ SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_w ...@@ -99,8 +99,8 @@ SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_w
grads.reserve(wrts_.size()); grads.reserve(wrts_.size());
VarNodeArray dest_vars; VarNodeArray dest_vars;
auto&& graph = target->owner_graph(); auto&& graph = target->owner_graph();
auto&& eager_mgr = static_cast<ComputingGraphImpl*>(graph)->eager_eval_manager(); auto&& eager_mgr = ComputingGraphImpl::downcast(graph)->eager_eval_manager();
auto&& grad_mgr = static_cast<ComputingGraphImpl*>(graph)->grad_manager(); auto&& grad_mgr = ComputingGraphImpl::downcast(graph)->grad_manager();
bool already_recorded = eager_mgr.enter_record_mode(); bool already_recorded = eager_mgr.enter_record_mode();
for (auto&& wrt_ : wrts_) { for (auto&& wrt_ : wrts_) {
auto wrt = wrt_.node(); auto wrt = wrt_.node();
...@@ -139,7 +139,7 @@ SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_w ...@@ -139,7 +139,7 @@ SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_w
SymbolVar cg::current_grad_target(ComputingGraph &graph) { SymbolVar cg::current_grad_target(ComputingGraph &graph) {
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
auto var = static_cast<ComputingGraphImpl&>(graph).grad_manager( auto var = ComputingGraphImpl::downcast(&graph)->grad_manager(
).current_grad_target(); ).current_grad_target();
mgb_throw_if(!var, GraphError, "current_grad_target() called outside " mgb_throw_if(!var, GraphError, "current_grad_target() called outside "
"grad computing environment"); "grad computing environment");
......
...@@ -93,7 +93,7 @@ OperatorNodeBase::OperatorNodeBase(ComputingGraph *owner, ...@@ -93,7 +93,7 @@ OperatorNodeBase::OperatorNodeBase(ComputingGraph *owner,
} }
OperatorNodeBase::~OperatorNodeBase() noexcept { OperatorNodeBase::~OperatorNodeBase() noexcept {
auto &&pool = static_cast<ComputingGraphImpl*>( auto &&pool = ComputingGraphImpl::cast(
owner_graph())->var_node_pool(); owner_graph())->var_node_pool();
for (auto i: m_output) { for (auto i: m_output) {
pool.free(i); pool.free(i);
...@@ -124,7 +124,7 @@ void OperatorNodeBase::execute(ExecEnv &env) { ...@@ -124,7 +124,7 @@ void OperatorNodeBase::execute(ExecEnv &env) {
} }
// allocate output with dynamic storage // allocate output with dynamic storage
static_cast<ComputingGraphImpl*>(owner_graph()) ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager() ->var_node_mem_manager()
.alloc_var_node_mem_dynamic(env, this); .alloc_var_node_mem_dynamic(env, this);
...@@ -135,11 +135,11 @@ void OperatorNodeBase::execute(ExecEnv &env) { ...@@ -135,11 +135,11 @@ void OperatorNodeBase::execute(ExecEnv &env) {
// static_infer_manager so the value would be up-to-date; however for shape // static_infer_manager so the value would be up-to-date; however for shape
// deps, oprs would access the shape directly, so we need to insert some // deps, oprs would access the shape directly, so we need to insert some
// code here to ensure it is up-to-date. // code here to ensure it is up-to-date.
if (!static_cast<ComputingGraphImpl*>(owner_graph()) if (!ComputingGraphImpl::downcast(owner_graph())
->eager_eval_manager() ->eager_eval_manager()
.enabled()) { .enabled()) {
VarNodeArray vars_to_set; VarNodeArray vars_to_set;
auto cg = static_cast<ComputingGraphImpl*>(owner_graph()); auto cg = ComputingGraphImpl::downcast(owner_graph());
auto step_cur = cg->opr_step_num_in_cur_comp_seq(this).val(); auto step_cur = cg->opr_step_num_in_cur_comp_seq(this).val();
mgb_assert(step_cur < std::numeric_limits<size_t>::max()); mgb_assert(step_cur < std::numeric_limits<size_t>::max());
using DT = NodeProp::DepType; using DT = NodeProp::DepType;
...@@ -264,7 +264,7 @@ VarNode* OperatorNodeBase::add_output(const Maybe<std::string> &name) { ...@@ -264,7 +264,7 @@ VarNode* OperatorNodeBase::add_output(const Maybe<std::string> &name) {
mgb_assert(!m_inserted_in_graph && !m_node_prop.valid(), mgb_assert(!m_inserted_in_graph && !m_node_prop.valid(),
"add output on opr after it has been inserted into graph"); "add output on opr after it has been inserted into graph");
auto ptr = static_cast<ComputingGraphImpl*>( auto ptr = ComputingGraphImpl::cast(
owner_graph())->var_node_pool().alloc( owner_graph())->var_node_pool().alloc(
name.valid() ? this->name() + ":" + name.val() : name, this); name.valid() ? this->name() + ":" + name.val() : name, this);
m_output.push_back(ptr); m_output.push_back(ptr);
...@@ -676,7 +676,7 @@ void mixin::IOSameShapeOperatorNode::get_output_var_shape( ...@@ -676,7 +676,7 @@ void mixin::IOSameShapeOperatorNode::get_output_var_shape(
void PostExecActions::add(VarNode* var) { void PostExecActions::add(VarNode* var) {
mgb_assert(m_comp_node == var->comp_node()); mgb_assert(m_comp_node == var->comp_node());
auto graph = static_cast<ComputingGraphImpl*>(var->owner_graph()); auto graph = ComputingGraphImpl::downcast(var->owner_graph());
auto&& infer_mgr = graph->static_infer_manager_impl(); auto&& infer_mgr = graph->static_infer_manager_impl();
auto&& extra_info = graph->current_comp_seq_extra_info(); auto&& extra_info = graph->current_comp_seq_extra_info();
......
...@@ -813,7 +813,7 @@ StaticInferManagerImpl::~StaticInferManagerImpl() noexcept { ...@@ -813,7 +813,7 @@ StaticInferManagerImpl::~StaticInferManagerImpl() noexcept {
m_mem_pool_value_trait.disable_freelist(); m_mem_pool_value_trait.disable_freelist();
for (auto &&i: m_dtor_callbacks) for (auto &&i: m_dtor_callbacks)
i.second(); i.second();
for (auto &&i: static_cast<ComputingGraphImpl*>( for (auto &&i: ComputingGraphImpl::downcast(
m_owner_graph)->all_oprs()) { m_owner_graph)->all_oprs()) {
for (auto j: i->output()) { for (auto j: i->output()) {
clear_tag_handler(j); clear_tag_handler(j);
...@@ -1212,7 +1212,7 @@ class StaticInferManagerImpl::SubgraphStaticInferHelperImpl final: ...@@ -1212,7 +1212,7 @@ class StaticInferManagerImpl::SubgraphStaticInferHelperImpl final:
void check_graph_par(VarNode *var) { void check_graph_par(VarNode *var) {
if (mgb_unlikely(!m_par_graph)) { if (mgb_unlikely(!m_par_graph)) {
m_par_graph = static_cast<ComputingGraphImpl*>(var->owner_graph()); m_par_graph = ComputingGraphImpl::downcast(var->owner_graph());
mgb_assert(m_par_graph != m_sub_graph); mgb_assert(m_par_graph != m_sub_graph);
auto cb = [this]() { auto cb = [this]() {
...@@ -1230,7 +1230,7 @@ class StaticInferManagerImpl::SubgraphStaticInferHelperImpl final: ...@@ -1230,7 +1230,7 @@ class StaticInferManagerImpl::SubgraphStaticInferHelperImpl final:
void check_graph_sub(VarNode *var) { void check_graph_sub(VarNode *var) {
if (mgb_unlikely(!m_sub_graph)) { if (mgb_unlikely(!m_sub_graph)) {
m_sub_graph = static_cast<ComputingGraphImpl*>(var->owner_graph()); m_sub_graph = ComputingGraphImpl::downcast(var->owner_graph());
mgb_assert(m_sub_graph != m_par_graph); mgb_assert(m_sub_graph != m_par_graph);
} else { } else {
mgb_assert(m_sub_graph == var->owner_graph()); mgb_assert(m_sub_graph == var->owner_graph());
......
...@@ -132,7 +132,7 @@ const DeviceTensorND& SymbolVar::eager_eval_get_value() const { ...@@ -132,7 +132,7 @@ const DeviceTensorND& SymbolVar::eager_eval_get_value() const {
#if MGB_BUILD_SLIM_SERVING #if MGB_BUILD_SLIM_SERVING
mgb_throw(MegBrainError, "eager eval disabled at compile time"); mgb_throw(MegBrainError, "eager eval disabled at compile time");
#else #else
auto og = static_cast<ComputingGraphImpl*>(node()->owner_graph()); auto og = ComputingGraphImpl::downcast(node()->owner_graph());
mgb_assert(og->options().eager_evaluation); mgb_assert(og->options().eager_evaluation);
return node()->dev_tensor(); return node()->dev_tensor();
#endif #endif
......
...@@ -260,7 +260,7 @@ void TopoSorter::DFSDepDiscover::proc_add_dep_comp_order1() { ...@@ -260,7 +260,7 @@ void TopoSorter::DFSDepDiscover::proc_add_dep_comp_order1() {
void TopoSorter::DFSDepDiscover::proc_find_missing_inp() { void TopoSorter::DFSDepDiscover::proc_find_missing_inp() {
auto frame = m_cur_frame; auto frame = m_cur_frame;
auto opr = frame->opr; auto opr = frame->opr;
auto&& mgr = static_cast<ComputingGraphImpl*>(opr->owner_graph()) auto&& mgr = ComputingGraphImpl::downcast(opr->owner_graph())
->static_infer_manager_impl(); ->static_infer_manager_impl();
auto&& missing_inp = frame->missing_inputs; auto&& missing_inp = frame->missing_inputs;
......
...@@ -233,12 +233,12 @@ bool VarNode::set_fwd_in2out_readonly( ...@@ -233,12 +233,12 @@ bool VarNode::set_fwd_in2out_readonly(
if (owner_graph()->options().imperative_proxy_graph) { if (owner_graph()->options().imperative_proxy_graph) {
return false; return false;
} }
return static_cast<ComputingGraphImpl*>(owner_graph()) return ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().fwd_in2out_readonly(input, sub, this); ->var_node_mem_manager().fwd_in2out_readonly(input, sub, this);
} }
VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) { VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) {
static_cast<ComputingGraphImpl*>(owner_graph()) ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().fwd_in2out_writable(input, this); ->var_node_mem_manager().fwd_in2out_writable(input, this);
return *this; return *this;
} }
...@@ -246,20 +246,20 @@ VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) { ...@@ -246,20 +246,20 @@ VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) {
VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) { VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) {
mgb_assert(!owner_graph()->options().imperative_proxy_graph); mgb_assert(!owner_graph()->options().imperative_proxy_graph);
static_cast<ComputingGraphImpl*>(owner_graph()) ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().fwd_in2out_writable_force(input, this); ->var_node_mem_manager().fwd_in2out_writable_force(input, this);
return *this; return *this;
} }
VarNode& VarNode::add_layout_constraint(LayoutConstraintCallback callback) { VarNode& VarNode::add_layout_constraint(LayoutConstraintCallback callback) {
static_cast<ComputingGraphImpl*>(owner_graph()) ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().add_layout_constraint( ->var_node_mem_manager().add_layout_constraint(
this, std::move(callback)); this, std::move(callback));
return *this; return *this;
} }
VarNode& VarNode::add_layout_constraint_contiguous() { VarNode& VarNode::add_layout_constraint_contiguous() {
static_cast<ComputingGraphImpl*>(owner_graph()) ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager() ->var_node_mem_manager()
.add_layout_constraint_level( .add_layout_constraint_level(
this, VarNodeMemManager::LayoutConstraintLevel::CONTIG); this, VarNodeMemManager::LayoutConstraintLevel::CONTIG);
...@@ -267,7 +267,7 @@ VarNode& VarNode::add_layout_constraint_contiguous() { ...@@ -267,7 +267,7 @@ VarNode& VarNode::add_layout_constraint_contiguous() {
} }
VarNode& VarNode::add_layout_constraint_monotone() { VarNode& VarNode::add_layout_constraint_monotone() {
static_cast<ComputingGraphImpl*>(owner_graph()) ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager() ->var_node_mem_manager()
.add_layout_constraint_level( .add_layout_constraint_level(
this, VarNodeMemManager::LayoutConstraintLevel::MONOTONE); this, VarNodeMemManager::LayoutConstraintLevel::MONOTONE);
...@@ -315,7 +315,7 @@ VarNode& VarNode::shape_alloc(const TensorShape &shape) { ...@@ -315,7 +315,7 @@ VarNode& VarNode::shape_alloc(const TensorShape &shape) {
"shape_alloc() could only be used for vars with" "shape_alloc() could only be used for vars with"
" NO_SYS_MEM_ALLOC flag; actual var: %s", " NO_SYS_MEM_ALLOC flag; actual var: %s",
cg::dump_var_info({this}).c_str()); cg::dump_var_info({this}).c_str());
static_cast<ComputingGraphImpl*>(owner_graph()) ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().var_alloc_with_shape(this, shape); ->var_node_mem_manager().var_alloc_with_shape(this, shape);
return *this; return *this;
} }
...@@ -330,7 +330,7 @@ bool VarNode::reset_dev_tensor_from_other_var(VarNode* src_var) { ...@@ -330,7 +330,7 @@ bool VarNode::reset_dev_tensor_from_other_var(VarNode* src_var) {
"dynamic storage on src is required for dynamic readonly " "dynamic storage on src is required for dynamic readonly "
"forwarding: vars=%s", "forwarding: vars=%s",
dump_var_info({src_var, this}).c_str()); dump_var_info({src_var, this}).c_str());
auto&& trait = static_cast<ComputingGraphImpl*>(owner_graph()) auto&& trait = ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager() ->var_node_mem_manager()
.get_var_node_mem_trait_at(src_var); .get_var_node_mem_trait_at(src_var);
if (trait.seq_force_update_dest || if (trait.seq_force_update_dest ||
...@@ -403,7 +403,7 @@ std::shared_ptr<json::Value> VarNode::to_json() const { ...@@ -403,7 +403,7 @@ std::shared_ptr<json::Value> VarNode::to_json() const {
return json::Null::make(); return json::Null::make();
}; };
auto &&trait = static_cast<ComputingGraphImpl*>(owner_graph() auto &&trait = ComputingGraphImpl::downcast(owner_graph()
)->var_node_mem_manager().get_var_node_mem_trait(this); )->var_node_mem_manager().get_var_node_mem_trait(this);
auto flag = json::Array::make(); auto flag = json::Array::make();
{ {
...@@ -459,7 +459,7 @@ std::shared_ptr<json::Value> VarNode::to_json() const { ...@@ -459,7 +459,7 @@ std::shared_ptr<json::Value> VarNode::to_json() const {
#endif #endif
MemAllocPlan& VarNode::init_mem_plan(const DeviceTensorND* fixed_alloc) { MemAllocPlan& VarNode::init_mem_plan(const DeviceTensorND* fixed_alloc) {
static_cast<ComputingGraphImpl*>(owner_graph()) ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager() ->var_node_mem_manager()
.init_single_var_mem_plan(this, fixed_alloc); .init_single_var_mem_plan(this, fixed_alloc);
return m_mem_plan; return m_mem_plan;
...@@ -477,7 +477,7 @@ void VarNode::modify_flag(Flag delta, Flag new_flag) { ...@@ -477,7 +477,7 @@ void VarNode::modify_flag(Flag delta, Flag new_flag) {
Flag::NO_SYS_STATIC_MEM_ALLOC | Flag::NO_SYS_STATIC_MEM_ALLOC |
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta); Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta);
mgb_assert(!static_cast<ComputingGraphImpl*>(owner_graph())-> mgb_assert(!ComputingGraphImpl::downcast(owner_graph())->
var_node_mem_manager().optimize_started(), var_node_mem_manager().optimize_started(),
"could not modify var flags after optimization started"); "could not modify var flags after optimization started");
} }
......
...@@ -340,7 +340,7 @@ VarNodeMemManager::DynamicAllocOprInfo::DynamicAllocOprInfo( ...@@ -340,7 +340,7 @@ VarNodeMemManager::DynamicAllocOprInfo::DynamicAllocOprInfo(
prev_dev_val_input.clear(); prev_dev_val_input.clear();
static_infer_inp.clear(); static_infer_inp.clear();
dev_val_input.clear(); dev_val_input.clear();
auto &&mgr = static_cast<ComputingGraphImpl*>(opr->owner_graph())-> auto &&mgr = ComputingGraphImpl::downcast(opr->owner_graph())->
static_infer_manager_impl(); static_infer_manager_impl();
CompNode single_cn; CompNode single_cn;
......
...@@ -73,7 +73,7 @@ void VarDevMemDefragmenter::defrag(VarNode* req_var, ...@@ -73,7 +73,7 @@ void VarDevMemDefragmenter::defrag(VarNode* req_var,
const CompNodeInfo& cn_info, const CompNodeInfo& cn_info,
size_t extra_size) { size_t extra_size) {
// pause all other comp nodes before calling defrag_impl() // pause all other comp nodes before calling defrag_impl()
auto exec_env = static_cast<ComputingGraphImpl*>(req_var->owner_graph()) auto exec_env = ComputingGraphImpl::downcast(req_var->owner_graph())
->current_exec_env(); ->current_exec_env();
mgb_assert(exec_env); mgb_assert(exec_env);
exec_env->pause_exec(); exec_env->pause_exec();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册