提交 fb9d73b2 编写于 作者: M Megvii Engine Team

fix(src/tensorrt): trt7 manage all workspace

GitOrigin-RevId: b78c80c8f1b8e6b2080284f945ddb950aeb42475
上级 dee02289
......@@ -151,7 +151,11 @@ void TensorRTManager::create_trt_context(
nvinfer1::ICudaEngine* engine) {
bool has_no_context = (!m_context);
if (has_no_context) {
#if TENSOR_RT_MANAGE_ALL_WORKSPACE
m_context = {engine->createExecutionContext(), {}};
#else
m_context = {engine->createExecutionContextWithoutDeviceMemory(), {}};
#endif
}
MGB_MARK_USED_VAR(cn);
#if NV_TENSOR_RT_VERSION >= 6001
......@@ -333,6 +337,9 @@ void TensorRTManager::exec(
}
}
MGB_MARK_USED_VAR(is_trt_opr);
#if TENSOR_RT_MANAGE_ALL_WORKSPACE
MGB_MARK_USED_VAR(should_reinit_device_memory);
#else
if (should_reinit_device_memory) {
mgb_assert(
opr->output().back()->shape()[0] == intl::workspace_size(engine) &&
......@@ -340,7 +347,7 @@ void TensorRTManager::exec(
m_context->setDeviceMemory(workspace_ptr);
m_device_workspace_memory_ptr = workspace_ptr;
}
#endif
auto&& env = mgb::CompNodeEnv::from_comp_node(comp_node);
bool exec_success = false;
......
......@@ -28,6 +28,18 @@ enum class Empty : int32_t {};
#define TENSORRT_NO_EXCEPT(api)
#endif
#if (NV_TENSOR_RT_VERSION >= 7000)
//! FIXME: trt7.2.2.3 leak memory in setDeviceMemory API, now trt malloc workspace
//! self, megengine do not alloc any workspace
#define TENSOR_RT_MANAGE_ALL_WORKSPACE 1
#else
#define TENSOR_RT_MANAGE_ALL_WORKSPACE 0
#endif
#if NV_TENSOR_RT_VERSION >= 8000
#error "if trt8 fix https://github.com/NVIDIA/TensorRT/issues/2290, try TENSOR_RT_MANAGE_ALL_WORKSPACE=0"
#endif
namespace mgb {
namespace opr {
......@@ -73,7 +85,12 @@ public:
};
static inline size_t workspace_size(nvinfer1::ICudaEngine* engine) {
#if TENSOR_RT_MANAGE_ALL_WORKSPACE
MGB_MARK_USED_VAR(engine);
return 0;
#else
return engine->getDeviceMemorySize();
#endif
}
} // namespace intl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册