提交 db43e2e4 编写于 作者: F FlyingQianMM

clse mkldnn for ppyolo

上级 9a703419
......@@ -66,9 +66,14 @@ void Model::create_predictor(const std::string& model_dir,
if (key == "") {
config.SetModel(model_file, params_file);
}
if (use_mkl && name != "HRNet" && name != "DeepLabv3p") {
config.EnableMKLDNN();
config.SetCpuMathLibraryNumThreads(mkl_thread_num);
if (use_mkl) {
if (name != "HRNet" && name != "DeepLabv3p" && name != "PPYOLO") {
config.EnableMKLDNN();
config.SetCpuMathLibraryNumThreads(mkl_thread_num);
} else {
std::cerr << "HRNet/DeepLabv3p/PPYOLO are not supported "
<< "for the use of mkldnn" << std::endl;
}
}
if (use_gpu) {
config.EnableUseGpu(100, gpu_id);
......
......@@ -23,6 +23,7 @@ from paddlex.cv.transforms import build_transforms
from paddlex.cv.models import BaseClassifier
from paddlex.cv.models import PPYOLO, FasterRCNN, MaskRCNN
from paddlex.cv.models import DeepLabv3p
import paddlex.utils.logging as logging
class Predictor:
......@@ -108,9 +109,13 @@ class Predictor:
else:
config.disable_gpu()
if use_mkl:
if self.model_name not in ["HRNet", "DeepLabv3p"]:
if self.model_name not in ["HRNet", "DeepLabv3p", "PPYOLO"]:
config.enable_mkldnn()
config.set_cpu_math_library_num_threads(mkl_thread_num)
else:
logging.warning(
"HRNet/DeepLabv3p/PPYOLO are not supported for the use of mkldnn\n"
)
if use_glog:
config.enable_glog_info()
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册