提交 fc2b9c94 编写于 作者: A Andrey Zhavoronkov 提交者: Nikita Manovich

Az/fix no dump default attrs (#656)

* fill absent attributes by default values during annotation save
* fill absent attributes by default values during init from db
* fixed tests
* updated changelog, added some coments, minor fixes
上级 7fb7ba15
......@@ -39,7 +39,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Installation of CVAT with OpenVINO on the Windows platform
- Background color was always black in utils/mask/converter.py
- Exception in attribute annotation mode when a label are switched to a value without any attributes
- Handling of wrong labelamp json file in auto annotation (https://github.com/opencv/cvat/issues/554)
- Handling of wrong labelamp json file in auto annotation (<https://github.com/opencv/cvat/issues/554>)
- No default attributes in dumped annotation (<https://github.com/opencv/cvat/issues/601>)
### Security
-
......
......@@ -4,6 +4,7 @@
import os
from enum import Enum
from collections import OrderedDict
from django.utils import timezone
from PIL import Image
......@@ -192,9 +193,21 @@ class JobAnnotation:
self.logger = slogger.job[self.db_job.id]
self.db_labels = {db_label.id:db_label
for db_label in db_segment.task.label_set.all()}
self.db_attributes = {db_attr.id:db_attr
for db_attr in models.AttributeSpec.objects.filter(
label__task__id=db_segment.task.id)}
self.db_attributes = {}
for db_label in self.db_labels.values():
self.db_attributes[db_label.id] = {
"mutable": OrderedDict(),
"immutable": OrderedDict(),
"all": OrderedDict(),
}
for db_attr in db_label.attributespec_set.all():
if db_attr.mutable:
self.db_attributes[db_label.id]["mutable"][db_attr.id] = db_attr
else:
self.db_attributes[db_label.id]["immutable"][db_attr.id] = db_attr
self.db_attributes[db_label.id]["all"][db_attr.id] = db_attr
def reset(self):
self.ir_data.reset()
......@@ -214,7 +227,7 @@ class JobAnnotation:
for attr in track_attributes:
db_attrval = models.LabeledTrackAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes:
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["immutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.track_id = len(db_tracks)
db_track_attrvals.append(db_attrval)
......@@ -228,7 +241,7 @@ class JobAnnotation:
for attr in shape_attributes:
db_attrval = models.TrackedShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes:
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["mutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.shape_id = len(db_shapes)
db_shape_attrvals.append(db_attrval)
......@@ -295,8 +308,9 @@ class JobAnnotation:
for attr in attributes:
db_attrval = models.LabeledShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes:
if db_attrval.spec_id not in self.db_attributes[db_shape.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.shape_id = len(db_shapes)
db_attrvals.append(db_attrval)
......@@ -335,7 +349,7 @@ class JobAnnotation:
for attr in attributes:
db_attrval = models.LabeledImageAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes:
if db_attrval.spec_id not in self.db_attributes[db_tag.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.tag_id = len(db_tags)
db_attrvals.append(db_attrval)
......@@ -350,7 +364,7 @@ class JobAnnotation:
)
for db_attrval in db_attrvals:
db_attrval.tag_id = db_tags[db_attrval.tag_id].id
db_attrval.image_id = db_tags[db_attrval.tag_id].id
bulk_create(
db_model=models.LabeledImageAttributeVal,
......@@ -436,6 +450,16 @@ class JobAnnotation:
self._delete(data)
self._commit()
@staticmethod
def _extend_attributes(attributeval_set, attribute_specs):
shape_attribute_specs_set = set(attr.spec_id for attr in attributeval_set)
for db_attr_spec in attribute_specs:
if db_attr_spec.id not in shape_attribute_specs_set:
attributeval_set.append(OrderedDict([
('spec_id', db_attr_spec.id),
('value', db_attr_spec.default_value),
]))
def _init_tags_from_db(self):
db_tags = self.db_job.labeledimage_set.prefetch_related(
"label",
......@@ -461,6 +485,11 @@ class JobAnnotation:
},
field_id='id',
)
for db_tag in db_tags:
self._extend_attributes(db_tag.labeledimageattributeval_set,
self.db_attributes[db_tag.label_id]["all"].values())
serializer = serializers.LabeledImageSerializer(db_tags, many=True)
self.ir_data.tags = serializer.data
......@@ -493,6 +522,9 @@ class JobAnnotation:
},
field_id='id',
)
for db_shape in db_shapes:
self._extend_attributes(db_shape.labeledshapeattributeval_set,
self.db_attributes[db_shape.label_id]["all"].values())
serializer = serializers.LabeledShapeSerializer(db_shapes, many=True)
self.ir_data.shapes = serializer.data
......@@ -558,10 +590,15 @@ class JobAnnotation:
# A result table can consist many equal rows for track/shape attributes
# We need filter unique attributes manually
db_track["labeledtrackattributeval_set"] = list(set(db_track["labeledtrackattributeval_set"]))
self._extend_attributes(db_track.labeledtrackattributeval_set,
self.db_attributes[db_track.label_id]["immutable"].values())
for db_shape in db_track["trackedshape_set"]:
db_shape["trackedshapeattributeval_set"] = list(
set(db_shape["trackedshapeattributeval_set"])
)
self._extend_attributes(db_shape["trackedshapeattributeval_set"],
self.db_attributes[db_track.label_id]["mutable"].values())
serializer = serializers.LabeledTrackSerializer(db_tracks, many=True)
self.ir_data.tracks = serializer.data
......
......@@ -1177,7 +1177,7 @@ class JobAnnotationAPITestCase(APITestCase):
"mutable": False,
"input_type": "select",
"default_value": "mazda",
"values": ["bmw", "mazda", "reno"]
"values": ["bmw", "mazda", "renault"]
},
{
"name": "parked",
......@@ -1212,6 +1212,27 @@ class JobAnnotationAPITestCase(APITestCase):
return (task, jobs)
@staticmethod
def _get_default_attr_values(task):
default_attr_values = {}
for label in task["labels"]:
default_attr_values[label["id"]] = {
"mutable": [],
"immutable": [],
"all": [],
}
for attr in label["attributes"]:
default_value = {
"spec_id": attr["id"],
"value": attr["default_value"],
}
if attr["mutable"]:
default_attr_values[label["id"]]["mutable"].append(default_value)
else:
default_attr_values[label["id"]]["immutable"].append(default_value)
default_attr_values[label["id"]]["all"].append(default_value)
return default_attr_values
def _put_api_v1_jobs_id_data(self, jid, user, data):
with ForceLogin(user, self.client):
response = self.client.put("/api/v1/jobs/{}/annotations".format(jid),
......@@ -1288,7 +1309,7 @@ class JobAnnotationAPITestCase(APITestCase):
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
],
"points": [1.0, 2.1, 100, 300.222],
......@@ -1310,7 +1331,12 @@ class JobAnnotationAPITestCase(APITestCase):
"frame": 0,
"label_id": task["labels"][0]["id"],
"group": None,
"attributes": [],
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [
{
"frame": 0,
......@@ -1319,14 +1345,10 @@ class JobAnnotationAPITestCase(APITestCase):
"occluded": False,
"outside": False,
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
}
"value": task["labels"][0]["attributes"][1]["default_value"]
},
]
},
{
......@@ -1357,6 +1379,8 @@ class JobAnnotationAPITestCase(APITestCase):
},
]
}
default_attr_values = self._get_default_attr_values(task)
response = self._put_api_v1_jobs_id_data(job["id"], annotator, data)
data["version"] += 1 # need to update the version
self.assertEqual(response.status_code, HTTP_200_OK)
......@@ -1364,6 +1388,9 @@ class JobAnnotationAPITestCase(APITestCase):
response = self._get_api_v1_jobs_id_data(job["id"], annotator)
self.assertEqual(response.status_code, HTTP_200_OK)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self._check_response(response, data)
response = self._delete_api_v1_jobs_id_data(job["id"], annotator)
......@@ -1402,7 +1429,7 @@ class JobAnnotationAPITestCase(APITestCase):
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
],
"points": [1.0, 2.1, 100, 300.222],
......@@ -1424,7 +1451,12 @@ class JobAnnotationAPITestCase(APITestCase):
"frame": 0,
"label_id": task["labels"][0]["id"],
"group": None,
"attributes": [],
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [
{
"frame": 0,
......@@ -1433,14 +1465,10 @@ class JobAnnotationAPITestCase(APITestCase):
"occluded": False,
"outside": False,
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
}
"value": task["labels"][0]["attributes"][1]["default_value"]
},
]
},
{
......@@ -1479,6 +1507,9 @@ class JobAnnotationAPITestCase(APITestCase):
response = self._get_api_v1_jobs_id_data(job["id"], annotator)
self.assertEqual(response.status_code, HTTP_200_OK)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self._check_response(response, data)
data = response.data
......@@ -1576,7 +1607,7 @@ class JobAnnotationAPITestCase(APITestCase):
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
]
},
......@@ -1733,7 +1764,12 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
"frame": 0,
"label_id": task["labels"][0]["id"],
"group": None,
"attributes": [],
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [
{
"frame": 0,
......@@ -1742,13 +1778,9 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
"occluded": False,
"outside": False,
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
]
},
......@@ -1782,10 +1814,15 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
}
response = self._put_api_v1_tasks_id_annotations(task["id"], annotator, data)
data["version"] += 1
self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data)
default_attr_values = self._get_default_attr_values(task)
response = self._get_api_v1_tasks_id_annotations(task["id"], annotator)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data)
......@@ -1847,7 +1884,12 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
"frame": 0,
"label_id": task["labels"][0]["id"],
"group": None,
"attributes": [],
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [
{
"frame": 0,
......@@ -1856,13 +1898,9 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
"occluded": False,
"outside": False,
"attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{
"spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"]
"value": task["labels"][0]["attributes"][1]["default_value"]
}
]
},
......@@ -1901,6 +1939,9 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
self._check_response(response, data)
response = self._get_api_v1_tasks_id_annotations(task["id"], annotator)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册