未验证 提交 4d32c3c6 编写于 作者: R Roman Donchenko 提交者: GitHub

SDK: make the dataset cache directory customizable (#5535)

This is useful for people whose home directory is too small/not fast
enough. It also lets us make the tests less hacky.
上级 72b61250
......@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[2.4.0] - Unreleased
### Added
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>)
- \[SDK\] Configuration setting to change the dataset cache directory
(<https://github.com/opencv/cvat/pull/5535>)
### Changed
- The Docker Compose files now use the Compose Specification version
......
......@@ -8,9 +8,11 @@ from __future__ import annotations
import logging
import urllib.parse
from contextlib import suppress
from pathlib import Path
from time import sleep
from typing import Any, Dict, Optional, Sequence, Tuple
import appdirs
import attrs
import packaging.version as pv
import urllib3
......@@ -27,6 +29,8 @@ from cvat_sdk.core.proxies.tasks import TasksRepo
from cvat_sdk.core.proxies.users import UsersRepo
from cvat_sdk.version import VERSION
_DEFAULT_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
@attrs.define
class Config:
......@@ -43,6 +47,9 @@ class Config:
verify_ssl: Optional[bool] = None
"""Whether to verify host SSL certificate or not"""
cache_dir: Path = attrs.field(converter=Path, default=_DEFAULT_CACHE_DIR)
"""Directory in which to store cached server data"""
class Client:
"""
......
......@@ -6,7 +6,6 @@ import shutil
import types
import zipfile
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import (
Callable,
Dict,
......@@ -20,7 +19,6 @@ from typing import (
TypeVar,
)
import appdirs
import attrs
import attrs.validators
import PIL.Image
......@@ -36,7 +34,6 @@ from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShap
_ModelType = TypeVar("_ModelType")
_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
_NUM_DOWNLOAD_THREADS = 4
......@@ -139,7 +136,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
server_dir_name = (
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
)
server_dir = _CACHE_DIR / f"servers/{server_dir_name}"
server_dir = client.config.cache_dir / f"servers/{server_dir_name}"
self._task_dir = server_dir / f"tasks/{self._task.id}"
self._initialize_task_dir()
......
......@@ -77,7 +77,7 @@ setup(
python_requires="{{{generatorLanguageVersion}}}",
install_requires=BASE_REQUIREMENTS,
extras_require={
"pytorch": ['appdirs', 'torch', 'torchvision'],
"pytorch": ['torch', 'torchvision'],
},
package_dir={"": "."},
packages=find_packages(include=["cvat_sdk*"]),
......
-r api_client.txt
appdirs
attrs >= 21.4.0
packaging >= 21.3
Pillow >= 9.0.1
......
......@@ -30,7 +30,6 @@ class TestTaskVisionDataset:
@pytest.fixture(autouse=True)
def setup(
self,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
......@@ -41,13 +40,12 @@ class TestTaskVisionDataset:
self.stdout = fxt_stdout
self.client, self.user = fxt_login
self.client.logger = logger
self.client.config.cache_dir = tmp_path / "cache"
api_client = self.client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger
monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache")
self._create_task()
yield
......@@ -107,6 +105,9 @@ class TestTaskVisionDataset:
def test_basic(self):
dataset = cvatpt.TaskVisionDataset(self.client, self.task.id)
# verify that the cache is not empty
assert list(self.client.config.cache_dir.iterdir())
assert len(dataset) == self.task.size
for index, (sample_image, sample_target) in enumerate(dataset):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册