diff --git a/.gitignore b/.gitignore
index 70f136ccff632b77df358643e5815eb0bf6b0395..ed9bc452f4f80e02c963728c0acd9728f5912095 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,9 +10,14 @@ __pycache__/
inference/
inference_results/
output/
+<<<<<<< HEAD
train_data
log
+=======
+train_data/
+log/
+>>>>>>> 1696b36bdb4152138ed5cb08a357df8fe03dc067
*.DS_Store
*.vs
*.user
diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py
index ce3d66f07f89cedab463cf38bc0cdc56f3a61237..aeed6435a4e84b2b58811e9087d9713300848104 100644
--- a/PPOCRLabel/PPOCRLabel.py
+++ b/PPOCRLabel/PPOCRLabel.py
@@ -28,7 +28,7 @@ from PyQt5.QtCore import QSize, Qt, QPoint, QByteArray, QTimer, QFileInfo, QPoin
from PyQt5.QtGui import QImage, QCursor, QPixmap, QImageReader
from PyQt5.QtWidgets import QMainWindow, QListWidget, QVBoxLayout, QToolButton, QHBoxLayout, QDockWidget, QWidget, \
QSlider, QGraphicsOpacityEffect, QMessageBox, QListView, QScrollArea, QWidgetAction, QApplication, QLabel, QGridLayout, \
- QFileDialog, QListWidgetItem, QComboBox, QDialog
+ QFileDialog, QListWidgetItem, QComboBox, QDialog, QAbstractItemView
__dir__ = os.path.dirname(os.path.abspath(__file__))
@@ -242,6 +242,20 @@ class MainWindow(QMainWindow):
self.labelListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
listLayout.addWidget(self.labelListDock)
+ # enable labelList drag_drop to adjust bbox order
+ # 设置选择模式为单选
+ self.labelList.setSelectionMode(QAbstractItemView.SingleSelection)
+ # 启用拖拽
+ self.labelList.setDragEnabled(True)
+ # 设置接受拖放
+ self.labelList.viewport().setAcceptDrops(True)
+ # 设置显示将要被放置的位置
+ self.labelList.setDropIndicatorShown(True)
+ # 设置拖放模式为移动项目,如果不设置,默认为复制项目
+ self.labelList.setDragDropMode(QAbstractItemView.InternalMove)
+ # 触发放置
+ self.labelList.model().rowsMoved.connect(self.drag_drop_happened)
+
# ================== Detection Box ==================
self.BoxList = QListWidget()
@@ -589,15 +603,23 @@ class MainWindow(QMainWindow):
self.displayLabelOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayLabelOption.triggered.connect(self.togglePaintLabelsOption)
+ # Add option to enable/disable box index being displayed at the top of bounding boxes
+ self.displayIndexOption = QAction(getStr('displayIndex'), self)
+ self.displayIndexOption.setCheckable(True)
+ self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
+ self.displayIndexOption.triggered.connect(self.togglePaintIndexOption)
+
self.labelDialogOption = QAction(getStr('labelDialogOption'), self)
self.labelDialogOption.setShortcut("Ctrl+Shift+L")
self.labelDialogOption.setCheckable(True)
self.labelDialogOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
+ self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.labelDialogOption.triggered.connect(self.speedChoose)
self.autoSaveOption = QAction(getStr('autoSaveMode'), self)
self.autoSaveOption.setCheckable(True)
self.autoSaveOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
+ self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.autoSaveOption.triggered.connect(self.autoSaveFunc)
addActions(self.menus.file,
@@ -606,7 +628,7 @@ class MainWindow(QMainWindow):
addActions(self.menus.help, (showKeys, showSteps, showInfo))
addActions(self.menus.view, (
- self.displayLabelOption, self.labelDialogOption,
+ self.displayLabelOption, self.displayIndexOption, self.labelDialogOption,
None,
hideAll, showAll, None,
zoomIn, zoomOut, zoomOrg, None,
@@ -964,9 +986,10 @@ class MainWindow(QMainWindow):
else:
self.canvas.selectedShapes_hShape = self.canvas.selectedShapes
for shape in self.canvas.selectedShapes_hShape:
- item = self.shapesToItemsbox[shape] # listitem
- text = [(int(p.x()), int(p.y())) for p in shape.points]
- item.setText(str(text))
+ if shape in self.shapesToItemsbox.keys():
+ item = self.shapesToItemsbox[shape] # listitem
+ text = [(int(p.x()), int(p.y())) for p in shape.points]
+ item.setText(str(text))
self.actions.undo.setEnabled(True)
self.setDirty()
@@ -1040,6 +1063,8 @@ class MainWindow(QMainWindow):
def addLabel(self, shape):
shape.paintLabel = self.displayLabelOption.isChecked()
+ shape.paintIdx = self.displayIndexOption.isChecked()
+
item = HashableQListWidgetItem(shape.label)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
item.setCheckState(Qt.Unchecked) if shape.difficult else item.setCheckState(Qt.Checked)
@@ -1083,6 +1108,7 @@ class MainWindow(QMainWindow):
def loadLabels(self, shapes):
s = []
+ shape_index = 0
for label, points, line_color, key_cls, difficult in shapes:
shape = Shape(label=label, line_color=line_color, key_cls=key_cls)
for x, y in points:
@@ -1094,6 +1120,8 @@ class MainWindow(QMainWindow):
shape.addPoint(QPointF(x, y))
shape.difficult = difficult
+ shape.idx = shape_index
+ shape_index += 1
# shape.locked = False
shape.close()
s.append(shape)
@@ -1209,18 +1237,54 @@ class MainWindow(QMainWindow):
self.canvas.deSelectShape()
def labelItemChanged(self, item):
- shape = self.itemsToShapes[item]
- label = item.text()
- if label != shape.label:
- shape.label = item.text()
- # shape.line_color = generateColorByText(shape.label)
- self.setDirty()
- elif not ((item.checkState() == Qt.Unchecked) ^ (not shape.difficult)):
- shape.difficult = True if item.checkState() == Qt.Unchecked else False
- self.setDirty()
- else: # User probably changed item visibility
- self.canvas.setShapeVisible(shape, True) # item.checkState() == Qt.Checked
- # self.actions.save.setEnabled(True)
+ # avoid accidentally triggering the itemChanged siganl with unhashable item
+ # Unknown trigger condition
+ if type(item) == HashableQListWidgetItem:
+ shape = self.itemsToShapes[item]
+ label = item.text()
+ if label != shape.label:
+ shape.label = item.text()
+ # shape.line_color = generateColorByText(shape.label)
+ self.setDirty()
+ elif not ((item.checkState() == Qt.Unchecked) ^ (not shape.difficult)):
+ shape.difficult = True if item.checkState() == Qt.Unchecked else False
+ self.setDirty()
+ else: # User probably changed item visibility
+ self.canvas.setShapeVisible(shape, True) # item.checkState() == Qt.Checked
+ # self.actions.save.setEnabled(True)
+ else:
+ print('enter labelItemChanged slot with unhashable item: ', item, item.text())
+
+ def drag_drop_happened(self):
+ '''
+ label list drag drop signal slot
+ '''
+ # print('___________________drag_drop_happened_______________')
+ # should only select single item
+ for item in self.labelList.selectedItems():
+ newIndex = self.labelList.indexFromItem(item).row()
+
+ # only support drag_drop one item
+ assert len(self.canvas.selectedShapes) > 0
+ for shape in self.canvas.selectedShapes:
+ selectedShapeIndex = shape.idx
+
+ if newIndex == selectedShapeIndex:
+ return
+
+ # move corresponding item in shape list
+ shape = self.canvas.shapes.pop(selectedShapeIndex)
+ self.canvas.shapes.insert(newIndex, shape)
+
+ # update bbox index
+ self.canvas.updateShapeIndex()
+
+ # boxList update simultaneously
+ item = self.BoxList.takeItem(selectedShapeIndex)
+ self.BoxList.insertItem(newIndex, item)
+
+ # changes happen
+ self.setDirty()
# Callback functions:
def newShape(self, value=True):
@@ -1560,6 +1624,7 @@ class MainWindow(QMainWindow):
settings[SETTING_LAST_OPEN_DIR] = ''
settings[SETTING_PAINT_LABEL] = self.displayLabelOption.isChecked()
+ settings[SETTING_PAINT_INDEX] = self.displayIndexOption.isChecked()
settings[SETTING_DRAW_SQUARE] = self.drawSquaresOption.isChecked()
settings.save()
try:
@@ -1946,8 +2011,16 @@ class MainWindow(QMainWindow):
self.labelHist.append(line)
def togglePaintLabelsOption(self):
+ self.displayIndexOption.setChecked(False)
+ for shape in self.canvas.shapes:
+ shape.paintLabel = self.displayLabelOption.isChecked()
+ shape.paintIdx = self.displayIndexOption.isChecked()
+
+ def togglePaintIndexOption(self):
+ self.displayLabelOption.setChecked(False)
for shape in self.canvas.shapes:
shape.paintLabel = self.displayLabelOption.isChecked()
+ shape.paintIdx = self.displayIndexOption.isChecked()
def toogleDrawSquare(self):
self.canvas.setDrawingShapeToSquare(self.drawSquaresOption.isChecked())
@@ -2042,7 +2115,7 @@ class MainWindow(QMainWindow):
self.init_key_list(self.Cachelabel)
def reRecognition(self):
- img = cv2.imread(self.filePath)
+ img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1)
# org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]]
if self.canvas.shapes:
self.result_dic = []
@@ -2111,7 +2184,7 @@ class MainWindow(QMainWindow):
QMessageBox.information(self, "Information", "Draw a box!")
def singleRerecognition(self):
- img = cv2.imread(self.filePath)
+ img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1)
for shape in self.canvas.selectedShapes:
box = [[int(p.x()), int(p.y())] for p in shape.points]
if len(box) > 4:
@@ -2187,6 +2260,7 @@ class MainWindow(QMainWindow):
shapes = []
result_len = len(region['res']['boxes'])
+ order_index = 0
for i in range(result_len):
bbox = np.array(region['res']['boxes'][i])
rec_text = region['res']['rec_res'][i][0]
@@ -2205,6 +2279,8 @@ class MainWindow(QMainWindow):
x, y, snapped = self.canvas.snapPointToCanvas(x, y)
shape.addPoint(QPointF(x, y))
shape.difficult = False
+ shape.idx = order_index
+ order_index += 1
# shape.locked = False
shape.close()
self.addLabel(shape)
diff --git a/PPOCRLabel/libs/canvas.py b/PPOCRLabel/libs/canvas.py
index e6cddf13ede235fa193daf84d4395d77c371049a..81f37995126140b03650f5ddea37ea282d5ceb09 100644
--- a/PPOCRLabel/libs/canvas.py
+++ b/PPOCRLabel/libs/canvas.py
@@ -314,21 +314,23 @@ class Canvas(QWidget):
QApplication.restoreOverrideCursor() # ?
if self.movingShape and self.hShape:
- index = self.shapes.index(self.hShape)
- if (
- self.shapesBackups[-1][index].points
- != self.shapes[index].points
- ):
- self.storeShapes()
- self.shapeMoved.emit() # connect to updateBoxlist in PPOCRLabel.py
+ if self.hShape in self.shapes:
+ index = self.shapes.index(self.hShape)
+ if (
+ self.shapesBackups[-1][index].points
+ != self.shapes[index].points
+ ):
+ self.storeShapes()
+ self.shapeMoved.emit() # connect to updateBoxlist in PPOCRLabel.py
- self.movingShape = False
+ self.movingShape = False
def endMove(self, copy=False):
assert self.selectedShapes and self.selectedShapesCopy
assert len(self.selectedShapesCopy) == len(self.selectedShapes)
if copy:
for i, shape in enumerate(self.selectedShapesCopy):
+ shape.idx = len(self.shapes) # add current box index
self.shapes.append(shape)
self.selectedShapes[i].selected = False
self.selectedShapes[i] = shape
@@ -524,6 +526,9 @@ class Canvas(QWidget):
self.storeShapes()
self.selectedShapes = []
self.update()
+
+ self.updateShapeIndex()
+
return deleted_shapes
def storeShapes(self):
@@ -619,6 +624,13 @@ class Canvas(QWidget):
pal.setColor(self.backgroundRole(), QColor(232, 232, 232, 255))
self.setPalette(pal)
+ # adaptive BBOX label & index font size
+ if self.pixmap:
+ h, w = self.pixmap.size().height(), self.pixmap.size().width()
+ fontszie = int(max(h, w) / 48)
+ for s in self.shapes:
+ s.fontsize = fontszie
+
p.end()
def fillDrawing(self):
@@ -651,7 +663,8 @@ class Canvas(QWidget):
return
self.current.close()
- self.shapes.append(self.current)
+ self.current.idx = len(self.shapes) # add current box index
+ self.shapes.append(self.current)
self.current = None
self.setHiding(False)
self.newShape.emit()
@@ -842,6 +855,7 @@ class Canvas(QWidget):
self.hVertex = None
# self.hEdge = None
self.storeShapes()
+ self.updateShapeIndex()
self.repaint()
def setShapeVisible(self, shape, value):
@@ -883,10 +897,16 @@ class Canvas(QWidget):
self.selectedShapes = []
for shape in self.shapes:
shape.selected = False
+ self.updateShapeIndex()
self.repaint()
-
+
@property
def isShapeRestorable(self):
if len(self.shapesBackups) < 2:
return False
- return True
\ No newline at end of file
+ return True
+
+ def updateShapeIndex(self):
+ for i in range(len(self.shapes)):
+ self.shapes[i].idx = i
+ self.update()
\ No newline at end of file
diff --git a/PPOCRLabel/libs/constants.py b/PPOCRLabel/libs/constants.py
index 58c8222ec52dcdbff7ddda04911f6703a2bdedc7..f075f4a53919db483ce9af7a09a2547f7ec3df6a 100644
--- a/PPOCRLabel/libs/constants.py
+++ b/PPOCRLabel/libs/constants.py
@@ -21,6 +21,7 @@ SETTING_ADVANCE_MODE = 'advanced'
SETTING_WIN_STATE = 'window/state'
SETTING_SAVE_DIR = 'savedir'
SETTING_PAINT_LABEL = 'paintlabel'
+SETTING_PAINT_INDEX = 'paintindex'
SETTING_LAST_OPEN_DIR = 'lastOpenDir'
SETTING_AUTO_SAVE = 'autosave'
SETTING_SINGLE_CLASS = 'singleclass'
diff --git a/PPOCRLabel/libs/editinlist.py b/PPOCRLabel/libs/editinlist.py
index 79d2d3aa371ac076de513a4d52ea51b27c6e08f2..4bcc11ec47e090e1cda9083a35baf5b451acf8fc 100644
--- a/PPOCRLabel/libs/editinlist.py
+++ b/PPOCRLabel/libs/editinlist.py
@@ -26,4 +26,4 @@ class EditInList(QListWidget):
def leaveEvent(self, event):
# close edit
for i in range(self.count()):
- self.closePersistentEditor(self.item(i))
+ self.closePersistentEditor(self.item(i))
\ No newline at end of file
diff --git a/PPOCRLabel/libs/shape.py b/PPOCRLabel/libs/shape.py
index 97e2eb72380be5c1fd1e06785be846b596763986..121e43b8aee62dd3e5e0b2e2fbdebf10e775f57b 100644
--- a/PPOCRLabel/libs/shape.py
+++ b/PPOCRLabel/libs/shape.py
@@ -46,15 +46,16 @@ class Shape(object):
point_size = 8
scale = 1.0
- def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False):
+ def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False, paintIdx=False):
self.label = label
- self.idx = 0
+ self.idx = None # bbox order, only for table annotation
self.points = []
self.fill = False
self.selected = False
self.difficult = difficult
self.key_cls = key_cls
self.paintLabel = paintLabel
+ self.paintIdx = paintIdx
self.locked = False
self.direction = 0
self.center = None
@@ -65,6 +66,7 @@ class Shape(object):
self.NEAR_VERTEX: (4, self.P_ROUND),
self.MOVE_VERTEX: (1.5, self.P_SQUARE),
}
+ self.fontsize = 8
self._closed = False
@@ -155,7 +157,7 @@ class Shape(object):
min_y = min(min_y, point.y())
if min_x != sys.maxsize and min_y != sys.maxsize:
font = QFont()
- font.setPointSize(8)
+ font.setPointSize(self.fontsize)
font.setBold(True)
painter.setFont(font)
if self.label is None:
@@ -164,6 +166,25 @@ class Shape(object):
min_y += MIN_Y_LABEL
painter.drawText(min_x, min_y, self.label)
+ # Draw number at the top-right
+ if self.paintIdx:
+ min_x = sys.maxsize
+ min_y = sys.maxsize
+ for point in self.points:
+ min_x = min(min_x, point.x())
+ min_y = min(min_y, point.y())
+ if min_x != sys.maxsize and min_y != sys.maxsize:
+ font = QFont()
+ font.setPointSize(self.fontsize)
+ font.setBold(True)
+ painter.setFont(font)
+ text = ''
+ if self.idx != None:
+ text = str(self.idx)
+ if min_y < MIN_Y_LABEL:
+ min_y += MIN_Y_LABEL
+ painter.drawText(min_x, min_y, text)
+
if self.fill:
color = self.select_fill_color if self.selected else self.fill_color
painter.fillPath(line_path, color)
diff --git a/PPOCRLabel/resources/strings/strings-en.properties b/PPOCRLabel/resources/strings/strings-en.properties
index 0b112c46461b6626dfbebaa87babd691b2492d0a..1b628016c079ad1c5eb5514c7d6eb2cba842b7e3 100644
--- a/PPOCRLabel/resources/strings/strings-en.properties
+++ b/PPOCRLabel/resources/strings/strings-en.properties
@@ -61,6 +61,7 @@ labels=Labels
autoSaveMode=Auto Save mode
singleClsMode=Single Class Mode
displayLabel=Display Labels
+displayIndex=Display box index
fileList=File List
files=Files
advancedMode=Advanced Mode
diff --git a/PPOCRLabel/resources/strings/strings-zh-CN.properties b/PPOCRLabel/resources/strings/strings-zh-CN.properties
index 184247e85b634af22394d6c038229ce3aadd9e8d..0758729a8ca0cae862a4bf5bcf2e5b24f2d95822 100644
--- a/PPOCRLabel/resources/strings/strings-zh-CN.properties
+++ b/PPOCRLabel/resources/strings/strings-zh-CN.properties
@@ -61,6 +61,7 @@ labels=标签
autoSaveMode=自动保存模式
singleClsMode=单一类别模式
displayLabel=显示类别
+displayIndex=显示box序号
fileList=文件列表
files=文件
advancedMode=专家模式
diff --git a/configs/det/det_r50_db++_ic15.yml b/configs/det/det_r50_db++_icdar15.yml
similarity index 100%
rename from configs/det/det_r50_db++_ic15.yml
rename to configs/det/det_r50_db++_icdar15.yml
diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml
index 4e5826adc990c30aaee1d63fd5b4523944906eee..eacde2965cad1954f9471ba11635936b3654da4b 100644
--- a/configs/rec/rec_mtb_nrtr.yml
+++ b/configs/rec/rec_mtb_nrtr.yml
@@ -82,7 +82,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/data_lmdb_release/validation/
+ data_dir: ./train_data/data_lmdb_release/evaluaiton/
transforms:
- DecodeImage: # load image
img_mode: BGR
diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml
new file mode 100644
index 0000000000000000000000000000000000000000..aea71388f703376120af4d0caf2fa8ccd4d92cce
--- /dev/null
+++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml
@@ -0,0 +1,116 @@
+Global:
+ use_gpu: True
+ epoch_num: 6
+ log_smooth_window: 50
+ print_batch_step: 50
+ save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
+ save_epoch_step: 3
+ # evaluation is run every 2000 iterations after the 4000th iteration
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path: ./ppocr/utils/dict/spin_dict.txt
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ name: Piecewise
+ decay_epochs: [3, 4, 5]
+ values: [0.001, 0.0003, 0.00009, 0.000027]
+ clip_norm: 5
+
+Architecture:
+ model_type: rec
+ algorithm: SPIN
+ in_channels: 1
+ Transform:
+ name: GA_SPIN
+ offsets: True
+ default_type: 6
+ loc_lr: 0.1
+ stn: True
+ Backbone:
+ name: ResNet32
+ out_channels: 512
+ Neck:
+ name: SequenceEncoder
+ encoder_type: cascadernn
+ hidden_size: 256
+ out_channels: [256, 512]
+ with_linear: True
+ Head:
+ name: SPINAttentionHead
+ hidden_size: 256
+
+
+Loss:
+ name: SPINAttentionLoss
+ ignore_index: 0
+
+PostProcess:
+ name: SPINLabelDecode
+ use_space_char: False
+
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SPINLabelEncode: # Class handling label
+ - SPINRecResizeImg:
+ image_shape: [100, 32]
+ interpolation : 2
+ mean: [127.5]
+ std: [127.5]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 8
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SPINLabelEncode: # Class handling label
+ - SPINRecResizeImg:
+ image_shape: [100, 32]
+ interpolation : 2
+ mean: [127.5]
+ std: [127.5]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 8
+ num_workers: 2
diff --git a/configs/rec/rec_r45_abinet.yml b/configs/rec/rec_r45_abinet.yml
index e604bcd12d9c06d1358e250bc76700af920db93b..bc9048a2005044874058179e8d683437a90a8519 100644
--- a/configs/rec/rec_r45_abinet.yml
+++ b/configs/rec/rec_r45_abinet.yml
@@ -8,7 +8,7 @@ Global:
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
- pretrained_model:
+ pretrained_model: ./pretrain_models/abinet_vl_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
@@ -82,7 +82,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/data_lmdb_release/validation/
+ data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: RGB
diff --git a/configs/rec/rec_svtrnet.yml b/configs/rec/rec_svtrnet.yml
index c1f5cc380ab9652e3ac750f8b8edacde9837daf0..e8ceefead6e42de5167984ffa0c18f7ecb03157b 100644
--- a/configs/rec/rec_svtrnet.yml
+++ b/configs/rec/rec_svtrnet.yml
@@ -77,7 +77,7 @@ Metric:
Train:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/data_lmdb_release/training
+ data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
@@ -97,7 +97,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/data_lmdb_release/validation
+ data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
diff --git a/configs/rec/rec_vitstr_none_ce.yml b/configs/rec/rec_vitstr_none_ce.yml
index b969c83a5d77b8c8ffde97ce5f914512074cf26e..ebe304fa4bad3a33b4a226e761ccabbfda41e202 100644
--- a/configs/rec/rec_vitstr_none_ce.yml
+++ b/configs/rec/rec_vitstr_none_ce.yml
@@ -81,7 +81,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/data_lmdb_release/validation/
+ data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
diff --git a/doc/doc_ch/algorithm_det_db.md b/doc/doc_ch/algorithm_det_db.md
index afdddb1a73a495cbb3186348704b235f8076c7d1..5401132061e507773ae77be49555ba754d1cba15 100644
--- a/doc/doc_ch/algorithm_det_db.md
+++ b/doc/doc_ch/algorithm_det_db.md
@@ -32,7 +32,7 @@
| --- | --- | --- | --- | --- | --- | --- |
|DB|ResNet50_vd|[configs/det/det_r50_vd_db.yml](../../configs/det/det_r50_vd_db.yml)|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|[configs/det/det_mv3_db.yml](../../configs/det/det_mv3_db.yml)|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
-|DB++|ResNet50|[configs/det/det_r50_db++_ic15.yml](../../configs/det/det_r50_db++_ic15.yml)|90.89%|82.66%|86.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)|
+|DB++|ResNet50|[configs/det/det_r50_db++_icdar15.yml](../../configs/det/det_r50_db++_icdar15.yml)|90.89%|82.66%|86.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)|
在TD_TR文本检测公开数据集上,算法复现效果如下:
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 2d763fd71c31616c182b66f9dba22e63a85aabec..a6ba71592fd2048b8e74a28e956e6d1711e4d322 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -69,6 +69,7 @@
- [x] [SVTR](./algorithm_rec_svtr.md)
- [x] [ViTSTR](./algorithm_rec_vitstr.md)
- [x] [ABINet](./algorithm_rec_abinet.md)
+- [x] [SPIN](./algorithm_rec_spin.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
@@ -90,7 +91,12 @@
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+<<<<<<< HEAD
|RobustScanner|ResNet31V2| 87.77% | rec_r31_robustscanner | coming soon |
+=======
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
+
+>>>>>>> 1696b36bdb4152138ed5cb08a357df8fe03dc067
diff --git a/doc/doc_ch/algorithm_rec_spin.md b/doc/doc_ch/algorithm_rec_spin.md
new file mode 100644
index 0000000000000000000000000000000000000000..c996992d2fa6297e6086ffae4bc36ad3e880873d
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_spin.md
@@ -0,0 +1,112 @@
+# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
+> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
+> AAAI, 2020
+
+SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识别中,矫正网络是一种较为常见的前置处理模块,但诸如RARE\ASTER\ESIR等只考虑了空间变换,并没有考虑色度变换。本文提出了一种结构Structure-Preserving Inner Offset Network (SPIN),可以在色彩空间上进行变换。该模块是可微分的,可以加入到任意识别器中。
+使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
+
+|模型|骨干网络|配置文件|Acc|下载链接|
+| --- | --- | --- | --- | --- |
+|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
+
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
+
+训练
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+
+```
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+```
+
+评估
+
+```
+# GPU 评估, Global.pretrained_model 为待测权重
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+预测:
+
+```
+# 预测使用的配置文件必须与训练一致
+python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
+```
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将SPIN文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
+```
+SPIN文本识别模型推理,可以执行如下命令:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=Falsee
+```
+
+
+### 4.2 C++推理
+
+由于C++预处理后处理还未支持SPIN,所以暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+
+## 引用
+
+```bibtex
+@article{2020SPIN,
+ title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
+ author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
+ journal={AAAI2020},
+ year={2020},
+}
+```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index 603af76fac065ef1c2aa8eb4cc0f275145956214..390f7c13c791c0511e731843178648a3c94d4c9e 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -68,7 +68,11 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SVTR](./algorithm_rec_svtr_en.md)
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
- [x] [ABINet](./algorithm_rec_abinet_en.md)
+<<<<<<< HEAD
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
+=======
+- [x] [SPIN](./algorithm_rec_spin_en.md)
+>>>>>>> 1696b36bdb4152138ed5cb08a357df8fe03dc067
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@@ -89,6 +93,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
+|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
diff --git a/doc/doc_en/algorithm_rec_spin_en.md b/doc/doc_en/algorithm_rec_spin_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..43ab30ce7d96cbb64ddf87156fee3012d666b2bf
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_spin_en.md
@@ -0,0 +1,112 @@
+# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
+> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
+> AAAI, 2020
+
+Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets. The algorithm reproduction effect is as follows:
+
+|Model|Backbone|config|Acc|Download link|
+| --- | --- | --- | --- | --- |
+|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
+
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
+```
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, the model saved during the SPIN text recognition training process is converted into an inference model. you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
+```
+
+For SPIN text recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=False
+```
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@article{2020SPIN,
+ title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
+ author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
+ journal={AAAI2020},
+ year={2020},
+}
+```
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 512f7502af4d96d154421f62f666b2efb1cc5afb..30d422b61e6ffaaf220ef72a26c3d2f51d784433 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -26,7 +26,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
- ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, RobustScannerRecResizeImg
+ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, SPINRecResizeImg, RobustScannerRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index a4087d53287fcd57f9c4992ba712c700f33b9981..97539faf232ec157340d3136d2efc0daca8deda8 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -1216,3 +1216,36 @@ class ABINetLabelEncode(BaseRecLabelEncode):
def add_special_char(self, dict_character):
dict_character = [''] + dict_character
return dict_character
+
+class SPINLabelEncode(AttnLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ lower=True,
+ **kwargs):
+ super(SPINLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+ self.lower = lower
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = [self.beg_str] + [self.end_str] + dict_character
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ target = [0] + text + [1]
+ padded_text = [0 for _ in range(self.max_text_len + 2)]
+
+ padded_text[:len(target)] = target
+ data['label'] = np.array(padded_text)
+ return data
\ No newline at end of file
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 1055e369e4cdf8edfbff94fec0b20520001de11d..a5620f84f8cb00584215f2b839055e9a46c25c5a 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -259,6 +259,7 @@ class PRENResizeImg(object):
data['image'] = resized_img.astype(np.float32)
return data
+<<<<<<< HEAD
class RobustScannerRecResizeImg(object):
def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
@@ -275,6 +276,50 @@ class RobustScannerRecResizeImg(object):
data['pad_shape'] = pad_shape
data['valid_ratio'] = valid_ratio
data['word_positons'] = word_positons
+=======
+class SPINRecResizeImg(object):
+ def __init__(self,
+ image_shape,
+ interpolation=2,
+ mean=(127.5, 127.5, 127.5),
+ std=(127.5, 127.5, 127.5),
+ **kwargs):
+ self.image_shape = image_shape
+
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.interpolation = interpolation
+
+ def __call__(self, data):
+ img = data['image']
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # different interpolation type corresponding the OpenCV
+ if self.interpolation == 0:
+ interpolation = cv2.INTER_NEAREST
+ elif self.interpolation == 1:
+ interpolation = cv2.INTER_LINEAR
+ elif self.interpolation == 2:
+ interpolation = cv2.INTER_CUBIC
+ elif self.interpolation == 3:
+ interpolation = cv2.INTER_AREA
+ else:
+ raise Exception("Unsupported interpolation type !!!")
+ # Deal with the image error during image loading
+ if img is None:
+ return None
+
+ img = cv2.resize(img, tuple(self.image_shape), interpolation)
+ img = np.array(img, np.float32)
+ img = np.expand_dims(img, -1)
+ img = img.transpose((2, 0, 1))
+ # normalize the image
+ img = img.copy().astype(np.float32)
+ mean = np.float64(self.mean.reshape(1, -1))
+ stdinv = 1 / np.float64(self.std.reshape(1, -1))
+ img -= mean
+ img *= stdinv
+ data['image'] = img
+>>>>>>> 1696b36bdb4152138ed5cb08a357df8fe03dc067
return data
class GrayRecResizeImg(object):
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 62e0544ea94daaaff7d019e6a48e65a2d508aca0..30120ac56756edd38676c40c39f0130f1b07c3ef 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
+from .rec_spin_att_loss import SPINAttentionLoss
# cls loss
from .cls_loss import ClsLoss
@@ -62,7 +63,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
- 'TableMasterLoss'
+ 'TableMasterLoss', 'SPINAttentionLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_spin_att_loss.py b/ppocr/losses/rec_spin_att_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..195780c7bfaf4aae5dd23bd72ace268bed9c1d4f
--- /dev/null
+++ b/ppocr/losses/rec_spin_att_loss.py
@@ -0,0 +1,45 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+'''This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR
+'''
+
+class SPINAttentionLoss(nn.Layer):
+ def __init__(self, reduction='mean', ignore_index=-100, **kwargs):
+ super(SPINAttentionLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(weight=None, reduction=reduction, ignore_index=ignore_index)
+
+ def forward(self, predicts, batch):
+ targets = batch[1].astype("int64")
+ targets = targets[:, 1:] # remove [eos] in label
+
+ label_lengths = batch[2].astype('int64')
+ batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
+ 1], predicts.shape[2]
+ assert len(targets.shape) == len(list(predicts.shape)) - 1, \
+ "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+
+ inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
+ targets = paddle.reshape(targets, [-1])
+
+ return {'loss': self.loss_func(inputs, targets)}
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index f4094d796b1f14c955e5962936e86bd6b3f5ec78..d4f5b15f56d34a9f6a6501058179a643ac7e8318 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -32,6 +32,7 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
+ from .rec_resnet_32 import ResNet32
from .rec_resnet_45 import ResNet45
from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
@@ -41,7 +42,7 @@ def build_backbone(config, model_type):
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
- 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR'
+ 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_resnet_32.py b/ppocr/modeling/backbones/rec_resnet_32.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd19251a3ed43a472d49f03743ead1491aa86ac
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_32.py
@@ -0,0 +1,269 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/backbones/ResNet32.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle.nn as nn
+
+__all__ = ["ResNet32"]
+
+conv_weight_attr = nn.initializer.KaimingNormal()
+
+class ResNet32(nn.Layer):
+ """
+ Feature Extractor is proposed in FAN Ref [1]
+
+ Ref [1]: Focusing Attention: Towards Accurate Text Recognition in Neural Images ICCV-2017
+ """
+
+ def __init__(self, in_channels, out_channels=512):
+ """
+
+ Args:
+ in_channels (int): input channel
+ output_channel (int): output channel
+ """
+ super(ResNet32, self).__init__()
+ self.out_channels = out_channels
+ self.ConvNet = ResNet(in_channels, out_channels, BasicBlock, [1, 2, 5, 3])
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs: input feature
+
+ Returns:
+ output feature
+
+ """
+ return self.ConvNet(inputs)
+
+class BasicBlock(nn.Layer):
+ """Res-net Basic Block"""
+ expansion = 1
+
+ def __init__(self, inplanes, planes,
+ stride=1, downsample=None,
+ norm_type='BN', **kwargs):
+ """
+ Args:
+ inplanes (int): input channel
+ planes (int): channels of the middle feature
+ stride (int): stride of the convolution
+ downsample (int): type of the down_sample
+ norm_type (str): type of the normalization
+ **kwargs (None): backup parameter
+ """
+ super(BasicBlock, self).__init__()
+ self.conv1 = self._conv3x3(inplanes, planes)
+ self.bn1 = nn.BatchNorm2D(planes)
+ self.conv2 = self._conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2D(planes)
+ self.relu = nn.ReLU()
+ self.downsample = downsample
+ self.stride = stride
+
+ def _conv3x3(self, in_planes, out_planes, stride=1):
+ """
+
+ Args:
+ in_planes (int): input channel
+ out_planes (int): channels of the middle feature
+ stride (int): stride of the convolution
+ Returns:
+ nn.Layer: Conv2D with kernel = 3
+
+ """
+
+ return nn.Conv2D(in_planes, out_planes,
+ kernel_size=3, stride=stride,
+ padding=1, weight_attr=conv_weight_attr,
+ bias_attr=False)
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+class ResNet(nn.Layer):
+ """Res-Net network structure"""
+ def __init__(self, input_channel,
+ output_channel, block, layers):
+ """
+
+ Args:
+ input_channel (int): input channel
+ output_channel (int): output channel
+ block (BasicBlock): convolution block
+ layers (list): layers of the block
+ """
+ super(ResNet, self).__init__()
+
+ self.output_channel_block = [int(output_channel / 4),
+ int(output_channel / 2),
+ output_channel,
+ output_channel]
+
+ self.inplanes = int(output_channel / 8)
+ self.conv0_1 = nn.Conv2D(input_channel, int(output_channel / 16),
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn0_1 = nn.BatchNorm2D(int(output_channel / 16))
+ self.conv0_2 = nn.Conv2D(int(output_channel / 16), self.inplanes,
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn0_2 = nn.BatchNorm2D(self.inplanes)
+ self.relu = nn.ReLU()
+
+ self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+ self.layer1 = self._make_layer(block,
+ self.output_channel_block[0],
+ layers[0])
+ self.conv1 = nn.Conv2D(self.output_channel_block[0],
+ self.output_channel_block[0],
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm2D(self.output_channel_block[0])
+
+ self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+ self.layer2 = self._make_layer(block,
+ self.output_channel_block[1],
+ layers[1], stride=1)
+ self.conv2 = nn.Conv2D(self.output_channel_block[1],
+ self.output_channel_block[1],
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False,)
+ self.bn2 = nn.BatchNorm2D(self.output_channel_block[1])
+
+ self.maxpool3 = nn.MaxPool2D(kernel_size=2,
+ stride=(2, 1),
+ padding=(0, 1))
+ self.layer3 = self._make_layer(block, self.output_channel_block[2],
+ layers[2], stride=1)
+ self.conv3 = nn.Conv2D(self.output_channel_block[2],
+ self.output_channel_block[2],
+ kernel_size=3, stride=1,
+ padding=1,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn3 = nn.BatchNorm2D(self.output_channel_block[2])
+
+ self.layer4 = self._make_layer(block, self.output_channel_block[3],
+ layers[3], stride=1)
+ self.conv4_1 = nn.Conv2D(self.output_channel_block[3],
+ self.output_channel_block[3],
+ kernel_size=2, stride=(2, 1),
+ padding=(0, 1),
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn4_1 = nn.BatchNorm2D(self.output_channel_block[3])
+ self.conv4_2 = nn.Conv2D(self.output_channel_block[3],
+ self.output_channel_block[3],
+ kernel_size=2, stride=1,
+ padding=0,
+ weight_attr=conv_weight_attr,
+ bias_attr=False)
+ self.bn4_2 = nn.BatchNorm2D(self.output_channel_block[3])
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ """
+
+ Args:
+ block (block): convolution block
+ planes (int): input channels
+ blocks (list): layers of the block
+ stride (int): stride of the convolution
+
+ Returns:
+ nn.Sequential: the combination of the convolution block
+
+ """
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2D(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride,
+ weight_attr=conv_weight_attr,
+ bias_attr=False),
+ nn.BatchNorm2D(planes * block.expansion),
+ )
+
+ layers = list()
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv0_1(x)
+ x = self.bn0_1(x)
+ x = self.relu(x)
+ x = self.conv0_2(x)
+ x = self.bn0_2(x)
+ x = self.relu(x)
+
+ x = self.maxpool1(x)
+ x = self.layer1(x)
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.maxpool2(x)
+ x = self.layer2(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+
+ x = self.maxpool3(x)
+ x = self.layer3(x)
+ x = self.conv3(x)
+ x = self.bn3(x)
+ x = self.relu(x)
+
+ x = self.layer4(x)
+ x = self.conv4_1(x)
+ x = self.bn4_1(x)
+ x = self.relu(x)
+ x = self.conv4_2(x)
+ x = self.bn4_2(x)
+ x = self.relu(x)
+ return x
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 9a1b76576e40aacc872003502ccce7a39930ac79..ca861c352dc7a4767e762619831708353dc67274 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -33,6 +33,7 @@ def build_head(config):
from .rec_aster_head import AsterHead
from .rec_pren_head import PRENHead
from .rec_multi_head import MultiHead
+ from .rec_spin_att_head import SPINAttentionHead
from .rec_abinet_head import ABINetHead
from .rec_robustscanner_head import RobustScannerHead
@@ -49,7 +50,7 @@ def build_head(config):
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
- 'MultiHead', 'ABINetHead', 'TableMasterHead', 'RobustScannerHead'
+ 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'RobustScannerHead'
]
#table head
diff --git a/ppocr/modeling/heads/rec_abinet_head.py b/ppocr/modeling/heads/rec_abinet_head.py
index a0f60f1be1727e85380eedb7d311ce9445f88b8e..2309ad65e6ebd32592df19eed0ce1fbd11cb9a81 100644
--- a/ppocr/modeling/heads/rec_abinet_head.py
+++ b/ppocr/modeling/heads/rec_abinet_head.py
@@ -273,7 +273,8 @@ def _get_length(logit):
out = out.cast('int32')
out = out.argmax(-1)
out = out + 1
- out = paddle.where(abn, out, paddle.to_tensor(logit.shape[1]))
+ len_seq = paddle.zeros_like(out) + logit.shape[1]
+ out = paddle.where(abn, out, len_seq)
return out
diff --git a/ppocr/modeling/heads/rec_spin_att_head.py b/ppocr/modeling/heads/rec_spin_att_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..86e35e4339d8e1006cfe43d6cf4f2f7d231082c4
--- /dev/null
+++ b/ppocr/modeling/heads/rec_spin_att_head.py
@@ -0,0 +1,115 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class SPINAttentionHead(nn.Layer):
+ def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
+ super(SPINAttentionHead, self).__init__()
+ self.input_size = in_channels
+ self.hidden_size = hidden_size
+ self.num_classes = out_channels
+
+ self.attention_cell = AttentionLSTMCell(
+ in_channels, hidden_size, out_channels, use_gru=False)
+ self.generator = nn.Linear(hidden_size, out_channels)
+
+ def _char_to_onehot(self, input_char, onehot_dim):
+ input_ont_hot = F.one_hot(input_char, onehot_dim)
+ return input_ont_hot
+
+ def forward(self, inputs, targets=None, batch_max_length=25):
+ batch_size = paddle.shape(inputs)[0]
+ num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
+
+ hidden = (paddle.zeros((batch_size, self.hidden_size)),
+ paddle.zeros((batch_size, self.hidden_size)))
+ output_hiddens = []
+ if self.training: # for train
+ targets = targets[0]
+ for i in range(num_steps):
+ char_onehots = self._char_to_onehot(
+ targets[:, i], onehot_dim=self.num_classes)
+ (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+ output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
+ output = paddle.concat(output_hiddens, axis=1)
+ probs = self.generator(output)
+ else:
+ targets = paddle.zeros(shape=[batch_size], dtype="int32")
+ probs = None
+ char_onehots = None
+ outputs = None
+ alpha = None
+
+ for i in range(num_steps):
+ char_onehots = self._char_to_onehot(
+ targets, onehot_dim=self.num_classes)
+ (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
+ char_onehots)
+ probs_step = self.generator(outputs)
+ if probs is None:
+ probs = paddle.unsqueeze(probs_step, axis=1)
+ else:
+ probs = paddle.concat(
+ [probs, paddle.unsqueeze(
+ probs_step, axis=1)], axis=1)
+ next_input = probs_step.argmax(axis=1)
+ targets = next_input
+ if not self.training:
+ probs = paddle.nn.functional.softmax(probs, axis=2)
+ return probs
+
+
+class AttentionLSTMCell(nn.Layer):
+ def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
+ super(AttentionLSTMCell, self).__init__()
+ self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
+ self.h2h = nn.Linear(hidden_size, hidden_size)
+ self.score = nn.Linear(hidden_size, 1, bias_attr=False)
+ if not use_gru:
+ self.rnn = nn.LSTMCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+ else:
+ self.rnn = nn.GRUCell(
+ input_size=input_size + num_embeddings, hidden_size=hidden_size)
+
+ self.hidden_size = hidden_size
+
+ def forward(self, prev_hidden, batch_H, char_onehots):
+ batch_H_proj = self.i2h(batch_H)
+ prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
+ res = paddle.add(batch_H_proj, prev_hidden_proj)
+ res = paddle.tanh(res)
+ e = self.score(res)
+
+ alpha = F.softmax(e, axis=1)
+ alpha = paddle.transpose(alpha, [0, 2, 1])
+ context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
+ concat_context = paddle.concat([context, char_onehots], 1)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+
+ return cur_hidden, alpha
diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py
index c8a774b8c543b9ccc14223c52f1b79ce690592f6..33be9400b34cb535d260881748e179c3df106caa 100644
--- a/ppocr/modeling/necks/rnn.py
+++ b/ppocr/modeling/necks/rnn.py
@@ -47,6 +47,56 @@ class EncoderWithRNN(nn.Layer):
x, _ = self.lstm(x)
return x
+class BidirectionalLSTM(nn.Layer):
+ def __init__(self, input_size,
+ hidden_size,
+ output_size=None,
+ num_layers=1,
+ dropout=0,
+ direction=False,
+ time_major=False,
+ with_linear=False):
+ super(BidirectionalLSTM, self).__init__()
+ self.with_linear = with_linear
+ self.rnn = nn.LSTM(input_size,
+ hidden_size,
+ num_layers=num_layers,
+ dropout=dropout,
+ direction=direction,
+ time_major=time_major)
+
+ # text recognition the specified structure LSTM with linear
+ if self.with_linear:
+ self.linear = nn.Linear(hidden_size * 2, output_size)
+
+ def forward(self, input_feature):
+ recurrent, _ = self.rnn(input_feature) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
+ if self.with_linear:
+ output = self.linear(recurrent) # batch_size x T x output_size
+ return output
+ return recurrent
+
+class EncoderWithCascadeRNN(nn.Layer):
+ def __init__(self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False):
+ super(EncoderWithCascadeRNN, self).__init__()
+ self.out_channels = out_channels[-1]
+ self.encoder = nn.LayerList(
+ [BidirectionalLSTM(
+ in_channels if i == 0 else out_channels[i - 1],
+ hidden_size,
+ output_size=out_channels[i],
+ num_layers=1,
+ direction='bidirectional',
+ with_linear=with_linear)
+ for i in range(num_layers)]
+ )
+
+
+ def forward(self, x):
+ for i, l in enumerate(self.encoder):
+ x = l(x)
+ return x
+
class EncoderWithFC(nn.Layer):
def __init__(self, in_channels, hidden_size):
@@ -166,13 +216,17 @@ class SequenceEncoder(nn.Layer):
'reshape': Im2Seq,
'fc': EncoderWithFC,
'rnn': EncoderWithRNN,
- 'svtr': EncoderWithSVTR
+ 'svtr': EncoderWithSVTR,
+ 'cascadernn': EncoderWithCascadeRNN
}
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys())
if encoder_type == "svtr":
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, **kwargs)
+ elif encoder_type == 'cascadernn':
+ self.encoder = support_encoder_dict[encoder_type](
+ self.encoder_reshape.out_channels, hidden_size, **kwargs)
else:
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size)
diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py
index 405ab3cc6c0380654f61e42e523ddc85839139b3..7e4ffdf46854416f71e1c8f4e131d1f0283bb725 100755
--- a/ppocr/modeling/transforms/__init__.py
+++ b/ppocr/modeling/transforms/__init__.py
@@ -18,8 +18,10 @@ __all__ = ['build_transform']
def build_transform(config):
from .tps import TPS
from .stn import STN_ON
+ from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
- support_dict = ['TPS', 'STN_ON']
+
+ support_dict = ['TPS', 'STN_ON', 'GA_SPIN']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
diff --git a/ppocr/modeling/transforms/gaspin_transformer.py b/ppocr/modeling/transforms/gaspin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4719eb2162a02141620586bcb6a849ae16f3b62
--- /dev/null
+++ b/ppocr/modeling/transforms/gaspin_transformer.py
@@ -0,0 +1,284 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+import functools
+from .tps import GridGenerator
+
+'''This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/transformations/gaspin_transformation.py
+'''
+
+class SP_TransformerNetwork(nn.Layer):
+ """
+ Sturture-Preserving Transformation (SPT) as Equa. (2) in Ref. [1]
+ Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
+ """
+
+ def __init__(self, nc=1, default_type=5):
+ """ Based on SPIN
+ Args:
+ nc (int): number of input channels (usually in 1 or 3)
+ default_type (int): the complexity of transformation intensities (by default set to 6 as the paper)
+ """
+ super(SP_TransformerNetwork, self).__init__()
+ self.power_list = self.cal_K(default_type)
+ self.sigmoid = nn.Sigmoid()
+ self.bn = nn.InstanceNorm2D(nc)
+
+ def cal_K(self, k=5):
+ """
+
+ Args:
+ k (int): the complexity of transformation intensities (by default set to 6 as the paper)
+
+ Returns:
+ List: the normalized intensity of each pixel in [0,1], denoted as \beta [1x(2K+1)]
+
+ """
+ from math import log
+ x = []
+ if k != 0:
+ for i in range(1, k+1):
+ lower = round(log(1-(0.5/(k+1))*i)/log((0.5/(k+1))*i), 2)
+ upper = round(1/lower, 2)
+ x.append(lower)
+ x.append(upper)
+ x.append(1.00)
+ return x
+
+ def forward(self, batch_I, weights, offsets, lambda_color=None):
+ """
+
+ Args:
+ batch_I (Tensor): batch of input images [batch_size x nc x I_height x I_width]
+ weights:
+ offsets: the predicted offset by AIN, a scalar
+ lambda_color: the learnable update gate \alpha in Equa. (5) as
+ g(x) = (1 - \alpha) \odot x + \alpha \odot x_{offsets}
+
+ Returns:
+ Tensor: transformed images by SPN as Equa. (4) in Ref. [1]
+ [batch_size x I_channel_num x I_r_height x I_r_width]
+
+ """
+ batch_I = (batch_I + 1) * 0.5
+ if offsets is not None:
+ batch_I = batch_I*(1-lambda_color) + offsets*lambda_color
+ batch_weight_params = paddle.unsqueeze(paddle.unsqueeze(weights, -1), -1)
+ batch_I_power = paddle.stack([batch_I.pow(p) for p in self.power_list], axis=1)
+
+ batch_weight_sum = paddle.sum(batch_I_power * batch_weight_params, axis=1)
+ batch_weight_sum = self.bn(batch_weight_sum)
+ batch_weight_sum = self.sigmoid(batch_weight_sum)
+ batch_weight_sum = batch_weight_sum * 2 - 1
+ return batch_weight_sum
+
+class GA_SPIN_Transformer(nn.Layer):
+ """
+ Geometric-Absorbed SPIN Transformation (GA-SPIN) proposed in Ref. [1]
+
+
+ Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
+ """
+
+ def __init__(self, in_channels=1,
+ I_r_size=(32, 100),
+ offsets=False,
+ norm_type='BN',
+ default_type=6,
+ loc_lr=1,
+ stn=True):
+ """
+ Args:
+ in_channels (int): channel of input features,
+ set it to 1 if the grayscale images and 3 if RGB input
+ I_r_size (tuple): size of rectified images (used in STN transformations)
+ offsets (bool): set it to False if use SPN w.o. AIN,
+ and set it to True if use SPIN (both with SPN and AIN)
+ norm_type (str): the normalization type of the module,
+ set it to 'BN' by default, 'IN' optionally
+ default_type (int): the K chromatic space,
+ set it to 3/5/6 depend on the complexity of transformation intensities
+ loc_lr (float): learning rate of location network
+ stn (bool): whther to use stn.
+
+ """
+ super(GA_SPIN_Transformer, self).__init__()
+ self.nc = in_channels
+ self.spt = True
+ self.offsets = offsets
+ self.stn = stn # set to True in GA-SPIN, while set it to False in SPIN
+ self.I_r_size = I_r_size
+ self.out_channels = in_channels
+ if norm_type == 'BN':
+ norm_layer = functools.partial(nn.BatchNorm2D, use_global_stats=True)
+ elif norm_type == 'IN':
+ norm_layer = functools.partial(nn.InstanceNorm2D, weight_attr=False,
+ use_global_stats=False)
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+
+ if self.spt:
+ self.sp_net = SP_TransformerNetwork(in_channels,
+ default_type)
+ self.spt_convnet = nn.Sequential(
+ # 32*100
+ nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False),
+ norm_layer(32), nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ # 16*50
+ nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False),
+ norm_layer(64), nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ # 8*25
+ nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False),
+ norm_layer(128), nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ # 4*12
+ )
+ self.stucture_fc1 = nn.Sequential(
+ nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False),
+ norm_layer(256), nn.ReLU(),
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False),
+ norm_layer(256), nn.ReLU(), # 2*6
+ nn.MaxPool2D(kernel_size=2, stride=2),
+ nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False),
+ norm_layer(512), nn.ReLU(), # 1*3
+ nn.AdaptiveAvgPool2D(1),
+ nn.Flatten(1, -1), # batch_size x 512
+ nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)),
+ nn.BatchNorm1D(256), nn.ReLU()
+ )
+ self.out_weight = 2*default_type+1
+ self.spt_length = 2*default_type+1
+ if offsets:
+ self.out_weight += 1
+ if self.stn:
+ self.F = 20
+ self.out_weight += self.F * 2
+ self.GridGenerator = GridGenerator(self.F*2, self.F)
+
+ # self.out_weight*=nc
+ # Init structure_fc2 in LocalizationNetwork
+ initial_bias = self.init_spin(default_type*2)
+ initial_bias = initial_bias.reshape(-1)
+ param_attr = ParamAttr(
+ learning_rate=loc_lr,
+ initializer=nn.initializer.Assign(np.zeros([256, self.out_weight])))
+ bias_attr = ParamAttr(
+ learning_rate=loc_lr,
+ initializer=nn.initializer.Assign(initial_bias))
+ self.stucture_fc2 = nn.Linear(256, self.out_weight,
+ weight_attr=param_attr,
+ bias_attr=bias_attr)
+ self.sigmoid = nn.Sigmoid()
+
+ if offsets:
+ self.offset_fc1 = nn.Sequential(nn.Conv2D(128, 16,
+ 3, 1, 1,
+ bias_attr=False),
+ norm_layer(16),
+ nn.ReLU(),)
+ self.offset_fc2 = nn.Conv2D(16, in_channels,
+ 3, 1, 1)
+ self.pool = nn.MaxPool2D(2, 2)
+
+ def init_spin(self, nz):
+ """
+ Args:
+ nz (int): number of paired \betas exponents, which means the value of K x 2
+
+ """
+ init_id = [0.00]*nz+[5.00]
+ if self.offsets:
+ init_id += [-5.00]
+ # init_id *=3
+ init = np.array(init_id)
+
+ if self.stn:
+ F = self.F
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
+ ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
+ ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ initial_bias = initial_bias.reshape(-1)
+ init = np.concatenate([init, initial_bias], axis=0)
+ return init
+
+ def forward(self, x, return_weight=False):
+ """
+ Args:
+ x (Tensor): input image batch
+ return_weight (bool): set to False by default,
+ if set to True return the predicted offsets of AIN, denoted as x_{offsets}
+
+ Returns:
+ Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size
+ """
+
+ if self.spt:
+ feat = self.spt_convnet(x)
+ fc1 = self.stucture_fc1(feat)
+ sp_weight_fusion = self.stucture_fc2(fc1)
+ sp_weight_fusion = sp_weight_fusion.reshape([x.shape[0], self.out_weight, 1])
+ if self.offsets: # SPIN w. AIN
+ lambda_color = sp_weight_fusion[:, self.spt_length, 0]
+ lambda_color = self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ sp_weight = sp_weight_fusion[:, :self.spt_length, :]
+ offsets = self.pool(self.offset_fc2(self.offset_fc1(feat)))
+
+ assert offsets.shape[2] == 2 # 2
+ assert offsets.shape[3] == 6 # 16
+ offsets = self.sigmoid(offsets) # v12
+
+ if return_weight:
+ return offsets
+ offsets = nn.functional.upsample(offsets, size=(x.shape[2], x.shape[3]), mode='bilinear')
+
+ if self.stn:
+ batch_C_prime = sp_weight_fusion[:, (self.spt_length + 1):, :].reshape([x.shape[0], self.F, 2])
+ build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
+ build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
+ self.I_r_size[0],
+ self.I_r_size[1],
+ 2])
+
+ else: # SPIN w.o. AIN
+ sp_weight = sp_weight_fusion[:, :self.spt_length, :]
+ lambda_color, offsets = None, None
+
+ if self.stn:
+ batch_C_prime = sp_weight_fusion[:, self.spt_length:, :].reshape([x.shape[0], self.F, 2])
+ build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
+ build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
+ self.I_r_size[0],
+ self.I_r_size[1],
+ 2])
+
+ x = self.sp_net(x, sp_weight, offsets, lambda_color)
+ if self.stn:
+ x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border')
+ return x
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 1d414eb2e8562925f461b0c6f6ce15774b81bb8f..eeebc5803f321df0d6709bb57a009692659bfe77 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -27,7 +27,8 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
- SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode
+ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
+ SPINLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
@@ -44,7 +45,7 @@ def build_post_process(config, global_config=None):
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
- 'TableMasterLabelDecode'
+ 'TableMasterLabelDecode', 'SPINLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index cc7c2cb379cc476943152507569f0b0066189c46..3fe29aabe58f42faa02d1b25b4255ba8a19b3ea3 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -667,3 +667,18 @@ class ABINetLabelDecode(NRTRLabelDecode):
def add_special_char(self, dict_character):
dict_character = [''] + dict_character
return dict_character
+
+class SPINLabelDecode(AttnLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(SPINLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = dict_character
+ dict_character = [self.beg_str] + [self.end_str] + dict_character
+ return dict_character
\ No newline at end of file
diff --git a/ppocr/utils/dict/spin_dict.txt b/ppocr/utils/dict/spin_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8ee8347fd9c85228a3cf46c810d4fc28ab05c492
--- /dev/null
+++ b/ppocr/utils/dict/spin_dict.txt
@@ -0,0 +1,68 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+:
+(
+'
+-
+,
+%
+>
+.
+[
+?
+)
+"
+=
+_
+*
+]
+;
+&
++
+$
+@
+/
+|
+!
+<
+#
+`
+{
+~
+\
+}
+^
\ No newline at end of file
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index b0ede5f3a1b88df6efed53d7ca33a696bc7a7fff..d6f2e24240ff783e14dbd61efdd27877f9ec39ff 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -34,7 +34,7 @@ from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result
-from ppstructure.recovery.docx import convert_info_docx
+from ppstructure.recovery.recovery_to_doc import convert_info_docx
logger = get_logger()
diff --git a/ppstructure/recovery/README_ch.md b/ppstructure/recovery/README_ch.md
index 1f72f8de8a5e2eb51c8c4f58df30465f5361a301..5a05abffd0399387bc0d22d878e64d03d8894a79 100644
--- a/ppstructure/recovery/README_ch.md
+++ b/ppstructure/recovery/README_ch.md
@@ -44,6 +44,12 @@ python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simp
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
+* **(2)安装依赖**
+
+```bash
+python3 -m pip install -r ppstructure/recovery/requirements.txt
+```
+
### 2.2 安装PaddleOCR
diff --git a/ppstructure/recovery/docx.py b/ppstructure/recovery/recovery_to_doc.py
similarity index 100%
rename from ppstructure/recovery/docx.py
rename to ppstructure/recovery/recovery_to_doc.py
diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py
index aa05459589208dde66a6710322593d091af41325..becc6daef02e7e3e98fcccd3b87a93e725577886 100644
--- a/ppstructure/table/predict_table.py
+++ b/ppstructure/table/predict_table.py
@@ -129,11 +129,25 @@ class TableSystem(object):
def rebuild_table(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
+ dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res)
matched_index = self.match_result(dt_boxes, pred_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
rec_res)
return pred_html, pred
+ def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res):
+ y1 = pred_bboxes[:,1::2].min()
+ new_dt_boxes = []
+ new_rec_res = []
+
+ for box,rec in zip(dt_boxes, rec_res):
+ if np.max(box[1::2]) < y1:
+ continue
+ new_dt_boxes.append(box)
+ new_rec_res.append(rec)
+ return new_dt_boxes, new_rec_res
+
+
def match_result(self, dt_boxes, pred_bboxes):
matched = {}
for i, gt_box in enumerate(dt_boxes):
diff --git a/test_tipc/benchmark_train.sh b/test_tipc/benchmark_train.sh
index e3e4d627fa27f3a34ae0ae47a8613d6ec0a0f60e..1dcb0129e767e6c35adfad36aa5dce2fbd84a2fd 100644
--- a/test_tipc/benchmark_train.sh
+++ b/test_tipc/benchmark_train.sh
@@ -21,6 +21,18 @@ function func_parser_params(){
echo ${tmp}
}
+function set_dynamic_epoch(){
+ string=$1
+ num=$2
+ _str=${string:1:6}
+ IFS="C"
+ arr=(${_str})
+ M=${arr[0]}
+ P=${arr[1]}
+ ep=`expr $num \* $M \* $P`
+ echo $ep
+}
+
function func_sed_params(){
filename=$1
line=$2
@@ -139,10 +151,11 @@ else
device_num=${params_list[4]}
IFS=";"
- if [ ${precision} = "null" ];then
- precision="fp32"
+ if [ ${precision} = "fp16" ];then
+ precision="amp"
fi
+ epoch=$(set_dynamic_epoch $device_num $epoch)
fp_items_list=($precision)
batch_size_list=($batch_size)
device_num_list=($device_num)
@@ -150,10 +163,16 @@ fi
IFS="|"
for batch_size in ${batch_size_list[*]}; do
- for precision in ${fp_items_list[*]}; do
+ for train_precision in ${fp_items_list[*]}; do
for device_num in ${device_num_list[*]}; do
# sed batchsize and precision
- func_sed_params "$FILENAME" "${line_precision}" "$precision"
+ if [ ${train_precision} = "amp" ];then
+ precision="fp16"
+ else
+ precision="fp32"
+ fi
+
+ func_sed_params "$FILENAME" "${line_precision}" "$train_precision"
func_sed_params "$FILENAME" "${line_batchsize}" "$MODE=$batch_size"
func_sed_params "$FILENAME" "${line_epoch}" "$MODE=$epoch"
gpu_id=$(set_gpu_id $device_num)
diff --git a/test_tipc/configs/det_r50_db_plusplus/train_infer_python.txt b/test_tipc/configs/det_r50_db_plusplus/train_infer_python.txt
index 04a3e845859167f78d5b3dd799236f8b8a051e81..110b7f9319cebdef5d9620671b2e62f3c1fe4a6d 100644
--- a/test_tipc/configs/det_r50_db_plusplus/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_db_plusplus/train_infer_python.txt
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c configs/det/det_r50_db++_ic15.yml -o Global.pretrained_model=./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
+norm_train:tools/train.py -c configs/det/det_r50_db++_icdar15.yml -o Global.pretrained_model=./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c configs/det/det_r50_db++_ic15.yml -o
+norm_export:tools/export_model.py -c configs/det/det_r50_db++_icdar15.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
inference_dir:null
train_model:./inference/det_r50_db++_train/best_accuracy
-infer_export:tools/export_model.py -c configs/det/det_r50_db++_ic15.yml -o
+infer_export:tools/export_model.py -c configs/det/det_r50_db++_icdar15.yml -o
infer_quant:False
inference:tools/infer/predict_det.py --det_algorithm="DB++"
--use_gpu:True|False
@@ -51,9 +51,3 @@ null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
-===========================train_benchmark_params==========================
-batch_size:8|16
-fp_items:fp32|fp16
-epoch:2
---profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
-flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
diff --git a/test_tipc/configs/en_table_structure/train_pact_infer_python.txt b/test_tipc/configs/en_table_structure/train_pact_infer_python.txt
index f62e8b68bc6c1af06a65a8dfb438d5d63576e123..9890b906a1d3b1127352af567dca0d7186f94694 100644
--- a/test_tipc/configs/en_table_structure/train_pact_infer_python.txt
+++ b/test_tipc/configs/en_table_structure/train_pact_infer_python.txt
@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50
Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
+Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=2
Global.pretrained_model:./pretrain_models/en_ppocr_mobile_v2.0_table_structure_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./ppstructure/docs/table/table.jpg
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d0cb20481f56a093f96c3d13f5fa2c2d13ae0c69
--- /dev/null
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
@@ -0,0 +1,117 @@
+Global:
+ use_gpu: True
+ epoch_num: 6
+ log_smooth_window: 50
+ print_batch_step: 50
+ save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
+ save_epoch_step: 3
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path: ./ppocr/utils/dict/spin_dict.txt
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ name: Piecewise
+ decay_epochs: [3, 4, 5]
+ values: [0.001, 0.0003, 0.00009, 0.000027]
+
+ clip_norm: 5
+
+Architecture:
+ model_type: rec
+ algorithm: SPIN
+ in_channels: 1
+ Transform:
+ name: GA_SPIN
+ offsets: True
+ default_type: 6
+ loc_lr: 0.1
+ stn: True
+ Backbone:
+ name: ResNet32
+ out_channels: 512
+ Neck:
+ name: SequenceEncoder
+ encoder_type: cascadernn
+ hidden_size: 256
+ out_channels: [256, 512]
+ with_linear: True
+ Head:
+ name: SPINAttentionHead
+ hidden_size: 256
+
+
+Loss:
+ name: SPINAttentionLoss
+ ignore_index: 0
+
+PostProcess:
+ name: SPINLabelDecode
+ use_space_char: False
+
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SPINLabelEncode: # Class handling label
+ - SPINRecResizeImg:
+ image_shape: [100, 32]
+ interpolation : 2
+ mean: [127.5]
+ std: [127.5]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 128
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SPINLabelEncode: # Class handling label
+ - SPINRecResizeImg:
+ image_shape: [100, 32]
+ interpolation : 2
+ mean: [127.5]
+ std: [127.5]
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 1
+ num_workers: 1
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4915055a576f0a5c1f7b0935a31d1d3c266903a5
--- /dev/null
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_r32_gaspin_bilstm_att
+python:python
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_r32_gaspin_bilstm_att/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spin_dict.txt --use_space_char=False --rec_image_shape="3,32,100" --rec_algorithm="SPIN"
+--use_gpu:True|False
+--enable_mkldnn:True|False
+--cpu_threads:1|6
+--rec_batch_num:1|6
+--use_tensorrt:False|False
+--precision:fp32|int8
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,32,100]}]
diff --git a/test_tipc/docs/benchmark_train.md b/test_tipc/docs/benchmark_train.md
index 3f846574ff75c8602ba1222977362c143582f560..ad2524c165da3079d24b2b1570a5111d152f8373 100644
--- a/test_tipc/docs/benchmark_train.md
+++ b/test_tipc/docs/benchmark_train.md
@@ -9,7 +9,7 @@
```shell
# 运行格式:bash test_tipc/prepare.sh train_benchmark.txt mode
-bash test_tipc/prepare.sh test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt benchmark_train
+bash test_tipc/prepare.sh test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt benchmark_train
```
## 1.2 功能测试
@@ -33,7 +33,7 @@ dynamic_bs8_fp32_DP_N1C1为test_tipc/benchmark_train.sh传入的参数,格式
## 2. 日志输出
-运行后将保存模型的训练日志和解析日志,使用 `test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt` 参数文件的训练日志解析结果是:
+运行后将保存模型的训练日志和解析日志,使用 `test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt` 参数文件的训练日志解析结果是:
```
{"model_branch": "dygaph", "model_commit": "7c39a1996b19087737c05d883fd346d2f39dbcc0", "model_name": "det_mv3_db_v2_0_bs8_fp32_SingleP_DP", "batch_size": 8, "fp_item": "fp32", "run_process_type": "SingleP", "run_mode": "DP", "convergence_value": "5.413110", "convergence_key": "loss:", "ips": 19.333, "speed_unit": "samples/s", "device_num": "N1C1", "model_run_time": "0", "frame_commit": "8cc09552473b842c651ead3b9848d41827a3dbab", "frame_version": "0.0.0"}
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 8e1758abb8adb3b120704d590e77e05476fb9d4e..ec6dece42a0126e6d05405b3262c1c1d24f0a376 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -58,7 +58,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_PP-OCRv3_det_distill_train.tar && cd ../
fi
- if [ ${model_name} == "en_table_structure" ];then
+ if [ ${model_name} == "en_table_structure" ] || [ ${model_name} == "en_table_structure_PACT" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
diff --git a/test_tipc/test_ptq_inference_python.sh b/test_tipc/test_ptq_inference_python.sh
index 288e6098966be4aaf2953d627e7890963100cb6e..e2939fd5e638ad0f6b4c44422a6fec6459903d1c 100644
--- a/test_tipc/test_ptq_inference_python.sh
+++ b/test_tipc/test_ptq_inference_python.sh
@@ -139,7 +139,7 @@ if [ ${MODE} = "whole_infer" ]; then
save_infer_dir="${infer_model}_klquant"
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
- export_log_path="${LOG_PATH}_export_${Count}.log"
+ export_log_path="${LOG_PATH}/${MODE}_export_${Count}.log"
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
echo ${infer_run_exports[Count]}
echo $export_cmd
diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh
index 907efcec9008f89740971bb6d4253bafb44938c4..402f636b1b92fa75380142803c6b513a897a89e4 100644
--- a/test_tipc/test_train_inference_python.sh
+++ b/test_tipc/test_train_inference_python.sh
@@ -265,7 +265,7 @@ else
if [ ${run_train} = "null" ]; then
continue
fi
- set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
+
set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
@@ -287,11 +287,11 @@ else
set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
- cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config} "
+ cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_train_params1} ${set_amp_config} "
elif [ ${#ips} -le 15 ];then # train with multi-gpu
- cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+ cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
else # train with multi-machine
- cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
+ cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
fi
# run train
eval $cmd
diff --git a/tools/export_model.py b/tools/export_model.py
index a9f4a62e87a6ad34b8c9a123b8055ed4a483bfc6..1cd273b209db83db8a8d91b966e5fb7ddef58cb4 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -107,7 +107,7 @@ def export_single_model(model,
]
# print([None, 3, 32, 128])
model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "NRTR":
+ elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 32, 100], dtype="float32"),
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 5131903dc4ce5f907bd2a3ad3f0afbc93b1350ef..1a262892a5133e5aa0d39c924901eeb186b11f4b 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -89,6 +89,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
+ elif self.rec_algorithm == "SPIN":
+ postprocess_params = {
+ 'name': 'SPINLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
@@ -266,6 +272,22 @@ class TextRecognizer(object):
return padding_im, resize_shape, pad_shape, valid_ratio
+ def resize_norm_img_spin(self, img):
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # return padding_im
+ img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
+ img = np.array(img, np.float32)
+ img = np.expand_dims(img, -1)
+ img = img.transpose((2, 0, 1))
+ mean = [127.5]
+ std = [127.5]
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ mean = np.float32(mean.reshape(1, -1))
+ stdinv = 1 / np.float32(std.reshape(1, -1))
+ img -= mean
+ img *= stdinv
+ return img
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
@@ -346,6 +368,10 @@ class TextRecognizer(object):
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
+ elif self.rec_algorithm == 'SPIN':
+ norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
elif self.rec_algorithm == "ABINet":
norm_img = self.resize_norm_img_abinet(
img_list[indices[ino]], self.rec_image_shape)
diff --git a/tools/infer_det.py b/tools/infer_det.py
index df346523896c9c3f82d254600986e0eb221e3c9f..f253e8f2876a5942538f18e93dfdada4391875b2 100755
--- a/tools/infer_det.py
+++ b/tools/infer_det.py
@@ -106,7 +106,7 @@ def main():
dt_boxes_list = []
for box in boxes:
tmp_json = {"transcription": ""}
- tmp_json['points'] = list(box)
+ tmp_json['points'] = np.array(box).tolist()
dt_boxes_list.append(tmp_json)
det_box_json[k] = dt_boxes_list
save_det_path = os.path.dirname(config['Global'][
@@ -118,7 +118,7 @@ def main():
# write result
for box in boxes:
tmp_json = {"transcription": ""}
- tmp_json['points'] = list(box)
+ tmp_json['points'] = np.array(box).tolist()
dt_boxes_json.append(tmp_json)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
diff --git a/tools/program.py b/tools/program.py
index 335ceb08a83fea468df278633b35ac3bc57ee2ed..3d62c32e6aa340d21fadac43c287dda6c7c77646 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -207,7 +207,7 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
- extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "RobustScanner"]
+ extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "RobustScanner"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
@@ -579,7 +579,7 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
- 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'RobustScanner'
+ 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'RobustScanner'
]
if use_xpu: