提交 4c52be07 编写于 作者: N nhzlx

fix ut error

上级 dcc09dce
...@@ -66,10 +66,13 @@ TEST(SubGraphSplitter, Split) { ...@@ -66,10 +66,13 @@ TEST(SubGraphSplitter, Split) {
TEST(SubGraphSplitter, Fuse) { TEST(SubGraphSplitter, Fuse) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
Argument argument;
int* minmum_subgraph_size = new int(3);
argument.Set<int>("minmum_subgraph_size", minmum_subgraph_size);
size_t count0 = dfg.nodes.size(); size_t count0 = dfg.nodes.size();
SubGraphFuse fuse(&dfg, teller); SubGraphFuse fuse(&dfg, teller, &argument);
fuse(); fuse();
int count1 = 0; int count1 = 0;
......
...@@ -58,8 +58,6 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block, ...@@ -58,8 +58,6 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block,
using inference::analysis::SetAttr; using inference::analysis::SetAttr;
TEST(TensorRTEngineOp, manual) { TEST(TensorRTEngineOp, manual) {
FLAGS_tensorrt_engine_batch_size = 2;
FLAGS_tensorrt_max_batch_size = 2;
framework::ProgramDesc program; framework::ProgramDesc program;
auto* block_ = program.Proto()->add_blocks(); auto* block_ = program.Proto()->add_blocks();
block_->set_idx(0); block_->set_idx(0);
...@@ -101,6 +99,8 @@ TEST(TensorRTEngineOp, manual) { ...@@ -101,6 +99,8 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"})); engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph", SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString()); block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", 2);
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10);
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine"); SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters", SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
std::vector<std::string>({})); std::vector<std::string>({}));
...@@ -129,8 +129,6 @@ TEST(TensorRTEngineOp, manual) { ...@@ -129,8 +129,6 @@ TEST(TensorRTEngineOp, manual) {
} }
void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
FLAGS_tensorrt_engine_batch_size = batch_size;
FLAGS_tensorrt_max_batch_size = batch_size;
framework::ProgramDesc program; framework::ProgramDesc program;
framework::Scope scope; framework::Scope scope;
platform::CUDAPlace place; platform::CUDAPlace place;
...@@ -195,8 +193,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { ...@@ -195,8 +193,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph", SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString()); block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch", batch_size); SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", batch_size);
SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 2 << 10); SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10);
SetAttr<std::vector<std::string>>( SetAttr<std::vector<std::string>>(
engine_op_desc.Proto(), "parameters", engine_op_desc.Proto(), "parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"})); std::vector<std::string>({"y0", "y1", "y2", "y3"}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册