From 8ca88e56d6890cb6ea33be1eabd0b0447e9ca09a Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Tue, 19 Oct 2021 16:29:57 +0800 Subject: [PATCH] add xpu and npu support for text_recognition series. (#1622) --- .../chinese_ocr_db_crnn_mobile/module.py | 95 +++++++++++++------ .../chinese_ocr_db_crnn_server/module.py | 95 +++++++++++++------ .../module.py | 93 ++++++++++++------ .../module.py | 86 ++++++++++++----- 4 files changed, 257 insertions(+), 112 deletions(-) diff --git a/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py b/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py index 371e8f97..8a32b999 100644 --- a/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py +++ b/modules/image/text_recognition/chinese_ocr_db_crnn_mobile/module.py @@ -6,7 +6,9 @@ import math import os import time -from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor +from paddle.inference import Config +from paddle.inference import create_predictor + from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving from PIL import Image @@ -53,6 +55,14 @@ class ChineseOCRDBCRNN(hub.Module): self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config( self.cls_pretrained_model_path) + def _get_device_id(self, places): + try: + places = os.environ[places] + id = int(places) + except: + id = -1 + return id + def _set_config(self, pretrained_model_path): """ predictor config path @@ -60,35 +70,49 @@ class ChineseOCRDBCRNN(hub.Module): model_file_path = os.path.join(pretrained_model_path, 'model') params_file_path = os.path.join(pretrained_model_path, 'params') - config = AnalysisConfig(model_file_path, params_file_path) - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - use_gpu = True - except: - use_gpu = False + config = Config(model_file_path, params_file_path) - if use_gpu: - config.enable_use_gpu(8000, 0) + # detect npu + npu_id = self._get_device_id("FLAGS_selected_npus") + if npu_id != -1: + # use npu + self.use_device = "npu" + config.enable_npu(device_id=npu_id) else: - config.disable_gpu() - if self.enable_mkldnn: - # cache 10 different shapes for mkldnn to avoid memory leak - config.set_mkldnn_cache_capacity(10) - config.enable_mkldnn() + # detect gpu + gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES") + if gpu_id != -1: + # use gpu + self.use_device = "gpu" + config.enable_use_gpu(memory_pool_init_size_mb=8000, device_id=gpu_id) + else: + # detect xpu + xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES") + if xpu_id != -1: + # use xpu + self.use_device = "xpu" + config.enable_xpu(100) + else: + self.use_device = "cpu" + config.disable_gpu() + config.set_cpu_math_library_num_threads(6) + if self.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() config.disable_glog_info() config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.switch_use_feed_fetch_ops(False) - predictor = create_paddle_predictor(config) + predictor = create_predictor(config) input_names = predictor.get_input_names() - input_tensor = predictor.get_input_tensor(input_names[0]) + input_tensor = predictor.get_input_handle(input_names[0]) output_names = predictor.get_output_names() output_tensors = [] for output_name in output_names: - output_tensor = predictor.get_output_tensor(output_name) + output_tensor = predictor.get_output_handle(output_name) output_tensors.append(output_tensor) return predictor, input_tensor, output_tensors @@ -186,7 +210,8 @@ class ChineseOCRDBCRNN(hub.Module): visualization=False, box_thresh=0.5, text_thresh=0.5, - angle_classification_thresh=0.9): + angle_classification_thresh=0.9, + use_device=None): """ Get the chinese texts in the predicted images. Args: @@ -199,18 +224,22 @@ class ChineseOCRDBCRNN(hub.Module): box_thresh(float): the threshold of the detected text box's confidence text_thresh(float): the threshold of the chinese text recognition confidence angle_classification_thresh(float): the threshold of the angle classification confidence + use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag. Returns: res (list): The result of chinese texts and save path of images. """ - if use_gpu: - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - except: + if use_device is not None: + # check 'use_device' match 'device on init' + if use_device != self.use_device: raise RuntimeError( - "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." + "the 'use_device' parameter when calling detect_text, does not match internal device found on init." ) + else: + # check 'use_gpu' match 'device on init' + if use_gpu == True and self.use_device != 'gpu' or use_gpu == False and self.use_device == 'gpu': + raise RuntimeError( + "the 'use_gpu' parameter when calling detect_text, does not match internal device found on init.") self.use_gpu = use_gpu @@ -224,7 +253,7 @@ class ChineseOCRDBCRNN(hub.Module): assert predicted_data != [], "There is not any image to be predicted. Please check the input data." detection_results = self.text_detector_module.detect_text( - images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh) + images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh, use_device=use_device) boxes = [np.array(item['data']).astype(np.float32) for item in detection_results] all_results = [] @@ -322,7 +351,7 @@ class ChineseOCRDBCRNN(hub.Module): norm_img_batch = norm_img_batch.copy() self.cls_input_tensor.copy_from_cpu(norm_img_batch) - self.cls_predictor.zero_copy_run() + self.cls_predictor.run() prob_out = self.cls_output_tensors[0].copy_to_cpu() label_out = self.cls_output_tensors[1].copy_to_cpu() @@ -366,7 +395,7 @@ class ChineseOCRDBCRNN(hub.Module): norm_img_batch = norm_img_batch.copy() self.rec_input_tensor.copy_from_cpu(norm_img_batch) - self.rec_predictor.zero_copy_run() + self.rec_predictor.run() rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu() rec_idx_lod = self.rec_output_tensors[0].lod()[0] @@ -471,7 +500,11 @@ class ChineseOCRDBCRNN(hub.Module): args = self.parser.parse_args(argvs) results = self.recognize_text( - paths=[args.input_path], use_gpu=args.use_gpu, output_dir=args.output_dir, visualization=args.visualization) + paths=[args.input_path], + use_gpu=args.use_gpu, + output_dir=args.output_dir, + visualization=args.visualization, + use_device=args.use_device) return results def add_module_config_arg(self): @@ -484,6 +517,10 @@ class ChineseOCRDBCRNN(hub.Module): '--output_dir', type=str, default='ocr_result', help="The directory to save output images.") self.arg_config_group.add_argument( '--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") + self.arg_config_group.add_argument( + '--use_device', + choices=["cpu", "gpu", "xpu", "npu"], + help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.") def add_module_input_arg(self): """ diff --git a/modules/image/text_recognition/chinese_ocr_db_crnn_server/module.py b/modules/image/text_recognition/chinese_ocr_db_crnn_server/module.py index a96673f3..cebd6812 100644 --- a/modules/image/text_recognition/chinese_ocr_db_crnn_server/module.py +++ b/modules/image/text_recognition/chinese_ocr_db_crnn_server/module.py @@ -10,7 +10,9 @@ import math import os import time -from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor +from paddle.inference import Config +from paddle.inference import create_predictor + from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving from PIL import Image @@ -57,6 +59,14 @@ class ChineseOCRDBCRNNServer(hub.Module): self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config( self.cls_pretrained_model_path) + def _get_device_id(self, places): + try: + places = os.environ[places] + id = int(places) + except: + id = -1 + return id + def _set_config(self, pretrained_model_path): """ predictor config path @@ -64,35 +74,49 @@ class ChineseOCRDBCRNNServer(hub.Module): model_file_path = os.path.join(pretrained_model_path, 'model') params_file_path = os.path.join(pretrained_model_path, 'params') - config = AnalysisConfig(model_file_path, params_file_path) - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - use_gpu = True - except: - use_gpu = False + config = Config(model_file_path, params_file_path) - if use_gpu: - config.enable_use_gpu(8000, 0) + # detect npu + npu_id = self._get_device_id("FLAGS_selected_npus") + if npu_id != -1: + # use npu + self.use_device = "npu" + config.enable_npu(device_id=npu_id) else: - config.disable_gpu() - if self.enable_mkldnn: - # cache 10 different shapes for mkldnn to avoid memory leak - config.set_mkldnn_cache_capacity(10) - config.enable_mkldnn() + # detect gpu + gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES") + if gpu_id != -1: + # use gpu + self.use_device = "gpu" + config.enable_use_gpu(memory_pool_init_size_mb=8000, device_id=gpu_id) + else: + # detect xpu + xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES") + if xpu_id != -1: + # use xpu + self.use_device = "xpu" + config.enable_xpu(100) + else: + self.use_device = "cpu" + config.disable_gpu() + config.set_cpu_math_library_num_threads(6) + if self.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() config.disable_glog_info() config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.switch_use_feed_fetch_ops(False) - predictor = create_paddle_predictor(config) + predictor = create_predictor(config) input_names = predictor.get_input_names() - input_tensor = predictor.get_input_tensor(input_names[0]) + input_tensor = predictor.get_input_handle(input_names[0]) output_names = predictor.get_output_names() output_tensors = [] for output_name in output_names: - output_tensor = predictor.get_output_tensor(output_name) + output_tensor = predictor.get_output_handle(output_name) output_tensors.append(output_tensor) return predictor, input_tensor, output_tensors @@ -190,7 +214,8 @@ class ChineseOCRDBCRNNServer(hub.Module): visualization=False, box_thresh=0.5, text_thresh=0.5, - angle_classification_thresh=0.9): + angle_classification_thresh=0.9, + use_device=None): """ Get the chinese texts in the predicted images. Args: @@ -203,18 +228,22 @@ class ChineseOCRDBCRNNServer(hub.Module): box_thresh(float): the threshold of the detected text box's confidence text_thresh(float): the threshold of the chinese text recognition confidence angle_classification_thresh(float): the threshold of the angle classification confidence + use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag. Returns: res (list): The result of chinese texts and save path of images. """ - if use_gpu: - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - except: + if use_device is not None: + # check 'use_device' match 'device on init' + if use_device != self.use_device: raise RuntimeError( - "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." + "the 'use_device' parameter when calling detect_text, does not match internal device found on init." ) + else: + # check 'use_gpu' match 'device on init' + if use_gpu == True and self.use_device != 'gpu' or use_gpu == False and self.use_device == 'gpu': + raise RuntimeError( + "the 'use_gpu' parameter when calling detect_text, does not match internal device found on init.") self.use_gpu = use_gpu @@ -228,7 +257,7 @@ class ChineseOCRDBCRNNServer(hub.Module): assert predicted_data != [], "There is not any image to be predicted. Please check the input data." detection_results = self.text_detector_module.detect_text( - images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh) + images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh, use_device=use_device) boxes = [np.array(item['data']).astype(np.float32) for item in detection_results] all_results = [] @@ -326,7 +355,7 @@ class ChineseOCRDBCRNNServer(hub.Module): norm_img_batch = norm_img_batch.copy() self.cls_input_tensor.copy_from_cpu(norm_img_batch) - self.cls_predictor.zero_copy_run() + self.cls_predictor.run() prob_out = self.cls_output_tensors[0].copy_to_cpu() label_out = self.cls_output_tensors[1].copy_to_cpu() @@ -370,7 +399,7 @@ class ChineseOCRDBCRNNServer(hub.Module): norm_img_batch = norm_img_batch.copy() self.rec_input_tensor.copy_from_cpu(norm_img_batch) - self.rec_predictor.zero_copy_run() + self.rec_predictor.run() rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu() rec_idx_lod = self.rec_output_tensors[0].lod()[0] @@ -475,7 +504,11 @@ class ChineseOCRDBCRNNServer(hub.Module): args = self.parser.parse_args(argvs) results = self.recognize_text( - paths=[args.input_path], use_gpu=args.use_gpu, output_dir=args.output_dir, visualization=args.visualization) + paths=[args.input_path], + use_gpu=args.use_gpu, + output_dir=args.output_dir, + visualization=args.visualization, + use_device=args.use_device) return results def add_module_config_arg(self): @@ -488,6 +521,10 @@ class ChineseOCRDBCRNNServer(hub.Module): '--output_dir', type=str, default='ocr_result', help="The directory to save output images.") self.arg_config_group.add_argument( '--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") + self.arg_config_group.add_argument( + '--use_device', + choices=["cpu", "gpu", "xpu", "npu"], + help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.") def add_module_input_arg(self): """ diff --git a/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py b/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py index aaae4aea..4114bbb7 100644 --- a/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py +++ b/modules/image/text_recognition/chinese_text_detection_db_mobile/module.py @@ -9,7 +9,9 @@ import math import os import time -from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor +from paddle.inference import Config +from paddle.inference import create_predictor + from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving from PIL import Image @@ -53,6 +55,14 @@ class ChineseTextDetectionDB(hub.Module): 'This module requires the shapely, pyclipper tools. The running environment does not meet the requirements. Please install the two packages.' ) + def _get_device_id(self, places): + try: + places = os.environ[places] + id = int(places) + except: + id = -1 + return id + def _set_config(self): """ predictor config setting @@ -60,36 +70,49 @@ class ChineseTextDetectionDB(hub.Module): model_file_path = os.path.join(self.pretrained_model_path, 'model') params_file_path = os.path.join(self.pretrained_model_path, 'params') - config = AnalysisConfig(model_file_path, params_file_path) - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - use_gpu = True - except: - use_gpu = False + config = Config(model_file_path, params_file_path) - if use_gpu: - config.enable_use_gpu(8000, 0) + # detect npu + npu_id = self._get_device_id("FLAGS_selected_npus") + if npu_id != -1: + # use npu + self.use_device = "npu" + config.enable_npu(device_id=npu_id) else: - config.disable_gpu() - config.set_cpu_math_library_num_threads(6) - if self.enable_mkldnn: - # cache 10 different shapes for mkldnn to avoid memory leak - config.set_mkldnn_cache_capacity(10) - config.enable_mkldnn() + # detect gpu + gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES") + if gpu_id != -1: + # use gpu + self.use_device = "gpu" + config.enable_use_gpu(memory_pool_init_size_mb=8000, device_id=gpu_id) + else: + # detect xpu + xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES") + if xpu_id != -1: + # use xpu + self.use_device = "xpu" + config.enable_xpu(100) + else: + self.use_device = "cpu" + config.disable_gpu() + config.set_cpu_math_library_num_threads(6) + if self.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() config.disable_glog_info() # use zero copy config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.switch_use_feed_fetch_ops(False) - self.predictor = create_paddle_predictor(config) + self.predictor = create_predictor(config) input_names = self.predictor.get_input_names() - self.input_tensor = self.predictor.get_input_tensor(input_names[0]) + self.input_tensor = self.predictor.get_input_handle(input_names[0]) output_names = self.predictor.get_output_names() self.output_tensors = [] for output_name in output_names: - output_tensor = self.predictor.get_output_tensor(output_name) + output_tensor = self.predictor.get_output_handle(output_name) self.output_tensors.append(output_tensor) def read_images(self, paths=[]): @@ -162,7 +185,8 @@ class ChineseTextDetectionDB(hub.Module): use_gpu=False, output_dir='detection_result', visualization=False, - box_thresh=0.5): + box_thresh=0.5, + use_device=None): """ Get the text box in the predicted images. Args: @@ -172,21 +196,24 @@ class ChineseTextDetectionDB(hub.Module): output_dir (str): The directory to store output images. visualization (bool): Whether to save image or not. box_thresh(float): the threshold of the detected text box's confidence + use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag. Returns: res (list): The result of text detection box and save path of images. """ self.check_requirements() from chinese_text_detection_db_mobile.processor import DBProcessTest, DBPostProcess, draw_boxes, get_image_ext - - if use_gpu: - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - except: + if use_device is not None: + # check 'use_device' match 'device on init' + if use_device != self.use_device: raise RuntimeError( - "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." + "the 'use_device' parameter when calling detect_text, does not match internal device found on init." ) + else: + # check 'use_gpu' match 'device on init' + if use_gpu == True and self.use_device != 'gpu' or use_gpu == False and self.use_device == 'gpu': + raise RuntimeError( + "the 'use_gpu' parameter when calling detect_text, does not match internal device found on init.") if images != [] and isinstance(images, list) and paths == []: predicted_data = images @@ -218,7 +245,7 @@ class ChineseTextDetectionDB(hub.Module): else: im = im.copy() self.input_tensor.copy_from_cpu(im) - self.predictor.zero_copy_run() + self.predictor.run() outputs = [] for output_tensor in self.output_tensors: @@ -304,7 +331,11 @@ class ChineseTextDetectionDB(hub.Module): args = self.parser.parse_args(argvs) results = self.detect_text( - paths=[args.input_path], use_gpu=args.use_gpu, output_dir=args.output_dir, visualization=args.visualization) + paths=[args.input_path], + use_gpu=args.use_gpu, + output_dir=args.output_dir, + visualization=args.visualization, + use_device=args.use_device) return results def add_module_config_arg(self): @@ -317,6 +348,10 @@ class ChineseTextDetectionDB(hub.Module): '--output_dir', type=str, default='detection_result', help="The directory to save output images.") self.arg_config_group.add_argument( '--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") + self.arg_config_group.add_argument( + '--use_device', + choices=["cpu", "gpu", "xpu", "npu"], + help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.") def add_module_input_arg(self): """ diff --git a/modules/image/text_recognition/chinese_text_detection_db_server/module.py b/modules/image/text_recognition/chinese_text_detection_db_server/module.py index 52295bef..10973ed0 100644 --- a/modules/image/text_recognition/chinese_text_detection_db_server/module.py +++ b/modules/image/text_recognition/chinese_text_detection_db_server/module.py @@ -9,7 +9,9 @@ import math import os import time -from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor +from paddle.inference import Config +from paddle.inference import create_predictor + from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving from PIL import Image @@ -53,6 +55,14 @@ class ChineseTextDetectionDBServer(hub.Module): 'This module requires the shapely, pyclipper tools. The running environment does not meet the requirements. Please install the two packages.' ) + def _get_device_id(self, places): + try: + places = os.environ[places] + id = int(places) + except: + id = -1 + return id + def _set_config(self): """ predictor config setting @@ -60,33 +70,46 @@ class ChineseTextDetectionDBServer(hub.Module): model_file_path = os.path.join(self.pretrained_model_path, 'model') params_file_path = os.path.join(self.pretrained_model_path, 'params') - config = AnalysisConfig(model_file_path, params_file_path) - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - use_gpu = True - except: - use_gpu = False + config = Config(model_file_path, params_file_path) - if use_gpu: - config.enable_use_gpu(8000, 0) + # detect npu + npu_id = self._get_device_id("FLAGS_selected_npus") + if npu_id != -1: + # use npu + self.use_device = "npu" + config.enable_npu(device_id=npu_id) else: - config.disable_gpu() - if self.enable_mkldnn: - config.enable_mkldnn() + # detect gpu + gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES") + if gpu_id != -1: + # use gpu + self.use_device = "gpu" + config.enable_use_gpu(memory_pool_init_size_mb=8000, device_id=gpu_id) + else: + # detect xpu + xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES") + if xpu_id != -1: + # use xpu + self.use_device = "xpu" + config.enable_xpu(100) + else: + self.use_device = "cpu" + config.disable_gpu() + if self.enable_mkldnn: + config.enable_mkldnn() config.disable_glog_info() # use zero copy config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.switch_use_feed_fetch_ops(False) - self.predictor = create_paddle_predictor(config) + self.predictor = create_predictor(config) input_names = self.predictor.get_input_names() - self.input_tensor = self.predictor.get_input_tensor(input_names[0]) + self.input_tensor = self.predictor.get_input_handle(input_names[0]) output_names = self.predictor.get_output_names() self.output_tensors = [] for output_name in output_names: - output_tensor = self.predictor.get_output_tensor(output_name) + output_tensor = self.predictor.get_output_handle(output_name) self.output_tensors.append(output_tensor) def read_images(self, paths=[]): @@ -151,7 +174,8 @@ class ChineseTextDetectionDBServer(hub.Module): use_gpu=False, output_dir='detection_result', visualization=False, - box_thresh=0.5): + box_thresh=0.5, + use_device=None): """ Get the text box in the predicted images. Args: @@ -161,6 +185,7 @@ class ChineseTextDetectionDBServer(hub.Module): output_dir (str): The directory to store output images. visualization (bool): Whether to save image or not. box_thresh(float): the threshold of the detected text box's confidence + use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag. Returns: res (list): The result of text detection box and save path of images. """ @@ -168,14 +193,17 @@ class ChineseTextDetectionDBServer(hub.Module): from chinese_text_detection_db_server.processor import DBPreProcess, DBPostProcess, draw_boxes, get_image_ext - if use_gpu: - try: - _places = os.environ["CUDA_VISIBLE_DEVICES"] - int(_places[0]) - except: + if use_device is not None: + # check 'use_device' match 'device on init' + if use_device != self.use_device: raise RuntimeError( - "Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." + "the 'use_device' parameter when calling detect_text, does not match internal device found on init." ) + else: + # check 'use_gpu' match 'device on init' + if use_gpu == True and self.use_device != 'gpu' or use_gpu == False and self.use_device == 'gpu': + raise RuntimeError( + "the 'use_gpu' parameter when calling detect_text, does not match internal device found on init.") if images != [] and isinstance(images, list) and paths == []: predicted_data = images @@ -202,7 +230,7 @@ class ChineseTextDetectionDBServer(hub.Module): im = im.copy() starttime = time.time() self.input_tensor.copy_from_cpu(im) - self.predictor.zero_copy_run() + self.predictor.run() data_out = self.output_tensors[0].copy_to_cpu() dt_boxes_list = postprocessor(data_out, [ratio_list]) boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape) @@ -278,7 +306,11 @@ class ChineseTextDetectionDBServer(hub.Module): args = self.parser.parse_args(argvs) results = self.detect_text( - paths=[args.input_path], use_gpu=args.use_gpu, output_dir=args.output_dir, visualization=args.visualization) + paths=[args.input_path], + use_gpu=args.use_gpu, + output_dir=args.output_dir, + visualization=args.visualization, + use_device=args.use_device) return results def add_module_config_arg(self): @@ -291,6 +323,10 @@ class ChineseTextDetectionDBServer(hub.Module): '--output_dir', type=str, default='detection_result', help="The directory to save output images.") self.arg_config_group.add_argument( '--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") + self.arg_config_group.add_argument( + '--use_device', + choices=["cpu", "gpu", "xpu", "npu"], + help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.") def add_module_input_arg(self): """ -- GitLab