未验证 提交 2758efec 编写于 作者: S Steffy-zxf 提交者: GitHub

Update ocr (#911)

上级 34e60467
......@@ -36,7 +36,8 @@ def recognize_text(images=[],
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
text_thresh=0.5)
text_thresh=0.5,
angle_classification_thresh=0.9)
```
预测API,检测输入图片中的所有中文文本的位置。
......@@ -48,6 +49,7 @@ def recognize_text(images=[],
* use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量**
* box\_thresh (float): 检测文本框置信度的阈值;
* text\_thresh (float): 识别中文文本置信度的阈值;
* angle_classification_thresh(float): 文本角度分类置信度的阈值
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 图片的保存路径,默认设为 ocr\_result;
......
......@@ -203,7 +203,8 @@ class ChineseOCRDBCRNN(hub.Module):
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
text_thresh=0.5):
text_thresh=0.5,
angle_classification_thresh=0.9):
"""
Get the chinese texts in the predicted images.
Args:
......@@ -214,7 +215,9 @@ class ChineseOCRDBCRNN(hub.Module):
output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the recognize chinese texts' confidence
text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence
Returns:
res (list): The result of chinese texts and save path of images.
"""
......@@ -259,7 +262,9 @@ class ChineseOCRDBCRNN(hub.Module):
img_crop = self.get_rotate_crop_image(
original_image, tmp_box)
img_crop_list.append(img_crop)
img_crop_list, angle_list = self._classify_text(img_crop_list)
img_crop_list, angle_list = self._classify_text(
img_crop_list,
angle_classification_thresh=angle_classification_thresh)
rec_results = self._recognize_text(img_crop_list)
# if the recognized text confidence score is lower than text_thresh, then drop it
......@@ -294,12 +299,14 @@ class ChineseOCRDBCRNN(hub.Module):
results = self.recognize_text(images_decode, **kwargs)
return results
def save_result_image(self,
original_image,
detection_boxes,
rec_results,
output_dir='ocr_result',
text_thresh=0.5):
def save_result_image(
self,
original_image,
detection_boxes,
rec_results,
output_dir='ocr_result',
text_thresh=0.5,
):
image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
txts = [item[0] for item in rec_results]
scores = [item[1] for item in rec_results]
......@@ -320,7 +327,7 @@ class ChineseOCRDBCRNN(hub.Module):
cv2.imwrite(save_file_path, draw_img[:, :, ::-1])
return save_file_path
def _classify_text(self, image_list):
def _classify_text(self, image_list, angle_classification_thresh=0.9):
img_list = copy.deepcopy(image_list)
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
......@@ -360,7 +367,7 @@ class ChineseOCRDBCRNN(hub.Module):
score = prob_out[rno][label_idx]
label = label_list[label_idx]
cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > 0.9999:
if '180' in label and score > angle_classification_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res
......
......@@ -35,7 +35,8 @@ def recognize_text(images=[],
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
text_thresh=0.5)
text_thresh=0.5,
angle_classification_thresh=0.9)
```
预测API,检测输入图片中的所有中文文本的位置。
......@@ -47,6 +48,7 @@ def recognize_text(images=[],
* use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA_VISIBLE_DEVICES环境变量**
* box\_thresh (float): 检测文本框置信度的阈值;
* text\_thresh (float): 识别中文文本置信度的阈值;
* angle_classification_thresh(float): 文本角度分类置信度的阈值
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 图片的保存路径,默认设为 ocr\_result;
......@@ -141,3 +143,7 @@ pyclipper
* 1.0.1
支持mkldnn加速CPU计算
* 1.1.0
使用三阶段模型(文本框检测-角度分类-文字识别)识别图片文字。
......@@ -22,17 +22,23 @@ class CharacterOps(object):
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif self.character_type == "ch":
character_dict_path = config['character_dict_path']
add_space = False
if 'use_space_char' in config:
add_space = config['use_space_char']
self.character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n")
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
if add_space:
self.character_str += " "
dict_character = list(self.character_str)
elif self.character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
......@@ -46,6 +52,8 @@ class CharacterOps(object):
self.end_str = "eos"
if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character
elif self.loss_type == "srn":
dict_character = dict_character + [self.beg_str, self.end_str]
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
......@@ -90,7 +98,7 @@ class CharacterOps(object):
if is_remove_duplicate:
if idx > 0 and text_index[idx - 1] == text_index[idx]:
continue
char_list.append(self.character[text_index[idx]])
char_list.append(self.character[int(text_index[idx])])
text = ''.join(char_list)
return text
......@@ -139,6 +147,42 @@ def cal_predicts_accuracy(char_ops,
return acc, acc_num, img_num
def cal_predicts_accuracy_srn(char_ops,
preds,
labels,
max_text_len,
is_debug=False):
acc_num = 0
img_num = 0
char_num = char_ops.get_char_num()
total_len = preds.shape[0]
img_num = int(total_len / max_text_len)
for i in range(img_num):
cur_label = []
cur_pred = []
for j in range(max_text_len):
if labels[j + i * max_text_len] != int(char_num - 1): #0
cur_label.append(labels[j + i * max_text_len][0])
else:
break
for j in range(max_text_len + 1):
if j < len(cur_label) and preds[
j + i * max_text_len][0] != cur_label[j]:
break
elif j == len(cur_label) and j == max_text_len:
acc_num += 1
break
elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(
char_num - 1):
acc_num += 1
break
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def convert_rec_attention_infer_res(preds):
img_num = preds.shape[0]
target_lod = [0]
......
......@@ -25,7 +25,7 @@ from chinese_ocr_db_crnn_server.utils import base64_to_cv2, draw_ocr, get_image_
@moduleinfo(
name="chinese_ocr_db_crnn_server",
version="1.0.3",
version="1.1.0",
summary=
"The module can recognize the chinese texts in an image. Firstly, it will detect the text box positions based on the differentiable_binarization_chn module. Then it recognizes the chinese texts. ",
author="paddle-dev",
......@@ -41,24 +41,31 @@ class ChineseOCRDBCRNNServer(hub.Module):
char_ops_params = {
'character_type': 'ch',
'character_dict_path': self.character_dict_path,
'loss_type': 'ctc'
'loss_type': 'ctc',
'max_text_length': 25,
'use_space_char': True
}
self.char_ops = CharacterOps(char_ops_params)
self.rec_image_shape = [3, 32, 320]
self._text_detector_module = text_detector_module
self.font_file = os.path.join(self.directory, 'assets', 'simfang.ttf')
self.pretrained_model_path = os.path.join(self.directory, 'assets',
'ch_rec_r34_vd_crnn')
self.enable_mkldnn = enable_mkldnn
self._set_config()
self.rec_pretrained_model_path = os.path.join(
self.directory, 'inference_model', 'character_rec')
self.cls_pretrained_model_path = os.path.join(
self.directory, 'inference_model', 'angle_cls')
self.rec_predictor, self.rec_input_tensor, self.rec_output_tensors = self._set_config(
self.rec_pretrained_model_path)
self.cls_predictor, self.cls_input_tensor, self.cls_output_tensors = self._set_config(
self.cls_pretrained_model_path)
def _set_config(self):
def _set_config(self, pretrained_model_path):
"""
predictor config setting
predictor config path
"""
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
model_file_path = os.path.join(pretrained_model_path, 'model')
params_file_path = os.path.join(pretrained_model_path, 'params')
config = AnalysisConfig(model_file_path, params_file_path)
try:
......@@ -73,21 +80,25 @@ class ChineseOCRDBCRNNServer(hub.Module):
else:
config.disable_gpu()
if self.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.disable_glog_info()
# use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config)
input_names = self.predictor.get_input_names()
self.input_tensor = self.predictor.get_input_tensor(input_names[0])
output_names = self.predictor.get_output_names()
self.output_tensors = []
predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0])
output_names = predictor.get_output_names()
output_tensors = []
for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name)
self.output_tensors.append(output_tensor)
output_tensor = predictor.get_output_tensor(output_name)
output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors
@property
def text_detector_module(self):
......@@ -98,7 +109,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
self._text_detector_module = hub.Module(
name='chinese_text_detection_db_server',
enable_mkldnn=self.enable_mkldnn,
version='1.0.1')
version='1.0.2')
return self._text_detector_module
def read_images(self, paths=[]):
......@@ -114,6 +125,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
return images
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
......@@ -122,23 +134,51 @@ class ChineseOCRDBCRNNServer(hub.Module):
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0],\
[img_crop_width, img_crop_height], [0, img_crop_height]])
'''
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def resize_norm_img(self, img, max_wh_ratio):
def resize_norm_img_rec(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
imgW = int(32 * max_wh_ratio)
assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_cls(self, img):
cls_image_shape = [3, 48, 192]
imgC, imgH, imgW = cls_image_shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
......@@ -148,7 +188,11 @@ class ChineseOCRDBCRNNServer(hub.Module):
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
if cls_image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
......@@ -162,7 +206,8 @@ class ChineseOCRDBCRNNServer(hub.Module):
output_dir='ocr_result',
visualization=False,
box_thresh=0.5,
text_thresh=0.5):
text_thresh=0.5,
angle_classification_thresh=0.9):
"""
Get the chinese texts in the predicted images.
Args:
......@@ -173,7 +218,9 @@ class ChineseOCRDBCRNNServer(hub.Module):
output_dir (str): The directory to store output images.
visualization (bool): Whether to save image or not.
box_thresh(float): the threshold of the detected text box's confidence
text_thresh(float): the threshold of the recognize chinese texts' confidence
text_thresh(float): the threshold of the chinese text recognition confidence
angle_classification_thresh(float): the threshold of the angle classification confidence
Returns:
res (list): The result of chinese texts and save path of images.
"""
......@@ -199,6 +246,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
detection_results = self.text_detector_module.detect_text(
images=predicted_data, use_gpu=self.use_gpu, box_thresh=box_thresh)
boxes = [
np.array(item['data']).astype(np.float32)
for item in detection_results
......@@ -207,7 +255,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
for index, img_boxes in enumerate(boxes):
original_image = predicted_data[index].copy()
result = {'save_path': ''}
if img_boxes is None:
if img_boxes.size == 0:
result['data'] = []
else:
img_crop_list = []
......@@ -217,8 +265,11 @@ class ChineseOCRDBCRNNServer(hub.Module):
img_crop = self.get_rotate_crop_image(
original_image, tmp_box)
img_crop_list.append(img_crop)
img_crop_list, angle_list = self._classify_text(
img_crop_list,
angle_classification_thresh=angle_classification_thresh)
rec_results = self._recognize_text(img_crop_list)
# if the recognized text confidence score is lower than text_thresh, then drop it
rec_res_final = []
for index, res in enumerate(rec_results):
......@@ -251,12 +302,14 @@ class ChineseOCRDBCRNNServer(hub.Module):
results = self.recognize_text(images_decode, **kwargs)
return results
def save_result_image(self,
original_image,
detection_boxes,
rec_results,
output_dir='ocr_result',
text_thresh=0.5):
def save_result_image(
self,
original_image,
detection_boxes,
rec_results,
output_dir='ocr_result',
text_thresh=0.5,
):
image = Image.fromarray(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
txts = [item[0] for item in rec_results]
scores = [item[1] for item in rec_results]
......@@ -277,32 +330,86 @@ class ChineseOCRDBCRNNServer(hub.Module):
cv2.imwrite(save_file_path, draw_img[:, :, ::-1])
return save_file_path
def _recognize_text(self, image_list):
img_num = len(image_list)
def _classify_text(self, image_list, angle_classification_thresh=0.9):
img_list = copy.deepcopy(image_list)
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the cls process
indices = np.argsort(np.array(width_list))
cls_res = [['', 0.0]] * img_num
batch_num = 30
rec_res = []
predict_time = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = image_list[ino].shape[0:2]
wh_ratio = w / h
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(image_list[ino], max_wh_ratio)
norm_img = self.resize_norm_img_cls(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0]
predict_batch = self.output_tensors[1].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0]
self.cls_input_tensor.copy_from_cpu(norm_img_batch)
self.cls_predictor.zero_copy_run()
prob_out = self.cls_output_tensors[0].copy_to_cpu()
label_out = self.cls_output_tensors[1].copy_to_cpu()
if len(label_out.shape) != 1:
prob_out, label_out = label_out, prob_out
label_list = ['0', '180']
for rno in range(len(label_out)):
label_idx = label_out[rno]
score = prob_out[rno][label_idx]
label = label_list[label_idx]
cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > angle_classification_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res
def _recognize_text(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = 30
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img_rec(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
norm_img_batch = norm_img_batch.copy()
self.rec_input_tensor.copy_from_cpu(norm_img_batch)
self.rec_predictor.zero_copy_run()
rec_idx_batch = self.rec_output_tensors[0].copy_to_cpu()
rec_idx_lod = self.rec_output_tensors[0].lod()[0]
predict_batch = self.rec_output_tensors[1].copy_to_cpu()
predict_lod = self.rec_output_tensors[1].lod()[0]
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
......@@ -317,9 +424,10 @@ class ChineseOCRDBCRNNServer(hub.Module):
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
# rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
return rec_res
return rec_res
def save_inference_model(self,
dirname,
......@@ -327,9 +435,12 @@ class ChineseOCRDBCRNNServer(hub.Module):
params_filename=None,
combined=True):
detector_dir = os.path.join(dirname, 'text_detector')
classifier_dir = os.path.join(dirname, 'angle_classifier')
recognizer_dir = os.path.join(dirname, 'text_recognizer')
self._save_detector_model(detector_dir, model_filename, params_filename,
combined)
self._save_classifier_model(classifier_dir, model_filename,
params_filename, combined)
self._save_recognizer_model(recognizer_dir, model_filename,
params_filename, combined)
logger.info("The inference model has been saved in the path {}".format(
......@@ -354,10 +465,40 @@ class ChineseOCRDBCRNNServer(hub.Module):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
model_file_path = os.path.join(self.pretrained_model_path, 'model')
params_file_path = os.path.join(self.pretrained_model_path, 'params')
model_file_path = os.path.join(self.rec_pretrained_model_path, 'model')
params_file_path = os.path.join(self.rec_pretrained_model_path,
'params')
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.rec_pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
def _save_classifier_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
model_file_path = os.path.join(self.cls_pretrained_model_path, 'model')
params_file_path = os.path.join(self.cls_pretrained_model_path,
'params')
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.pretrained_model_path,
dirname=self.cls_pretrained_model_path,
model_filename=model_file_path,
params_filename=params_file_path,
executor=exe)
......@@ -429,8 +570,7 @@ class ChineseOCRDBCRNNServer(hub.Module):
if __name__ == '__main__':
ocr = ChineseOCRDBCRNNServer(enable_mkldnn=True)
print(ocr.name)
ocr = ChineseOCRDBCRNNServer(enable_mkldnn=False)
image_path = [
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/11.jpg',
'/mnt/zhangxuefei/PaddleOCR/doc/imgs/12.jpg',
......
......@@ -175,8 +175,8 @@ def sorted_boxes(dt_boxes):
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
......
......@@ -132,3 +132,15 @@ pyclipper
* 1.0.1
修复使用在线服务调用模型失败问题
* 1.0.2
支持mkldnn加速CPU计算
* 1.0.3
增加更多预训练数据,更新预训练参数
1.1.0
使用超轻量级的三阶段模型(文本框检测-角度分类-文字识别)识别图片文字。
......@@ -129,3 +129,7 @@ pyclipper
* 1.0.2
支持mkldnn加速CPU计算
* 1.0.3
增加更多预训练数据,更新预训练参数
......@@ -29,7 +29,7 @@ def base64_to_cv2(b64str):
@moduleinfo(
name="chinese_text_detection_db_server",
version="1.0.1",
version="1.0.2",
summary=
"The module aims to detect chinese text position in the image, which is based on differentiable_binarization algorithm.",
author="paddle-dev",
......@@ -41,7 +41,7 @@ class ChineseTextDetectionDBServer(hub.Module):
initialize with the necessary elements
"""
self.pretrained_model_path = os.path.join(self.directory,
'ch_det_r50_vd_db')
'inference_model')
self.enable_mkldnn = enable_mkldnn
self._set_config()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册