未验证 提交 d950d245 编写于 作者: R Roman Donchenko 提交者: GitHub

PyTorch adapter: fix loading tasks that have already been cached (#6396)

<!-- Raise an issue to propose your change
(https://github.com/opencv/cvat/issues).
It helps to avoid duplication of efforts from multiple independent
contributors.
Discuss your ideas with maintainers to be sure that changes will be
approved and merged.
Read the [Contribution
guide](https://opencv.github.io/cvat/docs/contributing/). -->

<!-- Provide a general summary of your changes in the Title above -->

### Motivation and context
<!-- Why is this change required? What problem does it solve? If it
fixes an open
issue, please link to the issue here. Describe your changes in detail,
add
screenshots. -->
This was broken in 4fc494f4.

Add a test to cover this case.

Fixes #6047.
上级 1b281624
...@@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added missed auto_add argument to Issue model (<https://github.com/opencv/cvat/pull/6364>) - Added missed auto_add argument to Issue model (<https://github.com/opencv/cvat/pull/6364>)
- \[API\] Performance of several API endpoints (<https://github.com/opencv/cvat/pull/6340>) - \[API\] Performance of several API endpoints (<https://github.com/opencv/cvat/pull/6340>)
- \[API\] Invalid schema for the owner field in several endpoints (<https://github.com/opencv/cvat/pull/6343>) - \[API\] Invalid schema for the owner field in several endpoints (<https://github.com/opencv/cvat/pull/6343>)
- \[SDK\] Loading tasks that have been cached with the PyTorch adapter
(<https://github.com/opencv/cvat/issues/6047>)
### Security ### Security
- TDB - TDB
......
...@@ -159,7 +159,7 @@ class _CacheManagerOnline(CacheManager): ...@@ -159,7 +159,7 @@ class _CacheManagerOnline(CacheManager):
if task_dir.exists(): if task_dir.exists():
shutil.rmtree(task_dir) shutil.rmtree(task_dir)
else: else:
if saved_task.updated_date < task.updated_date: if saved_task.api_model.updated_date < task.updated_date:
self._logger.info( self._logger.info(
f"Task {task.id} has been updated on the server since it was cached; purging the cache" f"Task {task.id} has been updated on the server since it was cached; purging the cache"
) )
......
...@@ -7,7 +7,8 @@ import itertools ...@@ -7,7 +7,8 @@ import itertools
import os import os
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import Container, Tuple
from urllib.parse import urlparse
import pytest import pytest
from cvat_sdk import Client, models from cvat_sdk import Client, models
...@@ -46,11 +47,18 @@ def _common_setup( ...@@ -46,11 +47,18 @@ def _common_setup(
api_client.configuration.logger[k] = logger api_client.configuration.logger[k] = logger
def _disable_api_requests(monkeypatch: pytest.MonkeyPatch) -> None: def _restrict_api_requests(
def disabled_request(*args, **kwargs): monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = ()
raise RuntimeError("Disabled!") ) -> None:
original_request = RESTClientObject.request
monkeypatch.setattr(RESTClientObject, "request", disabled_request) def restricted_request(self, method, url, *args, **kwargs):
parsed_url = urlparse(url)
if parsed_url.path in allow_paths:
return original_request(self, method, url, *args, **kwargs)
raise RuntimeError("Disallowed!")
monkeypatch.setattr(RESTClientObject, "request", restricted_request)
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
...@@ -246,7 +254,7 @@ class TestTaskVisionDataset: ...@@ -246,7 +254,7 @@ class TestTaskVisionDataset:
fresh_samples = list(dataset) fresh_samples = list(dataset)
_disable_api_requests(monkeypatch) _restrict_api_requests(monkeypatch)
dataset = cvatpt.TaskVisionDataset( dataset = cvatpt.TaskVisionDataset(
self.client, self.client,
...@@ -258,6 +266,44 @@ class TestTaskVisionDataset: ...@@ -258,6 +266,44 @@ class TestTaskVisionDataset:
assert fresh_samples == cached_samples assert fresh_samples == cached_samples
def test_update(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)
# Recreating the dataset should only result in minimal requests.
_restrict_api_requests(
monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"}
)
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)
assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0]
# After an update, the annotations should be redownloaded.
monkeypatch.undo()
self.task.update_annotations(
models.PatchedLabeledDataRequest(
tags=[
models.LabeledImageRequest(
id=dataset[5][1].annotations.tags[0].id, frame=5, label_id=self.label_ids[1]
),
]
)
)
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)
assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[1]
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestProjectVisionDataset: class TestProjectVisionDataset:
...@@ -401,7 +447,7 @@ class TestProjectVisionDataset: ...@@ -401,7 +447,7 @@ class TestProjectVisionDataset:
fresh_samples = list(dataset) fresh_samples = list(dataset)
_disable_api_requests(monkeypatch) _restrict_api_requests(monkeypatch)
dataset = cvatpt.ProjectVisionDataset( dataset = cvatpt.ProjectVisionDataset(
self.client, self.client,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册