提交 b17b6c9a 编写于 作者: W wuyefeilin 提交者: wuzewu

update infer.py (#116)

* update infer.py
上级 bd3b285e
...@@ -33,6 +33,7 @@ gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model") ...@@ -33,6 +33,7 @@ gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
gflags.DEFINE_string("trt_mode", default="", help="Use optimized model") gflags.DEFINE_string("trt_mode", default="", help="Use optimized model")
gflags.FLAGS = gflags.FLAGS gflags.FLAGS = gflags.FLAGS
# Generate ColorMap for visualization # Generate ColorMap for visualization
def generate_colormap(num_classes): def generate_colormap(num_classes):
color_map = num_classes * [0, 0, 0] color_map = num_classes * [0, 0, 0]
...@@ -45,9 +46,10 @@ def generate_colormap(num_classes): ...@@ -45,9 +46,10 @@ def generate_colormap(num_classes):
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1 j += 1
lab >>= 3 lab >>= 3
color_map = [color_map[i:i+3] for i in range(0, len(color_map), 3)] color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map return color_map
# Paddle-TRT Precision Map # Paddle-TRT Precision Map
trt_precision_map = { trt_precision_map = {
"int8": fluid.core.AnalysisConfig.Precision.Int8, "int8": fluid.core.AnalysisConfig.Precision.Int8,
...@@ -55,6 +57,7 @@ trt_precision_map = { ...@@ -55,6 +57,7 @@ trt_precision_map = {
"fp16": fluid.core.AnalysisConfig.Precision.Half "fp16": fluid.core.AnalysisConfig.Precision.Half
} }
# scan a directory and get all images with support extensions # scan a directory and get all images with support extensions
def get_images_from_dir(img_dir, support_ext=".jpg|.jpeg"): def get_images_from_dir(img_dir, support_ext=".jpg|.jpeg"):
if (not os.path.exists(img_dir) or not os.path.isdir(img_dir)): if (not os.path.exists(img_dir) or not os.path.isdir(img_dir)):
...@@ -67,6 +70,7 @@ def get_images_from_dir(img_dir, support_ext=".jpg|.jpeg"): ...@@ -67,6 +70,7 @@ def get_images_from_dir(img_dir, support_ext=".jpg|.jpeg"):
imgs.append(item_path) imgs.append(item_path)
return imgs return imgs
# Deploy Configuration File Parser # Deploy Configuration File Parser
class DeployConfig: class DeployConfig:
def __init__(self, conf_file): def __init__(self, conf_file):
...@@ -77,7 +81,8 @@ class DeployConfig: ...@@ -77,7 +81,8 @@ class DeployConfig:
configs = yaml.load(fp, Loader=yaml.FullLoader) configs = yaml.load(fp, Loader=yaml.FullLoader)
deploy_conf = configs["DEPLOY"] deploy_conf = configs["DEPLOY"]
# 1. get eval_crop_size # 1. get eval_crop_size
self.eval_crop_size = ast.literal_eval(deploy_conf["EVAL_CROP_SIZE"]) self.eval_crop_size = ast.literal_eval(
deploy_conf["EVAL_CROP_SIZE"])
# 2. get mean # 2. get mean
self.mean = deploy_conf["MEAN"] self.mean = deploy_conf["MEAN"]
# 3. get std # 3. get std
...@@ -85,10 +90,10 @@ class DeployConfig: ...@@ -85,10 +90,10 @@ class DeployConfig:
# 4. get class_num # 4. get class_num
self.class_num = deploy_conf["NUM_CLASSES"] self.class_num = deploy_conf["NUM_CLASSES"]
# 5. get paddle model and params file path # 5. get paddle model and params file path
self.model_file = os.path.join( self.model_file = os.path.join(deploy_conf["MODEL_PATH"],
deploy_conf["MODEL_PATH"], deploy_conf["MODEL_FILENAME"]) deploy_conf["MODEL_FILENAME"])
self.param_file = os.path.join( self.param_file = os.path.join(deploy_conf["MODEL_PATH"],
deploy_conf["MODEL_PATH"], deploy_conf["PARAMS_FILENAME"]) deploy_conf["PARAMS_FILENAME"])
# 6. use_gpu # 6. use_gpu
self.use_gpu = deploy_conf["USE_GPU"] self.use_gpu = deploy_conf["USE_GPU"]
# 7. predictor_mode # 7. predictor_mode
...@@ -98,6 +103,7 @@ class DeployConfig: ...@@ -98,6 +103,7 @@ class DeployConfig:
# 9. channels # 9. channels
self.channels = deploy_conf["CHANNELS"] self.channels = deploy_conf["CHANNELS"]
class ImageReader: class ImageReader:
def __init__(self, configs): def __init__(self, configs):
self.config = configs self.config = configs
...@@ -133,7 +139,7 @@ class ImageReader: ...@@ -133,7 +139,7 @@ class ImageReader:
im = im[:, :, :].astype('float32') / 255.0 im = im[:, :, :].astype('float32') / 255.0
im -= im_mean im -= im_mean
im /= im_std im /= im_std
im = im[np.newaxis,:,:,:] im = im[np.newaxis, :, :, :]
info = [image_path, im, (ori_w, ori_h)] info = [image_path, im, (ori_w, ori_h)]
return info return info
...@@ -141,12 +147,15 @@ class ImageReader: ...@@ -141,12 +147,15 @@ class ImageReader:
def process(self, imgs, use_pr=False): def process(self, imgs, use_pr=False):
imgs_data = [] imgs_data = []
with ThreadPoolExecutor(max_workers=self.config.batch_size) as exec: with ThreadPoolExecutor(max_workers=self.config.batch_size) as exec:
tasks = [exec.submit(self.process_worker, imgs, idx, use_pr) tasks = [
for idx in range(len(imgs))] exec.submit(self.process_worker, imgs, idx, use_pr)
for idx in range(len(imgs))
]
for task in as_completed(tasks): for task in as_completed(tasks):
imgs_data.append(task.result()) imgs_data.append(task.result())
return imgs_data return imgs_data
class Predictor: class Predictor:
def __init__(self, conf_file): def __init__(self, conf_file):
self.config = DeployConfig(conf_file) self.config = DeployConfig(conf_file)
...@@ -168,7 +177,7 @@ class Predictor: ...@@ -168,7 +177,7 @@ class Predictor:
precision_type = trt_precision_map[gflags.FLAGS.trt_mode] precision_type = trt_precision_map[gflags.FLAGS.trt_mode]
use_calib = (gflags.FLAGS.trt_mode == "int8") use_calib = (gflags.FLAGS.trt_mode == "int8")
predictor_config.enable_tensorrt_engine( predictor_config.enable_tensorrt_engine(
workspace_size=1<<30, workspace_size=1 << 30,
max_batch_size=self.config.batch_size, max_batch_size=self.config.batch_size,
min_subgraph_size=40, min_subgraph_size=40,
precision_mode=precision_type, precision_mode=precision_type,
...@@ -184,15 +193,15 @@ class Predictor: ...@@ -184,15 +193,15 @@ class Predictor:
im_tensor = fluid.core.PaddleTensor() im_tensor = fluid.core.PaddleTensor()
im_tensor.name = "image" im_tensor.name = "image"
if not use_pr: if not use_pr:
im_tensor.shape = [batch_size, im_tensor.shape = [
self.config.channels, batch_size, self.config.channels, self.config.eval_crop_size[1],
self.config.eval_crop_size[1], self.config.eval_crop_size[0]
self.config.eval_crop_size[0]] ]
else: else:
im_tensor.shape = [batch_size, im_tensor.shape = [
self.config.eval_crop_size[1], batch_size, self.config.eval_crop_size[1],
self.config.eval_crop_size[0], self.config.eval_crop_size[0], self.config.channels
self.config.channels] ]
im_tensor.dtype = fluid.core.PaddleDType.FLOAT32 im_tensor.dtype = fluid.core.PaddleDType.FLOAT32
im_tensor.data = fluid.core.PaddleBuf(inputs.ravel().astype("float32")) im_tensor.data = fluid.core.PaddleBuf(inputs.ravel().astype("float32"))
return [im_tensor] return [im_tensor]
...@@ -225,7 +234,11 @@ class Predictor: ...@@ -225,7 +234,11 @@ class Predictor:
vis_result_name = img_name_fix + "_result.png" vis_result_name = img_name_fix + "_result.png"
result_png = score_png result_png = score_png
# if not use_pr: # if not use_pr:
result_png = cv2.resize(result_png, ori_shape, fx=0, fy=0, result_png = cv2.resize(
result_png,
ori_shape,
fx=0,
fy=0,
interpolation=cv2.INTER_CUBIC) interpolation=cv2.INTER_CUBIC)
cv2.imwrite(vis_result_name, result_png, [cv2.CV_8UC1]) cv2.imwrite(vis_result_name, result_png, [cv2.CV_8UC1])
print("save result of [" + img_name + "] done.") print("save result of [" + img_name + "] done.")
...@@ -248,7 +261,8 @@ class Predictor: ...@@ -248,7 +261,8 @@ class Predictor:
if i + batch_size >= len(images): if i + batch_size >= len(images):
real_batch_size = len(images) - i real_batch_size = len(images) - i
reader_start = time.time() reader_start = time.time()
img_datas = self.image_reader.process(images[i: i + real_batch_size]) img_datas = self.image_reader.process(images[i:i + real_batch_size],
gflags.FLAGS.use_pr)
input_data = np.concatenate([item[1] for item in img_datas]) input_data = np.concatenate([item[1] for item in img_datas])
input_data = self.create_tensor( input_data = self.create_tensor(
input_data, real_batch_size, use_pr=gflags.FLAGS.use_pr) input_data, real_batch_size, use_pr=gflags.FLAGS.use_pr)
...@@ -268,15 +282,17 @@ class Predictor: ...@@ -268,15 +282,17 @@ class Predictor:
total_end = time.time() total_end = time.time()
# compute whole processing time # compute whole processing time
total_runtime = (total_end - total_start) total_runtime = (total_end - total_start)
print("images_num=[%d],preprocessing_time=[%f],infer_time=[%f],postprocessing_time=[%f],total_runtime=[%f]" print(
"images_num=[%d],preprocessing_time=[%f],infer_time=[%f],postprocessing_time=[%f],total_runtime=[%f]"
% (len(images), reader_time, infer_time, post_time, total_runtime)) % (len(images), reader_time, infer_time, post_time, total_runtime))
def run(deploy_conf, imgs_dir, support_extensions=".jpg|.jpeg"): def run(deploy_conf, imgs_dir, support_extensions=".jpg|.jpeg"):
# 1. scan and get all images with valid extensions in directory imgs_dir # 1. scan and get all images with valid extensions in directory imgs_dir
imgs = get_images_from_dir(imgs_dir) imgs = get_images_from_dir(imgs_dir)
if len(imgs) == 0: if len(imgs) == 0:
print("No Image (with extensions : %s) found in [%s]" print("No Image (with extensions : %s) found in [%s]" %
% (support_extensions, imgs_dir)) (support_extensions, imgs_dir))
return -1 return -1
# 2. create a predictor # 2. create a predictor
seg_predictor = Predictor(deploy_conf) seg_predictor = Predictor(deploy_conf)
...@@ -284,17 +300,19 @@ def run(deploy_conf, imgs_dir, support_extensions=".jpg|.jpeg"): ...@@ -284,17 +300,19 @@ def run(deploy_conf, imgs_dir, support_extensions=".jpg|.jpeg"):
seg_predictor.predict(imgs) seg_predictor.predict(imgs)
return 0 return 0
if __name__ == "__main__": if __name__ == "__main__":
# 0. parse the arguments # 0. parse the arguments
gflags.FLAGS(sys.argv) gflags.FLAGS(sys.argv)
if (gflags.FLAGS.conf == "" or gflags.FLAGS.input_dir == ""): if (gflags.FLAGS.conf == "" or gflags.FLAGS.input_dir == ""):
print("Usage: python infer.py --conf=/config/path/to/your/model " print("Usage: python infer.py --conf=/config/path/to/your/model " +
+"--input_dir=/directory/of/your/input/images [--use_pr=True]") "--input_dir=/directory/of/your/input/images [--use_pr=True]")
exit(-1) exit(-1)
# set empty to turn off as default # set empty to turn off as default
trt_mode = gflags.FLAGS.trt_mode trt_mode = gflags.FLAGS.trt_mode
if (trt_mode != "" and trt_mode not in trt_precision_map): if (trt_mode != "" and trt_mode not in trt_precision_map):
print("Invalid trt_mode [%s], only support[int8, fp16, fp32]" % trt_mode) print(
"Invalid trt_mode [%s], only support[int8, fp16, fp32]" % trt_mode)
exit(-1) exit(-1)
# run inference # run inference
run(gflags.FLAGS.conf, gflags.FLAGS.input_dir) run(gflags.FLAGS.conf, gflags.FLAGS.input_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册