提交 d42cf4cd 编写于 作者: M Megvii Engine Team

refactor(mgb): replace static_cast<ComputingGraphImpl*> with a checked version

GitOrigin-RevId: d05b114668f84f9b4045abe55af11bce0ff3bd4a
上级 4d56371e
......@@ -129,7 +129,7 @@ else()
if(ANDROID)
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -DNDEBUG -g")
else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -DNDEBUG -g")
......@@ -224,6 +224,7 @@ endif()
option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON)
option(MGE_BUILD_XXX "Build _xxx.so instead of mgb.so " OFF)
if(MGE_BUILD_XXX)
add_compile_definitions(MGB_ENABLE_IMPERATIVE_RUNTIME)
set(CMAKE_CXX_STANDARD 17)
endif()
......@@ -662,14 +663,14 @@ endif()
configure_file(cmake/megengine.pc.in
${CMAKE_CURRENT_BINARY_DIR}/megengine.pc
@ONLY)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/megengine.pc
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/megengine.pc
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
# Do not export targets if MGE_WITH_DISTRIBUTED is on. MegRay is not ready.
if (NOT MGE_WITH_DISTRIBUTED)
include(CMakePackageConfigHelpers)
set (MGE_INSTALL_CMAKEDIR ${CMAKE_INSTALL_LIBDIR}/cmake/MegEngine)
configure_package_config_file(cmake/MegEngineConfig.cmake.in
configure_package_config_file(cmake/MegEngineConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/MegEngineConfig.cmake
INSTALL_DESTINATION ${MGE_INSTALL_CMAKEDIR}
)
......
......@@ -674,7 +674,7 @@ void ComputingGraphImpl::share_device_memory_with(ComputingGraph& other) {
mgb_assert(
!m_current_comp_seq,
"share_device_memory_with must be called before compiling graph");
auto&& oimpl = static_cast<ComputingGraphImpl&>(other);
auto&& oimpl = *ComputingGraphImpl::downcast(&other);
var_node_mem_manager().static_device_memory_manager(
oimpl.var_node_mem_manager().static_device_memory_manager());
}
......@@ -707,7 +707,7 @@ size_t ComputingGraphImpl::clear_device_memory() {
}
void ComputingGraphImpl::set_as_subgraph(ComputingGraph& par_graph) {
m_parent_graph = static_cast<ComputingGraphImpl*>(&par_graph);
m_parent_graph = ComputingGraphImpl::downcast(&par_graph);
m_parent_graph->m_subgraphs.emplace_back(this);
m_node_id_counter = m_parent_graph->m_node_id_counter;
options().var_sanity_check_first_run =
......
......@@ -122,6 +122,15 @@ public:
ComputingGraphImpl();
~ComputingGraphImpl();
template<typename T> static ComputingGraphImpl* downcast(T* ptr) = delete;
inline static ComputingGraphImpl* downcast(ComputingGraph* graph) {
#ifdef MGB_ENABLE_IMPERATIVE_RUNTIME
mgb_assert(!graph->options().imperative_proxy_graph);
#endif
return static_cast<ComputingGraphImpl*>(graph);
}
friend struct ComputingGraph::Options;
std::unique_ptr<AsyncExecutable> compile(
......
......@@ -100,7 +100,7 @@ class ComputingGraphImpl::ComputingSequence final : public AsyncExecutable {
public:
ComputingSequence(const std::shared_ptr<ComputingGraph>& graph)
: m_owner_graph_refkeep{graph},
m_owner_graph{static_cast<ComputingGraphImpl*>(graph.get())},
m_owner_graph{ComputingGraphImpl::downcast(graph.get())},
m_have_parent_graph{m_owner_graph->m_parent_graph} {}
GraphExecutable::ExecEnv& exec_env() { return m_exec_env; }
......
......@@ -76,7 +76,7 @@ EagerEvalManager::~EagerEvalManager() noexcept {
if (m_first_opr_enable_status == 1) {
m_var_sync_mgr_pool.disable_freelist();
for (auto&& i :
static_cast<ComputingGraphImpl*>(m_owner_graph)->all_oprs()) {
ComputingGraphImpl::downcast(m_owner_graph)->all_oprs()) {
for (auto var : i->output()) {
auto mgr = VarNodeMemManager::var_node_cn_sync_manager(var);
if (mgr) {
......@@ -223,7 +223,7 @@ void EagerEvalManager::prepare_for_exec(OperatorNodeBase* opr) {
}
void EagerEvalManager::update_static_infer_result(OperatorNodeBase* opr) {
auto&& mgr = static_cast<ComputingGraphImpl*>(m_owner_graph)
auto&& mgr = ComputingGraphImpl::downcast(m_owner_graph)
->static_infer_manager_impl();
auto sync_missing_trait =
[&](static_infer::StaticInferManagerImpl::TagHandler* handler) {
......@@ -260,7 +260,7 @@ void EagerEvalManager::update_static_infer_result(OperatorNodeBase* opr) {
}
void EagerEvalManager::ensure_input_layout(VarNode* var) {
auto&& mem_mgr = static_cast<ComputingGraphImpl*>(var->owner_graph())
auto&& mem_mgr = ComputingGraphImpl::downcast(var->owner_graph())
->var_node_mem_manager();
auto trait = mem_mgr.get_var_node_mem_trait_nullable(var);
......@@ -287,7 +287,7 @@ void EagerEvalManager::alloc_output_mem(OperatorNodeBase* opr) {
}
}
auto&& mgr = static_cast<ComputingGraphImpl*>(m_owner_graph)
auto&& mgr = ComputingGraphImpl::downcast(m_owner_graph)
->var_node_mem_manager();
OprNodeArray opr_seq{opr};
......@@ -348,7 +348,7 @@ void EagerEvalManager::do_on_opr_insert(OperatorNodeBase* opr) {
if (status) {
update_static_infer_result(opr);
alloc_output_mem(opr);
auto&& mgr = static_cast<ComputingGraphImpl*>(m_owner_graph)
auto&& mgr = ComputingGraphImpl::downcast(m_owner_graph)
->var_node_mem_manager();
mgr.on_graph_compile_finished();
opr->execute(*m_exec_env);
......
......@@ -40,7 +40,7 @@ class GradShapeChecker {
void do_on_var_shape(VarNode *var) {
MGB_MARK_USED_VAR(m_opr);
auto graph = static_cast<ComputingGraphImpl*>(var->owner_graph());
auto graph = ComputingGraphImpl::downcast(var->owner_graph());
auto seq = graph->current_comp_seq();
if (seq) {
......@@ -90,7 +90,7 @@ class GradShapeChecker {
}
static void make(OperatorNodeBase *opr, VarNode *wrt, VarNode *grad) {
if (static_cast<ComputingGraphImpl*>(wrt->owner_graph())
if (ComputingGraphImpl::downcast(wrt->owner_graph())
->eager_eval_manager().enabled())
return;
using namespace std::placeholders;
......@@ -650,13 +650,13 @@ void GradManager::add_var_virtual_receiver(
}
void cg::add_grad_transformer(VarNode *var, const GradTransformer &cb) {
static_cast<ComputingGraphImpl*>(var->owner_graph())->
ComputingGraphImpl::downcast(var->owner_graph())->
grad_manager().
add_grad_transformer(var, cb);
}
void cg::add_extra_dep_for_grad(VarNode *inp, VarNode *out) {
static_cast<ComputingGraphImpl*>(inp->owner_graph())->grad_manager().
ComputingGraphImpl::downcast(inp->owner_graph())->grad_manager().
add_extra_dep_for_grad(inp, out);
}
......@@ -667,7 +667,7 @@ void cg::add_var_virtual_receiver(
desc->inputs = inputs;
desc->outputs = outputs;
desc->grad = grad;
static_cast<ComputingGraphImpl*>(inputs.at(0)->owner_graph())->
ComputingGraphImpl::downcast(inputs.at(0)->owner_graph())->
grad_manager().
add_var_virtual_receiver(desc);
}
......
......@@ -99,8 +99,8 @@ SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_w
grads.reserve(wrts_.size());
VarNodeArray dest_vars;
auto&& graph = target->owner_graph();
auto&& eager_mgr = static_cast<ComputingGraphImpl*>(graph)->eager_eval_manager();
auto&& grad_mgr = static_cast<ComputingGraphImpl*>(graph)->grad_manager();
auto&& eager_mgr = ComputingGraphImpl::downcast(graph)->eager_eval_manager();
auto&& grad_mgr = ComputingGraphImpl::downcast(graph)->grad_manager();
bool already_recorded = eager_mgr.enter_record_mode();
for (auto&& wrt_ : wrts_) {
auto wrt = wrt_.node();
......@@ -139,7 +139,7 @@ SymbolVarArray cg::grad(SymbolVar target_, SymbolVarArray wrts_, bool warn_mid_w
SymbolVar cg::current_grad_target(ComputingGraph &graph) {
#if MGB_ENABLE_GRAD
auto var = static_cast<ComputingGraphImpl&>(graph).grad_manager(
auto var = ComputingGraphImpl::downcast(&graph)->grad_manager(
).current_grad_target();
mgb_throw_if(!var, GraphError, "current_grad_target() called outside "
"grad computing environment");
......
......@@ -93,7 +93,7 @@ OperatorNodeBase::OperatorNodeBase(ComputingGraph *owner,
}
OperatorNodeBase::~OperatorNodeBase() noexcept {
auto &&pool = static_cast<ComputingGraphImpl*>(
auto &&pool = ComputingGraphImpl::cast(
owner_graph())->var_node_pool();
for (auto i: m_output) {
pool.free(i);
......@@ -124,7 +124,7 @@ void OperatorNodeBase::execute(ExecEnv &env) {
}
// allocate output with dynamic storage
static_cast<ComputingGraphImpl*>(owner_graph())
ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager()
.alloc_var_node_mem_dynamic(env, this);
......@@ -135,11 +135,11 @@ void OperatorNodeBase::execute(ExecEnv &env) {
// static_infer_manager so the value would be up-to-date; however for shape
// deps, oprs would access the shape directly, so we need to insert some
// code here to ensure it is up-to-date.
if (!static_cast<ComputingGraphImpl*>(owner_graph())
if (!ComputingGraphImpl::downcast(owner_graph())
->eager_eval_manager()
.enabled()) {
VarNodeArray vars_to_set;
auto cg = static_cast<ComputingGraphImpl*>(owner_graph());
auto cg = ComputingGraphImpl::downcast(owner_graph());
auto step_cur = cg->opr_step_num_in_cur_comp_seq(this).val();
mgb_assert(step_cur < std::numeric_limits<size_t>::max());
using DT = NodeProp::DepType;
......@@ -264,7 +264,7 @@ VarNode* OperatorNodeBase::add_output(const Maybe<std::string> &name) {
mgb_assert(!m_inserted_in_graph && !m_node_prop.valid(),
"add output on opr after it has been inserted into graph");
auto ptr = static_cast<ComputingGraphImpl*>(
auto ptr = ComputingGraphImpl::cast(
owner_graph())->var_node_pool().alloc(
name.valid() ? this->name() + ":" + name.val() : name, this);
m_output.push_back(ptr);
......@@ -676,7 +676,7 @@ void mixin::IOSameShapeOperatorNode::get_output_var_shape(
void PostExecActions::add(VarNode* var) {
mgb_assert(m_comp_node == var->comp_node());
auto graph = static_cast<ComputingGraphImpl*>(var->owner_graph());
auto graph = ComputingGraphImpl::downcast(var->owner_graph());
auto&& infer_mgr = graph->static_infer_manager_impl();
auto&& extra_info = graph->current_comp_seq_extra_info();
......
......@@ -813,7 +813,7 @@ StaticInferManagerImpl::~StaticInferManagerImpl() noexcept {
m_mem_pool_value_trait.disable_freelist();
for (auto &&i: m_dtor_callbacks)
i.second();
for (auto &&i: static_cast<ComputingGraphImpl*>(
for (auto &&i: ComputingGraphImpl::downcast(
m_owner_graph)->all_oprs()) {
for (auto j: i->output()) {
clear_tag_handler(j);
......@@ -1212,7 +1212,7 @@ class StaticInferManagerImpl::SubgraphStaticInferHelperImpl final:
void check_graph_par(VarNode *var) {
if (mgb_unlikely(!m_par_graph)) {
m_par_graph = static_cast<ComputingGraphImpl*>(var->owner_graph());
m_par_graph = ComputingGraphImpl::downcast(var->owner_graph());
mgb_assert(m_par_graph != m_sub_graph);
auto cb = [this]() {
......@@ -1230,7 +1230,7 @@ class StaticInferManagerImpl::SubgraphStaticInferHelperImpl final:
void check_graph_sub(VarNode *var) {
if (mgb_unlikely(!m_sub_graph)) {
m_sub_graph = static_cast<ComputingGraphImpl*>(var->owner_graph());
m_sub_graph = ComputingGraphImpl::downcast(var->owner_graph());
mgb_assert(m_sub_graph != m_par_graph);
} else {
mgb_assert(m_sub_graph == var->owner_graph());
......
......@@ -132,7 +132,7 @@ const DeviceTensorND& SymbolVar::eager_eval_get_value() const {
#if MGB_BUILD_SLIM_SERVING
mgb_throw(MegBrainError, "eager eval disabled at compile time");
#else
auto og = static_cast<ComputingGraphImpl*>(node()->owner_graph());
auto og = ComputingGraphImpl::downcast(node()->owner_graph());
mgb_assert(og->options().eager_evaluation);
return node()->dev_tensor();
#endif
......
......@@ -260,7 +260,7 @@ void TopoSorter::DFSDepDiscover::proc_add_dep_comp_order1() {
void TopoSorter::DFSDepDiscover::proc_find_missing_inp() {
auto frame = m_cur_frame;
auto opr = frame->opr;
auto&& mgr = static_cast<ComputingGraphImpl*>(opr->owner_graph())
auto&& mgr = ComputingGraphImpl::downcast(opr->owner_graph())
->static_infer_manager_impl();
auto&& missing_inp = frame->missing_inputs;
......
......@@ -233,12 +233,12 @@ bool VarNode::set_fwd_in2out_readonly(
if (owner_graph()->options().imperative_proxy_graph) {
return false;
}
return static_cast<ComputingGraphImpl*>(owner_graph())
return ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().fwd_in2out_readonly(input, sub, this);
}
VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) {
static_cast<ComputingGraphImpl*>(owner_graph())
ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().fwd_in2out_writable(input, this);
return *this;
}
......@@ -246,20 +246,20 @@ VarNode& VarNode::set_fwd_in2out_writable(VarNode *input) {
VarNode& VarNode::set_fwd_in2out_writable_force(VarNode *input) {
mgb_assert(!owner_graph()->options().imperative_proxy_graph);
static_cast<ComputingGraphImpl*>(owner_graph())
ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().fwd_in2out_writable_force(input, this);
return *this;
}
VarNode& VarNode::add_layout_constraint(LayoutConstraintCallback callback) {
static_cast<ComputingGraphImpl*>(owner_graph())
ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().add_layout_constraint(
this, std::move(callback));
return *this;
}
VarNode& VarNode::add_layout_constraint_contiguous() {
static_cast<ComputingGraphImpl*>(owner_graph())
ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager()
.add_layout_constraint_level(
this, VarNodeMemManager::LayoutConstraintLevel::CONTIG);
......@@ -267,7 +267,7 @@ VarNode& VarNode::add_layout_constraint_contiguous() {
}
VarNode& VarNode::add_layout_constraint_monotone() {
static_cast<ComputingGraphImpl*>(owner_graph())
ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager()
.add_layout_constraint_level(
this, VarNodeMemManager::LayoutConstraintLevel::MONOTONE);
......@@ -315,7 +315,7 @@ VarNode& VarNode::shape_alloc(const TensorShape &shape) {
"shape_alloc() could only be used for vars with"
" NO_SYS_MEM_ALLOC flag; actual var: %s",
cg::dump_var_info({this}).c_str());
static_cast<ComputingGraphImpl*>(owner_graph())
ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager().var_alloc_with_shape(this, shape);
return *this;
}
......@@ -330,7 +330,7 @@ bool VarNode::reset_dev_tensor_from_other_var(VarNode* src_var) {
"dynamic storage on src is required for dynamic readonly "
"forwarding: vars=%s",
dump_var_info({src_var, this}).c_str());
auto&& trait = static_cast<ComputingGraphImpl*>(owner_graph())
auto&& trait = ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager()
.get_var_node_mem_trait_at(src_var);
if (trait.seq_force_update_dest ||
......@@ -403,7 +403,7 @@ std::shared_ptr<json::Value> VarNode::to_json() const {
return json::Null::make();
};
auto &&trait = static_cast<ComputingGraphImpl*>(owner_graph()
auto &&trait = ComputingGraphImpl::downcast(owner_graph()
)->var_node_mem_manager().get_var_node_mem_trait(this);
auto flag = json::Array::make();
{
......@@ -459,7 +459,7 @@ std::shared_ptr<json::Value> VarNode::to_json() const {
#endif
MemAllocPlan& VarNode::init_mem_plan(const DeviceTensorND* fixed_alloc) {
static_cast<ComputingGraphImpl*>(owner_graph())
ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager()
.init_single_var_mem_plan(this, fixed_alloc);
return m_mem_plan;
......@@ -477,7 +477,7 @@ void VarNode::modify_flag(Flag delta, Flag new_flag) {
Flag::NO_SYS_STATIC_MEM_ALLOC |
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta);
mgb_assert(!static_cast<ComputingGraphImpl*>(owner_graph())->
mgb_assert(!ComputingGraphImpl::downcast(owner_graph())->
var_node_mem_manager().optimize_started(),
"could not modify var flags after optimization started");
}
......
......@@ -340,7 +340,7 @@ VarNodeMemManager::DynamicAllocOprInfo::DynamicAllocOprInfo(
prev_dev_val_input.clear();
static_infer_inp.clear();
dev_val_input.clear();
auto &&mgr = static_cast<ComputingGraphImpl*>(opr->owner_graph())->
auto &&mgr = ComputingGraphImpl::downcast(opr->owner_graph())->
static_infer_manager_impl();
CompNode single_cn;
......
......@@ -73,7 +73,7 @@ void VarDevMemDefragmenter::defrag(VarNode* req_var,
const CompNodeInfo& cn_info,
size_t extra_size) {
// pause all other comp nodes before calling defrag_impl()
auto exec_env = static_cast<ComputingGraphImpl*>(req_var->owner_graph())
auto exec_env = ComputingGraphImpl::downcast(req_var->owner_graph())
->current_exec_env();
mgb_assert(exec_env);
exec_env->pause_exec();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册