diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 579549a4c3ec476b979f7dd7919a1b8a38850a7c..2f5da3c44b97fd41025e90e296be121027a9a379 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -438,11 +438,32 @@ class TensorRTEngineOp : public framework::OperatorBase { calib_res->calib_.reset(new TRTInt8Calibrator( calib_buffers, runtime_batch, calibration_engine_key_, dev_place)); calib_res->thr_.reset(new std::thread([&]() { + std::map> min_input_shape; + std::map> max_input_shape; + std::map> opt_input_shape; + std::map> min_shape_tensor; + std::map> max_shape_tensor; + std::map> opt_shape_tensor; + if (shape_range_info_path_.size()) + inference::DeserializeShapeRangeInfo(shape_range_info_path_, + &min_input_shape, + &max_input_shape, + &opt_input_shape, + &min_shape_tensor, + &max_shape_tensor, + &opt_shape_tensor); + calib_res->engine_.reset(new TensorRTEngine(max_batch_size_, workspace_size_, precision_mode_, calib_res->calib_.get(), - dev_place.device)); + dev_place.device, + min_input_shape, + max_input_shape, + opt_input_shape, + min_shape_tensor, + max_shape_tensor, + opt_shape_tensor)); VLOG(3) << "start the calib trt engine thread"; PrepareTRTEngine(scope, calib_res->engine_.get()); }));