提交 ee060d07 编写于 作者: R redearly123/PaddleOCR

add feature lockedShapes

上级 d671e49d
...@@ -36,6 +36,7 @@ import numpy as np ...@@ -36,6 +36,7 @@ import numpy as np
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../PaddleOCR'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../PaddleOCR')))
sys.path.append("../ppstructure/vqa")
sys.path.append("..") sys.path.append("..")
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
...@@ -126,7 +127,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -126,7 +127,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.labelHist = [] self.labelHist = []
self.lastOpenDir = None self.lastOpenDir = None
self.result_dic = [] self.result_dic = []
self.result_dic_locked = []
self.changeFileFolder = False self.changeFileFolder = False
self.haveAutoReced = False self.haveAutoReced = False
self.labelFile = None self.labelFile = None
...@@ -1058,6 +1059,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1058,6 +1059,7 @@ class MainWindow(QMainWindow, WindowMixin):
shape.addPoint(QPointF(x, y)) shape.addPoint(QPointF(x, y))
shape.difficult = difficult shape.difficult = difficult
#shape.locked = False
shape.close() shape.close()
s.append(shape) s.append(shape)
...@@ -1117,8 +1119,8 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1117,8 +1119,8 @@ class MainWindow(QMainWindow, WindowMixin):
shapes = [] if mode == 'Auto' else \ shapes = [] if mode == 'Auto' else \
[format_shape(shape) for shape in self.canvas.shapes] [format_shape(shape) for shape in self.canvas.shapes]
# Can add differrent annotation formats here # Can add differrent annotation formats here
print("in save labels self.result_dic",self.result_dic)
for box in self.result_dic: for box in self.result_dic :
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False} trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
if trans_dic["label"] == "" and mode == 'Auto': if trans_dic["label"] == "" and mode == 'Auto':
continue continue
...@@ -1322,6 +1324,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1322,6 +1324,7 @@ class MainWindow(QMainWindow, WindowMixin):
# unicodeFilePath = os.path.abspath(unicodeFilePath) # unicodeFilePath = os.path.abspath(unicodeFilePath)
# Tzutalin 20160906 : Add file list and dock to move faster # Tzutalin 20160906 : Add file list and dock to move faster
# Highlight the file item # Highlight the file item
if unicodeFilePath and self.fileListWidget.count() > 0: if unicodeFilePath and self.fileListWidget.count() > 0:
if unicodeFilePath in self.mImgList: if unicodeFilePath in self.mImgList:
index = self.mImgList.index(unicodeFilePath) index = self.mImgList.index(unicodeFilePath)
...@@ -1370,7 +1373,8 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1370,7 +1373,8 @@ class MainWindow(QMainWindow, WindowMixin):
else: else:
self.dirty = False self.dirty = False
self.actions.save.setEnabled(True) self.actions.save.setEnabled(True)
if len(self.canvas.lockedShapes) != 0:
self.actions.save.setEnabled(True)
self.canvas.setEnabled(True) self.canvas.setEnabled(True)
self.adjustScale(initial=True) self.adjustScale(initial=True)
self.paintCanvas() self.paintCanvas()
...@@ -1397,11 +1401,16 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1397,11 +1401,16 @@ class MainWindow(QMainWindow, WindowMixin):
#box['ratio'] of the shapes saved in lockedShapes contains the ratio of the #box['ratio'] of the shapes saved in lockedShapes contains the ratio of the
# four corner coordinates of the shapes to the height and width of the image # four corner coordinates of the shapes to the height and width of the image
for box in self.canvas.lockedShapes: for box in self.canvas.lockedShapes:
shapes.append(("锁定框:待识别", [[s[0]*width,s[1]*height]for s in box['ratio']],DEFAULT_LOCK_COLOR, None, box['difficult'])) if self.canvas.isInTheSameImage:
shapes.append((box['transcription'], [[s[0]*width,s[1]*height]for s in box['ratio']],
DEFAULT_LOCK_COLOR, None, box['difficult']))
else:
shapes.append(('锁定框:待检测', [[s[0]*width,s[1]*height]for s in box['ratio']],
DEFAULT_LOCK_COLOR, None, box['difficult']))
if imgidx in self.PPlabel.keys(): if imgidx in self.PPlabel.keys():
for box in self.PPlabel[imgidx]: for box in self.PPlabel[imgidx]:
# print(box)
shapes.append((box['transcription'], box['points'], None, None, box['difficult'])) shapes.append((box['transcription'], box['points'], None, None, box['difficult']))
self.loadLabels(shapes) self.loadLabels(shapes)
self.canvas.verified = False self.canvas.verified = False
...@@ -1660,8 +1669,37 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1660,8 +1669,37 @@ class MainWindow(QMainWindow, WindowMixin):
return fullFilePath return fullFilePath
return '' return ''
def saveLockedShapes(self):
self.canvas.lockedShapes = []
self.canvas.selectedShapes = []
for s in self.canvas.shapes:
if s.line_color == DEFAULT_LOCK_COLOR:
self.canvas.selectedShapes.append(s)
self.lockSelectedShape()
for s in self.canvas.shapes:
if s.line_color == DEFAULT_LOCK_COLOR:
self.canvas.selectedShapes.remove(s)
self.canvas.shapes.remove(s)
def _saveFile(self, annotationFilePath, mode='Manual'): def _saveFile(self, annotationFilePath, mode='Manual'):
if len(self.canvas.lockedShapes) != 0:
self.saveLockedShapes()
if mode == 'Manual': if mode == 'Manual':
if len(self.result_dic_locked) == 0:
img = cv2.imread(self.filePath)
width, height = self.image.width(), self.image.height()
for shape in self.canvas.lockedShapes:
print(shape)
box = [[int(p[0]*width), int(p[1]*height)] for p in shape['ratio']]
assert len(box) == 4
result = [(shape['transcription'],1)]
result.insert(0, box)
self.result_dic_locked.append(result)
self.result_dic += self.result_dic_locked
self.result_dic_locked = []
if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode): if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode):
self.setClean() self.setClean()
self.statusBar().showMessage('Saved to %s' % annotationFilePath) self.statusBar().showMessage('Saved to %s' % annotationFilePath)
...@@ -1676,13 +1714,13 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1676,13 +1714,13 @@ class MainWindow(QMainWindow, WindowMixin):
self.savePPlabel(mode='Auto') self.savePPlabel(mode='Auto')
self.fileListWidget.insertItem(int(currIndex), item) self.fileListWidget.insertItem(int(currIndex), item)
if not self.canvas.isInTheSameImage:
self.openNextImg() self.openNextImg()
self.actions.saveRec.setEnabled(True) self.actions.saveRec.setEnabled(True)
self.actions.saveLabel.setEnabled(True) self.actions.saveLabel.setEnabled(True)
elif mode == 'Auto': elif mode == 'Auto':
if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode): if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode):
self.setClean() self.setClean()
self.statusBar().showMessage('Saved to %s' % annotationFilePath) self.statusBar().showMessage('Saved to %s' % annotationFilePath)
self.statusBar().show() self.statusBar().show()
...@@ -1746,7 +1784,9 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1746,7 +1784,9 @@ class MainWindow(QMainWindow, WindowMixin):
if discardChanges == QMessageBox.No: if discardChanges == QMessageBox.No:
return True return True
elif discardChanges == QMessageBox.Yes: elif discardChanges == QMessageBox.Yes:
self.canvas.isInTheSameImage = True
self.saveFile() self.saveFile()
self.canvas.isInTheSameImage = False
return True return True
else: else:
return False return False
...@@ -1885,6 +1925,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1885,6 +1925,7 @@ class MainWindow(QMainWindow, WindowMixin):
# org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]] # org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]]
if self.canvas.shapes: if self.canvas.shapes:
self.result_dic = [] self.result_dic = []
self.result_dic_locked = [] # result_dic_locked stores the ocr result of self.canvas.lockedShapes
rec_flag = 0 rec_flag = 0
for shape in self.canvas.shapes: for shape in self.canvas.shapes:
box = [[int(p.x()), int(p.y())] for p in shape.points] box = [[int(p.x()), int(p.y())] for p in shape.points]
...@@ -1896,21 +1937,33 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1896,21 +1937,33 @@ class MainWindow(QMainWindow, WindowMixin):
return return
result = self.ocr.ocr(img_crop, cls=True, det=False) result = self.ocr.ocr(img_crop, cls=True, det=False)
if result[0][0] != '': if result[0][0] != '':
if shape.line_color == DEFAULT_LOCK_COLOR:
shape.label = result[0][0]
result.insert(0, box)
self.result_dic_locked.append(result)
else:
result.insert(0, box) result.insert(0, box)
print('result in reRec is ', result)
self.result_dic.append(result) self.result_dic.append(result)
else: else:
print('Can not recognise the box') print('Can not recognise the box')
if shape.line_color == DEFAULT_LOCK_COLOR:
shape.label = result[0][0]
self.result_dic_locked.append([box,(self.noLabelText,0)])
else:
self.result_dic.append([box,(self.noLabelText,0)]) self.result_dic.append([box,(self.noLabelText,0)])
try:
if self.noLabelText == shape.label or result[1][0] == shape.label: if self.noLabelText == shape.label or result[1][0] == shape.label:
print('label no change') print('label no change')
else: else:
rec_flag += 1 rec_flag += 1
except IndexError as e:
print('except:', e)
if len(self.result_dic) > 0 and rec_flag > 0: if (len(self.result_dic) > 0 and rec_flag > 0)or self.canvas.lockedShapes:
self.canvas.isInTheSameImage = True
self.saveFile(mode='Auto') self.saveFile(mode='Auto')
self.loadFile(self.filePath) self.loadFile(self.filePath)
self.canvas.isInTheSameImage = False
self.setDirty() self.setDirty()
elif len(self.result_dic) == len(self.canvas.shapes) and rec_flag == 0: elif len(self.result_dic) == len(self.canvas.shapes) and rec_flag == 0:
QMessageBox.information(self, "Information", "The recognition result remains unchanged!") QMessageBox.information(self, "Information", "The recognition result remains unchanged!")
...@@ -2124,6 +2177,12 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -2124,6 +2177,12 @@ class MainWindow(QMainWindow, WindowMixin):
def lockSelectedShape(self): def lockSelectedShape(self):
"""lock the selsected shapes.
Add self.selectedShapes to lock self.canvas.lockedShapes,
which holds the ratio of the four coordinates of the locked shapes
to the width and height of the image
"""
width, height = self.image.width(), self.image.height() width, height = self.image.width(), self.image.height()
def format_shape(s): def format_shape(s):
return dict(label=s.label, # str return dict(label=s.label, # str
...@@ -2135,19 +2194,23 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -2135,19 +2194,23 @@ class MainWindow(QMainWindow, WindowMixin):
#lock #lock
if len(self.canvas.lockedShapes) == 0: if len(self.canvas.lockedShapes) == 0:
for s in self.canvas.selectedShapes: for s in self.canvas.selectedShapes:
s.line_color=DEFAULT_LOCK_COLOR s.line_color = DEFAULT_LOCK_COLOR
shapes=[format_shape(shape) for shape in self.canvas.selectedShapes] s.locked = True
shapes = [format_shape(shape) for shape in self.canvas.selectedShapes]
trans_dic = [] trans_dic = []
for box in shapes: for box in shapes:
trans_dic.append({"transcription": box['label'], "ratio": box['ratio'], 'difficult': box['difficult']}) trans_dic.append({"transcription": box['label'], "ratio": box['ratio'], 'difficult': box['difficult']})
self.canvas.lockedShapes = trans_dic self.canvas.lockedShapes = trans_dic
# print("self.canvas.lockedShapes:",self.canvas.lockedShapes) self.actions.save.setEnabled(True)
#unlock #unlock
else: else:
for s in self.canvas.shapes: for s in self.canvas.shapes:
s.line_color=DEFAULT_LINE_COLOR s.line_color = DEFAULT_LINE_COLOR
trans_dic = [] self.canvas.lockedShapes = []
self.canvas.lockedShapes = trans_dic self.result_dic_locked = []
self.setDirty()
self.actions.save.setEnabled(True)
def inverted(color): def inverted(color):
......
...@@ -86,7 +86,11 @@ class Canvas(QWidget): ...@@ -86,7 +86,11 @@ class Canvas(QWidget):
#initialisation for panning #initialisation for panning
self.pan_initial_pos = QPoint() self.pan_initial_pos = QPoint()
#lockedshapes related
self.lockedShapes = [] self.lockedShapes = []
self.isInTheSameImage = False
def setDrawingColor(self, qColor): def setDrawingColor(self, qColor):
self.drawingLineColor = qColor self.drawingLineColor = qColor
self.drawingRectColor = qColor self.drawingRectColor = qColor
......
此差异已折叠。
...@@ -58,7 +58,7 @@ class Shape(object): ...@@ -58,7 +58,7 @@ class Shape(object):
self.selected = False self.selected = False
self.difficult = difficult self.difficult = difficult
self.paintLabel = paintLabel self.paintLabel = paintLabel
self.locked = False
self._highlightIndex = None self._highlightIndex = None
self._highlightMode = self.NEAR_VERTEX self._highlightMode = self.NEAR_VERTEX
self._highlightSettings = { self._highlightSettings = {
......
...@@ -24,7 +24,7 @@ import paddle ...@@ -24,7 +24,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments from vaq_utils import parse_args, get_bio_label_maps, print_arguments
from data_collator import DataCollator from data_collator import DataCollator
from metric import re_score from metric import re_score
......
...@@ -33,7 +33,7 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMFor ...@@ -33,7 +33,7 @@ from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMFor
from xfun import XFUNDataset from xfun import XFUNDataset
from losses import SERLoss from losses import SERLoss
from utils import parse_args, get_bio_label_maps, print_arguments from vaq_utils import parse_args, get_bio_label_maps, print_arguments
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
......
...@@ -15,7 +15,7 @@ import paddle ...@@ -15,7 +15,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, draw_re_results from vaq_utils import parse_args, get_bio_label_maps, draw_re_results
from data_collator import DataCollator from data_collator import DataCollator
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
......
...@@ -22,7 +22,7 @@ from copy import deepcopy ...@@ -22,7 +22,7 @@ from copy import deepcopy
import paddle import paddle
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps from vaq_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
......
...@@ -25,9 +25,9 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM ...@@ -25,9 +25,9 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps from vaq_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info from vaq_utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
MODELS = { MODELS = {
'LayoutXLM': 'LayoutXLM':
......
...@@ -24,7 +24,7 @@ import paddle ...@@ -24,7 +24,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
# relative reference # relative reference
from utils import parse_args, get_image_file_list, draw_re_results from vaq_utils import parse_args, get_image_file_list, draw_re_results
from infer_ser_e2e import SerPredictor from infer_ser_e2e import SerPredictor
......
...@@ -27,7 +27,7 @@ import paddle ...@@ -27,7 +27,7 @@ import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments, set_seed from vaq_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from data_collator import DataCollator from data_collator import DataCollator
from eval_re import evaluate from eval_re import evaluate
......
...@@ -32,7 +32,7 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM ...@@ -32,7 +32,7 @@ from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLM
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset from xfun import XFUNDataset
from utils import parse_args, get_bio_label_maps, print_arguments, set_seed from vaq_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from eval_ser import evaluate from eval_ser import evaluate
from losses import SERLoss from losses import SERLoss
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册