diff --git a/CHANGELOG.md b/CHANGELOG.md index 4109498c30ff98f35f737bf27b59b80076fb60f6..9297a8a9295eda2c8b2669f2b9aeb0b86227e87a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added missed auto_add argument to Issue model () - \[API\] Performance of several API endpoints () - \[API\] Invalid schema for the owner field in several endpoints () +- \[SDK\] Loading tasks that have been cached with the PyTorch adapter + () ### Security - TDB diff --git a/cvat-sdk/cvat_sdk/pytorch/caching.py b/cvat-sdk/cvat_sdk/pytorch/caching.py index 215ab23147cf9ec9442088e29b4bf46a58e36c35..deaa8abf8721da2f9272b40920af8f77441adaa9 100644 --- a/cvat-sdk/cvat_sdk/pytorch/caching.py +++ b/cvat-sdk/cvat_sdk/pytorch/caching.py @@ -159,7 +159,7 @@ class _CacheManagerOnline(CacheManager): if task_dir.exists(): shutil.rmtree(task_dir) else: - if saved_task.updated_date < task.updated_date: + if saved_task.api_model.updated_date < task.updated_date: self._logger.info( f"Task {task.id} has been updated on the server since it was cached; purging the cache" ) diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 3ec8a8794bdfedc37f6b111ca3bf16bbcedf33d4..bcfedf7b8b3b30f93210de934ecbea2a6883838e 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -7,7 +7,8 @@ import itertools import os from logging import Logger from pathlib import Path -from typing import Tuple +from typing import Container, Tuple +from urllib.parse import urlparse import pytest from cvat_sdk import Client, models @@ -46,11 +47,18 @@ def _common_setup( api_client.configuration.logger[k] = logger -def _disable_api_requests(monkeypatch: pytest.MonkeyPatch) -> None: - def disabled_request(*args, **kwargs): - raise RuntimeError("Disabled!") +def _restrict_api_requests( + monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = () +) -> None: + original_request = RESTClientObject.request - monkeypatch.setattr(RESTClientObject, "request", disabled_request) + def restricted_request(self, method, url, *args, **kwargs): + parsed_url = urlparse(url) + if parsed_url.path in allow_paths: + return original_request(self, method, url, *args, **kwargs) + raise RuntimeError("Disallowed!") + + monkeypatch.setattr(RESTClientObject, "request", restricted_request) @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") @@ -246,7 +254,7 @@ class TestTaskVisionDataset: fresh_samples = list(dataset) - _disable_api_requests(monkeypatch) + _restrict_api_requests(monkeypatch) dataset = cvatpt.TaskVisionDataset( self.client, @@ -258,6 +266,44 @@ class TestTaskVisionDataset: assert fresh_samples == cached_samples + def test_update(self, monkeypatch: pytest.MonkeyPatch): + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + ) + + # Recreating the dataset should only result in minimal requests. + _restrict_api_requests( + monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"} + ) + + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + ) + + assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0] + + # After an update, the annotations should be redownloaded. + monkeypatch.undo() + + self.task.update_annotations( + models.PatchedLabeledDataRequest( + tags=[ + models.LabeledImageRequest( + id=dataset[5][1].annotations.tags[0].id, frame=5, label_id=self.label_ids[1] + ), + ] + ) + ) + + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + ) + + assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[1] + @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") class TestProjectVisionDataset: @@ -401,7 +447,7 @@ class TestProjectVisionDataset: fresh_samples = list(dataset) - _disable_api_requests(monkeypatch) + _restrict_api_requests(monkeypatch) dataset = cvatpt.ProjectVisionDataset( self.client,