提交 8817841c 编写于 作者: N nhzlx

fix unit test bug

test=develop
上级 b95f2ff8
......@@ -58,7 +58,7 @@ class TensorRTEngine : public EngineBase {
TensorRTEngine(int max_batch, int max_workspace,
cudaStream_t* stream = nullptr, int device = 0,
bool enable_int8 = "false",
bool enable_int8 = false,
TRTInt8Calibrator* calibrator = nullptr,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
......
......@@ -96,19 +96,20 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", 2);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10);
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
std::vector<std::string>({}));
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(),
"output_name_mapping",
std::vector<std::string>({"z0"}));
engine_op_desc.SetBlockAttr("sub_block", &block_desc);
engine_op_desc.SetAttr("max_batch_size", static_cast<int>(2));
engine_op_desc.SetAttr("workspace_size", static_cast<int>(2 << 10));
engine_op_desc.SetAttr("parameters", std::vector<std::string>({}));
engine_op_desc.SetAttr("engine_key", std::string("a_engine"));
engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z0"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
LOG(INFO) << "engine_op " << engine_op.get();
framework::Scope scope;
......@@ -190,20 +191,19 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x0"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", batch_size);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10);
SetAttr<std::vector<std::string>>(
engine_op_desc.Proto(), "parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "b_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(),
"output_name_mapping",
std::vector<std::string>({"z3"}));
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
engine_op_desc.SetBlockAttr("sub_block", &block_desc);
engine_op_desc.SetAttr("max_batch_size", static_cast<int>(batch_size));
engine_op_desc.SetAttr("workspace_size", static_cast<int>(2 << 10));
engine_op_desc.SetAttr("parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
engine_op_desc.SetAttr("engine_key", std::string("b_engine"));
engine_op_desc.SetAttr("calibration_data", std::string(""));
engine_op_desc.SetAttr("enable_int8", static_cast<bool>(false));
engine_op_desc.SetAttr("output_name_mapping",
std::vector<std::string>({"z3"}));
engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString()));
auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);
// Execute them.
engine_op->Run(scope, place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册