project_dataset.py 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (C) 2022-2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import os
from typing import Callable, Container, Mapping, Optional

import torch
import torch.utils.data
import torchvision.datasets

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
15
from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
from cvat_sdk.pytorch.task_dataset import TaskVisionDataset


class ProjectVisionDataset(torchvision.datasets.VisionDataset):
    """
    Represents a project on a CVAT server as a PyTorch Dataset.

    The dataset contains one sample for each frame of each task in the project
    (except for tasks that are filtered out - see the description of `task_filter`
    in the constructor). The sequence of samples is formed by concatening sequences
    of samples from all included tasks in an arbitrary order that's consistent
    between executions. Each task's sequence of samples corresponds to the sequence
    of frames on the server.

    See `TaskVisionDataset` for information on sample format, caching, and
    current limitations.
    """

    def __init__(
        self,
        client: cvat_sdk.core.Client,
        project_id: int,
        *,
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        label_name_to_index: Mapping[str, int] = None,
        task_filter: Optional[Callable[[models.ITaskRead], bool]] = None,
        include_subsets: Optional[Container[str]] = None,
45
        update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    ) -> None:
        """
        Creates a dataset corresponding to the project with ID `project_id` on the
        server that `client` is connected to.

        `transforms`, `transform` and `target_transforms` are optional transformation
        functions; see the documentation for `torchvision.datasets.VisionDataset` for
        more information.

        See `TaskVisionDataset.__init__` for information on `label_name_to_index`.

        By default, all of the project's tasks will be included in the dataset.
        The following parameters can be specified to exclude some tasks:

        * If `task_filter` is set to a callable object, it will be applied to every task.
          Tasks for which it returns a false value will be excluded.

        * If `include_subsets` is set to a container, then tasks whose subset is
          not a member of this container will be excluded.
65 66

        `update_policy` determines when and if the local cache will be updated.
67 68 69 70
        """

        self._logger = client.logger

71 72
        cache_manager = make_cache_manager(client, update_policy)
        project = cache_manager.retrieve_project(project_id)
73 74

        super().__init__(
75
            os.fspath(cache_manager.project_dir(project_id)),
76 77 78 79 80 81
            transforms=transforms,
            transform=transform,
            target_transform=target_transform,
        )

        self._logger.info("Fetching project tasks...")
82
        tasks = [cache_manager.retrieve_task(task_id) for task_id in project.tasks]
83 84 85 86 87 88 89 90 91 92 93

        if task_filter is not None:
            tasks = list(filter(task_filter, tasks))

        if include_subsets is not None:
            tasks = [task for task in tasks if task.subset in include_subsets]

        tasks.sort(key=lambda t: t.id)  # ensure consistent order between executions

        self._underlying = torch.utils.data.ConcatDataset(
            [
94 95 96 97 98 99
                TaskVisionDataset(
                    client,
                    task.id,
                    label_name_to_index=label_name_to_index,
                    update_policy=update_policy,
                )
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
                for task in tasks
            ]
        )

    def __getitem__(self, sample_index: int):
        """
        Returns the sample with index `sample_index`.

        `sample_index` must satisfy the condition `0 <= sample_index < len(self)`.
        """

        sample_image, sample_target = self._underlying[sample_index]

        if self.transforms:
            sample_image, sample_target = self.transforms(sample_image, sample_target)

        return sample_image, sample_target

    def __len__(self) -> int:
        """Returns the number of samples in the dataset."""
        return len(self._underlying)