未验证 提交 57158544 编写于 作者: L lidanqing 提交者: GitHub

Add mkldnn bfloat option in inference scripts (#5212)

* add mkldnn bfloat16 args

* add mkldnn_bfloat16 to static deploy

* update

* update
上级 50da62fd
......@@ -89,6 +89,7 @@ class Detector(object):
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
output_dir (str): The path of output
threshold (float): The threshold of score for visualization
"""
......@@ -105,6 +106,7 @@ class Detector(object):
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='output',
threshold=0.5, ):
self.pred_config = self.set_config(model_dir)
......@@ -120,7 +122,8 @@ class Detector(object):
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.batch_size = batch_size
......@@ -323,6 +326,7 @@ class DetectorSOLOv2(Detector):
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16
output_dir (str): The path of output
threshold (float): The threshold of score for visualization
......@@ -340,6 +344,7 @@ class DetectorSOLOv2(Detector):
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='./',
threshold=0.5, ):
super(DetectorSOLOv2, self).__init__(
......@@ -353,6 +358,7 @@ class DetectorSOLOv2(Detector):
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
output_dir=output_dir,
threshold=threshold, )
......@@ -399,7 +405,8 @@ class DetectorPicoDet(Detector):
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
enable_mkldnn (bool): whether to turn on MKLDNN
enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
"""
def __init__(
......@@ -414,6 +421,7 @@ class DetectorPicoDet(Detector):
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='./',
threshold=0.5, ):
super(DetectorPicoDet, self).__init__(
......@@ -427,6 +435,7 @@ class DetectorPicoDet(Detector):
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
output_dir=output_dir,
threshold=threshold, )
......@@ -571,7 +580,8 @@ def load_predictor(model_dir,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
enable_mkldnn=False,
enable_mkldnn_bfloat16=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
......@@ -611,6 +621,8 @@ def load_predictor(model_dir,
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if enable_mkldnn_bfloat16:
config.enable_mkldnn_bfloat16()
except Exception as e:
print(
"The current environment does not support `mkldnn`, so disable mkldnn."
......@@ -747,6 +759,7 @@ def main():
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir)
......@@ -781,4 +794,6 @@ if __name__ == '__main__':
], "device should be CPU, GPU or XPU"
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
assert not (FLAGS.enable_mkldnn==False and FLAGS.enable_mkldnn_bfloat16==True), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
main()
......@@ -80,6 +80,11 @@ def argsparser():
type=ast.literal_eval,
default=False,
help="Whether use mkldnn with CPU.")
parser.add_argument(
"--enable_mkldnn_bfloat16",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn bfloat16 inference with CPU.")
parser.add_argument(
"--cpu_threads", type=int, default=1, help="Num of threads with CPU.")
parser.add_argument(
......
......@@ -60,6 +60,7 @@ class Detector(object):
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
threshold (float): threshold to reserve the result for output.
enable_mkldnn (bool): whether use mkldnn with CPU.
enable_mkldnn_bfloat16 (bool): whether use mkldnn bfloat16 with CPU.
"""
def __init__(self,
......@@ -69,7 +70,8 @@ class Detector(object):
run_mode='fluid',
threshold=0.5,
trt_calib_mode=False,
enable_mkldnn=False):
enable_mkldnn=False,
enable_mkldnn_bfloat16=False):
self.config = config
if self.config.use_python_inference:
self.executor, self.program, self.fecth_targets = load_executor(
......@@ -81,7 +83,8 @@ class Detector(object):
min_subgraph_size=self.config.min_subgraph_size,
device=device,
trt_calib_mode=trt_calib_mode,
enable_mkldnn=enable_mkldnn)
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16)
def preprocess(self, im):
preprocess_ops = []
......@@ -229,7 +232,8 @@ class DetectorSOLOv2(Detector):
run_mode='fluid',
threshold=0.5,
trt_calib_mode=False,
enable_mkldnn=False):
enable_mkldnn=False,
enable_mkldnn_bfloat16=False):
super(DetectorSOLOv2, self).__init__(
config=config,
model_dir=model_dir,
......@@ -237,7 +241,8 @@ class DetectorSOLOv2(Detector):
run_mode=run_mode,
threshold=threshold,
trt_calib_mode=trt_calib_mode,
enable_mkldn=enable_mkldnn)
enable_mkldn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16)
def predict(self,
image,
......@@ -391,7 +396,8 @@ def load_predictor(model_dir,
device='CPU',
min_subgraph_size=3,
trt_calib_mode=False,
enable_mkldnn=False):
enable_mkldnn=False,
enable_mkldnn_bfloat16=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
......@@ -399,6 +405,7 @@ def load_predictor(model_dir,
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
enable_mkldnn (bool): Whether use mkldnn with CPU, default is False
enable_mkldnn_bfloat16 (bool): Whether use mkldnn bfloat16 with CPU, default is False
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
......@@ -430,6 +437,8 @@ def load_predictor(model_dir,
config.set_mkldnn_cache_capacity(0)
config.enable_mkldnn()
config.pass_builder().append_pass("interpolate_mkldnn_pass")
if enable_mkldnn_bfloat16:
config.enable_mkldnn_bfloat16()
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=1 << 10,
......@@ -557,7 +566,8 @@ def main():
device=FLAGS.device,
run_mode=FLAGS.run_mode,
trt_calib_mode=FLAGS.trt_calib_mode,
enable_mkldnn=FLAGS.enable_mkldnn)
enable_mkldnn=FLAGS.enable_mkldnn,
enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16)
if config.arch == 'SOLOv2':
detector = DetectorSOLOv2(
config,
......@@ -565,7 +575,8 @@ def main():
device=FLAGS.device,
run_mode=FLAGS.run_mode,
trt_calib_mode=FLAGS.trt_calib_mode,
enable_mkldnn=FLAGS.enable_mkldnn)
enable_mkldnn=FLAGS.enable_mkldnn,
enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16)
# predict from image
if FLAGS.image_file != '':
predict_image(detector)
......@@ -636,6 +647,11 @@ if __name__ == '__main__':
type=ast.literal_eval,
default=False,
help="Whether use mkldnn with CPU.")
parser.add_argument(
"--enable_mkldnn_bfloat16",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn bfloat16 with CPU.")
FLAGS = parser.parse_args()
print_arguments(FLAGS)
if FLAGS.image_file != '' and FLAGS.video_file != '':
......@@ -644,5 +660,6 @@ if __name__ == '__main__':
assert FLAGS.device in ['CPU', 'GPU', 'XPU'
], "device should be CPU, GPU or XPU"
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
assert not (FLAGS.enable_mkldnn==False and FLAGS.enable_mkldnn_bfloat16==True),"To turn on mkldnn_bfloat, please set both enable_mkldnn and enable_mkldnn_bfloat16 True"
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册