未验证 提交 a48dac50 编写于 作者: Z zhoujun 提交者: GitHub

Merge branch 'dygraph' into lite

......@@ -147,6 +147,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.itemsToShapesbox = {}
self.shapesToItemsbox = {}
self.prevLabelText = getStr('tempLabel')
self.noLabelText = getStr('nullLabel')
self.model = 'paddle'
self.PPreader = None
self.autoSaveNum = 5
......@@ -1020,7 +1021,7 @@ class MainWindow(QMainWindow, WindowMixin):
item.setText(str([(int(p.x()), int(p.y())) for p in shape.points]))
self.updateComboBox()
def updateComboBox(self): # TODO:貌似没用
def updateComboBox(self):
# Get the unique labels and add them to the Combobox.
itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())]
......@@ -1040,7 +1041,7 @@ class MainWindow(QMainWindow, WindowMixin):
return dict(label=s.label, # str
line_color=s.line_color.getRgb(),
fill_color=s.fill_color.getRgb(),
points=[(p.x(), p.y()) for p in s.points], # QPonitF
points=[(int(p.x()), int(p.y())) for p in s.points], # QPonitF
# add chris
difficult=s.difficult) # bool
......@@ -1069,7 +1070,7 @@ class MainWindow(QMainWindow, WindowMixin):
# print('Image:{0} -> Annotation:{1}'.format(self.filePath, annotationFilePath))
return True
except:
self.errorMessage(u'Error saving label data')
self.errorMessage(u'Error saving label data', u'Error saving label data')
return False
def copySelectedShape(self):
......@@ -1802,7 +1803,11 @@ class MainWindow(QMainWindow, WindowMixin):
result.insert(0, box)
print('result in reRec is ', result)
self.result_dic.append(result)
if result[1][0] == shape.label:
else:
print('Can not recognise the box')
self.result_dic.append([box,(self.noLabelText,0)])
if self.noLabelText == shape.label or result[1][0] == shape.label:
print('label no change')
else:
rec_flag += 1
......@@ -1836,9 +1841,14 @@ class MainWindow(QMainWindow, WindowMixin):
print('label no change')
else:
shape.label = result[1][0]
else:
print('Can not recognise the box')
if self.noLabelText == shape.label:
print('label no change')
else:
shape.label = self.noLabelText
self.singleLabel(shape)
self.setDirty()
print(box)
def autolcm(self):
vbox = QVBoxLayout()
......
......@@ -45,7 +45,7 @@ class Canvas(QWidget):
CREATE, EDIT = list(range(2))
_fill_drawing = False # draw shadows
epsilon = 11.0
epsilon = 5.0
def __init__(self, *args, **kwargs):
super(Canvas, self).__init__(*args, **kwargs)
......
此差异已折叠。
......@@ -87,6 +87,7 @@ creatPolygon=四点标注
drawSquares=正方形标注
saveRec=保存识别结果
tempLabel=待识别
nullLabel=无法识别
steps=操作步骤
choseModelLg=选择模型语言
cancel=取消
......
......@@ -77,7 +77,7 @@ IR=Image Resize
autoRecognition=Auto Recognition
reRecognition=Re-recognition
mfile=File
medit=Eidt
medit=Edit
mview=View
mhelp=Help
iconList=Icon List
......@@ -87,6 +87,7 @@ creatPolygon=Create Quadrilateral
drawSquares=Draw Squares
saveRec=Save Recognition Result
tempLabel=TEMPORARY
nullLabel=NULL
steps=Steps
choseModelLg=Choose Model Language
cancel=Cancel
......
......@@ -32,7 +32,8 @@ PaddleOCR supports both dynamic graph and static graph programming paradigm
<div align="center">
<img src="doc/imgs_results/ch_ppocr_mobile_v2.0/test_add_91.jpg" width="800">
<img src="doc/imgs_results/ch_ppocr_mobile_v2.0/00018069.jpg" width="800">
<img src="doc/imgs_results/multi_lang/img_01.jpg" width="800">
<img src="doc/imgs_results/multi_lang/img_02.jpg" width="800">
</div>
The above pictures are the visualizations of the general ppocr_server model. For more effect pictures, please see [More visualizations](./doc/doc_en/visualization_en.md).
......
......@@ -62,20 +62,21 @@ PostProcess:
mode: fast # fast or slow two ways
Metric:
name: E2EMetric
gt_mat_dir: # the dir of gt_mat
gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat
character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e
Train:
dataset:
name: PGDataSet
label_file_list: [.././train_data/total_text/train/]
data_dir: ./train_data/total_text/train
label_file_list: [./train_data/total_text/train/]
ratio_list: [1.0]
data_format: icdar #two data format: icdar/textnet
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- E2ELabelEncode:
- PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24
......@@ -92,13 +93,12 @@ Train:
Eval:
dataset:
name: PGDataSet
data_dir: ./train_data/
data_dir: ./train_data/total_text/test
label_file_list: [./train_data/total_text/test/]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- E2ELabelEncode:
- E2EResizeForTest:
max_side_len: 768
- NormalizeImage:
......@@ -108,7 +108,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
keep_keys: [ 'image', 'shape', 'img_id']
loader:
shuffle: False
drop_last: False
......
......@@ -118,7 +118,6 @@ class ArgsParser(ArgumentParser):
return config
def _set_language(self, type):
print("type:", type)
lang = type[0]
assert (type), "please use -l or --language to choose language type"
assert(
......
......@@ -40,6 +40,7 @@ endif()
if (WIN32)
include_directories("${PADDLE_LIB}/paddle/fluid/inference")
include_directories("${PADDLE_LIB}/paddle/include")
link_directories("${PADDLE_LIB}/paddle/lib")
link_directories("${PADDLE_LIB}/paddle/fluid/inference")
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH)
......@@ -140,22 +141,22 @@ else()
endif ()
endif()
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a
# Note: libpaddle_inference_api.so/a must put before libpaddle_inference.so/a
if(WITH_STATIC_LIB)
if(WIN32)
set(DEPS
${PADDLE_LIB}/paddle/lib/paddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(DEPS
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
endif()
else()
if(WIN32)
set(DEPS
${PADDLE_LIB}/paddle/lib/paddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX})
${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
else()
set(DEPS
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX})
${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
endif(WITH_STATIC_LIB)
......
......@@ -74,9 +74,10 @@ opencv3/
* 有2种方式获取Paddle预测库,下面进行详细介绍。
#### 1.2.1 直接下载安装
* [Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html)上提供了不同cuda版本的Linux预测库,可以在官网查看并选择合适的预测库版本
* [Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)上提供了不同cuda版本的Linux预测库,可以在官网查看并选择合适的预测库版本(*建议选择paddle版本>=2.0.1版本的预测库*
* 下载之后使用下面的方法解压。
......@@ -130,8 +131,6 @@ build/paddle_inference_install_dir/
其中`paddle`就是C++预测所需的Paddle库,`version.txt`中包含当前预测库的版本信息。
## 2 开始运行
### 2.1 将模型导出为inference model
......@@ -232,7 +231,7 @@ visualize 1 # 是否对结果进行可视化,为1时,会在当前文件夹
最终屏幕上会输出检测结果如下。
<div align="center">
<img src="../imgs/cpp_infer_pred_12.png" width="600">
<img src="./imgs/cpp_infer_pred_12.png" width="600">
</div>
......
......@@ -91,8 +91,8 @@ tar -xf paddle_inference.tgz
Finally you can see the following files in the folder of `paddle_inference/`.
#### 1.2.2 Compile from the source code
* If you want to get the latest Paddle inference library features, you can download the latest code from Paddle github repository and compile the inference library from the source code.
* You can refer to [Paddle inference library] (https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/05_inference_deployment/inference/build_and_install_lib_en.html) to get the Paddle source code from github, and then compile To generate the latest inference library. The method of using git to access the code is as follows.
* If you want to get the latest Paddle inference library features, you can download the latest code from Paddle github repository and compile the inference library from the source code. It is recommended to download the inference library with paddle version greater than or equal to 2.0.1.
* You can refer to [Paddle inference library] (https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/inference_deployment/inference/build_and_install_lib_en.html) to get the Paddle source code from github, and then compile To generate the latest inference library. The method of using git to access the code is as follows.
```shell
......@@ -238,7 +238,7 @@ visualize 1 # Whether to visualize the results,when it is set as 1, The predic
The detection results will be shown on the screen, which is as follows.
<div align="center">
<img src="../imgs/cpp_infer_pred_12.png" width="600">
<img src="./imgs/cpp_infer_pred_12.png" width="600">
</div>
......
......@@ -113,7 +113,7 @@ python3 generate_multi_language_configs.py -l it \
| cyrillic_mobile_v2.0_rec | 斯拉夫字母 | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
| devanagari_mobile_v2.0_rec | 梵文字母 | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
更多支持语种请参考: [多语言模型](./multi_languages.md)
<a name="文本方向分类模型"></a>
......
......@@ -134,7 +134,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
<a name="python_脚本运行"></a>
### 2.2 python 脚本运行
ppocr 也支持在python脚本中运行,便于嵌入到您自己的代码中:
ppocr 也支持在python脚本中运行,便于嵌入到您自己的代码中
* 整图预测(检测+识别)
......@@ -155,7 +155,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf')
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/korean.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
......@@ -240,7 +240,7 @@ ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别
|德文|german|german|
|日文|japan|japan|
|韩文|korean|korean|
|中文繁体|chinese traditional |ch_tra|
|中文繁体|chinese traditional |chinese_cht|
|意大利文| Italian |it|
|西班牙文|Spanish |es|
|葡萄牙文| Portuguese|pt|
......@@ -259,7 +259,6 @@ ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别
|乌克兰文|Ukranian|uk|
|白俄罗斯文|Belarusian|be|
|泰卢固文|Telugu |te|
|卡纳达文|Kannada |kn|
|泰米尔文|Tamil |ta|
|南非荷兰文 |Afrikaans |af|
|阿塞拜疆文 |Azerbaijani |az|
......
......@@ -111,7 +111,7 @@ python3 generate_multi_language_configs.py -l it \
| cyrillic_mobile_v2.0_rec | Lightweight model for cyrillic recognition | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
| devanagari_mobile_v2.0_rec | Lightweight model for devanagari recognition | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
For more supported languages, please refer to : [Multi-language model](./multi_languages_en.md)
<a name="Angle"></a>
......
......@@ -153,7 +153,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf')
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/korean.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
......@@ -232,7 +232,7 @@ For functions such as data annotation, you can read the complete [Document Tutor
|german|german|
|japan|japan|
|korean|korean|
|chinese traditional |ch_tra|
|chinese traditional |chinese_cht|
| Italian |it|
|Spanish |es|
| Portuguese|pt|
......@@ -251,7 +251,6 @@ For functions such as data annotation, you can read the complete [Document Tutor
|Ukranian|uk|
|Belarusian|be|
|Telugu |te|
|Kannada |kn|
|Tamil |ta|
|Afrikaans |af|
|Azerbaijani |az|
......
......@@ -30,6 +30,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from tools.infer.utility import draw_ocr
__all__ = ['PaddleOCR']
......@@ -117,7 +118,7 @@ model_urls = {
}
SUPPORT_DET_MODEL = ['DB']
VERSION = 2.1
VERSION = '2.1'
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
......@@ -315,14 +316,13 @@ class PaddleOCR(predict_system.TextSystem):
# init model dir
if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(
BASE_DIR, '{}/det/{}'.format(VERSION, det_lang))
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
'det', det_lang)
if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(
BASE_DIR, '{}/rec/{}'.format(VERSION, lang))
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
'rec', lang)
if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(
BASE_DIR, '{}/cls'.format(VERSION))
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
print(postprocess_params)
# download model
maybe_download(postprocess_params.det_model_dir,
......
......@@ -96,7 +96,7 @@ class BaseRecLabelEncode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne'
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
......@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character
class E2ELabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
self.pad_num = len(self.dict) # the length to pad
class E2ELabelEncode(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
texts = data['strs']
temp_texts = []
for text in texts:
text = text.lower()
text = self.encode(text)
if text is None:
return None
text = text + [self.pad_num] * (self.max_text_len - len(text))
temp_texts.append(text)
data['strs'] = np.array(temp_texts)
import json
label = data['label']
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
box = label[bno]['points']
txt = label[bno]['transcription']
boxes.append(box)
txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['texts'] = txts
data['ignore_tags'] = txt_tags
return data
......
......@@ -88,7 +88,7 @@ class PGProcessTrain(object):
return min_area_quad
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
def check_and_validate_polys(self, polys, tags, im_size):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
......@@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags:
:return:
"""
(h, w) = xxx_todo_changeme
(h, w) = im_size
if polys.shape[0] == 0:
return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
......@@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size = 512
im = data['image']
text_polys = data['polys']
text_tags = data['tags']
text_strs = data['strs']
text_tags = data['ignore_tags']
text_strs = data['texts']
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w))
......
......@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
self.data_format = dataset_config.get('data_format', 'icdar')
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
self.data_format)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
......@@ -55,108 +55,40 @@ class PGDataSet(Dataset):
random.shuffle(self.data_lines)
return
def extract_polys(self, poly_txt_path):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys, txt_tags, txts = [], [], []
with open(poly_txt_path) as f:
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(',')))
text_polys.append(
np.array(
poly, dtype=np.float32).reshape(-1, 2))
txts.append(txt)
txt_tags.append(txt == '###')
return np.array(list(map(np.array, text_polys))), \
np.array(txt_tags, dtype=np.bool), txts
def extract_info_textnet(self, im_fn, img_dir=''):
"""
Extract information from line in textnet format.
"""
info_list = im_fn.split('\t')
img_path = ''
for ext in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
]:
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
break
if img_path == '':
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
info_list[0], img_dir))
nBox = (len(info_list) - 1) // 9
wordBBs, txts, txt_tags = [], [], []
for n in range(0, nBox):
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
txt = info_list[(n + 1) * 9]
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, data_source in enumerate(file_list):
image_files = []
if data_format == 'icdar':
image_files = [(data_source, x) for x in
os.listdir(os.path.join(data_source, 'rgb'))
if x.split('.')[-1] in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
'tiff', 'gif', 'JPG'
]]
elif data_format == 'textnet':
with open(data_source) as f:
image_files = [(data_source, x.strip())
for x in f.readlines()]
else:
print("Unrecognized data format...")
exit(-1)
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
image_files = random.sample(
image_files, round(len(image_files) * ratio_list[idx]))
data_lines.extend(image_files)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_path, data_line = self.data_lines[file_idx]
data_line = self.data_lines[file_idx]
try:
if self.data_format == 'icdar':
im_path = os.path.join(data_path, 'rgb', data_line)
poly_path = os.path.join(data_path, 'poly',
data_line.split('.')[0] + '.txt')
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
if self.mode.lower() == 'eval':
img_id = int(data_line.split(".")[0][7:])
else:
image_dir = os.path.join(os.path.dirname(data_path), 'image')
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
data_line, image_dir)
img_id = int(data_line.split(".")[0][3:])
data = {
'img_path': im_path,
'polys': text_polys,
'tags': text_tags,
'strs': text_strs,
'img_id': img_id
}
img_id = 0
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
......
......@@ -35,11 +35,11 @@ class E2EMetric(object):
self.reset()
def __call__(self, preds, batch, **kwargs):
img_id = batch[5][0]
img_id = batch[2][0]
e2e_info_list = [{
'points': det_polyon,
'text': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
'texts': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result)
......
......@@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
'ne', 'EN'
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
......
......@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['text']
text = pred_dict[i]['texts']
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
......
......@@ -21,6 +21,7 @@ import math
import numpy as np
from itertools import groupby
from cv2.ximgproc import thinning as thin
from skimage.morphology._skeletonize import thin
......
......@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w, src_h, self.valid_set)
data = {
'points': poly_list,
'strs': keep_str_list,
'texts': keep_str_list,
}
return data
......@@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
exit(-1)
data = {
'points': poly_list,
'strs': keep_str_list,
'texts': keep_str_list,
}
return data
......@@ -122,7 +122,7 @@ class TextE2E(object):
else:
raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list)
points, strs = post_result['points'], post_result['strs']
points, strs = post_result['points'], post_result['texts']
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, strs, elapse
......
......@@ -103,7 +103,7 @@ def main():
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
points, strs = post_result['points'], post_result['strs']
points, strs = post_result['points'], post_result['texts']
# write resule
dt_boxes_json = []
for poly, str in zip(points, strs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册