# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT import io import itertools import os from logging import Logger from pathlib import Path from typing import Tuple import pytest from cvat_sdk import Client, models from cvat_sdk.api_client.rest import RESTClientObject from cvat_sdk.core.proxies.tasks import ResourceType try: import cvat_sdk.pytorch as cvatpt import PIL.Image import torch import torchvision.transforms import torchvision.transforms.functional as TF from torch.utils.data import DataLoader except ImportError: cvatpt = None from shared.utils.helpers import generate_image_files @pytest.fixture(autouse=True) def _common_setup( tmp_path: Path, fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], ): logger = fxt_logger[0] client = fxt_login[0] client.logger = logger client.config.cache_dir = tmp_path / "cache" api_client = client.api_client for k in api_client.configuration.logger: api_client.configuration.logger[k] = logger def _disable_api_requests(monkeypatch: pytest.MonkeyPatch) -> None: def disabled_request(*args, **kwargs): raise RuntimeError("Disabled!") monkeypatch.setattr(RESTClientObject, "request", disabled_request) @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") class TestTaskVisionDataset: @pytest.fixture(autouse=True) def setup( self, tmp_path: Path, fxt_login: Tuple[Client, str], ): self.client = fxt_login[0] self.images = generate_image_files(10) image_dir = tmp_path / "images" image_dir.mkdir() image_paths = [] for image in self.images: image_path = image_dir / image.name image_path.write_bytes(image.getbuffer()) image_paths.append(image_path) self.task = self.client.tasks.create_from_data( models.TaskWriteRequest( "PyTorch integration test task", labels=[ models.PatchedLabelRequest(name="person"), models.PatchedLabelRequest(name="car"), ], ), resource_type=ResourceType.LOCAL, resources=list(map(os.fspath, image_paths)), data_params={"chunk_size": 3}, ) self.label_ids = sorted(l.id for l in self.task.labels) self.task.update_annotations( models.PatchedLabeledDataRequest( tags=[ models.LabeledImageRequest(frame=5, label_id=self.label_ids[0]), models.LabeledImageRequest(frame=6, label_id=self.label_ids[1]), models.LabeledImageRequest(frame=8, label_id=self.label_ids[0]), models.LabeledImageRequest(frame=8, label_id=self.label_ids[1]), ], shapes=[ models.LabeledShapeRequest( frame=6, label_id=self.label_ids[1], type=models.ShapeType("rectangle"), points=[1.0, 2.0, 3.0, 4.0], ), models.LabeledShapeRequest( frame=7, label_id=self.label_ids[0], type=models.ShapeType("points"), points=[1.1, 2.1, 3.1, 4.1], ), ], ) ) def test_basic(self): dataset = cvatpt.TaskVisionDataset(self.client, self.task.id) # verify that the cache is not empty assert list(self.client.config.cache_dir.iterdir()) assert len(dataset) == self.task.size for index, (sample_image, sample_target) in enumerate(dataset): sample_image_tensor = TF.pil_to_tensor(sample_image) reference_tensor = TF.pil_to_tensor(PIL.Image.open(self.images[index])) assert torch.equal(sample_image_tensor, reference_tensor) for index, label_id in enumerate(self.label_ids): assert sample_target.label_id_to_index[label_id] == index assert not dataset[0][1].annotations.tags assert not dataset[0][1].annotations.shapes assert len(dataset[5][1].annotations.tags) == 1 assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0] assert not dataset[5][1].annotations.shapes assert len(dataset[6][1].annotations.tags) == 1 assert dataset[6][1].annotations.tags[0].label_id == self.label_ids[1] assert len(dataset[6][1].annotations.shapes) == 1 assert dataset[6][1].annotations.shapes[0].type.value == "rectangle" assert dataset[6][1].annotations.shapes[0].points == [1.0, 2.0, 3.0, 4.0] assert not dataset[7][1].annotations.tags assert len(dataset[7][1].annotations.shapes) == 1 assert dataset[7][1].annotations.shapes[0].type.value == "points" assert dataset[7][1].annotations.shapes[0].points == [1.1, 2.1, 3.1, 4.1] def test_deleted_frame(self): self.task.remove_frames_by_ids([1]) dataset = cvatpt.TaskVisionDataset(self.client, self.task.id) assert len(dataset) == self.task.size - 1 # sample #0 is still frame #0 assert torch.equal( TF.pil_to_tensor(dataset[0][0]), TF.pil_to_tensor(PIL.Image.open(self.images[0])) ) # sample #1 is now frame #2 assert torch.equal( TF.pil_to_tensor(dataset[1][0]), TF.pil_to_tensor(PIL.Image.open(self.images[2])) ) # sample #4 is now frame #5 assert len(dataset[4][1].annotations.tags) == 1 assert dataset[4][1].annotations.tags[0].label_id == self.label_ids[0] assert not dataset[4][1].annotations.shapes def test_extract_single_label_index(self): dataset = cvatpt.TaskVisionDataset( self.client, self.task.id, transform=torchvision.transforms.PILToTensor(), target_transform=cvatpt.ExtractSingleLabelIndex(), ) assert torch.equal(dataset[5][1], torch.tensor(0)) assert torch.equal(dataset[6][1], torch.tensor(1)) with pytest.raises(ValueError): # no tags _ = dataset[7] with pytest.raises(ValueError): # multiple tags _ = dataset[8] # make sure the samples can be batched with the default collater loader = DataLoader(dataset, batch_size=2, sampler=[5, 6]) batch = next(iter(loader)) assert torch.equal(batch[0][0], TF.pil_to_tensor(PIL.Image.open(self.images[5]))) assert torch.equal(batch[0][1], TF.pil_to_tensor(PIL.Image.open(self.images[6]))) assert torch.equal(batch[1], torch.tensor([0, 1])) def test_extract_bounding_boxes(self): dataset = cvatpt.TaskVisionDataset( self.client, self.task.id, transform=torchvision.transforms.PILToTensor(), target_transform=cvatpt.ExtractBoundingBoxes(include_shape_types={"rectangle"}), ) assert torch.equal(dataset[0][1]["boxes"], torch.tensor([])) assert torch.equal(dataset[0][1]["labels"], torch.tensor([])) assert torch.equal(dataset[6][1]["boxes"], torch.tensor([(1.0, 2.0, 3.0, 4.0)])) assert torch.equal(dataset[6][1]["labels"], torch.tensor([1])) # points are filtered out assert torch.equal(dataset[7][1]["boxes"], torch.tensor([])) assert torch.equal(dataset[7][1]["labels"], torch.tensor([])) def test_transforms(self): dataset = cvatpt.TaskVisionDataset( self.client, self.task.id, transforms=lambda x, y: (y, x), ) assert isinstance(dataset[0][0], cvatpt.Target) assert isinstance(dataset[0][1], PIL.Image.Image) def test_custom_label_mapping(self): label_name_to_id = {label.name: label.id for label in self.task.labels} dataset = cvatpt.TaskVisionDataset( self.client, self.task.id, label_name_to_index={"person": 123, "car": 456}, ) _, target = dataset[5] assert target.label_id_to_index[label_name_to_id["person"]] == 123 assert target.label_id_to_index[label_name_to_id["car"]] == 456 def test_offline(self, monkeypatch: pytest.MonkeyPatch): dataset = cvatpt.TaskVisionDataset( self.client, self.task.id, update_policy=cvatpt.UpdatePolicy.IF_MISSING_OR_STALE, ) fresh_samples = list(dataset) _disable_api_requests(monkeypatch) dataset = cvatpt.TaskVisionDataset( self.client, self.task.id, update_policy=cvatpt.UpdatePolicy.NEVER, ) cached_samples = list(dataset) assert fresh_samples == cached_samples @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") class TestProjectVisionDataset: @pytest.fixture(autouse=True) def setup( self, tmp_path: Path, fxt_login: Tuple[Client, str], ): self.client = fxt_login[0] self.project = self.client.projects.create( models.ProjectWriteRequest( "PyTorch integration test project", labels=[ models.PatchedLabelRequest(name="person"), models.PatchedLabelRequest(name="car"), ], ) ) self.label_ids = sorted(l.id for l in self.project.labels) subsets = ["Train", "Test", "Val"] num_images_per_task = 3 all_images = generate_image_files(num_images_per_task * len(subsets)) self.images_per_task = list(zip(*[iter(all_images)] * num_images_per_task)) image_dir = tmp_path / "images" image_dir.mkdir() image_paths_per_task = [] for images in self.images_per_task: image_paths = [] for image in images: image_path = image_dir / image.name image_path.write_bytes(image.getbuffer()) image_paths.append(image_path) image_paths_per_task.append(image_paths) self.tasks = [ self.client.tasks.create_from_data( models.TaskWriteRequest( "PyTorch integration test task", project_id=self.project.id, subset=subset, ), resource_type=ResourceType.LOCAL, resources=image_paths, data_params={"image_quality": 70}, ) for subset, image_paths in zip(subsets, image_paths_per_task) ] # sort both self.tasks and self.images_per_task in the order that ProjectVisionDataset uses self.tasks, self.images_per_task = zip( *sorted(zip(self.tasks, self.images_per_task), key=lambda t: t[0].id) ) for task_id, label_index in ((0, 0), (1, 1), (2, 0)): self.tasks[task_id].update_annotations( models.PatchedLabeledDataRequest( tags=[ models.LabeledImageRequest( frame=task_id, label_id=self.label_ids[label_index] ), ], ) ) def test_basic(self): dataset = cvatpt.ProjectVisionDataset(self.client, self.project.id) assert len(dataset) == sum(task.size for task in self.tasks) for sample, image in zip(dataset, itertools.chain.from_iterable(self.images_per_task)): assert torch.equal(TF.pil_to_tensor(sample[0]), TF.pil_to_tensor(PIL.Image.open(image))) assert dataset[0][1].annotations.tags[0].label_id == self.label_ids[0] assert dataset[4][1].annotations.tags[0].label_id == self.label_ids[1] assert dataset[8][1].annotations.tags[0].label_id == self.label_ids[0] def _test_filtering(self, **kwargs): dataset = cvatpt.ProjectVisionDataset(self.client, self.project.id, **kwargs) assert len(dataset) == sum(task.size for task in self.tasks[1:]) for sample, image in zip(dataset, itertools.chain.from_iterable(self.images_per_task[1:])): assert torch.equal(TF.pil_to_tensor(sample[0]), TF.pil_to_tensor(PIL.Image.open(image))) assert dataset[1][1].annotations.tags[0].label_id == self.label_ids[1] assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0] def test_task_filter(self): self._test_filtering(task_filter=lambda t: t.subset != self.tasks[0].subset) def test_include_subsets(self): self._test_filtering(include_subsets={self.tasks[1].subset, self.tasks[2].subset}) def test_custom_label_mapping(self): label_name_to_id = {label.name: label.id for label in self.project.labels} dataset = cvatpt.ProjectVisionDataset( self.client, self.project.id, label_name_to_index={"person": 123, "car": 456} ) _, target = dataset[5] assert target.label_id_to_index[label_name_to_id["person"]] == 123 assert target.label_id_to_index[label_name_to_id["car"]] == 456 def test_separate_transforms(self): dataset = cvatpt.ProjectVisionDataset( self.client, self.project.id, transform=torchvision.transforms.ToTensor(), target_transform=cvatpt.ExtractSingleLabelIndex(), ) assert torch.equal( dataset[0][0], TF.pil_to_tensor(PIL.Image.open(self.images_per_task[0][0])) ) assert torch.equal(dataset[0][1], torch.tensor(0)) def test_combined_transforms(self): dataset = cvatpt.ProjectVisionDataset( self.client, self.project.id, transforms=lambda x, y: (y, x), ) assert isinstance(dataset[0][0], cvatpt.Target) assert isinstance(dataset[0][1], PIL.Image.Image) def test_offline(self, monkeypatch: pytest.MonkeyPatch): dataset = cvatpt.ProjectVisionDataset( self.client, self.project.id, update_policy=cvatpt.UpdatePolicy.IF_MISSING_OR_STALE, ) fresh_samples = list(dataset) _disable_api_requests(monkeypatch) dataset = cvatpt.ProjectVisionDataset( self.client, self.project.id, update_policy=cvatpt.UpdatePolicy.NEVER, ) cached_samples = list(dataset) assert fresh_samples == cached_samples