提交 87c845fd 编写于 作者: M Megvii Engine Team

feat(mgb): tensorrt runtime opr support dynamic batch trt model

GitOrigin-RevId: 7461de704e2bc8dd12bcf2f783f4218437489ed4
上级 f7cf3e34
......@@ -153,6 +153,65 @@ void TensorRTOpr::GpuAllocator::free(void* memory) {
}
/* ========================== TensorRTManager ========================== */
const intl::TensorRTUniquePtr<nvinfer1::IExecutionContext>& TensorRTManager::
create_trt_context(
const TensorShapeArray& inp_shape, nvinfer1::ICudaEngine* engine) {
if (!m_context) {
m_context = {engine->createExecutionContextWithoutDeviceMemory(), {}};
#if NV_TENSOR_RT_VERSION >= 6001
for (size_t i = 0; i < inp_shape.size(); ++i) {
auto dims = m_context->getBindingDimensions(i);
for (int j = 0; j < dims.nbDims; j++) {
if (dims.d[j] == -1) {
dims.d[j] = inp_shape.at(i)[j];
}
}
m_context->setBindingDimensions(i, dims);
}
// check if input shape is set correctly
for (int i = inp_shape.size(); i < engine->getNbBindings(); ++i) {
auto dims = m_context->getBindingDimensions(i);
if (dims.nbDims == -1) {
for (int j = 0; j < engine->getNbOptimizationProfiles(); j++) {
mgb_log_debug("TensorRT profile %d:\n", j);
for (size_t k = 0; k < inp_shape.size(); k++) {
mgb_log_debug(
"input[%zu]'s minimum shape is: %s\n", k,
TensorRTOpr::dims2shape(
engine->getProfileDimensions(
j, k,
nvinfer1::OptProfileSelector::kMIN))
.to_string()
.c_str());
mgb_log_debug(
"input[%zu]'s optimum shape is: %s\n", k,
TensorRTOpr::dims2shape(
engine->getProfileDimensions(
j, k,
nvinfer1::OptProfileSelector::kOPT))
.to_string()
.c_str());
mgb_log_debug(
"input[%zu]'s maximum shape is: %s\n", k,
TensorRTOpr::dims2shape(
engine->getProfileDimensions(
j, k,
nvinfer1::OptProfileSelector::kMAX))
.to_string()
.c_str());
}
}
mgb_throw(
MegBrainError,
"Invalid network output, this might be caused by inconsistent "
"input shapes.Correct input optimization profiles as above.");
}
}
#endif
}
return m_context;
}
void TensorRTManager::exec(
cg::SingleCNOperatorNodeBase* opr, CompNode comp_node_check,
nvinfer1::ICudaEngine* engine, size_t batch, bool use_trt_profiler) {
......@@ -169,9 +228,11 @@ void TensorRTManager::exec(
auto workspace_ptr = opr->output().back()->dev_tensor().raw_ptr();
bool should_reinit_device_memory =
!m_context || m_device_workspace_memory_ptr != workspace_ptr;
if (!m_context) {
m_context = {engine->createExecutionContextWithoutDeviceMemory(), {}};
TensorShapeArray arr;
for (auto&& i : opr->input()) {
arr.push_back(i->shape());
}
create_trt_context(arr, engine);
m_trt_iobuf.resize(opr->input().size() + opr->output().size() - 1);
bool is_trt_opr = false;
if (opr->same_type<TensorRTOpr>()) {
......
......@@ -101,7 +101,8 @@ TensorRTRuntimeOpr::TensorRTRuntimeOpr(
void TensorRTRuntimeOpr::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
auto batch = inp_shape.at(0)[0];
auto get_mgb_shape = [this, batch](int binding_idx) -> TensorShape {
auto&& context = m_manager.create_trt_context(inp_shape, m_engine.get());
auto get_mgb_shape = [&](int binding_idx) -> TensorShape {
auto dims = m_engine->getBindingDimensions(binding_idx);
#if NV_TENSOR_RT_VERSION >= 6001
auto format = m_engine->getBindingFormat(binding_idx);
......@@ -121,8 +122,25 @@ void TensorRTRuntimeOpr::get_output_var_shape(
dims.d[dims.nbDims - 1] = 4;
}
#endif
return m_trt_engine_has_batch ? TensorRTOpr::dims2shape(dims)
auto shape = m_trt_engine_has_batch ? TensorRTOpr::dims2shape(dims)
: TensorRTOpr::dims2shape(dims, batch);
#if NV_TENSOR_RT_VERSION >= 6001
if (static_cast<size_t>(binding_idx) < inp_shape.size()) {
for (int i = 0; i < dims.nbDims; i++) {
if (dims.d[i] == -1) {
shape[i] = inp_shape.at(binding_idx)[i];
}
}
} else {
auto trt_infer_dims = context->getBindingDimensions(binding_idx);
for (int i = 0; i < dims.nbDims; i++) {
if (dims.d[i] == -1) {
shape[i] = trt_infer_dims.d[i];
}
}
}
#endif
return shape;
};
for (size_t i = 0; i < inp_shape.size(); ++i) {
mgb_assert(batch == inp_shape[i][0], "input batchsize not equal");
......@@ -135,6 +153,8 @@ void TensorRTRuntimeOpr::get_output_var_shape(
out_shape[i] = get_mgb_shape(i + input().size());
}
out_shape.back() = {intl::workspace_size(m_engine.get())};
// must clear context, otherwise it may cause unknwon error.
m_manager.clear_trt_context();
}
void TensorRTRuntimeOpr::add_input_layout_constraint() {
......
......@@ -52,6 +52,8 @@ class TensorRTManager {
void* m_device_workspace_memory_ptr;
public:
const TensorRTUniquePtr<nvinfer1::IExecutionContext>& create_trt_context(
const TensorShapeArray& inp_shape, nvinfer1::ICudaEngine* engine);
void exec(
cg::SingleCNOperatorNodeBase* opr, CompNode comp_node_check,
nvinfer1::ICudaEngine* engine, size_t batch = 1,
......
......@@ -86,7 +86,7 @@ private:
// note: gpu allocator must be released after other trt objects
std::shared_ptr<TensorRTOpr::GpuAllocator> m_gpu_allocator;
std::shared_ptr<nvinfer1::ICudaEngine> m_engine;
intl::TensorRTManager m_manager;
mutable intl::TensorRTManager m_manager;
// if m_engine's dims with batch
bool m_trt_engine_has_batch;
}; // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册