', '', '',
@@ -169,7 +168,8 @@ class StructureSystem(object):
'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
'img': roi_img,
- 'res': res
+ 'res': res,
+ 'img_idx': img_idx
})
end = time.time()
time_dict['all'] = end - start
@@ -179,26 +179,29 @@ class StructureSystem(object):
return None, None
-def save_structure_res(res, save_folder, img_name):
+def save_structure_res(res, save_folder, img_name, img_idx=0):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
res_cp = deepcopy(res)
# save res
with open(
- os.path.join(excel_save_folder, 'res.txt'), 'w',
+ os.path.join(excel_save_folder, 'res_{}.txt'.format(img_idx)),
+ 'w',
encoding='utf8') as f:
for region in res_cp:
roi_img = region.pop('img')
f.write('{}\n'.format(json.dumps(region)))
- if region['type'] == 'table' and len(region[
+ if region['type'].lower() == 'table' and len(region[
'res']) > 0 and 'html' in region['res']:
- excel_path = os.path.join(excel_save_folder,
- '{}.xlsx'.format(region['bbox']))
+ excel_path = os.path.join(
+ excel_save_folder,
+ '{}_{}.xlsx'.format(region['bbox'], img_idx))
to_excel(region['res']['html'], excel_path)
- elif region['type'] == 'figure':
- img_path = os.path.join(excel_save_folder,
- '{}.jpg'.format(region['bbox']))
+ elif region['type'].lower() == 'figure':
+ img_path = os.path.join(
+ excel_save_folder,
+ '{}_{}.jpg'.format(region['bbox'], img_idx))
cv2.imwrite(img_path, roi_img)
@@ -214,28 +217,75 @@ def main(args):
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
- img, flag = check_and_read_gif(image_file)
+ img, flag_gif, flag_pdf = check_and_read(image_file)
img_name = os.path.basename(image_file).split('.')[0]
- if not flag:
+ if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
- if img is None:
- logger.error("error in loading image:{}".format(image_file))
- continue
- res, time_dict = structure_sys(img)
- if structure_sys.mode == 'structure':
- save_structure_res(res, save_folder, img_name)
- draw_img = draw_structure_result(img, res, args.vis_font_path)
- img_save_path = os.path.join(save_folder, img_name, 'show.jpg')
- elif structure_sys.mode == 'vqa':
- raise NotImplementedError
- # draw_img = draw_ser_results(img, res, args.vis_font_path)
- # img_save_path = os.path.join(save_folder, img_name + '.jpg')
- cv2.imwrite(img_save_path, draw_img)
- logger.info('result save to {}'.format(img_save_path))
- if args.recovery:
- convert_info_docx(img, res, save_folder, img_name)
+ if not flag_pdf:
+ if img is None:
+ logger.error("error in loading image:{}".format(image_file))
+ continue
+ res, time_dict = structure_sys(img)
+
+ if structure_sys.mode == 'structure':
+ save_structure_res(res, save_folder, img_name)
+ draw_img = draw_structure_result(img, res, args.vis_font_path)
+ img_save_path = os.path.join(save_folder, img_name, 'show.jpg')
+ elif structure_sys.mode == 'vqa':
+ raise NotImplementedError
+ # draw_img = draw_ser_results(img, res, args.vis_font_path)
+ # img_save_path = os.path.join(save_folder, img_name + '.jpg')
+ cv2.imwrite(img_save_path, draw_img)
+ logger.info('result save to {}'.format(img_save_path))
+ if args.recovery:
+ try:
+ from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx
+ h, w, _ = img.shape
+ res = sorted_layout_boxes(res, w)
+ convert_info_docx(img, res, save_folder, img_name,
+ args.save_pdf)
+ except Exception as ex:
+ logger.error(
+ "error in layout recovery image:{}, err msg: {}".format(
+ image_file, ex))
+ continue
+ else:
+ pdf_imgs = img
+ all_res = []
+ for index, img in enumerate(pdf_imgs):
+
+ res, time_dict = structure_sys(img, index)
+ if structure_sys.mode == 'structure' and res != []:
+ save_structure_res(res, save_folder, img_name, index)
+ draw_img = draw_structure_result(img, res,
+ args.vis_font_path)
+ img_save_path = os.path.join(save_folder, img_name,
+ 'show_{}.jpg'.format(index))
+ elif structure_sys.mode == 'vqa':
+ raise NotImplementedError
+ # draw_img = draw_ser_results(img, res, args.vis_font_path)
+ # img_save_path = os.path.join(save_folder, img_name + '.jpg')
+ if res != []:
+ cv2.imwrite(img_save_path, draw_img)
+ logger.info('result save to {}'.format(img_save_path))
+ if args.recovery and res != []:
+ from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx
+ h, w, _ = img.shape
+ res = sorted_layout_boxes(res, w)
+ all_res += res
+
+ if args.recovery and all_res != []:
+ try:
+ convert_info_docx(img, all_res, save_folder, img_name,
+ args.save_pdf)
+ except Exception as ex:
+ logger.error(
+ "error in layout recovery image:{}, err msg: {}".format(
+ image_file, ex))
+ continue
+
logger.info("Predict time : {:.3f}s".format(time_dict['all']))
diff --git a/ppstructure/recovery/README.md b/ppstructure/recovery/README.md
index 883dbef3e829dfa213644b610af1ca279dac8641..713d0307dbbd66664db15d19df484af76efea75a 100644
--- a/ppstructure/recovery/README.md
+++ b/ppstructure/recovery/README.md
@@ -78,9 +78,27 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
# Download the ultra-lightweight English table inch model and unzip it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+# Download the layout model of publaynet dataset and unzip it
+wget
+https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar picodet_lcnet_x1_0_layout_infer.tar
cd ..
# run
-python3 predict_system.py --det_model_dir=inference/en_PP-OCRv3_det_infer --rec_model_dir=inference/en_PP-OCRv3_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --rec_char_dict_path=../ppocr/utils/en_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output ./output/table --rec_image_shape=3,48,320 --vis_font_path=../doc/fonts/simfang.ttf --recovery=True --image_dir=./docs/table/1.png
+python3 predict_system.py \
+ --image_dir=./docs/table/1.png \
+ --det_model_dir=inference/en_PP-OCRv3_det_infer \
+ --rec_model_dir=inference/en_PP-OCRv3_rec_infe \
+ --rec_char_dict_path=../ppocr/utils/en_dict.txt \
+ --output=../output/ \
+ --table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
+ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
+ --table_max_len=488 \
+ --layout_model_dir=inference/picodet_lcnet_x1_0_layout_infer \
+ --layout_dict_path=../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf \
+ --recovery=True \
+ --save_pdf=False
```
-After running, the docx of each picture will be saved in the directory specified by the output field
\ No newline at end of file
+After running, the docx of each picture will be saved in the directory specified by the output field
+
+Recovery table to Word code[table_process.py] reference:https://github.com/pqzx/html2docx.git
\ No newline at end of file
diff --git a/ppstructure/recovery/README_ch.md b/ppstructure/recovery/README_ch.md
index 5a05abffd0399387bc0d22d878e64d03d8894a79..14ca8836a0332a5b0e119be4bf6bcb36fb011d1e 100644
--- a/ppstructure/recovery/README_ch.md
+++ b/ppstructure/recovery/README_ch.md
@@ -35,21 +35,15 @@
python3 -m pip install --upgrade pip
# GPU安装
-python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
+python3 -m pip install "paddlepaddle-gpu>=2.3" -i https://mirror.baidu.com/pypi/simple
# CPU安装
-python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
+python3 -m pip install "paddlepaddle>=2.3" -i https://mirror.baidu.com/pypi/simple
```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
-* **(2)安装依赖**
-
-```bash
-python3 -m pip install -r ppstructure/recovery/requirements.txt
-```
-
### 2.2 安装PaddleOCR
@@ -87,11 +81,28 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
# 下载英文轻量级PP-OCRv3模型的识别模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
# 下载超轻量级英文表格英寸模型并解压
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
+# 下载英文版面分析模型
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar picodet_lcnet_x1_0_layout_infer.tar
cd ..
+
# 执行预测
-python3 predict_system.py --det_model_dir=inference/en_PP-OCRv3_det_infer --rec_model_dir=inference/en_PP-OCRv3_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --rec_char_dict_path=../ppocr/utils/en_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output ./output/table --rec_image_shape=3,48,320 --vis_font_path=../doc/fonts/simfang.ttf --recovery=True --image_dir=./docs/table/1.png
+python3 predict_system.py \
+ --image_dir=./docs/table/1.png \
+ --det_model_dir=inference/en_PP-OCRv3_det_infer \
+ --rec_model_dir=inference/en_PP-OCRv3_rec_infe \
+ --rec_char_dict_path=../ppocr/utils/en_dict.txt \
+ --output=../output/ \
+ --table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
+ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
+ --table_max_len=488 \
+ --layout_model_dir=inference/picodet_lcnet_x1_0_layout_infer \
+ --layout_dict_path=../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf \
+ --recovery=True \
+ --save_pdf=False
```
-运行完成后,每张图片的docx文档会保存到output字段指定的目录下
+运行完成后,每张图片的docx文档会保存到`output`字段指定的目录下
+表格恢复到Word代码[table_process.py]来自:https://github.com/pqzx/html2docx.git
diff --git a/ppstructure/recovery/recovery_to_doc.py b/ppstructure/recovery/recovery_to_doc.py
index 5278217d5b983008d357b6b1be3ab1b883a4939d..4401b1f27cf10f8483ee9b2b4a61315ad6aad264 100644
--- a/ppstructure/recovery/recovery_to_doc.py
+++ b/ppstructure/recovery/recovery_to_doc.py
@@ -22,21 +22,23 @@ from docx import shared
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.section import WD_SECTION
from docx.oxml.ns import qn
+from docx.enum.table import WD_TABLE_ALIGNMENT
+
+from table_process import HtmlToDocx
from ppocr.utils.logging import get_logger
logger = get_logger()
-def convert_info_docx(img, res, save_folder, img_name):
+def convert_info_docx(img, res, save_folder, img_name, save_pdf):
doc = Document()
doc.styles['Normal'].font.name = 'Times New Roman'
doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体')
doc.styles['Normal'].font.size = shared.Pt(6.5)
- h, w, _ = img.shape
- res = sorted_layout_boxes(res, w)
flag = 1
for i, region in enumerate(res):
+ img_idx = region['img_idx']
if flag == 2 and region['layout'] == 'single':
section = doc.add_section(WD_SECTION.CONTINUOUS)
section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '1')
@@ -46,10 +48,10 @@ def convert_info_docx(img, res, save_folder, img_name):
section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '2')
flag = 2
- if region['type'] == 'Figure':
+ if region['type'].lower() == 'figure':
excel_save_folder = os.path.join(save_folder, img_name)
img_path = os.path.join(excel_save_folder,
- '{}.jpg'.format(region['bbox']))
+ '{}_{}.jpg'.format(region['bbox'], img_idx))
paragraph_pic = doc.add_paragraph()
paragraph_pic.alignment = WD_ALIGN_PARAGRAPH.CENTER
run = paragraph_pic.add_run("")
@@ -57,40 +59,38 @@ def convert_info_docx(img, res, save_folder, img_name):
run.add_picture(img_path, width=shared.Inches(5))
elif flag == 2:
run.add_picture(img_path, width=shared.Inches(2))
- elif region['type'] == 'Title':
+ elif region['type'].lower() == 'title':
doc.add_heading(region['res'][0]['text'])
- elif region['type'] == 'Text':
+ elif region['type'].lower() == 'table':
+ paragraph = doc.add_paragraph()
+ new_parser = HtmlToDocx()
+ new_parser.table_style = 'TableGrid'
+ table = new_parser.handle_table(html=region['res']['html'])
+ new_table = deepcopy(table)
+ new_table.alignment = WD_TABLE_ALIGNMENT.CENTER
+ paragraph.add_run().element.addnext(new_table._tbl)
+
+ else:
paragraph = doc.add_paragraph()
paragraph_format = paragraph.paragraph_format
for i, line in enumerate(region['res']):
if i == 0:
paragraph_format.first_line_indent = shared.Inches(0.25)
text_run = paragraph.add_run(line['text'] + ' ')
- text_run.font.size = shared.Pt(9)
- elif region['type'] == 'Table':
- pypandoc.convert(
- source=region['res']['html'],
- format='html',
- to='docx',
- outputfile='tmp.docx')
- tmp_doc = Document('tmp.docx')
- paragraph = doc.add_paragraph()
-
- table = tmp_doc.tables[0]
- new_table = deepcopy(table)
- new_table.style = doc.styles['Table Grid']
- from docx.enum.table import WD_TABLE_ALIGNMENT
- new_table.alignment = WD_TABLE_ALIGNMENT.CENTER
- paragraph.add_run().element.addnext(new_table._tbl)
- os.remove('tmp.docx')
- else:
- continue
+ text_run.font.size = shared.Pt(10)
# save to docx
docx_path = os.path.join(save_folder, '{}.docx'.format(img_name))
doc.save(docx_path)
logger.info('docx save to {}'.format(docx_path))
+ # save to pdf
+ if save_pdf:
+ pdf = os.path.join(save_folder, '{}.pdf'.format(img_name))
+ from docx2pdf import convert
+ convert(docx_path, pdf_path)
+ logger.info('pdf save to {}'.format(pdf))
+
def sorted_layout_boxes(res, w):
"""
diff --git a/ppstructure/recovery/requirements.txt b/ppstructure/recovery/requirements.txt
index 04187baa2a72d2ac60f0a4e5ce643f882b7255fb..5ba3099d64574954c65ac8169798759dd7c053ac 100644
--- a/ppstructure/recovery/requirements.txt
+++ b/ppstructure/recovery/requirements.txt
@@ -1,3 +1,5 @@
-opencv-contrib-python==4.4.0.46
pypandoc
-python-docx
\ No newline at end of file
+python-docx
+docx2pdf
+fitz
+PyMuPDF
\ No newline at end of file
diff --git a/ppstructure/recovery/table_process.py b/ppstructure/recovery/table_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..243aaf8933791bf4704964d9665173fe70982f95
--- /dev/null
+++ b/ppstructure/recovery/table_process.py
@@ -0,0 +1,632 @@
+
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:https://github.com/pqzx/html2docx/blob/8f6695a778c68befb302e48ac0ed5201ddbd4524/htmldocx/h2d.py
+
+"""
+import re, argparse
+import io, os
+import urllib.request
+from urllib.parse import urlparse
+from html.parser import HTMLParser
+
+import docx, docx.table
+from docx import Document
+from docx.shared import RGBColor, Pt, Inches
+from docx.enum.text import WD_COLOR, WD_ALIGN_PARAGRAPH
+from docx.oxml import OxmlElement
+from docx.oxml.ns import qn
+
+from bs4 import BeautifulSoup
+
+# values in inches
+INDENT = 0.25
+LIST_INDENT = 0.5
+MAX_INDENT = 5.5 # To stop indents going off the page
+
+# Style to use with tables. By default no style is used.
+DEFAULT_TABLE_STYLE = None
+
+# Style to use with paragraphs. By default no style is used.
+DEFAULT_PARAGRAPH_STYLE = None
+
+
+def get_filename_from_url(url):
+ return os.path.basename(urlparse(url).path)
+
+def is_url(url):
+ """
+ Not to be used for actually validating a url, but in our use case we only
+ care if it's a url or a file path, and they're pretty distinguishable
+ """
+ parts = urlparse(url)
+ return all([parts.scheme, parts.netloc, parts.path])
+
+def fetch_image(url):
+ """
+ Attempts to fetch an image from a url.
+ If successful returns a bytes object, else returns None
+ :return:
+ """
+ try:
+ with urllib.request.urlopen(url) as response:
+ # security flaw?
+ return io.BytesIO(response.read())
+ except urllib.error.URLError:
+ return None
+
+def remove_last_occurence(ls, x):
+ ls.pop(len(ls) - ls[::-1].index(x) - 1)
+
+def remove_whitespace(string, leading=False, trailing=False):
+ """Remove white space from a string.
+ Args:
+ string(str): The string to remove white space from.
+ leading(bool, optional): Remove leading new lines when True.
+ trailing(bool, optional): Remove trailing new lines when False.
+ Returns:
+ str: The input string with new line characters removed and white space squashed.
+ Examples:
+ Single or multiple new line characters are replaced with space.
+ >>> remove_whitespace("abc\\ndef")
+ 'abc def'
+ >>> remove_whitespace("abc\\n\\n\\ndef")
+ 'abc def'
+ New line characters surrounded by white space are replaced with a single space.
+ >>> remove_whitespace("abc \\n \\n \\n def")
+ 'abc def'
+ >>> remove_whitespace("abc \\n \\n \\n def")
+ 'abc def'
+ Leading and trailing new lines are replaced with a single space.
+ >>> remove_whitespace("\\nabc")
+ ' abc'
+ >>> remove_whitespace(" \\n abc")
+ ' abc'
+ >>> remove_whitespace("abc\\n")
+ 'abc '
+ >>> remove_whitespace("abc \\n ")
+ 'abc '
+ Use ``leading=True`` to remove leading new line characters, including any surrounding
+ white space:
+ >>> remove_whitespace("\\nabc", leading=True)
+ 'abc'
+ >>> remove_whitespace(" \\n abc", leading=True)
+ 'abc'
+ Use ``trailing=True`` to remove trailing new line characters, including any surrounding
+ white space:
+ >>> remove_whitespace("abc \\n ", trailing=True)
+ 'abc'
+ """
+ # Remove any leading new line characters along with any surrounding white space
+ if leading:
+ string = re.sub(r'^\s*\n+\s*', '', string)
+
+ # Remove any trailing new line characters along with any surrounding white space
+ if trailing:
+ string = re.sub(r'\s*\n+\s*$', '', string)
+
+ # Replace new line characters and absorb any surrounding space.
+ string = re.sub(r'\s*\n\s*', ' ', string)
+ # TODO need some way to get rid of extra spaces in e.g. text text
+ return re.sub(r'\s+', ' ', string)
+
+def delete_paragraph(paragraph):
+ # https://github.com/python-openxml/python-docx/issues/33#issuecomment-77661907
+ p = paragraph._element
+ p.getparent().remove(p)
+ p._p = p._element = None
+
+font_styles = {
+ 'b': 'bold',
+ 'strong': 'bold',
+ 'em': 'italic',
+ 'i': 'italic',
+ 'u': 'underline',
+ 's': 'strike',
+ 'sup': 'superscript',
+ 'sub': 'subscript',
+ 'th': 'bold',
+}
+
+font_names = {
+ 'code': 'Courier',
+ 'pre': 'Courier',
+}
+
+styles = {
+ 'LIST_BULLET': 'List Bullet',
+ 'LIST_NUMBER': 'List Number',
+}
+
+class HtmlToDocx(HTMLParser):
+
+ def __init__(self):
+ super().__init__()
+ self.options = {
+ 'fix-html': True,
+ 'images': True,
+ 'tables': True,
+ 'styles': True,
+ }
+ self.table_row_selectors = [
+ 'table > tr',
+ 'table > thead > tr',
+ 'table > tbody > tr',
+ 'table > tfoot > tr'
+ ]
+ self.table_style = DEFAULT_TABLE_STYLE
+ self.paragraph_style = DEFAULT_PARAGRAPH_STYLE
+
+ def set_initial_attrs(self, document=None):
+ self.tags = {
+ 'span': [],
+ 'list': [],
+ }
+ if document:
+ self.doc = document
+ else:
+ self.doc = Document()
+ self.bs = self.options['fix-html'] # whether or not to clean with BeautifulSoup
+ self.document = self.doc
+ self.include_tables = True #TODO add this option back in?
+ self.include_images = self.options['images']
+ self.include_styles = self.options['styles']
+ self.paragraph = None
+ self.skip = False
+ self.skip_tag = None
+ self.instances_to_skip = 0
+
+ def copy_settings_from(self, other):
+ """Copy settings from another instance of HtmlToDocx"""
+ self.table_style = other.table_style
+ self.paragraph_style = other.paragraph_style
+
+ def get_cell_html(self, soup):
+ # Returns string of td element with opening and closing tags removed
+ # Cannot use find_all as it only finds element tags and does not find text which
+ # is not inside an element
+ return ' '.join([str(i) for i in soup.contents])
+
+ def add_styles_to_paragraph(self, style):
+ if 'text-align' in style:
+ align = style['text-align']
+ if align == 'center':
+ self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.CENTER
+ elif align == 'right':
+ self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.RIGHT
+ elif align == 'justify':
+ self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY
+ if 'margin-left' in style:
+ margin = style['margin-left']
+ units = re.sub(r'[0-9]+', '', margin)
+ margin = int(float(re.sub(r'[a-z]+', '', margin)))
+ if units == 'px':
+ self.paragraph.paragraph_format.left_indent = Inches(min(margin // 10 * INDENT, MAX_INDENT))
+ # TODO handle non px units
+
+ def add_styles_to_run(self, style):
+ if 'color' in style:
+ if 'rgb' in style['color']:
+ color = re.sub(r'[a-z()]+', '', style['color'])
+ colors = [int(x) for x in color.split(',')]
+ elif '#' in style['color']:
+ color = style['color'].lstrip('#')
+ colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
+ else:
+ colors = [0, 0, 0]
+ # TODO map colors to named colors (and extended colors...)
+ # For now set color to black to prevent crashing
+ self.run.font.color.rgb = RGBColor(*colors)
+
+ if 'background-color' in style:
+ if 'rgb' in style['background-color']:
+ color = color = re.sub(r'[a-z()]+', '', style['background-color'])
+ colors = [int(x) for x in color.split(',')]
+ elif '#' in style['background-color']:
+ color = style['background-color'].lstrip('#')
+ colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
+ else:
+ colors = [0, 0, 0]
+ # TODO map colors to named colors (and extended colors...)
+ # For now set color to black to prevent crashing
+ self.run.font.highlight_color = WD_COLOR.GRAY_25 #TODO: map colors
+
+ def apply_paragraph_style(self, style=None):
+ try:
+ if style:
+ self.paragraph.style = style
+ elif self.paragraph_style:
+ self.paragraph.style = self.paragraph_style
+ except KeyError as e:
+ raise ValueError(f"Unable to apply style {self.paragraph_style}.") from e
+
+ def parse_dict_string(self, string, separator=';'):
+ new_string = string.replace(" ", '').split(separator)
+ string_dict = dict([x.split(':') for x in new_string if ':' in x])
+ return string_dict
+
+ def handle_li(self):
+ # check list stack to determine style and depth
+ list_depth = len(self.tags['list'])
+ if list_depth:
+ list_type = self.tags['list'][-1]
+ else:
+ list_type = 'ul' # assign unordered if no tag
+
+ if list_type == 'ol':
+ list_style = styles['LIST_NUMBER']
+ else:
+ list_style = styles['LIST_BULLET']
+
+ self.paragraph = self.doc.add_paragraph(style=list_style)
+ self.paragraph.paragraph_format.left_indent = Inches(min(list_depth * LIST_INDENT, MAX_INDENT))
+ self.paragraph.paragraph_format.line_spacing = 1
+
+ def add_image_to_cell(self, cell, image):
+ # python-docx doesn't have method yet for adding images to table cells. For now we use this
+ paragraph = cell.add_paragraph()
+ run = paragraph.add_run()
+ run.add_picture(image)
+
+ def handle_img(self, current_attrs):
+ if not self.include_images:
+ self.skip = True
+ self.skip_tag = 'img'
+ return
+ src = current_attrs['src']
+ # fetch image
+ src_is_url = is_url(src)
+ if src_is_url:
+ try:
+ image = fetch_image(src)
+ except urllib.error.URLError:
+ image = None
+ else:
+ image = src
+ # add image to doc
+ if image:
+ try:
+ if isinstance(self.doc, docx.document.Document):
+ self.doc.add_picture(image)
+ else:
+ self.add_image_to_cell(self.doc, image)
+ except FileNotFoundError:
+ image = None
+ if not image:
+ if src_is_url:
+ self.doc.add_paragraph("" % src)
+ else:
+ # avoid exposing filepaths in document
+ self.doc.add_paragraph("" % get_filename_from_url(src))
+
+
+ def handle_table(self, html):
+ """
+ To handle nested tables, we will parse tables manually as follows:
+ Get table soup
+ Create docx table
+ Iterate over soup and fill docx table with new instances of this parser
+ Tell HTMLParser to ignore any tags until the corresponding closing table tag
+ """
+ doc = Document()
+ table_soup = BeautifulSoup(html, 'html.parser')
+ rows, cols_len = self.get_table_dimensions(table_soup)
+ table = doc.add_table(len(rows), cols_len)
+ table.style = doc.styles['Table Grid']
+ cell_row = 0
+ for index, row in enumerate(rows):
+ cols = self.get_table_columns(row)
+ cell_col = 0
+ for col in cols:
+ colspan = int(col.attrs.get('colspan', 1))
+ rowspan = int(col.attrs.get('rowspan', 1))
+
+ cell_html = self.get_cell_html(col)
+
+ if col.name == 'th':
+ cell_html = "%s" % cell_html
+ docx_cell = table.cell(cell_row, cell_col)
+ while docx_cell.text != '': # Skip the merged cell
+ cell_col += 1
+ docx_cell = table.cell(cell_row, cell_col)
+
+ cell_to_merge = table.cell(cell_row + rowspan - 1, cell_col + colspan - 1)
+ if docx_cell != cell_to_merge:
+ docx_cell.merge(cell_to_merge)
+
+ child_parser = HtmlToDocx()
+ child_parser.copy_settings_from(self)
+
+ child_parser.add_html_to_cell(cell_html or ' ', docx_cell) # occupy the position
+
+ cell_col += colspan
+ cell_row += 1
+
+ # skip all tags until corresponding closing tag
+ self.instances_to_skip = len(table_soup.find_all('table'))
+ self.skip_tag = 'table'
+ self.skip = True
+ self.table = None
+ return table
+
+ def handle_link(self, href, text):
+ # Link requires a relationship
+ is_external = href.startswith('http')
+ rel_id = self.paragraph.part.relate_to(
+ href,
+ docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK,
+ is_external=True # don't support anchor links for this library yet
+ )
+
+ # Create the w:hyperlink tag and add needed values
+ hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink')
+ hyperlink.set(docx.oxml.shared.qn('r:id'), rel_id)
+
+
+ # Create sub-run
+ subrun = self.paragraph.add_run()
+ rPr = docx.oxml.shared.OxmlElement('w:rPr')
+
+ # add default color
+ c = docx.oxml.shared.OxmlElement('w:color')
+ c.set(docx.oxml.shared.qn('w:val'), "0000EE")
+ rPr.append(c)
+
+ # add underline
+ u = docx.oxml.shared.OxmlElement('w:u')
+ u.set(docx.oxml.shared.qn('w:val'), 'single')
+ rPr.append(u)
+
+ subrun._r.append(rPr)
+ subrun._r.text = text
+
+ # Add subrun to hyperlink
+ hyperlink.append(subrun._r)
+
+ # Add hyperlink to run
+ self.paragraph._p.append(hyperlink)
+
+ def handle_starttag(self, tag, attrs):
+ if self.skip:
+ return
+ if tag == 'head':
+ self.skip = True
+ self.skip_tag = tag
+ self.instances_to_skip = 0
+ return
+ elif tag == 'body':
+ return
+
+ current_attrs = dict(attrs)
+
+ if tag == 'span':
+ self.tags['span'].append(current_attrs)
+ return
+ elif tag == 'ol' or tag == 'ul':
+ self.tags['list'].append(tag)
+ return # don't apply styles for now
+ elif tag == 'br':
+ self.run.add_break()
+ return
+
+ self.tags[tag] = current_attrs
+ if tag in ['p', 'pre']:
+ self.paragraph = self.doc.add_paragraph()
+ self.apply_paragraph_style()
+
+ elif tag == 'li':
+ self.handle_li()
+
+ elif tag == "hr":
+
+ # This implementation was taken from:
+ # https://github.com/python-openxml/python-docx/issues/105#issuecomment-62806373
+
+ self.paragraph = self.doc.add_paragraph()
+ pPr = self.paragraph._p.get_or_add_pPr()
+ pBdr = OxmlElement('w:pBdr')
+ pPr.insert_element_before(pBdr,
+ 'w:shd', 'w:tabs', 'w:suppressAutoHyphens', 'w:kinsoku', 'w:wordWrap',
+ 'w:overflowPunct', 'w:topLinePunct', 'w:autoSpaceDE', 'w:autoSpaceDN',
+ 'w:bidi', 'w:adjustRightInd', 'w:snapToGrid', 'w:spacing', 'w:ind',
+ 'w:contextualSpacing', 'w:mirrorIndents', 'w:suppressOverlap', 'w:jc',
+ 'w:textDirection', 'w:textAlignment', 'w:textboxTightWrap',
+ 'w:outlineLvl', 'w:divId', 'w:cnfStyle', 'w:rPr', 'w:sectPr',
+ 'w:pPrChange'
+ )
+ bottom = OxmlElement('w:bottom')
+ bottom.set(qn('w:val'), 'single')
+ bottom.set(qn('w:sz'), '6')
+ bottom.set(qn('w:space'), '1')
+ bottom.set(qn('w:color'), 'auto')
+ pBdr.append(bottom)
+
+ elif re.match('h[1-9]', tag):
+ if isinstance(self.doc, docx.document.Document):
+ h_size = int(tag[1])
+ self.paragraph = self.doc.add_heading(level=min(h_size, 9))
+ else:
+ self.paragraph = self.doc.add_paragraph()
+
+ elif tag == 'img':
+ self.handle_img(current_attrs)
+ return
+
+ elif tag == 'table':
+ self.handle_table()
+ return
+
+ # set new run reference point in case of leading line breaks
+ if tag in ['p', 'li', 'pre']:
+ self.run = self.paragraph.add_run()
+
+ # add style
+ if not self.include_styles:
+ return
+ if 'style' in current_attrs and self.paragraph:
+ style = self.parse_dict_string(current_attrs['style'])
+ self.add_styles_to_paragraph(style)
+
+ def handle_endtag(self, tag):
+ if self.skip:
+ if not tag == self.skip_tag:
+ return
+
+ if self.instances_to_skip > 0:
+ self.instances_to_skip -= 1
+ return
+
+ self.skip = False
+ self.skip_tag = None
+ self.paragraph = None
+
+ if tag == 'span':
+ if self.tags['span']:
+ self.tags['span'].pop()
+ return
+ elif tag == 'ol' or tag == 'ul':
+ remove_last_occurence(self.tags['list'], tag)
+ return
+ elif tag == 'table':
+ self.table_no += 1
+ self.table = None
+ self.doc = self.document
+ self.paragraph = None
+
+ if tag in self.tags:
+ self.tags.pop(tag)
+ # maybe set relevant reference to None?
+
+ def handle_data(self, data):
+ if self.skip:
+ return
+
+ # Only remove white space if we're not in a pre block.
+ if 'pre' not in self.tags:
+ # remove leading and trailing whitespace in all instances
+ data = remove_whitespace(data, True, True)
+
+ if not self.paragraph:
+ self.paragraph = self.doc.add_paragraph()
+ self.apply_paragraph_style()
+
+ # There can only be one nested link in a valid html document
+ # You cannot have interactive content in an A tag, this includes links
+ # https://html.spec.whatwg.org/#interactive-content
+ link = self.tags.get('a')
+ if link:
+ self.handle_link(link['href'], data)
+ else:
+ # If there's a link, dont put the data directly in the run
+ self.run = self.paragraph.add_run(data)
+ spans = self.tags['span']
+ for span in spans:
+ if 'style' in span:
+ style = self.parse_dict_string(span['style'])
+ self.add_styles_to_run(style)
+
+ # add font style and name
+ for tag in self.tags:
+ if tag in font_styles:
+ font_style = font_styles[tag]
+ setattr(self.run.font, font_style, True)
+
+ if tag in font_names:
+ font_name = font_names[tag]
+ self.run.font.name = font_name
+
+ def ignore_nested_tables(self, tables_soup):
+ """
+ Returns array containing only the highest level tables
+ Operates on the assumption that bs4 returns child elements immediately after
+ the parent element in `find_all`. If this changes in the future, this method will need to be updated
+ :return:
+ """
+ new_tables = []
+ nest = 0
+ for table in tables_soup:
+ if nest:
+ nest -= 1
+ continue
+ new_tables.append(table)
+ nest = len(table.find_all('table'))
+ return new_tables
+
+ def get_table_rows(self, table_soup):
+ # If there's a header, body, footer or direct child tr tags, add row dimensions from there
+ return table_soup.select(', '.join(self.table_row_selectors), recursive=False)
+
+ def get_table_columns(self, row):
+ # Get all columns for the specified row tag.
+ return row.find_all(['th', 'td'], recursive=False) if row else []
+
+ def get_table_dimensions(self, table_soup):
+ # Get rows for the table
+ rows = self.get_table_rows(table_soup)
+ # Table is either empty or has non-direct children between table and tr tags
+ # Thus the row dimensions and column dimensions are assumed to be 0
+
+ cols = self.get_table_columns(rows[0]) if rows else []
+ # Add colspan calculation column number
+ col_count = 0
+ for col in cols:
+ colspan = col.attrs.get('colspan', 1)
+ col_count += int(colspan)
+
+ # return len(rows), col_count
+ return rows, col_count
+
+ def get_tables(self):
+ if not hasattr(self, 'soup'):
+ self.include_tables = False
+ return
+ # find other way to do it, or require this dependency?
+ self.tables = self.ignore_nested_tables(self.soup.find_all('table'))
+ self.table_no = 0
+
+ def run_process(self, html):
+ if self.bs and BeautifulSoup:
+ self.soup = BeautifulSoup(html, 'html.parser')
+ html = str(self.soup)
+ if self.include_tables:
+ self.get_tables()
+ self.feed(html)
+
+ def add_html_to_document(self, html, document):
+ if not isinstance(html, str):
+ raise ValueError('First argument needs to be a %s' % str)
+ elif not isinstance(document, docx.document.Document) and not isinstance(document, docx.table._Cell):
+ raise ValueError('Second argument needs to be a %s' % docx.document.Document)
+ self.set_initial_attrs(document)
+ self.run_process(html)
+
+ def add_html_to_cell(self, html, cell):
+ self.set_initial_attrs(cell)
+ self.run_process(html)
+
+ def parse_html_file(self, filename_html, filename_docx=None):
+ with open(filename_html, 'r') as infile:
+ html = infile.read()
+ self.set_initial_attrs()
+ self.run_process(html)
+ if not filename_docx:
+ path, filename = os.path.split(filename_html)
+ filename_docx = '%s/new_docx_file_%s' % (path, filename)
+ self.doc.save('%s.docx' % filename_docx)
+
+ def parse_html_string(self, html):
+ self.set_initial_attrs()
+ self.run_process(html)
+ return self.doc
\ No newline at end of file
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index cda4c063bccbd2aff34cf25768866feb4d68dc2d..2cf20eb53f87a8f8fbe2bdb4c3ead77f40120370 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -38,7 +38,7 @@ def init_args():
parser.add_argument(
"--layout_dict_path",
type=str,
- default="../ppocr/utils/dict/layout_publaynet_dict.txt")
+ default="../ppocr/utils/dict/layout_dict/layout_pubalynet_dict.txt")
parser.add_argument(
"--layout_score_threshold",
type=float,
@@ -89,6 +89,11 @@ def init_args():
type=bool,
default=False,
help='Whether to enable layout of recovery')
+ parser.add_argument(
+ "--save_pdf",
+ type=bool,
+ default=False,
+ help='Whether to save pdf file')
return parser
diff --git a/test_tipc/common_func.sh b/test_tipc/common_func.sh
index f7d8a1e04adee9d32332eda8cb5913bbaf168481..1bbf829165323b76341461b297b71102462d83af 100644
--- a/test_tipc/common_func.sh
+++ b/test_tipc/common_func.sh
@@ -58,10 +58,11 @@ function status_check(){
run_command=$2
run_log=$3
model_name=$4
+ log_path=$5
if [ $last_status -eq 0 ]; then
- echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log}
+ echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log}
else
- echo -e "\033[33m Run failed with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log}
+ echo -e "\033[33m Run failed with command - ${model_name} - ${run_command} - ${log_path} \033[0m" | tee -a ${run_log}
fi
}
diff --git a/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt b/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt
deleted file mode 100644
index df88c0e5434511fb48deac699e8f67fc535765d3..0000000000000000000000000000000000000000
--- a/test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt
+++ /dev/null
@@ -1,58 +0,0 @@
-===========================train_params===========================
-model_name:det_r18_db_v2_0
-python:python3.7
-gpu_list:0|0,1
-Global.use_gpu:True|True
-Global.auto_cast:null
-Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
-Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4
-Global.pretrained_model:null
-train_model_name:latest
-train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
-null:null
-##
-trainer:norm_train
-norm_train:tools/train.py -c configs/det/det_res18_db_v2.0.yml -o
-quant_export:null
-fpgm_export:null
-distill_train:null
-null:null
-null:null
-##
-===========================eval_params===========================
-eval:null
-null:null
-##
-===========================infer_params===========================
-Global.save_inference_dir:./output/
-Global.checkpoints:
-norm_export:null
-quant_export:null
-fpgm_export:null
-distill_export:null
-export1:null
-export2:null
-##
-train_model:null
-infer_export:null
-infer_quant:False
-inference:tools/infer/predict_det.py
---use_gpu:True|False
---enable_mkldnn:False
---cpu_threads:6
---rec_batch_num:1
---use_tensorrt:False
---precision:fp32
---det_model_dir:
---image_dir:./inference/ch_det_data_50/all-sum-510/
---save_log_path:null
---benchmark:True
-null:null
-===========================infer_benchmark_params==========================
-random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
-===========================train_benchmark_params==========================
-batch_size:8|16
-fp_items:fp32|fp16
-epoch:15
---profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
diff --git a/test_tipc/configs/en_table_structure/train_infer_python.txt b/test_tipc/configs/en_table_structure/train_infer_python.txt
index 633b6185d976ac61408283025bd4ba305187317d..3fd5dc9f60a9621026d488e5654cd7e1421e8b65 100644
--- a/test_tipc/configs/en_table_structure/train_infer_python.txt
+++ b/test_tipc/configs/en_table_structure/train_infer_python.txt
@@ -54,6 +54,6 @@ random_infer_input:[{float32,[3,488,488]}]
===========================train_benchmark_params==========================
batch_size:32
fp_items:fp32|fp16
-epoch:1
+epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
index 34082bc193a2ebd8f4c7a9e7c9ce55dc8dbf8e40..5284ffabe2de4eb8bb000e7fb745ef2846ed6b64 100644
--- a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
+++ b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
@@ -52,7 +52,7 @@ null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_benchmark_params==========================
-batch_size:4
+batch_size:8
fp_items:fp32|fp16
epoch:3
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
diff --git a/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b5466d4478be27d6fd152ee467f7f25731c8dce0
--- /dev/null
+++ b/test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml
@@ -0,0 +1,111 @@
+Global:
+ use_gpu: true
+ epoch_num: 5
+ log_smooth_window: 20
+ print_batch_step: 20
+ save_model_dir: ./output/rec/rec_r31_robustscanner/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: ./inference/rec_inference
+ # for data or label process
+ character_dict_path: ppocr/utils/dict90.txt
+ max_text_length: &max_text_length 40
+ infer_mode: False
+ use_space_char: False
+ rm_symbol: True
+ save_res_path: ./output/rec/predicts_robustscanner.txt
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ name: Piecewise
+ decay_epochs: [3, 4]
+ values: [0.001, 0.0001, 0.00001]
+ regularizer:
+ name: 'L2'
+ factor: 0
+
+Architecture:
+ model_type: rec
+ algorithm: RobustScanner
+ Transform:
+ Backbone:
+ name: ResNet31
+ init_type: KaimingNormal
+ Head:
+ name: RobustScannerHead
+ enc_outchannles: 128
+ hybrid_dec_rnn_layers: 2
+ hybrid_dec_dropout: 0
+ position_dec_rnn_layers: 2
+ start_idx: 91
+ mask: True
+ padding_idx: 92
+ encode_value: False
+ max_text_length: *max_text_length
+
+Loss:
+ name: SARLoss
+
+PostProcess:
+ name: SARLabelDecode
+
+Metric:
+ name: RecMetric
+ is_filter: True
+
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SARLabelEncode: # Class handling label
+ - RobustScannerRecResizeImg:
+ image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
+ width_downsample_ratio: 0.25
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 16
+ drop_last: True
+ num_workers: 0
+ use_shared_memory: False
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SARLabelEncode: # Class handling label
+ - RobustScannerRecResizeImg:
+ image_shape: [3, 48, 48, 160]
+ max_text_length: *max_text_length
+ width_downsample_ratio: 0.25
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'valid_ratio', 'word_positons'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 16
+ num_workers: 0
+ use_shared_memory: False
+
diff --git a/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt b/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..07498c9e81ada9652343b8d8fff0f102d4684380
--- /dev/null
+++ b/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt
@@ -0,0 +1,54 @@
+===========================train_params===========================
+model_name:rec_r31_robustscanner
+python:python
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=5
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_r31_robustscanner/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/rec_r31_robustscanner.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --use_space_char=False --rec_algorithm="RobustScanner"
+--use_gpu:True|False
+--enable_mkldnn:True|False
+--cpu_threads:1|6
+--rec_batch_num:1|6
+--use_tensorrt:False|False
+--precision:fp32|int8
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,48,160]}]
+
diff --git a/test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml b/test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml
new file mode 100644
index 0000000000000000000000000000000000000000..860e4f53043138e7434d71a816fdf051048be6f7
--- /dev/null
+++ b/test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml
@@ -0,0 +1,108 @@
+Global:
+ use_gpu: true
+ epoch_num: 8
+ log_smooth_window: 200
+ print_batch_step: 200
+ save_model_dir: ./output/rec/r45_visionlan
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/en/word_2.png
+ # for data or label process
+ character_dict_path:
+ max_text_length: &max_text_length 25
+ training_step: &training_step LA
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_visionlan.txt
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 20.0
+ group_lr: true
+ training_step: *training_step
+ lr:
+ name: Piecewise
+ decay_epochs: [6]
+ values: [0.0001, 0.00001]
+ regularizer:
+ name: 'L2'
+ factor: 0
+
+Architecture:
+ model_type: rec
+ algorithm: VisionLAN
+ Transform:
+ Backbone:
+ name: ResNet45
+ strides: [2, 2, 2, 1, 1]
+ Head:
+ name: VLHead
+ n_layers: 3
+ n_position: 256
+ n_dim: 512
+ max_text_length: *max_text_length
+ training_step: *training_step
+
+Loss:
+ name: VLLoss
+ mode: *training_step
+ weight_res: 0.5
+ weight_mas: 0.5
+
+PostProcess:
+ name: VLLabelDecode
+
+Metric:
+ name: RecMetric
+ is_filter: true
+
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: RGB
+ channel_first: False
+ - ABINetRecAug:
+ - VLLabelEncode: # Class handling label
+ - VLRecResizeImg:
+ image_shape: [3, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 220
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: RGB
+ channel_first: False
+ - VLLabelEncode: # Class handling label
+ - VLRecResizeImg:
+ image_shape: [3, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 64
+ num_workers: 4
+
diff --git a/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt b/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c08ae7beb6c867bf36283e60dc1e70cfd9ee06a7
--- /dev/null
+++ b/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_r45_visionlan
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=32|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_r45_visionlan_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r45_visionlan/rec_r45_visionlan.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,64,256" --rec_algorithm="VisionLAN" --use_space_char=False
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1|6
+--use_tensorrt:False
+--precision:fp32
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,64,256]}]
diff --git a/test_tipc/readme.md b/test_tipc/readme.md
index f9e9d89e4198c1ad5fabdf58775c6f7b6d190322..1442ee1c86a7c1319446a0eb22c08287e1ce689a 100644
--- a/test_tipc/readme.md
+++ b/test_tipc/readme.md
@@ -54,6 +54,7 @@
| NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 混合精度 | - | - |
| SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 混合精度 | - | - |
| SPIN |rec_r32_gaspin_bilstm_att | 识别 | 支持 | 多机多卡 混合精度 | - | - |
+| RobustScanner |rec_r31_robustscanner | 识别 | 支持 | 多机多卡 混合精度 | - | - |
| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 混合精度 | - | - |
| TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡 混合精度 | - | - |
diff --git a/test_tipc/test_inference_cpp.sh b/test_tipc/test_inference_cpp.sh
index c0c7c18a38a46b00c839757e303049135a508691..aadaa8b0773632885138806861fc851ede503f3d 100644
--- a/test_tipc/test_inference_cpp.sh
+++ b/test_tipc/test_inference_cpp.sh
@@ -84,7 +84,7 @@ function func_cpp_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}" "${model_name}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
@@ -117,7 +117,7 @@ function func_cpp_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}" "${model_name}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
diff --git a/test_tipc/test_inference_python.sh b/test_tipc/test_inference_python.sh
index 2a31a468f0d54d1979e82c8f0da98cac6f4edcec..e9908df1f6049f9d38524dc6598499ddd2b58af8 100644
--- a/test_tipc/test_inference_python.sh
+++ b/test_tipc/test_inference_python.sh
@@ -88,7 +88,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}" "${model_name}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
@@ -119,7 +119,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}" "${model_name}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
@@ -146,14 +146,15 @@ if [ ${MODE} = "whole_infer" ]; then
for infer_model in ${infer_model_dir_list[*]}; do
# run export
if [ ${infer_run_exports[Count]} != "null" ];then
+ _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}_infermodel_${infer_model}.log"
save_infer_dir=$(dirname $infer_model)
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
- export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}"
+ export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${_save_log_path} 2>&1 "
echo ${infer_run_exports[Count]}
eval $export_cmd
status_export=$?
- status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
+ status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
else
save_infer_dir=${infer_model}
fi
diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh
index 78d79d0b8eaac782f98c1e883d091a001443f41a..bace6b2d4684e0ad40ffbd76b37a78ddf1e70722 100644
--- a/test_tipc/test_paddle2onnx.sh
+++ b/test_tipc/test_paddle2onnx.sh
@@ -66,7 +66,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
# trans rec
set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
@@ -78,7 +78,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
elif [[ ${model_name} =~ "det" ]]; then
# trans det
set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}")
@@ -91,7 +91,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}"
elif [[ ${model_name} =~ "rec" ]]; then
# trans rec
set_dirname=$(func_set_params "--model_dir" "${rec_infer_model_dir_value}")
@@ -104,7 +104,7 @@ function func_paddle2onnx(){
trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
fi
# python inference
@@ -127,7 +127,7 @@ function func_paddle2onnx(){
eval $infer_model_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
_save_log_path="${LOG_PATH}/paddle2onnx_infer_gpu.log"
set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}")
@@ -146,7 +146,7 @@ function func_paddle2onnx(){
eval $infer_model_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${infer_model_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
else
echo "Does not support hardware other than CPU and GPU Currently!"
fi
@@ -158,4 +158,4 @@ echo "################### run test ###################"
export Count=0
IFS="|"
-func_paddle2onnx
\ No newline at end of file
+func_paddle2onnx
diff --git a/test_tipc/test_ptq_inference_python.sh b/test_tipc/test_ptq_inference_python.sh
index e2939fd5e638ad0f6b4c44422a6fec6459903d1c..caf3d506029ee066aa5abebc25b739439b6e9d75 100644
--- a/test_tipc/test_ptq_inference_python.sh
+++ b/test_tipc/test_ptq_inference_python.sh
@@ -84,7 +84,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}" "${model_name}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
@@ -109,7 +109,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}" "${model_name}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
@@ -145,7 +145,7 @@ if [ ${MODE} = "whole_infer" ]; then
echo $export_cmd
eval $export_cmd
status_export=$?
- status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
+ status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
else
save_infer_dir=${infer_model}
fi
diff --git a/test_tipc/test_serving_infer_cpp.sh b/test_tipc/test_serving_infer_cpp.sh
index 0be6a45adf3105f088a96336dddfbe9ac612f19b..10ddecf3fa26805fef7bc6ae10d78ee5e741cd27 100644
--- a/test_tipc/test_serving_infer_cpp.sh
+++ b/test_tipc/test_serving_infer_cpp.sh
@@ -83,7 +83,7 @@ function func_serving(){
trans_model_cmd="${python_list[0]} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client} > ${trans_rec_log} 2>&1 "
eval $trans_model_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}"
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
python_list=(${python_list})
cd ${serving_dir_value}
@@ -95,14 +95,14 @@ function func_serving(){
web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 &"
eval $web_service_cpp_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
sleep 5s
_save_log_path="${LOG_PATH}/cpp_client_cpu.log"
cpp_client_cmd="${python_list[0]} ${cpp_client_py} ${det_client_value} ${rec_client_value} > ${_save_log_path} 2>&1"
eval $cpp_client_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
else
server_log_path="${LOG_PATH}/cpp_server_gpu.log"
@@ -114,7 +114,7 @@ function func_serving(){
eval $cpp_client_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${cpp_client_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
fi
done
diff --git a/test_tipc/test_serving_infer_python.sh b/test_tipc/test_serving_infer_python.sh
index 4b7dfcf785a3c8459cce95d55744dbcd4f97027a..c7d305d5d2dcd2ea1bf5a7c3254eea4231d59879 100644
--- a/test_tipc/test_serving_infer_python.sh
+++ b/test_tipc/test_serving_infer_python.sh
@@ -126,19 +126,19 @@ function func_serving(){
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
fi
sleep 2s
for pipeline in ${pipeline_py[*]}; do
@@ -147,7 +147,7 @@ function func_serving(){
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
sleep 2s
done
ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9
@@ -177,19 +177,19 @@ function func_serving(){
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
- status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" "${server_log_path}"
fi
sleep 2s
for pipeline in ${pipeline_py[*]}; do
@@ -198,7 +198,7 @@ function func_serving(){
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}"
+ status_check $last_status "${pipeline_cmd}" "${status_log}" "${model_name}" "${_save_log_path}"
sleep 2s
done
ps ux | grep -E 'web_service' | awk '{print $2}' | xargs kill -s 9
diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh
index 545cdbba2051c8123ef7f70f2aeb4b4b5a57b7c5..e182fa57f060c81af012a5da89b892bde02b4a2b 100644
--- a/test_tipc/test_train_inference_python.sh
+++ b/test_tipc/test_train_inference_python.sh
@@ -133,7 +133,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}" "${model_name}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
done
@@ -164,7 +164,7 @@ function func_inference(){
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}" "${model_name}"
+ status_check $last_status "${command}" "${status_log}" "${model_name}" "${_save_log_path}"
done
done
@@ -201,7 +201,7 @@ if [ ${MODE} = "whole_infer" ]; then
echo $export_cmd
eval $export_cmd
status_export=$?
- status_check $status_export "${export_cmd}" "${status_log}" "${model_name}"
+ status_check $status_export "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
else
save_infer_dir=${infer_model}
fi
@@ -298,7 +298,7 @@ else
# run train
eval $cmd
eval "cat ${save_log}/train.log >> ${save_log}.log"
- status_check $? "${cmd}" "${status_log}" "${model_name}"
+ status_check $? "${cmd}" "${status_log}" "${model_name}" "${save_log}.log"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
@@ -309,7 +309,7 @@ else
eval_log_path="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}_nodes_${nodes}_eval.log"
eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1} > ${eval_log_path} 2>&1 "
eval $eval_cmd
- status_check $? "${eval_cmd}" "${status_log}" "${model_name}"
+ status_check $? "${eval_cmd}" "${status_log}" "${model_name}" "${eval_log_path}"
fi
# run export model
if [ ${run_export} != "null" ]; then
@@ -320,7 +320,7 @@ else
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}")
export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
eval $export_cmd
- status_check $? "${export_cmd}" "${status_log}" "${model_name}"
+ status_check $? "${export_cmd}" "${status_log}" "${model_name}" "${export_log_path}"
#run inference
eval $env
diff --git a/tools/eval.py b/tools/eval.py
index 2fc53488efa2c4c475d31af47f69b3560e6cc69a..38d72d178db45a4787ddc09c865afba9222f385a 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -73,7 +73,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
- extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
+ extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
diff --git a/tools/export_model.py b/tools/export_model.py
index c6763374a634dfca125f64b63d7c85716f68f142..193988cc1b62a6c4536a8d2ec640e3e5fc81a79c 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -58,6 +58,8 @@ def export_single_model(model,
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"),
+ [paddle.static.InputSpec(
+ shape=[None], dtype="float32")]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SVTR":
@@ -109,6 +111,22 @@ def export_single_model(model,
shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "RobustScanner":
+ max_text_length = arch_config["Head"]["max_text_length"]
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 48, 160], dtype="float32"),
+
+ [
+ paddle.static.InputSpec(
+ shape=[None, ],
+ dtype="float32"),
+ paddle.static.InputSpec(
+ shape=[None, max_text_length],
+ dtype="int64")
+ ]
+ ]
+ model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
@@ -128,7 +146,7 @@ def export_single_model(model,
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
- infer_shape = [3, 48, -1] # for rec model, H must be 32
+ infer_shape = [3, 32, -1] # for rec model, H must be 32
if "Transform" in arch_config and arch_config[
"Transform"] is not None and arch_config["Transform"][
"name"] == "TPS":
@@ -234,4 +252,4 @@ def main():
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 449f69ed6a22cb12af8a6a2ef8f2eedc1aca087c..53dab6f26d8b84a224360f2fa6fe5f411eea751f 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -68,7 +68,7 @@ class TextRecognizer(object):
'name': 'SARLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
- }
+ }
elif self.rec_algorithm == "VisionLAN":
postprocess_params = {
'name': 'VLLabelDecode',
@@ -93,6 +93,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
+ elif self.rec_algorithm == "RobustScanner":
+ postprocess_params = {
+ 'name': 'SARLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char,
+ "rm_symbol": True
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
@@ -390,6 +397,18 @@ class TextRecognizer(object):
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
+ elif self.rec_algorithm == "RobustScanner":
+ norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
+ img_list[indices[ino]], self.rec_image_shape, width_downsample_ratio=0.25)
+ norm_img = norm_img[np.newaxis, :]
+ valid_ratio = np.expand_dims(valid_ratio, axis=0)
+ valid_ratios = []
+ valid_ratios.append(valid_ratio)
+ norm_img_batch.append(norm_img)
+ word_positions_list = []
+ word_positions = np.array(range(0, 40)).astype('int64')
+ word_positions = np.expand_dims(word_positions, axis=0)
+ word_positions_list.append(word_positions)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
@@ -437,10 +456,40 @@ class TextRecognizer(object):
preds = {"predict": outputs[2]}
elif self.rec_algorithm == "SAR":
valid_ratios = np.concatenate(valid_ratios)
+ inputs = [
+ norm_img_batch,
+ np.array(
+ [valid_ratios], dtype=np.float32),
+ ]
+ if self.use_onnx:
+ input_dict = {}
+ input_dict[self.input_tensor.name] = norm_img_batch
+ outputs = self.predictor.run(self.output_tensors,
+ input_dict)
+ preds = outputs[0]
+ else:
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_handle(
+ input_names[i])
+ input_tensor.copy_from_cpu(inputs[i])
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ if self.benchmark:
+ self.autolog.times.stamp()
+ preds = outputs[0]
+ elif self.rec_algorithm == "RobustScanner":
+ valid_ratios = np.concatenate(valid_ratios)
+ word_positions_list = np.concatenate(word_positions_list)
inputs = [
norm_img_batch,
valid_ratios,
+ word_positions_list
]
+
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 81d0196ccd6b86741e73524d9321618f3f5cc34b..1eebc73f31e6b48a473c20d907ca401ad919fe0b 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -231,89 +231,10 @@ def create_predictor(args, mode, logger):
)
config.enable_tuned_tensorrt_dynamic_shape(
args.shape_info_filename, True)
-
- use_dynamic_shape = True
- if mode == "det":
- min_input_shape = {
- "x": [1, 3, 50, 50],
- "conv2d_92.tmp_0": [1, 120, 20, 20],
- "conv2d_91.tmp_0": [1, 24, 10, 10],
- "conv2d_59.tmp_0": [1, 96, 20, 20],
- "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
- "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
- "conv2d_124.tmp_0": [1, 256, 20, 20],
- "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
- "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
- "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
- "elementwise_add_7": [1, 56, 2, 2],
- "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
- }
- max_input_shape = {
- "x": [1, 3, 1536, 1536],
- "conv2d_92.tmp_0": [1, 120, 400, 400],
- "conv2d_91.tmp_0": [1, 24, 200, 200],
- "conv2d_59.tmp_0": [1, 96, 400, 400],
- "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
- "conv2d_124.tmp_0": [1, 256, 400, 400],
- "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
- "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
- "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
- "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
- "elementwise_add_7": [1, 56, 400, 400],
- "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
- }
- opt_input_shape = {
- "x": [1, 3, 640, 640],
- "conv2d_92.tmp_0": [1, 120, 160, 160],
- "conv2d_91.tmp_0": [1, 24, 80, 80],
- "conv2d_59.tmp_0": [1, 96, 160, 160],
- "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
- "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
- "conv2d_124.tmp_0": [1, 256, 160, 160],
- "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
- "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
- "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
- "elementwise_add_7": [1, 56, 40, 40],
- "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
- }
- min_pact_shape = {
- "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
- "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
- "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
- "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
- }
- max_pact_shape = {
- "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
- "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
- "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
- "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
- }
- opt_pact_shape = {
- "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
- "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
- "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
- "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
- }
- min_input_shape.update(min_pact_shape)
- max_input_shape.update(max_pact_shape)
- opt_input_shape.update(opt_pact_shape)
- elif mode == "rec":
- if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]:
- use_dynamic_shape = False
- imgH = int(args.rec_image_shape.split(',')[-2])
- min_input_shape = {"x": [1, 3, imgH, 10]}
- max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]}
- opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]}
- config.exp_disable_tensorrt_ops(["transpose2"])
- elif mode == "cls":
- min_input_shape = {"x": [1, 3, 48, 10]}
- max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
- opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else:
- use_dynamic_shape = False
- if use_dynamic_shape:
- config.set_trt_dynamic_shape_info(
- min_input_shape, max_input_shape, opt_input_shape)
+ logger.info(
+ f"when using tensorrt, dynamic shape is a suggested option, you can use '--shape_info_filename=shape.txt' for offline dygnamic shape tuning"
+ )
elif args.use_xpu:
config.enable_xpu(10 * 1024 * 1024)
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index 182694e6cda12ead0e263bb94a7d6483a6f7f212..14b14544eb11e9fb0a0c2cdf92aff9d7cb4b5ba7 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -96,6 +96,8 @@ def main():
]
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
+ elif config['Architecture']['algorithm'] == "RobustScanner":
+ op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
@@ -131,12 +133,20 @@ def main():
if config['Architecture']['algorithm'] == "SAR":
valid_ratio = np.expand_dims(batch[-1], axis=0)
img_metas = [paddle.to_tensor(valid_ratio)]
+ if config['Architecture']['algorithm'] == "RobustScanner":
+ valid_ratio = np.expand_dims(batch[1], axis=0)
+ word_positons = np.expand_dims(batch[2], axis=0)
+ img_metas = [paddle.to_tensor(valid_ratio),
+ paddle.to_tensor(word_positons),
+ ]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
elif config['Architecture']['algorithm'] == "SAR":
preds = model(images, img_metas)
+ elif config['Architecture']['algorithm'] == "RobustScanner":
+ preds = model(images, img_metas)
else:
preds = model(images)
post_result = post_process_class(preds)
diff --git a/tools/infer_sr.py b/tools/infer_sr.py
index 0bc2f6aaa7c4400676268ec64d37e721af0f99c2..df4334f3427e57b9062dd819aa16c110fd771e8c 100755
--- a/tools/infer_sr.py
+++ b/tools/infer_sr.py
@@ -63,14 +63,14 @@ def main():
elif op_name in ['SRResize']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['imge_lr']
+ op[op_name]['keep_keys'] = ['img_lr']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
- save_res_path = config['Global'].get('save_res_path', "./infer_result")
- if not os.path.exists(os.path.dirname(save_res_path)):
- os.makedirs(os.path.dirname(save_res_path))
+ save_visual_path = config['Global'].get('save_visual', "infer_result/")
+ if not os.path.exists(os.path.dirname(save_visual_path)):
+ os.makedirs(os.path.dirname(save_visual_path))
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
@@ -87,7 +87,7 @@ def main():
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
img_name_pure = os.path.split(file)[-1]
- cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
+ cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure),
fm_sr[:, :, ::-1])
logger.info("The visualized image saved in infer_result/sr_{}".format(
img_name_pure))
diff --git a/tools/program.py b/tools/program.py
index b450bc5a3abf0be500b42712d72c81c190412f34..5a4d3ea4d2ec6832e6735d15096d46fbb62f86dd 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -162,18 +162,18 @@ def to_float32(preds):
for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
- else:
- preds[k] = paddle.to_tensor(preds[k], dtype='float32')
+ elif isinstance(preds[k], paddle.Tensor):
+ preds[k] = preds[k].astype(paddle.float32)
elif isinstance(preds, list):
for k in range(len(preds)):
if isinstance(preds[k], dict):
preds[k] = to_float32(preds[k])
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
- else:
- preds[k] = paddle.to_tensor(preds[k], dtype='float32')
- else:
- preds = paddle.to_tensor(preds, dtype='float32')
+ elif isinstance(preds[k], paddle.Tensor):
+ preds[k] = preds[k].astype(paddle.float32)
+ elif isinstance(preds, paddle.Tensor):
+ preds = preds.astype(paddle.float32)
return preds
@@ -190,7 +190,8 @@ def train(config,
pre_best_model_dict,
logger,
log_writer=None,
- scaler=None):
+ scaler=None,
+ amp_level='O2'):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
@@ -230,7 +231,8 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
- "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN"
+ "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
+ "RobustScanner"
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
@@ -276,7 +278,8 @@ def train(config,
model_average = True
# use amp
if scaler:
- with paddle.amp.auto_cast(level='O2'):
+ custom_black_list = config['Global'].get('amp_custom_black_list',[])
+ with paddle.amp.auto_cast(level=amp_level, custom_black_list=custom_black_list):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
@@ -502,18 +505,9 @@ def eval(model,
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
-
- for i in (range(sr_img.shape[0])):
- fm_sr = (sr_img[i].numpy() * 255).transpose(
- 1, 2, 0).astype(np.uint8)
- fm_lr = (lr_img[i].numpy() * 255).transpose(
- 1, 2, 0).astype(np.uint8)
- cv2.imwrite("output/images/{}_{}_sr.jpg".format(
- sum_images, i), fm_sr)
- cv2.imwrite("output/images/{}_{}_lr.jpg".format(
- sum_images, i), fm_lr)
else:
preds = model(images)
+ preds = to_float32(preds)
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
@@ -523,16 +517,6 @@ def eval(model,
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
-
- for i in (range(sr_img.shape[0])):
- fm_sr = (sr_img[i].numpy() * 255).transpose(
- 1, 2, 0).astype(np.uint8)
- fm_lr = (lr_img[i].numpy() * 255).transpose(
- 1, 2, 0).astype(np.uint8)
- cv2.imwrite("output/images/{}_{}_sr.jpg".format(
- sum_images, i), fm_sr)
- cv2.imwrite("output/images/{}_{}_lr.jpg".format(
- sum_images, i), fm_lr)
else:
preds = model(images)
@@ -653,7 +637,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
- 'Gestalt', 'SLANet'
+ 'Gestalt', 'SLANet', 'RobustScanner'
]
if use_xpu:
diff --git a/tools/train.py b/tools/train.py
index 0c881ecae8daf78860829b1419178358c2209f25..5f310938f3ae3488281b47ccdb436697595b5578 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -147,6 +147,7 @@ def main(config, device, logger, vdl_writer):
len(valid_dataloader)))
use_amp = config["Global"].get("use_amp", False)
+ amp_level = config["Global"].get("amp_level", 'O2')
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
@@ -159,8 +160,9 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
- model, optimizer = paddle.amp.decorate(
- models=model, optimizers=optimizer, level='O2', master_weight=True)
+ if amp_level == "O2":
+ model, optimizer = paddle.amp.decorate(
+ models=model, optimizers=optimizer, level=amp_level, master_weight=True)
else:
scaler = None
@@ -169,7 +171,7 @@ def main(config, device, logger, vdl_writer):
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
- eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
+ eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level)
def test_reader(config, device, logger):
|