diff --git a/src/tensorrt/impl/opr_replace.cpp b/src/tensorrt/impl/opr_replace.cpp index b6ebb876b4c0128016c6d36d20f45e9813875333..bcd4d0d7a70264cac0271a87ed23e74d5a83f842 100644 --- a/src/tensorrt/impl/opr_replace.cpp +++ b/src/tensorrt/impl/opr_replace.cpp @@ -1404,6 +1404,7 @@ void TensorRTReplacePass::Impl::detect_replace() { m_graph_map[opr] = max; if (max > m_tensorrt_graphs.size()) { + opr->output(0)->comp_node().activate(); m_tensorrt_graphs.push_back( std::make_shared(feature_bits)); } diff --git a/src/tensorrt/impl/tensorrt_opr.cpp b/src/tensorrt/impl/tensorrt_opr.cpp index 95652025984566ba67057662174040df0e8a7681..00b3cd7b274fd4a464dbb502c55383b7cfc66704 100644 --- a/src/tensorrt/impl/tensorrt_opr.cpp +++ b/src/tensorrt/impl/tensorrt_opr.cpp @@ -533,6 +533,7 @@ void TensorRTOpr::get_output_var_shape(const TensorShapeArray& inp_shape, } if (!engine_valid) { + comp_node().activate(); // If a context created by a cuda engine, the context must be destroyed // before the corresponding cuda engine. Otherwise, a segmentfault will // occur. @@ -576,6 +577,7 @@ void TensorRTOpr::build_engine_from_cache() { TensorRTEngineCache::make_key_from_trt_opr(this)); if (!ret.valid()) return; + comp_node().activate(); auto engine = runtime->deserializeCudaEngine( reinterpret_cast(ret->ptr), ret->size, nullptr); mgb_assert(engine, "failed to deserialize ICudaEngine");