未验证 提交 d950d245 编写于 作者: R Roman Donchenko 提交者: GitHub

PyTorch adapter: fix loading tasks that have already been cached (#6396)

<!-- Raise an issue to propose your change
(https://github.com/opencv/cvat/issues).
It helps to avoid duplication of efforts from multiple independent
contributors.
Discuss your ideas with maintainers to be sure that changes will be
approved and merged.
Read the [Contribution
guide](https://opencv.github.io/cvat/docs/contributing/). -->

<!-- Provide a general summary of your changes in the Title above -->

### Motivation and context
<!-- Why is this change required? What problem does it solve? If it
fixes an open
issue, please link to the issue here. Describe your changes in detail,
add
screenshots. -->
This was broken in 4fc494f4.

Add a test to cover this case.

Fixes #6047.
上级 1b281624
......@@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added missed auto_add argument to Issue model (<https://github.com/opencv/cvat/pull/6364>)
- \[API\] Performance of several API endpoints (<https://github.com/opencv/cvat/pull/6340>)
- \[API\] Invalid schema for the owner field in several endpoints (<https://github.com/opencv/cvat/pull/6343>)
- \[SDK\] Loading tasks that have been cached with the PyTorch adapter
(<https://github.com/opencv/cvat/issues/6047>)
### Security
- TDB
......
......@@ -159,7 +159,7 @@ class _CacheManagerOnline(CacheManager):
if task_dir.exists():
shutil.rmtree(task_dir)
else:
if saved_task.updated_date < task.updated_date:
if saved_task.api_model.updated_date < task.updated_date:
self._logger.info(
f"Task {task.id} has been updated on the server since it was cached; purging the cache"
)
......
......@@ -7,7 +7,8 @@ import itertools
import os
from logging import Logger
from pathlib import Path
from typing import Tuple
from typing import Container, Tuple
from urllib.parse import urlparse
import pytest
from cvat_sdk import Client, models
......@@ -46,11 +47,18 @@ def _common_setup(
api_client.configuration.logger[k] = logger
def _disable_api_requests(monkeypatch: pytest.MonkeyPatch) -> None:
def disabled_request(*args, **kwargs):
raise RuntimeError("Disabled!")
def _restrict_api_requests(
monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = ()
) -> None:
original_request = RESTClientObject.request
monkeypatch.setattr(RESTClientObject, "request", disabled_request)
def restricted_request(self, method, url, *args, **kwargs):
parsed_url = urlparse(url)
if parsed_url.path in allow_paths:
return original_request(self, method, url, *args, **kwargs)
raise RuntimeError("Disallowed!")
monkeypatch.setattr(RESTClientObject, "request", restricted_request)
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
......@@ -246,7 +254,7 @@ class TestTaskVisionDataset:
fresh_samples = list(dataset)
_disable_api_requests(monkeypatch)
_restrict_api_requests(monkeypatch)
dataset = cvatpt.TaskVisionDataset(
self.client,
......@@ -258,6 +266,44 @@ class TestTaskVisionDataset:
assert fresh_samples == cached_samples
def test_update(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)
# Recreating the dataset should only result in minimal requests.
_restrict_api_requests(
monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"}
)
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)
assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0]
# After an update, the annotations should be redownloaded.
monkeypatch.undo()
self.task.update_annotations(
models.PatchedLabeledDataRequest(
tags=[
models.LabeledImageRequest(
id=dataset[5][1].annotations.tags[0].id, frame=5, label_id=self.label_ids[1]
),
]
)
)
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)
assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[1]
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestProjectVisionDataset:
......@@ -401,7 +447,7 @@ class TestProjectVisionDataset:
fresh_samples = list(dataset)
_disable_api_requests(monkeypatch)
_restrict_api_requests(monkeypatch)
dataset = cvatpt.ProjectVisionDataset(
self.client,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册