未验证 提交 4d32c3c6 编写于 作者: R Roman Donchenko 提交者: GitHub

SDK: make the dataset cache directory customizable (#5535)

This is useful for people whose home directory is too small/not fast
enough. It also lets us make the tests less hacky.
上级 72b61250
...@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[2.4.0] - Unreleased ## \[2.4.0] - Unreleased
### Added ### Added
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>) - Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>)
- \[SDK\] Configuration setting to change the dataset cache directory
(<https://github.com/opencv/cvat/pull/5535>)
### Changed ### Changed
- The Docker Compose files now use the Compose Specification version - The Docker Compose files now use the Compose Specification version
......
...@@ -8,9 +8,11 @@ from __future__ import annotations ...@@ -8,9 +8,11 @@ from __future__ import annotations
import logging import logging
import urllib.parse import urllib.parse
from contextlib import suppress from contextlib import suppress
from pathlib import Path
from time import sleep from time import sleep
from typing import Any, Dict, Optional, Sequence, Tuple from typing import Any, Dict, Optional, Sequence, Tuple
import appdirs
import attrs import attrs
import packaging.version as pv import packaging.version as pv
import urllib3 import urllib3
...@@ -27,6 +29,8 @@ from cvat_sdk.core.proxies.tasks import TasksRepo ...@@ -27,6 +29,8 @@ from cvat_sdk.core.proxies.tasks import TasksRepo
from cvat_sdk.core.proxies.users import UsersRepo from cvat_sdk.core.proxies.users import UsersRepo
from cvat_sdk.version import VERSION from cvat_sdk.version import VERSION
_DEFAULT_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
@attrs.define @attrs.define
class Config: class Config:
...@@ -43,6 +47,9 @@ class Config: ...@@ -43,6 +47,9 @@ class Config:
verify_ssl: Optional[bool] = None verify_ssl: Optional[bool] = None
"""Whether to verify host SSL certificate or not""" """Whether to verify host SSL certificate or not"""
cache_dir: Path = attrs.field(converter=Path, default=_DEFAULT_CACHE_DIR)
"""Directory in which to store cached server data"""
class Client: class Client:
""" """
......
...@@ -6,7 +6,6 @@ import shutil ...@@ -6,7 +6,6 @@ import shutil
import types import types
import zipfile import zipfile
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import ( from typing import (
Callable, Callable,
Dict, Dict,
...@@ -20,7 +19,6 @@ from typing import ( ...@@ -20,7 +19,6 @@ from typing import (
TypeVar, TypeVar,
) )
import appdirs
import attrs import attrs
import attrs.validators import attrs.validators
import PIL.Image import PIL.Image
...@@ -36,7 +34,6 @@ from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShap ...@@ -36,7 +34,6 @@ from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShap
_ModelType = TypeVar("_ModelType") _ModelType = TypeVar("_ModelType")
_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
_NUM_DOWNLOAD_THREADS = 4 _NUM_DOWNLOAD_THREADS = 4
...@@ -139,7 +136,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): ...@@ -139,7 +136,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
server_dir_name = ( server_dir_name = (
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode() base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
) )
server_dir = _CACHE_DIR / f"servers/{server_dir_name}" server_dir = client.config.cache_dir / f"servers/{server_dir_name}"
self._task_dir = server_dir / f"tasks/{self._task.id}" self._task_dir = server_dir / f"tasks/{self._task.id}"
self._initialize_task_dir() self._initialize_task_dir()
......
...@@ -77,7 +77,7 @@ setup( ...@@ -77,7 +77,7 @@ setup(
python_requires="{{{generatorLanguageVersion}}}", python_requires="{{{generatorLanguageVersion}}}",
install_requires=BASE_REQUIREMENTS, install_requires=BASE_REQUIREMENTS,
extras_require={ extras_require={
"pytorch": ['appdirs', 'torch', 'torchvision'], "pytorch": ['torch', 'torchvision'],
}, },
package_dir={"": "."}, package_dir={"": "."},
packages=find_packages(include=["cvat_sdk*"]), packages=find_packages(include=["cvat_sdk*"]),
......
-r api_client.txt -r api_client.txt
appdirs
attrs >= 21.4.0 attrs >= 21.4.0
packaging >= 21.3 packaging >= 21.3
Pillow >= 9.0.1 Pillow >= 9.0.1
......
...@@ -30,7 +30,6 @@ class TestTaskVisionDataset: ...@@ -30,7 +30,6 @@ class TestTaskVisionDataset:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup( def setup(
self, self,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path, tmp_path: Path,
fxt_login: Tuple[Client, str], fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO], fxt_logger: Tuple[Logger, io.StringIO],
...@@ -41,13 +40,12 @@ class TestTaskVisionDataset: ...@@ -41,13 +40,12 @@ class TestTaskVisionDataset:
self.stdout = fxt_stdout self.stdout = fxt_stdout
self.client, self.user = fxt_login self.client, self.user = fxt_login
self.client.logger = logger self.client.logger = logger
self.client.config.cache_dir = tmp_path / "cache"
api_client = self.client.api_client api_client = self.client.api_client
for k in api_client.configuration.logger: for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger api_client.configuration.logger[k] = logger
monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache")
self._create_task() self._create_task()
yield yield
...@@ -107,6 +105,9 @@ class TestTaskVisionDataset: ...@@ -107,6 +105,9 @@ class TestTaskVisionDataset:
def test_basic(self): def test_basic(self):
dataset = cvatpt.TaskVisionDataset(self.client, self.task.id) dataset = cvatpt.TaskVisionDataset(self.client, self.task.id)
# verify that the cache is not empty
assert list(self.client.config.cache_dir.iterdir())
assert len(dataset) == self.task.size assert len(dataset) == self.task.size
for index, (sample_image, sample_target) in enumerate(dataset): for index, (sample_image, sample_target) in enumerate(dataset):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册