未验证 提交 53697eca 编写于 作者: M Maxim Zhiltsov 提交者: GitHub

SDK layer 2 - cover RC1 usecases (#4813)

上级 b60d3b48
......@@ -6,3 +6,4 @@
# B406 : import_xml_sax
# B410 : import_lxml
skips: B101,B102,B320,B404,B406,B410
exclude: **/tests/**,tests
......@@ -33,7 +33,7 @@ jobs:
echo "Bandit version: "$(bandit --version | head -1)
echo "The files will be checked: "$(echo $CHANGED_FILES)
bandit $CHANGED_FILES --exclude '**/tests/**' -a file --ini ./.bandit -f html -o ./bandit_report/bandit_checks.html
bandit -a file --ini .bandit -f html -o ./bandit_report/bandit_checks.html $CHANGED_FILES
deactivate
else
echo "No files with the \"py\" extension found"
......
......@@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Possibility to display tags on frame
- Support source and target storages (server part)
- Tests for import/export annotation, dataset, backup from/to cloud storage
- Added Python SDK package (`cvat-sdk`)
- Added Python SDK package (`cvat-sdk`) (<https://github.com/opencv/cvat/pull/4813>)
- Previews for jobs
- Documentation for LDAP authentication (<https://github.com/cvat-ai/cvat/pull/39>)
- OpenCV.js caching and autoload (<https://github.com/cvat-ai/cvat/pull/30>)
......@@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Bumped nuclio version to 1.8.14
- Simplified running REST API tests. Extended CI-nightly workflow
- REST API tests are partially moved to Python SDK (`users`, `projects`, `tasks`)
- REST API tests are partially moved to Python SDK (`users`, `projects`, `tasks`, `issues`)
- cvat-ui: Improve UI/UX on label, create task and create project forms (<https://github.com/cvat-ai/cvat/pull/7>)
- Removed link to OpenVINO documentation (<https://github.com/cvat-ai/cvat/pull/35>)
- Clarified meaning of chunking for videos
......@@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Image search in cloud storage (<https://github.com/cvat-ai/cvat/pull/8>)
- Reset password functionality (<https://github.com/cvat-ai/cvat/pull/52>)
- Creating task with cloud storage data (<https://github.com/cvat-ai/cvat/pull/116>)
- Show empty tasks (<https://github.com/cvat-ai/cvat/pull/100>)
### Security
- TDB
......
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
......@@ -11,7 +10,7 @@ from typing import Dict, List, Sequence, Tuple
import tqdm
from cvat_sdk import Client, models
from cvat_sdk.core.helpers import TqdmProgressReporter
from cvat_sdk.core.types import ResourceType
from cvat_sdk.core.proxies.tasks import ResourceType
class CLI:
......@@ -26,7 +25,7 @@ class CLI:
def tasks_list(self, *, use_json_output: bool = False, **kwargs):
"""List all tasks in either basic or JSON format."""
results = self.client.list_tasks(return_json=use_json_output, **kwargs)
results = self.client.tasks.list(return_json=use_json_output, **kwargs)
if use_json_output:
print(json.dumps(json.loads(results), indent=2))
else:
......@@ -50,7 +49,7 @@ class CLI:
"""
Create a new task with the given name and labels JSON and add the files to it.
"""
task = self.client.create_task(
task = self.client.tasks.create_from_data(
spec=models.TaskWriteRequest(name=name, labels=labels, **kwargs),
resource_type=resource_type,
resources=resources,
......@@ -66,7 +65,7 @@ class CLI:
def tasks_delete(self, task_ids: Sequence[int]) -> None:
"""Delete a list of tasks, ignoring those which don't exist."""
self.client.delete_tasks(task_ids=task_ids)
self.client.tasks.remove_by_ids(task_ids=task_ids)
def tasks_frames(
self,
......@@ -80,11 +79,11 @@ class CLI:
Download the requested frame numbers for a task and save images as
task_<ID>_frame_<FRAME>.jpg.
"""
self.client.retrieve_task(task_id=task_id).download_frames(
self.client.tasks.retrieve(obj_id=task_id).download_frames(
frame_ids=frame_ids,
outdir=outdir,
quality=quality,
filename_pattern="task_{task_id}_frame_{frame_id:06d}{frame_ext}",
filename_pattern=f"task_{task_id}" + "_frame_{frame_id:06d}{frame_ext}",
)
def tasks_dump(
......@@ -99,7 +98,7 @@ class CLI:
"""
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
self.client.retrieve_task(task_id=task_id).export_dataset(
self.client.tasks.retrieve(obj_id=task_id).export_dataset(
format_name=fileformat,
filename=filename,
pbar=self._make_pbar(),
......@@ -112,7 +111,7 @@ class CLI:
) -> None:
"""Upload annotations for a task in the specified format
(e.g. 'YOLO ZIP 1.0')."""
self.client.retrieve_task(task_id=task_id).import_annotations(
self.client.tasks.retrieve(obj_id=task_id).import_annotations(
format_name=fileformat,
filename=filename,
status_check_period=status_check_period,
......@@ -121,13 +120,13 @@ class CLI:
def tasks_export(self, task_id: str, filename: str, *, status_check_period: int = 2) -> None:
"""Download a task backup"""
self.client.retrieve_task(task_id=task_id).download_backup(
self.client.tasks.retrieve(obj_id=task_id).download_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
)
def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None:
"""Import a task from a backup file"""
self.client.create_task_from_backup(
self.client.tasks.create_from_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
)
......
......@@ -10,7 +10,7 @@ import logging
import os
from distutils.util import strtobool
from cvat_sdk.core.types import ResourceType
from cvat_sdk.core.proxies.tasks import ResourceType
from .version import VERSION
......
......@@ -74,4 +74,4 @@ cvat_sdk/api_client/
requirements/
docs/
setup.py
README.md
\ No newline at end of file
README.md
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
......@@ -6,23 +5,22 @@
from __future__ import annotations
import json
import logging
import os.path as osp
import urllib.parse
from time import sleep
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Tuple
import attrs
import urllib3
from cvat_sdk.api_client import ApiClient, ApiException, ApiValueError, Configuration, models
from cvat_sdk.core.git import create_git_repo
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.tasks import TaskProxy
from cvat_sdk.core.types import ResourceType
from cvat_sdk.core.uploading import Uploader
from cvat_sdk.core.utils import assert_status
from cvat_sdk.api_client import ApiClient, Configuration, models
from cvat_sdk.core.helpers import expect_status
from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
from cvat_sdk.core.proxies.jobs import JobsRepo
from cvat_sdk.core.proxies.model_proxy import Repo
from cvat_sdk.core.proxies.projects import ProjectsRepo
from cvat_sdk.core.proxies.tasks import TasksRepo
from cvat_sdk.core.proxies.users import UsersRepo
@attrs.define
......@@ -43,11 +41,13 @@ class Client:
):
# TODO: use requests instead of urllib3 in ApiClient
# TODO: try to autodetect schema
self._api_map = _CVAT_API_V2(url)
self.api_map = CVAT_API_V2(url)
self.api = ApiClient(Configuration(host=url))
self.logger = logger or logging.getLogger(__name__)
self.config = config or Config()
self._repos: Dict[str, Repo] = {}
def __enter__(self):
self.api.__enter__()
return self
......@@ -67,150 +67,93 @@ class Client:
assert "csrftoken" in self.api.cookies
self.api.set_default_header("Authorization", "Token " + auth.key)
def create_task(
self,
spec: models.ITaskWriteRequest,
resource_type: ResourceType,
resources: Sequence[str],
def _has_credentials(self):
return (
("sessionid" in self.api.cookies)
or ("csrftoken" in self.api.cookies)
or (self.api.get_common_headers().get("Authorization", ""))
)
def logout(self):
if self._has_credentials():
self.api.auth_api.create_logout()
def wait_for_completion(
self: Client,
url: str,
*,
data_params: Optional[Dict[str, Any]] = None,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
status_check_period: int = None,
dataset_repository_url: str = "",
use_lfs: bool = False,
pbar: Optional[ProgressReporter] = None,
) -> TaskProxy:
"""
Create a new task with the given name and labels JSON and
add the files to it.
Returns: id of the created task
"""
success_status: int,
status_check_period: Optional[int] = None,
query_params: Optional[Dict[str, Any]] = None,
post_params: Optional[Dict[str, Any]] = None,
method: str = "POST",
positive_statuses: Optional[Sequence[int]] = None,
) -> urllib3.HTTPResponse:
if status_check_period is None:
status_check_period = self.config.status_check_period
if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
raise ApiValueError(
"Can't set labels to a task inside a project. "
"Tasks inside a project use project's labels.",
["labels"],
)
(task, _) = self.api.tasks_api.create(spec)
self.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
positive_statuses = set(positive_statuses) | {success_status}
task = TaskProxy(self, task)
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
self.logger.info("Awaiting for task %s creation...", task.id)
status = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
while True:
sleep(status_check_period)
(status, _) = self.api.tasks_api.retrieve_status(task.id)
self.logger.info(
"Task %s creation status=%s, message=%s",
task.id,
status.state.value,
status.message,
)
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
raise ApiException(status=status.state.value, reason=status.message)
status = status.state.value
if annotation_path:
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
if dataset_repository_url:
create_git_repo(
self,
task_id=task.id,
repo_url=dataset_repository_url,
status_check_period=status_check_period,
use_lfs=use_lfs,
response = self.api.rest_client.request(
method=method,
url=url,
headers=self.api.get_common_headers(),
query_params=query_params,
post_params=post_params,
)
task.fetch()
self.logger.debug("STATUS %s", response.status)
expect_status(positive_statuses, response)
if response.status == success_status:
break
return task
return response
def list_tasks(
self, *, return_json: bool = False, **kwargs
) -> Union[List[TaskProxy], List[Dict[str, Any]]]:
"""List all tasks in either basic or JSON format."""
def _get_repo(self, key: str) -> Repo:
_repo_map = {
"tasks": TasksRepo,
"projects": ProjectsRepo,
"jobs": JobsRepo,
"users": UsersRepo,
"issues": IssuesRepo,
"comments": CommentsRepo,
}
results = get_paginated_collection(
endpoint=self.api.tasks_api.list_endpoint, return_json=return_json, **kwargs
)
repo = self._repos.get(key, None)
if repo is None:
repo = _repo_map[key](self)
self._repos[key] = repo
return repo
if return_json:
return json.dumps(results)
return [TaskProxy(self, v) for v in results]
def retrieve_task(self, task_id: int) -> TaskProxy:
(task, _) = self.api.tasks_api.retrieve(task_id)
return TaskProxy(self, task)
def delete_tasks(self, task_ids: Sequence[int]):
"""
Delete a list of tasks, ignoring those which don't exist.
"""
for task_id in task_ids:
(_, response) = self.api.tasks_api.destroy(task_id, _check_status=False)
if 200 <= response.status <= 299:
self.logger.info(f"Task ID {task_id} deleted")
elif response.status == 404:
self.logger.info(f"Task ID {task_id} not found")
else:
self.logger.warning(
f"Failed to delete task ID {task_id}: "
f"{response.msg} (status {response.status})"
)
def create_task_from_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> TaskProxy:
"""
Import a task from a backup file
"""
if status_check_period is None:
status_check_period = self.config.status_check_period
@property
def tasks(self) -> TasksRepo:
return self._get_repo("tasks")
params = {"filename": osp.basename(filename)}
url = self._api_map.make_endpoint_url(self.api.tasks_api.create_backup_endpoint.path)
uploader = Uploader(self)
response = uploader.upload_file(
url, filename, meta=params, query_params=params, pbar=pbar, logger=self.logger.debug
)
rq_id = json.loads(response.data)["rq_id"]
@property
def projects(self) -> ProjectsRepo:
return self._get_repo("projects")
# check task status
while True:
sleep(status_check_period)
@property
def jobs(self) -> JobsRepo:
return self._get_repo("jobs")
response = self.api.rest_client.POST(
url, post_params={"rq_id": rq_id}, headers=self.api.get_common_headers()
)
if response.status == 201:
break
assert_status(202, response)
@property
def users(self) -> UsersRepo:
return self._get_repo("users")
task_id = json.loads(response.data)["id"]
self.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}")
@property
def issues(self) -> IssuesRepo:
return self._get_repo("issues")
return self.retrieve_task(task_id)
@property
def comments(self) -> CommentsRepo:
return self._get_repo("comments")
class _CVAT_API_V2:
class CVAT_API_V2:
"""Build parameterized API URLs"""
def __init__(self, host, https=False):
......
......@@ -8,8 +8,9 @@ from __future__ import annotations
import os
import os.path as osp
from contextlib import closing
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter
if TYPE_CHECKING:
......@@ -17,8 +18,12 @@ if TYPE_CHECKING:
class Downloader:
"""
Implements common downloading protocols
"""
def __init__(self, client: Client):
self.client = client
self._client = client
def download_file(
self,
......@@ -29,8 +34,7 @@ class Downloader:
pbar: Optional[ProgressReporter] = None,
) -> None:
"""
Downloads the file from url into a temporary file, then renames it
to the requested name.
Downloads the file from url into a temporary file, then renames it to the requested name.
"""
CHUNK_SIZE = 10 * 2**20
......@@ -41,10 +45,10 @@ class Downloader:
if osp.exists(tmp_path):
raise FileExistsError(f"Can't write temporary file '{tmp_path}' - file exists")
response = self.client.api.rest_client.GET(
response = self._client.api.rest_client.GET(
url,
_request_timeout=timeout,
headers=self.client.api.get_common_headers(),
headers=self._client.api.get_common_headers(),
_parse_response=False,
)
with closing(response):
......@@ -72,3 +76,38 @@ class Downloader:
except:
os.unlink(tmp_path)
raise
def prepare_and_download_file_from_endpoint(
self,
endpoint: Endpoint,
filename: str,
*,
url_params: Optional[Dict[str, Any]] = None,
query_params: Optional[Dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
):
client = self._client
if status_check_period is None:
status_check_period = client.config.status_check_period
client.logger.info("Waiting for the server to prepare the file...")
url = client.api_map.make_endpoint_url(
endpoint.path, kwsub=url_params, query_params=query_params
)
client.wait_for_completion(
url,
method="GET",
positive_statuses=[202],
success_status=201,
status_check_period=status_check_period,
)
query_params = dict(query_params or {})
query_params["action"] = "download"
url = client.api_map.make_endpoint_url(
endpoint.path, kwsub=url_params, query_params=query_params
)
downloader = Downloader(client)
downloader.download_file(url, output_path=filename, pbar=pbar)
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
......@@ -27,7 +26,7 @@ def create_git_repo(
common_headers = client.api.get_common_headers()
response = client.api.rest_client.POST(
client._api_map.git_create(task_id),
client.api_map.git_create(task_id),
post_params={"path": repo_url, "lfs": use_lfs, "tid": task_id},
headers=common_headers,
)
......@@ -36,7 +35,7 @@ def create_git_repo(
client.logger.info(f"Create RQ ID: {rq_id}")
client.logger.debug("Awaiting a dataset repository to be created for the task %s...", task_id)
check_url = client._api_map.git_check(rq_id)
check_url = client.api_map.git_check(rq_id)
status = None
while status != "finished":
sleep(status_check_period)
......
......@@ -6,13 +6,14 @@ from __future__ import annotations
import io
import json
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union
import tqdm
import urllib3
from cvat_sdk import exceptions
from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.utils import assert_status
def get_paginated_collection(
......@@ -26,7 +27,7 @@ def get_paginated_collection(
page = 1
while True:
(page_contents, response) = endpoint.call_with_http_info(**kwargs, page=page)
assert_status(200, response)
expect_status(200, response)
if return_json:
results.extend(json.loads(response.data).get("results", []))
......@@ -86,3 +87,18 @@ class StreamWithProgress:
def tell(self):
return self.stream.tell()
def expect_status(codes: Union[int, Iterable[int]], response: urllib3.HTTPResponse) -> None:
if not hasattr(codes, "__iter__"):
codes = [codes]
if response.status in codes:
return
if 300 <= response.status <= 500:
raise exceptions.ApiException(response.status, reason=response.msg, http_resp=response)
else:
raise exceptions.ApiException(
response.status, reason="Unexpected status code received", http_resp=response
)
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from abc import ABC
from enum import Enum
from typing import Optional, Sequence
from cvat_sdk import models
from cvat_sdk.core.proxies.model_proxy import _EntityT
class AnnotationUpdateAction(Enum):
CREATE = "create"
UPDATE = "update"
DELETE = "delete"
class AnnotationCrudMixin(ABC):
# TODO: refactor
@property
def _put_annotations_data_param(self) -> str:
...
def get_annotations(self: _EntityT) -> models.ILabeledData:
(annotations, _) = self.api.retrieve_annotations(getattr(self, self._model_id_field))
return annotations
def set_annotations(self: _EntityT, data: models.ILabeledDataRequest):
self.api.update_annotations(
getattr(self, self._model_id_field), **{self._put_annotations_data_param: data}
)
def update_annotations(
self: _EntityT,
data: models.IPatchedLabeledDataRequest,
*,
action: AnnotationUpdateAction = AnnotationUpdateAction.UPDATE,
):
self.api.partial_update_annotations(
action=action.value,
id=getattr(self, self._model_id_field),
patched_labeled_data_request=data,
)
def remove_annotations(self: _EntityT, *, ids: Optional[Sequence[int]] = None):
if ids:
anns = self.get_annotations()
if not isinstance(ids, set):
ids = set(ids)
anns_to_remove = models.PatchedLabeledDataRequest(
tags=[models.LabeledImageRequest(**a.to_dict()) for a in anns.tags if a.id in ids],
tracks=[
models.LabeledTrackRequest(**a.to_dict()) for a in anns.tracks if a.id in ids
],
shapes=[
models.LabeledShapeRequest(**a.to_dict()) for a in anns.shapes if a.id in ids
],
)
self.update_annotations(anns_to_remove, action=AnnotationUpdateAction.DELETE)
else:
self.api.destroy_annotations(getattr(self, self._model_id_field))
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.proxies.model_proxy import (
ModelCreateMixin,
ModelDeleteMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
_CommentEntityBase, _CommentRepoBase = build_model_bases(
models.CommentRead, apis.CommentsApi, api_member_name="comments_api"
)
class Comment(
models.ICommentRead,
_CommentEntityBase,
ModelUpdateMixin[models.IPatchedCommentWriteRequest],
ModelDeleteMixin,
):
_model_partial_update_arg = "patched_comment_write_request"
class CommentsRepo(
_CommentRepoBase,
ModelListMixin[Comment],
ModelCreateMixin[Comment, models.ICommentWriteRequest],
ModelRetrieveMixin[Comment],
):
_entity_type = Comment
_IssueEntityBase, _IssueRepoBase = build_model_bases(
models.IssueRead, apis.IssuesApi, api_member_name="issues_api"
)
class Issue(
models.IIssueRead,
_IssueEntityBase,
ModelUpdateMixin[models.IPatchedIssueWriteRequest],
ModelDeleteMixin,
):
_model_partial_update_arg = "patched_issue_write_request"
class IssuesRepo(
_IssueRepoBase,
ModelListMixin[Issue],
ModelCreateMixin[Issue, models.IIssueWriteRequest],
ModelRetrieveMixin[Issue],
):
_entity_type = Issue
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import io
import mimetypes
import os
import os.path as osp
from typing import List, Optional, Sequence
from PIL import Image
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
from cvat_sdk.core.proxies.issues import Issue
from cvat_sdk.core.proxies.model_proxy import (
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
from cvat_sdk.core.uploading import AnnotationUploader
_JobEntityBase, _JobRepoBase = build_model_bases(
models.JobRead, apis.JobsApi, api_member_name="jobs_api"
)
class Job(
models.IJobRead,
_JobEntityBase,
ModelUpdateMixin[models.IPatchedJobWriteRequest],
AnnotationCrudMixin,
):
_model_partial_update_arg = "patched_job_write_request"
_put_annotations_data_param = "job_annotations_update_request"
def import_annotations(
self,
format_name: str,
filename: str,
*,
status_check_period: Optional[int] = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Upload annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
"""
AnnotationUploader(self._client).upload_file_and_wait(
self.api.create_annotations_endpoint,
filename,
format_name,
url_params={"id": self.id},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Annotation file '{filename}' for job #{self.id} uploaded")
def export_dataset(
self,
format_name: str,
filename: str,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
"""
if include_images:
endpoint = self.api.retrieve_dataset_endpoint
else:
endpoint = self.api.retrieve_annotations_endpoint
Downloader(self._client).prepare_and_download_file_from_endpoint(
endpoint=endpoint,
filename=filename,
url_params={"id": self.id},
query_params={"format": format_name},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Dataset for job {self.id} has been downloaded to {filename}")
def get_frame(
self,
frame_id: int,
*,
quality: Optional[str] = None,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(
self.id, number=frame_id, quality=quality, type="frame"
)
return io.BytesIO(response.data)
def get_preview(
self,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(self.id, type="preview")
return io.BytesIO(response.data)
def download_frames(
self,
frame_ids: Sequence[int],
*,
outdir: str = "",
quality: str = "original",
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
) -> Optional[List[Image.Image]]:
"""
Download the requested frame numbers for a job and save images as outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
os.makedirs(outdir, exist_ok=True)
for frame_id in frame_ids:
frame_bytes = self.get_frame(frame_id, quality=quality)
im = Image.open(frame_bytes)
mime_type = im.get_format_mimetype() or "image/jpg"
im_ext = mimetypes.guess_extension(mime_type)
# FIXME It is better to use meta information from the server
# to determine the extension
# replace '.jpe' or '.jpeg' with a more used '.jpg'
if im_ext in (".jpe", ".jpeg", None):
im_ext = ".jpg"
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
im.save(osp.join(outdir, outfile))
def get_meta(self) -> models.IDataMetaRead:
(meta, _) = self.api.retrieve_data_meta(self.id)
return meta
def get_frames_info(self) -> List[models.IFrameMeta]:
return self.get_meta().frames
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
self._client.api.tasks_api.jobs_partial_update_data_meta(
self.id,
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
)
def get_issues(self) -> List[Issue]:
return [Issue(self._client, m) for m in self.api.list_issues(id=self.id)[0]]
def get_commits(self) -> List[models.IJobCommit]:
return get_paginated_collection(self.api.list_commits_endpoint, id=self.id)
class JobsRepo(
_JobRepoBase,
ModelListMixin[Job],
ModelRetrieveMixin[Job],
):
_entity_type = Job
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import json
from abc import ABC
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import Self
from cvat_sdk.api_client.model_utils import IModelData, ModelNormal, to_json
from cvat_sdk.core.helpers import get_paginated_collection
if TYPE_CHECKING:
from cvat_sdk.core.client import Client
IModel = TypeVar("IModel", bound=IModelData)
ModelType = TypeVar("ModelType", bound=ModelNormal)
ApiType = TypeVar("ApiType")
class ModelProxy(ABC, Generic[ModelType, ApiType]):
_client: Client
@property
def _api_member_name(self) -> str:
...
def __init__(self, client: Client) -> None:
self.__dict__["_client"] = client
@classmethod
def get_api(cls, client: Client) -> ApiType:
return getattr(client.api, cls._api_member_name)
@property
def api(self) -> ApiType:
return self.get_api(self._client)
class Entity(ModelProxy[ModelType, ApiType]):
"""
Represents a single object. Implements related operations and provides access to data members.
"""
_model: ModelType
def __init__(self, client: Client, model: ModelType) -> None:
super().__init__(client)
self.__dict__["_model"] = model
@property
def _model_id_field(self) -> str:
return "id"
def __getattr__(self, __name: str) -> Any:
# NOTE: be aware of potential problems with throwing AttributeError from @property
# in derived classes!
# https://medium.com/@ceshine/python-debugging-pitfall-mixed-use-of-property-and-getattr-f89e0ede13f1
return self._model[__name]
def __str__(self) -> str:
return str(self._model)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={getattr(self, self._model_id_field)}>"
class Repo(ModelProxy[ModelType, ApiType]):
"""
Represents a collection of corresponding Entity objects.
Implements group and management operations for entities.
"""
_entity_type: Type[Entity[ModelType, ApiType]]
### Utilities
def build_model_bases(
mt: Type[ModelType], at: Type[ApiType], *, api_member_name: Optional[str] = None
) -> Tuple[Type[Entity[ModelType, ApiType]], Type[Repo[ModelType, ApiType]]]:
"""
Helps to remove code duplication in declarations of derived classes
"""
class _EntityBase(Entity[ModelType, ApiType]):
if api_member_name:
_api_member_name = api_member_name
class _RepoBase(Repo[ModelType, ApiType]):
if api_member_name:
_api_member_name = api_member_name
return _EntityBase, _RepoBase
### CRUD mixins
_EntityT = TypeVar("_EntityT", bound=Entity)
#### Repo mixins
class ModelCreateMixin(Generic[_EntityT, IModel]):
def create(self: Repo, spec: Union[Dict[str, Any], IModel]) -> _EntityT:
"""
Creates a new object on the server and returns corresponding local object
"""
(model, _) = self.api.create(spec)
return self._entity_type(self._client, model)
class ModelRetrieveMixin(Generic[_EntityT]):
def retrieve(self: Repo, obj_id: int) -> _EntityT:
"""
Retrieves an object from server by ID
"""
(model, _) = self.api.retrieve(id=obj_id)
return self._entity_type(self._client, model)
class ModelListMixin(Generic[_EntityT]):
@overload
def list(self: Repo, *, return_json: Literal[False] = False) -> List[_EntityT]:
...
@overload
def list(self: Repo, *, return_json: Literal[True] = False) -> List[Any]:
...
def list(self: Repo, *, return_json: bool = False) -> List[Union[_EntityT, Any]]:
"""
Retrieves all objects from the server and returns them in basic or JSON format.
"""
results = get_paginated_collection(endpoint=self.api.list_endpoint, return_json=return_json)
if return_json:
return json.dumps(results)
return [self._entity_type(self._client, model) for model in results]
#### Entity mixins
class ModelUpdateMixin(ABC, Generic[IModel]):
@property
def _model_partial_update_arg(self: Entity) -> str:
...
def _export_update_fields(
self: Entity, overrides: Optional[Union[Dict[str, Any], IModel]] = None
) -> Dict[str, Any]:
# TODO: support field conversion and assignment updating
# fields = to_json(self._model)
if isinstance(overrides, ModelNormal):
overrides = to_json(overrides)
fields = deepcopy(overrides)
return fields
def fetch(self: Entity) -> Self:
"""
Updates current object from the server
"""
# TODO: implement revision checking
(self._model, _) = self.api.retrieve(id=getattr(self, self._model_id_field))
return self
def update(self: Entity, values: Union[Dict[str, Any], IModel]) -> Self:
"""
Commits local model changes to the server
"""
# TODO: implement revision checking
self.api.partial_update(
id=getattr(self, self._model_id_field),
**{self._model_partial_update_arg: self._export_update_fields(values)},
)
# TODO: use the response model, once input and output models are same
return self.fetch()
class ModelDeleteMixin:
def remove(self: Entity) -> None:
"""
Removes current object on the server
"""
self.api.destroy(id=getattr(self, self._model_id_field))
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import json
import os.path as osp
from typing import Optional
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.model_proxy import (
ModelCreateMixin,
ModelDeleteMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
from cvat_sdk.core.uploading import DatasetUploader, Uploader
_ProjectEntityBase, _ProjectRepoBase = build_model_bases(
models.ProjectRead, apis.ProjectsApi, api_member_name="projects_api"
)
class Project(
_ProjectEntityBase, models.IProjectRead, ModelUpdateMixin[models.IPatchedProjectWriteRequest]
):
_model_partial_update_arg = "patched_project_write_request"
def import_dataset(
self,
format_name: str,
filename: str,
*,
status_check_period: Optional[int] = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Import dataset for a project in the specified format (e.g. 'YOLO ZIP 1.0').
"""
DatasetUploader(self._client).upload_file_and_wait(
self.api.create_dataset_endpoint,
filename,
format_name,
url_params={"id": self.id},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Annotation file '{filename}' for project #{self.id} uploaded")
def export_dataset(
self,
format_name: str,
filename: str,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a project in the specified format (e.g. 'YOLO ZIP 1.0').
"""
if include_images:
endpoint = self.api.retrieve_dataset_endpoint
else:
endpoint = self.api.retrieve_annotations_endpoint
Downloader(self._client).prepare_and_download_file_from_endpoint(
endpoint=endpoint,
filename=filename,
url_params={"id": self.id},
query_params={"format": format_name},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Dataset for project {self.id} has been downloaded to {filename}")
def download_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> None:
"""
Download a project backup
"""
Downloader(self._client).prepare_and_download_file_from_endpoint(
self.api.retrieve_backup_endpoint,
filename=filename,
pbar=pbar,
status_check_period=status_check_period,
url_params={"id": self.id},
)
self._client.logger.info(f"Backup for project {self.id} has been downloaded to {filename}")
def get_annotations(self) -> models.ILabeledData:
(annotations, _) = self.api.retrieve_annotations(self.id)
return annotations
class ProjectsRepo(
_ProjectRepoBase,
ModelCreateMixin[Project, models.IProjectWriteRequest],
ModelListMixin[Project],
ModelRetrieveMixin[Project],
ModelDeleteMixin,
):
_entity_type = Project
def create_from_dataset(
self,
spec: models.IProjectWriteRequest,
*,
dataset_path: str = "",
dataset_format: str = "CVAT XML 1.1",
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> Project:
"""
Create a new project with the given name and labels JSON and
add the files to it.
Returns: id of the created project
"""
project = self.create(spec=spec)
self._client.logger.info("Created project ID: %s NAME: %s", project.id, project.name)
if dataset_path:
project.import_dataset(
format_name=dataset_format,
filename=dataset_path,
pbar=pbar,
status_check_period=status_check_period,
)
project.fetch()
return project
def create_from_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> Project:
"""
Import a project from a backup file
"""
if status_check_period is None:
status_check_period = self.config.status_check_period
params = {"filename": osp.basename(filename)}
url = self.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
uploader = Uploader(self)
response = uploader.upload_file(
url,
filename,
meta=params,
query_params=params,
pbar=pbar,
logger=self._client.logger.debug,
)
rq_id = json.loads(response.data)["rq_id"]
response = self._client.wait_for_completion(
url,
success_status=201,
positive_statuses=[202],
post_params={"rq_id": rq_id},
status_check_period=status_check_period,
)
project_id = json.loads(response.data)["id"]
self._client.logger.info(f"Project has been imported sucessfully. Project ID: {project_id}")
return self.retrieve(project_id)
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
......@@ -6,71 +5,60 @@
from __future__ import annotations
import io
import json
import mimetypes
import os
import os.path as osp
from abc import ABC, abstractmethod
from io import BytesIO
from enum import Enum
from time import sleep
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence
from PIL import Image
from cvat_sdk import models
from cvat_sdk.api_client.model_utils import OpenApiModel
from cvat_sdk.api_client import apis, exceptions, models
from cvat_sdk.core import git
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.types import ResourceType
from cvat_sdk.core.uploading import Uploader
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
from cvat_sdk.core.proxies.jobs import Job
from cvat_sdk.core.proxies.model_proxy import (
ModelCreateMixin,
ModelDeleteMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
from cvat_sdk.core.uploading import AnnotationUploader, DataUploader, Uploader
from cvat_sdk.core.utils import filter_dict
if TYPE_CHECKING:
from cvat_sdk.core.client import Client
class ResourceType(Enum):
LOCAL = 0
SHARE = 1
REMOTE = 2
class ModelProxy(ABC):
_client: Client
_model: OpenApiModel
def __str__(self):
return self.name.lower()
def __init__(self, client: Client, model: OpenApiModel) -> None:
self.__dict__["_client"] = client
self.__dict__["_model"] = model
def __repr__(self):
return str(self)
def __getattr__(self, __name: str) -> Any:
return self._model[__name]
def __setattr__(self, __name: str, __value: Any) -> None:
if __name in self.__dict__:
self.__dict__[__name] = __value
else:
self._model[__name] = __value
@abstractmethod
def fetch(self, force: bool = False):
"""Fetches model data from the server"""
...
@abstractmethod
def commit(self, force: bool = False):
"""Commits local changes to the server"""
...
def sync(self):
"""Pulls server state and commits local model changes"""
raise NotImplementedError
@abstractmethod
def update(self, **kwargs):
"""Updates multiple fields at once"""
...
_TaskEntityBase, _TaskRepoBase = build_model_bases(
models.TaskRead, apis.TasksApi, api_member_name="tasks_api"
)
class TaskProxy(ModelProxy, models.ITaskRead):
def __init__(self, client: Client, task: models.TaskRead):
ModelProxy.__init__(self, client=client, model=task)
def remove(self):
self._client.api.tasks_api.destroy(self.id)
class Task(
_TaskEntityBase,
models.ITaskRead,
ModelUpdateMixin[models.IPatchedTaskWriteRequest],
ModelDeleteMixin,
AnnotationCrudMixin,
):
_model_partial_update_arg = "patched_task_write_request"
_put_annotations_data_param = "task_annotations_update_request"
def upload_data(
self,
......@@ -83,9 +71,6 @@ class TaskProxy(ModelProxy, models.ITaskRead):
"""
Add local, remote, or shared files to an existing task.
"""
client = self._client
task_id = self.id
params = params or {}
data = {}
......@@ -116,73 +101,58 @@ class TaskProxy(ModelProxy, models.ITaskRead):
data["frame_filter"] = f"step={params.get('frame_step')}"
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
client.api.tasks_api.create_data(
task_id,
self.api.create_data(
self.id,
data_request=models.DataRequest(**data),
_content_type="multipart/form-data",
)
elif resource_type == ResourceType.LOCAL:
url = client._api_map.make_endpoint_url(
client.api.tasks_api.create_data_endpoint.path, kwsub={"id": task_id}
url = self._client.api_map.make_endpoint_url(
self.api.create_data_endpoint.path, kwsub={"id": self.id}
)
uploader = Uploader(client)
uploader.upload_files(url, resources, pbar=pbar, **data)
DataUploader(self._client).upload_files(url, resources, pbar=pbar, **data)
def import_annotations(
self,
format_name: str,
filename: str,
*,
status_check_period: int = None,
status_check_period: Optional[int] = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Upload annotations for a task in the specified format
(e.g. 'YOLO ZIP 1.0').
Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
client = self._client
if status_check_period is None:
status_check_period = client.config.status_check_period
task_id = self.id
url = client._api_map.make_endpoint_url(
client.api.tasks_api.create_annotations_endpoint.path,
kwsub={"id": task_id},
AnnotationUploader(self._client).upload_file_and_wait(
self.api.create_annotations_endpoint,
filename,
format_name,
url_params={"id": self.id},
pbar=pbar,
status_check_period=status_check_period,
)
params = {"format": format_name, "filename": osp.basename(filename)}
uploader = Uploader(client)
uploader.upload_file(
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
)
while True:
response = client.api.rest_client.POST(
url, headers=client.api.get_common_headers(), query_params=params
)
if response.status == 201:
break
sleep(status_check_period)
self._client.logger.info(f"Annotation file '{filename}' for task #{self.id} uploaded")
client.logger.info(
f"Upload job for Task ID {task_id} with annotation file {filename} finished"
)
def retrieve_frame(
def get_frame(
self,
frame_id: int,
*,
quality: Optional[str] = None,
) -> io.RawIOBase:
client = self._client
task_id = self.id
params = {}
if quality:
params["quality"] = quality
(_, response) = self.api.retrieve_data(self.id, number=frame_id, **params, type="frame")
return io.BytesIO(response.data)
(_, response) = client.api.tasks_api.retrieve_data(task_id, frame_id, quality, type="frame")
return BytesIO(response.data)
def get_preview(
self,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(self.id, type="preview")
return io.BytesIO(response.data)
def download_frames(
self,
......@@ -190,19 +160,16 @@ class TaskProxy(ModelProxy, models.ITaskRead):
*,
outdir: str = "",
quality: str = "original",
filename_pattern: str = "task_{task_id}_frame_{frame_id:06d}{frame_ext}",
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
) -> Optional[List[Image.Image]]:
"""
Download the requested frame numbers for a task and save images as
outdir/filename_pattern
Download the requested frame numbers for a task and save images as outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
task_id = self.id
os.makedirs(outdir, exist_ok=True)
for frame_id in frame_ids:
frame_bytes = self.retrieve_frame(frame_id, quality=quality)
frame_bytes = self.get_frame(frame_id, quality=quality)
im = Image.open(frame_bytes)
mime_type = im.get_format_mimetype() or "image/jpg"
......@@ -214,7 +181,7 @@ class TaskProxy(ModelProxy, models.ITaskRead):
if im_ext in (".jpe", ".jpeg", None):
im_ext = ".jpg"
outfile = filename_pattern.format(task_id=task_id, frame_id=frame_id, frame_ext=im_ext)
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
im.save(osp.join(outdir, outfile))
def export_dataset(
......@@ -223,40 +190,27 @@ class TaskProxy(ModelProxy, models.ITaskRead):
filename: str,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: int = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
client = self._client
if status_check_period is None:
status_check_period = client.config.status_check_period
task_id = self.id
params = {"filename": self.name, "format": format_name}
if include_images:
endpoint = client.api.tasks_api.retrieve_dataset_endpoint
endpoint = self.api.retrieve_dataset_endpoint
else:
endpoint = client.api.tasks_api.retrieve_annotations_endpoint
client.logger.info("Waiting for the server to prepare the file...")
while True:
(_, response) = endpoint.call_with_http_info(id=task_id, **params)
client.logger.debug("STATUS {}".format(response.status))
if response.status == 201:
break
sleep(status_check_period)
params["action"] = "download"
url = client._api_map.make_endpoint_url(
endpoint.path, kwsub={"id": task_id}, query_params=params
endpoint = self.api.retrieve_annotations_endpoint
Downloader(self._client).prepare_and_download_file_from_endpoint(
endpoint=endpoint,
filename=filename,
url_params={"id": self.id},
query_params={"format": format_name},
pbar=pbar,
status_check_period=status_check_period,
)
downloader = Downloader(client)
downloader.download_file(url, output_path=filename, pbar=pbar)
client.logger.info(f"Dataset has been exported to {filename}")
self._client.logger.info(f"Dataset for task {self.id} has been downloaded to {filename}")
def download_backup(
self,
......@@ -264,45 +218,171 @@ class TaskProxy(ModelProxy, models.ITaskRead):
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
):
) -> None:
"""
Download a task backup
"""
client = self._client
Downloader(self._client).prepare_and_download_file_from_endpoint(
self.api.retrieve_backup_endpoint,
filename=filename,
pbar=pbar,
status_check_period=status_check_period,
url_params={"id": self.id},
)
self._client.logger.info(f"Backup for task {self.id} has been downloaded to {filename}")
def get_jobs(self) -> List[Job]:
return [Job(self._client, m) for m in self.api.list_jobs(id=self.id)[0]]
def get_meta(self) -> models.IDataMetaRead:
(meta, _) = self.api.retrieve_data_meta(self.id)
return meta
def get_frames_info(self) -> List[models.IFrameMeta]:
return self.get_meta().frames
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
self.api.partial_update_data_meta(
self.id,
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
)
class TasksRepo(
_TaskRepoBase,
ModelCreateMixin[Task, models.ITaskWriteRequest],
ModelRetrieveMixin[Task],
ModelListMixin[Task],
ModelDeleteMixin,
):
_entity_type = Task
def create_from_data(
self,
spec: models.ITaskWriteRequest,
resource_type: ResourceType,
resources: Sequence[str],
*,
data_params: Optional[Dict[str, Any]] = None,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
status_check_period: int = None,
dataset_repository_url: str = "",
use_lfs: bool = False,
pbar: Optional[ProgressReporter] = None,
) -> Task:
"""
Create a new task with the given name and labels JSON and
add the files to it.
Returns: id of the created task
"""
if status_check_period is None:
status_check_period = client.config.status_check_period
status_check_period = self._client.config.status_check_period
if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
raise exceptions.ApiValueError(
"Can't set labels to a task inside a project. "
"Tasks inside a project use project's labels.",
["labels"],
)
task_id = self.id
task = self.create(spec=spec)
self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
endpoint = client.api.tasks_api.retrieve_backup_endpoint
client.logger.info("Waiting for the server to prepare the file...")
while True:
(_, response) = endpoint.call_with_http_info(id=task_id)
client.logger.debug("STATUS {}".format(response.status))
if response.status == 201:
break
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
self._client.logger.info("Awaiting for task %s creation...", task.id)
status: models.RqStatus = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
sleep(status_check_period)
(status, response) = self.api.retrieve_status(task.id)
url = client._api_map.make_endpoint_url(
endpoint.path, kwsub={"id": task_id}, query_params={"action": "download"}
)
downloader = Downloader(client)
downloader.download_file(url, output_path=filename, pbar=pbar)
self._client.logger.info(
"Task %s creation status=%s, message=%s",
task.id,
status.state.value,
status.message,
)
client.logger.info(
f"Task {task_id} has been exported sucessfully to {osp.abspath(filename)}"
)
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
raise exceptions.ApiException(
status=status.state.value, reason=status.message, http_resp=response
)
status = status.state.value
if annotation_path:
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
def fetch(self, force: bool = False):
# TODO: implement revision checking
model, _ = self._client.api.tasks_api.retrieve(self.id)
self._model = model
if dataset_repository_url:
git.create_git_repo(
self,
task_id=task.id,
repo_url=dataset_repository_url,
status_check_period=status_check_period,
use_lfs=use_lfs,
)
task.fetch()
return task
def remove_by_ids(self, task_ids: Sequence[int]) -> None:
"""
Delete a list of tasks, ignoring those which don't exist.
"""
for task_id in task_ids:
(_, response) = self.api.destroy(task_id, _check_status=False)
if 200 <= response.status <= 299:
self._client.logger.info(f"Task ID {task_id} deleted")
elif response.status == 404:
self._client.logger.info(f"Task ID {task_id} not found")
else:
self._client.logger.warning(
f"Failed to delete task ID {task_id}: "
f"{response.msg} (status {response.status})"
)
def commit(self, force: bool = False):
return super().commit(force)
def create_from_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> Task:
"""
Import a task from a backup file
"""
if status_check_period is None:
status_check_period = self._client.config.status_check_period
params = {"filename": osp.basename(filename)}
url = self._client.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
uploader = Uploader(self._client)
response = uploader.upload_file(
url,
filename,
meta=params,
query_params=params,
pbar=pbar,
logger=self._client.logger.debug,
)
rq_id = json.loads(response.data)["rq_id"]
response = self._client.wait_for_completion(
url,
success_status=201,
positive_statuses=[202],
post_params={"rq_id": rq_id},
status_check_period=status_check_period,
)
def update(self, **kwargs):
return super().update(**kwargs)
task_id = json.loads(response.data)["id"]
self._client.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}")
def __str__(self) -> str:
return str(self._model)
return self.retrieve(task_id)
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.proxies.model_proxy import (
ModelDeleteMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
_UserEntityBase, _UserRepoBase = build_model_bases(
models.User, apis.UsersApi, api_member_name="users_api"
)
class User(
models.IUser, _UserEntityBase, ModelUpdateMixin[models.IPatchedUserRequest], ModelDeleteMixin
):
_model_partial_update_arg = "patched_user_request"
class UsersRepo(
_UserRepoBase,
ModelListMixin[User],
ModelRetrieveMixin[User],
):
_entity_type = User
def retrieve_current_user(self) -> User:
return User(self._client, self.api.retrieve_self()[0])
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from enum import Enum
class ResourceType(Enum):
LOCAL = 0
SHARE = 1
REMOTE = 2
def __str__(self):
return self.name.lower()
def __repr__(self):
return str(self)
......@@ -7,16 +7,15 @@ from __future__ import annotations
import os
import os.path as osp
from contextlib import ExitStack, closing
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
import requests
import urllib3
from cvat_sdk.api_client import ApiClient
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
from cvat_sdk.api_client.rest import RESTClientObject
from cvat_sdk.core.helpers import StreamWithProgress
from cvat_sdk.core.helpers import StreamWithProgress, expect_status
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.utils import assert_status
if TYPE_CHECKING:
from cvat_sdk.core.client import Client
......@@ -25,57 +24,12 @@ MAX_REQUEST_SIZE = 100 * 2**20
class Uploader:
def __init__(self, client: Client):
self.client = client
def upload_files(
self,
url: str,
resources: List[str],
*,
pbar: Optional[ProgressReporter] = None,
**kwargs,
):
bulk_file_groups, separate_files, total_size = self._split_files_by_requests(resources)
if pbar is not None:
pbar.start(total_size, desc="Uploading data")
self._tus_start_upload(url)
"""
Implements common uploading protocols
"""
for group, group_size in bulk_file_groups:
with ExitStack() as es:
files = {}
for i, filename in enumerate(group):
files[f"client_files[{i}]"] = (
filename,
es.enter_context(closing(open(filename, "rb"))).read(),
)
response = self.client.api.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
**self.client.api.get_common_headers(),
},
)
assert_status(200, response)
if pbar is not None:
pbar.advance(group_size)
for filename in separate_files:
# TODO: check if basename produces invalid paths here, can lead to overwriting
self._upload_file_data_with_tus(
url,
filename,
meta={"filename": osp.basename(filename)},
pbar=pbar,
logger=self.client.logger.debug,
)
self._tus_finish_upload(url, fields=kwargs)
def __init__(self, client: Client):
self._client = client
def upload_file(
self,
......@@ -121,6 +75,27 @@ class Uploader:
)
return self._tus_finish_upload(url, query_params=query_params, fields=fields)
def _wait_for_completion(
self,
url: str,
*,
success_status: int,
status_check_period: Optional[int] = None,
query_params: Optional[Dict[str, Any]] = None,
post_params: Optional[Dict[str, Any]] = None,
method: str = "POST",
positive_statuses: Optional[Sequence[int]] = None,
) -> urllib3.HTTPResponse:
return self._client.wait_for_completion(
url,
success_status=success_status,
status_check_period=status_check_period,
query_params=query_params,
post_params=post_params,
method=method,
positive_statuses=positive_statuses,
)
def _split_files_by_requests(
self, filenames: List[str]
) -> Tuple[List[Tuple[List[str], int]], List[str], int]:
......@@ -268,7 +243,7 @@ class Uploader:
input_file = StreamWithProgress(input_file, pbar, length=file_size)
tus_uploader = self._make_tus_uploader(
self.client.api,
self._client.api,
url=url.rstrip("/") + "/",
metadata=meta,
file_stream=input_file,
......@@ -278,26 +253,131 @@ class Uploader:
tus_uploader.upload()
def _tus_start_upload(self, url, *, query_params=None):
response = self.client.api.rest_client.POST(
response = self._client.api.rest_client.POST(
url,
query_params=query_params,
headers={
"Upload-Start": "",
**self.client.api.get_common_headers(),
**self._client.api.get_common_headers(),
},
)
assert_status(202, response)
expect_status(202, response)
return response
def _tus_finish_upload(self, url, *, query_params=None, fields=None):
response = self.client.api.rest_client.POST(
response = self._client.api.rest_client.POST(
url,
headers={
"Upload-Finish": "",
**self.client.api.get_common_headers(),
**self._client.api.get_common_headers(),
},
query_params=query_params,
post_params=fields,
)
assert_status(202, response)
expect_status(202, response)
return response
class AnnotationUploader(Uploader):
def upload_file_and_wait(
self,
endpoint: Endpoint,
filename: str,
format_name: str,
*,
url_params: Optional[Dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
):
url = self._client.api_map.make_endpoint_url(endpoint.path, kwsub=url_params)
params = {"format": format_name, "filename": osp.basename(filename)}
self.upload_file(
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
)
self._wait_for_completion(
url,
success_status=201,
positive_statuses=[202],
status_check_period=status_check_period,
query_params=params,
method="POST",
)
class DatasetUploader(Uploader):
def upload_file_and_wait(
self,
endpoint: Endpoint,
filename: str,
format_name: str,
*,
url_params: Optional[Dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
):
url = self._client.api_map.make_endpoint_url(endpoint.path, kwsub=url_params)
params = {"format": format_name, "filename": osp.basename(filename)}
self.upload_file(
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
)
self._wait_for_completion(
url,
success_status=201,
positive_statuses=[202],
status_check_period=status_check_period,
query_params=params,
method="GET",
)
class DataUploader(Uploader):
def upload_files(
self,
url: str,
resources: List[str],
*,
pbar: Optional[ProgressReporter] = None,
**kwargs,
):
bulk_file_groups, separate_files, total_size = self._split_files_by_requests(resources)
if pbar is not None:
pbar.start(total_size, desc="Uploading data")
self._tus_start_upload(url)
for group, group_size in bulk_file_groups:
with ExitStack() as es:
files = {}
for i, filename in enumerate(group):
files[f"client_files[{i}]"] = (
filename,
es.enter_context(closing(open(filename, "rb"))).read(),
)
response = self._client.api.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
**self._client.api.get_common_headers(),
},
)
expect_status(200, response)
if pbar is not None:
pbar.advance(group_size)
for filename in separate_files:
# TODO: check if basename produces invalid paths here, can lead to overwriting
self._upload_file_data_with_tus(
url,
filename,
meta={"filename": osp.basename(filename)},
pbar=pbar,
logger=self._client.logger.debug,
)
self._tus_finish_upload(url, fields=kwargs)
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
......@@ -7,13 +6,6 @@ from __future__ import annotations
from typing import Any, Dict, Sequence
import urllib3
def assert_status(code: int, response: urllib3.HTTPResponse) -> None:
if response.status != code:
raise Exception(f"Unexpected status code received {response.status}")
def filter_dict(
d: Dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None
......
......@@ -48,7 +48,7 @@ class Processor:
tokenized_path = tokenized_path[2:]
prefix = tokenized_path[0] + "_"
if new_name.startswith(prefix):
if new_name.startswith(prefix) and tokenized_path[0] in operation["tags"]:
new_name = new_name[len(prefix) :]
return new_name
......
......@@ -345,6 +345,9 @@ class ApiClient(object):
"""
if response_schema == (file_type,):
# TODO: response schema can be "oneOf" with a file option,
# this implementation does not cover this.
# handle file downloading
# save response body into a tmp file and return the instance
content_disposition = response.getheader("Content-Disposition")
......
......@@ -9,6 +9,7 @@ import sys # noqa: F401
from {{packageName}}.model_utils import ( # noqa: F401
ApiTypeError,
IModelData,
ModelComposed,
ModelNormal,
ModelSimple,
......
class I{{classname}}:
class I{{classname}}(IModelData):
"""
NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
......
class I{{classname}}:
class I{{classname}}(IModelData):
"""
NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
......
......@@ -113,6 +113,11 @@ def composed_model_input_classes(cls):
return []
class IModelData:
"""
The base class for model data. Declares model fields and their types for better introspection
"""
class OpenApiModel(object):
"""The base class for all OpenAPIModels"""
......
......@@ -3,3 +3,4 @@
attrs >= 21.4.0
tqdm >= 4.64.0
tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code
typing_extensions >= 4.2.0
{
"name": "cvat-ui",
"version": "1.41.0",
"version": "1.41.1",
"description": "CVAT single-page application",
"main": "src/index.tsx",
"scripts": {
......
......@@ -70,7 +70,7 @@ function TasksPageComponent(props: Props): JSX.Element {
<Button
type='link'
onClick={(): void => {
dispatch(hideEmptyTasks(true));
dispatch(hideEmptyTasks(false));
message.destroy();
}}
>
......
......@@ -109,10 +109,8 @@ def update_git_repo(request, tid):
status=http.HTTPStatus.OK,
)
except Exception as ex:
try:
with contextlib.suppress(Exception):
slogger.task[tid].error("error occurred during changing repository request", exc_info=True)
except Exception:
pass
return HttpResponseBadRequest(str(ex))
......
......@@ -15,7 +15,7 @@ from rest_framework.exceptions import ValidationError
class SearchFilter(filters.SearchFilter):
def get_search_fields(self, view, request):
search_fields = getattr(view, 'search_fields', [])
search_fields = getattr(view, 'search_fields') or []
lookup_fields = {field:field for field in search_fields}
view_lookup_fields = getattr(view, 'lookup_fields', {})
keys_to_update = set(search_fields) & set(view_lookup_fields.keys())
......
......@@ -9,7 +9,7 @@ import uuid
from django.conf import settings
from django.core.cache import cache
from distutils.util import strtobool
from rest_framework import status
from rest_framework import status, mixins
from rest_framework.response import Response
from cvat.apps.engine.models import Location
......@@ -315,3 +315,17 @@ class SerializeMixin:
file_name = request.query_params.get("filename", "")
return import_func(request, filename=file_name)
return self.upload_data(request)
class PartialUpdateModelMixin:
"""
Update fields of a model instance.
Almost the same as UpdateModelMixin, but has no public PUT / update() method.
"""
def perform_update(self, serializer):
mixins.UpdateModelMixin.perform_update(self, serializer=serializer)
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
return mixins.UpdateModelMixin.update(self, request=request, *args, **kwargs)
......@@ -2,12 +2,26 @@
#
# SPDX-License-Identifier: MIT
from typing import Type
from rest_framework import serializers
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import force_instance
from drf_spectacular.plumbing import force_instance, build_basic_type
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.serializers import PolymorphicProxySerializerExtension
def _copy_serializer(
instance: serializers.Serializer,
*,
_new_type: Type[serializers.Serializer] = None,
**kwargs
) -> serializers.Serializer:
_new_type = _new_type or type(instance)
instance_kwargs = instance._kwargs
instance_kwargs['partial'] = instance.partial # this can be set separately
instance_kwargs.update(kwargs)
return _new_type(*instance._args, **instance._kwargs)
class DataSerializerExtension(OpenApiSerializerExtension):
# *FileSerializer mimics a FileField
# but it is mapped as an object with a file field, which
......@@ -23,40 +37,106 @@ class DataSerializerExtension(OpenApiSerializerExtension):
target_class = 'cvat.apps.engine.serializers.DataSerializer'
def map_serializer(self, auto_schema, direction):
assert isinstance(self.target_class, type)
assert issubclass(self.target_class, serializers.ModelSerializer)
instance = force_instance(self.target_class)
instance = self.target
assert isinstance(instance, serializers.ModelSerializer)
def _get_field(instance, source_name, field_name):
def _get_field(
instance: serializers.ModelSerializer,
source_name: str,
field_name: str
) -> serializers.ModelField:
child_instance = force_instance(instance.fields[source_name].child)
assert isinstance(child_instance, serializers.ModelSerializer)
child_fields = child_instance.fields
assert child_fields.keys() == {'file'} # protect from changes
assert child_fields.keys() == {'file'} # protection from implementation changes
return child_fields[field_name]
def _sanitize_field(field):
def _sanitize_field(field: serializers.ModelField) -> serializers.ModelField:
field.source = None
field.source_attrs = []
return field
def _make_field(source_name, field_name):
def _make_field(source_name: str, field_name: str) -> serializers.ModelField:
return _sanitize_field(_get_field(instance, source_name, field_name))
class _Override(self.target_class): # pylint: disable=inherit-non-class
client_files = serializers.ListField(child=_make_field('client_files', 'file'), default=[])
server_files = serializers.ListField(child=_make_field('server_files', 'file'), default=[])
remote_files = serializers.ListField(child=_make_field('remote_files', 'file'), default=[])
client_files = serializers.ListField(
child=_make_field('client_files', 'file'), default=[])
server_files = serializers.ListField(
child=_make_field('server_files', 'file'), default=[])
remote_files = serializers.ListField(
child=_make_field('remote_files', 'file'), default=[])
return auto_schema._map_serializer(
_copy_serializer(instance, _new_type=_Override, context={'view': auto_schema.view}),
direction, bypass_extensions=False)
class WriteOnceSerializerExtension(OpenApiSerializerExtension):
"""
Enables support for cvat.apps.engine.serializers.WriteOnceMixin in drf-spectacular.
Doesn't block other extensions on the target serializer.
"""
return auto_schema._map_serializer(_Override(), direction, bypass_extensions=False)
match_subclasses = True
target_class = 'cvat.apps.engine.serializers.WriteOnceMixin'
_PROCESSED_INDICATOR_NAME = 'write_once_serializer_extension_processed'
class CustomProxySerializerExtension(PolymorphicProxySerializerExtension):
"""
Allows to patch PolymorphicProxySerializer-based schema.
@classmethod
def _matches(cls, target) -> bool:
if super()._matches(target):
# protect from recursive invocations
assert isinstance(target, serializers.Serializer)
processed = target.context.get(cls._PROCESSED_INDICATOR_NAME, False)
return not processed
return False
Override "target_component" in children classes.
def map_serializer(self, auto_schema, direction):
return auto_schema._map_serializer(
_copy_serializer(self.target, context={
'view': auto_schema.view,
self._PROCESSED_INDICATOR_NAME: True
}),
direction, bypass_extensions=False)
class OpenApiTypeProxySerializerExtension(PolymorphicProxySerializerExtension):
"""
Provides support for OpenApiTypes in the PolymorphicProxySerializer list
"""
priority = 0 # restore normal priority
def _process_serializer(self, auto_schema, serializer, direction):
if isinstance(serializer, OpenApiTypes):
schema = build_basic_type(serializer)
return (None, schema)
else:
return super()._process_serializer(auto_schema=auto_schema,
serializer=serializer, direction=direction)
def map_serializer(self, auto_schema, direction):
""" custom handling for @extend_schema's injection of PolymorphicProxySerializer """
result = super().map_serializer(auto_schema=auto_schema, direction=direction)
if isinstance(self.target.serializers, dict):
required = OpenApiTypes.NONE not in self.target.serializers.values()
else:
required = OpenApiTypes.NONE not in self.target.serializers
if not required:
result['nullable'] = True
return result
class ComponentProxySerializerExtension(OpenApiTypeProxySerializerExtension):
"""
Allows to patch PolymorphicProxySerializer-based component schema.
Override the "target_component" field in children classes.
"""
priority = 1 # higher than in the parent class
target_component: str = ''
@classmethod
......@@ -69,7 +149,7 @@ class CustomProxySerializerExtension(PolymorphicProxySerializerExtension):
return target.component_name == cls.target_component
class AnyOfProxySerializerExtension(CustomProxySerializerExtension):
class AnyOfProxySerializerExtension(ComponentProxySerializerExtension):
"""
Replaces oneOf with anyOf in the generated schema. Useful when
no disciminator field is available, and the options are
......
......@@ -198,7 +198,9 @@ class JobReadSerializer(serializers.ModelSerializer):
class JobWriteSerializer(serializers.ModelSerializer):
assignee = serializers.IntegerField(allow_null=True, required=False)
def to_representation(self, instance):
# FIXME: deal with resquest/response separation
serializer = JobReadSerializer(instance, context=self.context)
return serializer.data
......@@ -307,8 +309,8 @@ class RqStatusSerializer(serializers.Serializer):
progress = serializers.FloatField(max_value=100, default=0)
class WriteOnceMixin:
"""Adds support for write once fields to serializers.
"""
Adds support for write once fields to serializers.
To use it, specify a list of fields as `write_once_fields` on the
serializer's Meta:
......@@ -329,12 +331,15 @@ class WriteOnceMixin:
# We're only interested in PATCH/PUT.
if 'update' in getattr(self.context.get('view'), 'action', ''):
return self._set_write_once_fields(extra_kwargs)
extra_kwargs = self._set_write_once_fields(extra_kwargs)
return extra_kwargs
def _set_write_once_fields(self, extra_kwargs):
"""Set all fields in `Meta.write_once_fields` to read_only."""
"""
Set all fields in `Meta.write_once_fields` to read_only.
"""
write_once_fields = getattr(self.Meta, 'write_once_fields', None)
if not write_once_fields:
return extra_kwargs
......@@ -352,7 +357,7 @@ class WriteOnceMixin:
return extra_kwargs
class DataSerializer(serializers.ModelSerializer):
class DataSerializer(WriteOnceMixin, serializers.ModelSerializer):
image_quality = serializers.IntegerField(min_value=0, max_value=100)
use_zip_chunks = serializers.BooleanField(default=False)
client_files = ClientFileSerializer(many=True, default=[])
......@@ -876,16 +881,16 @@ class AnnotationSerializer(serializers.Serializer):
id = serializers.IntegerField(default=None, allow_null=True)
frame = serializers.IntegerField(min_value=0)
label_id = serializers.IntegerField(min_value=0)
group = serializers.IntegerField(min_value=0, allow_null=True)
source = serializers.CharField(default = 'manual')
group = serializers.IntegerField(min_value=0, allow_null=True, default=None)
source = serializers.CharField(default='manual')
class LabeledImageSerializer(AnnotationSerializer):
attributes = AttributeValSerializer(many=True,
source="labeledimageattributeval_set")
source="labeledimageattributeval_set", default=[])
class ShapeSerializer(serializers.Serializer):
type = serializers.ChoiceField(choices=models.ShapeType.choices())
occluded = serializers.BooleanField()
occluded = serializers.BooleanField(default=False)
outside = serializers.BooleanField(default=False, required=False)
z_order = serializers.IntegerField(default=0)
rotation = serializers.FloatField(default=0, min_value=0, max_value=360)
......@@ -896,7 +901,7 @@ class ShapeSerializer(serializers.Serializer):
class SubLabeledShapeSerializer(ShapeSerializer, AnnotationSerializer):
attributes = AttributeValSerializer(many=True,
source="labeledshapeattributeval_set")
source="labeledshapeattributeval_set", default=[])
class LabeledShapeSerializer(SubLabeledShapeSerializer):
elements = SubLabeledShapeSerializer(many=True, required=False)
......@@ -905,22 +910,22 @@ class TrackedShapeSerializer(ShapeSerializer):
id = serializers.IntegerField(default=None, allow_null=True)
frame = serializers.IntegerField(min_value=0)
attributes = AttributeValSerializer(many=True,
source="trackedshapeattributeval_set")
source="trackedshapeattributeval_set", default=[])
class SubLabeledTrackSerializer(AnnotationSerializer):
shapes = TrackedShapeSerializer(many=True, allow_empty=True,
source="trackedshape_set")
attributes = AttributeValSerializer(many=True,
source="labeledtrackattributeval_set")
source="labeledtrackattributeval_set", default=[])
class LabeledTrackSerializer(SubLabeledTrackSerializer):
elements = SubLabeledTrackSerializer(many=True, required=False)
class LabeledDataSerializer(serializers.Serializer):
version = serializers.IntegerField()
tags = LabeledImageSerializer(many=True)
shapes = LabeledShapeSerializer(many=True)
tracks = LabeledTrackSerializer(many=True)
version = serializers.IntegerField(default=0) # TODO: remove
tags = LabeledImageSerializer(many=True, default=[])
shapes = LabeledShapeSerializer(many=True, default=[])
tracks = LabeledTrackSerializer(many=True, default=[])
class FileInfoSerializer(serializers.Serializer):
name = serializers.CharField(max_length=1024)
......@@ -991,6 +996,10 @@ class IssueReadSerializer(serializers.ModelSerializer):
fields = ('id', 'frame', 'position', 'job', 'owner', 'assignee',
'created_date', 'updated_date', 'comments', 'resolved')
read_only_fields = fields
extra_kwargs = {
'created_date': { 'allow_null': True },
'updated_date': { 'allow_null': True },
}
class IssueWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
......@@ -1010,6 +1019,12 @@ class IssueWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
message=message, owner=db_issue.owner)
return db_issue
def update(self, instance, validated_data):
message = validated_data.pop('message', None)
if message:
raise NotImplementedError('Check https://github.com/cvat-ai/cvat/issues/122')
return super().update(instance, validated_data)
class Meta:
model = models.Issue
fields = ('id', 'frame', 'position', 'job', 'owner', 'assignee',
......
......@@ -313,7 +313,7 @@ class JobGetAPITestCase(APITestCase):
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class JobUpdateAPITestCase(APITestCase):
class JobPartialUpdateAPITestCase(APITestCase):
def setUp(self):
self.client = APIClient()
self.task = create_dummy_db_tasks(self)[0]
......@@ -327,7 +327,7 @@ class JobUpdateAPITestCase(APITestCase):
def _run_api_v2_jobs_id(self, jid, user, data):
with ForceLogin(user, self.client):
response = self.client.put('/api/jobs/{}'.format(jid), data=data, format='json')
response = self.client.patch('/api/jobs/{}'.format(jid), data=data, format='json')
return response
......@@ -382,22 +382,43 @@ class JobUpdateAPITestCase(APITestCase):
response = self._run_api_v2_jobs_id(self.job.id + 10, None, data)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class JobPartialUpdateAPITestCase(JobUpdateAPITestCase):
def test_api_v2_jobs_id_annotator_partial(self):
data = {"stage": StageChoice.ANNOTATION}
response = self._run_api_v2_jobs_id(self.job.id, self.annotator, data)
self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN, response)
def test_api_v2_jobs_id_admin_partial(self):
data = {"assignee_id": self.user.id}
response = self._run_api_v2_jobs_id(self.job.id, self.owner, data)
self._check_request(response, data)
class JobUpdateAPITestCase(APITestCase):
def setUp(self):
self.client = APIClient()
self.task = create_dummy_db_tasks(self)[0]
self.job = Job.objects.filter(segment__task_id=self.task.id).first()
self.job.assignee = self.annotator
self.job.save()
@classmethod
def setUpTestData(cls):
create_db_users(cls)
def _run_api_v2_jobs_id(self, jid, user, data):
with ForceLogin(user, self.client):
response = self.client.patch('/api/jobs/{}'.format(jid), data=data, format='json')
response = self.client.put('/api/jobs/{}'.format(jid), data=data, format='json')
return response
def test_api_v2_jobs_id_annotator_partial(self):
def test_api_v2_jobs_id_annotator(self):
data = {"stage": StageChoice.ANNOTATION}
response = self._run_api_v2_jobs_id(self.job.id, self.annotator, data)
self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN, response)
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED, response)
def test_api_v2_jobs_id_admin_partial(self):
def test_api_v2_jobs_id_admin(self):
data = {"assignee_id": self.user.id}
response = self._run_api_v2_jobs_id(self.job.id, self.owner, data)
self._check_request(response, data)
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED, response)
class JobDataMetaPartialUpdateAPITestCase(APITestCase):
def setUp(self):
......@@ -1987,7 +2008,6 @@ class TaskDeleteAPITestCase(APITestCase):
self.assertFalse(os.path.exists(task_dir))
class TaskUpdateAPITestCase(APITestCase):
def setUp(self):
self.client = APIClient()
......@@ -2003,6 +2023,39 @@ class TaskUpdateAPITestCase(APITestCase):
return response
def _check_api_v2_tasks_id(self, user, data):
for db_task in self.tasks:
response = self._run_api_v2_tasks_id(db_task.id, user, data)
if user is None:
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
else:
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_api_v2_tasks_id_admin(self):
data = { "name": "new name for the task" }
self._check_api_v2_tasks_id(self.admin, data)
def test_api_v2_tasks_id_user(self):
data = { "name": "new name for the task" }
self._check_api_v2_tasks_id(self.user, data)
def test_api_v2_tasks_id_somebody(self):
data = { "name": "new name for the task" }
self._check_api_v2_tasks_id(self.somebody, data)
def test_api_v2_tasks_id_no_auth(self):
data = { "name": "new name for the task" }
self._check_api_v2_tasks_id(None, data)
class TaskPartialUpdateAPITestCase(APITestCase):
def setUp(self):
self.client = APIClient()
@classmethod
def setUpTestData(cls):
create_db_users(cls)
cls.tasks = create_dummy_db_tasks(cls)
def _check_response(self, response, db_task, data):
self.assertEqual(response.status_code, status.HTTP_200_OK)
name = data.get("name", db_task.name)
......@@ -2034,6 +2087,13 @@ class TaskUpdateAPITestCase(APITestCase):
[label["name"] for label in response.data["labels"]]
)
def _run_api_v2_tasks_id(self, tid, user, data):
with ForceLogin(user, self.client):
response = self.client.patch('/api/tasks/{}'.format(tid),
data=data, format="json")
return response
def _check_api_v2_tasks_id(self, user, data):
for db_task in self.tasks:
response = self._run_api_v2_tasks_id(db_task.id, user, data)
......@@ -2077,32 +2137,6 @@ class TaskUpdateAPITestCase(APITestCase):
}
self._check_api_v2_tasks_id(self.user, data)
def test_api_v2_tasks_id_somebody(self):
data = {
"name": "new name for the task",
"labels": [{
"name": "test",
}]
}
self._check_api_v2_tasks_id(self.somebody, data)
def test_api_v2_tasks_id_no_auth(self):
data = {
"name": "new name for the task",
"labels": [{
"name": "test",
}]
}
self._check_api_v2_tasks_id(None, data)
class TaskPartialUpdateAPITestCase(TaskUpdateAPITestCase):
def _run_api_v2_tasks_id(self, tid, user, data):
with ForceLogin(user, self.client):
response = self.client.patch('/api/tasks/{}'.format(tid),
data=data, format="json")
return response
def test_api_v2_tasks_id_admin_partial(self):
data = {
"name": "new name for the task #2",
......
此差异已折叠。
......@@ -45,10 +45,10 @@ def get_git_changeset():
so it's sufficient for generating the development version numbers.
"""
repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
git_log = subprocess.Popen(
'git log --pretty=format:%ct --quiet -1 HEAD',
git_log = subprocess.Popen( # nosec: B603, B607
['git', 'log', '--pretty=format:%ct', '--quiet', '-1', 'HEAD'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=True, cwd=repo_dir, universal_newlines=True,
cwd=repo_dir, universal_newlines=True,
)
timestamp = git_log.communicate()[0]
try:
......@@ -56,4 +56,3 @@ def get_git_changeset():
except ValueError:
return None
return timestamp.strftime('%Y%m%d%H%M%S')
......@@ -8,9 +8,9 @@ import os
from pathlib import Path
import pytest
from cvat_sdk import exceptions, make_client
from cvat_sdk.core.tasks import TaskProxy
from cvat_sdk.core.types import ResourceType
from cvat_sdk import make_client
from cvat_sdk.api_client import exceptions
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from PIL import Image
from sdk.util import generate_coco_json
......@@ -41,8 +41,6 @@ class TestCLI:
yield
self.tmp_path = None
@pytest.fixture
def fxt_image_file(self):
img_path = self.tmp_path / "img_0.png"
......@@ -61,7 +59,7 @@ class TestCLI:
yield ann_filename
@pytest.fixture
def fxt_backup_file(self, fxt_new_task: TaskProxy, fxt_coco_file: str):
def fxt_backup_file(self, fxt_new_task: Task, fxt_coco_file: str):
backup_path = self.tmp_path / "backup.zip"
fxt_new_task.import_annotations("COCO 1.0", filename=fxt_coco_file)
......@@ -73,7 +71,7 @@ class TestCLI:
def fxt_new_task(self):
files = generate_images(str(self.tmp_path), 5)
task = self.client.create_task(
task = self.client.tasks.create_from_data(
spec={
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
......@@ -114,30 +112,28 @@ class TestCLI:
)
task_id = int(stdout.split()[-1])
assert self.client.retrieve_task(task_id).size == 5
assert self.client.tasks.retrieve(task_id).size == 5
def test_can_list_tasks_in_simple_format(self, fxt_new_task: TaskProxy):
def test_can_list_tasks_in_simple_format(self, fxt_new_task: Task):
output = self.run_cli("ls")
results = output.split("\n")
assert any(str(fxt_new_task.id) in r for r in results)
def test_can_list_tasks_in_json_format(self, fxt_new_task: TaskProxy):
def test_can_list_tasks_in_json_format(self, fxt_new_task: Task):
output = self.run_cli("ls", "--json")
results = json.loads(output)
assert any(r["id"] == fxt_new_task.id for r in results)
def test_can_delete_task(self, fxt_new_task: TaskProxy):
def test_can_delete_task(self, fxt_new_task: Task):
self.run_cli("delete", str(fxt_new_task.id))
with pytest.raises(exceptions.ApiException) as capture:
with pytest.raises(exceptions.NotFoundException):
fxt_new_task.fetch()
assert capture.value.status == 404
def test_can_download_task_annotations(self, fxt_new_task: TaskProxy):
filename: Path = self.tmp_path / "task_{fxt_new_task.id}-cvat.zip"
def test_can_download_task_annotations(self, fxt_new_task: Task):
filename = self.tmp_path / "task_{fxt_new_task.id}-cvat.zip"
self.run_cli(
"dump",
str(fxt_new_task.id),
......@@ -152,8 +148,8 @@ class TestCLI:
assert 0 < filename.stat().st_size
def test_can_download_task_backup(self, fxt_new_task: TaskProxy):
filename: Path = self.tmp_path / "task_{fxt_new_task.id}-cvat.zip"
def test_can_download_task_backup(self, fxt_new_task: Task):
filename = self.tmp_path / "task_{fxt_new_task.id}-cvat.zip"
self.run_cli(
"export",
str(fxt_new_task.id),
......@@ -165,7 +161,7 @@ class TestCLI:
assert 0 < filename.stat().st_size
@pytest.mark.parametrize("quality", ("compressed", "original"))
def test_can_download_task_frames(self, fxt_new_task: TaskProxy, quality: str):
def test_can_download_task_frames(self, fxt_new_task: Task, quality: str):
out_dir = str(self.tmp_path / "downloads")
self.run_cli(
"frames",
......@@ -182,13 +178,13 @@ class TestCLI:
"task_{}_frame_{:06d}.jpg".format(fxt_new_task.id, i) for i in range(2)
}
def test_can_upload_annotations(self, fxt_new_task: TaskProxy, fxt_coco_file: Path):
def test_can_upload_annotations(self, fxt_new_task: Task, fxt_coco_file: Path):
self.run_cli("upload", str(fxt_new_task.id), str(fxt_coco_file), "--format", "COCO 1.0")
def test_can_create_from_backup(self, fxt_new_task: TaskProxy, fxt_backup_file: Path):
def test_can_create_from_backup(self, fxt_new_task: Task, fxt_backup_file: Path):
stdout = self.run_cli("import", str(fxt_backup_file))
task_id = int(stdout.split()[-1])
assert task_id
assert task_id != fxt_new_task.id
assert self.client.retrieve_task(task_id).size == fxt_new_task.size
assert self.client.tasks.retrieve(task_id).size == fxt_new_task.size
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import json
from http import HTTPStatus
import pytest
from cvat_sdk.api_client import ApiClient, Configuration, models
from shared.utils.config import BASE_URL, USER_PASS, make_api_client
@pytest.mark.usefixtures("dontchangedb")
class TestBasicAuth:
def test_can_do_basic_auth(self, admin_user: str):
username = admin_user
config = Configuration(host=BASE_URL, username=username, password=USER_PASS)
with ApiClient(config) as client:
(user, response) = client.users_api.retrieve_self()
assert response.status == HTTPStatus.OK
assert user.username == username
@pytest.mark.usefixtures("changedb")
class TestTokenAuth:
@staticmethod
def login(client: ApiClient, username: str) -> models.Token:
(auth, _) = client.auth_api.create_login(
models.LoginRequest(username=username, password=USER_PASS)
)
client.set_default_header("Authorization", "Token " + auth.key)
return auth
@classmethod
def make_client(cls, username: str) -> ApiClient:
with ApiClient(Configuration(host=BASE_URL)) as client:
cls.login(client, username)
return client
def test_can_do_token_auth_and_manage_cookies(self, admin_user: str):
username = admin_user
with ApiClient(Configuration(host=BASE_URL)) as client:
auth = self.login(client, username=username)
assert "sessionid" in client.cookies
assert "csrftoken" in client.cookies
assert auth.key
(user, response) = client.users_api.retrieve_self()
assert response.status == HTTPStatus.OK
assert user.username == username
def test_can_do_logout(self, admin_user: str):
username = admin_user
with self.make_client(username) as client:
(_, response) = client.auth_api.create_logout()
assert response.status == HTTPStatus.OK
(_, response) = client.users_api.retrieve_self(
_parse_response=False, _check_status=False
)
assert response.status == HTTPStatus.UNAUTHORIZED
@pytest.mark.usefixtures("changedb")
class TestCredentialsManagement:
def test_can_register(self):
username = "newuser"
email = "123@456.com"
with ApiClient(Configuration(host=BASE_URL)) as client:
(user, response) = client.auth_api.create_register(
models.RestrictedRegisterRequest(
username=username, password1=USER_PASS, password2=USER_PASS, email=email
)
)
assert response.status == HTTPStatus.CREATED
assert user.username == username
with make_api_client(username) as client:
(user, response) = client.users_api.retrieve_self()
assert response.status == HTTPStatus.OK
assert user.username == username
assert user.email == email
def test_can_change_password(self, admin_user: str):
username = admin_user
new_pass = "5w4knrqaW#$@gewa"
with make_api_client(username) as client:
(info, response) = client.auth_api.create_password_change(
models.PasswordChangeRequest(
old_password=USER_PASS, new_password1=new_pass, new_password2=new_pass
)
)
assert response.status == HTTPStatus.OK
assert info.detail == "New password has been saved."
(_, response) = client.users_api.retrieve_self(
_parse_response=False, _check_status=False
)
assert response.status == HTTPStatus.UNAUTHORIZED
client.configuration.password = new_pass
(user, response) = client.users_api.retrieve_self()
assert response.status == HTTPStatus.OK
assert user.username == username
def test_can_report_weak_password(self, admin_user: str):
username = admin_user
new_pass = "pass"
with make_api_client(username) as client:
(_, response) = client.auth_api.create_password_change(
models.PasswordChangeRequest(
old_password=USER_PASS, new_password1=new_pass, new_password2=new_pass
),
_parse_response=False,
_check_status=False,
)
assert response.status == HTTPStatus.BAD_REQUEST
assert json.loads(response.data) == {
"new_password2": [
"This password is too short. It must contain at least 8 characters.",
"This password is too common.",
]
}
def test_can_report_mismatching_passwords(self, admin_user: str):
username = admin_user
with make_api_client(username) as client:
(_, response) = client.auth_api.create_password_change(
models.PasswordChangeRequest(
old_password=USER_PASS, new_password1="3j4tb13/T$#", new_password2="q#@$n34g5"
),
_parse_response=False,
_check_status=False,
)
assert response.status == HTTPStatus.BAD_REQUEST
assert json.loads(response.data) == {
"new_password2": ["The two password fields didn’t match."]
}
......@@ -3,48 +3,78 @@
#
# SPDX-License-Identifier: MIT
import pytest
import json
from copy import deepcopy
from http import HTTPStatus
import pytest
from cvat_sdk import models
from deepdiff import DeepDiff
from copy import deepcopy
from shared.utils.config import post_method, patch_method
from cvat_sdk.api_client import exceptions
from shared.utils.config import make_api_client
@pytest.mark.usefixtures('changedb')
@pytest.mark.usefixtures("changedb")
class TestPostIssues:
def _test_check_response(self, user, data, is_allow, **kwargs):
response = post_method(user, 'issues', data, **kwargs)
with make_api_client(user) as client:
(_, response) = client.issues_api.create(
models.IssueWriteRequest(**data),
**kwargs,
_parse_response=False,
_check_status=False,
)
if is_allow:
assert response.status_code == HTTPStatus.CREATED
assert user == response.json()['owner']['username']
assert data['message'] == response.json()['comments'][0]['message']
assert DeepDiff(data, response.json(),
exclude_regex_paths=r"root\['created_date|updated_date|comments|id|owner|message'\]") == {}
assert response.status == HTTPStatus.CREATED
response_json = json.loads(response.data)
assert user == response_json["owner"]["username"]
assert data["message"] == response_json["comments"][0]["message"]
assert (
DeepDiff(
data,
response_json,
exclude_regex_paths=r"root\['created_date|updated_date|comments|id|owner|message'\]",
)
== {}
)
else:
assert response.status_code == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize('org', [''])
@pytest.mark.parametrize('privilege, job_staff, is_allow', [
('admin', True, True), ('admin', False, True),
('business', True, True), ('business', False, False),
('worker', True, True), ('worker', False, False),
('user', True, True), ('user', False, False)
])
def test_user_create_issue(self, org, privilege, job_staff, is_allow,
find_job_staff_user, find_users, jobs_by_org):
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize("org", [""])
@pytest.mark.parametrize(
"privilege, job_staff, is_allow",
[
("admin", True, True),
("admin", False, True),
("business", True, True),
("business", False, False),
("worker", True, True),
("worker", False, False),
("user", True, True),
("user", False, False),
],
)
def test_user_create_issue(
self, org, privilege, job_staff, is_allow, find_job_staff_user, find_users, jobs_by_org
):
users = find_users(privilege=privilege)
jobs = jobs_by_org[org]
username, jid = find_job_staff_user(jobs, users, job_staff)
job, = filter(lambda job: job['id'] == jid, jobs)
(job,) = filter(lambda job: job["id"] == jid, jobs)
data = {
"assignee": None,
"comments": [],
"job": jid,
"frame": job['start_frame'],
"frame": job["start_frame"],
"position": [
0., 0., 1., 1.,
0.0,
0.0,
1.0,
1.0,
],
"resolved": False,
"message": "lorem ipsum",
......@@ -52,16 +82,23 @@ class TestPostIssues:
self._test_check_response(username, data, is_allow)
@pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('role, job_staff, is_allow', [
('maintainer', False, True), ('owner', False, True),
('supervisor', False, False), ('worker', False, False),
('maintainer', True, True), ('owner', True, True),
('supervisor', True, True), ('worker', True, True)
])
def test_member_create_issue(self, org, role, job_staff, is_allow,
find_job_staff_user, find_users, jobs_by_org, jobs):
@pytest.mark.parametrize("org", [2])
@pytest.mark.parametrize(
"role, job_staff, is_allow",
[
("maintainer", False, True),
("owner", False, True),
("supervisor", False, False),
("worker", False, False),
("maintainer", True, True),
("owner", True, True),
("supervisor", True, True),
("worker", True, True),
],
)
def test_member_create_issue(
self, org, role, job_staff, is_allow, find_job_staff_user, find_users, jobs_by_org, jobs
):
users = find_users(role=role, org=org)
username, jid = find_job_staff_user(jobs_by_org[org], users, job_staff)
job = jobs[jid]
......@@ -70,50 +107,85 @@ class TestPostIssues:
"assignee": None,
"comments": [],
"job": jid,
"frame": job['start_frame'],
"frame": job["start_frame"],
"position": [
0., 0., 1., 1.,
0.0,
0.0,
1.0,
1.0,
],
"resolved": False,
"message": "lorem ipsum",
}
self._test_check_response(username, data, is_allow, org_id=org)
@pytest.mark.usefixtures('changedb')
@pytest.mark.usefixtures("changedb")
class TestPatchIssues:
def _test_check_response(self, user, issue_id, data, is_allow, **kwargs):
response = patch_method(user, f'issues/{issue_id}', data,
action='update', **kwargs)
with make_api_client(user) as client:
(_, response) = client.issues_api.partial_update(
issue_id,
patched_issue_write_request=models.PatchedIssueWriteRequest(**data),
**kwargs,
_parse_response=False,
_check_status=False,
)
if is_allow:
assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json(),
exclude_regex_paths=r"root\['created_date|updated_date|comments|id|owner'\]") == {}
assert response.status == HTTPStatus.OK
assert (
DeepDiff(
data,
json.loads(response.data),
exclude_regex_paths=r"root\['created_date|updated_date|comments|id|owner'\]",
)
== {}
)
else:
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.status == HTTPStatus.FORBIDDEN
@pytest.fixture(scope='class')
@pytest.fixture(scope="class")
def request_data(self, issues):
def get_data(issue_id):
data = deepcopy(issues[issue_id])
data['resolved'] = not data['resolved']
data.pop('comments')
data.pop('updated_date')
data.pop('id')
data.pop('owner')
data["resolved"] = not data["resolved"]
data.pop("comments")
data.pop("updated_date")
data.pop("id")
data.pop("owner")
return data
return get_data
@pytest.mark.parametrize('org', [''])
@pytest.mark.parametrize('privilege, issue_staff, issue_admin, is_allow', [
('admin', True, None, True), ('admin', False, None, True),
('business', True, None, True), ('business', False, None, False),
('user', True, None, True), ('user', False, None, False),
('worker', False, True, True), ('worker', True, False, False),
('worker', False, False, False)
])
def test_user_update_issue(self, org, privilege, issue_staff, issue_admin, is_allow,
find_issue_staff_user, find_users, issues_by_org, request_data):
@pytest.mark.parametrize("org", [""])
@pytest.mark.parametrize(
"privilege, issue_staff, issue_admin, is_allow",
[
("admin", True, None, True),
("admin", False, None, True),
("business", True, None, True),
("business", False, None, False),
("user", True, None, True),
("user", False, None, False),
("worker", False, True, True),
("worker", True, False, False),
("worker", False, False, False),
],
)
def test_user_update_issue(
self,
org,
privilege,
issue_staff,
issue_admin,
is_allow,
find_issue_staff_user,
find_users,
issues_by_org,
request_data,
):
users = find_users(privilege=privilege)
issues = issues_by_org[org]
username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin)
......@@ -121,19 +193,135 @@ class TestPatchIssues:
data = request_data(issue_id)
self._test_check_response(username, issue_id, data, is_allow)
@pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('role, issue_staff, issue_admin, is_allow', [
('maintainer', True, None, True), ('maintainer', False, None, True),
('supervisor', True, None, True), ('supervisor', False, None, False),
('owner', True, None, True), ('owner', False, None, True),
('worker', False, True, True), ('worker', True, False, False),
('worker', False, False, False)
])
def test_member_update_issue(self, org, role, issue_staff, issue_admin, is_allow,
find_issue_staff_user, find_users, issues_by_org, request_data):
@pytest.mark.parametrize("org", [2])
@pytest.mark.parametrize(
"role, issue_staff, issue_admin, is_allow",
[
("maintainer", True, None, True),
("maintainer", False, None, True),
("supervisor", True, None, True),
("supervisor", False, None, False),
("owner", True, None, True),
("owner", False, None, True),
("worker", False, True, True),
("worker", True, False, False),
("worker", False, False, False),
],
)
def test_member_update_issue(
self,
org,
role,
issue_staff,
issue_admin,
is_allow,
find_issue_staff_user,
find_users,
issues_by_org,
request_data,
):
users = find_users(role=role, org=org)
issues = issues_by_org[org]
username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin)
data = request_data(issue_id)
self._test_check_response(username, issue_id, data, is_allow, org_id=org)
@pytest.mark.xfail(raises=exceptions.ServiceException,
reason="server bug, https://github.com/cvat-ai/cvat/issues/122")
def test_cant_update_message(self, admin_user: str, issues_by_org):
org = 2
issue_id = issues_by_org[org][0]['id']
with make_api_client(admin_user) as client:
client.issues_api.partial_update(
issue_id,
patched_issue_write_request=models.PatchedIssueWriteRequest(message="foo"),
org_id=org,
)
@pytest.mark.usefixtures("changedb")
class TestDeleteIssues:
def _test_check_response(self, user, issue_id, expect_success, **kwargs):
with make_api_client(user) as client:
(_, response) = client.issues_api.destroy(
issue_id,
**kwargs,
_parse_response=False,
_check_status=False,
)
if expect_success:
assert response.status == HTTPStatus.NO_CONTENT
(_, response) = client.issues_api.retrieve(
issue_id, _parse_response=False, _check_status=False
)
assert response.status == HTTPStatus.NOT_FOUND
else:
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize("org", [""])
@pytest.mark.parametrize(
"privilege, issue_staff, issue_admin, expect_success",
[
("admin", True, None, True),
("admin", False, None, True),
("business", True, None, True),
("business", False, None, False),
("user", True, None, True),
("user", False, None, False),
("worker", False, True, True),
("worker", True, False, False),
("worker", False, False, False),
],
)
def test_user_delete_issue(
self,
org,
privilege,
issue_staff,
issue_admin,
expect_success,
find_issue_staff_user,
find_users,
issues_by_org,
):
users = find_users(privilege=privilege)
issues = issues_by_org[org]
username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin)
self._test_check_response(username, issue_id, expect_success)
@pytest.mark.parametrize("org", [2])
@pytest.mark.parametrize(
"role, issue_staff, issue_admin, expect_success",
[
("maintainer", True, None, True),
("maintainer", False, None, True),
("supervisor", True, None, True),
("supervisor", False, None, False),
("owner", True, None, True),
("owner", False, None, True),
("worker", False, True, True),
("worker", True, False, False),
("worker", False, False, False),
],
)
def test_org_member_delete_issue(
self,
org,
role,
issue_staff,
issue_admin,
expect_success,
find_issue_staff_user,
find_users,
issues_by_org,
):
users = find_users(role=role, org=org)
issues = issues_by_org[org]
username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin)
self._test_check_response(username, issue_id, expect_success, org_id=org)
......@@ -4,10 +4,14 @@
# SPDX-License-Identifier: MIT
from http import HTTPStatus
import json
from typing import List
from cvat_sdk.core.helpers import get_paginated_collection
from deepdiff import DeepDiff
import pytest
from copy import deepcopy
from shared.utils.config import get_method, patch_method
from shared.utils.config import make_api_client
from .utils import export_dataset
def get_job_staff(job, tasks, projects):
job_staff = []
......@@ -42,15 +46,17 @@ def filter_jobs(jobs, tasks, org):
@pytest.mark.usefixtures('dontchangedb')
class TestGetJobs:
def _test_get_job_200(self, user, jid, data, **kwargs):
response = get_method(user, f'jobs/{jid}', **kwargs)
assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json(), exclude_paths="root['updated_date']",
ignore_order=True) == {}
with make_api_client(user) as client:
(_, response) = client.jobs_api.retrieve(jid, **kwargs)
assert response.status == HTTPStatus.OK
assert DeepDiff(data, json.loads(response.data), exclude_paths="root['updated_date']",
ignore_order=True) == {}
def _test_get_job_403(self, user, jid, **kwargs):
response = get_method(user, f'jobs/{jid}', **kwargs)
assert response.status_code == HTTPStatus.FORBIDDEN
with make_api_client(user) as client:
(_, response) = client.jobs_api.retrieve(jid, **kwargs,
_check_status=False, _parse_response=False)
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize('org', [None, '', 1, 2])
def test_admin_get_job(self, jobs, tasks, org):
......@@ -82,15 +88,17 @@ class TestGetJobs:
@pytest.mark.usefixtures('dontchangedb')
class TestListJobs:
def _test_list_jobs_200(self, user, data, **kwargs):
response = get_method(user, 'jobs', **kwargs, page_size='all')
assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json()['results'], exclude_paths="root['updated_date']",
ignore_order=True) == {}
with make_api_client(user) as client:
results = get_paginated_collection(client.jobs_api.list_endpoint,
return_json=True, **kwargs)
assert DeepDiff(data, results, exclude_paths="root['updated_date']",
ignore_order=True) == {}
def _test_list_jobs_403(self, user, **kwargs):
response = get_method(user, 'jobs', **kwargs)
assert response.status_code == HTTPStatus.FORBIDDEN
with make_api_client(user) as client:
(_, response) = client.jobs_api.list(**kwargs,
_check_status=False, _parse_response=False)
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize('org', [None, '', 1, 2])
def test_admin_list_jobs(self, jobs, tasks, org):
......@@ -119,52 +127,54 @@ class TestListJobs:
@pytest.mark.usefixtures('dontchangedb')
class TestGetAnnotations:
def _test_get_job_annotations_200(self, user, jid, data, **kwargs):
response = get_method(user, f'jobs/{jid}/annotations', **kwargs)
response_data = response.json()
response_data['shapes'] = sorted(response_data['shapes'], key=lambda a: a['id'])
with make_api_client(user) as client:
(_, response) = client.jobs_api.retrieve_annotations(jid, **kwargs)
assert response.status == HTTPStatus.OK
assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response_data,
exclude_regex_paths=r"root\['version|updated_date'\]") == {}
response_data = json.loads(response.data)
response_data['shapes'] = sorted(response_data['shapes'], key=lambda a: a['id'])
assert DeepDiff(data, response_data,
exclude_regex_paths=r"root\['version|updated_date'\]") == {}
def _test_get_job_annotations_403(self, user, jid, **kwargs):
response = get_method(user, f'jobs/{jid}/annotations', **kwargs)
assert response.status_code == HTTPStatus.FORBIDDEN
with make_api_client(user) as client:
(_, response) = client.jobs_api.retrieve_annotations(jid, **kwargs,
_check_status=False, _parse_response=False)
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize('org', [''])
@pytest.mark.parametrize('groups, job_staff, is_allow', [
@pytest.mark.parametrize('groups, job_staff, expect_success', [
(['admin'], True, True), (['admin'], False, True),
(['business'], True, True), (['business'], False, False),
(['worker'], True, True), (['worker'], False, False),
(['user'], True, True), (['user'], False, False)
])
def test_user_get_job_annotations(self, org, groups, job_staff,
is_allow, users, jobs, tasks, annotations, find_job_staff_user):
expect_success, users, jobs, tasks, annotations, find_job_staff_user):
users = [u for u in users if u['groups'] == groups]
jobs, kwargs = filter_jobs(jobs, tasks, org)
username, job_id = find_job_staff_user(jobs, users, job_staff)
if is_allow:
if expect_success:
self._test_get_job_annotations_200(username,
job_id, annotations['job'][str(job_id)], **kwargs)
else:
self._test_get_job_annotations_403(username, job_id, **kwargs)
@pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('role, job_staff, is_allow', [
@pytest.mark.parametrize('role, job_staff, expect_success', [
('owner', True, True), ('owner', False, True),
('maintainer', True, True), ('maintainer', False, True),
('supervisor', True, True), ('supervisor', False, False),
('worker', True, True), ('worker', False, False),
])
def test_member_get_job_annotations(self, org, role, job_staff, is_allow,
def test_member_get_job_annotations(self, org, role, job_staff, expect_success,
jobs, tasks, find_job_staff_user, annotations, find_users):
users = find_users(org=org, role=role)
jobs, kwargs = filter_jobs(jobs, tasks, org)
username, jid = find_job_staff_user(jobs, users, job_staff)
if is_allow:
if expect_success:
data = annotations['job'][str(jid)]
data['shapes'] = sorted(data['shapes'], key=lambda a: a['id'])
self._test_get_job_annotations_200(username, jid, data, **kwargs)
......@@ -172,17 +182,17 @@ class TestGetAnnotations:
self._test_get_job_annotations_403(username, jid, **kwargs)
@pytest.mark.parametrize('org', [1])
@pytest.mark.parametrize('privilege, is_allow', [
@pytest.mark.parametrize('privilege, expect_success', [
('admin', True), ('business', False), ('worker', False), ('user', False)
])
def test_non_member_get_job_annotations(self, org, privilege, is_allow,
def test_non_member_get_job_annotations(self, org, privilege, expect_success,
jobs, tasks, find_job_staff_user, annotations, find_users):
users = find_users(privilege=privilege, exclude_org=org)
jobs, kwargs = filter_jobs(jobs, tasks, org)
username, job_id = find_job_staff_user(jobs, users, False)
kwargs = {'org_id': org}
if is_allow:
if expect_success:
self._test_get_job_annotations_200(username,
job_id, annotations['job'][str(job_id)], **kwargs)
else:
......@@ -190,15 +200,25 @@ class TestGetAnnotations:
@pytest.mark.usefixtures('changedb')
class TestPatchJobAnnotations:
_ORG = 2
def _check_respone(self, username, jid, expect_success, data=None, org=None):
kwargs = {}
if org is not None:
if isinstance(org, str):
kwargs['org'] = org
else:
kwargs['org_id'] = org
def _test_check_respone(self, is_allow, response, data=None):
if is_allow:
assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json(),
exclude_regex_paths=r"root\['version|updated_date'\]") == {}
else:
assert response.status_code == HTTPStatus.FORBIDDEN
with make_api_client(username) as client:
(_, response) = client.jobs_api.partial_update_annotations(id=jid,
patched_labeled_data_request=deepcopy(data), action='update', **kwargs,
_parse_response=expect_success, _check_status=expect_success)
if expect_success:
assert response.status == HTTPStatus.OK
assert DeepDiff(data, json.loads(response.data),
exclude_regex_paths=r"root\['version|updated_date'\]") == {}
else:
assert response.status == HTTPStatus.FORBIDDEN
@pytest.fixture(scope='class')
def request_data(self, annotations):
......@@ -210,13 +230,13 @@ class TestPatchJobAnnotations:
return get_data
@pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('role, job_staff, is_allow', [
@pytest.mark.parametrize('role, job_staff, expect_success', [
('maintainer', False, True), ('owner', False, True),
('supervisor', False, False), ('worker', False, False),
('maintainer', True, True), ('owner', True, True),
('supervisor', True, True), ('worker', True, True)
])
def test_member_update_job_annotations(self, org, role, job_staff, is_allow,
def test_member_update_job_annotations(self, org, role, job_staff, expect_success,
find_job_staff_user, find_users, request_data, jobs_by_org, filter_jobs_with_shapes):
users = find_users(role=role, org=org)
jobs = jobs_by_org[org]
......@@ -224,17 +244,13 @@ class TestPatchJobAnnotations:
username, jid = find_job_staff_user(filtered_jobs, users, job_staff)
data = request_data(jid)
response = patch_method(username, f'jobs/{jid}/annotations',
data, org_id=org, action='update')
self._test_check_respone(is_allow, response, data)
self._check_respone(username, jid, expect_success, data, org=org)
@pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('privilege, is_allow', [
@pytest.mark.parametrize('privilege, expect_success', [
('admin', True), ('business', False), ('worker', False), ('user', False)
])
def test_non_member_update_job_annotations(self, org, privilege, is_allow,
def test_non_member_update_job_annotations(self, org, privilege, expect_success,
find_job_staff_user, find_users, request_data, jobs_by_org, filter_jobs_with_shapes):
users = find_users(privilege=privilege, exclude_org=org)
jobs = jobs_by_org[org]
......@@ -242,19 +258,16 @@ class TestPatchJobAnnotations:
username, jid = find_job_staff_user(filtered_jobs, users, False)
data = request_data(jid)
response = patch_method(username, f'jobs/{jid}/annotations', data,
org_id=org, action='update')
self._test_check_respone(is_allow, response, data)
self._check_respone(username, jid, expect_success, data, org=org)
@pytest.mark.parametrize('org', [''])
@pytest.mark.parametrize('privilege, job_staff, is_allow', [
@pytest.mark.parametrize('privilege, job_staff, expect_success', [
('admin', True, True), ('admin', False, True),
('business', True, True), ('business', False, False),
('worker', True, True), ('worker', False, False),
('user', True, True), ('user', False, False)
])
def test_user_update_job_annotations(self, org, privilege, job_staff, is_allow,
def test_user_update_job_annotations(self, org, privilege, job_staff, expect_success,
find_job_staff_user, find_users, request_data, jobs_by_org, filter_jobs_with_shapes):
users = find_users(privilege=privilege)
jobs = jobs_by_org[org]
......@@ -262,15 +275,10 @@ class TestPatchJobAnnotations:
username, jid = find_job_staff_user(filtered_jobs, users, job_staff)
data = request_data(jid)
response = patch_method(username, f'jobs/{jid}/annotations', data,
org_id=org, action='update')
self._test_check_respone(is_allow, response, data)
self._check_respone(username, jid, expect_success, data, org=org)
@pytest.mark.usefixtures('changedb')
class TestPatchJob:
_ORG = 2
@pytest.fixture(scope='class')
def find_task_staff_user(self, is_task_staff):
def find(jobs, users, is_staff):
......@@ -300,24 +308,47 @@ class TestPatchJob:
return find_new_assignee
@pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('role, task_staff, is_allow', [
@pytest.mark.parametrize('role, task_staff, expect_success', [
('maintainer', False, True), ('owner', False, True),
('supervisor', False, False), ('worker', False, False),
('maintainer', True, True), ('owner', True, True),
('supervisor', True, True), ('worker', True, True)
])
def test_member_update_job_assignee(self, org, role, task_staff, is_allow,
def test_member_update_job_assignee(self, org, role, task_staff, expect_success,
find_task_staff_user, find_users, jobs_by_org, new_assignee, expected_data):
users, jobs = find_users(role=role, org=org), jobs_by_org[org]
user, jid = find_task_staff_user(jobs, users, task_staff)
assignee = new_assignee(jid, user['id'])
response = patch_method(user['username'], f'jobs/{jid}',
{'assignee': assignee}, org_id=self._ORG)
with make_api_client(user['username']) as client:
(_, response) = client.jobs_api.partial_update(id=jid,
patched_job_write_request={'assignee': assignee}, org_id=org,
_parse_response=expect_success, _check_status=expect_success)
if expect_success:
assert response.status == HTTPStatus.OK
assert DeepDiff(expected_data(jid, assignee), json.loads(response.data),
exclude_paths="root['updated_date']", ignore_order=True) == {}
else:
assert response.status == HTTPStatus.FORBIDDEN
if is_allow:
assert response.status_code == HTTPStatus.OK
assert DeepDiff(expected_data(jid, assignee), response.json(),
exclude_paths="root['updated_date']", ignore_order=True) == {}
else:
assert response.status_code == HTTPStatus.FORBIDDEN
@pytest.mark.usefixtures('dontchangedb')
class TestJobDataset:
def _export_dataset(self, username, jid, **kwargs):
with make_api_client(username) as api_client:
return export_dataset(api_client.jobs_api.retrieve_dataset_endpoint, id=jid, **kwargs)
def _export_annotations(self, username, jid, **kwargs):
with make_api_client(username) as api_client:
return export_dataset(api_client.jobs_api.retrieve_annotations_endpoint,
id=jid, **kwargs)
def test_can_export_dataset(self, admin_user: str, jobs_with_shapes: List):
job = jobs_with_shapes[0]
response = self._export_dataset(admin_user, job['id'], format='CVAT for images 1.1')
assert response.data
def test_can_export_annotations(self, admin_user: str, jobs_with_shapes: List):
job = jobs_with_shapes[0]
response = self._export_annotations(admin_user, job['id'], format='CVAT for images 1.1')
assert response.data
......@@ -13,9 +13,8 @@ import pytest
from copy import deepcopy
from deepdiff import DeepDiff
from cvat_sdk.models import DatasetFileRequest, ProjectWriteRequest
from shared.utils.config import get_method, patch_method, make_api_client
from .utils import export_dataset
@pytest.mark.usefixtures('dontchangedb')
......@@ -229,12 +228,12 @@ class TestGetProjectBackup:
class TestPostProjects:
def _test_create_project_201(self, user, spec, **kwargs):
with make_api_client(user) as api_client:
(_, response) = api_client.projects_api.create(ProjectWriteRequest(**spec), **kwargs)
(_, response) = api_client.projects_api.create(spec, **kwargs)
assert response.status == HTTPStatus.CREATED
def _test_create_project_403(self, user, spec, **kwargs):
with make_api_client(user) as api_client:
(_, response) = api_client.projects_api.create(ProjectWriteRequest(**spec), **kwargs,
(_, response) = api_client.projects_api.create(spec, **kwargs,
_parse_response=False, _check_status=False)
assert response.status == HTTPStatus.FORBIDDEN
......@@ -316,43 +315,30 @@ class TestPostProjects:
self._test_create_project_201(user['username'], spec, org_id=user['org'])
@pytest.mark.usefixtures("changedb")
@pytest.mark.usefixtures("restore_cvat_data")
class TestImportExportDatasetProject:
def _test_export_project(self, username, project_id, format_name):
def _test_export_project(self, username, pid, format_name):
with make_api_client(username) as api_client:
while True:
(_, response) = api_client.projects_api.retrieve_dataset(id=project_id,
format=format_name)
if response.status == HTTPStatus.CREATED:
break
(_, response) = api_client.projects_api.retrieve_dataset(id=project_id,
format=format_name, action='download')
assert response.status == HTTPStatus.OK
return response
return export_dataset(api_client.projects_api.retrieve_dataset_endpoint,
id=pid, format=format_name)
def _test_import_project(self, username, project_id, format_name, data):
with make_api_client(username) as api_client:
(_, response) = api_client.projects_api.create_dataset(id=project_id,
format=format_name, dataset_file_request=DatasetFileRequest(**data),
format=format_name, dataset_write_request=deepcopy(data),
_content_type="multipart/form-data")
assert response.status == HTTPStatus.ACCEPTED
while True:
# TODO: Request schema doesn't describe this capability.
# It's better be refactored to a separate endpoint to get request status
response = get_method(username, f'projects/{project_id}/dataset',
# TODO: It's better be refactored to a separate endpoint to get request status
(_, response) = api_client.projects_api.retrieve_dataset(project_id,
action='import_status')
response.raise_for_status()
if response.status_code == HTTPStatus.CREATED:
if response.status == HTTPStatus.CREATED:
break
def test_can_import_dataset_in_org(self):
username = 'admin1'
def test_can_import_dataset_in_org(self, admin_user):
project_id = 4
response = self._test_export_project(username, project_id, 'CVAT for images 1.1')
response = self._test_export_project(admin_user, project_id, 'CVAT for images 1.1')
tmp_file = io.BytesIO(response.data)
tmp_file.name = 'dataset.zip'
......@@ -361,7 +347,7 @@ class TestImportExportDatasetProject:
'dataset_file': tmp_file,
}
self._test_import_project(username, project_id, 'CVAT 1.1', import_data)
self._test_import_project(admin_user, project_id, 'CVAT 1.1', import_data)
@pytest.mark.usefixtures('changedb')
class TestPatchProjectLabel:
......
......@@ -7,14 +7,15 @@ import json
from copy import deepcopy
from http import HTTPStatus
from time import sleep
from cvat_sdk.api_client.apis import TasksApi
from cvat_sdk.api_client import models
from cvat_sdk.api_client import models, apis
from cvat_sdk.core.helpers import get_paginated_collection
import pytest
from deepdiff import DeepDiff
from shared.utils.config import make_api_client
from shared.utils.helpers import generate_image_files
from .utils import export_dataset
def get_cloud_storage_content(username, cloud_storage_id, manifest):
with make_api_client(username) as api_client:
......@@ -27,12 +28,9 @@ def get_cloud_storage_content(username, cloud_storage_id, manifest):
class TestGetTasks:
def _test_task_list_200(self, user, project_id, data, exclude_paths = '', **kwargs):
with make_api_client(user) as api_client:
(_, response) = api_client.projects_api.list_tasks(project_id, **kwargs,
_parse_response=False)
assert response.status == HTTPStatus.OK
response_data = json.loads(response.data)
assert DeepDiff(data, response_data['results'], ignore_order=True, exclude_paths=exclude_paths) == {}
results = get_paginated_collection(api_client.projects_api.list_tasks_endpoint,
return_json=True, id=project_id, **kwargs)
assert DeepDiff(data, results, ignore_order=True, exclude_paths=exclude_paths) == {}
def _test_task_list_403(self, user, project_id, **kwargs):
with make_api_client(user) as api_client:
......@@ -60,7 +58,7 @@ class TestGetTasks:
for user in staff_users:
with make_api_client(user['username']) as api_client:
(_, response) = api_client.tasks_api.list(**kwargs, _parse_response=False)
(_, response) = api_client.tasks_api.list(**kwargs)
assert response.status == HTTPStatus.OK
response_data = json.loads(response.data)
......@@ -113,12 +111,12 @@ class TestGetTasks:
class TestPostTasks:
def _test_create_task_201(self, user, spec, **kwargs):
with make_api_client(user) as api_client:
(_, response) = api_client.tasks_api.create(models.TaskWriteRequest(**spec), **kwargs)
(_, response) = api_client.tasks_api.create(spec, **kwargs)
assert response.status == HTTPStatus.CREATED
def _test_create_task_403(self, user, spec, **kwargs):
with make_api_client(user) as api_client:
(_, response) = api_client.tasks_api.create(models.TaskWriteRequest(**spec), **kwargs,
(_, response) = api_client.tasks_api.create(spec, **kwargs,
_parse_response=False, _check_status=False)
assert response.status == HTTPStatus.FORBIDDEN
......@@ -210,10 +208,9 @@ class TestPatchTaskAnnotations:
data = request_data(tid)
with make_api_client(username) as api_client:
patched_data = models.PatchedTaskWriteRequest(**deepcopy(data))
(_, response) = api_client.tasks_api.partial_update_annotations(
id=tid, action='update', org=org,
patched_task_write_request=patched_data,
patched_labeled_data_request=deepcopy(data),
_parse_response=False, _check_status=False)
self._test_check_response(is_allow, response, data)
......@@ -233,30 +230,23 @@ class TestPatchTaskAnnotations:
data = request_data(tid)
with make_api_client(username) as api_client:
patched_data = models.PatchedTaskWriteRequest(**deepcopy(data))
(_, response) = api_client.tasks_api.partial_update_annotations(
id=tid, org_id=org, action='update',
patched_task_write_request=patched_data,
patched_labeled_data_request=deepcopy(data),
_parse_response=False, _check_status=False)
self._test_check_response(is_allow, response, data)
@pytest.mark.usefixtures('dontchangedb')
class TestGetTaskDataset:
def _test_export_project(self, username, tid, **kwargs):
def _test_export_task(self, username, tid, **kwargs):
with make_api_client(username) as api_client:
(_, response) = api_client.tasks_api.retrieve_dataset(id=tid, **kwargs)
assert response.status == HTTPStatus.ACCEPTED
(_, response) = api_client.tasks_api.retrieve_dataset(id=tid, **kwargs)
assert response.status == HTTPStatus.CREATED
(_, response) = api_client.tasks_api.retrieve_dataset(id=tid, **kwargs, action='download')
assert response.status == HTTPStatus.OK
return export_dataset(api_client.tasks_api.retrieve_dataset_endpoint, id=tid, **kwargs)
def test_admin_can_export_task_dataset(self, tasks_with_shapes):
def test_can_export_task_dataset(self, admin_user, tasks_with_shapes):
task = tasks_with_shapes[0]
self._test_export_project('admin1', task['id'], format='CVAT for images 1.1')
response = self._test_export_task(admin_user, task['id'], format='CVAT for images 1.1')
assert response.data
@pytest.mark.usefixtures("changedb")
@pytest.mark.usefixtures("restore_cvat_data")
......@@ -264,7 +254,7 @@ class TestPostTaskData:
_USERNAME = 'admin1'
@staticmethod
def _wait_until_task_is_created(api: TasksApi, task_id: int) -> models.RqStatus:
def _wait_until_task_is_created(api: apis.TasksApi, task_id: int) -> models.RqStatus:
for _ in range(100):
(status, _) = api.retrieve_status(task_id)
if status.state.value in ['Finished', 'Failed']:
......@@ -274,11 +264,10 @@ class TestPostTaskData:
def _test_create_task(self, username, spec, data, content_type, **kwargs):
with make_api_client(username) as api_client:
(task, response) = api_client.tasks_api.create(models.TaskWriteRequest(**spec), **kwargs)
(task, response) = api_client.tasks_api.create(spec, **kwargs)
assert response.status == HTTPStatus.CREATED
task_data = models.DataRequest(**data)
(_, response) = api_client.tasks_api.create_data(task.id, task_data,
(_, response) = api_client.tasks_api.create_data(task.id, data_request=deepcopy(data),
_content_type=content_type, **kwargs)
assert response.status == HTTPStatus.ACCEPTED
......
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from http import HTTPStatus
from time import sleep
from cvat_sdk.api_client.api_client import Endpoint
from urllib3 import HTTPResponse
def export_dataset(
endpoint: Endpoint, *, max_retries: int = 20, interval: float = 0.1, **kwargs
) -> HTTPResponse:
for _ in range(max_retries):
(_, response) = endpoint.call_with_http_info(**kwargs, _parse_response=False)
if response.status == HTTPStatus.CREATED:
break
assert response.status == HTTPStatus.ACCEPTED
sleep(interval)
assert response.status == HTTPStatus.CREATED
(_, response) = endpoint.call_with_http_info(**kwargs, action="download", _parse_response=False)
assert response.status == HTTPStatus.OK
return response
......@@ -2,10 +2,16 @@
#
# SPDX-License-Identifier: MIT
from pathlib import Path
import pytest
from cvat_sdk import Client
from PIL import Image
from shared.utils.config import BASE_URL
from shared.utils.helpers import generate_image_file
from .util import generate_coco_json
@pytest.fixture
......@@ -20,3 +26,22 @@ def fxt_client(fxt_logger):
with client:
yield client
@pytest.fixture
def fxt_image_file(tmp_path: Path):
img_path = tmp_path / "img.png"
with img_path.open("wb") as f:
f.write(generate_image_file(filename=str(img_path), size=(5, 10)).getvalue())
return img_path
@pytest.fixture
def fxt_coco_file(tmp_path: Path, fxt_image_file: Path):
img_filename = fxt_image_file
img_size = Image.open(img_filename).size
ann_filename = tmp_path / "coco.json"
generate_coco_json(ann_filename, img_info=(img_filename, *img_size))
yield ann_filename
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
from logging import Logger
from pathlib import Path
from typing import Tuple
import pytest
from cvat_sdk import Client
from cvat_sdk.api_client import exceptions, models
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from shared.utils.config import USER_PASS
class TestIssuesUsecases:
@pytest.fixture(autouse=True)
def setup(
self,
changedb, # force fixture call order to allow DB setup
tmp_path: Path,
fxt_logger: Tuple[Logger, io.StringIO],
fxt_client: Client,
fxt_stdout: io.StringIO,
admin_user: str,
):
self.tmp_path = tmp_path
_, self.logger_stream = fxt_logger
self.client = fxt_client
self.stdout = fxt_stdout
self.user = admin_user
self.client.login((self.user, USER_PASS))
yield
@pytest.fixture
def fxt_new_task(self, fxt_image_file: Path):
task = self.client.tasks.create_from_data(
spec={
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
},
resource_type=ResourceType.LOCAL,
resources=[str(fxt_image_file)],
data_params={"image_quality": 80},
)
return task
def test_can_retrieve_issue(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
retrieved_issue = self.client.issues.retrieve(issue.id)
assert issue.id == retrieved_issue.id
assert self.stdout.getvalue() == ""
def test_can_list_issues(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
assignee=self.client.users.list()[0].id,
)
)
issues = self.client.issues.list()
assert any(issue.id == j.id for j in issues)
assert self.stdout.getvalue() == ""
def test_can_list_comments(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
issue.fetch()
comment_ids = {c.id for c in issue.comments}
assert len(comment_ids) == 2
assert comment.id in comment_ids
assert self.stdout.getvalue() == ""
def test_can_modify_issue(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
issue.update(models.PatchedIssueWriteRequest(resolved=True))
retrieved_issue = self.client.issues.retrieve(issue.id)
assert retrieved_issue.resolved is True
assert issue.resolved == retrieved_issue.resolved
assert self.stdout.getvalue() == ""
def test_can_remove_issue(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
issue.remove()
with pytest.raises(exceptions.NotFoundException):
issue.fetch()
with pytest.raises(exceptions.NotFoundException):
self.client.comments.retrieve(issue.comments[0].id)
assert self.stdout.getvalue() == ""
class TestCommentsUsecases:
@pytest.fixture(autouse=True)
def setup(
self,
changedb, # force fixture call order to allow DB setup
tmp_path: Path,
fxt_logger: Tuple[Logger, io.StringIO],
fxt_client: Client,
fxt_stdout: io.StringIO,
admin_user: str,
):
self.tmp_path = tmp_path
_, self.logger_stream = fxt_logger
self.client = fxt_client
self.stdout = fxt_stdout
self.user = admin_user
self.client.login((self.user, USER_PASS))
yield
@pytest.fixture
def fxt_new_task(self, fxt_image_file: Path):
task = self.client.tasks.create_from_data(
spec={
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
},
resource_type=ResourceType.LOCAL,
resources=[str(fxt_image_file)],
data_params={"image_quality": 80},
)
return task
def test_can_retrieve_comment(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
retrieved_comment = self.client.comments.retrieve(comment.id)
assert comment.id == retrieved_comment.id
assert self.stdout.getvalue() == ""
def test_can_list_comments(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
comments = self.client.comments.list()
assert any(comment.id == c.id for c in comments)
assert self.stdout.getvalue() == ""
def test_can_modify_comment(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
comment.update(models.PatchedCommentWriteRequest(message="bar"))
retrieved_comment = self.client.comments.retrieve(comment.id)
assert retrieved_comment.message == "bar"
assert comment.message == retrieved_comment.message
assert self.stdout.getvalue() == ""
def test_can_remove_comment(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
comment.remove()
with pytest.raises(exceptions.NotFoundException):
comment.fetch()
assert self.stdout.getvalue() == ""
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
import os.path as osp
from logging import Logger
from pathlib import Path
from typing import Tuple
import pytest
from cvat_sdk import Client
from cvat_sdk.api_client import models
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from PIL import Image
from shared.utils.config import USER_PASS
from .util import make_pbar
class TestJobUsecases:
@pytest.fixture(autouse=True)
def setup(
self,
changedb, # force fixture call order to allow DB setup
tmp_path: Path,
fxt_logger: Tuple[Logger, io.StringIO],
fxt_client: Client,
fxt_stdout: io.StringIO,
admin_user: str,
):
self.tmp_path = tmp_path
_, self.logger_stream = fxt_logger
self.client = fxt_client
self.stdout = fxt_stdout
self.user = admin_user
self.client.login((self.user, USER_PASS))
yield
@pytest.fixture
def fxt_new_task(self, fxt_image_file: Path):
task = self.client.tasks.create_from_data(
spec={
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
},
resource_type=ResourceType.LOCAL,
resources=[str(fxt_image_file)],
data_params={"image_quality": 80},
)
return task
@pytest.fixture
def fxt_task_with_shapes(self, fxt_new_task: Task):
fxt_new_task.set_annotations(
models.LabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=0,
label_id=fxt_new_task.labels[0].id,
type="rectangle",
points=[1, 1, 2, 2],
),
],
)
)
return fxt_new_task
def test_can_retrieve_job(self, fxt_new_task: Task):
job_id = fxt_new_task.get_jobs()[0].id
job = self.client.jobs.retrieve(job_id)
assert job.id == job_id
assert self.stdout.getvalue() == ""
def test_can_list_jobs(self, fxt_new_task: Task):
task_job_ids = set(j.id for j in fxt_new_task.get_jobs())
jobs = self.client.jobs.list()
assert len(task_job_ids) != 0
assert task_job_ids.issubset(j.id for j in jobs)
assert self.stdout.getvalue() == ""
def test_can_update_job_field_directly(self, fxt_new_task: Task):
job = self.client.jobs.list()[0]
assert not job.assignee
new_assignee = self.client.users.list()[0]
job.update({"assignee": new_assignee.id})
updated_job = self.client.jobs.retrieve(job.id)
assert updated_job.assignee.id == new_assignee.id
assert self.stdout.getvalue() == ""
@pytest.mark.parametrize("include_images", (True, False))
def test_can_download_dataset(self, fxt_new_task: Task, include_images: bool):
pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out)
task_id = fxt_new_task.id
path = str(self.tmp_path / f"task_{task_id}-cvat.zip")
job = self.client.jobs.retrieve(task_id)
job.export_dataset(
format_name="CVAT for images 1.1",
filename=path,
pbar=pbar,
include_images=include_images,
)
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert osp.isfile(path)
assert self.stdout.getvalue() == ""
def test_can_download_preview(self, fxt_new_task: Task):
frame_encoded = fxt_new_task.get_jobs()[0].get_preview()
assert Image.open(frame_encoded).size != 0
assert self.stdout.getvalue() == ""
@pytest.mark.parametrize("quality", ("compressed", "original"))
def test_can_download_frame(self, fxt_new_task: Task, quality: str):
frame_encoded = fxt_new_task.get_jobs()[0].get_frame(0, quality=quality)
assert Image.open(frame_encoded).size != 0
assert self.stdout.getvalue() == ""
@pytest.mark.parametrize("quality", ("compressed", "original"))
def test_can_download_frames(self, fxt_new_task: Task, quality: str):
fxt_new_task.get_jobs()[0].download_frames(
[0],
quality=quality,
outdir=str(self.tmp_path),
filename_pattern="frame-{frame_id}{frame_ext}",
)
assert osp.isfile(self.tmp_path / "frame-0.jpg")
assert self.stdout.getvalue() == ""
def test_can_upload_annotations(self, fxt_new_task: Task, fxt_coco_file: Path):
pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out)
fxt_new_task.get_jobs()[0].import_annotations(
format_name="COCO 1.0", filename=str(fxt_coco_file), pbar=pbar
)
assert "uploaded" in self.logger_stream.getvalue()
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert self.stdout.getvalue() == ""
def test_can_get_meta(self, fxt_new_task: Task):
meta = fxt_new_task.get_jobs()[0].get_meta()
assert meta.image_quality == 80
assert meta.size == 1
assert len(meta.frames) == meta.size
assert meta.frames[0].name == "img.png"
assert meta.frames[0].width == 5
assert meta.frames[0].height == 10
assert not meta.deleted_frames
assert self.stdout.getvalue() == ""
def test_can_remove_frames(self, fxt_new_task: Task):
fxt_new_task.get_jobs()[0].remove_frames_by_ids([0])
meta = fxt_new_task.get_jobs()[0].get_meta()
assert meta.deleted_frames == [0]
assert self.stdout.getvalue() == ""
def test_can_get_issues(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
job_issue_ids = set(j.id for j in fxt_new_task.get_jobs()[0].get_issues())
assert {issue.id} == job_issue_ids
assert self.stdout.getvalue() == ""
def test_can_get_annotations(self, fxt_task_with_shapes: Task):
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
assert len(anns.shapes) == 1
assert anns.shapes[0].type.value == "rectangle"
assert self.stdout.getvalue() == ""
def test_can_set_annotations(self, fxt_new_task: Task):
fxt_new_task.get_jobs()[0].set_annotations(
models.LabeledDataRequest(
tags=[models.LabeledImageRequest(frame=0, label_id=fxt_new_task.labels[0].id)],
)
)
anns = fxt_new_task.get_jobs()[0].get_annotations()
assert len(anns.tags) == 1
assert self.stdout.getvalue() == ""
def test_can_clear_annotations(self, fxt_task_with_shapes: Task):
fxt_task_with_shapes.get_jobs()[0].remove_annotations()
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
assert len(anns.tags) == 0
assert len(anns.tracks) == 0
assert len(anns.shapes) == 0
assert self.stdout.getvalue() == ""
def test_can_remove_annotations(self, fxt_new_task: Task):
fxt_new_task.get_jobs()[0].set_annotations(
models.LabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=0,
label_id=fxt_new_task.labels[0].id,
type="rectangle",
points=[1, 1, 2, 2],
),
models.LabeledShapeRequest(
frame=0,
label_id=fxt_new_task.labels[0].id,
type="rectangle",
points=[2, 2, 3, 3],
),
],
)
)
anns = fxt_new_task.get_jobs()[0].get_annotations()
fxt_new_task.get_jobs()[0].remove_annotations(ids=[anns.shapes[0].id])
anns = fxt_new_task.get_jobs()[0].get_annotations()
assert len(anns.tags) == 0
assert len(anns.tracks) == 0
assert len(anns.shapes) == 1
assert self.stdout.getvalue() == ""
def test_can_update_annotations(self, fxt_task_with_shapes: Task):
fxt_task_with_shapes.get_jobs()[0].update_annotations(
models.PatchedLabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=0,
label_id=fxt_task_with_shapes.labels[0].id,
type="rectangle",
points=[0, 1, 2, 3],
),
],
tracks=[
models.LabeledTrackRequest(
frame=0,
label_id=fxt_task_with_shapes.labels[0].id,
shapes=[
models.TrackedShapeRequest(
frame=0, type="polygon", points=[3, 2, 2, 3, 3, 4]
),
],
)
],
tags=[
models.LabeledImageRequest(frame=0, label_id=fxt_task_with_shapes.labels[0].id)
],
)
)
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
assert len(anns.shapes) == 2
assert len(anns.tracks) == 1
assert len(anns.tags) == 1
assert self.stdout.getvalue() == ""
此差异已折叠。
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
from logging import Logger
from pathlib import Path
from typing import Tuple
import pytest
from cvat_sdk import Client, models
from cvat_sdk.api_client import exceptions
from shared.utils.config import USER_PASS
class TestUserUsecases:
@pytest.fixture(autouse=True)
def setup(
self,
changedb, # force fixture call order to allow DB setup
tmp_path: Path,
fxt_logger: Tuple[Logger, io.StringIO],
fxt_client: Client,
fxt_stdout: io.StringIO,
admin_user: str,
):
self.tmp_path = tmp_path
_, self.logger_stream = fxt_logger
self.client = fxt_client
self.stdout = fxt_stdout
self.user = admin_user
self.client.login((self.user, USER_PASS))
yield
def test_can_retrieve_user(self):
me = self.client.users.retrieve_current_user()
user = self.client.users.retrieve(me.id)
assert user.id == me.id
assert user.username == self.user
assert self.stdout.getvalue() == ""
def test_can_list_users(self):
users = self.client.users.list()
assert self.user in set(u.username for u in users)
assert self.stdout.getvalue() == ""
def test_can_update_user(self):
user = self.client.users.retrieve_current_user()
user.update(models.PatchedUserRequest(first_name="foo", last_name="bar"))
retrieved_user = self.client.users.retrieve(user.id)
assert retrieved_user.first_name == "foo"
assert retrieved_user.last_name == "bar"
assert user.first_name == retrieved_user.first_name
assert user.last_name == retrieved_user.last_name
assert self.stdout.getvalue() == ""
def test_can_remove_user(self):
users = self.client.users.list()
removed_user = next(u for u in users if u.username != self.user)
removed_user.remove()
with pytest.raises(exceptions.NotFoundException):
removed_user.fetch()
assert self.stdout.getvalue() == ""
此差异已折叠。
......@@ -279,6 +279,10 @@ def filter_tasks_with_shapes(annotations):
return list(filter(lambda t: annotations['task'][str(t['id'])]['shapes'], tasks))
return find
@pytest.fixture(scope='session')
def jobs_with_shapes(jobs, filter_jobs_with_shapes):
return filter_jobs_with_shapes(jobs)
@pytest.fixture(scope='session')
def tasks_with_shapes(tasks, filter_tasks_with_shapes):
return filter_tasks_with_shapes(tasks)
......
......@@ -48,5 +48,6 @@ def post_files_method(username, endpoint, data, files, **kwargs):
def server_get(username, endpoint, **kwargs):
return requests.get(get_server_url(endpoint, **kwargs), auth=(username, USER_PASS))
def make_api_client(user: str) -> ApiClient:
return ApiClient(configuration=Configuration(host=BASE_URL, username=user, password=USER_PASS))
def make_api_client(user: str, *, password: str = None) -> ApiClient:
return ApiClient(configuration=Configuration(host=BASE_URL,
username=user, password=password or USER_PASS))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册