提交 7ad8d5c5 编写于 作者: K Kentaro Wada

Refactoring to merge the feature of label-specific flags

上级 425b3c4b
......@@ -847,12 +847,20 @@ class MainWindow(QtWidgets.QMainWindow):
return True
return False
def editLabel(self, item=None):
def editLabel(self, item=False):
if item and not isinstance(item, QtWidgets.QListWidgetItem):
raise TypeError('unsupported type of item: {}'.format(type(item)))
if not self.canvas.editing():
return
item = item if item else self.currentItem()
if not item:
item = self.currentItem()
if item is None:
return
shape = self.labelList.get_shape_from_item(item)
text, flags = self.labelDialog.popUp(item.text() if item else None, flags=(shape.flags if shape else None))
if shape is None:
return
text, flags = self.labelDialog.popUp(shape.label, flags=shape.flags)
if text is None:
return
if not self.validateLabel(text):
......@@ -860,8 +868,8 @@ class MainWindow(QtWidgets.QMainWindow):
"Invalid label '{}' with validation type '{}'"
.format(text, self._config['validate_label']))
return
shape.flags = flags
shape.label = text
shape.flags = flags
item.setText(text)
self.setDirty()
if not self.uniqLabelList.findItems(text, Qt.MatchExactly):
......@@ -908,13 +916,6 @@ class MainWindow(QtWidgets.QMainWindow):
self.actions.shapeFillColor.setEnabled(selected)
def addLabel(self, shape):
if not shape.flags:
shape.flags = {}
if self._config['label_flags']:
for label in ["__all__", shape.label]:
if label in self._config['label_flags']:
for key in self._config['label_flags'][label]:
shape.flags[key] = False
item = QtWidgets.QListWidgetItem(shape.label)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
item.setCheckState(Qt.Checked)
......@@ -943,12 +944,22 @@ class MainWindow(QtWidgets.QMainWindow):
for x, y in points:
shape.addPoint(QtCore.QPoint(x, y))
shape.close()
s.append(shape)
if line_color:
shape.line_color = QtGui.QColor(*line_color)
if fill_color:
shape.fill_color = QtGui.QColor(*fill_color)
shape.flags = flags
default_flags = {}
if self._config['label_flags']:
for l in ['__all__', label]:
for k in self._config['label_flags'].get(l, []):
default_flags[k] = False
shape.flags = default_flags
shape.flags.update(flags)
s.append(shape)
self.loadShapes(s)
def loadFlags(self, flags):
......@@ -1021,11 +1032,10 @@ class MainWindow(QtWidgets.QMainWindow):
def labelSelectionChanged(self):
item = self.currentItem()
if item:
if item and self.canvas.editing():
self._noSelectionSlot = True
shape = self.labelList.get_shape_from_item(item)
if self.canvas.editing():
self._noSelectionSlot = True
self.canvas.selectShape(shape)
self.canvas.selectShape(shape)
def labelItemChanged(self, item):
shape = self.labelList.get_shape_from_item(item)
......@@ -1045,6 +1055,7 @@ class MainWindow(QtWidgets.QMainWindow):
"""
items = self.uniqLabelList.selectedItems()
text = None
flags = None
if items:
text = items[0].text()
if self._config['display_label_popup'] or not text:
......@@ -1066,9 +1077,7 @@ class MainWindow(QtWidgets.QMainWindow):
self.canvas.undoLastLine()
self.canvas.shapesBackups.pop()
else:
shape = self.canvas.setLastLabel(text)
shape.flags = flags
self.addLabel(shape)
self.addLabel(self.canvas.setLastLabel(text, flags))
self.actions.editMode.setEnabled(True)
self.actions.undoLastPoint.setEnabled(False)
self.actions.undo.setEnabled(True)
......
......@@ -89,7 +89,7 @@ class LabelFile(object):
s['line_color'],
s['fill_color'],
s.get('shape_type', 'polygon'),
s['flags'] if 'flags' in s else None
s.get('flags', {}),
)
for s in data['shapes']
)
......
......@@ -81,7 +81,9 @@ def _main():
parser.add_argument(
'--labelflags',
dest='label_flags',
help='yaml string of label specific flags OR file containing json string of label specific flags (ex. {human:[male,female],dog:[big],__all__:[occluded]} )',
help='yaml string of label specific flags OR file containing json '
'string of label specific flags (ex. {person: [male, tall], '
'dog: [big, black, brown, white], __all__: [occluded]})',
default=argparse.SUPPRESS,
)
parser.add_argument(
......@@ -137,11 +139,11 @@ def _main():
else:
args.label_flags = yaml.load(args.label_flags)
# Add not overlapping labels from label flags
# add not overlapping labels from label flags
if not hasattr(args, 'labels'):
args.labels = []
for label in args.label_flags.keys():
if label != "__all__" and label not in args.labels:
if label != '__all__' and label not in args.labels:
args.labels.append(label)
config_from_args = args.__dict__
......
......@@ -36,7 +36,8 @@ class Shape(object):
point_size = 8
scale = 1.0
def __init__(self, label=None, line_color=None, shape_type=None, flags=None):
def __init__(self, label=None, line_color=None, shape_type=None,
flags=None):
self.label = label
self.points = []
self.fill = False
......
......@@ -643,9 +643,10 @@ class Canvas(QtWidgets.QWidget):
elif key == QtCore.Qt.Key_Return and self.canCloseShape():
self.finalise()
def setLastLabel(self, text):
def setLastLabel(self, text, flags):
assert text
self.shapes[-1].label = text
self.shapes[-1].flags = flags
self.shapesBackups.pop()
self.storeShapes()
return self.shapes[-1]
......
......@@ -77,12 +77,13 @@ class LabelDialog(QtWidgets.QDialog):
self.edit.setListWidget(self.labelList)
layout.addWidget(self.labelList)
# label_flags
self.flags = flags
self.label_flags = None
if flags:
self.label_flags = QtWidgets.QVBoxLayout()
self.resetFlags()
layout.addItem(self.label_flags)
if flags is None:
flags = {}
self._flags = flags
self.flagsLayout = QtWidgets.QVBoxLayout()
self.resetFlags()
layout.addItem(self.flagsLayout)
self.edit.textChanged.connect(self.updateFlags)
self.setLayout(layout)
# completion
completer = QtWidgets.QCompleter()
......@@ -114,15 +115,6 @@ class LabelDialog(QtWidgets.QDialog):
def labelSelected(self, item):
self.edit.setText(item.text())
def updateFlags(self, text):
flags = self.getFlags()
newFlags = {}
for label in ["__all__", text]:
if label in self.flags:
for key in self.flags[label]:
newFlags[key] = False if key not in flags else flags[key]
self.setFlags(newFlags)
def validate(self):
text = self.edit.text()
if hasattr(text, 'strip'):
......@@ -140,47 +132,41 @@ class LabelDialog(QtWidgets.QDialog):
text = text.trimmed()
self.edit.setText(text)
def updateFlags(self, label_new):
# keep state of shared flags
flags_old = self.getFlags()
flags_new = {}
for label in ['__all__', label_new]:
for key in self._flags.get(label, []):
flags_new[key] = flags_old.get(key, False)
self.setFlags(flags_new)
def deleteFlags(self):
for i in reversed(range(self.label_flags.count())):
item = self.label_flags.itemAt(i).widget()
self.label_flags.removeWidget(item)
for i in reversed(range(self.flagsLayout.count())):
item = self.flagsLayout.itemAt(i).widget()
self.flagsLayout.removeWidget(item)
item.setParent(None)
def resetFlags(self, text=''):
self.deleteFlags()
# Add all flags
for label in ["__all__", text]:
if label in self.flags:
for key in self.flags[label]:
item = QtWidgets.QCheckBox(key, self)
self.label_flags.addWidget(item)
item.show()
def resetFlags(self, label=None):
flags = {k: False for k in self._flags.get('__all__', [])}
if label:
flags.update({k: False for k in self._flags.get(label, [])})
self.setFlags(flags)
def setFlags(self, flags, text=''):
def setFlags(self, flags):
self.deleteFlags()
# Add flags not set
for label in ["__all__", text]:
if label in self.flags:
for key in self.flags[label]:
if key not in flags:
item = QtWidgets.QCheckBox(key, self)
self.label_flags.addWidget(item)
item.show()
# Add set flags
for key in flags:
item = QtWidgets.QCheckBox(key, self)
item.setChecked(flags[key])
self.label_flags.addWidget(item)
self.flagsLayout.addWidget(item)
item.show()
def getFlags(self):
flags = {}
for i in range(self.label_flags.count()):
item = self.label_flags.itemAt(i).widget()
flags[item.text()] = True if item.isChecked() else False
for i in range(self.flagsLayout.count()):
item = self.flagsLayout.itemAt(i).widget()
flags[item.text()] = item.isChecked()
return flags
def popUp(self, text=None, move=True, flags=None):
......@@ -195,11 +181,10 @@ class LabelDialog(QtWidgets.QDialog):
# if text is None, the previous label in self.edit is kept
if text is None:
text = self.edit.text()
if self.label_flags:
if flags:
self.setFlags(flags)
else:
self.resetFlags(text)
if flags:
self.setFlags(flags)
else:
self.resetFlags(text)
self.edit.setText(text)
self.edit.setSelection(0, len(text))
items = self.labelList.findItems(text, QtCore.Qt.MatchFixedString)
......@@ -213,6 +198,6 @@ class LabelDialog(QtWidgets.QDialog):
if move:
self.move(QtGui.QCursor.pos())
if self.exec_():
return self.edit.text(), self.getFlags() if self.flags else None
return self.edit.text(), self.getFlags()
else:
return None, None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册