提交 b6ce02a1 编写于 作者: M Megvii Engine Team

fix(subgraph): fallback back to cg if jit unsupported

GitOrigin-RevId: 853a00a4025d6e4fefa0f3ac2fe3d4cad4c3f8b9
上级 21f5a7fc
......@@ -22,6 +22,7 @@ from .._imperative_rt.core2 import (
make_shape_tuple,
)
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._imperative_rt.ops import jit_supported
from .._wrap import as_device
from ..autodiff.grad import Function
from ..ops import builtin
......@@ -234,6 +235,10 @@ def subgraph(
gopt_level = None # disable jit and compile
jit_fusion = False
if jit_fusion and not jit_supported:
jit_fusion = False # jit unusable, fallback to graph compile
gopt_level = 2
def as_op(op, nargs):
if isinstance(op, str):
assert (op, nargs) in _opr_map, "unknown operator"
......
......@@ -652,6 +652,11 @@ void init_ops(py::module m) {
});
m.def("set_jit_enabled", &JITFusionOp::set_enabled);
bool jit_supported = false;
#if MGB_JIT
jit_supported = true;
#endif
m.attr("jit_supported") = jit_supported;
auto custom = submodule(m, "_custom");
init_custom(custom);
......
......@@ -9,21 +9,26 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <queue>
#include <deque>
#include "../op_trait.h"
#include "megbrain/imperative/graph_cache.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#if MGB_JIT
#include "megbrain/jit/executor_opr.h"
#endif
#include "../event_pool.h"
#include "../op_trait.h"
namespace mgb::imperative {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp);
......@@ -309,7 +314,7 @@ struct ComputingGraphHolder {
SmallVector<VarNode*> input_vars;
SmallVector<VarNode*> output_vars;
std::shared_ptr<DeviceMemoryAllocatorImpl> allocator;
SmallVector<std::unique_ptr<CompNode::Event>> events;
SmallVector<std::shared_ptr<CompNode::Event>> events;
std::unique_ptr<cg::static_infer::StaticInferUpdater> updater;
void initialize(
......@@ -402,7 +407,7 @@ struct ComputingGraphHolder {
return true;
});
for (auto&& comp_node : comp_nodes) {
events.push_back(comp_node.create_event());
events.push_back(EventPool::without_timer().alloc_shared(comp_node));
events.back()->record();
}
}
......@@ -510,7 +515,7 @@ ComputingGraphHolder<Kind>& get_computing_graph(
std::shared_ptr<OpDef> compiled_op,
const SmallVector<LogicalTensorDesc>& descs) {
using ComputingGraphHolderCache =
OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder<Kind>>>>;
OpMethResultCache<std::deque<std::unique_ptr<ComputingGraphHolder<Kind>>>>;
thread_local auto cache = std::make_unique<ComputingGraphHolderCache>();
thread_local size_t nr_cg_holders = 0;
typename ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
......@@ -540,20 +545,28 @@ ComputingGraphHolder<Kind>& get_computing_graph(
}
}
if (holder) {
cg_holder_queue.pop();
cg_holder_queue.pop_front();
}
}
if (!holder) {
// create new computing graph
holder = std::make_unique<ComputingGraphHolder<Kind>>();
auto create_holder = [&] {
auto holder = std::make_unique<ComputingGraphHolder<Kind>>();
auto& cg_holder = *holder;
cg_holder.initialize(compiled_op->cast_final_safe<CompiledOp>(), descs);
nr_cg_holders++;
mgb_log_debug(
"add new computing graph for compiled op, now %zu graphs",
nr_cg_holders);
return holder;
};
size_t nr_graphs = std::max(cg_holder_queue.size(), (size_t)1);
for (size_t i = 1; i < nr_graphs; ++i) {
cg_holder_queue.push_front(create_holder());
}
holder = create_holder();
}
cg_holder_queue.push(std::move(holder));
cg_holder_queue.push_back(std::move(holder));
return *cg_holder_queue.back();
}
......@@ -670,6 +683,7 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
// skip for dump (JITExecutor can not be dumped)
return outputs;
}
#if MGB_JIT
for (auto& output : outputs) {
jit::InternalGraphGenerator igg{output->owner_opr()};
std::vector<cg::OperatorNodeBase*> reverse_order;
......@@ -686,6 +700,9 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto ig = igg.generate();
output = jit::JITExecutor::make(ig, igg.orig_inps()).node();
}
#else
mgb_assert(false, "MGB_WITH_JIT was disabled");
#endif
return outputs;
}
......
......@@ -216,11 +216,11 @@ void CudaExecutable::FuncCache::compile(
ptx = NVRTCCompile(cuda_exe->m_source, major, minor);
ptx_cache = PersistentCache::Blob{ptx.data(), ptx.size()};
cache.put(cache_category, key, ptx_cache.val());
}
mgb_log("NVRTC JIT: compile %s for %d.%d: source_len=%zu ptx_len=%zu "
"time=%.3fms",
cuda_exe->m_name.c_str(), major, minor, key.size, ptx.size(),
timer.get_msecs());
}
}
void CudaExecutable::FuncCache::exec(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册