# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from __future__ import annotations import io import json import mimetypes import os import os.path as osp import shutil from enum import Enum from time import sleep from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from PIL import Image from cvat_sdk.api_client import apis, exceptions, models from cvat_sdk.core import git from cvat_sdk.core.downloading import Downloader from cvat_sdk.core.progress import ProgressReporter from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin from cvat_sdk.core.proxies.jobs import Job from cvat_sdk.core.proxies.model_proxy import ( ModelCreateMixin, ModelDeleteMixin, ModelListMixin, ModelRetrieveMixin, ModelUpdateMixin, build_model_bases, ) from cvat_sdk.core.uploading import AnnotationUploader, DataUploader, Uploader from cvat_sdk.core.utils import filter_dict if TYPE_CHECKING: from _typeshed import SupportsWrite class ResourceType(Enum): LOCAL = 0 SHARE = 1 REMOTE = 2 def __str__(self): return self.name.lower() def __repr__(self): return str(self) _TaskEntityBase, _TaskRepoBase = build_model_bases( models.TaskRead, apis.TasksApi, api_member_name="tasks_api" ) class Task( _TaskEntityBase, models.ITaskRead, ModelUpdateMixin[models.IPatchedTaskWriteRequest], ModelDeleteMixin, AnnotationCrudMixin, ): _model_partial_update_arg = "patched_task_write_request" _put_annotations_data_param = "task_annotations_update_request" def upload_data( self, resource_type: ResourceType, resources: Sequence[str], *, pbar: Optional[ProgressReporter] = None, params: Optional[Dict[str, Any]] = None, ) -> None: """ Add local, remote, or shared files to an existing task. """ params = params or {} data = {} if resource_type is ResourceType.LOCAL: pass # handled later elif resource_type is ResourceType.REMOTE: data["remote_files"] = resources elif resource_type is ResourceType.SHARE: data["server_files"] = resources data["image_quality"] = 70 data.update( filter_dict( params, keep=[ "chunk_size", "copy_data", "image_quality", "sorting_method", "start_frame", "stop_frame", "use_cache", "use_zip_chunks", ], ) ) if params.get("frame_step") is not None: data["frame_filter"] = f"step={params.get('frame_step')}" if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]: self.api.create_data( self.id, data_request=models.DataRequest(**data), ) elif resource_type == ResourceType.LOCAL: url = self._client.api_map.make_endpoint_url( self.api.create_data_endpoint.path, kwsub={"id": self.id} ) DataUploader(self._client).upload_files(url, resources, pbar=pbar, **data) def import_annotations( self, format_name: str, filename: str, *, status_check_period: Optional[int] = None, pbar: Optional[ProgressReporter] = None, ): """ Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). """ AnnotationUploader(self._client).upload_file_and_wait( self.api.create_annotations_endpoint, filename, format_name, url_params={"id": self.id}, pbar=pbar, status_check_period=status_check_period, ) self._client.logger.info(f"Annotation file '{filename}' for task #{self.id} uploaded") def get_frame( self, frame_id: int, *, quality: Optional[str] = None, ) -> io.RawIOBase: params = {} if quality: params["quality"] = quality (_, response) = self.api.retrieve_data(self.id, number=frame_id, **params, type="frame") return io.BytesIO(response.data) def get_preview( self, ) -> io.RawIOBase: (_, response) = self.api.retrieve_data(self.id, type="preview") return io.BytesIO(response.data) def download_chunk( self, chunk_id: int, output_file: SupportsWrite[bytes], *, quality: Optional[str] = None, ) -> None: params = {} if quality: params["quality"] = quality (_, response) = self.api.retrieve_data( self.id, number=chunk_id, **params, type="chunk", _parse_response=False ) with response: shutil.copyfileobj(response, output_file) def download_frames( self, frame_ids: Sequence[int], *, outdir: str = "", quality: str = "original", filename_pattern: str = "frame_{frame_id:06d}{frame_ext}", ) -> Optional[List[Image.Image]]: """ Download the requested frame numbers for a task and save images as outdir/filename_pattern """ # TODO: add arg descriptions in schema os.makedirs(outdir, exist_ok=True) for frame_id in frame_ids: frame_bytes = self.get_frame(frame_id, quality=quality) im = Image.open(frame_bytes) mime_type = im.get_format_mimetype() or "image/jpg" im_ext = mimetypes.guess_extension(mime_type) # FIXME It is better to use meta information from the server # to determine the extension # replace '.jpe' or '.jpeg' with a more used '.jpg' if im_ext in (".jpe", ".jpeg", None): im_ext = ".jpg" outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext) im.save(osp.join(outdir, outfile)) def export_dataset( self, format_name: str, filename: str, *, pbar: Optional[ProgressReporter] = None, status_check_period: Optional[int] = None, include_images: bool = True, ) -> None: """ Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). """ if include_images: endpoint = self.api.retrieve_dataset_endpoint else: endpoint = self.api.retrieve_annotations_endpoint Downloader(self._client).prepare_and_download_file_from_endpoint( endpoint=endpoint, filename=filename, url_params={"id": self.id}, query_params={"format": format_name}, pbar=pbar, status_check_period=status_check_period, ) self._client.logger.info(f"Dataset for task {self.id} has been downloaded to {filename}") def download_backup( self, filename: str, *, status_check_period: int = None, pbar: Optional[ProgressReporter] = None, ) -> None: """ Download a task backup """ Downloader(self._client).prepare_and_download_file_from_endpoint( self.api.retrieve_backup_endpoint, filename=filename, pbar=pbar, status_check_period=status_check_period, url_params={"id": self.id}, ) self._client.logger.info(f"Backup for task {self.id} has been downloaded to {filename}") def get_jobs(self) -> List[Job]: return [Job(self._client, m) for m in self.api.list_jobs(id=self.id)[0]] def get_meta(self) -> models.IDataMetaRead: (meta, _) = self.api.retrieve_data_meta(self.id) return meta def get_frames_info(self) -> List[models.IFrameMeta]: return self.get_meta().frames def remove_frames_by_ids(self, ids: Sequence[int]) -> None: self.api.partial_update_data_meta( self.id, patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids), ) class TasksRepo( _TaskRepoBase, ModelCreateMixin[Task, models.ITaskWriteRequest], ModelRetrieveMixin[Task], ModelListMixin[Task], ModelDeleteMixin, ): _entity_type = Task def create_from_data( self, spec: models.ITaskWriteRequest, resource_type: ResourceType, resources: Sequence[str], *, data_params: Optional[Dict[str, Any]] = None, annotation_path: str = "", annotation_format: str = "CVAT XML 1.1", status_check_period: int = None, dataset_repository_url: str = "", use_lfs: bool = False, pbar: Optional[ProgressReporter] = None, ) -> Task: """ Create a new task with the given name and labels JSON and add the files to it. Returns: id of the created task """ if status_check_period is None: status_check_period = self._client.config.status_check_period if getattr(spec, "project_id", None) and getattr(spec, "labels", None): raise exceptions.ApiValueError( "Can't set labels to a task inside a project. " "Tasks inside a project use project's labels.", ["labels"], ) task = self.create(spec=spec) self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name) task.upload_data(resource_type, resources, pbar=pbar, params=data_params) self._client.logger.info("Awaiting for task %s creation...", task.id) status: models.RqStatus = None while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]: sleep(status_check_period) (status, response) = self.api.retrieve_status(task.id) self._client.logger.info( "Task %s creation status=%s, message=%s", task.id, status.state.value, status.message, ) if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]: raise exceptions.ApiException( status=status.state.value, reason=status.message, http_resp=response ) status = status.state.value if annotation_path: task.import_annotations(annotation_format, annotation_path, pbar=pbar) if dataset_repository_url: git.create_git_repo( self._client, task_id=task.id, repo_url=dataset_repository_url, status_check_period=status_check_period, use_lfs=use_lfs, ) task.fetch() return task def remove_by_ids(self, task_ids: Sequence[int]) -> None: """ Delete a list of tasks, ignoring those which don't exist. """ for task_id in task_ids: (_, response) = self.api.destroy(task_id, _check_status=False) if 200 <= response.status <= 299: self._client.logger.info(f"Task ID {task_id} deleted") elif response.status == 404: self._client.logger.info(f"Task ID {task_id} not found") else: self._client.logger.warning( f"Failed to delete task ID {task_id}: " f"{response.msg} (status {response.status})" ) def create_from_backup( self, filename: str, *, status_check_period: int = None, pbar: Optional[ProgressReporter] = None, ) -> Task: """ Import a task from a backup file """ if status_check_period is None: status_check_period = self._client.config.status_check_period params = {"filename": osp.basename(filename)} url = self._client.api_map.make_endpoint_url(self.api.create_backup_endpoint.path) uploader = Uploader(self._client) response = uploader.upload_file( url, filename, meta=params, query_params=params, pbar=pbar, logger=self._client.logger.debug, ) rq_id = json.loads(response.data)["rq_id"] response = self._client.wait_for_completion( url, success_status=201, positive_statuses=[202], post_params={"rq_id": rq_id}, status_check_period=status_check_period, ) task_id = json.loads(response.data)["id"] self._client.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}") return self.retrieve(task_id)