提交 8817841c 编写于 作者: N nhzlx

fix unit test bug

test=develop
上级 b95f2ff8
...@@ -58,7 +58,7 @@ class TensorRTEngine : public EngineBase { ...@@ -58,7 +58,7 @@ class TensorRTEngine : public EngineBase {
TensorRTEngine(int max_batch, int max_workspace, TensorRTEngine(int max_batch, int max_workspace,
cudaStream_t* stream = nullptr, int device = 0, cudaStream_t* stream = nullptr, int device = 0,
bool enable_int8 = "false", bool enable_int8 = false,
TRTInt8Calibrator* calibrator = nullptr, TRTInt8Calibrator* calibrator = nullptr,
nvinfer1::ILogger& logger = NaiveLogger::Global()) nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch), : max_batch_(max_batch),
......
...@@ -96,19 +96,20 @@ TEST(TensorRTEngineOp, manual) { ...@@ -96,19 +96,20 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetType("tensorrt_engine"); engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x"})); engine_op_desc.SetInput("Xs", std::vector<std::string>({"x"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"})); engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString()); engine_op_desc.SetBlockAttr("sub_block", &block_desc);
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", 2); engine_op_desc.SetAttr("max_batch_size", static_cast<int>(2));
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10); engine_op_desc.SetAttr("workspace_size", static_cast<int>(2 << 10));
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine"); engine_op_desc.SetAttr("parameters", std::vector<std::string>({}));
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters", engine_op_desc.SetAttr("engine_key", std::string("a_engine"));
std::vector<std::string>({})); engine_op_desc.SetAttr("calibration_data", std::string(""));
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
"output_name_mapping", engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z0"})); std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
LOG(INFO) << "create engine op"; LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto()); auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
LOG(INFO) << "engine_op " << engine_op.get(); LOG(INFO) << "engine_op " << engine_op.get();
framework::Scope scope; framework::Scope scope;
...@@ -190,20 +191,19 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -190,20 +191,19 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x0"})); engine_op_desc.SetInput("Xs", std::vector<std::string>({"x0"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z3"})); engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph", engine_op_desc.SetBlockAttr("sub_block", &block_desc);
block_->SerializeAsString()); engine_op_desc.SetAttr("max_batch_size", static_cast<int>(batch_size));
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", batch_size); engine_op_desc.SetAttr("workspace_size", static_cast<int>(2 << 10));
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10); engine_op_desc.SetAttr("parameters",
SetAttr<std::vector<std::string>>( std::vector<std::string>({"y0", "y1", "y2", "y3"}));
engine_op_desc.Proto(), "parameters", engine_op_desc.SetAttr("engine_key", std::string("b_engine"));
std::vector<std::string>({"y0", "y1", "y2", "y3"})); engine_op_desc.SetAttr("calibration_data", std::string(""));
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "b_engine"); engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), std::vector<std::string>({"z3"}));
"output_name_mapping", engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
std::vector<std::string>({"z3"}));
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
// Execute them. // Execute them.
engine_op->Run(scope, place); engine_op->Run(scope, place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册