提交 c2ba1975 编写于 作者: C cmerchant 提交者: Kentaro Wada

Added the ability to have different flags for each label plus shared flags for all labels

上级 92d2b226
......@@ -100,8 +100,6 @@ class MainWindow(QtWidgets.QMainWindow):
self.label_flag_dock = QtWidgets.QDockWidget('Label Flags', self)
self.label_flag_dock.setObjectName('Label Flags')
self.label_flag_widget = QtWidgets.QListWidget()
if config['label_flags']:
self.loadLabelFlags({k: False for k in config['label_flags']})
self.label_flag_dock.setWidget(self.label_flag_widget)
self.label_flag_widget.itemChanged.connect(self.labelFlagChanged)
......@@ -863,7 +861,7 @@ class MainWindow(QtWidgets.QMainWindow):
return
item = item if item else self.currentItem()
shape = self.labelList.get_shape_from_item(item)
text, flags = self.labelDialog.popUp(item.text() if item else None, flags=(shape.flags if item else None))
text, flags = self.labelDialog.popUp(item.text() if item else None, flags=(shape.flags if shape else None))
if text is None:
return
if not self.validateLabel(text):
......@@ -872,7 +870,9 @@ class MainWindow(QtWidgets.QMainWindow):
.format(text, self._config['validate_label']))
return
shape.flags = flags
self.loadLabelFlags(flags)
shape.label = text
if self._config['label_flags']:
self.loadLabelFlags(flags, shape.label)
item.setText(text)
self.setDirty()
if not self.uniqLabelList.findItems(text, Qt.MatchExactly):
......@@ -920,7 +920,12 @@ class MainWindow(QtWidgets.QMainWindow):
def addLabel(self, shape):
if not shape.flags:
shape.flags = {k: False for k in self._config['label_flags']}
shape.flags = {}
if self._config['label_flags']:
for label in ["__all__", shape.label]:
if label in self._config['label_flags']:
for key in self._config['label_flags'][label]:
shape.flags[key] = False
item = QtWidgets.QListWidgetItem(shape.label)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
item.setCheckState(Qt.Checked)
......@@ -941,8 +946,6 @@ class MainWindow(QtWidgets.QMainWindow):
for shape in shapes:
self.addLabel(shape)
self.canvas.loadShapes(shapes, replace=replace)
if self._config['label_flags']:
self.loadLabelFlags({k: False for k in self._config['label_flags']})
def loadLabels(self, shapes):
s = []
......@@ -967,13 +970,16 @@ class MainWindow(QtWidgets.QMainWindow):
item.setCheckState(Qt.Checked if flag else Qt.Unchecked)
self.flag_widget.addItem(item)
def loadLabelFlags(self, flags):
def loadLabelFlags(self, flags=None, label=''):
self.label_flag_widget.clear()
for key, flag in flags.items():
item = QtWidgets.QListWidgetItem(key)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
item.setCheckState(Qt.Checked if flag else Qt.Unchecked)
self.label_flag_widget.addItem(item)
if flags:
for label in ["__all__", label]:
if label in self._config['label_flags']:
for key in self._config['label_flags'][label]:
item = QtWidgets.QListWidgetItem(key)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
item.setCheckState(Qt.Checked if (key in flags and flags[key]) else Qt.Unchecked)
self.label_flag_widget.addItem(item)
def saveLabels(self, filename):
lf = LabelFile()
......@@ -1039,7 +1045,8 @@ class MainWindow(QtWidgets.QMainWindow):
item = self.currentItem()
if item:
shape = self.labelList.get_shape_from_item(item)
self.loadLabelFlags(shape.flags)
if self._config['label_flags']:
self.loadLabelFlags(shape.flags, shape.label)
if self.canvas.editing():
self._noSelectionSlot = True
self.canvas.selectShape(shape)
......@@ -1056,14 +1063,16 @@ class MainWindow(QtWidgets.QMainWindow):
def labelFlagChanged(self):
item = self.currentItem()
if item:
shape = self.labelList.get_shape_from_item(item)
index = 0
flags = {}
for key in self._config["label_flags"]:
checkBox = self.label_flag_widget.item(index)
index = index + 1
value = True if checkBox.checkState() else False
flags[key] = value
shape = self.labelList.get_shape_from_item(item)
for label in ["__all__", shape.label]:
if label in self._config['label_flags']:
for key in self._config['label_flags'][label]:
checkBox = self.label_flag_widget.item(index)
index = index + 1
value = True if checkBox.checkState() else False
flags[key] = value
if shape.flags != flags:
shape.flags = flags
self.setDirty()
......@@ -1226,6 +1235,8 @@ class MainWindow(QtWidgets.QMainWindow):
self.canvas.loadPixmap(QtGui.QPixmap.fromImage(image))
if self._config['flags']:
self.loadFlags({k: False for k in self._config['flags']})
if self._config['label_flags']:
self.loadLabelFlags()
if self.labelFile:
self.loadLabels(self.labelFile.shapes)
if self.labelFile.flags is not None:
......@@ -1557,8 +1568,8 @@ class MainWindow(QtWidgets.QMainWindow):
self.remLabel(self.canvas.deleteSelected())
self.setDirty()
if self.noShapes():
if self._config['shape_flags']:
self.loadLabelFlags({k: False for k in self._config['shape_flags']})
if self._config['label_flags']:
self.loadLabelFlags()
for action in self.actions.onShapesPresent:
action.setEnabled(False)
......
......@@ -3,6 +3,7 @@ import codecs
import logging
import os
import sys
import yaml
from qtpy import QtWidgets
......@@ -81,7 +82,7 @@ def _main():
parser.add_argument(
'--labelflags',
dest='label_flags',
help='comma separated list of label specific flags OR file containing flags',
help='yaml string of label specific flags OR file containing json string of label specific flags (ex. {human:[male,female],dog:[big],__all__:[occluded]} )',
default=argparse.SUPPRESS,
)
parser.add_argument(
......@@ -122,13 +123,6 @@ def _main():
args.flags = [l.strip() for l in f if l.strip()]
else:
args.flags = [l for l in args.flags.split(',') if l]
if hasattr(args, 'label_flags'):
if os.path.isfile(args.label_flags):
with codecs.open(args.label_flags, 'r', encoding='utf-8') as f:
args.label_flags = [l.strip() for l in f if l.strip()]
else:
args.label_flags = [l for l in args.label_flags.split(',') if l]
if hasattr(args, 'labels'):
if os.path.isfile(args.labels):
......@@ -137,6 +131,20 @@ def _main():
else:
args.labels = [l for l in args.labels.split(',') if l]
if hasattr(args, 'label_flags'):
if os.path.isfile(args.label_flags):
with codecs.open(args.label_flags, 'r', encoding='utf-8') as f:
args.label_flags = yaml.load(f)
else:
args.label_flags = yaml.load(args.label_flags)
# Add not overlapping labels from label flags
if not hasattr(args, 'labels'):
args.labels = []
for label in args.label_flags.keys():
if label != "__all__" and label not in args.labels:
args.labels.append(label)
config_from_args = args.__dict__
config_from_args.pop('version')
reset_config = config_from_args.pop('reset_config')
......
......@@ -29,7 +29,7 @@ class LabelDialog(QtWidgets.QDialog):
def __init__(self, text="Enter object label", parent=None, labels=None,
sort_labels=True, show_text_field=True,
completion='startswith', fit_to_content=None, flags=[]):
completion='startswith', fit_to_content=None, flags=None):
if fit_to_content is None:
fit_to_content = {'row': False, 'column': True}
self._fit_to_content = fit_to_content
......@@ -39,6 +39,8 @@ class LabelDialog(QtWidgets.QDialog):
self.edit.setPlaceholderText(text)
self.edit.setValidator(labelme.utils.labelValidator())
self.edit.editingFinished.connect(self.postProcess)
if flags:
self.edit.textChanged.connect(self.updateFlags)
layout = QtWidgets.QVBoxLayout()
if show_text_field:
layout.addWidget(self.edit)
......@@ -79,8 +81,7 @@ class LabelDialog(QtWidgets.QDialog):
self.label_flags = None
if flags:
self.label_flags = QtWidgets.QVBoxLayout()
for flag in flags:
self.label_flags.addWidget(QtWidgets.QCheckBox(flag, self))
self.resetFlags()
layout.addItem(self.label_flags)
self.setLayout(layout)
# completion
......@@ -113,6 +114,15 @@ class LabelDialog(QtWidgets.QDialog):
def labelSelected(self, item):
self.edit.setText(item.text())
def updateFlags(self, text):
flags = self.getFlags()
newFlags = {}
for label in ["__all__", text]:
if label in self.flags:
for key in self.flags[label]:
newFlags[key] = False if key not in flags else flags[key]
self.setFlags(newFlags)
def validate(self):
text = self.edit.text()
if hasattr(text, 'strip'):
......@@ -130,15 +140,41 @@ class LabelDialog(QtWidgets.QDialog):
text = text.trimmed()
self.edit.setText(text)
def resetFlags(self):
for i in range(self.label_flags.count()):
item = self.label_flags.itemAt(i).widget()
item.setChecked(False)
def setFlags(self, flags):
for i in range(self.label_flags.count()):
def deleteFlags(self):
for i in reversed(range(self.label_flags.count())):
item = self.label_flags.itemAt(i).widget()
item.setChecked(flags[item.text()])
self.label_flags.removeWidget(item)
item.setParent(None)
def resetFlags(self, text=''):
self.deleteFlags()
# Add all flags
for label in ["__all__", text]:
if label in self.flags:
for key in self.flags[label]:
item = QtWidgets.QCheckBox(key, self)
self.label_flags.addWidget(item)
item.show()
def setFlags(self, flags, text=''):
self.deleteFlags()
# Add flags not set
for label in ["__all__", text]:
if label in self.flags:
for key in self.flags[label]:
if key not in flags:
item = QtWidgets.QCheckBox(key, self)
self.label_flags.addWidget(item)
item.show()
# Add set flags
for key in flags:
item = QtWidgets.QCheckBox(key, self)
item.setChecked(flags[key])
self.label_flags.addWidget(item)
item.show()
def getFlags(self):
flags = {}
......@@ -159,10 +195,11 @@ class LabelDialog(QtWidgets.QDialog):
# if text is None, the previous label in self.edit is kept
if text is None:
text = self.edit.text()
if flags:
self.setFlags(flags)
else:
self.resetFlags()
if self.label_flags:
if flags:
self.setFlags(flags)
else:
self.resetFlags(text)
self.edit.setText(text)
self.edit.setSelection(0, len(text))
items = self.labelList.findItems(text, QtCore.Qt.MatchFixedString)
......@@ -175,4 +212,4 @@ class LabelDialog(QtWidgets.QDialog):
self.edit.setFocus(QtCore.Qt.PopupFocusReason)
if move:
self.move(QtGui.QCursor.pos())
return (self.edit.text(), self.getFlags()) if self.exec_() else None
return (self.edit.text(), self.getFlags() if self.flags else None) if self.exec_() else (None, None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册