提交 5733423b 编写于 作者: A Andrey Zhavoronkov 提交者: Nikita Manovich

Az/multiformat downloader (#551)

上级 efa47a3a
## Description
The purpose of this application is to add support for multiple annotation formats for CVAT.
It allows to download and upload annotations in different formats and easily add support for new.
## How to add a new annotation format support
1. Write a python script that will be executed via exec() function. Following items must be defined inside at code:
- **format_spec** - a dictionary with the following structure:
```python
format_spec = {
"name": "CVAT",
"dumpers": [
{
"display_name": "{name} {format} {version} for videos",
"format": "XML",
"version": "1.1",
"handler": "dump_as_cvat_interpolation"
},
{
"display_name": "{name} {format} {version} for images",
"format": "XML",
"version": "1.1",
"handler": "dump_as_cvat_annotation"
}
],
"loaders": [
{
"display_name": "{name} {format} {version}",
"format": "XML",
"version": "1.1",
"handler": "load",
}
],
}
```
- **name** - unique name for each format
- **dumpers and loaders** - lists of objects that describes exposed dumpers and loaders and must
have following keys:
1. display_name - **unique** string used as ID for a dumpers and loaders.
Also this string is displayed in CVAT UI.
Possible to use a named placeholders like the python format function
(supports only name, format and version variables).
1. format - a string, used as extension for a dumped annotation.
1. version - just string with version.
1. handler - function that will be called and should be defined at top scope.
- dumper/loader handler functions. Each function should have the following signature:
```python
def dump_handler(file_object, annotations):
```
Inside of the script environment 3 variables are available:
- file_object - python's standard file object returned by open() function and exposing a file-oriented API
(with methods such as read() or write()) to an underlying resource.
- **annotations** - instance of [Annotation](annotation.py#L106) class.
- **spec** - string with name of the requested specification
(if the annotation format defines them).
It may be useful if one script implements more than one format support.
Annotation class expose API and some additional pre-defined types that allow to get/add shapes inside
a parser/dumper code.
Short description of the public methods:
- **Annotation.shapes** - property, returns a generator of Annotation.LabeledShape objects
- **Annotation.tracks** - property, returns a generator of Annotation.Track objects
- **Annotation.tags** - property, returns a generator of Annotation.Tag objects
- **Annotation.group_by_frame()** - method, returns an iterator on Annotation.Frame object,
which groups annotation objects by frame. Note that TrackedShapes will be represented as Annotation.LabeledShape.
- **Annotation.meta** - property, returns dictionary which represent a task meta information,
for example - video source name, number of frames, number of jobs, etc
- **Annotation.add_tag(tag)** - tag should be a instance of the Annotation.Tag class
- **Annotation.add_shape(shape)** - shape should be a instance of the Annotation.Shape class
- **Annotation.add_track(track)** - track should be a instance of the Annotation.Track class
- **Annotation.Attribute** = namedtuple('Attribute', 'name, value')
- name - String, name of the attribute
- value - String, value of the attribute
- **Annotation.LabeledShape** = namedtuple('LabeledShape', 'type, frame, label, points, occluded, attributes,
group, z_order')
LabeledShape.\__new\__.\__defaults\__ = (0, None)
- **TrackedShape** = namedtuple('TrackedShape', 'type, points, occluded, frame, attributes, outside,
keyframe, z_order')
TrackedShape.\__new\__.\__defaults\__ = (None, )
- **Track** = namedtuple('Track', 'label, group, shapes')
- **Tag** = namedtuple('Tag', 'frame, label, attributes, group')
Tag.\__new\__.\__defaults\__ = (0, )
- **Frame** = namedtuple('Frame', 'frame, name, width, height, labeled_shapes, tags')
Pseudocode for a dumper script
```python
...
# dump meta info if necessary
...
# iterate over all frames
for frame_annotation in annotations.group_by_frame():
# get frame info
image_name = frame_annotation.name
image_width = frame_annotation.width
image_height = frame_annotation.height
# iterate over all shapes on the frame
for shape in frame_annotation.labeled_shapes:
label = shape.label
xtl = shape.points[0]
ytl = shape.points[1]
xbr = shape.points[2]
ybr = shape.points[3]
# iterate over shape attributes
for attr in shape.attributes:
attr_name = attr.name
attr_value = attr.value
...
# dump annotation code
file_object.write(...)
...
```
Pseudocode for a parser code
```python
...
#read file_object
...
for parsed_shape in parsed_shapes:
shape = annotations.LabeledShape(
type="rectangle",
points=[0, 0, 100, 100],
occluded=False,
attributes=[],
label="car",
outside=False,
frame=99,
)
annotations.add_shape(shape)
```
Full examples can be found in [builtin](builtin) folder.
1. Add path to a new python script to the annotation app settings:
```python
BUILTIN_FORMATS = (
os.path.join(path_prefix, 'cvat.py'),
os.path.join(path_prefix,'pascal_voc.py'),
)
```
## Ideas for improvements
- Annotation format manager like DL Model manager with which the user can add custom format support by
writing dumper/loader scripts.
- Often a custom loader/dumper requires additional python packages and it would be useful if CVAT provided some API
that allows the user to install a python dependencies from their own code without changing the source code.
Possible solutions: install additional modules via pip call to a separate directory for each Annotation Format
to reduce version conflicts, etc. Thus, custom code can be run in an extended environment, and core CVAT modules
should not be affected. As well, this functionality can be useful for Auto Annotation module.
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
default_app_config = 'cvat.apps.annotation.apps.AnnotationConfig'
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os
import copy
from collections import OrderedDict, namedtuple
from django.utils import timezone
from cvat.apps.engine.data_manager import DataManager, TrackManager
from cvat.apps.engine.serializers import LabeledDataSerializer
class AnnotationIR:
def __init__(self, data=None):
self.reset()
if data:
self._tags = getattr(data, 'tags', []) or data['tags']
self._shapes = getattr(data, 'shapes', []) or data['shapes']
self._tracks = getattr(data, 'tracks', []) or data['tracks']
def add_tag(self, tag):
self._tags.append(tag)
def add_shape(self, shape):
self._shapes.append(shape)
def add_track(self, track):
self._tracks.append(track)
@property
def tags(self):
return self._tags
@property
def shapes(self):
return self._shapes
@property
def tracks(self):
return self._tracks
@property
def version(self):
return self._version
@tags.setter
def tags(self, tags):
self._tags = tags
@shapes.setter
def shapes(self, shapes):
self._shapes = shapes
@tracks.setter
def tracks(self, tracks):
self._tracks = tracks
@version.setter
def version(self, version):
self._version = version
def __getitem__(self, key):
return getattr(self, key)
@property
def data(self):
return {
'version': self.version,
'tags': self.tags,
'shapes': self.shapes,
'tracks': self.tracks,
}
def serialize(self):
serializer = LabeledDataSerializer(data=self.data)
if serializer.is_valid(raise_exception=True):
return serializer.data
#makes a data copy from specified frame interval
def slice(self, start, stop):
is_frame_inside = lambda x: (start <= int(x['frame']) <= stop)
splitted_data = AnnotationIR()
splitted_data.tags = copy.deepcopy(list(filter(is_frame_inside, self.tags)))
splitted_data.shapes = copy.deepcopy(list(filter(is_frame_inside, self.shapes)))
splitted_data.tracks = copy.deepcopy(list(filter(lambda y: len(list(filter(is_frame_inside, y['shapes']))), self.tracks)))
return splitted_data
@data.setter
def data(self, data):
self.version = data['version']
self.tags = data['tags']
self.shapes = data['shapes']
self.tracks = data['tracks']
def reset(self):
self._version = 0
self._tags = []
self._shapes = []
self._tracks = []
class Annotation:
Attribute = namedtuple('Attribute', 'name, value')
LabeledShape = namedtuple('LabeledShape', 'type, frame, label, points, occluded, attributes, group, z_order')
LabeledShape.__new__.__defaults__ = (0, 0)
TrackedShape = namedtuple('TrackedShape', 'type, points, occluded, frame, attributes, outside, keyframe, z_order')
TrackedShape.__new__.__defaults__ = (0, )
Track = namedtuple('Track', 'label, group, shapes')
Tag = namedtuple('Tag', 'frame, label, attributes, group')
Tag.__new__.__defaults__ = (0, )
Frame = namedtuple('Frame', 'frame, name, width, height, labeled_shapes, tags')
def __init__(self, annotation_ir, db_task, scheme='', host='', create_callback=None):
self._annotation_ir = annotation_ir
self._db_task = db_task
self._scheme = scheme
self._host = host
self._create_callback=create_callback
self._MAX_ANNO_SIZE=30000
db_labels = self._db_task.label_set.all().prefetch_related('attributespec_set')
self._label_mapping = {db_label.id: db_label for db_label in db_labels}
self._attribute_mapping = {
'mutable': {},
'immutable': {},
}
for db_label in db_labels:
for db_attribute in db_label.attributespec_set.all():
if db_attribute.mutable:
self._attribute_mapping['mutable'][db_attribute.id] = db_attribute.name
else:
self._attribute_mapping['immutable'][db_attribute.id] = db_attribute.name
self._attribute_mapping_merged = {
**self._attribute_mapping['mutable'],
**self._attribute_mapping['immutable'],
}
self._init_frame_info()
self._init_meta()
def _get_label_id(self, label_name):
for db_label in self._label_mapping.values():
if label_name == db_label.name:
return db_label.id
return None
def _get_label_name(self, label_id):
return self._label_mapping[label_id].name
def _get_attribute_name(self, attribute_id):
return self._attribute_mapping_merged[attribute_id]
def _get_attribute_id(self, attribute_name, attribute_type=None):
if attribute_type:
container = self._attribute_mapping[attribute_type]
else:
container = self._attribute_mapping_merged
for attr_id, attr_name in container.items():
if attribute_name == attr_name:
return attr_id
return None
def _get_mutable_attribute_id(self, attribute_name):
return self._get_attribute_id(attribute_name, 'mutable')
def _get_immutable_attribute_id(self, attribute_name):
return self._get_attribute_id(attribute_name, 'immutable')
def _init_frame_info(self):
if self._db_task.mode == "interpolation":
self._frame_info = {
frame: {
"path": "frame_{:06d}".format(frame),
"width": self._db_task.video.width,
"height": self._db_task.video.height,
} for frame in range(self._db_task.size)
}
else:
self._frame_info = {db_image.frame: {
"path": db_image.path,
"width": db_image.width,
"height": db_image.height,
} for db_image in self._db_task.image_set.all()}
def _init_meta(self):
db_segments = self._db_task.segment_set.all().prefetch_related('job_set')
self._meta = OrderedDict([
("task", OrderedDict([
("id", str(self._db_task.id)),
("name", self._db_task.name),
("size", str(self._db_task.size)),
("mode", self._db_task.mode),
("overlap", str(self._db_task.overlap)),
("bugtracker", self._db_task.bug_tracker),
("created", str(timezone.localtime(self._db_task.created_date))),
("updated", str(timezone.localtime(self._db_task.updated_date))),
("start_frame", str(self._db_task.start_frame)),
("stop_frame", str(self._db_task.stop_frame)),
("frame_filter", self._db_task.frame_filter),
("z_order", str(self._db_task.z_order)),
("labels", [
("label", OrderedDict([
("name", db_label.name),
("attributes", [
("attribute", OrderedDict([
("name", db_attr.name),
("mutable", str(db_attr.mutable)),
("input_type", db_attr.input_type),
("default_value", db_attr.default_value),
("values", db_attr.values)]))
for db_attr in db_label.attributespec_set.all()])
])) for db_label in self._label_mapping.values()
]),
("segments", [
("segment", OrderedDict([
("id", str(db_segment.id)),
("start", str(db_segment.start_frame)),
("stop", str(db_segment.stop_frame)),
("url", "{0}://{1}/?id={2}".format(
self._scheme, self._host, db_segment.job_set.all()[0].id))]
)) for db_segment in db_segments
]),
("owner", OrderedDict([
("username", self._db_task.owner.username),
("email", self._db_task.owner.email)
]) if self._db_task.owner else ""),
("assignee", OrderedDict([
("username", self._db_task.assignee.username),
("email", self._db_task.assignee.email)
]) if self._db_task.assignee else ""),
])),
("dumped", str(timezone.localtime(timezone.now())))
])
if self._db_task.mode == "interpolation":
self._meta["task"]["original_size"] = OrderedDict([
("width", str(self._db_task.video.width)),
("height", str(self._db_task.video.height))
])
# Add source to dumped file
self._meta["source"] = str(os.path.basename(self._db_task.video.path))
def _export_attributes(self, attributes):
exported_attributes = []
for attr in attributes:
db_attribute = self._attribute_mapping_merged[attr["spec_id"]]
exported_attributes.append(Annotation.Attribute(
name=db_attribute,
value=attr["value"],
))
return exported_attributes
def _export_tracked_shape(self, shape):
return Annotation.TrackedShape(
type=shape["type"],
frame=self._db_task.start_frame + shape["frame"] * self._db_task.get_frame_step(),
points=shape["points"],
occluded=shape["occluded"],
outside=shape.get("outside", False),
keyframe=shape.get("keyframe", True),
z_order=shape["z_order"],
attributes=self._export_attributes(shape["attributes"]),
)
def _export_labeled_shape(self, shape):
return Annotation.LabeledShape(
type=shape["type"],
label=self._get_label_name(shape["label_id"]),
frame=self._db_task.start_frame + shape["frame"] * self._db_task.get_frame_step(),
points=shape["points"],
occluded=shape["occluded"],
z_order=shape.get("z_order", 0),
group=shape.get("group", 0),
attributes=self._export_attributes(shape["attributes"]),
)
def _export_tag(self, tag):
return Annotation.Tag(
frame=self._db_task.start_frame + tag["frame"] * self._db_task.get_frame_step(),
label=self._get_label_name(tag["label_id"]),
group=tag.get("group", 0),
attributes=self._export_attributes(tag["attributes"]),
)
def group_by_frame(self):
def _get_frame(annotations, shape):
db_image = self._frame_info[shape["frame"]]
frame = self._db_task.start_frame + shape["frame"] * self._db_task.get_frame_step()
rpath = db_image['path'].split(os.path.sep)
if len(rpath) != 1:
rpath = os.path.sep.join(rpath[rpath.index(".upload")+1:])
else:
rpath = rpath[0]
if frame not in annotations:
annotations[frame] = Annotation.Frame(
frame=frame,
name=rpath,
height=db_image["height"],
width=db_image["width"],
labeled_shapes=[],
tags=[],
)
return annotations[frame]
annotations = {}
data_manager = DataManager(self._annotation_ir)
for shape in data_manager.to_shapes(self._db_task.size):
_get_frame(annotations, shape).labeled_shapes.append(self._export_labeled_shape(shape))
for tag in self._annotation_ir.tags:
_get_frame(annotations, tag).tags.append(self._export_tag(tag))
return iter(annotations.values())
@property
def shapes(self):
for shape in self._annotation_ir.shapes:
yield self._export_labeled_shape(shape)
@property
def tracks(self):
for track in self._annotation_ir.tracks:
tracked_shapes = TrackManager.get_interpolated_shapes(track, 0, self._db_task.size)
yield Annotation.Track(
label=self._get_label_name(track["label_id"]),
group=track['group'],
shapes=[self._export_tracked_shape(shape) for shape in tracked_shapes],
)
@property
def tags(self):
for tag in self._annotation_ir.tags:
yield self._export_tag(tag)
@property
def meta(self):
return self._meta
def _import_tag(self, tag):
_tag = tag._asdict()
_tag['label_id'] = self._get_label_id(_tag.pop('label'))
_tag['attributes'] = [self._import_attribute(attrib) for attrib in _tag['attributes'] if self._get_attribute_id(attrib.name)]
return _tag
def _import_attribute(self, attribute):
return {
'spec_id': self._get_attribute_id(attribute.name),
'value': attribute.value,
}
def _import_shape(self, shape):
_shape = shape._asdict()
_shape['label_id'] = self._get_label_id(_shape.pop('label'))
_shape['attributes'] = [self._import_attribute(attrib) for attrib in _shape['attributes'] if self._get_attribute_id(attrib.name)]
return _shape
def _import_track(self, track):
_track = track._asdict()
_track['frame'] = min(shape.frame for shape in _track['shapes'])
_track['label_id'] = self._get_label_id(_track.pop('label'))
_track['attributes'] = []
_track['shapes'] = [shape._asdict() for shape in _track['shapes']]
for shape in _track['shapes']:
_track['attributes'] = [self._import_attribute(attrib) for attrib in shape['attributes'] if self._get_immutable_attribute_id(attrib.name)]
shape['attributes'] = [self._import_attribute(attrib) for attrib in shape['attributes'] if self._get_mutable_attribute_id(attrib.name)]
return _track
def _call_callback(self):
if self._len() > self._MAX_ANNO_SIZE:
self._create_callback(self._annotation_ir.serialize())
self._annotation_ir.reset()
def add_tag(self, tag):
imported_tag = self._import_tag(tag)
if imported_tag['label_id']:
self._annotation_ir.add_tag(imported_tag)
self._call_callback()
def add_shape(self, shape):
imported_shape = self._import_shape(shape)
if imported_shape['label_id']:
self._annotation_ir.add_shape(imported_shape)
self._call_callback()
def add_track(self, track):
imported_track = self._import_track(track)
if imported_track['label_id']:
self._annotation_ir.add_track(imported_track)
self._call_callback()
@property
def data(self):
return self._annotation_ir
def _len(self):
track_len = 0
for track in self._annotation_ir.tracks:
track_len += len(track['shapes'])
return len(self._annotation_ir.tags) + len(self._annotation_ir.shapes) + track_len
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
from django.apps import AppConfig
from django.db.models.signals import post_migrate
from cvat.apps.annotation.settings import BUILTIN_FORMATS
def register_builtins_callback(sender, **kwargs):
from .format import register_format
for builtin_format in BUILTIN_FORMATS:
register_format(builtin_format)
class AnnotationConfig(AppConfig):
name = 'cvat.apps.annotation'
def ready(self):
post_migrate.connect(register_builtins_callback, sender=self)
format_spec = {
"name": "CVAT",
"dumpers": [
{
"display_name": "{name} {format} {version} for videos",
"format": "XML",
"version": "1.1",
"handler": "dump_as_cvat_interpolation"
},
{
"display_name": "{name} {format} {version} for images",
"format": "XML",
"version": "1.1",
"handler": "dump_as_cvat_annotation"
}
],
"loaders": [
{
"display_name": "{name} {format} {version}",
"format": "XML",
"version": "1.1",
"handler": "load",
}
],
}
def pairwise(iterable):
a = iter(iterable)
return zip(a, a)
def create_xml_dumper(file_object):
from xml.sax.saxutils import XMLGenerator
from collections import OrderedDict
class XmlAnnotationWriter:
def __init__(self, file):
self.version = "1.1"
self.file = file
self.xmlgen = XMLGenerator(self.file, 'utf-8')
self._level = 0
def _indent(self, newline = True):
if newline:
self.xmlgen.ignorableWhitespace("\n")
self.xmlgen.ignorableWhitespace(" " * self._level)
def _add_version(self):
self._indent()
self.xmlgen.startElement("version", {})
self.xmlgen.characters(self.version)
self.xmlgen.endElement("version")
def open_root(self):
self.xmlgen.startDocument()
self.xmlgen.startElement("annotations", {})
self._level += 1
self._add_version()
def _add_meta(self, meta):
self._level += 1
for k, v in meta.items():
if isinstance(v, OrderedDict):
self._indent()
self.xmlgen.startElement(k, {})
self._add_meta(v)
self._indent()
self.xmlgen.endElement(k)
elif isinstance(v, list):
self._indent()
self.xmlgen.startElement(k, {})
for tup in v:
self._add_meta(OrderedDict([tup]))
self._indent()
self.xmlgen.endElement(k)
else:
self._indent()
self.xmlgen.startElement(k, {})
self.xmlgen.characters(v)
self.xmlgen.endElement(k)
self._level -= 1
def add_meta(self, meta):
self._indent()
self.xmlgen.startElement("meta", {})
self._add_meta(meta)
self._indent()
self.xmlgen.endElement("meta")
def open_track(self, track):
self._indent()
self.xmlgen.startElement("track", track)
self._level += 1
def open_image(self, image):
self._indent()
self.xmlgen.startElement("image", image)
self._level += 1
def open_box(self, box):
self._indent()
self.xmlgen.startElement("box", box)
self._level += 1
def open_polygon(self, polygon):
self._indent()
self.xmlgen.startElement("polygon", polygon)
self._level += 1
def open_polyline(self, polyline):
self._indent()
self.xmlgen.startElement("polyline", polyline)
self._level += 1
def open_points(self, points):
self._indent()
self.xmlgen.startElement("points", points)
self._level += 1
def add_attribute(self, attribute):
self._indent()
self.xmlgen.startElement("attribute", {"name": attribute["name"]})
self.xmlgen.characters(attribute["value"])
self.xmlgen.endElement("attribute")
def close_box(self):
self._level -= 1
self._indent()
self.xmlgen.endElement("box")
def close_polygon(self):
self._level -= 1
self._indent()
self.xmlgen.endElement("polygon")
def close_polyline(self):
self._level -= 1
self._indent()
self.xmlgen.endElement("polyline")
def close_points(self):
self._level -= 1
self._indent()
self.xmlgen.endElement("points")
def close_image(self):
self._level -= 1
self._indent()
self.xmlgen.endElement("image")
def close_track(self):
self._level -= 1
self._indent()
self.xmlgen.endElement("track")
def close_root(self):
self._level -= 1
self._indent()
self.xmlgen.endElement("annotations")
self.xmlgen.endDocument()
return XmlAnnotationWriter(file_object)
def dump_as_cvat_annotation(file_object, annotations):
from collections import OrderedDict
dumper = create_xml_dumper(file_object)
dumper.open_root()
dumper.add_meta(annotations.meta)
for frame_annotation in annotations.group_by_frame():
frame_id = frame_annotation.frame
dumper.open_image(OrderedDict([
("id", str(frame_id)),
("name", frame_annotation.name),
("width", str(frame_annotation.width)),
("height", str(frame_annotation.height))
]))
for shape in frame_annotation.labeled_shapes:
dump_data = OrderedDict([
("label", shape.label),
("occluded", str(int(shape.occluded))),
])
if shape.type == "rectangle":
dump_data.update(OrderedDict([
("xtl", "{:.2f}".format(shape.points[0])),
("ytl", "{:.2f}".format(shape.points[1])),
("xbr", "{:.2f}".format(shape.points[2])),
("ybr", "{:.2f}".format(shape.points[3]))
]))
else:
dump_data.update(OrderedDict([
("points", ';'.join((
','.join((
"{:.2f}".format(x),
"{:.2f}".format(y)
)) for x, y in pairwise(shape.points))
)),
]))
if annotations.meta["task"]["z_order"] != "False":
dump_data['z_order'] = str(shape.z_order)
if "group" in shape and shape.group:
dump_data['group_id'] = str(shape.group)
if shape.type == "rectangle":
dumper.open_box(dump_data)
elif shape.type == "polygon":
dumper.open_polygon(dump_data)
elif shape.type == "polyline":
dumper.open_polyline(dump_data)
elif shape.type == "points":
dumper.open_points(dump_data)
else:
raise NotImplementedError("unknown shape type")
for attr in shape.attributes:
dumper.add_attribute(OrderedDict([
("name", attr.name),
("value", attr.value)
]))
if shape.type == "rectangle":
dumper.close_box()
elif shape.type == "polygon":
dumper.close_polygon()
elif shape.type == "polyline":
dumper.close_polyline()
elif shape.type == "points":
dumper.close_points()
else:
raise NotImplementedError("unknown shape type")
dumper.close_image()
dumper.close_root()
def dump_as_cvat_interpolation(file_object, annotations):
from collections import OrderedDict
dumper = create_xml_dumper(file_object)
dumper.open_root()
dumper.add_meta(annotations.meta)
def dump_track(idx, track):
track_id = idx
dump_data = OrderedDict([
("id", str(track_id)),
("label", track.label),
])
if track.group:
dump_data['group_id'] = str(track.group)
dumper.open_track(dump_data)
for shape in track.shapes:
dump_data = OrderedDict([
("frame", str(shape.frame)),
("outside", str(int(shape.outside))),
("occluded", str(int(shape.occluded))),
("keyframe", str(int(shape.keyframe))),
])
if shape.type == "rectangle":
dump_data.update(OrderedDict([
("xtl", "{:.2f}".format(shape.points[0])),
("ytl", "{:.2f}".format(shape.points[1])),
("xbr", "{:.2f}".format(shape.points[2])),
("ybr", "{:.2f}".format(shape.points[3])),
]))
else:
dump_data.update(OrderedDict([
("points", ';'.join(['{:.2f},{:.2f}'.format(x, y)
for x,y in pairwise(shape.points)]))
]))
if annotations.meta["task"]["z_order"] != "False":
dump_data["z_order"] = str(shape.z_order)
if shape.type == "rectangle":
dumper.open_box(dump_data)
elif shape.type == "polygon":
dumper.open_polygon(dump_data)
elif shape.type == "polyline":
dumper.open_polyline(dump_data)
elif shape.type == "points":
dumper.open_points(dump_data)
else:
raise NotImplementedError("unknown shape type")
for attr in shape.attributes:
dumper.add_attribute(OrderedDict([
("name", attr.name),
("value", attr.value)
]))
if shape.type == "rectangle":
dumper.close_box()
elif shape.type == "polygon":
dumper.close_polygon()
elif shape.type == "polyline":
dumper.close_polyline()
elif shape.type == "points":
dumper.close_points()
else:
raise NotImplementedError("unknown shape type")
dumper.close_track()
counter = 0
for track in annotations.tracks:
dump_track(counter, track)
counter += 1
for shape in annotations.shapes:
dump_track(counter, annotations.Track(
label=shape.label,
group=shape.group,
shapes=[annotations.TrackedShape(
type=shape.type,
points=shape.points,
occluded=shape.occluded,
outside=False,
keyframe=True,
z_order=shape.z_order,
frame=shape.frame,
attributes=shape.attributes,
),
annotations.TrackedShape(
type=shape.type,
points=shape.points,
occluded=shape.occluded,
outside=True,
keyframe=True,
z_order=shape.z_order,
frame=shape.frame + 1,
attributes=shape.attributes,
),
],
))
counter += 1
dumper.close_root()
def load(file_object, annotations):
import xml.etree.ElementTree as et
context = et.iterparse(file_object, events=("start", "end"))
context = iter(context)
ev, _ = next(context)
supported_shapes = ('box', 'polygon', 'polyline', 'points')
track = None
shape = None
image_is_opened = False
for ev, el in context:
if ev == 'start':
if el.tag == 'track':
track = annotations.Track(
label=el.attrib['label'],
group=int(el.attrib.get('group_id', 0)),
shapes=[],
)
elif el.tag == 'image':
image_is_opened = True
frame_id = int(el.attrib['id'])
elif el.tag in supported_shapes and (track is not None or image_is_opened):
shape = {
'attributes': [],
'points': [],
}
elif ev == 'end':
if el.tag == 'attribute' and shape is not None:
shape['attributes'].append(annotations.Attribute(
name=el.attrib['name'],
value=el.text,
))
if el.tag in supported_shapes:
if track is not None:
shape['frame'] = el.attrib['frame']
shape['outside'] = el.attrib['outside'] == "1"
shape['keyframe'] = el.attrib['keyframe'] == "1"
else:
shape['frame'] = frame_id
shape['label'] = el.attrib['label']
shape['group'] = int(el.attrib.get('group_id', 0))
shape['type'] = 'rectangle' if el.tag == 'box' else el.tag
shape['occluded'] = el.attrib['occluded'] == '1'
shape['z_order'] = int(el.attrib.get('z_order', 0))
if el.tag == 'box':
shape['points'].append(el.attrib['xtl'])
shape['points'].append(el.attrib['ytl'])
shape['points'].append(el.attrib['xbr'])
shape['points'].append(el.attrib['ybr'])
else:
for pair in el.attrib['points'].split(';'):
shape['points'].extend(map(float, pair.split(',')))
if track is not None:
track.shapes.append(annotations.TrackedShape(**shape))
else:
annotations.add_shape(annotations.LabeledShape(**shape))
shape = None
elif el.tag == 'track':
annotations.add_track(track)
track = None
elif el.tag == 'image':
image_is_opened = False
el.clear()
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
from cvat.apps.annotation import models
from django.conf import settings
from cvat.apps.annotation.serializers import AnnotationFormatSerializer
import os
from copy import deepcopy
def register_format(format_file):
source_code = open(format_file, 'r').read()
global_vars = {
"__builtins__": {},
}
exec(source_code, global_vars)
if "format_spec" not in global_vars or not isinstance(global_vars["format_spec"], dict):
raise Exception("Could not find \'format_spec\' definition in format file specification")
format_spec = deepcopy(global_vars["format_spec"])
if not models.AnnotationFormat.objects.filter(name=format_spec["name"]).exists():
format_spec["handler_file"] = os.path.relpath(format_file, settings.BASE_DIR)
for spec in format_spec["loaders"] + format_spec["dumpers"]:
spec["display_name"] = spec["display_name"].format(
name=format_spec["name"],
format=spec["format"],
version=spec["version"],
)
serializer = AnnotationFormatSerializer(data=format_spec)
if serializer.is_valid(raise_exception=True):
serializer.save()
def get_annotation_formats():
return AnnotationFormatSerializer(
models.AnnotationFormat.objects.all(),
many=True).data
# Generated by Django 2.1.9 on 2019-07-31 15:20
import cvat.apps.annotation.models
import cvat.apps.engine.models
from django.conf import settings
import django.core.files.storage
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
initial = True
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='AnnotationFormat',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', cvat.apps.engine.models.SafeCharField(max_length=256)),
('created_date', models.DateTimeField(auto_now_add=True)),
('updated_date', models.DateTimeField(auto_now_add=True)),
('handler_file', models.FileField(storage=django.core.files.storage.FileSystemStorage(location=settings.BASE_DIR), upload_to=cvat.apps.annotation.models.upload_file_handler)),
('owner', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL)),
],
options={
'default_permissions': (),
},
),
migrations.CreateModel(
name='AnnotationHandler',
fields=[
('type', models.CharField(choices=[('dumper', 'DUMPER'), ('loader', 'LOADER')], max_length=16)),
('display_name', cvat.apps.engine.models.SafeCharField(max_length=256, primary_key=True, serialize=False)),
('format', models.CharField(max_length=16)),
('version', models.CharField(max_length=16)),
('handler', models.CharField(max_length=256)),
('annotation_format', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='annotation.AnnotationFormat')),
],
options={
'default_permissions': (),
},
),
]
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os
from enum import Enum
from django.db import models
from django.conf import settings
from django.core.files.storage import FileSystemStorage
from cvat.apps.engine.models import SafeCharField
from django.contrib.auth.models import User
def upload_file_handler(instance, filename):
return os.path.join('formats', str(instance.id), filename)
class HandlerType(str, Enum):
DUMPER = 'dumper'
LOADER = 'loader'
@classmethod
def choices(self):
return tuple((x.value, x.name) for x in self)
def __str__(self):
return self.value
class AnnotationFormat(models.Model):
name = SafeCharField(max_length=256)
owner = models.ForeignKey(User, null=True, blank=True,
on_delete=models.SET_NULL)
created_date = models.DateTimeField(auto_now_add=True)
updated_date = models.DateTimeField(auto_now_add=True)
handler_file = models.FileField(
upload_to=upload_file_handler,
storage=FileSystemStorage(location=os.path.join(settings.BASE_DIR)),
)
class Meta:
default_permissions = ()
class AnnotationHandler(models.Model):
type = models.CharField(max_length=16,
choices=HandlerType.choices())
display_name = SafeCharField(max_length=256, primary_key=True)
format = models.CharField(max_length=16)
version = models.CharField(max_length=16)
handler = models.CharField(max_length=256)
annotation_format = models.ForeignKey(AnnotationFormat, on_delete=models.CASCADE)
class Meta:
default_permissions = ()
format_spec = {
"name": "PASCAL VOC",
"dumpers": [
{
"display_name": "{name} {format} {version}",
"format": "ZIP",
"version": "1.0",
"handler": "dump"
},
],
"loaders": [],
}
def load(file_object, annotations, spec):
raise NotImplementedError
def dump(file_object, annotations):
from pascal_voc_writer import Writer
import os
from zipfile import ZipFile
from tempfile import TemporaryDirectory
with TemporaryDirectory() as out_dir:
with ZipFile(file_object, 'w') as output_zip:
for frame_annotation in annotations.group_by_frame():
image_name = frame_annotation.name
width = frame_annotation.width
height = frame_annotation.height
writer = Writer(image_name, width, height)
writer.template_parameters['path'] = ''
writer.template_parameters['folder'] = ''
for shape in frame_annotation.labeled_shapes:
if shape.type != "rectangle":
continue
label = shape.label
xtl = shape.points[0]
ytl = shape.points[1]
xbr = shape.points[2]
ybr = shape.points[3]
writer.addObject(label, xtl, ytl, xbr, ybr)
anno_name = os.path.basename('{}.{}'.format(os.path.splitext(image_name)[0], 'xml'))
anno_file = os.path.join(out_dir, anno_name)
writer.save(anno_file)
output_zip.write(filename=anno_file, arcname=anno_name)
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
from rest_framework import serializers
from cvat.apps.annotation import models
class AnnotationHandlerSerializer(serializers.ModelSerializer):
class Meta:
model = models.AnnotationHandler
exclude = ('annotation_format',)
class AnnotationFormatSerializer(serializers.ModelSerializer):
handlers = AnnotationHandlerSerializer(many=True, source='annotationhandler_set')
class Meta:
model = models.AnnotationFormat
exclude = ("handler_file", )
# pylint: disable=no-self-use
def create(self, validated_data):
handlers = validated_data.pop('handlers')
annotation_format = models.AnnotationFormat.objects.create(**validated_data)
handlers = [models.AnnotationHandler(annotation_format=annotation_format, **handler) for handler in handlers]
models.AnnotationHandler.objects.bulk_create(handlers)
return annotation_format
# pylint: disable=no-self-use
def to_internal_value(self, data):
_data = data.copy()
_data["handlers"] = []
for d in _data.pop("dumpers"):
d["type"] = models.HandlerType.DUMPER
_data["handlers"].append(d)
for l in _data.pop("loaders"):
l["type"] = models.HandlerType.LOADER
_data["handlers"].append(l)
return _data
def to_representation(self, instance):
data = super().to_representation(instance)
data['dumpers'] = []
data['loaders'] = []
for handler in data.pop("handlers"):
handler_type = handler.pop("type")
if handler_type == models.HandlerType.DUMPER:
data["dumpers"].append(handler)
else:
data["loaders"].append(handler)
return data
class AnnotationFileSerializer(serializers.Serializer):
annotation_file = serializers.FileField()
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os
path_prefix = os.path.join('cvat', 'apps', 'annotation')
BUILTIN_FORMATS = (
os.path.join(path_prefix, 'cvat.py'),
os.path.join(path_prefix,'pascal_voc.py'),
)
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
# Copyright (C) 2018 Intel Corporation
#
# SPDX-License-Identifier: MIT
......@@ -26,7 +26,7 @@ from cvat.apps.engine.annotation import put_task_data, patch_task_data
from .models import AnnotationModel, FrameworkChoice
from .model_loader import ModelLoader, load_labelmap
from .image_loader import ImageLoader
from .import_modules import import_modules
from cvat.apps.engine.utils.import_modules import import_modules
def _remove_old_file(model_file_field):
......
......@@ -3,6 +3,7 @@
#
# SPDX-License-Identifier: MIT
import os
from enum import Enum
from django.db import models
......@@ -13,7 +14,7 @@ from django.core.files.storage import FileSystemStorage
fs = FileSystemStorage()
def upload_path_handler(instance, filename):
return "{models_root}/{id}/{file}".format(models_root=settings.MODELS_ROOT, id=instance.id, file=filename)
return os.path.join(settings.MODELS_ROOT, str(instance.id), filename)
class FrameworkChoice(Enum):
OPENVINO = 'openvino'
......
......@@ -5,17 +5,18 @@
*/
/* global
AnnotationParser:false
userConfirm:false
dumpAnnotationRequest: false
dumpAnnotationRequest:false
uploadTaskAnnotationRequest:false
LabelsInfo:false
showMessage:false
showOverlay:false
*/
class TaskView {
constructor(task) {
constructor(task, annotationFormats) {
this.init(task);
this._annotationFormats = annotationFormats;
this._UI = null;
}
......@@ -75,147 +76,34 @@ class TaskView {
});
}
_upload() {
async function saveChunk(parsed) {
const CHUNK_SIZE = 30000;
let chunk = null;
class Chunk {
constructor() {
this.shapes = [];
this.tracks = [];
this.tags = [];
this.capasity = CHUNK_SIZE;
this.version = 0;
}
length() {
return this.tags.length
+ this.shapes.length
+ this.tracks.reduce((sum, track) => sum + track.shapes.length, 0);
}
isFull() {
return this.length() >= this.capasity;
}
isEmpty() {
return this.length() === 0;
}
clear() {
this.shapes = [];
this.tracks = [];
this.tags = [];
}
export() {
return {
shapes: this.shapes,
tracks: this.tracks,
tags: this.tags,
version: this.version,
};
}
async save(taskID) {
try {
const response = await $.ajax({
url: `/api/v1/tasks/${taskID}/annotations?action=create`,
type: 'PATCH',
data: JSON.stringify(chunk.export()),
contentType: 'application/json',
});
this.version = response.version;
this.clear();
} catch (error) {
throw error;
}
}
}
const splitAndSave = async (chunkForSave, prop, splitStep) => {
for (let start = 0; start < parsed[prop].length; start += splitStep) {
Array.prototype.push.apply(chunkForSave[prop],
parsed[prop].slice(start, start + splitStep));
if (chunkForSave.isFull()) {
await chunkForSave.save(this._task.id);
}
}
// save tail
if (!chunkForSave.isEmpty()) {
await chunkForSave.save(this._task.id);
}
};
chunk = new Chunk();
// TODO tags aren't supported by parser
// await split(chunk, "tags", CHUNK_SIZE);
await splitAndSave(chunk, 'shapes', CHUNK_SIZE);
await splitAndSave(chunk, 'tracks', 1);
}
async function save(parsed) {
await $.ajax({
url: `/api/v1/tasks/${this._task.id}/annotations`,
type: 'DELETE',
});
await saveChunk.call(this, parsed);
}
async function onload(overlay, text) {
try {
overlay.setMessage('Required data are being downloaded from the server..');
const imageCache = await $.get(`/api/v1/tasks/${this._task.id}/frames/meta`);
const labelsCopy = JSON.parse(JSON.stringify(this._task.labels
.map(el => el.toJSON())));
const parser = new AnnotationParser({
start: 0,
stop: this._task.size,
image_meta_data: imageCache,
}, new LabelsInfo(labelsCopy));
overlay.setMessage('The annotation file is being parsed..');
const parsed = parser.parse(text);
overlay.setMessage('The annotation is being saved..');
await save.call(this, parsed);
const message = 'Annotation have been successfully uploaded';
showMessage(message);
} catch (errorData) {
let message = null;
if (typeof (errorData) === 'string') {
message = `Can not upload annotations. ${errorData}`;
} else {
message = `Can not upload annotations. Code: ${errorData.status}. `
+ `Message: ${errorData.responseText || errorData.statusText}`;
}
showMessage(message);
} finally {
overlay.remove();
}
}
$('<input type="file" accept="text/xml">').on('change', (onChangeEvent) => {
_upload(uploadAnnotationButton) {
const button = $(uploadAnnotationButton);
const CVATformat = this._annotationFormats.find(el => el.name === 'CVAT');
$('<input type="file" accept="text/xml">').on('change', async (onChangeEvent) => {
const file = onChangeEvent.target.files[0];
$(onChangeEvent.target).remove();
if (file) {
const overlay = showOverlay('File is being parsed..');
const fileReader = new FileReader();
fileReader.onload = (onloadEvent) => {
onload.call(this, overlay, onloadEvent.target.result);
};
fileReader.readAsText(file);
button.text('Uploading..');
button.prop('disabled', true);
const annotationData = new FormData();
annotationData.append('annotation_file', file);
try {
await uploadTaskAnnotationRequest(this._task.id, annotationData,
CVATformat.loaders[0].display_name);
} catch (error) {
showMessage(error.message);
} finally {
button.prop('disabled', false);
button.text('Upload Annotation');
}
}
}).click();
}
async _dump(button) {
async _dump(button, format) {
button.disabled = true;
try {
await dumpAnnotationRequest(this._task.id, this._task.name);
await dumpAnnotationRequest(this._task.id, this._task.name, format);
} catch (error) {
showMessage(error.message);
} finally {
......@@ -242,14 +130,27 @@ class TaskView {
}),
);
const buttonsContainer = $('<div class="dashboardButtonsUI"> </div>').appendTo(this._UI);
$('<button class="regular dashboardButtonUI"> Dump Annotation </button>').on('click', (e) => {
this._dump(e.target);
}).appendTo(buttonsContainer);
$('<button class="regular dashboardButtonUI"> Upload Annotation </button>').on('click', () => {
userConfirm('The current annotation will be lost. Are you sure?', () => this._upload());
const downloadButton = $('<button class="regular dashboardButtonUI"> Dump Annotation </button>');
const dropdownMenu = $('<ul class="dropdown-content hidden"></ul>');
for (const format of this._annotationFormats) {
for (const dumpSpec of format.dumpers) {
dropdownMenu.append($(`<li>${dumpSpec.display_name}</li>`).on('click', () => {
dropdownMenu.addClass('hidden');
this._dump(downloadButton[0], dumpSpec.display_name);
}));
}
}
$('<div class="dropdown"></div>').append(
downloadButton.on('click', () => {
dropdownMenu.toggleClass('hidden');
}),
).append(dropdownMenu).appendTo(buttonsContainer);
$('<button class="regular dashboardButtonUI"> Upload Annotation </button>').on('click', (e) => {
userConfirm('The current annotation will be lost. Are you sure?', () => this._upload(e.target));
}).appendTo(buttonsContainer);
$('<button class="regular dashboardButtonUI"> Update Task </button>').on('click', () => {
......@@ -290,13 +191,14 @@ class TaskView {
class DashboardView {
constructor(metaData, taskData) {
constructor(metaData, taskData, annotationFormats) {
this._dashboardList = taskData.results;
this._maxUploadSize = metaData.max_upload_size;
this._maxUploadCount = metaData.max_upload_count;
this._baseURL = metaData.base_url;
this._sharePath = metaData.share_path;
this._params = {};
this._annotationFormats = annotationFormats;
this._setupList();
this._setupTaskSearch();
......@@ -348,7 +250,7 @@ class DashboardView {
}));
for (const task of tasks) {
const taskView = new TaskView(task);
const taskView = new TaskView(task, this._annotationFormats);
dashboardList.append(taskView.render(baseURL));
}
......@@ -807,9 +709,10 @@ window.addEventListener('DOMContentLoaded', () => {
// TODO: Use REST API in order to get meta
$.get('/dashboard/meta'),
$.get(`/api/v1/tasks${window.location.search}`),
).then((metaData, taskData) => {
$.get('/api/v1/server/annotation/formats'),
).then((metaData, taskData, annotationFormats) => {
try {
new DashboardView(metaData[0], taskData[0]);
new DashboardView(metaData[0], taskData[0], annotationFormats[0]);
} catch (exception) {
$('#content').empty();
const message = `Can not build CVAT dashboard. Exception: ${exception}.`;
......
此差异已折叠。
import copy
import numpy as np
from scipy.optimize import linear_sum_assignment
from shapely import geometry
from . import models
class DataManager:
def __init__(self, data):
self.data = data
def merge(self, data, start_frame, overlap):
tags = TagManager(self.data.tags)
tags.merge(data.tags, start_frame, overlap)
shapes = ShapeManager(self.data.shapes)
shapes.merge(data.shapes, start_frame, overlap)
tracks = TrackManager(self.data.tracks)
tracks.merge(data.tracks, start_frame, overlap)
def to_shapes(self, end_frame):
shapes = self.data.shapes
tracks = TrackManager(self.data.tracks)
return shapes + tracks.to_shapes(end_frame)
def to_tracks(self):
tracks = self.data.tracks
shapes = ShapeManager(self.data.shapes)
return tracks + shapes.to_tracks()
class ObjectManager:
def __init__(self, objects):
self.objects = objects
@staticmethod
def _get_objects_by_frame(objects, start_frame):
objects_by_frame = {}
for obj in objects:
if obj["frame"] >= start_frame:
if obj["frame"] in objects_by_frame:
objects_by_frame[obj["frame"]].append(obj)
else:
objects_by_frame[obj["frame"]] = [obj]
return objects_by_frame
@staticmethod
def _get_cost_threshold():
raise NotImplementedError()
@staticmethod
def _calc_objects_similarity(obj0, obj1, start_frame, overlap):
raise NotImplementedError()
@staticmethod
def _unite_objects(obj0, obj1):
raise NotImplementedError()
@staticmethod
def _modify_unmached_object(obj, end_frame):
raise NotImplementedError()
def merge(self, objects, start_frame, overlap):
# 1. Split objects on two parts: new and which can be intersected
# with existing objects.
new_objects = [obj for obj in objects
if obj["frame"] >= start_frame + overlap]
int_objects = [obj for obj in objects
if obj["frame"] < start_frame + overlap]
assert len(new_objects) + len(int_objects) == len(objects)
# 2. Convert to more convenient data structure (objects by frame)
int_objects_by_frame = self._get_objects_by_frame(int_objects, start_frame)
old_objects_by_frame = self._get_objects_by_frame(self.objects, start_frame)
# 3. Add new objects as is. It should be done only after old_objects_by_frame
# variable is initialized.
self.objects.extend(new_objects)
# Nothing to merge here. Just add all int_objects if any.
if not old_objects_by_frame or not int_objects_by_frame:
for frame in old_objects_by_frame:
for old_obj in old_objects_by_frame[frame]:
self._modify_unmached_object(old_obj, start_frame + overlap)
self.objects.extend(int_objects)
return
# 4. Build cost matrix for each frame and find correspondence using
# Hungarian algorithm. In this case min_cost_thresh is stronger
# because we compare only on one frame.
min_cost_thresh = self._get_cost_threshold()
for frame in int_objects_by_frame:
if frame in old_objects_by_frame:
int_objects = int_objects_by_frame[frame]
old_objects = old_objects_by_frame[frame]
cost_matrix = np.empty(shape=(len(int_objects), len(old_objects)),
dtype=float)
# 5.1 Construct cost matrix for the frame.
for i, int_obj in enumerate(int_objects):
for j, old_obj in enumerate(old_objects):
cost_matrix[i][j] = 1 - self._calc_objects_similarity(
int_obj, old_obj, start_frame, overlap)
# 6. Find optimal solution using Hungarian algorithm.
row_ind, col_ind = linear_sum_assignment(cost_matrix)
old_objects_indexes = list(range(0, len(old_objects)))
int_objects_indexes = list(range(0, len(int_objects)))
for i, j in zip(row_ind, col_ind):
# Reject the solution if the cost is too high. Remember
# inside int_objects_indexes objects which were handled.
if cost_matrix[i][j] <= min_cost_thresh:
old_objects[j] = self._unite_objects(int_objects[i], old_objects[j])
int_objects_indexes[i] = -1
old_objects_indexes[j] = -1
# 7. Add all new objects which were not processed.
for i in int_objects_indexes:
if i != -1:
self.objects.append(int_objects[i])
# 8. Modify all old objects which were not processed
# (e.g. generate a shape with outside=True at the end).
for j in old_objects_indexes:
if j != -1:
self._modify_unmached_object(old_objects[j],
start_frame + overlap)
else:
# We don't have old objects on the frame. Let's add all new ones.
self.objects.extend(int_objects_by_frame[frame])
class TagManager(ObjectManager):
@staticmethod
def _get_cost_threshold():
return 0.25
@staticmethod
def _calc_objects_similarity(obj0, obj1, start_frame, overlap):
# TODO: improve the trivial implementation, compare attributes
return 1 if obj0["label_id"] == obj1["label_id"] else 0
@staticmethod
def _unite_objects(obj0, obj1):
# TODO: improve the trivial implementation
return obj0 if obj0["frame"] < obj1["frame"] else obj1
@staticmethod
def _modify_unmached_object(obj, end_frame):
pass
def pairwise(iterable):
a = iter(iterable)
return zip(a, a)
class ShapeManager(ObjectManager):
def to_tracks(self):
tracks = []
for shape in self.objects:
shape0 = copy.copy(shape)
shape0["keyframe"] = True
shape0["outside"] = False
# TODO: Separate attributes on mutable and unmutable
shape0["attributes"] = []
shape0.pop("group", None)
shape1 = copy.copy(shape0)
shape1["outside"] = True
shape1["frame"] += 1
track = {
"label_id": shape["label_id"],
"frame": shape["frame"],
"group": shape.get("group", None),
"attributes": shape["attributes"],
"shapes": [shape0, shape1]
}
tracks.append(track)
return tracks
@staticmethod
def _get_cost_threshold():
return 0.25
@staticmethod
def _calc_objects_similarity(obj0, obj1, start_frame, overlap):
def _calc_polygons_similarity(p0, p1):
overlap_area = p0.intersection(p1).area
return overlap_area / (p0.area + p1.area - overlap_area)
has_same_type = obj0["type"] == obj1["type"]
has_same_label = obj0.get("label_id") == obj1.get("label_id")
if has_same_type and has_same_label:
if obj0["type"] == models.ShapeType.RECTANGLE:
p0 = geometry.box(*obj0["points"])
p1 = geometry.box(*obj1["points"])
return _calc_polygons_similarity(p0, p1)
elif obj0["type"] == models.ShapeType.POLYGON:
p0 = geometry.Polygon(pairwise(obj0["points"]))
p1 = geometry.Polygon(pairwise(obj0["points"]))
return _calc_polygons_similarity(p0, p1)
else:
return 0 # FIXME: need some similarity for points and polylines
return 0
@staticmethod
def _unite_objects(obj0, obj1):
# TODO: improve the trivial implementation
return obj0 if obj0["frame"] < obj1["frame"] else obj1
@staticmethod
def _modify_unmached_object(obj, end_frame):
pass
class TrackManager(ObjectManager):
def to_shapes(self, end_frame):
shapes = []
for idx, track in enumerate(self.objects):
for shape in TrackManager.get_interpolated_shapes(track, 0, end_frame):
if not shape["outside"]:
shape["label_id"] = track["label_id"]
shape["group"] = track["group"]
shape["track_id"] = idx
shape["attributes"] += track["attributes"]
shapes.append(shape)
return shapes
@staticmethod
def _get_objects_by_frame(objects, start_frame):
# Just for unification. All tracks are assigned on the same frame
objects_by_frame = {0: []}
for obj in objects:
shape = obj["shapes"][-1] # optimization for old tracks
if shape["frame"] >= start_frame or not shape["outside"]:
objects_by_frame[0].append(obj)
if not objects_by_frame[0]:
objects_by_frame = {}
return objects_by_frame
@staticmethod
def _get_cost_threshold():
return 0.5
@staticmethod
def _calc_objects_similarity(obj0, obj1, start_frame, overlap):
if obj0["label_id"] == obj1["label_id"]:
# Here start_frame is the start frame of next segment
# and stop_frame is the stop frame of current segment
# end_frame == stop_frame + 1
end_frame = start_frame + overlap
obj0_shapes = TrackManager.get_interpolated_shapes(obj0, start_frame, end_frame)
obj1_shapes = TrackManager.get_interpolated_shapes(obj1, start_frame, end_frame)
obj0_shapes_by_frame = {shape["frame"]:shape for shape in obj0_shapes}
obj1_shapes_by_frame = {shape["frame"]:shape for shape in obj1_shapes}
assert obj0_shapes_by_frame and obj1_shapes_by_frame
count, error = 0, 0
for frame in range(start_frame, end_frame):
shape0 = obj0_shapes_by_frame.get(frame)
shape1 = obj1_shapes_by_frame.get(frame)
if shape0 and shape1:
if shape0["outside"] != shape1["outside"]:
error += 1
else:
error += 1 - ShapeManager._calc_objects_similarity(shape0, shape1, start_frame, overlap)
count += 1
elif shape0 or shape1:
error += 1
count += 1
return 1 - error / count
else:
return 0
@staticmethod
def _modify_unmached_object(obj, end_frame):
shape = obj["shapes"][-1]
if not shape["outside"]:
shape = copy.deepcopy(shape)
shape["frame"] = end_frame
shape["outside"] = True
obj["shapes"].append(shape)
@staticmethod
def normalize_shape(shape):
points = np.asarray(shape["points"]).reshape(-1, 2)
broken_line = geometry.LineString(points)
points = []
for off in range(0, 100, 1):
p = broken_line.interpolate(off / 100, True)
points.append(p.x)
points.append(p.y)
shape = copy.copy(shape)
shape["points"] = points
return shape
@staticmethod
def get_interpolated_shapes(track, start_frame, end_frame):
def interpolate(shape0, shape1):
shapes = []
is_same_type = shape0["type"] == shape1["type"]
is_polygon = shape0["type"] == models.ShapeType.POLYGON
is_polyline = shape0["type"] == models.ShapeType.POLYLINE
is_same_size = len(shape0["points"]) == len(shape1["points"])
if not is_same_type or is_polygon or is_polyline or not is_same_size:
shape0 = TrackManager.normalize_shape(shape0)
shape1 = TrackManager.normalize_shape(shape1)
distance = shape1["frame"] - shape0["frame"]
step = np.subtract(shape1["points"], shape0["points"]) / distance
for frame in range(shape0["frame"] + 1, shape1["frame"]):
off = frame - shape0["frame"]
if shape1["outside"]:
points = np.asarray(shape0["points"]).reshape(-1, 2)
else:
points = (shape0["points"] + step * off).reshape(-1, 2)
shape = copy.deepcopy(shape0)
if len(points) == 1:
shape["points"] = points.flatten()
else:
broken_line = geometry.LineString(points).simplify(0.05, False)
shape["points"] = [x for p in broken_line.coords for x in p]
shape["keyframe"] = False
shape["frame"] = frame
shapes.append(shape)
return shapes
if track.get("interpolated_shapes"):
return track["interpolated_shapes"]
# TODO: should be return an iterator?
shapes = []
curr_frame = track["shapes"][0]["frame"]
prev_shape = {}
for shape in track["shapes"]:
if prev_shape:
assert shape["frame"] > curr_frame
for attr in prev_shape["attributes"]:
if attr["spec_id"] not in map(lambda el: el["spec_id"], shape["attributes"]):
shape["attributes"].append(copy.deepcopy(attr))
if not prev_shape["outside"]:
shapes.extend(interpolate(prev_shape, shape))
shape["keyframe"] = True
shapes.append(shape)
curr_frame = shape["frame"]
prev_shape = shape
# TODO: Need to modify a client and a database (append "outside" shapes for polytracks)
if not prev_shape["outside"] and prev_shape["type"] == models.ShapeType.RECTANGLE:
shape = copy.copy(prev_shape)
shape["frame"] = end_frame
shapes.extend(interpolate(prev_shape, shape))
track["interpolated_shapes"] = shapes
return shapes
@staticmethod
def _unite_objects(obj0, obj1):
track = obj0 if obj0["frame"] < obj1["frame"] else obj1
assert obj0["label_id"] == obj1["label_id"]
shapes = {shape["frame"]:shape for shape in obj0["shapes"]}
for shape in obj1["shapes"]:
frame = shape["frame"]
if frame in shapes:
shapes[frame] = ShapeManager._unite_objects(shapes[frame], shape)
else:
shapes[frame] = shape
track["frame"] = min(obj0["frame"], obj1["frame"])
track["shapes"] = list(sorted(shapes.values(), key=lambda shape: shape["frame"]))
track["interpolated_shapes"] = []
return track
......@@ -311,7 +311,6 @@ class TrackedShape(Shape):
class TrackedShapeAttributeVal(AttributeVal):
shape = models.ForeignKey(TrackedShape, on_delete=models.CASCADE)
class Plugin(models.Model):
name = models.SlugField(max_length=32, primary_key=True)
description = SafeCharField(max_length=8192)
......
......@@ -153,4 +153,29 @@ html {
-moz-user-select: text;
-ms-user-select: text;
user-select: text;
}
\ No newline at end of file
}
.dropdown {
position: relative;
}
.dropdown-content {
position: absolute;
background-color: #f1f1f1;
min-width: 160px;
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
z-index: 1;
list-style-type: none;
padding-inline-start: 0;
margin-left: 20%;
margin-top: -1%;
}
.dropdown-content li {
color: black;
padding: 12px 16px;
text-decoration: none;
display: block;
}
.dropdown-content li:hover {background-color: #ddd}
\ No newline at end of file
......@@ -22,7 +22,7 @@ class AnnotationSaverModel extends Listener {
this._shapeCollection = shapeCollection;
this._initialObjects = [];
this._hash = this._getHash();
this.update();
// We need use data from export instead of initialData
// Otherwise we have differ keys order and JSON comparison code incorrect
......@@ -36,6 +36,10 @@ class AnnotationSaverModel extends Listener {
}
}
update() {
this._hash = this._getHash();
}
async _request(data, action) {
return new Promise((resolve, reject) => {
$.ajax({
......@@ -399,4 +403,5 @@ function buildAnnotationSaver(initialData, shapeCollection) {
const model = new AnnotationSaverModel(initialData, shapeCollection);
const controller = new AnnotationSaverController(model);
new AnnotationSaverView(model, controller);
return model;
}
......@@ -43,9 +43,9 @@
ShapeMergerModel:false
ShapeMergerView:false
showMessage:false
showOverlay:false
buildAnnotationSaver:false
LabelsInfo:false
uploadJobAnnotationRequest:false
*/
async function initLogger(jobID) {
......@@ -64,52 +64,30 @@ function blurAllElements() {
document.activeElement.blur();
}
function uploadAnnotation(shapeCollectionModel, historyModel,
annotationParser, uploadAnnotationButton) {
$('#annotationFileSelector').one('change', (changedFileEvent) => {
function uploadAnnotation(jobId, shapeCollectionModel, historyModel, annotationSaverModel,
uploadAnnotationButton, formatId, parseSpec) {
$('#annotationFileSelector').one('change', async (changedFileEvent) => {
const file = changedFileEvent.target.files['0'];
changedFileEvent.target.value = '';
if (!file || file.type !== 'text/xml') return;
uploadAnnotationButton.text('Preparing..');
if (!file) return;
uploadAnnotationButton.text('Uploading..');
uploadAnnotationButton.prop('disabled', true);
const overlay = showOverlay('File is being uploaded..');
const fileReader = new FileReader();
fileReader.onload = (loadedFileEvent) => {
let data = null;
const asyncParse = () => {
try {
data = annotationParser.parse(loadedFileEvent.target.result);
} catch (err) {
overlay.remove();
showMessage(err.message);
return;
} finally {
uploadAnnotationButton.text('Upload Annotation');
uploadAnnotationButton.prop('disabled', false);
}
const asyncImport = () => {
try {
historyModel.empty();
shapeCollectionModel.empty();
shapeCollectionModel.import(data);
shapeCollectionModel.update();
} finally {
overlay.remove();
}
};
overlay.setMessage('Data are being imported..');
setTimeout(asyncImport);
};
overlay.setMessage('File is being parsed..');
setTimeout(asyncParse);
};
fileReader.readAsText(file);
const annotationData = new FormData();
annotationData.append('annotation_file', file);
try {
await uploadJobAnnotationRequest(jobId, annotationData, formatId, parseSpec);
historyModel.empty();
shapeCollectionModel.empty();
const data = await $.get(`/api/v1/jobs/${jobId}/annotations`);
shapeCollectionModel.import(data);
shapeCollectionModel.update();
annotationSaverModel.update();
} catch (error) {
showMessage(error.message);
} finally {
uploadAnnotationButton.prop('disabled', false);
uploadAnnotationButton.text('Upload Annotation');
}
}).click();
}
......@@ -287,12 +265,15 @@ function setupSettingsWindow() {
function setupMenu(job, task, shapeCollectionModel,
annotationParser, aamModel, playerModel, historyModel) {
annotationParser, aamModel, playerModel, historyModel,
annotationFormats, annotationSaverModel) {
const annotationMenu = $('#annotationMenu');
const menuButton = $('#menuButton');
const downloadDropdownMenu = $('#downloadDropdownMenu');
function hide() {
annotationMenu.addClass('hidden');
downloadDropdownMenu.addClass('hidden');
}
function setupVisibility() {
......@@ -406,22 +387,39 @@ function setupMenu(job, task, shapeCollectionModel,
$('#settingsButton').attr('title', `
${shortkeys.open_settings.view_value} - ${shortkeys.open_settings.description}`);
$('#downloadAnnotationButton').on('click', async (e) => {
e.target.disabled = true;
try {
await dumpAnnotationRequest(task.id, task.name);
} catch (error) {
showMessage(error.message);
} finally {
e.target.disabled = false;
for (const format of annotationFormats) {
for (const dumpSpec of format.dumpers) {
$(`<li>${dumpSpec.display_name}</li>`).on('click', async () => {
$('#downloadAnnotationButton')[0].disabled = true;
$('#downloadDropdownMenu').addClass('hidden');
try {
await dumpAnnotationRequest(task.id, task.name, dumpSpec.display_name);
} catch (error) {
showMessage(error.message);
} finally {
$('#downloadAnnotationButton')[0].disabled = false;
}
}).appendTo('#downloadDropdownMenu');
}
}
$('#downloadAnnotationButton').on('click', () => {
$('#downloadDropdownMenu').toggleClass('hidden');
});
$('#uploadAnnotationButton').on('click', () => {
hide();
const CVATformat = annotationFormats.find(el => el.name === 'CVAT');
userConfirm('Current annotation will be removed from the client. Continue?',
() => {
uploadAnnotation(shapeCollectionModel, historyModel, annotationParser, $('#uploadAnnotationButton'));
uploadAnnotation(
job.id,
shapeCollectionModel,
historyModel,
annotationSaverModel,
$('#uploadAnnotationButton'),
CVATformat.loaders[0].display_name,
);
});
});
......@@ -460,7 +458,8 @@ function setupMenu(job, task, shapeCollectionModel,
}
function buildAnnotationUI(jobData, taskData, imageMetaData, annotationData, loadJobEvent) {
function buildAnnotationUI(jobData, taskData, imageMetaData, annotationData, annotationFormats,
loadJobEvent) {
// Setup some API
window.cvat = {
labelsInfo: new LabelsInfo(taskData.labels),
......@@ -537,7 +536,7 @@ function buildAnnotationUI(jobData, taskData, imageMetaData, annotationData, loa
const shapeCollectionView = new ShapeCollectionView(shapeCollectionModel,
shapeCollectionController);
buildAnnotationSaver(annotationData, shapeCollectionModel);
const annotationSaverModel = buildAnnotationSaver(annotationData, shapeCollectionModel);
window.cvat.data = {
get: () => shapeCollectionModel.export()[0],
......@@ -620,7 +619,8 @@ function buildAnnotationUI(jobData, taskData, imageMetaData, annotationData, loa
setupHelpWindow(shortkeys);
setupSettingsWindow();
setupMenu(jobData, taskData, shapeCollectionModel,
annotationParser, aamModel, playerModel, historyModel);
annotationParser, aamModel, playerModel, historyModel,
annotationFormats, annotationSaverModel);
setupFrameFilters();
setupShortkeys(shortkeys, {
aam: aamModel,
......@@ -677,11 +677,12 @@ function callAnnotationUI(jid) {
$.get(`/api/v1/tasks/${jobData.task_id}`),
$.get(`/api/v1/tasks/${jobData.task_id}/frames/meta`),
$.get(`/api/v1/jobs/${jid}/annotations`),
).then((taskData, imageMetaData, annotationData) => {
$.get('/api/v1/server/annotation/formats'),
).then((taskData, imageMetaData, annotationData, annotationFormats) => {
$('#loadingOverlay').remove();
setTimeout(() => {
buildAnnotationUI(jobData, taskData[0],
imageMetaData[0], annotationData[0], loadJobEvent);
imageMetaData[0], annotationData[0], annotationFormats[0], loadJobEvent);
});
}).fail(onError);
}).fail(onError);
......
......@@ -9,6 +9,8 @@
dumpAnnotationRequest
showMessage
showOverlay
uploadJobAnnotationRequest
uploadTaskAnnotationRequest
*/
/* global
......@@ -127,21 +129,24 @@ function showOverlay(message) {
return overlayWindow[0];
}
async function dumpAnnotationRequest(tid, taskName) {
async function dumpAnnotationRequest(tid, taskName, format) {
// URL Router on the server doesn't work correctly with slashes.
// So, we have to replace them on the client side
taskName = taskName.replace(/\//g, '_');
const name = encodeURIComponent(`${tid}_${taskName}`);
return new Promise((resolve, reject) => {
const url = `/api/v1/tasks/${tid}/annotations/${name}`;
let queryString = `format=${format}`;
async function request() {
$.get(url)
$.get(`${url}?${queryString}`)
.done((...args) => {
if (args[2].status === 202) {
setTimeout(request, 3000);
} else {
const a = document.createElement('a');
a.href = `${url}?action=download`;
queryString = `${queryString}&action=download`;
a.href = `${url}?${queryString}`;
document.body.appendChild(a);
a.click();
a.remove();
......@@ -158,6 +163,42 @@ async function dumpAnnotationRequest(tid, taskName) {
});
}
async function uploadAnnoRequest(url, formData, format) {
return new Promise((resolve, reject) => {
const queryString = `format=${format}`;
async function request(data) {
try {
await $.ajax({
url: `${url}?${queryString}`,
type: 'PUT',
data,
contentType: false,
processData: false,
}).done((...args) => {
if (args[2].status === 202) {
setTimeout(() => request(''), 3000);
} else {
resolve();
}
});
} catch (errorData) {
const message = `Can not upload annotations for the job. Code: ${errorData.status}. `
+ `Message: ${errorData.responseText || errorData.statusText}`;
reject(new Error(message));
}
}
setTimeout(() => request(formData));
});
}
async function uploadJobAnnotationRequest(jid, formData, format) {
return uploadAnnoRequest(`/api/v1/jobs/${jid}/annotations`, formData, format);
}
async function uploadTaskAnnotationRequest(tid, formData, format) {
return uploadAnnoRequest(`/api/v1/tasks/${tid}/annotations`, formData, format);
}
/* These HTTP methods do not require CSRF protection */
function csrfSafeMethod(method) {
......
......@@ -557,5 +557,3 @@
50% {stroke-dashoffset: 100; stroke: #f44;}
100% {stroke-dashoffset: 300; stroke: #09c;}
}
......@@ -330,7 +330,10 @@
<div id="annotationMenu" class="hidden regular">
<center style="float:left; width: 28%; height: 100%;" id="engineMenuButtons">
<button id="downloadAnnotationButton" class="menuButton semiBold h2"> Dump Annotation </button>
<div class="dropdown">
<button id="downloadAnnotationButton" class="menuButton semiBold h2"> Dump Annotation </button>
<ul id="downloadDropdownMenu" class="dropdown-content hidden"></ul>
</div>
<button id="uploadAnnotationButton" class="menuButton semiBold h2"> Upload Annotation </button>
<button id="removeAnnotationButton" class="menuButton semiBold h2"> Remove Annotation </button>
<button id="settingsButton" class="menuButton semiBold h2"> Settings </button>
......
......@@ -13,6 +13,8 @@ from django.conf import settings
from django.contrib.auth.models import User, Group
from cvat.apps.engine.models import (Task, Segment, Job, StatusChoice,
AttributeType)
from cvat.apps.annotation.models import AnnotationFormat
from cvat.apps.annotation.models import HandlerType
from unittest import mock
import io
import xml.etree.ElementTree as ET
......@@ -1636,7 +1638,7 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
def _dump_api_v1_tasks_id_annotations(self, pk, user, query_params=""):
with ForceLogin(user, self.client):
response = self.client.get(
"/api/v1/tasks/{0}/annotations/my_task_{0}{1}".format(pk, query_params))
"/api/v1/tasks/{0}/annotations/my_task_{0}?{1}".format(pk, query_params))
return response
......@@ -2022,15 +2024,20 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
"create", data)
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator)
self.assertEqual(response.status_code, HTTP_202_ACCEPTED)
cvat_format = AnnotationFormat.objects.get(name="CVAT")
for annotation_handler in cvat_format.annotationhandler_set.filter(type=HandlerType.DUMPER):
response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator,
"format={}".format(annotation_handler.display_name))
self.assertEqual(response.status_code, HTTP_202_ACCEPTED)
response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator)
self.assertEqual(response.status_code, HTTP_201_CREATED)
response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator,
"format={}".format(annotation_handler.display_name))
self.assertEqual(response.status_code, HTTP_201_CREATED)
response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator, "?action=download")
self.assertEqual(response.status_code, HTTP_200_OK)
self._check_dump_response(response, task, jobs, data)
response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator,
"action=download&format={}".format(annotation_handler.display_name))
self.assertEqual(response.status_code, HTTP_200_OK)
self._check_dump_response(response, task, jobs, data)
def _check_dump_response(self, response, task, jobs, data):
if response.status_code == status.HTTP_200_OK:
......
......@@ -4,6 +4,20 @@ import importlib
Import = namedtuple("Import", ["module", "name", "alias"])
def parse_imports(source_code: str):
root = ast.parse(source_code)
for node in ast.iter_child_nodes(root):
if isinstance(node, ast.Import):
module = []
elif isinstance(node, ast.ImportFrom):
module = node.module
else:
continue
for n in node.names:
yield Import(module, n.name, n.asname)
def import_modules(source_code: str):
results = {}
imports = parse_imports(source_code)
......@@ -20,17 +34,3 @@ def import_modules(source_code: str):
results[import_.name] = loaded_module
return results
def parse_imports(source_code: str):
root = ast.parse(source_code)
for node in ast.iter_child_nodes(root):
if isinstance(node, ast.Import):
module = []
elif isinstance(node, ast.ImportFrom):
module = node.module
else:
continue
for n in node.names:
yield Import(module, n.name, n.asname)
......@@ -8,6 +8,7 @@ import traceback
from ast import literal_eval
import shutil
from datetime import datetime
from tempfile import mkstemp
from django.http import HttpResponseBadRequest
from django.shortcuts import redirect, render
......@@ -36,9 +37,13 @@ from cvat.apps.engine.serializers import (TaskSerializer, UserSerializer,
ExceptionSerializer, AboutSerializer, JobSerializer, ImageMetaSerializer,
RqStatusSerializer, TaskDataSerializer, LabeledDataSerializer,
PluginSerializer, FileInfoSerializer, LogEventSerializer)
from cvat.apps.annotation.serializers import AnnotationFileSerializer
from django.contrib.auth.models import User
from django.core.exceptions import ObjectDoesNotExist
from cvat.apps.authentication import auth
from rest_framework.permissions import SAFE_METHODS
from cvat.apps.annotation.models import AnnotationHandler
from cvat.apps.annotation.format import get_annotation_formats
# Server REST API
@login_required
......@@ -147,6 +152,12 @@ class ServerViewSet(viewsets.ViewSet):
return Response("{} is an invalid directory".format(param),
status=status.HTTP_400_BAD_REQUEST)
@staticmethod
@action(detail=False, methods=['GET'], url_path='annotation/formats')
def formats(request):
data = get_annotation_formats()
return Response(data)
class TaskFilter(filters.FilterSet):
name = filters.CharFilter(field_name="name", lookup_expr="icontains")
owner = filters.CharFilter(field_name="owner__username", lookup_expr="icontains")
......@@ -223,10 +234,18 @@ class TaskViewSet(auth.TaskGetQuerySetMixin, viewsets.ModelViewSet):
if serializer.is_valid(raise_exception=True):
return Response(serializer.data)
elif request.method == 'PUT':
serializer = LabeledDataSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
data = annotation.put_task_data(pk, request.user, serializer.data)
return Response(data)
if request.query_params.get("format", ""):
return load_data_proxy(
request=request,
rq_id="{}@/api/v1/tasks/{}/annotations/upload".format(request.user, pk),
rq_func=annotation.load_task_data,
pk=pk,
)
else:
serializer = LabeledDataSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
data = annotation.put_task_data(pk, request.user, serializer.data)
return Response(data)
elif request.method == 'DELETE':
annotation.delete_task_data(pk, request.user)
return Response(status=status.HTTP_204_NO_CONTENT)
......@@ -247,19 +266,25 @@ class TaskViewSet(auth.TaskGetQuerySetMixin, viewsets.ModelViewSet):
url_path='annotations/(?P<filename>[^/]+)')
def dump(self, request, pk, filename):
filename = re.sub(r'[\\/*?:"<>|]', '_', filename)
queue = django_rq.get_queue("default")
username = request.user.username
db_task = self.get_object()
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
file_ext = request.query_params.get("format", "xml")
action = request.query_params.get("action")
if action not in [None, "download"]:
raise serializers.ValidationError(
"Please specify a correct 'action' for the request")
dump_format = request.query_params.get("format", "")
try:
db_dumper = AnnotationHandler.objects.get(display_name=dump_format)
except ObjectDoesNotExist:
raise serializers.ValidationError(
"Please specify a correct 'format' parameter for the request")
file_path = os.path.join(db_task.get_task_dirname(),
filename + ".{}.{}.".format(username, timestamp) + "xml")
"{}.{}.{}.{}".format(filename, username, timestamp, db_dumper.format.lower()))
queue = django_rq.get_queue("default")
rq_id = "{}@/api/v1/tasks/{}/annotations/{}".format(username, pk, filename)
rq_job = queue.fetch_job(rq_id)
......@@ -270,7 +295,7 @@ class TaskViewSet(auth.TaskGetQuerySetMixin, viewsets.ModelViewSet):
rq_job.meta[action] = True
rq_job.save_meta()
return sendfile(request, rq_job.meta["file_path"], attachment=True,
attachment_filename=filename + "." + file_ext)
attachment_filename="{}.{}".format(filename, db_dumper.format.lower()))
else:
return Response(status=status.HTTP_201_CREATED)
else: # Remove the old dump file
......@@ -286,10 +311,12 @@ class TaskViewSet(auth.TaskGetQuerySetMixin, viewsets.ModelViewSet):
else:
return Response(status=status.HTTP_202_ACCEPTED)
rq_job = queue.enqueue_call(func=annotation.dump_task_data,
args=(pk, request.user, file_path, request.scheme,
request.get_host(), request.query_params),
job_id=rq_id)
rq_job = queue.enqueue_call(
func=annotation.dump_task_data,
args=(pk, request.user, file_path, db_dumper,
request.scheme, request.get_host()),
job_id=rq_id,
)
rq_job.meta["file_path"] = file_path
rq_job.save_meta()
......@@ -380,13 +407,21 @@ class JobViewSet(viewsets.GenericViewSet,
data = annotation.get_job_data(pk, request.user)
return Response(data)
elif request.method == 'PUT':
serializer = LabeledDataSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
try:
data = annotation.put_job_data(pk, request.user, serializer.data)
except (AttributeError, IntegrityError) as e:
return Response(data=str(e), status=status.HTTP_400_BAD_REQUEST)
return Response(data)
if request.query_params.get("format", ""):
return load_data_proxy(
request=request,
rq_id="{}@/api/v1/jobs/{}/annotations/upload".format(request.user, pk),
rq_func=annotation.load_job_data,
pk=pk,
)
else:
serializer = LabeledDataSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
try:
data = annotation.put_job_data(pk, request.user, serializer.data)
except (AttributeError, IntegrityError) as e:
return Response(data=str(e), status=status.HTTP_400_BAD_REQUEST)
return Response(data)
elif request.method == 'DELETE':
annotation.delete_job_data(pk, request.user)
return Response(status=status.HTTP_204_NO_CONTENT)
......@@ -404,7 +439,6 @@ class JobViewSet(viewsets.GenericViewSet,
return Response(data=str(e), status=status.HTTP_400_BAD_REQUEST)
return Response(data)
class UserViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
mixins.RetrieveModelMixin, mixins.UpdateModelMixin):
queryset = User.objects.all().order_by('id')
......@@ -463,3 +497,44 @@ def rq_handler(job, exc_type, exc_value, tb):
return task.rq_handler(job, exc_type, exc_value, tb)
return True
def load_data_proxy(request, rq_id, rq_func, pk):
queue = django_rq.get_queue("default")
rq_job = queue.fetch_job(rq_id)
upload_format = request.query_params.get("format", "")
if not rq_job:
serializer = AnnotationFileSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
try:
db_parser = AnnotationHandler.objects.get(pk=upload_format)
except ObjectDoesNotExist:
raise serializers.ValidationError(
"Please specify a correct 'format' parameter for the upload request")
anno_file = serializer.validated_data['annotation_file']
fd, filename = mkstemp(prefix='cvat_{}'.format(pk))
with open(filename, 'wb+') as f:
for chunk in anno_file.chunks():
f.write(chunk)
rq_job = queue.enqueue_call(
func=rq_func,
args=(pk, request.user, filename, db_parser),
job_id=rq_id
)
rq_job.meta['tmp_file'] = filename
rq_job.meta['tmp_file_descriptor'] = fd
rq_job.save_meta()
else:
if rq_job.is_finished:
os.close(rq_job.meta['tmp_file_descriptor'])
os.remove(rq_job.meta['tmp_file'])
rq_job.delete()
return Response(status=status.HTTP_201_CREATED)
elif rq_job.is_failed:
os.close(rq_job.meta['tmp_file_descriptor'])
os.remove(rq_job.meta['tmp_file'])
rq_job.delete()
return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return Response(status=status.HTTP_202_ACCEPTED)
......@@ -10,6 +10,7 @@ from cvat.apps.engine.models import Task, Job, User
from cvat.apps.engine.annotation import dump_task_data
from cvat.apps.engine.plugins import add_plugin
from cvat.apps.git.models import GitStatusChoice
from cvat.apps.annotation.models import AnnotationHandler
from cvat.apps.git.models import GitData
from collections import OrderedDict
......@@ -62,6 +63,7 @@ class Git:
}
self._cwd = os.path.join(os.getcwd(), "data", str(tid), "repos")
self._diffs_dir = os.path.join(os.getcwd(), "data", str(tid), "repos_diffs_v2")
self._task_mode = Task.objects.get(pk = tid).mode
self._task_name = re.sub(r'[\\/*?:"<>|\s]', '_', Task.objects.get(pk = tid).name)[:100]
self._branch_name = 'cvat_{}_{}'.format(tid, self._task_name)
self._annotation_file = os.path.join(self._cwd, self._path)
......@@ -259,17 +261,19 @@ class Git:
self._rep.git.add(['.gitattributes'])
# Dump an annotation
# TODO: Fix dump, query params
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
display_name = "CVAT XML 1.1"
display_name += " for images" if self._task_mode == "annotation" else " for videos"
cvat_dumper = AnnotationHandler.objects.get(display_name=display_name)
dump_name = os.path.join(db_task.get_task_dirname(),
"git_annotation_{}.".format(timestamp) + "dump")
dump_task_data(
pk=self._tid,
user=user,
file_path=dump_name,
filename=dump_name,
dumper=cvat_dumper,
scheme=scheme,
host=host,
query_params={},
)
ext = os.path.splitext(self._path)[1]
......
......@@ -34,3 +34,4 @@ Pygments==2.3.1
drf-yasg==1.15.0
Shapely==1.6.4.post2
pdf2image==1.6.0
pascal_voc_writer==0.1.4
......@@ -92,6 +92,7 @@ INSTALLED_APPS = [
'cvat.apps.authentication',
'cvat.apps.documentation',
'cvat.apps.git',
'cvat.apps.annotation',
'django_rq',
'compressor',
'cacheops',
......@@ -121,7 +122,10 @@ REST_FRAMEWORK = {
'DEFAULT_FILTER_BACKENDS': (
'rest_framework.filters.SearchFilter',
'django_filters.rest_framework.DjangoFilterBackend',
'rest_framework.filters.OrderingFilter')
'rest_framework.filters.OrderingFilter'),
# Disable default handling of the 'format' query parameter by REST framework
'URL_FORMAT_OVERRIDE': None,
}
if 'yes' == os.environ.get('TF_ANNOTATION', 'no'):
......@@ -334,11 +338,14 @@ if os.getenv('DJANGO_LOG_SERVER_HOST'):
STATIC_URL = '/static/'
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
os.makedirs(STATIC_ROOT, exist_ok=True)
DATA_ROOT = os.path.join(BASE_DIR, 'data')
os.makedirs(DATA_ROOT, exist_ok=True)
SHARE_ROOT = os.path.join(BASE_DIR, 'share')
os.makedirs(SHARE_ROOT, exist_ok=True)
MODELS_ROOT=os.path.join(BASE_DIR, 'models')
MODELS_ROOT = os.path.join(BASE_DIR, 'models')
os.makedirs(MODELS_ROOT, exist_ok=True)
DATA_UPLOAD_MAX_MEMORY_SIZE = 100 * 1024 * 1024 # 100 MB
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册