From 9336dd3ee9f6938507a67ab33d8b4b4b83f55249 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Mon, 8 Aug 2022 16:36:49 +0800 Subject: [PATCH] add trt int8 dynamic support (#44800) * add trt int8 dynamic support * just support trt7+ * just for trt7.1.3.a * Update tensorrt_subgraph_pass.cc * delete trt_engine when it not use --- .../ir_passes/tensorrt_subgraph_pass.cc | 107 +++++++----- paddle/fluid/inference/tensorrt/engine.cc | 5 + paddle/fluid/inference/tensorrt/engine.h | 4 + .../operators/tensorrt/tensorrt_engine_op.h | 156 ++++++++++++------ 4 files changed, 181 insertions(+), 91 deletions(-) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index e4fc52b6fa7..89b41fced12 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -40,21 +40,24 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( auto with_dynamic_shape = Get("with_dynamic_shape"); auto teller = [&](const framework::ir::Node *node) { if (!node->IsOp() || !node->Op()) return false; - if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(), + if (find(trt_disabled_ops.begin(), + trt_disabled_ops.end(), node->Op()->Type()) != trt_disabled_ops.end()) { VLOG(3) << node->Op()->Type().c_str() << " is diabled by config in TensorRT"; return false; } - bool is_ok = tensorrt::OpTeller::Global().Tell(node, no_calib_int8, - with_dynamic_shape); + bool is_ok = tensorrt::OpTeller::Global().Tell( + node, no_calib_int8, with_dynamic_shape); if (!is_ok) VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT"; return is_ok; }; framework::ir::SubGraphFuser fuser( - graph, teller, Get("min_subgraph_size") /*min subgraph size*/, + graph, + teller, + Get("min_subgraph_size") /*min subgraph size*/, "tensorrt_engine"); fuser(); @@ -102,26 +105,27 @@ std::string GenerateEngineKey(const std::set &engine_inputs, engine_hash_key += "#"; } engine_hash_key += predictor_id; - if (!for_calibration) { - engine_hash_key += "#"; - engine_hash_key += max_batch_size; - } engine_hash_key += "#"; engine_hash_key += precision; auto engine_key = std::to_string(std::hash()(engine_hash_key)); + if (!for_calibration) { + engine_key += max_batch_size; + } VLOG(2) << "TRT engine hash key: " << engine_hash_key; VLOG(2) << "TRT engine key: " << engine_key; return engine_key; } void TensorRtSubgraphPass::CreateTensorRTOp( - framework::ir::Node *node, framework::ir::Graph *graph, + framework::ir::Node *node, + framework::ir::Graph *graph, const std::vector &graph_params, std::vector *repetitive_params) const { auto *op_desc = node->Op(); auto &subgraph = *framework::ir::Agent(node).subgraph(); - PADDLE_ENFORCE_EQ(subgraph.empty(), false, + PADDLE_ENFORCE_EQ(subgraph.empty(), + false, platform::errors::PreconditionNotMet( "The subgraph should not be empty.")); @@ -208,7 +212,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( if (trt_tuned_dynamic_shape) { VLOG(1) << "trt dynamic_shape deserialize from " << shape_range_info_path; inference::DeserializeShapeRangeInfo(shape_range_info_path, - &min_input_shape, &max_input_shape, + &min_input_shape, + &max_input_shape, &opt_input_shape); } @@ -224,9 +229,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp( // input of a OP, but also the output of a Op, there will be problems. // So we have to rename the variable in the subgraph to make sure // it is either an OP's input or an OP's output. - RenameAndGetOutputs(subgraph_nodes, &block_desc, input_names_with_id, - &output_names_with_id, &output_names, &output_name_map, - graph_var_map, !enable_int8); + RenameAndGetOutputs(subgraph_nodes, + &block_desc, + input_names_with_id, + &output_names_with_id, + &output_names, + &output_name_map, + graph_var_map, + !enable_int8); // When tensorrt engine runs at the end of the operation, // output_mapping help us copy the data from the renamed ITensor @@ -234,17 +244,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::vector output_mapping; std::vector renamed_output_dims; for (auto name : output_names) { - PADDLE_ENFORCE_NE(output_name_map.count(name), 0, + PADDLE_ENFORCE_NE(output_name_map.count(name), + 0, platform::errors::PreconditionNotMet( "The output_name_map should have %s", name)); output_mapping.push_back(output_name_map[name]); renamed_output_dims.push_back(origin_name_output_dims[name]); } - PADDLE_ENFORCE_EQ(output_mapping.empty(), false, + PADDLE_ENFORCE_EQ(output_mapping.empty(), + false, platform::errors::PreconditionNotMet( "The output_mapping should not be empty.")); PADDLE_ENFORCE_EQ( - !block_desc.Proto()->vars().empty(), true, + !block_desc.Proto()->vars().empty(), + true, platform::errors::PreconditionNotMet("the block has no var-desc")); // Set attrs @@ -287,14 +300,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp( // when running in the 'use_serialize' mode, there is a bug. // serialization is affected by max_batch_size, but calibration is not. // So we use seperate engine keys in serialization and calibration. - auto engine_key = GenerateEngineKey( - input_names_with_id, output_names_with_id, std::to_string(0), - std::to_string(max_batch_size), - std::to_string(static_cast(precision_mode)), false); + auto engine_key = + GenerateEngineKey(input_names_with_id, + output_names_with_id, + std::to_string(0), + std::to_string(max_batch_size), + std::to_string(static_cast(precision_mode)), + false); auto calibration_engine_key = - GenerateEngineKey(input_names_with_id, output_names_with_id, - std::to_string(0), std::to_string(max_batch_size), - std::to_string(static_cast(precision_mode)), true); + GenerateEngineKey(input_names_with_id, + output_names_with_id, + std::to_string(0), + std::to_string(max_batch_size), + std::to_string(static_cast(precision_mode)), + true); auto predictor_id = Get("predictor_id"); // Get "" when there is no cached calibration table data. @@ -302,7 +321,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( if (enable_int8 && use_calib_mode) { calibration_data = GetTrtCalibTableData(Get("model_opt_cache_dir"), - calibration_engine_key, enable_int8); + calibration_engine_key, + enable_int8); } op_desc->SetAttr("calibration_data", calibration_data); op_desc->SetAttr("enable_int8", enable_int8); @@ -325,13 +345,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp( // calibration table data. bool calibration_mode = (enable_int8 && calibration_data.size() == 0 && use_calib_mode); - if (calibration_mode) { - // calibraion mode means generate int8 calibration table data process. - return; - } - std::copy(params_not_shared.begin(), params_not_shared.end(), - std::back_inserter(*repetitive_params)); + if (!calibration_mode) { + std::copy(params_not_shared.begin(), + params_not_shared.end(), + std::back_inserter(*repetitive_params)); + } // Check trt version for dynamic shape input. @@ -368,10 +387,16 @@ void TensorRtSubgraphPass::CreateTensorRTOp( bool disable_trt_plugin_fp16 = Get("disable_trt_plugin_fp16"); tensorrt::TensorRTEngine *trt_engine = inference::Singleton::Global() - .Create(engine_key + std::to_string(predictor_id), max_batch_size, - Get("workspace_size"), precision_mode, calibrator.get(), - Get("gpu_device_id"), min_input_shape, max_input_shape, - opt_input_shape, disable_trt_plugin_fp16); + .Create(engine_key + std::to_string(predictor_id), + max_batch_size, + Get("workspace_size"), + precision_mode, + calibrator.get(), + Get("gpu_device_id"), + min_input_shape, + max_input_shape, + opt_input_shape, + disable_trt_plugin_fp16); trt_engine->SetUseOSS(Get("use_oss")); trt_engine->SetWithInterleaved(Get("with_interleaved")); trt_engine->SetUseDLA(Get("trt_use_dla")); @@ -384,6 +409,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( (graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) && graph->Has(framework::ir::kMultiheadMatmulPass))); + if (calibration_mode) { + // calibraion mode means generate int8 calibration table data process. + return; + } + if (use_static_engine) { trt_engine_serialized_data = GetTrtEngineSerializedData( Get("model_opt_cache_dir"), engine_key); @@ -408,9 +438,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::unordered_set param_set(params.begin(), params.end()); inference::Singleton::Global() .ConvertBlockToTRTEngine( - &block_desc_temp, *scope, + &block_desc_temp, + *scope, std::vector(input_names.begin(), input_names.end()), - param_set, output_mapping, trt_engine); + param_set, + output_mapping, + trt_engine); if (use_static_engine) { nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize(); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 3a9504d9c67..fdf1fa2142f 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -254,6 +254,11 @@ void TensorRTEngine::FreezeNetwork() { nvinfer1::OptProfileSelector::kOPT, Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true)); } +#if IS_TRT_VERSION_GE(7130) + if (enable_int8) { + infer_builder_config_->setCalibrationProfile(optim_profiles_[i]); + } +#endif infer_builder_config_->addOptimizationProfile(optim_profiles_[i]); } if (WithFp16() && disable_trt_plugin_fp16()) { diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index b4a0478925b..0ce1f7130e3 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -614,6 +614,10 @@ class TensorRTEngine { void SetUseInspector(bool use_inspector) { use_inspector_ = use_inspector; } + void SetCalibrator(TRTInt8Calibrator* calibrator) { + calibrator_ = calibrator; + } + private: // Each ICudaEngine object is bound to a specific GPU when it is instantiated, // ensure that the thread is associated with the correct device by calling diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 0a71875d893..ea5f2045fae 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -52,23 +52,28 @@ namespace operators { using inference::Singleton; using inference::tensorrt::TensorRTEngine; -using inference::tensorrt::TRTInt8Calibrator; using inference::tensorrt::TRTCalibratorEngine; using inference::tensorrt::TRTCalibratorEngineManager; +using inference::tensorrt::TRTInt8Calibrator; static void RuntimeStaticShapeCheck(std::vector runtime_input_shape, std::vector model_input_shape) { auto comma_fold = [](std::string a, int b) { return std::move(a) + ", " + std::to_string(b); }; - std::string model_input_shape_str = std::accumulate( - std::next(model_input_shape.begin()), model_input_shape.end(), - std::to_string(model_input_shape[0]), comma_fold); - std::string runtime_input_shape_str = std::accumulate( - std::next(runtime_input_shape.begin()), runtime_input_shape.end(), - std::to_string(runtime_input_shape[0]), comma_fold); + std::string model_input_shape_str = + std::accumulate(std::next(model_input_shape.begin()), + model_input_shape.end(), + std::to_string(model_input_shape[0]), + comma_fold); + std::string runtime_input_shape_str = + std::accumulate(std::next(runtime_input_shape.begin()), + runtime_input_shape.end(), + std::to_string(runtime_input_shape[0]), + comma_fold); PADDLE_ENFORCE_EQ( - model_input_shape == runtime_input_shape, true, + model_input_shape == runtime_input_shape, + true, platform::errors::InvalidArgument( "Input shapes are inconsistent with the model. Expect [%s] in " "model description, but got [%s] in runtime. TRT 5 " @@ -76,7 +81,8 @@ static void RuntimeStaticShapeCheck(std::vector runtime_input_shape, "does not support dynamic input shapes. Please check and " "modify " "your input shapes.", - model_input_shape_str, runtime_input_shape_str)); + model_input_shape_str, + runtime_input_shape_str)); } static paddle::experimental::DataType TRT2FluidDataType( @@ -102,7 +108,8 @@ static paddle::experimental::DataType TRT2FluidDataType( } static void RuntimeDynamicShapeCheck( - const std::string &x, const std::vector &runtime_input_shape, + const std::string &x, + const std::vector &runtime_input_shape, const std::vector &min_input_shape, const std::vector &max_input_shape) { // PADDLE_ENFORCE_EQ( @@ -111,10 +118,10 @@ static void RuntimeDynamicShapeCheck( // "TRT engine runtime input %s dims size(%d) inconsistent " // "with the dynamic shape size(%d)", // x, runtime_input_shape.size(), min_input_shape.size())); - auto is_input_shape_valid = [&]( - const std::vector &runtime_input_shape, - const std::vector &min_input_shape, - const std::vector &max_input_shape) -> bool { + auto is_input_shape_valid = + [&](const std::vector &runtime_input_shape, + const std::vector &min_input_shape, + const std::vector &max_input_shape) -> bool { for (size_t i = 0; i < runtime_input_shape.size(); i++) { if (runtime_input_shape[i] <= max_input_shape[i] && runtime_input_shape[i] >= min_input_shape[i]) { @@ -128,17 +135,23 @@ static void RuntimeDynamicShapeCheck( auto comma_fold = [](std::string a, int b) { return std::move(a) + ", " + std::to_string(b); }; - std::string runtime_input_shape_str = std::accumulate( - std::next(runtime_input_shape.begin()), runtime_input_shape.end(), - std::to_string(runtime_input_shape[0]), comma_fold); + std::string runtime_input_shape_str = + std::accumulate(std::next(runtime_input_shape.begin()), + runtime_input_shape.end(), + std::to_string(runtime_input_shape[0]), + comma_fold); std::string min_input_shape_str = - std::accumulate(std::next(min_input_shape.begin()), min_input_shape.end(), - std::to_string(min_input_shape[0]), comma_fold); + std::accumulate(std::next(min_input_shape.begin()), + min_input_shape.end(), + std::to_string(min_input_shape[0]), + comma_fold); std::string max_input_shape_str = - std::accumulate(std::next(max_input_shape.begin()), max_input_shape.end(), - std::to_string(max_input_shape[0]), comma_fold); - PADDLE_ENFORCE_EQ(is_input_shape_valid(runtime_input_shape, min_input_shape, - max_input_shape), + std::accumulate(std::next(max_input_shape.begin()), + max_input_shape.end(), + std::to_string(max_input_shape[0]), + comma_fold); + PADDLE_ENFORCE_EQ(is_input_shape_valid( + runtime_input_shape, min_input_shape, max_input_shape), true, platform::errors::InvalidArgument( "TRT runtime input shape of %s is invalid. Expect " @@ -146,7 +159,9 @@ static void RuntimeDynamicShapeCheck( "configured in SetTRTDynamicShapeInfo()," "but got runtime input shape = [%s], min input shape = " "[%s], max input shape = [%s].", - x, runtime_input_shape_str, min_input_shape_str, + x, + runtime_input_shape_str, + min_input_shape_str, max_input_shape_str)); } @@ -158,7 +173,7 @@ class TensorRTEngineOp : public framework::OperatorBase { mutable TensorRTEngine *trt_engine_{nullptr}; int max_batch_size_; int workspace_size_; - std::unique_ptr calibrator_; + mutable std::unique_ptr calibrator_; bool enable_int8_; bool enable_fp16_; bool use_calib_mode_; @@ -260,7 +275,7 @@ class TensorRTEngineOp : public framework::OperatorBase { inference::Singleton::Global() .Has(engine_key_ + std::to_string(predictor_id_)); - if (!calibration_mode_ && has_engine) { + if (has_engine) { trt_engine_ = inference::Singleton::Global() .Get(engine_key_ + std::to_string(predictor_id_)); @@ -287,8 +302,8 @@ class TensorRTEngineOp : public framework::OperatorBase { Attr>("output_name_mapping"); inference::Singleton::Global() - .ConvertBlockToTRTEngine(&block_desc, scope, inputs, param_names_, - outputs, engine); + .ConvertBlockToTRTEngine( + &block_desc, scope, inputs, param_names_, outputs, engine); } protected: @@ -304,11 +319,14 @@ class TensorRTEngineOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { + auto *trt_engine = GetEngine(scope, dev_place); if (calibration_mode_ == true) { - RunCalibration(scope, dev_place); + RunCalibration(scope, dev_place, trt_engine); + paddle::inference::Singleton< + inference::tensorrt::TRTEngineManager>::Global() + .DeleteKey(engine_key_ + std::to_string(predictor_id_)); return; } - auto *trt_engine = GetEngine(scope, dev_place); if (use_inspector_) { trt_engine->GetEngineInfo(); } @@ -331,15 +349,19 @@ class TensorRTEngineOp : public framework::OperatorBase { trt_engine->max_input_shape(); for (auto &x : runtime_input_names_) { PADDLE_ENFORCE_EQ( - min_input_shape.count(x), true, + min_input_shape.count(x), + true, platform::errors::InvalidArgument( "Input %s not found in TRT engine min_input_shape.", x)); PADDLE_ENFORCE_EQ( - max_input_shape.count(x), true, + max_input_shape.count(x), + true, platform::errors::InvalidArgument( "Input %s not found in TRT engine max_input_shape.", x)); - RuntimeDynamicShapeCheck(x, runtime_input_shape[x], - min_input_shape[x], max_input_shape[x]); + RuntimeDynamicShapeCheck(x, + runtime_input_shape[x], + min_input_shape[x], + max_input_shape[x]); } } else { // compare runtime_input_shape and trt_engine dynamic shapes. @@ -357,13 +379,18 @@ class TensorRTEngineOp : public framework::OperatorBase { if (anc == nullptr) { anc = &scope; } + if (enable_int8_ && calibration_data_.size()) { + calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); + trt_engine->SetCalibrator(calibrator_.get()); + } PrepareTRTEngine(*anc, trt_engine); // update shape_range_info_pbtxt if (!shape_range_info_path_.empty()) { - inference::UpdateShapeRangeInfo( - shape_range_info_path_, trt_engine->min_input_shape(), - trt_engine->max_input_shape(), trt_engine->optim_input_shape(), - shape_changed_name); + inference::UpdateShapeRangeInfo(shape_range_info_path_, + trt_engine->min_input_shape(), + trt_engine->max_input_shape(), + trt_engine->optim_input_shape(), + shape_changed_name); } if (use_static_engine_) { @@ -387,7 +414,8 @@ class TensorRTEngineOp : public framework::OperatorBase { } void RunCalibration(const framework::Scope &scope, - const platform::Place &dev_place) const { + const platform::Place &dev_place, + TensorRTEngine *trt_engine) const { // This process will builds a 32-bit trt engine, runs it on the calibration // set, and records a histogram for each // tensor of the distribution of activation values. @@ -412,9 +440,15 @@ class TensorRTEngineOp : public framework::OperatorBase { calib_res->calib_.reset(new TRTInt8Calibrator( calib_buffers, runtime_batch, calibration_engine_key_, dev_place)); calib_res->thr_.reset(new std::thread([&]() { - calib_res->engine_.reset(new TensorRTEngine( - max_batch_size_, workspace_size_, precision_mode_, - calib_res->calib_.get(), dev_place.device)); + calib_res->engine_.reset( + new TensorRTEngine(max_batch_size_, + workspace_size_, + precision_mode_, + calib_res->calib_.get(), + dev_place.device, + trt_engine->min_input_shape(), + trt_engine->max_input_shape(), + trt_engine->optim_input_shape())); VLOG(3) << "start the calib trt engine thread"; PrepareTRTEngine(scope, calib_res->engine_.get()); })); @@ -436,7 +470,8 @@ class TensorRTEngineOp : public framework::OperatorBase { RunNativeImpl(scope, dev_place); } - void RunTrt(const framework::Scope &scope, const platform::Place &dev_place, + void RunTrt(const framework::Scope &scope, + const platform::Place &dev_place, TensorRTEngine *engine) const { int runtime_batch = -1; platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); @@ -445,7 +480,8 @@ class TensorRTEngineOp : public framework::OperatorBase { reinterpret_cast(dev_ctx).stream(); PADDLE_ENFORCE_EQ( - runtime_input_names_.empty(), false, + runtime_input_names_.empty(), + false, platform::errors::PreconditionNotMet( "TensorRT engine needs at least one input, but no input is found. " "Please check if you set the input correctly.")); @@ -488,13 +524,15 @@ class TensorRTEngineOp : public framework::OperatorBase { const int bind_index = engine->engine()->getBindingIndex(x.c_str()) + binding_offset; PADDLE_ENFORCE_LT( - bind_index, num_bindings, + bind_index, + num_bindings, platform::errors::InvalidArgument( "Wrong TRT engine input binding index. Expected The " "binding index of TRT engine input to be less than " "the number of inputs and outputs. Received binding " "index=%d >= total inputs and outputs=%d", - bind_index, num_bindings)); + bind_index, + num_bindings)); if (!engine->with_dynamic_shape()) { // check if the input shapes are consistent with model. if (HasAttr(x + "_shape")) { @@ -507,7 +545,8 @@ class TensorRTEngineOp : public framework::OperatorBase { RuntimeStaticShapeCheck(runtime_input_shape, model_input_shape); if (runtime_batch != -1) { PADDLE_ENFORCE_EQ( - runtime_batch, t_shape[0], + runtime_batch, + t_shape[0], platform::errors::InvalidArgument( "Inputs of trt subgraphs has different batchsize. " "It's not allowed in static shape mode. " @@ -589,12 +628,14 @@ class TensorRTEngineOp : public framework::OperatorBase { auto *fluid_t = fluid_v->GetMutable(); fluid_t->Resize(phi::make_ddim(ddim)); - PADDLE_ENFORCE_LT(bind_index, num_bindings, + PADDLE_ENFORCE_LT(bind_index, + num_bindings, platform::errors::InvalidArgument( "The binding index in TRT engine should be less " "than the number of bindings, but got binding " "index = %d, number of bindings = %d.", - bind_index, num_bindings)); + bind_index, + num_bindings)); auto trt_type = engine->engine()->getBindingDataType(bind_index); // get adr and set type buffers[bind_index] = static_cast( @@ -604,7 +645,8 @@ class TensorRTEngineOp : public framework::OperatorBase { if (!engine->with_dynamic_shape()) { PADDLE_ENFORCE_LE( - runtime_batch, max_batch_size_, + runtime_batch, + max_batch_size_, platform::errors::InvalidArgument( "The runtime batch size (%d) is greater than the max batch " "size(%d).\n" @@ -623,7 +665,8 @@ class TensorRTEngineOp : public framework::OperatorBase { "\tThe min_subgraph_size shouble to be greater than the number " "of " "nodes in the inconsistent subgraph.\n", - runtime_batch, max_batch_size_)); + runtime_batch, + max_batch_size_)); } // Execute the engine. engine->Execute(runtime_batch, &buffers, stream); @@ -635,9 +678,14 @@ class TensorRTEngineOp : public framework::OperatorBase { trt_engine_ = inference::Singleton::Global() .Create(engine_key_ + std::to_string(predictor_id_), - max_batch_size_, workspace_size_, precision_mode_, - calibrator_.get(), device_id_, min_input_shape_, - max_input_shape_, opt_input_shape_); + max_batch_size_, + workspace_size_, + precision_mode_, + calibrator_.get(), + device_id_, + min_input_shape_, + max_input_shape_, + opt_input_shape_); PrepareTRTEngine(scope, trt_engine_); } return trt_engine_; -- GitLab