diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 8010bd8ecc63bedb4d69a0fd4b42bd6d6cad23e3..7f470924b337d59943c04ab0ff2820555f961732 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -82,7 +82,7 @@ class TensorRTEngineOp : public framework::OperatorBase { calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); } - if (!calibration_mode_) { + if (!calibration_mode_ && !engine_serialized_data_.empty()) { trt_engine_.reset(new inference::tensorrt::TensorRTEngine( max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(), device_id_)); @@ -236,6 +236,9 @@ class TensorRTEngineOp : public framework::OperatorBase { TensorRTEngine *GetEngine(const framework::Scope &scope, const platform::Place &dev_place) const { if (!trt_engine_) { + trt_engine_.reset(new inference::tensorrt::TensorRTEngine( + max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(), + device_id_)); PrepareTRTEngine(scope, trt_engine_.get()); } return trt_engine_.get(); diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index e7ad2f4fe0c654d8928f5793c1ad8052ab766fb5..cc4d8d6e6f7e24dcb04ed0f58e63cb13ce176bdb 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -108,6 +108,8 @@ TEST(TensorRTEngineOp, manual) { std::vector({"z0"})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); + int device_id = 0; + engine_op_desc.SetAttr("gpu_id", device_id); LOG(INFO) << "create engine op"; auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); @@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { std::vector({"z3"})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); + int device_id = 0; + engine_op_desc.SetAttr("gpu_id", device_id); auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);