From 867d4a5bca4e828e98b62b31697f18420ff8415e Mon Sep 17 00:00:00 2001 From: chliang <317000130@qq.com> Date: Sun, 24 May 2020 22:13:32 +0800 Subject: [PATCH] Read and use items of "use_pr" in deploy configuration files (deploy.yaml ) (#264) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 陈亮 --- deploy/python/infer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/deploy/python/infer.py b/deploy/python/infer.py index eb5cfa4b..05e84eb1 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -29,7 +29,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed gflags.DEFINE_string("conf", default="", help="Configuration File Path") gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images") -gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model") gflags.DEFINE_string("trt_mode", default="", help="Use optimized model") gflags.DEFINE_string( "ext", default=".jpeg|.jpg", help="Input Image File Extensions") @@ -104,6 +103,9 @@ class DeployConfig: self.batch_size = deploy_conf["BATCH_SIZE"] # 9. channels self.channels = deploy_conf["CHANNELS"] + # 10. use_pr + self.use_pr = deploy_conf["USE_PR"] + class ImageReader: @@ -258,23 +260,24 @@ class Predictor: # record starting time point total_start = time.time() batch_size = self.config.batch_size + use_pr = self.config.use_pr for i in range(0, len(images), batch_size): real_batch_size = batch_size if i + batch_size >= len(images): real_batch_size = len(images) - i reader_start = time.time() img_datas = self.image_reader.process(images[i:i + real_batch_size], - gflags.FLAGS.use_pr) + use_pr) input_data = np.concatenate([item[1] for item in img_datas]) input_data = self.create_tensor( - input_data, real_batch_size, use_pr=gflags.FLAGS.use_pr) + input_data, real_batch_size, use_pr=use_pr) reader_end = time.time() infer_start = time.time() output_data = self.predictor.run(input_data)[0] infer_end = time.time() output_data = output_data.as_ndarray() post_start = time.time() - self.output_result(img_datas, output_data, gflags.FLAGS.use_pr) + self.output_result(img_datas, output_data, use_pr) post_end = time.time() reader_time += (reader_end - reader_start) infer_time += (infer_end - infer_start) -- GitLab