From 7cde2d9e8473fe2eb3845604f1b6a0a7d69907f4 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Fri, 29 Mar 2019 04:41:38 +0000 Subject: [PATCH] fix trt engine test error. test=develop --- paddle/fluid/operators/tensorrt/tensorrt_engine_op.h | 5 ++++- paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 8010bd8ecc..7f470924b3 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -82,7 +82,7 @@ class TensorRTEngineOp : public framework::OperatorBase { calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); } - if (!calibration_mode_) { + if (!calibration_mode_ && !engine_serialized_data_.empty()) { trt_engine_.reset(new inference::tensorrt::TensorRTEngine( max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(), device_id_)); @@ -236,6 +236,9 @@ class TensorRTEngineOp : public framework::OperatorBase { TensorRTEngine *GetEngine(const framework::Scope &scope, const platform::Place &dev_place) const { if (!trt_engine_) { + trt_engine_.reset(new inference::tensorrt::TensorRTEngine( + max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(), + device_id_)); PrepareTRTEngine(scope, trt_engine_.get()); } return trt_engine_.get(); diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index e7ad2f4fe0..cc4d8d6e6f 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -108,6 +108,8 @@ TEST(TensorRTEngineOp, manual) { std::vector({"z0"})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); + int device_id = 0; + engine_op_desc.SetAttr("gpu_id", device_id); LOG(INFO) << "create engine op"; auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); @@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { std::vector({"z3"})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); + int device_id = 0; + engine_op_desc.SetAttr("gpu_id", device_id); auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); -- GitLab