未验证 提交 bc5036fd 编写于 作者: R Roman Donchenko 提交者: GitHub

Create cvat_sdk.datasets, a framework-agnostic version of cvat_sdk.pytorch (#6428)

The new `TaskDataset` class provides conveniences like per-frame
annotations, bulk data downloading, and caching without forcing a
dependency on PyTorch (and somewhat awkwardly conforming to the PyTorch
dataset interface). It also provides a few extra niceties, like easy
access to labels and original frame numbers.

Note that it's called `TaskDataset` rather than `TaskVisionDataset`, as
my plan is to keep it domain-agnostic. The `MediaElement` class is
extensible, and we can add, for example, support for point clouds, by
adding another `load_*` method.

There is currently no `ProjectDataset` equivalent, although one could
(and probably should) be added later. If we add one, we should probably
also add a `task_id` field to `Sample`.
上级 9fc6b00e
......@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[Unreleased]
### Added
- Multi-line text attributes supported (<https://github.com/opencv/cvat/pull/6458>)
- \{SDK\] `cvat_sdk.datasets`, a framework-agnostic equivalent of `cvat_sdk.pytorch`
(<https://github.com/opencv/cvat/pull/6428>)
### Changed
- TDB
......
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from .caching import UpdatePolicy
from .common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError
from .task_dataset import TaskDataset
# Copyright (C) 2022-2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import abc
from typing import List
import attrs
import attrs.validators
import PIL.Image
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):
pass
@attrs.frozen
class FrameAnnotations:
"""
Contains annotations that pertain to a single frame.
"""
tags: List[models.LabeledImage] = attrs.Factory(list)
shapes: List[models.LabeledShape] = attrs.Factory(list)
class MediaElement(metaclass=abc.ABCMeta):
"""
The media part of a dataset sample.
"""
@abc.abstractmethod
def load_image(self) -> PIL.Image.Image:
"""
Loads the media data and returns it as a PIL Image object.
"""
...
@attrs.frozen
class Sample:
"""
Represents an element of a dataset.
"""
frame_index: int
"""Index of the corresponding frame in its task."""
annotations: FrameAnnotations
"""Annotations belonging to the frame."""
media: MediaElement
"""Media data of the frame."""
# Copyright (C) 2022-2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import zipfile
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence
import PIL.Image
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError
_NUM_DOWNLOAD_THREADS = 4
class TaskDataset:
"""
Represents a task on a CVAT server as a collection of samples.
Each sample corresponds to one frame in the task, and provides access to
the corresponding annotations and media data. Deleted frames are omitted.
This class caches all data and annotations for the task on the local file system
during construction.
Limitations:
* Only tasks with image (not video) data are supported at the moment.
* Track annotations are currently not accessible.
"""
class _TaskMediaElement(MediaElement):
def __init__(self, dataset: TaskDataset, frame_index: int) -> None:
self._dataset = dataset
self._frame_index = frame_index
def load_image(self) -> PIL.Image.Image:
return self._dataset._load_frame_image(self._frame_index)
def __init__(
self,
client: cvat_sdk.core.Client,
task_id: int,
*,
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
server that `client` is connected to.
`update_policy` determines when and if the local cache will be updated.
"""
self._logger = client.logger
cache_manager = make_cache_manager(client, update_policy)
self._task = cache_manager.retrieve_task(task_id)
if not self._task.size or not self._task.data_chunk_size:
raise UnsupportedDatasetError("The task has no data")
if self._task.data_original_chunk_type != "imageset":
raise UnsupportedDatasetError(
f"{self.__class__.__name__} only supports tasks with image chunks;"
f" current chunk type is {self._task.data_original_chunk_type!r}"
)
self._logger.info("Fetching labels...")
self._labels = tuple(self._task.get_labels())
data_meta = cache_manager.ensure_task_model(
self._task.id,
"data_meta.json",
models.DataMetaRead,
self._task.get_meta,
"data metadata",
)
active_frame_indexes = set(range(self._task.size)) - set(data_meta.deleted_frames)
self._logger.info("Downloading chunks...")
self._chunk_dir = cache_manager.chunk_dir(task_id)
self._chunk_dir.mkdir(exist_ok=True, parents=True)
needed_chunks = {index // self._task.data_chunk_size for index in active_frame_indexes}
with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:
def ensure_chunk(chunk_index):
cache_manager.ensure_chunk(self._task, chunk_index)
for _ in pool.map(ensure_chunk, sorted(needed_chunks)):
# just need to loop through all results so that any exceptions are propagated
pass
self._logger.info("All chunks downloaded")
annotations = cache_manager.ensure_task_model(
self._task.id,
"annotations.json",
models.LabeledData,
self._task.get_annotations,
"annotations",
)
self._frame_annotations = {
frame_index: FrameAnnotations() for frame_index in sorted(active_frame_indexes)
}
for tag in annotations.tags:
# Some annotations may belong to deleted frames; skip those.
if tag.frame in self._frame_annotations:
self._frame_annotations[tag.frame].tags.append(tag)
for shape in annotations.shapes:
if shape.frame in self._frame_annotations:
self._frame_annotations[shape.frame].shapes.append(shape)
# TODO: tracks?
self._samples = [
Sample(frame_index=k, annotations=v, media=self._TaskMediaElement(self, k))
for k, v in self._frame_annotations.items()
]
@property
def labels(self) -> Sequence[models.ILabel]:
"""
Returns the labels configured in the task.
Clients must not modify the object returned by this property or its components.
"""
return self._labels
@property
def samples(self) -> Sequence[Sample]:
"""
Returns a sequence of all samples, in order of their frame indices.
Note that the frame indices may not be contiguous, as deleted frames will not be included.
Clients must not modify the object returned by this property or its components.
"""
return self._samples
def _load_frame_image(self, frame_index: int) -> PIL.Image:
assert frame_index in self._frame_annotations
chunk_index = frame_index // self._task.data_chunk_size
member_index = frame_index % self._task.data_chunk_size
with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip:
with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member:
image = PIL.Image.open(chunk_member)
image.load()
return image
......@@ -2,8 +2,12 @@
#
# SPDX-License-Identifier: MIT
from .caching import UpdatePolicy
from .common import FrameAnnotations, Target, UnsupportedDatasetError
from .common import Target
from .project_dataset import ProjectVisionDataset
from .task_dataset import TaskVisionDataset
from .transforms import ExtractBoundingBoxes, ExtractSingleLabelIndex, LabeledBoxes
# isort: split
# Compatibility imports
from ..datasets.caching import UpdatePolicy
from ..datasets.common import FrameAnnotations, UnsupportedDatasetError
......@@ -2,28 +2,11 @@
#
# SPDX-License-Identifier: MIT
from typing import List, Mapping
from typing import Mapping
import attrs
import attrs.validators
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):
pass
@attrs.frozen
class FrameAnnotations:
"""
Contains annotations that pertain to a single frame.
"""
tags: List[models.LabeledImage] = attrs.Factory(list)
shapes: List[models.LabeledShape] = attrs.Factory(list)
from cvat_sdk.datasets.common import FrameAnnotations
@attrs.frozen
......
......@@ -12,7 +12,7 @@ import torchvision.datasets
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.pytorch.task_dataset import TaskVisionDataset
......
......@@ -2,21 +2,17 @@
#
# SPDX-License-Identifier: MIT
import collections
import os
import types
import zipfile
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Mapping, Optional
from typing import Callable, Mapping, Optional
import PIL.Image
import torchvision.datasets
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.pytorch.common import FrameAnnotations, Target, UnsupportedDatasetError
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.task_dataset import TaskDataset
from cvat_sdk.pytorch.common import Target
_NUM_DOWNLOAD_THREADS = 4
......@@ -75,92 +71,31 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
`update_policy` determines when and if the local cache will be updated.
"""
self._logger = client.logger
self._underlying = TaskDataset(client, task_id, update_policy=update_policy)
cache_manager = make_cache_manager(client, update_policy)
self._task = cache_manager.retrieve_task(task_id)
if not self._task.size or not self._task.data_chunk_size:
raise UnsupportedDatasetError("The task has no data")
if self._task.data_original_chunk_type != "imageset":
raise UnsupportedDatasetError(
f"{self.__class__.__name__} only supports tasks with image chunks;"
f" current chunk type is {self._task.data_original_chunk_type!r}"
)
super().__init__(
os.fspath(cache_manager.task_dir(self._task.id)),
os.fspath(cache_manager.task_dir(task_id)),
transforms=transforms,
transform=transform,
target_transform=target_transform,
)
data_meta = cache_manager.ensure_task_model(
self._task.id,
"data_meta.json",
models.DataMetaRead,
self._task.get_meta,
"data metadata",
)
self._active_frame_indexes = sorted(
set(range(self._task.size)) - set(data_meta.deleted_frames)
)
self._logger.info("Downloading chunks...")
self._chunk_dir = cache_manager.chunk_dir(task_id)
self._chunk_dir.mkdir(exist_ok=True, parents=True)
needed_chunks = {
index // self._task.data_chunk_size for index in self._active_frame_indexes
}
with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:
def ensure_chunk(chunk_index):
cache_manager.ensure_chunk(self._task, chunk_index)
for _ in pool.map(ensure_chunk, sorted(needed_chunks)):
# just need to loop through all results so that any exceptions are propagated
pass
self._logger.info("All chunks downloaded")
if label_name_to_index is None:
self._label_id_to_index = types.MappingProxyType(
{
label.id: label_index
for label_index, label in enumerate(
sorted(self._task.get_labels(), key=lambda l: l.id)
sorted(self._underlying.labels, key=lambda l: l.id)
)
}
)
else:
self._label_id_to_index = types.MappingProxyType(
{label.id: label_name_to_index[label.name] for label in self._task.get_labels()}
{label.id: label_name_to_index[label.name] for label in self._underlying.labels}
)
annotations = cache_manager.ensure_task_model(
self._task.id,
"annotations.json",
models.LabeledData,
self._task.get_annotations,
"annotations",
)
self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict(
FrameAnnotations
)
for tag in annotations.tags:
self._frame_annotations[tag.frame].tags.append(tag)
for shape in annotations.shapes:
self._frame_annotations[shape.frame].shapes.append(shape)
# TODO: tracks?
def __getitem__(self, sample_index: int):
"""
Returns the sample with index `sample_index`.
......@@ -168,19 +103,10 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
`sample_index` must satisfy the condition `0 <= sample_index < len(self)`.
"""
frame_index = self._active_frame_indexes[sample_index]
chunk_index = frame_index // self._task.data_chunk_size
member_index = frame_index % self._task.data_chunk_size
sample = self._underlying.samples[sample_index]
with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip:
with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member:
sample_image = PIL.Image.open(chunk_member)
sample_image.load()
sample_target = Target(
annotations=self._frame_annotations[frame_index],
label_id_to_index=self._label_id_to_index,
)
sample_image = sample.media.load_image()
sample_target = Target(sample.annotations, self._label_id_to_index)
if self.transforms:
sample_image, sample_target = self.transforms(sample_image, sample_target)
......@@ -188,4 +114,4 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
def __len__(self) -> int:
"""Returns the number of samples in the dataset."""
return len(self._active_frame_indexes)
return len(self._underlying.samples)
......@@ -10,7 +10,8 @@ import torch
import torch.utils.data
from typing_extensions import TypedDict
from cvat_sdk.pytorch.common import Target, UnsupportedDatasetError
from cvat_sdk.datasets.common import UnsupportedDatasetError
from cvat_sdk.pytorch.common import Target
@attrs.frozen
......
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
from logging import Logger
from pathlib import Path
from typing import Tuple
import cvat_sdk.datasets as cvatds
import PIL.Image
import pytest
from cvat_sdk import Client, models
from cvat_sdk.core.proxies.tasks import ResourceType
from shared.utils.helpers import generate_image_files
from .util import restrict_api_requests
@pytest.fixture(autouse=True)
def _common_setup(
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
):
logger = fxt_logger[0]
client = fxt_login[0]
client.logger = logger
client.config.cache_dir = tmp_path / "cache"
api_client = client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger
class TestTaskDataset:
@pytest.fixture(autouse=True)
def setup(
self,
tmp_path: Path,
fxt_login: Tuple[Client, str],
):
self.client = fxt_login[0]
self.images = generate_image_files(10)
image_dir = tmp_path / "images"
image_dir.mkdir()
image_paths = []
for image in self.images:
image_path = image_dir / image.name
image_path.write_bytes(image.getbuffer())
image_paths.append(image_path)
self.task = self.client.tasks.create_from_data(
models.TaskWriteRequest(
"Dataset layer test task",
labels=[
models.PatchedLabelRequest(name="person"),
models.PatchedLabelRequest(name="car"),
],
),
resource_type=ResourceType.LOCAL,
resources=image_paths,
data_params={"chunk_size": 3},
)
self.expected_labels = sorted(self.task.get_labels(), key=lambda l: l.id)
self.task.update_annotations(
models.PatchedLabeledDataRequest(
tags=[
models.LabeledImageRequest(frame=8, label_id=self.expected_labels[0].id),
models.LabeledImageRequest(frame=8, label_id=self.expected_labels[1].id),
],
shapes=[
models.LabeledShapeRequest(
frame=6,
label_id=self.expected_labels[1].id,
type=models.ShapeType("rectangle"),
points=[1.0, 2.0, 3.0, 4.0],
),
],
)
)
def test_basic(self):
dataset = cvatds.TaskDataset(self.client, self.task.id)
# verify that the cache is not empty
assert list(self.client.config.cache_dir.iterdir())
for expected_label, actual_label in zip(
self.expected_labels, sorted(dataset.labels, key=lambda l: l.id)
):
assert expected_label.id == actual_label.id
assert expected_label.name == actual_label.name
assert len(dataset.samples) == self.task.size
for index, sample in enumerate(dataset.samples):
assert sample.frame_index == index
actual_image = sample.media.load_image()
expected_image = PIL.Image.open(self.images[index])
assert actual_image == expected_image
assert not dataset.samples[0].annotations.tags
assert not dataset.samples[1].annotations.shapes
assert {tag.label_id for tag in dataset.samples[8].annotations.tags} == {
label.id for label in self.expected_labels
}
assert not dataset.samples[8].annotations.shapes
assert not dataset.samples[6].annotations.tags
assert len(dataset.samples[6].annotations.shapes) == 1
assert dataset.samples[6].annotations.shapes[0].type.value == "rectangle"
assert dataset.samples[6].annotations.shapes[0].points == [1.0, 2.0, 3.0, 4.0]
def test_deleted_frame(self):
self.task.remove_frames_by_ids([1])
dataset = cvatds.TaskDataset(self.client, self.task.id)
assert len(dataset.samples) == self.task.size - 1
# sample #0 is still frame #0
assert dataset.samples[0].frame_index == 0
assert dataset.samples[0].media.load_image() == PIL.Image.open(self.images[0])
# sample #1 is now frame #2
assert dataset.samples[1].frame_index == 2
assert dataset.samples[1].media.load_image() == PIL.Image.open(self.images[2])
# sample #5 is now frame #6
assert dataset.samples[5].frame_index == 6
assert dataset.samples[5].media.load_image() == PIL.Image.open(self.images[6])
assert len(dataset.samples[5].annotations.shapes) == 1
def test_offline(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
update_policy=cvatds.UpdatePolicy.IF_MISSING_OR_STALE,
)
fresh_samples = list(dataset.samples)
restrict_api_requests(monkeypatch)
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
update_policy=cvatds.UpdatePolicy.NEVER,
)
cached_samples = list(dataset.samples)
for fresh_sample, cached_sample in zip(fresh_samples, cached_samples):
assert fresh_sample.frame_index == cached_sample.frame_index
assert fresh_sample.annotations == cached_sample.annotations
assert fresh_sample.media.load_image() == cached_sample.media.load_image()
def test_update(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
)
# Recreating the dataset should only result in minimal requests.
restrict_api_requests(
monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"}
)
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
)
assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[1].id
# After an update, the annotations should be redownloaded.
monkeypatch.undo()
self.task.update_annotations(
models.PatchedLabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
id=dataset.samples[6].annotations.shapes[0].id,
frame=6,
label_id=self.expected_labels[0].id,
type=models.ShapeType("rectangle"),
points=[1.0, 2.0, 3.0, 4.0],
),
]
)
)
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
)
assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[0].id
......@@ -7,12 +7,10 @@ import itertools
import os
from logging import Logger
from pathlib import Path
from typing import Container, Tuple
from urllib.parse import urlparse
from typing import Tuple
import pytest
from cvat_sdk import Client, models
from cvat_sdk.api_client.rest import RESTClientObject
from cvat_sdk.core.proxies.tasks import ResourceType
try:
......@@ -30,6 +28,8 @@ except ModuleNotFoundError as e:
from shared.utils.helpers import generate_image_files
from .util import restrict_api_requests
@pytest.fixture(autouse=True)
def _common_setup(
......@@ -47,20 +47,6 @@ def _common_setup(
api_client.configuration.logger[k] = logger
def _restrict_api_requests(
monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = ()
) -> None:
original_request = RESTClientObject.request
def restricted_request(self, method, url, *args, **kwargs):
parsed_url = urlparse(url)
if parsed_url.path in allow_paths:
return original_request(self, method, url, *args, **kwargs)
raise RuntimeError("Disallowed!")
monkeypatch.setattr(RESTClientObject, "request", restricted_request)
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestTaskVisionDataset:
@pytest.fixture(autouse=True)
......@@ -254,7 +240,7 @@ class TestTaskVisionDataset:
fresh_samples = list(dataset)
_restrict_api_requests(monkeypatch)
restrict_api_requests(monkeypatch)
dataset = cvatpt.TaskVisionDataset(
self.client,
......@@ -273,7 +259,7 @@ class TestTaskVisionDataset:
)
# Recreating the dataset should only result in minimal requests.
_restrict_api_requests(
restrict_api_requests(
monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"}
)
......@@ -447,7 +433,7 @@ class TestProjectVisionDataset:
fresh_samples = list(dataset)
_restrict_api_requests(monkeypatch)
restrict_api_requests(monkeypatch)
dataset = cvatpt.ProjectVisionDataset(
self.client,
......
......@@ -4,8 +4,11 @@
import textwrap
from pathlib import Path
from typing import Tuple
from typing import Container, Tuple
from urllib.parse import urlparse
import pytest
from cvat_sdk.api_client.rest import RESTClientObject
from cvat_sdk.core.helpers import TqdmProgressReporter
from tqdm import tqdm
......@@ -82,3 +85,17 @@ def generate_coco_anno(image_path: str, image_width: int, image_height: int) ->
"image_width": image_width,
}
)
def restrict_api_requests(
monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = ()
) -> None:
original_request = RESTClientObject.request
def restricted_request(self, method, url, *args, **kwargs):
parsed_url = urlparse(url)
if parsed_url.path in allow_paths:
return original_request(self, method, url, *args, **kwargs)
raise RuntimeError("Disallowed!")
monkeypatch.setattr(RESTClientObject, "request", restricted_request)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册