未验证 提交 8ca88e56 编写于 作者: H houj04 提交者: GitHub

add xpu and npu support for text_recognition series. (#1622)

上级 d79f6c49
...@@ -6,7 +6,9 @@ import math ...@@ -6,7 +6,9 @@ import math
import os import os
import time import time
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from PIL import Image from PIL import Image
...@@ -53,6 +55,14 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -53,6 +55,14 @@ class ChineseOCRDBCRNN(hub.Module):
self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config( self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config(
self.cls_pretrained_model_path) self.cls_pretrained_model_path)
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self, pretrained_model_path): def _set_config(self, pretrained_model_path):
""" """
predictor config path predictor config path
...@@ -60,35 +70,49 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -60,35 +70,49 @@ class ChineseOCRDBCRNN(hub.Module):
model_file_path = os.path.join(pretrained_model_path, 'model') model_file_path = os.path.join(pretrained_model_path, 'model')
params_file_path = os.path.join(pretrained_model_path, 'params') params_file_path = os.path.join(pretrained_model_path, 'params')
config = AnalysisConfig(model_file_path, params_file_path) config = Config(model_file_path, params_file_path)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu: # detect npu
config.enable_use_gpu(8000, 0) npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
self.use_device = "npu"
config.enable_npu(device_id=npu_id)
else: else:
config.disable_gpu() # detect gpu
if self.enable_mkldnn: gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
# cache 10 different shapes for mkldnn to avoid memory leak if gpu_id != -1:
config.set_mkldnn_cache_capacity(10) # use gpu
config.enable_mkldnn() self.use_device = "gpu"
config.enable_use_gpu(memory_pool_init_size_mb=8000, device_id=gpu_id)
else:
# detect xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
self.use_device = "xpu"
config.enable_xpu(100)
else:
self.use_device = "cpu"
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)
if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
predictor = create_paddle_predictor(config) predictor = create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0]) input_tensor = predictor.get_input_handle(input_names[0])
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
output_tensors = [] output_tensors = []
for output_name in output_names: for output_name in output_names:
output_tensor = predictor.get_output_tensor(output_name) output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors return predictor, input_tensor, output_tensors
...@@ -186,7 +210,8 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -186,7 +210,8 @@ class ChineseOCRDBCRNN(hub.Module):
visualization=False, visualization=False,
box_thresh=0.5, box_thresh=0.5,
text_thresh=0.5, text_thresh=0.5,
angle_classification_thresh=0.9): angle_classification_thresh=0.9,
use_device=None):
""" """
Get the chinese texts in the predicted images. Get the chinese texts in the predicted images.
Args: Args:
...@@ -199,18 +224,22 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -199,18 +224,22 @@ class ChineseOCRDBCRNN(hub.Module):
box_thresh(float): the threshold of the detected text box's confidence box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the chinese text recognition confidence text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence angle_classification_thresh(float): the threshold of the angle classification confidence
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
res (list): The result of chinese texts and save path of images. res (list): The result of chinese texts and save path of images.
""" """
if use_gpu: if use_device is not None:
try: # check 'use_device' match 'device on init'
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device != self.use_device:
int(_places[0])
except:
raise RuntimeError( raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." "the 'use_device' parameter when calling detect_text, does not match internal device found on init."
) )
else:
# check 'use_gpu' match 'device on init'
if use_gpu == True and self.use_device != 'gpu' or use_gpu == False and self.use_device == 'gpu':
raise RuntimeError(
"the 'use_gpu' parameter when calling detect_text, does not match internal device found on init.")
self.use_gpu = use_gpu self.use_gpu = use_gpu
...@@ -224,7 +253,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -224,7 +253,7 @@ class ChineseOCRDBCRNN(hub.Module):
assert predicted_data != [], "There is not any image to be predicted. Please check the input data." assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
detection_results = self.text_detector_module.detect_text( detection_results = self.text_detector_module.detect_text(
images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh) images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh, use_device=use_device)
boxes = [np.array(item['data']).astype(np.float32) for item in detection_results] boxes = [np.array(item['data']).astype(np.float32) for item in detection_results]
all_results = [] all_results = []
...@@ -322,7 +351,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -322,7 +351,7 @@ class ChineseOCRDBCRNN(hub.Module):
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
self.cls_input_tensor.copy_from_cpu(norm_img_batch) self.cls_input_tensor.copy_from_cpu(norm_img_batch)
self.cls_predictor.zero_copy_run() self.cls_predictor.run()
prob_out = self.cls_output_tensors[0].copy_to_cpu() prob_out = self.cls_output_tensors[0].copy_to_cpu()
label_out = self.cls_output_tensors[1].copy_to_cpu() label_out = self.cls_output_tensors[1].copy_to_cpu()
...@@ -366,7 +395,7 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -366,7 +395,7 @@ class ChineseOCRDBCRNN(hub.Module):
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
self.rec_input_tensor.copy_from_cpu(norm_img_batch) self.rec_input_tensor.copy_from_cpu(norm_img_batch)
self.rec_predictor.zero_copy_run() self.rec_predictor.run()
rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu() rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu()
rec_idx_lod = self.rec_output_tensors[0].lod()[0] rec_idx_lod = self.rec_output_tensors[0].lod()[0]
...@@ -471,7 +500,11 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -471,7 +500,11 @@ class ChineseOCRDBCRNN(hub.Module):
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.recognize_text( results = self.recognize_text(
paths=[args.input_path], use_gpu=args.use_gpu, output_dir=args.output_dir, visualization=args.visualization) paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
visualization=args.visualization,
use_device=args.use_device)
return results return results
def add_module_config_arg(self): def add_module_config_arg(self):
...@@ -484,6 +517,10 @@ class ChineseOCRDBCRNN(hub.Module): ...@@ -484,6 +517,10 @@ class ChineseOCRDBCRNN(hub.Module):
'--output_dir', type=str, default='ocr_result', help="The directory to save output images.") '--output_dir', type=str, default='ocr_result', help="The directory to save output images.")
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") '--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
...@@ -10,7 +10,9 @@ import math ...@@ -10,7 +10,9 @@ import math
import os import os
import time import time
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from PIL import Image from PIL import Image
...@@ -57,6 +59,14 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -57,6 +59,14 @@ class ChineseOCRDBCRNNServer(hub.Module):
self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config( self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config(
self.cls_pretrained_model_path) self.cls_pretrained_model_path)
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self, pretrained_model_path): def _set_config(self, pretrained_model_path):
""" """
predictor config path predictor config path
...@@ -64,35 +74,49 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -64,35 +74,49 @@ class ChineseOCRDBCRNNServer(hub.Module):
model_file_path = os.path.join(pretrained_model_path, 'model') model_file_path = os.path.join(pretrained_model_path, 'model')
params_file_path = os.path.join(pretrained_model_path, 'params') params_file_path = os.path.join(pretrained_model_path, 'params')
config = AnalysisConfig(model_file_path, params_file_path) config = Config(model_file_path, params_file_path)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu: # detect npu
config.enable_use_gpu(8000, 0) npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
self.use_device = "npu"
config.enable_npu(device_id=npu_id)
else: else:
config.disable_gpu() # detect gpu
if self.enable_mkldnn: gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
# cache 10 different shapes for mkldnn to avoid memory leak if gpu_id != -1:
config.set_mkldnn_cache_capacity(10) # use gpu
config.enable_mkldnn() self.use_device = "gpu"
config.enable_use_gpu(memory_pool_init_size_mb=8000, device_id=gpu_id)
else:
# detect xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
self.use_device = "xpu"
config.enable_xpu(100)
else:
self.use_device = "cpu"
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)
if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
predictor = create_paddle_predictor(config) predictor = create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0]) input_tensor = predictor.get_input_handle(input_names[0])
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
output_tensors = [] output_tensors = []
for output_name in output_names: for output_name in output_names:
output_tensor = predictor.get_output_tensor(output_name) output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors return predictor, input_tensor, output_tensors
...@@ -190,7 +214,8 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -190,7 +214,8 @@ class ChineseOCRDBCRNNServer(hub.Module):
visualization=False, visualization=False,
box_thresh=0.5, box_thresh=0.5,
text_thresh=0.5, text_thresh=0.5,
angle_classification_thresh=0.9): angle_classification_thresh=0.9,
use_device=None):
""" """
Get the chinese texts in the predicted images. Get the chinese texts in the predicted images.
Args: Args:
...@@ -203,18 +228,22 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -203,18 +228,22 @@ class ChineseOCRDBCRNNServer(hub.Module):
box_thresh(float): the threshold of the detected text box's confidence box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the chinese text recognition confidence text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence angle_classification_thresh(float): the threshold of the angle classification confidence
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
res (list): The result of chinese texts and save path of images. res (list): The result of chinese texts and save path of images.
""" """
if use_gpu: if use_device is not None:
try: # check 'use_device' match 'device on init'
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device != self.use_device:
int(_places[0])
except:
raise RuntimeError( raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." "the 'use_device' parameter when calling detect_text, does not match internal device found on init."
) )
else:
# check 'use_gpu' match 'device on init'
if use_gpu == True and self.use_device != 'gpu' or use_gpu == False and self.use_device == 'gpu':
raise RuntimeError(
"the 'use_gpu' parameter when calling detect_text, does not match internal device found on init.")
self.use_gpu = use_gpu self.use_gpu = use_gpu
...@@ -228,7 +257,7 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -228,7 +257,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
assert predicted_data != [], "There is not any image to be predicted. Please check the input data." assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
detection_results = self.text_detector_module.detect_text( detection_results = self.text_detector_module.detect_text(
images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh) images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh, use_device=use_device)
boxes = [np.array(item['data']).astype(np.float32) for item in detection_results] boxes = [np.array(item['data']).astype(np.float32) for item in detection_results]
all_results = [] all_results = []
...@@ -326,7 +355,7 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -326,7 +355,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
self.cls_input_tensor.copy_from_cpu(norm_img_batch) self.cls_input_tensor.copy_from_cpu(norm_img_batch)
self.cls_predictor.zero_copy_run() self.cls_predictor.run()
prob_out = self.cls_output_tensors[0].copy_to_cpu() prob_out = self.cls_output_tensors[0].copy_to_cpu()
label_out = self.cls_output_tensors[1].copy_to_cpu() label_out = self.cls_output_tensors[1].copy_to_cpu()
...@@ -370,7 +399,7 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -370,7 +399,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
self.rec_input_tensor.copy_from_cpu(norm_img_batch) self.rec_input_tensor.copy_from_cpu(norm_img_batch)
self.rec_predictor.zero_copy_run() self.rec_predictor.run()
rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu() rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu()
rec_idx_lod = self.rec_output_tensors[0].lod()[0] rec_idx_lod = self.rec_output_tensors[0].lod()[0]
...@@ -475,7 +504,11 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -475,7 +504,11 @@ class ChineseOCRDBCRNNServer(hub.Module):
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.recognize_text( results = self.recognize_text(
paths=[args.input_path], use_gpu=args.use_gpu, output_dir=args.output_dir, visualization=args.visualization) paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
visualization=args.visualization,
use_device=args.use_device)
return results return results
def add_module_config_arg(self): def add_module_config_arg(self):
...@@ -488,6 +521,10 @@ class ChineseOCRDBCRNNServer(hub.Module): ...@@ -488,6 +521,10 @@ class ChineseOCRDBCRNNServer(hub.Module):
'--output_dir', type=str, default='ocr_result', help="The directory to save output images.") '--output_dir', type=str, default='ocr_result', help="The directory to save output images.")
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") '--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
...@@ -9,7 +9,9 @@ import math ...@@ -9,7 +9,9 @@ import math
import os import os
import time import time
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from PIL import Image from PIL import Image
...@@ -53,6 +55,14 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -53,6 +55,14 @@ class ChineseTextDetectionDB(hub.Module):
'This module requires the shapely, pyclipper tools. The running environment does not meet the requirements. Please install the two packages.' 'This module requires the shapely, pyclipper tools. The running environment does not meet the requirements. Please install the two packages.'
) )
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
...@@ -60,36 +70,49 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -60,36 +70,49 @@ class ChineseTextDetectionDB(hub.Module):
model_file_path = os.path.join(self.pretrained_model_path, 'model') model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params') params_file_path = os.path.join(self.pretrained_model_path, 'params')
config = AnalysisConfig(model_file_path, params_file_path) config = Config(model_file_path, params_file_path)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu: # detect npu
config.enable_use_gpu(8000, 0) npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
self.use_device = "npu"
config.enable_npu(device_id=npu_id)
else: else:
config.disable_gpu() # detect gpu
config.set_cpu_math_library_num_threads(6) gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if self.enable_mkldnn: if gpu_id != -1:
# cache 10 different shapes for mkldnn to avoid memory leak # use gpu
config.set_mkldnn_cache_capacity(10) self.use_device = "gpu"
config.enable_mkldnn() config.enable_use_gpu(memory_pool_init_size_mb=8000, device_id=gpu_id)
else:
# detect xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
self.use_device = "xpu"
config.enable_xpu(100)
else:
self.use_device = "cpu"
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)
if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
# use zero copy # use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config) self.predictor = create_predictor(config)
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
self.input_tensor = self.predictor.get_input_tensor(input_names[0]) self.input_tensor = self.predictor.get_input_handle(input_names[0])
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
self.output_tensors = [] self.output_tensors = []
for output_name in output_names: for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name) output_tensor = self.predictor.get_output_handle(output_name)
self.output_tensors.append(output_tensor) self.output_tensors.append(output_tensor)
def read_images(self, paths=[]): def read_images(self, paths=[]):
...@@ -162,7 +185,8 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -162,7 +185,8 @@ class ChineseTextDetectionDB(hub.Module):
use_gpu=False, use_gpu=False,
output_dir='detection_result', output_dir='detection_result',
visualization=False, visualization=False,
box_thresh=0.5): box_thresh=0.5,
use_device=None):
""" """
Get the text box in the predicted images. Get the text box in the predicted images.
Args: Args:
...@@ -172,21 +196,24 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -172,21 +196,24 @@ class ChineseTextDetectionDB(hub.Module):
output_dir (str): The directory to store output images. output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not. visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence box_thresh(float): the threshold of the detected text box's confidence
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
res (list): The result of text detection box and save path of images. res (list): The result of text detection box and save path of images.
""" """
self.check_requirements() self.check_requirements()
from chinese_text_detection_db_mobile.processor import DBProcessTest, DBPostProcess, draw_boxes, get_image_ext from chinese_text_detection_db_mobile.processor import DBProcessTest, DBPostProcess, draw_boxes, get_image_ext
if use_device is not None:
if use_gpu: # check 'use_device' match 'device on init'
try: if use_device != self.use_device:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError( raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." "the 'use_device' parameter when calling detect_text, does not match internal device found on init."
) )
else:
# check 'use_gpu' match 'device on init'
if use_gpu == True and self.use_device != 'gpu' or use_gpu == False and self.use_device == 'gpu':
raise RuntimeError(
"the 'use_gpu' parameter when calling detect_text, does not match internal device found on init.")
if images != [] and isinstance(images, list) and paths == []: if images != [] and isinstance(images, list) and paths == []:
predicted_data = images predicted_data = images
...@@ -218,7 +245,7 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -218,7 +245,7 @@ class ChineseTextDetectionDB(hub.Module):
else: else:
im = im.copy() im = im.copy()
self.input_tensor.copy_from_cpu(im) self.input_tensor.copy_from_cpu(im)
self.predictor.zero_copy_run() self.predictor.run()
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
...@@ -304,7 +331,11 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -304,7 +331,11 @@ class ChineseTextDetectionDB(hub.Module):
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.detect_text( results = self.detect_text(
paths=[args.input_path], use_gpu=args.use_gpu, output_dir=args.output_dir, visualization=args.visualization) paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
visualization=args.visualization,
use_device=args.use_device)
return results return results
def add_module_config_arg(self): def add_module_config_arg(self):
...@@ -317,6 +348,10 @@ class ChineseTextDetectionDB(hub.Module): ...@@ -317,6 +348,10 @@ class ChineseTextDetectionDB(hub.Module):
'--output_dir', type=str, default='detection_result', help="The directory to save output images.") '--output_dir', type=str, default='detection_result', help="The directory to save output images.")
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") '--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
...@@ -9,7 +9,9 @@ import math ...@@ -9,7 +9,9 @@ import math
import os import os
import time import time
from paddle.fluid.core import AnalysisConfig, create_paddle_predictor, PaddleTensor from paddle.inference import Config
from paddle.inference import create_predictor
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from PIL import Image from PIL import Image
...@@ -53,6 +55,14 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -53,6 +55,14 @@ class ChineseTextDetectionDBServer(hub.Module):
'This module requires the shapely, pyclipper tools. The running environment does not meet the requirements. Please install the two packages.' 'This module requires the shapely, pyclipper tools. The running environment does not meet the requirements. Please install the two packages.'
) )
def _get_device_id(self, places):
try:
places = os.environ[places]
id = int(places)
except:
id = -1
return id
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
...@@ -60,33 +70,46 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -60,33 +70,46 @@ class ChineseTextDetectionDBServer(hub.Module):
model_file_path = os.path.join(self.pretrained_model_path, 'model') model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params') params_file_path = os.path.join(self.pretrained_model_path, 'params')
config = AnalysisConfig(model_file_path, params_file_path) config = Config(model_file_path, params_file_path)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu: # detect npu
config.enable_use_gpu(8000, 0) npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
self.use_device = "npu"
config.enable_npu(device_id=npu_id)
else: else:
config.disable_gpu() # detect gpu
if self.enable_mkldnn: gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
config.enable_mkldnn() if gpu_id != -1:
# use gpu
self.use_device = "gpu"
config.enable_use_gpu(memory_pool_init_size_mb=8000, device_id=gpu_id)
else:
# detect xpu
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
self.use_device = "xpu"
config.enable_xpu(100)
else:
self.use_device = "cpu"
config.disable_gpu()
if self.enable_mkldnn:
config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
# use zero copy # use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config) self.predictor = create_predictor(config)
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
self.input_tensor = self.predictor.get_input_tensor(input_names[0]) self.input_tensor = self.predictor.get_input_handle(input_names[0])
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
self.output_tensors = [] self.output_tensors = []
for output_name in output_names: for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name) output_tensor = self.predictor.get_output_handle(output_name)
self.output_tensors.append(output_tensor) self.output_tensors.append(output_tensor)
def read_images(self, paths=[]): def read_images(self, paths=[]):
...@@ -151,7 +174,8 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -151,7 +174,8 @@ class ChineseTextDetectionDBServer(hub.Module):
use_gpu=False, use_gpu=False,
output_dir='detection_result', output_dir='detection_result',
visualization=False, visualization=False,
box_thresh=0.5): box_thresh=0.5,
use_device=None):
""" """
Get the text box in the predicted images. Get the text box in the predicted images.
Args: Args:
...@@ -161,6 +185,7 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -161,6 +185,7 @@ class ChineseTextDetectionDBServer(hub.Module):
output_dir (str): The directory to store output images. output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not. visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence box_thresh(float): the threshold of the detected text box's confidence
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns: Returns:
res (list): The result of text detection box and save path of images. res (list): The result of text detection box and save path of images.
""" """
...@@ -168,14 +193,17 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -168,14 +193,17 @@ class ChineseTextDetectionDBServer(hub.Module):
from chinese_text_detection_db_server.processor import DBPreProcess, DBPostProcess, draw_boxes, get_image_ext from chinese_text_detection_db_server.processor import DBPreProcess, DBPostProcess, draw_boxes, get_image_ext
if use_gpu: if use_device is not None:
try: # check 'use_device' match 'device on init'
_places = os.environ["CUDA_VISIBLE_DEVICES"] if use_device != self.use_device:
int(_places[0])
except:
raise RuntimeError( raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id." "the 'use_device' parameter when calling detect_text, does not match internal device found on init."
) )
else:
# check 'use_gpu' match 'device on init'
if use_gpu == True and self.use_device != 'gpu' or use_gpu == False and self.use_device == 'gpu':
raise RuntimeError(
"the 'use_gpu' parameter when calling detect_text, does not match internal device found on init.")
if images != [] and isinstance(images, list) and paths == []: if images != [] and isinstance(images, list) and paths == []:
predicted_data = images predicted_data = images
...@@ -202,7 +230,7 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -202,7 +230,7 @@ class ChineseTextDetectionDBServer(hub.Module):
im = im.copy() im = im.copy()
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(im) self.input_tensor.copy_from_cpu(im)
self.predictor.zero_copy_run() self.predictor.run()
data_out = self.output_tensors[0].copy_to_cpu() data_out = self.output_tensors[0].copy_to_cpu()
dt_boxes_list = postprocessor(data_out, [ratio_list]) dt_boxes_list = postprocessor(data_out, [ratio_list])
boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape) boxes = self.filter_tag_det_res(dt_boxes_list[0], original_image.shape)
...@@ -278,7 +306,11 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -278,7 +306,11 @@ class ChineseTextDetectionDBServer(hub.Module):
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.detect_text( results = self.detect_text(
paths=[args.input_path], use_gpu=args.use_gpu, output_dir=args.output_dir, visualization=args.visualization) paths=[args.input_path],
use_gpu=args.use_gpu,
output_dir=args.output_dir,
visualization=args.visualization,
use_device=args.use_device)
return results return results
def add_module_config_arg(self): def add_module_config_arg(self):
...@@ -291,6 +323,10 @@ class ChineseTextDetectionDBServer(hub.Module): ...@@ -291,6 +323,10 @@ class ChineseTextDetectionDBServer(hub.Module):
'--output_dir', type=str, default='detection_result', help="The directory to save output images.") '--output_dir', type=str, default='detection_result', help="The directory to save output images.")
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") '--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
self.arg_config_group.add_argument(
'--use_device',
choices=["cpu", "gpu", "xpu", "npu"],
help="use cpu, gpu, xpu or npu. overwrites use_gpu flag.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册