提交 7cde2d9e 编写于 作者: N nhzlx

fix trt engine test error.

test=develop
上级 d065b5bf
......@@ -82,7 +82,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
}
if (!calibration_mode_) {
if (!calibration_mode_ && !engine_serialized_data_.empty()) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
device_id_));
......@@ -236,6 +236,9 @@ class TensorRTEngineOp : public framework::OperatorBase {
TensorRTEngine *GetEngine(const framework::Scope &scope,
const platform::Place &dev_place) const {
if (!trt_engine_) {
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
device_id_));
PrepareTRTEngine(scope, trt_engine_.get());
}
return trt_engine_.get();
......
......@@ -108,6 +108,8 @@ TEST(TensorRTEngineOp, manual) {
std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
engine_op_desc.SetAttr("gpu_id", device_id);
LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
......@@ -204,6 +206,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
engine_op_desc.SetAttr("engine_serialized_data", std::string(""));
int device_id = 0;
engine_op_desc.SetAttr("gpu_id", device_id);
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册