提交 7cde2d9e 编写于 作者: N nhzlx

fix trt engine test error.

test=develop
上级 d065b5bf
...@@ -82,7 +82,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -82,7 +82,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
} }
if (!calibration_mode_) { if (!calibration_mode_ && !engine_serialized_data_.empty()) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine( trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(), max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
device_id_)); device_id_));
...@@ -236,6 +236,9 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -236,6 +236,9 @@ class TensorRTEngineOp : public framework::OperatorBase {
TensorRTEngine *GetEngine(const framework::Scope &scope, TensorRTEngine *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
if (!trt_engine_) { if (!trt_engine_) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
device_id_));
PrepareTRTEngine(scope, trt_engine_.get()); PrepareTRTEngine(scope, trt_engine_.get());
} }
return trt_engine_.get(); return trt_engine_.get();
......
...@@ -108,6 +108,8 @@ TEST(TensorRTEngineOp, manual) { ...@@ -108,6 +108,8 @@ TEST(TensorRTEngineOp, manual) {
std::vector<std::string>({"z0"})); std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string("")); engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
engine_op_desc.SetAttr("gpu_id", device_id);
LOG(INFO) << "create engine op"; LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
...@@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
std::vector<std::string>({"z3"})); std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string("")); engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
engine_op_desc.SetAttr("gpu_id", device_id);
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册