提交 ac012cbb 编写于 作者: C chenguowei01

update infer.py

上级 f3a53654
......@@ -33,6 +33,7 @@ gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
gflags.DEFINE_string("trt_mode", default="", help="Use optimized model")
gflags.FLAGS = gflags.FLAGS
# Generate ColorMap for visualization
def generate_colormap(num_classes):
color_map = num_classes * [0, 0, 0]
......@@ -45,9 +46,10 @@ def generate_colormap(num_classes):
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i+3] for i in range(0, len(color_map), 3)]
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
# Paddle-TRT Precision Map
trt_precision_map = {
"int8": fluid.core.AnalysisConfig.Precision.Int8,
......@@ -55,6 +57,7 @@ trt_precision_map = {
"fp16": fluid.core.AnalysisConfig.Precision.Half
}
# scan a directory and get all images with support extensions
def get_images_from_dir(img_dir, support_ext=".jpg|.jpeg"):
if (not os.path.exists(img_dir) or not os.path.isdir(img_dir)):
......@@ -67,6 +70,7 @@ def get_images_from_dir(img_dir, support_ext=".jpg|.jpeg"):
imgs.append(item_path)
return imgs
# Deploy Configuration File Parser
class DeployConfig:
def __init__(self, conf_file):
......@@ -77,7 +81,8 @@ class DeployConfig:
configs = yaml.load(fp, Loader=yaml.FullLoader)
deploy_conf = configs["DEPLOY"]
# 1. get eval_crop_size
self.eval_crop_size = ast.literal_eval(deploy_conf["EVAL_CROP_SIZE"])
self.eval_crop_size = ast.literal_eval(
deploy_conf["EVAL_CROP_SIZE"])
# 2. get mean
self.mean = deploy_conf["MEAN"]
# 3. get std
......@@ -85,10 +90,10 @@ class DeployConfig:
# 4. get class_num
self.class_num = deploy_conf["NUM_CLASSES"]
# 5. get paddle model and params file path
self.model_file = os.path.join(
deploy_conf["MODEL_PATH"], deploy_conf["MODEL_FILENAME"])
self.param_file = os.path.join(
deploy_conf["MODEL_PATH"], deploy_conf["PARAMS_FILENAME"])
self.model_file = os.path.join(deploy_conf["MODEL_PATH"],
deploy_conf["MODEL_FILENAME"])
self.param_file = os.path.join(deploy_conf["MODEL_PATH"],
deploy_conf["PARAMS_FILENAME"])
# 6. use_gpu
self.use_gpu = deploy_conf["USE_GPU"]
# 7. predictor_mode
......@@ -98,6 +103,7 @@ class DeployConfig:
# 9. channels
self.channels = deploy_conf["CHANNELS"]
class ImageReader:
def __init__(self, configs):
self.config = configs
......@@ -133,7 +139,7 @@ class ImageReader:
im = im[:, :, :].astype('float32') / 255.0
im -= im_mean
im /= im_std
im = im[np.newaxis,:,:,:]
im = im[np.newaxis, :, :, :]
info = [image_path, im, (ori_w, ori_h)]
return info
......@@ -141,12 +147,15 @@ class ImageReader:
def process(self, imgs, use_pr=False):
imgs_data = []
with ThreadPoolExecutor(max_workers=self.config.batch_size) as exec:
tasks = [exec.submit(self.process_worker, imgs, idx, use_pr)
for idx in range(len(imgs))]
tasks = [
exec.submit(self.process_worker, imgs, idx, use_pr)
for idx in range(len(imgs))
]
for task in as_completed(tasks):
imgs_data.append(task.result())
return imgs_data
class Predictor:
def __init__(self, conf_file):
self.config = DeployConfig(conf_file)
......@@ -168,7 +177,7 @@ class Predictor:
precision_type = trt_precision_map[gflags.FLAGS.trt_mode]
use_calib = (gflags.FLAGS.trt_mode == "int8")
predictor_config.enable_tensorrt_engine(
workspace_size=1<<30,
workspace_size=1 << 30,
max_batch_size=self.config.batch_size,
min_subgraph_size=40,
precision_mode=precision_type,
......@@ -184,15 +193,15 @@ class Predictor:
im_tensor = fluid.core.PaddleTensor()
im_tensor.name = "image"
if not use_pr:
im_tensor.shape = [batch_size,
self.config.channels,
self.config.eval_crop_size[1],
self.config.eval_crop_size[0]]
im_tensor.shape = [
batch_size, self.config.channels, self.config.eval_crop_size[1],
self.config.eval_crop_size[0]
]
else:
im_tensor.shape = [batch_size,
self.config.eval_crop_size[1],
self.config.eval_crop_size[0],
self.config.channels]
im_tensor.shape = [
batch_size, self.config.eval_crop_size[1],
self.config.eval_crop_size[0], self.config.channels
]
im_tensor.dtype = fluid.core.PaddleDType.FLOAT32
im_tensor.data = fluid.core.PaddleBuf(inputs.ravel().astype("float32"))
return [im_tensor]
......@@ -225,8 +234,12 @@ class Predictor:
vis_result_name = img_name_fix + "_result.png"
result_png = score_png
# if not use_pr:
result_png = cv2.resize(result_png, ori_shape, fx=0, fy=0,
interpolation=cv2.INTER_CUBIC)
result_png = cv2.resize(
result_png,
ori_shape,
fx=0,
fy=0,
interpolation=cv2.INTER_CUBIC)
cv2.imwrite(vis_result_name, result_png, [cv2.CV_8UC1])
print("save result of [" + img_name + "] done.")
......@@ -248,7 +261,8 @@ class Predictor:
if i + batch_size >= len(images):
real_batch_size = len(images) - i
reader_start = time.time()
img_datas = self.image_reader.process(images[i: i + real_batch_size])
img_datas = self.image_reader.process(images[i:i + real_batch_size],
gflags.FLAGS.use_pr)
input_data = np.concatenate([item[1] for item in img_datas])
input_data = self.create_tensor(
input_data, real_batch_size, use_pr=gflags.FLAGS.use_pr)
......@@ -268,15 +282,17 @@ class Predictor:
total_end = time.time()
# compute whole processing time
total_runtime = (total_end - total_start)
print("images_num=[%d],preprocessing_time=[%f],infer_time=[%f],postprocessing_time=[%f],total_runtime=[%f]"
% (len(images), reader_time, infer_time, post_time, total_runtime))
print(
"images_num=[%d],preprocessing_time=[%f],infer_time=[%f],postprocessing_time=[%f],total_runtime=[%f]"
% (len(images), reader_time, infer_time, post_time, total_runtime))
def run(deploy_conf, imgs_dir, support_extensions=".jpg|.jpeg"):
# 1. scan and get all images with valid extensions in directory imgs_dir
imgs = get_images_from_dir(imgs_dir)
if len(imgs) == 0:
print("No Image (with extensions : %s) found in [%s]"
% (support_extensions, imgs_dir))
print("No Image (with extensions : %s) found in [%s]" %
(support_extensions, imgs_dir))
return -1
# 2. create a predictor
seg_predictor = Predictor(deploy_conf)
......@@ -284,17 +300,19 @@ def run(deploy_conf, imgs_dir, support_extensions=".jpg|.jpeg"):
seg_predictor.predict(imgs)
return 0
if __name__ == "__main__":
# 0. parse the arguments
gflags.FLAGS(sys.argv)
if (gflags.FLAGS.conf == "" or gflags.FLAGS.input_dir == ""):
print("Usage: python infer.py --conf=/config/path/to/your/model "
+"--input_dir=/directory/of/your/input/images [--use_pr=True]")
print("Usage: python infer.py --conf=/config/path/to/your/model " +
"--input_dir=/directory/of/your/input/images [--use_pr=True]")
exit(-1)
# set empty to turn off as default
trt_mode = gflags.FLAGS.trt_mode
if (trt_mode != "" and trt_mode not in trt_precision_map):
print("Invalid trt_mode [%s], only support[int8, fp16, fp32]" % trt_mode)
print(
"Invalid trt_mode [%s], only support[int8, fp16, fp32]" % trt_mode)
exit(-1)
# run inference
run(gflags.FLAGS.conf, gflags.FLAGS.input_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册