diff --git a/benchmark/tm_benchmark.cc b/benchmark/tm_benchmark.cc index d0d92e805551e4aa32cd6273a12d1f7a92fb3b50..c2a90f4f98132a1cfbb52fbba17196572681912e 100644 --- a/benchmark/tm_benchmark.cc +++ b/benchmark/tm_benchmark.cc @@ -37,12 +37,14 @@ int benchmark_threads = 1; int benchmark_model = -1; int benchmark_cluster = 0; int benchmark_mask = 0xFFFF; +std::string benchmark_device = ""; +context_t s_context; int benchmark_graph(options_t* opt, const char* name, const char* file, int height, int width, int channel, int batch) { // create graph, load tengine model xxx.tmfile - graph_t graph = create_graph(nullptr, "tengine", file); + graph_t graph = create_graph(s_context, "tengine", file); if (nullptr == graph) { fprintf(stderr, "Tengine Benchmark: Create graph failed.\n"); @@ -144,6 +146,7 @@ int main(int argc, char* argv[]) cmd.add("cpu_cluster", 'p', "cpu cluster [0:auto, 1:big, 2:middle, 3:little]", false, 0); cmd.add("model", 's', "benchmark which model, \"-1\" means all models", false, -1); cmd.add("cpu_mask", 'a', "benchmark on masked cpu core(s)", false, -1); + cmd.add("device", 'd', "device name (should be upper-case)", false); cmd.parse_check(argc, argv); @@ -152,12 +155,25 @@ int main(int argc, char* argv[]) benchmark_model = cmd.get("model"); benchmark_cluster = cmd.get("cpu_cluster"); benchmark_mask = cmd.get("cpu_mask"); + benchmark_device = cmd.get("device"); + if (benchmark_device.empty()) + { + benchmark_device = "CPU"; + } + else + { + for (int i = 0; i < benchmark_device.length(); i++) + { + benchmark_device[i] = ::toupper(benchmark_device[i]); + } + } fprintf(stdout, "Tengine benchmark:\n"); fprintf(stdout, " loops: %d\n", benchmark_loop); fprintf(stdout, " threads: %d\n", benchmark_threads); fprintf(stdout, " cluster: %d\n", benchmark_cluster); fprintf(stdout, " affinity: 0x%X\n", benchmark_mask); + fprintf(stdout, " device: %s\n", benchmark_device.c_str()); // initialize tengine if (0 != init_tengine()) @@ -167,6 +183,17 @@ int main(int argc, char* argv[]) } fprintf(stdout, "Tengine-lite library version: %s\n", get_tengine_version()); + s_context = create_context("ctx", benchmark_device.empty() ? 0 : 1); + if (!benchmark_device.empty()) + { + int ret = set_context_device(s_context, benchmark_device.c_str(), nullptr, 0); + if (0 != ret) + { + fprintf(stderr, "Set context device failed: %d.\n", ret); + return false; + } + } + struct options opt; opt.num_thread = benchmark_threads; opt.precision = TENGINE_MODE_FP32; @@ -253,6 +280,7 @@ int main(int argc, char* argv[]) } /* release tengine */ + destroy_context(s_context); release_tengine(); fprintf(stderr, "ALL TEST DONE.\n");