未验证 提交 33c624ae 编写于 作者: R Roman Donchenko 提交者: GitHub

PyTorch adapter: add a way to disable cache updates (#5549)

This will let users to run their PyTorch code without network access,
provided that they have already cached the data.

### How has this been tested?
<!-- Please describe in detail how you tested your changes.
Include details of your testing environment, and the tests you ran to
see how your change affects other areas of the code, etc. -->
Unit tests.
上级 fd7d8024
......@@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/opencv/cvat/pull/5535>)
- \[SDK\] Class to represent a project as a PyTorch dataset
(<https://github.com/opencv/cvat/pull/5523>)
- \[SDK\] A PyTorch adapter setting to disable cache updates
(<https://github.com/opencv/cvat/pull/5549>)
### Changed
- The Docker Compose files now use the Compose Specification version
......
......@@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: MIT
from .caching import UpdatePolicy
from .common import FrameAnnotations, Target, UnsupportedDatasetError
from .project_dataset import ProjectVisionDataset
from .task_dataset import TaskVisionDataset
......
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import base64
import json
import shutil
from abc import ABCMeta, abstractmethod
from enum import Enum, auto
from pathlib import Path
from typing import Callable, Mapping, Type, TypeVar
import cvat_sdk.models as models
from cvat_sdk.api_client.model_utils import OpenApiModel, to_json
from cvat_sdk.core.client import Client
from cvat_sdk.core.proxies.projects import Project
from cvat_sdk.core.proxies.tasks import Task
from cvat_sdk.core.utils import atomic_writer
class UpdatePolicy(Enum):
"""
Defines policies for when the local cache is updated from the CVAT server.
"""
IF_MISSING_OR_STALE = auto()
"""
Update the cache whenever cached data is missing or the server has a newer version.
"""
NEVER = auto()
"""
Never update the cache. If an operation requires data that is not cached,
it will fail.
No network access will be performed if this policy is used.
"""
_ModelType = TypeVar("_ModelType", bound=OpenApiModel)
class CacheManager(metaclass=ABCMeta):
def __init__(self, client: Client) -> None:
self._client = client
self._logger = client.logger
self._server_dir = client.config.cache_dir / f"servers/{self.server_dir_name}"
@property
def server_dir_name(self) -> str:
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
return base64.urlsafe_b64encode(self._client.api_map.host.encode()).rstrip(b"=").decode()
def task_dir(self, task_id: int) -> Path:
return self._server_dir / f"tasks/{task_id}"
def task_json_path(self, task_id: int) -> Path:
return self.task_dir(task_id) / "task.json"
def chunk_dir(self, task_id: int) -> Path:
return self.task_dir(task_id) / "chunks"
def project_dir(self, project_id: int) -> Path:
return self._server_dir / f"projects/{project_id}"
def project_json_path(self, project_id: int) -> Path:
return self.project_dir(project_id) / "project.json"
def load_model(self, path: Path, model_type: Type[_ModelType]) -> _ModelType:
with open(path, "rb") as f:
return model_type._new_from_openapi_data(**json.load(f))
def save_model(self, path: Path, model: OpenApiModel) -> None:
with atomic_writer(path, "w", encoding="UTF-8") as f:
json.dump(to_json(model), f, indent=4)
print(file=f) # add final newline
@abstractmethod
def retrieve_task(self, task_id: int) -> Task:
...
@abstractmethod
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: Type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
...
@abstractmethod
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
...
@abstractmethod
def retrieve_project(self, project_id: int) -> Project:
...
class _CacheManagerOnline(CacheManager):
def retrieve_task(self, task_id: int) -> Task:
self._logger.info(f"Fetching task {task_id}...")
task = self._client.tasks.retrieve(task_id)
self._initialize_task_dir(task)
return task
def _initialize_task_dir(self, task: Task) -> None:
task_dir = self.task_dir(task.id)
task_json_path = self.task_json_path(task.id)
try:
saved_task = self.load_model(task_json_path, models.TaskRead)
except Exception:
self._logger.info(f"Task {task.id} is not yet cached or the cache is corrupted")
# If the cache was corrupted, the directory might already be there; clear it.
if task_dir.exists():
shutil.rmtree(task_dir)
else:
if saved_task.updated_date < task.updated_date:
self._logger.info(
f"Task {task.id} has been updated on the server since it was cached; purging the cache"
)
shutil.rmtree(task_dir)
task_dir.mkdir(exist_ok=True, parents=True)
self.save_model(task_json_path, task._model)
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: Type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
path = self.task_dir(task_id) / filename
try:
model = self.load_model(path, model_type)
self._logger.info(f"Loaded {model_description} from cache")
return model
except FileNotFoundError:
pass
except Exception:
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True)
self._logger.info(f"Downloading {model_description}...")
model = downloader()
self._logger.info(f"Downloaded {model_description}")
self.save_model(path, model)
return model
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"
if chunk_path.exists():
return # already downloaded previously
self._logger.info(f"Downloading chunk #{chunk_index}...")
with atomic_writer(chunk_path, "wb") as chunk_file:
task.download_chunk(chunk_index, chunk_file, quality="original")
def retrieve_project(self, project_id: int) -> Project:
self._logger.info(f"Fetching project {project_id}...")
project = self._client.projects.retrieve(project_id)
project_dir = self.project_dir(project_id)
project_dir.mkdir(parents=True, exist_ok=True)
project_json_path = self.project_json_path(project_id)
# There are currently no files cached alongside project.json,
# so we don't need to check if we need to purge them.
self.save_model(project_json_path, project._model)
return project
class _CacheManagerOffline(CacheManager):
def retrieve_task(self, task_id: int) -> Task:
self._logger.info(f"Retrieving task {task_id} from cache...")
return Task(self._client, self.load_model(self.task_json_path(task_id), models.TaskRead))
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: Type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
self._logger.info(f"Loading {model_description} from cache...")
return self.load_model(self.task_dir(task_id) / filename, model_type)
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"
if not chunk_path.exists():
raise FileNotFoundError(f"Chunk {chunk_index} of task {task.id} is not cached")
def retrieve_project(self, project_id: int) -> Project:
self._logger.info(f"Retrieving project {project_id} from cache...")
return Project(
self._client, self.load_model(self.project_json_path(project_id), models.ProjectRead)
)
_CACHE_MANAGER_CLASSES: Mapping[UpdatePolicy, Type[CacheManager]] = {
UpdatePolicy.IF_MISSING_OR_STALE: _CacheManagerOnline,
UpdatePolicy.NEVER: _CacheManagerOffline,
}
def make_cache_manager(client: Client, update_policy: UpdatePolicy) -> CacheManager:
return _CACHE_MANAGER_CLASSES[update_policy](client)
......@@ -2,8 +2,6 @@
#
# SPDX-License-Identifier: MIT
import base64
from pathlib import Path
from typing import List, Mapping
import attrs
......@@ -42,9 +40,3 @@ class Target:
A mapping from label_id values in `LabeledImage` and `LabeledShape` objects
to an integer index. This mapping is consistent across all samples for a given task.
"""
def get_server_cache_dir(client: cvat_sdk.core.Client) -> Path:
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
server_dir_name = base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
return client.config.cache_dir / f"servers/{server_dir_name}"
......@@ -12,7 +12,7 @@ import torchvision.datasets
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.pytorch.common import get_server_cache_dir
from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.pytorch.task_dataset import TaskVisionDataset
......@@ -42,6 +42,7 @@ class ProjectVisionDataset(torchvision.datasets.VisionDataset):
label_name_to_index: Mapping[str, int] = None,
task_filter: Optional[Callable[[models.ITaskRead], bool]] = None,
include_subsets: Optional[Container[str]] = None,
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
) -> None:
"""
Creates a dataset corresponding to the project with ID `project_id` on the
......@@ -61,29 +62,24 @@ class ProjectVisionDataset(torchvision.datasets.VisionDataset):
* If `include_subsets` is set to a container, then tasks whose subset is
not a member of this container will be excluded.
`update_policy` determines when and if the local cache will be updated.
"""
self._logger = client.logger
self._logger.info(f"Fetching project {project_id}...")
project = client.projects.retrieve(project_id)
# We don't actually need to save anything to this directory (yet),
# but VisionDataset.__init__ requires a root, so make one.
# It could be useful in the future to store the project data for
# offline-only mode.
project_dir = get_server_cache_dir(client) / f"projects/{project_id}"
project_dir.mkdir(parents=True, exist_ok=True)
cache_manager = make_cache_manager(client, update_policy)
project = cache_manager.retrieve_project(project_id)
super().__init__(
os.fspath(project_dir),
os.fspath(cache_manager.project_dir(project_id)),
transforms=transforms,
transform=transform,
target_transform=target_transform,
)
self._logger.info("Fetching project tasks...")
tasks = project.get_tasks()
tasks = [cache_manager.retrieve_task(task_id) for task_id in project.tasks]
if task_filter is not None:
tasks = list(filter(task_filter, tasks))
......@@ -95,7 +91,12 @@ class ProjectVisionDataset(torchvision.datasets.VisionDataset):
self._underlying = torch.utils.data.ConcatDataset(
[
TaskVisionDataset(client, task.id, label_name_to_index=label_name_to_index)
TaskVisionDataset(
client,
task.id,
label_name_to_index=label_name_to_index,
update_policy=update_policy,
)
for task in tasks
]
)
......
......@@ -3,13 +3,11 @@
# SPDX-License-Identifier: MIT
import collections
import json
import os
import shutil
import types
import zipfile
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Mapping, Optional, Type, TypeVar
from typing import Callable, Dict, Mapping, Optional
import PIL.Image
import torchvision.datasets
......@@ -17,16 +15,8 @@ import torchvision.datasets
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.api_client.model_utils import to_json
from cvat_sdk.core.utils import atomic_writer
from cvat_sdk.pytorch.common import (
FrameAnnotations,
Target,
UnsupportedDatasetError,
get_server_cache_dir,
)
_ModelType = TypeVar("_ModelType")
from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.pytorch.common import FrameAnnotations, Target, UnsupportedDatasetError
_NUM_DOWNLOAD_THREADS = 4
......@@ -44,7 +34,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
* target is a `Target` object containing annotations for the frame.
This class caches all data and annotations for the task on the local file system
during construction. If the task is updated on the server, the cache is updated.
during construction.
Limitations:
......@@ -61,6 +51,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
label_name_to_index: Mapping[str, int] = None,
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
......@@ -80,12 +71,14 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
will map each label ID to a distinct integer in the range [0, `num_labels`), where
`num_labels` is the number of labels defined in the task. This mapping will be
generally unpredictable, but consistent for a given task.
`update_policy` determines when and if the local cache will be updated.
"""
self._logger = client.logger
self._logger.info(f"Fetching task {task_id}...")
self._task = client.tasks.retrieve(task_id)
cache_manager = make_cache_manager(client, update_policy)
self._task = cache_manager.retrieve_task(task_id)
if not self._task.size or not self._task.data_chunk_size:
raise UnsupportedDatasetError("The task has no data")
......@@ -96,18 +89,19 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
f" current chunk type is {self._task.data_original_chunk_type!r}"
)
self._task_dir = get_server_cache_dir(client) / f"tasks/{self._task.id}"
self._initialize_task_dir()
super().__init__(
os.fspath(self._task_dir),
os.fspath(cache_manager.task_dir(self._task.id)),
transforms=transforms,
transform=transform,
target_transform=target_transform,
)
data_meta = self._ensure_model(
"data_meta.json", models.DataMetaRead, self._task.get_meta, "data metadata"
data_meta = cache_manager.ensure_task_model(
self._task.id,
"data_meta.json",
models.DataMetaRead,
self._task.get_meta,
"data metadata",
)
self._active_frame_indexes = sorted(
set(range(self._task.size)) - set(data_meta.deleted_frames)
......@@ -115,7 +109,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
self._logger.info("Downloading chunks...")
self._chunk_dir = self._task_dir / "chunks"
self._chunk_dir = cache_manager.chunk_dir(task_id)
self._chunk_dir.mkdir(exist_ok=True, parents=True)
needed_chunks = {
......@@ -123,7 +117,11 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
}
with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:
for _ in pool.map(self._ensure_chunk, sorted(needed_chunks)):
def ensure_chunk(chunk_index):
cache_manager.ensure_chunk(self._task, chunk_index)
for _ in pool.map(ensure_chunk, sorted(needed_chunks)):
# just need to loop through all results so that any exceptions are propagated
pass
......@@ -143,8 +141,12 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
{label.id: label_name_to_index[label.name] for label in self._task.labels}
)
annotations = self._ensure_model(
"annotations.json", models.LabeledData, self._task.get_annotations, "annotations"
annotations = cache_manager.ensure_task_model(
self._task.id,
"annotations.json",
models.LabeledData,
self._task.get_annotations,
"annotations",
)
self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict(
......@@ -159,70 +161,6 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
# TODO: tracks?
def _initialize_task_dir(self) -> None:
task_json_path = self._task_dir / "task.json"
try:
with open(task_json_path, "rb") as task_json_file:
saved_task = models.TaskRead._new_from_openapi_data(**json.load(task_json_file))
except Exception:
self._logger.info("Task is not yet cached or the cache is corrupted")
# If the cache was corrupted, the directory might already be there; clear it.
if self._task_dir.exists():
shutil.rmtree(self._task_dir)
else:
if saved_task.updated_date < self._task.updated_date:
self._logger.info(
"Task has been updated on the server since it was cached; purging the cache"
)
shutil.rmtree(self._task_dir)
self._task_dir.mkdir(exist_ok=True, parents=True)
with atomic_writer(task_json_path, "w", encoding="UTF-8") as task_json_file:
json.dump(to_json(self._task._model), task_json_file, indent=4)
print(file=task_json_file) # add final newline
def _ensure_chunk(self, chunk_index: int) -> None:
chunk_path = self._chunk_dir / f"{chunk_index}.zip"
if chunk_path.exists():
return # already downloaded previously
self._logger.info(f"Downloading chunk #{chunk_index}...")
with atomic_writer(chunk_path, "wb") as chunk_file:
self._task.download_chunk(chunk_index, chunk_file, quality="original")
def _ensure_model(
self,
filename: str,
model_type: Type[_ModelType],
download: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
path = self._task_dir / filename
try:
with open(path, "rb") as f:
model = model_type._new_from_openapi_data(**json.load(f))
self._logger.info(f"Loaded {model_description} from cache")
return model
except FileNotFoundError:
pass
except Exception:
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True)
self._logger.info(f"Downloading {model_description}...")
model = download()
self._logger.info(f"Downloaded {model_description}")
with atomic_writer(path, "w", encoding="UTF-8") as f:
json.dump(to_json(model), f, indent=4)
print(file=f) # add final newline
return model
def __getitem__(self, sample_index: int):
"""
Returns the sample with index `sample_index`.
......
......@@ -11,6 +11,7 @@ from typing import Tuple
import pytest
from cvat_sdk import Client, models
from cvat_sdk.api_client.rest import RESTClientObject
from cvat_sdk.core.proxies.tasks import ResourceType
try:
......@@ -42,6 +43,13 @@ def _common_setup(
api_client.configuration.logger[k] = logger
def _disable_api_requests(monkeypatch: pytest.MonkeyPatch) -> None:
def disabled_request(*args, **kwargs):
raise RuntimeError("Disabled!")
monkeypatch.setattr(RESTClientObject, "request", disabled_request)
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestTaskVisionDataset:
@pytest.fixture(autouse=True)
......@@ -226,6 +234,27 @@ class TestTaskVisionDataset:
assert target.label_id_to_index[label_name_to_id["person"]] == 123
assert target.label_id_to_index[label_name_to_id["car"]] == 456
def test_offline(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
update_policy=cvatpt.UpdatePolicy.IF_MISSING_OR_STALE,
)
fresh_samples = list(dataset)
_disable_api_requests(monkeypatch)
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
update_policy=cvatpt.UpdatePolicy.NEVER,
)
cached_samples = list(dataset)
assert fresh_samples == cached_samples
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestProjectVisionDataset:
......@@ -359,3 +388,24 @@ class TestProjectVisionDataset:
assert isinstance(dataset[0][0], cvatpt.Target)
assert isinstance(dataset[0][1], PIL.Image.Image)
def test_offline(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatpt.ProjectVisionDataset(
self.client,
self.project.id,
update_policy=cvatpt.UpdatePolicy.IF_MISSING_OR_STALE,
)
fresh_samples = list(dataset)
_disable_api_requests(monkeypatch)
dataset = cvatpt.ProjectVisionDataset(
self.client,
self.project.id,
update_policy=cvatpt.UpdatePolicy.NEVER,
)
cached_samples = list(dataset)
assert fresh_samples == cached_samples
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册