提交 067ba31f 编写于 作者: S syyxsxx

fix mkldnn

上级 654adfe9
...@@ -61,9 +61,9 @@ int main(int argc, char** argv) { ...@@ -61,9 +61,9 @@ int main(int argc, char** argv) {
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_trt, FLAGS_use_trt,
FLAGS_use_mkl, FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id, FLAGS_gpu_id,
FLAGS_key, FLAGS_key);
FLAGS_mkl_thread_num);
// Predict // Predict
int imgs = 1; int imgs = 1;
......
...@@ -66,9 +66,9 @@ int main(int argc, char** argv) { ...@@ -66,9 +66,9 @@ int main(int argc, char** argv) {
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_trt, FLAGS_use_trt,
FLAGS_use_mkl, FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id, FLAGS_gpu_id,
FLAGS_key, FLAGS_key);
FLAGS_mkl_thread_num);
int imgs = 1; int imgs = 1;
std::string save_dir = "output"; std::string save_dir = "output";
// Predict // Predict
......
...@@ -63,9 +63,9 @@ int main(int argc, char** argv) { ...@@ -63,9 +63,9 @@ int main(int argc, char** argv) {
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_trt, FLAGS_use_trt,
FLAGS_use_mkl, FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id, FLAGS_gpu_id,
FLAGS_key, FLAGS_key);
FLAGS_mkl_thread_num);
int imgs = 1; int imgs = 1;
// Predict // Predict
if (FLAGS_image_list != "") { if (FLAGS_image_list != "") {
......
...@@ -67,9 +67,9 @@ int main(int argc, char** argv) { ...@@ -67,9 +67,9 @@ int main(int argc, char** argv) {
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_trt, FLAGS_use_trt,
FLAGS_use_mkl, FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id, FLAGS_gpu_id,
FLAGS_key, FLAGS_key);
FLAGS_mkl_thread_num);
// Open video // Open video
cv::VideoCapture capture; cv::VideoCapture capture;
......
...@@ -69,9 +69,9 @@ int main(int argc, char** argv) { ...@@ -69,9 +69,9 @@ int main(int argc, char** argv) {
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_trt, FLAGS_use_trt,
FLAGS_use_mkl, FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id, FLAGS_gpu_id,
FLAGS_key, FLAGS_key);
FLAGS_mkl_thread_num);
// Open video // Open video
cv::VideoCapture capture; cv::VideoCapture capture;
if (FLAGS_use_camera) { if (FLAGS_use_camera) {
......
...@@ -67,9 +67,9 @@ int main(int argc, char** argv) { ...@@ -67,9 +67,9 @@ int main(int argc, char** argv) {
FLAGS_use_gpu, FLAGS_use_gpu,
FLAGS_use_trt, FLAGS_use_trt,
FLAGS_use_mkl, FLAGS_use_mkl,
FLAGS_mkl_thread_num,
FLAGS_gpu_id, FLAGS_gpu_id,
FLAGS_key, FLAGS_key);
FLAGS_mkl_thread_num);
// Open video // Open video
cv::VideoCapture capture; cv::VideoCapture capture;
if (FLAGS_use_camera) { if (FLAGS_use_camera) {
......
...@@ -70,6 +70,8 @@ class Model { ...@@ -70,6 +70,8 @@ class Model {
* @param model_dir: the directory which contains model.yml * @param model_dir: the directory which contains model.yml
* @param use_gpu: use gpu or not when infering * @param use_gpu: use gpu or not when infering
* @param use_trt: use Tensor RT or not when infering * @param use_trt: use Tensor RT or not when infering
* @param use_trt: use mkl or not when infering
* @param mkl_thread_num: the threads of mkl when infering
* @param gpu_id: the id of gpu when infering with using gpu * @param gpu_id: the id of gpu when infering with using gpu
* @param key: the key of encryption when using encrypted model * @param key: the key of encryption when using encrypted model
* @param use_ir_optim: use ir optimization when infering * @param use_ir_optim: use ir optimization when infering
...@@ -78,28 +80,27 @@ class Model { ...@@ -78,28 +80,27 @@ class Model {
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false, bool use_trt = false,
bool use_mkl = true, bool use_mkl = true,
int mkl_thread_num = 4,
int gpu_id = 0, int gpu_id = 0,
std::string key = "", std::string key = "",
int mkl_thread_num = 4,
bool use_ir_optim = true) { bool use_ir_optim = true) {
create_predictor( create_predictor(
model_dir, model_dir,
use_gpu, use_gpu,
use_trt, use_trt,
use_mkl, use_mkl,
mkl_thread_num,
gpu_id, gpu_id,
key, key,
mkl_thread_num,
use_ir_optim); use_ir_optim);
} }
void create_predictor(const std::string& model_dir, void create_predictor(const std::string& model_dir,
bool use_gpu = false, bool use_gpu = false,
bool use_trt = false, bool use_trt = false,
bool use_mkl = true, bool use_mkl = true,
int mkl_thread_num = 4,
int gpu_id = 0, int gpu_id = 0,
std::string key = "", std::string key = "",
int mkl_thread_num = 4,
bool use_ir_optim = true); bool use_ir_optim = true);
/* /*
......
...@@ -29,9 +29,9 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -29,9 +29,9 @@ void Model::create_predictor(const std::string& model_dir,
bool use_gpu, bool use_gpu,
bool use_trt, bool use_trt,
bool use_mkl, bool use_mkl,
int mkl_thread_num,
int gpu_id, int gpu_id,
std::string key, std::string key,
int mkl_thread_num,
bool use_ir_optim) { bool use_ir_optim) {
paddle::AnalysisConfig config; paddle::AnalysisConfig config;
std::string model_file = model_dir + OS_PATH_SEP + "__model__"; std::string model_file = model_dir + OS_PATH_SEP + "__model__";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册