未验证 提交 ce37be1f 编写于 作者: R Roman Donchenko 提交者: GitHub

SDK: add a ProjectVisionDataset class (#5523)

上级 06f3359a
......@@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>)
- \[SDK\] Configuration setting to change the dataset cache directory
(<https://github.com/opencv/cvat/pull/5535>)
- \[SDK\] Class to represent a project as a PyTorch dataset
(<https://github.com/opencv/cvat/pull/5523>)
### Changed
- The Docker Compose files now use the Compose Specification version
......
......@@ -6,31 +6,22 @@ import shutil
import types
import zipfile
from concurrent.futures import ThreadPoolExecutor
from typing import (
Callable,
Dict,
FrozenSet,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from pathlib import Path
from typing import Callable, Container, Dict, FrozenSet, List, Mapping, Optional, Type, TypeVar
import attrs
import attrs.validators
import PIL.Image
import torch
import torch.utils.data
import torchvision.datasets
from typing_extensions import TypedDict
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.api_client.model_utils import to_json
from cvat_sdk.core.utils import atomic_writer
from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShape, TaskRead
_ModelType = TypeVar("_ModelType")
......@@ -47,8 +38,8 @@ class FrameAnnotations:
Contains annotations that pertain to a single frame.
"""
tags: List[LabeledImage] = attrs.Factory(list)
shapes: List[LabeledShape] = attrs.Factory(list)
tags: List[models.LabeledImage] = attrs.Factory(list)
shapes: List[models.LabeledShape] = attrs.Factory(list)
@attrs.frozen
......@@ -67,6 +58,12 @@ class Target:
"""
def _get_server_dir(client: cvat_sdk.core.Client) -> Path:
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
server_dir_name = base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
return client.config.cache_dir / f"servers/{server_dir_name}"
class TaskVisionDataset(torchvision.datasets.VisionDataset):
"""
Represents a task on a CVAT server as a PyTorch Dataset.
......@@ -132,13 +129,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
f" current chunk type is {self._task.data_original_chunk_type!r}"
)
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
server_dir_name = (
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
)
server_dir = client.config.cache_dir / f"servers/{server_dir_name}"
self._task_dir = server_dir / f"tasks/{self._task.id}"
self._task_dir = _get_server_dir(client) / f"tasks/{self._task.id}"
self._initialize_task_dir()
super().__init__(
......@@ -149,7 +140,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
)
data_meta = self._ensure_model(
"data_meta.json", DataMetaRead, self._task.get_meta, "data metadata"
"data_meta.json", models.DataMetaRead, self._task.get_meta, "data metadata"
)
self._active_frame_indexes = sorted(
set(range(self._task.size)) - set(data_meta.deleted_frames)
......@@ -186,7 +177,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
)
annotations = self._ensure_model(
"annotations.json", LabeledData, self._task.get_annotations, "annotations"
"annotations.json", models.LabeledData, self._task.get_annotations, "annotations"
)
self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict(
......@@ -206,7 +197,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
try:
with open(task_json_path, "rb") as task_json_file:
saved_task = TaskRead._new_from_openapi_data(**json.load(task_json_file))
saved_task = models.TaskRead._new_from_openapi_data(**json.load(task_json_file))
except Exception:
self._logger.info("Task is not yet cached or the cache is corrupted")
......@@ -295,6 +286,109 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
return len(self._active_frame_indexes)
class ProjectVisionDataset(torchvision.datasets.VisionDataset):
"""
Represents a project on a CVAT server as a PyTorch Dataset.
The dataset contains one sample for each frame of each task in the project
(except for tasks that are filtered out - see the description of `task_filter`
in the constructor). The sequence of samples is formed by concatening sequences
of samples from all included tasks in an arbitrary order that's consistent
between executions. Each task's sequence of samples corresponds to the sequence
of frames on the server.
See `TaskVisionDataset` for information on sample format, caching, and
current limitations.
"""
def __init__(
self,
client: cvat_sdk.core.Client,
project_id: int,
*,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
label_name_to_index: Mapping[str, int] = None,
task_filter: Optional[Callable[[models.ITaskRead], bool]] = None,
include_subsets: Optional[Container[str]] = None,
) -> None:
"""
Creates a dataset corresponding to the project with ID `project_id` on the
server that `client` is connected to.
`transforms`, `transform` and `target_transforms` are optional transformation
functions; see the documentation for `torchvision.datasets.VisionDataset` for
more information.
See `TaskVisionDataset.__init__` for information on `label_name_to_index`.
By default, all of the project's tasks will be included in the dataset.
The following parameters can be specified to exclude some tasks:
* If `task_filter` is set to a callable object, it will be applied to every task.
Tasks for which it returns a false value will be excluded.
* If `include_subsets` is set to a container, then tasks whose subset is
not a member of this container will be excluded.
"""
self._logger = client.logger
self._logger.info(f"Fetching project {project_id}...")
project = client.projects.retrieve(project_id)
# We don't actually need to save anything to this directory (yet),
# but VisionDataset.__init__ requires a root, so make one.
# It could be useful in the future to store the project data for
# offline-only mode.
project_dir = _get_server_dir(client) / f"projects/{project_id}"
project_dir.mkdir(parents=True, exist_ok=True)
super().__init__(
os.fspath(project_dir),
transforms=transforms,
transform=transform,
target_transform=target_transform,
)
self._logger.info("Fetching project tasks...")
tasks = project.get_tasks()
if task_filter is not None:
tasks = list(filter(task_filter, tasks))
if include_subsets is not None:
tasks = [task for task in tasks if task.subset in include_subsets]
tasks.sort(key=lambda t: t.id) # ensure consistent order between executions
self._underlying = torch.utils.data.ConcatDataset(
[
TaskVisionDataset(client, task.id, label_name_to_index=label_name_to_index)
for task in tasks
]
)
def __getitem__(self, sample_index: int):
"""
Returns the sample with index `sample_index`.
`sample_index` must satisfy the condition `0 <= sample_index < len(self)`.
"""
sample_image, sample_target = self._underlying[sample_index]
if self.transforms:
sample_image, sample_target = self.transforms(sample_image, sample_target)
return sample_image, sample_target
def __len__(self) -> int:
"""Returns the number of samples in the dataset."""
return len(self._underlying)
@attrs.frozen
class ExtractSingleLabelIndex:
"""
......
......@@ -3,6 +3,7 @@
# SPDX-License-Identifier: MIT
import io
import itertools
import os
from logging import Logger
from pathlib import Path
......@@ -25,6 +26,22 @@ except ImportError:
from shared.utils.helpers import generate_image_files
@pytest.fixture(autouse=True)
def _common_setup(
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
):
logger = fxt_logger[0]
client = fxt_login[0]
client.logger = logger
client.config.cache_dir = tmp_path / "cache"
api_client = client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestTaskVisionDataset:
@pytest.fixture(autouse=True)
......@@ -32,28 +49,11 @@ class TestTaskVisionDataset:
self,
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
fxt_stdout: io.StringIO,
):
self.tmp_path = tmp_path
logger, self.logger_stream = fxt_logger
self.stdout = fxt_stdout
self.client, self.user = fxt_login
self.client.logger = logger
self.client.config.cache_dir = tmp_path / "cache"
api_client = self.client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger
self._create_task()
yield
def _create_task(self):
self.client = fxt_login[0]
self.images = generate_image_files(10)
image_dir = self.tmp_path / "images"
image_dir = tmp_path / "images"
image_dir.mkdir()
image_paths = []
......@@ -225,3 +225,137 @@ class TestTaskVisionDataset:
_, target = dataset[5]
assert target.label_id_to_index[label_name_to_id["person"]] == 123
assert target.label_id_to_index[label_name_to_id["car"]] == 456
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestProjectVisionDataset:
@pytest.fixture(autouse=True)
def setup(
self,
tmp_path: Path,
fxt_login: Tuple[Client, str],
):
self.client = fxt_login[0]
self.project = self.client.projects.create(
models.ProjectWriteRequest(
"PyTorch integration test project",
labels=[
models.PatchedLabelRequest(name="person"),
models.PatchedLabelRequest(name="car"),
],
)
)
self.label_ids = sorted(l.id for l in self.project.labels)
subsets = ["Train", "Test", "Val"]
num_images_per_task = 3
all_images = generate_image_files(num_images_per_task * len(subsets))
self.images_per_task = list(zip(*[iter(all_images)] * num_images_per_task))
image_dir = tmp_path / "images"
image_dir.mkdir()
image_paths_per_task = []
for images in self.images_per_task:
image_paths = []
for image in images:
image_path = image_dir / image.name
image_path.write_bytes(image.getbuffer())
image_paths.append(image_path)
image_paths_per_task.append(image_paths)
self.tasks = [
self.client.tasks.create_from_data(
models.TaskWriteRequest(
"PyTorch integration test task",
project_id=self.project.id,
subset=subset,
),
ResourceType.LOCAL,
image_paths,
data_params={"image_quality": 70},
)
for subset, image_paths in zip(subsets, image_paths_per_task)
]
# sort both self.tasks and self.images_per_task in the order that ProjectVisionDataset uses
self.tasks, self.images_per_task = zip(
*sorted(zip(self.tasks, self.images_per_task), key=lambda t: t[0].id)
)
for task_id, label_index in ((0, 0), (1, 1), (2, 0)):
self.tasks[task_id].update_annotations(
models.PatchedLabeledDataRequest(
tags=[
models.LabeledImageRequest(
frame=task_id, label_id=self.label_ids[label_index]
),
],
)
)
def test_basic(self):
dataset = cvatpt.ProjectVisionDataset(self.client, self.project.id)
assert len(dataset) == sum(task.size for task in self.tasks)
for sample, image in zip(dataset, itertools.chain.from_iterable(self.images_per_task)):
assert torch.equal(TF.pil_to_tensor(sample[0]), TF.pil_to_tensor(PIL.Image.open(image)))
assert dataset[0][1].annotations.tags[0].label_id == self.label_ids[0]
assert dataset[4][1].annotations.tags[0].label_id == self.label_ids[1]
assert dataset[8][1].annotations.tags[0].label_id == self.label_ids[0]
def _test_filtering(self, **kwargs):
dataset = cvatpt.ProjectVisionDataset(self.client, self.project.id, **kwargs)
assert len(dataset) == sum(task.size for task in self.tasks[1:])
for sample, image in zip(dataset, itertools.chain.from_iterable(self.images_per_task[1:])):
assert torch.equal(TF.pil_to_tensor(sample[0]), TF.pil_to_tensor(PIL.Image.open(image)))
assert dataset[1][1].annotations.tags[0].label_id == self.label_ids[1]
assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0]
def test_task_filter(self):
self._test_filtering(task_filter=lambda t: t.subset != self.tasks[0].subset)
def test_include_subsets(self):
self._test_filtering(include_subsets={self.tasks[1].subset, self.tasks[2].subset})
def test_custom_label_mapping(self):
label_name_to_id = {label.name: label.id for label in self.project.labels}
dataset = cvatpt.ProjectVisionDataset(
self.client, self.project.id, label_name_to_index={"person": 123, "car": 456}
)
_, target = dataset[5]
assert target.label_id_to_index[label_name_to_id["person"]] == 123
assert target.label_id_to_index[label_name_to_id["car"]] == 456
def test_separate_transforms(self):
dataset = cvatpt.ProjectVisionDataset(
self.client,
self.project.id,
transform=torchvision.transforms.ToTensor(),
target_transform=cvatpt.ExtractSingleLabelIndex(),
)
assert torch.equal(
dataset[0][0], TF.pil_to_tensor(PIL.Image.open(self.images_per_task[0][0]))
)
assert torch.equal(dataset[0][1], torch.tensor(0))
def test_combined_transforms(self):
dataset = cvatpt.ProjectVisionDataset(
self.client,
self.project.id,
transforms=lambda x, y: (y, x),
)
assert isinstance(dataset[0][0], cvatpt.Target)
assert isinstance(dataset[0][1], PIL.Image.Image)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册