From d950d2456ef1515578fed1bfaf3b816797de9e8e Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Wed, 28 Jun 2023 23:26:24 +0300 Subject: [PATCH] PyTorch adapter: fix loading tasks that have already been cached (#6396) ### Motivation and context This was broken in 4fc494f4. Add a test to cover this case. Fixes #6047. --- CHANGELOG.md | 2 + cvat-sdk/cvat_sdk/pytorch/caching.py | 2 +- tests/python/sdk/test_pytorch.py | 60 ++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4109498c3..9297a8a92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added missed auto_add argument to Issue model () - \[API\] Performance of several API endpoints () - \[API\] Invalid schema for the owner field in several endpoints () +- \[SDK\] Loading tasks that have been cached with the PyTorch adapter + () ### Security - TDB diff --git a/cvat-sdk/cvat_sdk/pytorch/caching.py b/cvat-sdk/cvat_sdk/pytorch/caching.py index 215ab2314..deaa8abf8 100644 --- a/cvat-sdk/cvat_sdk/pytorch/caching.py +++ b/cvat-sdk/cvat_sdk/pytorch/caching.py @@ -159,7 +159,7 @@ class _CacheManagerOnline(CacheManager): if task_dir.exists(): shutil.rmtree(task_dir) else: - if saved_task.updated_date < task.updated_date: + if saved_task.api_model.updated_date < task.updated_date: self._logger.info( f"Task {task.id} has been updated on the server since it was cached; purging the cache" ) diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 3ec8a8794..bcfedf7b8 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -7,7 +7,8 @@ import itertools import os from logging import Logger from pathlib import Path -from typing import Tuple +from typing import Container, Tuple +from urllib.parse import urlparse import pytest from cvat_sdk import Client, models @@ -46,11 +47,18 @@ def _common_setup( api_client.configuration.logger[k] = logger -def _disable_api_requests(monkeypatch: pytest.MonkeyPatch) -> None: - def disabled_request(*args, **kwargs): - raise RuntimeError("Disabled!") +def _restrict_api_requests( + monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = () +) -> None: + original_request = RESTClientObject.request - monkeypatch.setattr(RESTClientObject, "request", disabled_request) + def restricted_request(self, method, url, *args, **kwargs): + parsed_url = urlparse(url) + if parsed_url.path in allow_paths: + return original_request(self, method, url, *args, **kwargs) + raise RuntimeError("Disallowed!") + + monkeypatch.setattr(RESTClientObject, "request", restricted_request) @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") @@ -246,7 +254,7 @@ class TestTaskVisionDataset: fresh_samples = list(dataset) - _disable_api_requests(monkeypatch) + _restrict_api_requests(monkeypatch) dataset = cvatpt.TaskVisionDataset( self.client, @@ -258,6 +266,44 @@ class TestTaskVisionDataset: assert fresh_samples == cached_samples + def test_update(self, monkeypatch: pytest.MonkeyPatch): + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + ) + + # Recreating the dataset should only result in minimal requests. + _restrict_api_requests( + monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"} + ) + + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + ) + + assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0] + + # After an update, the annotations should be redownloaded. + monkeypatch.undo() + + self.task.update_annotations( + models.PatchedLabeledDataRequest( + tags=[ + models.LabeledImageRequest( + id=dataset[5][1].annotations.tags[0].id, frame=5, label_id=self.label_ids[1] + ), + ] + ) + ) + + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + ) + + assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[1] + @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") class TestProjectVisionDataset: @@ -401,7 +447,7 @@ class TestProjectVisionDataset: fresh_samples = list(dataset) - _disable_api_requests(monkeypatch) + _restrict_api_requests(monkeypatch) dataset = cvatpt.ProjectVisionDataset( self.client, -- GitLab