未验证 提交 ba308336 编写于 作者: A Anastasia Yasakova 提交者: GitHub

Fix project export with skeletons (#5004)

上级 fd666d00
......@@ -25,7 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
non-ascii paths while adding files from "Connected file share" (issue #4428)
- Removed unnecessary volumes defined in docker-compose.serverless.yml
(<https://github.com/openvinotoolkit/cvat/pull/4659>)
- Project import with skeletons (<https://github.com/opencv/cvat/pull/4867>)
- Project import with skeletons (<https://github.com/opencv/cvat/pull/4867>,
<https://github.com/opencv/cvat/pull/5004>)
### Security
- TDB
......
# Copyright (C) 2019-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
......@@ -426,15 +427,21 @@ class JobAnnotation:
)
shapes = {}
elements = {}
for db_shape in db_shapes:
self._extend_attributes(db_shape.labeledshapeattributeval_set,
self.db_attributes[db_shape.label_id]["all"].values())
db_shape.elements = []
if db_shape.parent is None:
shapes[db_shape.id] = db_shape
else:
shapes[db_shape.parent].elements.append(db_shape)
if db_shape.parent not in elements:
elements[db_shape.parent] = []
elements[db_shape.parent].append(db_shape)
for shape_id, shape_elements in elements.items():
shapes[shape_id].elements = shape_elements
serializer = serializers.LabeledShapeSerializer(list(shapes.values()), many=True)
self.ir_data.shapes = serializer.data
......@@ -493,6 +500,7 @@ class JobAnnotation:
)
tracks = {}
elements = {}
for db_track in db_tracks:
db_track["trackedshape_set"] = _merge_table_rows(db_track["trackedshape_set"], {
'trackedshapeattributeval_set': [
......@@ -518,11 +526,15 @@ class JobAnnotation:
self._extend_attributes(db_shape["trackedshapeattributeval_set"], default_attribute_values)
default_attribute_values = db_shape["trackedshapeattributeval_set"]
db_track.elements = []
if db_track.parent is None:
tracks[db_track.id] = db_track
else:
tracks[db_track.parent].elements.append(db_track)
if db_track.parent not in elements:
elements[db_track.parent] = []
elements[db_track.parent].append(db_track)
for track_id, track_elements in elements.items():
tracks[track_id].elements = track_elements
serializer = serializers.LabeledTrackSerializer(list(tracks.values()), many=True)
self.ir_data.tracks = serializer.data
......
......@@ -13,7 +13,7 @@ from cvat_sdk.core.helpers import get_paginated_collection
import pytest
from deepdiff import DeepDiff
from shared.utils.config import make_api_client
from shared.utils.config import get_method, make_api_client, patch_method
from shared.utils.helpers import generate_image_files
from .utils import export_dataset
......@@ -309,6 +309,174 @@ class TestPostTaskData:
(task, _) = api_client.tasks_api.retrieve(task_id)
assert task.size == 4
def test_can_get_annotations_from_new_task_with_skeletons(self):
spec = {
"name": f'test admin1 to create a task with skeleton',
"labels": [
{
"name": "s1",
"color": "#5c5eba",
"attributes": [],
"type": "skeleton",
"sublabels": [
{
"name": "1",
"color": "#d12345",
"attributes": [],
"type": "points"
},
{
"name": "2",
"color": "#350dea",
"attributes": [],
"type": "points"
}
],
"svg": "<line x1=\"19.464284896850586\" y1=\"21.922269821166992\" x2=\"54.08613586425781\" y2=\"43.60293960571289\" " \
"stroke=\"black\" data-type=\"edge\" data-node-from=\"1\" stroke-width=\"0.5\" data-node-to=\"2\"></line>" \
"<circle r=\"1.5\" stroke=\"black\" fill=\"#b3b3b3\" cx=\"19.464284896850586\" cy=\"21.922269821166992\" " \
"stroke-width=\"0.1\" data-type=\"element node\" data-element-id=\"1\" data-node-id=\"1\" data-label-id=\"103\"></circle>" \
"<circle r=\"1.5\" stroke=\"black\" fill=\"#b3b3b3\" cx=\"54.08613586425781\" cy=\"43.60293960571289\" " \
"stroke-width=\"0.1\" data-type=\"element node\" data-element-id=\"2\" data-node-id=\"2\" data-label-id=\"104\"></circle>"
}
]
}
task_data = {
'image_quality': 75,
'client_files': generate_image_files(3),
}
task_id = self._test_create_task(self._USERNAME, spec, task_data,
content_type="multipart/form-data")
response = get_method(self._USERNAME, f"tasks/{task_id}")
label_ids = {}
for label in response.json()["labels"]:
label_ids.setdefault(label["type"], []).append(label["id"])
job_id = response.json()["segments"][0]["jobs"][0]["id"]
patch_data = {
"shapes": [{
"type": "skeleton",
"occluded": False,
"outside": False,
"z_order": 0,
"rotation": 0,
"points": [],
"frame": 0,
"label_id": label_ids["skeleton"][0],
"group": 0,
"source": "manual",
"attributes": [],
"elements": [
{
"type": "points",
"occluded": False,
"outside": False,
"z_order": 0,
"rotation": 0,
"points": [
131.63947368421032,
165.0868421052637
],
"frame": 0,
"label_id": label_ids["points"][0],
"group": 0,
"source": "manual",
"attributes": []
},
{
"type": "points",
"occluded": False,
"outside": False,
"z_order": 0,
"rotation": 0,
"points": [
354.98157894736823,
304.2710526315795
],
"frame": 0,
"label_id": label_ids["points"][1],
"group": 0,
"source": "manual",
"attributes": []
}
]
}],
"tracks": [{
"frame": 0,
"label_id": label_ids["skeleton"][0],
"group": 0,
"source": "manual",
"shapes": [
{
"type": "skeleton",
"occluded": False,
"outside": False,
"z_order": 0,
"rotation": 0,
"points": [],
"frame": 0,
"attributes": []
}
],
"attributes": [],
"elements": [
{
"frame": 0,
"label_id": label_ids["points"][0],
"group": 0,
"source": "manual",
"shapes": [
{
"type": "points",
"occluded": False,
"outside": False,
"z_order": 0,
"rotation": 0,
"points": [
295.6394736842103,
472.5868421052637
],
"frame": 0,
"attributes": []
}
],
"attributes": []
},
{
"frame": 0,
"label_id": label_ids["points"][1],
"group": 0,
"source": "manual",
"shapes": [
{
"type": "points",
"occluded": False,
"outside": False,
"z_order": 0,
"rotation": 0,
"points": [
619.3236842105262,
846.9815789473689
],
"frame": 0,
"attributes": []
}
],
"attributes": []
}
]
}],
"tags": [],
"version": 0
}
response = patch_method(self._USERNAME, f"jobs/{job_id}/annotations", patch_data, action="create")
response = get_method(self._USERNAME, f"jobs/{job_id}/annotations")
assert response.status_code == HTTPStatus.OK
@pytest.mark.parametrize('cloud_storage_id, manifest, use_bucket_content, org', [
(1, 'manifest.jsonl', False, ''), # public bucket
(2, 'sub/manifest.jsonl', True, 'org2'), # private bucket
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册