提交 b6ce02a1 编写于 作者: M Megvii Engine Team

fix(subgraph): fallback back to cg if jit unsupported

GitOrigin-RevId: 853a00a4025d6e4fefa0f3ac2fe3d4cad4c3f8b9
上级 21f5a7fc
...@@ -22,6 +22,7 @@ from .._imperative_rt.core2 import ( ...@@ -22,6 +22,7 @@ from .._imperative_rt.core2 import (
make_shape_tuple, make_shape_tuple,
) )
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._imperative_rt.ops import jit_supported
from .._wrap import as_device from .._wrap import as_device
from ..autodiff.grad import Function from ..autodiff.grad import Function
from ..ops import builtin from ..ops import builtin
...@@ -234,6 +235,10 @@ def subgraph( ...@@ -234,6 +235,10 @@ def subgraph(
gopt_level = None # disable jit and compile gopt_level = None # disable jit and compile
jit_fusion = False jit_fusion = False
if jit_fusion and not jit_supported:
jit_fusion = False # jit unusable, fallback to graph compile
gopt_level = 2
def as_op(op, nargs): def as_op(op, nargs):
if isinstance(op, str): if isinstance(op, str):
assert (op, nargs) in _opr_map, "unknown operator" assert (op, nargs) in _opr_map, "unknown operator"
......
...@@ -652,6 +652,11 @@ void init_ops(py::module m) { ...@@ -652,6 +652,11 @@ void init_ops(py::module m) {
}); });
m.def("set_jit_enabled", &JITFusionOp::set_enabled); m.def("set_jit_enabled", &JITFusionOp::set_enabled);
bool jit_supported = false;
#if MGB_JIT
jit_supported = true;
#endif
m.attr("jit_supported") = jit_supported;
auto custom = submodule(m, "_custom"); auto custom = submodule(m, "_custom");
init_custom(custom); init_custom(custom);
......
...@@ -9,21 +9,26 @@ ...@@ -9,21 +9,26 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include <queue> #include <deque>
#include "../op_trait.h"
#include "megbrain/imperative/graph_cache.h" #include "megbrain/imperative/graph_cache.h"
#include "megbrain/imperative/opr_utility.h" #include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/subgraph_detail.h" #include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
#if MGB_JIT
#include "megbrain/jit/executor_opr.h"
#endif
#include "../event_pool.h"
#include "../op_trait.h"
namespace mgb::imperative { namespace mgb::imperative {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp);
...@@ -309,7 +314,7 @@ struct ComputingGraphHolder { ...@@ -309,7 +314,7 @@ struct ComputingGraphHolder {
SmallVector<VarNode*> input_vars; SmallVector<VarNode*> input_vars;
SmallVector<VarNode*> output_vars; SmallVector<VarNode*> output_vars;
std::shared_ptr<DeviceMemoryAllocatorImpl> allocator; std::shared_ptr<DeviceMemoryAllocatorImpl> allocator;
SmallVector<std::unique_ptr<CompNode::Event>> events; SmallVector<std::shared_ptr<CompNode::Event>> events;
std::unique_ptr<cg::static_infer::StaticInferUpdater> updater; std::unique_ptr<cg::static_infer::StaticInferUpdater> updater;
void initialize( void initialize(
...@@ -402,7 +407,7 @@ struct ComputingGraphHolder { ...@@ -402,7 +407,7 @@ struct ComputingGraphHolder {
return true; return true;
}); });
for (auto&& comp_node : comp_nodes) { for (auto&& comp_node : comp_nodes) {
events.push_back(comp_node.create_event()); events.push_back(EventPool::without_timer().alloc_shared(comp_node));
events.back()->record(); events.back()->record();
} }
} }
...@@ -510,7 +515,7 @@ ComputingGraphHolder<Kind>& get_computing_graph( ...@@ -510,7 +515,7 @@ ComputingGraphHolder<Kind>& get_computing_graph(
std::shared_ptr<OpDef> compiled_op, std::shared_ptr<OpDef> compiled_op,
const SmallVector<LogicalTensorDesc>& descs) { const SmallVector<LogicalTensorDesc>& descs) {
using ComputingGraphHolderCache = using ComputingGraphHolderCache =
OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder<Kind>>>>; OpMethResultCache<std::deque<std::unique_ptr<ComputingGraphHolder<Kind>>>>;
thread_local auto cache = std::make_unique<ComputingGraphHolderCache>(); thread_local auto cache = std::make_unique<ComputingGraphHolderCache>();
thread_local size_t nr_cg_holders = 0; thread_local size_t nr_cg_holders = 0;
typename ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs}; typename ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
...@@ -540,20 +545,28 @@ ComputingGraphHolder<Kind>& get_computing_graph( ...@@ -540,20 +545,28 @@ ComputingGraphHolder<Kind>& get_computing_graph(
} }
} }
if (holder) { if (holder) {
cg_holder_queue.pop(); cg_holder_queue.pop_front();
} }
} }
if (!holder) { if (!holder) {
// create new computing graph // create new computing graph
holder = std::make_unique<ComputingGraphHolder<Kind>>(); auto create_holder = [&] {
auto& cg_holder = *holder; auto holder = std::make_unique<ComputingGraphHolder<Kind>>();
cg_holder.initialize(compiled_op->cast_final_safe<CompiledOp>(), descs); auto& cg_holder = *holder;
nr_cg_holders++; cg_holder.initialize(compiled_op->cast_final_safe<CompiledOp>(), descs);
mgb_log_debug( nr_cg_holders++;
"add new computing graph for compiled op, now %zu graphs", mgb_log_debug(
nr_cg_holders); "add new computing graph for compiled op, now %zu graphs",
nr_cg_holders);
return holder;
};
size_t nr_graphs = std::max(cg_holder_queue.size(), (size_t)1);
for (size_t i = 1; i < nr_graphs; ++i) {
cg_holder_queue.push_front(create_holder());
}
holder = create_holder();
} }
cg_holder_queue.push(std::move(holder)); cg_holder_queue.push_back(std::move(holder));
return *cg_holder_queue.back(); return *cg_holder_queue.back();
} }
...@@ -670,6 +683,7 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -670,6 +683,7 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
// skip for dump (JITExecutor can not be dumped) // skip for dump (JITExecutor can not be dumped)
return outputs; return outputs;
} }
#if MGB_JIT
for (auto& output : outputs) { for (auto& output : outputs) {
jit::InternalGraphGenerator igg{output->owner_opr()}; jit::InternalGraphGenerator igg{output->owner_opr()};
std::vector<cg::OperatorNodeBase*> reverse_order; std::vector<cg::OperatorNodeBase*> reverse_order;
...@@ -686,6 +700,9 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -686,6 +700,9 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto ig = igg.generate(); auto ig = igg.generate();
output = jit::JITExecutor::make(ig, igg.orig_inps()).node(); output = jit::JITExecutor::make(ig, igg.orig_inps()).node();
} }
#else
mgb_assert(false, "MGB_WITH_JIT was disabled");
#endif
return outputs; return outputs;
} }
......
...@@ -216,11 +216,11 @@ void CudaExecutable::FuncCache::compile( ...@@ -216,11 +216,11 @@ void CudaExecutable::FuncCache::compile(
ptx = NVRTCCompile(cuda_exe->m_source, major, minor); ptx = NVRTCCompile(cuda_exe->m_source, major, minor);
ptx_cache = PersistentCache::Blob{ptx.data(), ptx.size()}; ptx_cache = PersistentCache::Blob{ptx.data(), ptx.size()};
cache.put(cache_category, key, ptx_cache.val()); cache.put(cache_category, key, ptx_cache.val());
mgb_log("NVRTC JIT: compile %s for %d.%d: source_len=%zu ptx_len=%zu "
"time=%.3fms",
cuda_exe->m_name.c_str(), major, minor, key.size, ptx.size(),
timer.get_msecs());
} }
mgb_log("NVRTC JIT: compile %s for %d.%d: source_len=%zu ptx_len=%zu "
"time=%.3fms",
cuda_exe->m_name.c_str(), major, minor, key.size, ptx.size(),
timer.get_msecs());
} }
void CudaExecutable::FuncCache::exec( void CudaExecutable::FuncCache::exec(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册