提交 39bd66fc 编写于 作者: M Megvii Engine Team

fix(mgb): fix TensorRT missing cudaSetDevice

GitOrigin-RevId: 40eb119e48042bbcf287872481e4adb381930bad
上级 ab9dfbce
......@@ -1404,6 +1404,7 @@ void TensorRTReplacePass::Impl::detect_replace() {
m_graph_map[opr] = max;
if (max > m_tensorrt_graphs.size()) {
opr->output(0)->comp_node().activate();
m_tensorrt_graphs.push_back(
std::make_shared<TensorRTGraph>(feature_bits));
}
......
......@@ -533,6 +533,7 @@ void TensorRTOpr::get_output_var_shape(const TensorShapeArray& inp_shape,
}
if (!engine_valid) {
comp_node().activate();
// If a context created by a cuda engine, the context must be destroyed
// before the corresponding cuda engine. Otherwise, a segmentfault will
// occur.
......@@ -576,6 +577,7 @@ void TensorRTOpr::build_engine_from_cache() {
TensorRTEngineCache::make_key_from_trt_opr(this));
if (!ret.valid())
return;
comp_node().activate();
auto engine = runtime->deserializeCudaEngine(
reinterpret_cast<const void*>(ret->ptr), ret->size, nullptr);
mgb_assert(engine, "failed to deserialize ICudaEngine");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册