From 31f0578220d9cd9d20136cf0bde40996da334946 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 13 Jan 2023 19:23:57 +0300 Subject: [PATCH] Add a way to create task with custom jobs (#5536) This PR adds an option to specify file to job mapping explicitly during task creation. This option is incompatible with most other job-related parameters like `sorting_method` and `frame_step`. - Added a new task creation parameter (`job_file_mapping`) to set a custom file to job mapping during task creation --- CHANGELOG.md | 2 + cvat-sdk/cvat_sdk/core/proxies/tasks.py | 1 + cvat-sdk/cvat_sdk/core/uploading.py | 2 +- cvat/apps/engine/backup.py | 58 +++++++- cvat/apps/engine/media_extractors.py | 2 +- cvat/apps/engine/models.py | 1 + cvat/apps/engine/serializers.py | 58 +++++++- cvat/apps/engine/task.py | 184 +++++++++++++++++++----- cvat/apps/engine/views.py | 3 + tests/python/rest_api/test_tasks.py | 154 +++++++++++++++++++- 10 files changed, 422 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 700f75e40..8c8b4d6b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - \[SDK\] Class to represent a project as a PyTorch dataset () +- Support for custom file to job splits in tasks (server API & SDK only) + () - \[SDK\] A PyTorch adapter setting to disable cache updates () - YOLO v7 serverless feature added using ONNX backend () diff --git a/cvat-sdk/cvat_sdk/core/proxies/tasks.py b/cvat-sdk/cvat_sdk/core/proxies/tasks.py index 97dcbdbcb..6be5e8e8f 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/tasks.py +++ b/cvat-sdk/cvat_sdk/core/proxies/tasks.py @@ -92,6 +92,7 @@ class Task( "stop_frame", "use_cache", "use_zip_chunks", + "job_file_mapping", "filename_pattern", "cloud_storage_id", ], diff --git a/cvat-sdk/cvat_sdk/core/uploading.py b/cvat-sdk/cvat_sdk/core/uploading.py index 95d129a06..ceacda782 100644 --- a/cvat-sdk/cvat_sdk/core/uploading.py +++ b/cvat-sdk/cvat_sdk/core/uploading.py @@ -346,7 +346,7 @@ class DataUploader(Uploader): ) response = self._client.api_client.rest_client.POST( url, - post_params=dict(**kwargs, **files), + post_params={"image_quality": kwargs["image_quality"], **files}, headers={ "Content-Type": "multipart/form-data", "Upload-Multiple": "", diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index d22739bc6..65c75da6b 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -9,6 +9,7 @@ from enum import Enum import re import shutil import tempfile +from typing import Any, Dict, Iterable import uuid from zipfile import ZipFile from datetime import datetime @@ -37,7 +38,7 @@ from cvat.apps.engine.utils import av_scan_paths, process_failed_job, configure_ from cvat.apps.engine.models import ( StorageChoice, StorageMethodChoice, DataChoice, Task, Project, Location, CloudStorage as CloudStorageModel) -from cvat.apps.engine.task import _create_thread +from cvat.apps.engine.task import JobFileMapping, _create_thread from cvat.apps.dataset_manager.views import TASK_CACHE_TTL, PROJECT_CACHE_TTL, get_export_cache_dir, clear_export_cache, log_exception from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance @@ -133,6 +134,8 @@ class _TaskBackupBase(_BackupBase): 'storage', 'sorting_method', 'deleted_frames', + 'custom_segments', + 'job_file_mapping', } self._prepare_meta(allowed_fields, data) @@ -331,6 +334,9 @@ class TaskExporter(_ExporterBase, _TaskBackupBase): segment = segment_serailizer.data segment.update(job_data) + if self._db_task.segment_size == 0: + segment.update(serialize_custom_file_mapping(db_segment)) + return segment def serialize_jobs(): @@ -338,12 +344,28 @@ class TaskExporter(_ExporterBase, _TaskBackupBase): db_segments.sort(key=lambda i: i.job_set.first().id) return (serialize_segment(s) for s in db_segments) + def serialize_custom_file_mapping(db_segment: models.Segment): + if self._db_task.mode == 'annotation': + files: Iterable[models.Image] = self._db_data.images.all().order_by('frame') + segment_files = files[db_segment.start_frame : db_segment.stop_frame + 1] + return {'files': list(frame.path for frame in segment_files)} + else: + assert False, ( + "Backups with custom file mapping are not supported" + " in the 'interpolation' task mode" + ) + def serialize_data(): data_serializer = DataSerializer(self._db_data) data = data_serializer.data data['chunk_type'] = data.pop('compressed_chunk_type') + # There are no deleted frames in DataSerializer so we need to pick it data['deleted_frames'] = self._db_data.deleted_frames + + if self._db_task.segment_size == 0: + data['custom_segments'] = True + return self._prepare_data_meta(data) task = serialize_task() @@ -491,6 +513,20 @@ class TaskImporter(_ImporterBase, _TaskBackupBase): return segment_size, overlap + @staticmethod + def _parse_custom_segments(*, jobs: Dict[str, Any]) -> JobFileMapping: + segments = [] + + for i, segment in enumerate(jobs): + segment_size = segment['stop_frame'] - segment['start_frame'] + 1 + segment_files = segment['files'] + if len(segment_files) != segment_size: + raise ValidationError(f"segment {i}: segment files do not match segment size") + + segments.append(segment_files) + + return segments + def _import_task(self): def _write_data(zip_object): data_path = self._db_task.data.get_upload_dirname() @@ -519,10 +555,23 @@ class TaskImporter(_ImporterBase, _TaskBackupBase): jobs = self._manifest.pop('jobs') self._prepare_task_meta(self._manifest) - self._manifest['segment_size'], self._manifest['overlap'] = self._calculate_segment_size(jobs) self._manifest['owner_id'] = self._user_id self._manifest['project_id'] = self._project_id + if custom_segments := data.pop('custom_segments', False): + job_file_mapping = self._parse_custom_segments(jobs=jobs) + data['job_file_mapping'] = job_file_mapping + + for d in [self._manifest, data]: + for k in [ + 'segment_size', 'overlap', 'start_frame', 'stop_frame', + 'sorting_method', 'frame_filter', 'filename_pattern' + ]: + d.pop(k, None) + else: + self._manifest['segment_size'], self._manifest['overlap'] = \ + self._calculate_segment_size(jobs) + self._db_task = models.Task.objects.create(**self._manifest, organization_id=self._org_id) task_path = self._db_task.get_dirname() if os.path.isdir(task_path): @@ -550,7 +599,10 @@ class TaskImporter(_ImporterBase, _TaskBackupBase): data['use_zip_chunks'] = data.pop('chunk_type') == DataChoice.IMAGESET data = data_serializer.data data['client_files'] = uploaded_files - _create_thread(self._db_task.pk, data.copy(), True) + if custom_segments: + data['job_file_mapping'] = job_file_mapping + + _create_thread(self._db_task.pk, data.copy(), isBackupRestore=True) db_data.start_frame = data['start_frame'] db_data.stop_frame = data['stop_frame'] db_data.frame_filter = data['frame_filter'] diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py index 02a4b5fd3..bec571ad5 100644 --- a/cvat/apps/engine/media_extractors.py +++ b/cvat/apps/engine/media_extractors.py @@ -169,7 +169,7 @@ class ImageListReader(IMediaReader): if not source_path: raise Exception('No image found') - if stop is None: + if not stop: stop = len(source_path) else: stop = min(len(source_path), stop + 1) diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index 73b309335..cd47e516c 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -344,6 +344,7 @@ class Task(models.Model): updated_date = models.DateTimeField(auto_now=True) overlap = models.PositiveIntegerField(null=True) # Zero means that there are no limits (default) + # Note that the files can be split into jobs in a custom way in this case segment_size = models.PositiveIntegerField(default=0) status = models.CharField(max_length=32, choices=StatusChoice.choices(), default=StatusChoice.ANNOTATION) diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index abb29bd36..71a211b9b 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -8,6 +8,7 @@ import re import shutil from tempfile import NamedTemporaryFile +import textwrap from typing import OrderedDict from rest_framework import serializers, exceptions @@ -362,6 +363,41 @@ class WriteOnceMixin: return extra_kwargs + +class JobFiles(serializers.ListField): + """ + Read JobFileMapping docs for more info. + """ + + def __init__(self, *args, **kwargs): + kwargs.setdefault('child', serializers.CharField(allow_blank=False, max_length=1024)) + kwargs.setdefault('allow_empty', False) + super().__init__(*args, **kwargs) + + +class JobFileMapping(serializers.ListField): + """ + Represents a file-to-job mapping. Useful to specify a custom job + configuration during task creation. This option is not compatible with + most other job split-related options. + + Example: + [ + ["file1.jpg", "file2.jpg"], # job #1 files + ["file3.png"], # job #2 files + ["file4.jpg", "file5.png", "file6.bmp"], # job #3 files + ] + + Files in the jobs must not overlap and repeat. + """ + + def __init__(self, *args, **kwargs): + kwargs.setdefault('child', JobFiles()) + kwargs.setdefault('allow_empty', False) + kwargs.setdefault('help_text', textwrap.dedent(__class__.__doc__)) + super().__init__(*args, **kwargs) + + class DataSerializer(WriteOnceMixin, serializers.ModelSerializer): image_quality = serializers.IntegerField(min_value=0, max_value=100) use_zip_chunks = serializers.BooleanField(default=False) @@ -372,12 +408,14 @@ class DataSerializer(WriteOnceMixin, serializers.ModelSerializer): copy_data = serializers.BooleanField(default=False) cloud_storage_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) filename_pattern = serializers.CharField(allow_null=True, required=False) + job_file_mapping = JobFileMapping(required=False, write_only=True) class Meta: model = models.Data fields = ('chunk_size', 'size', 'image_quality', 'start_frame', 'stop_frame', 'frame_filter', 'compressed_chunk_type', 'original_chunk_type', 'client_files', 'server_files', 'remote_files', 'use_zip_chunks', - 'cloud_storage_id', 'use_cache', 'copy_data', 'storage_method', 'storage', 'sorting_method', 'filename_pattern') + 'cloud_storage_id', 'use_cache', 'copy_data', 'storage_method', 'storage', 'sorting_method', 'filename_pattern', + 'job_file_mapping') # pylint: disable=no-self-use def validate_frame_filter(self, value): @@ -392,6 +430,21 @@ class DataSerializer(WriteOnceMixin, serializers.ModelSerializer): raise serializers.ValidationError('Chunk size must be a positive integer') return value + def validate_job_file_mapping(self, value): + existing_files = set() + + for job_files in value: + for filename in job_files: + if filename in existing_files: + raise serializers.ValidationError( + f"The same file '{filename}' cannot be used multiple " + "times in the job file mapping" + ) + + existing_files.add(filename) + + return value + # pylint: disable=no-self-use def validate(self, attrs): if 'start_frame' in attrs and 'stop_frame' in attrs \ @@ -402,6 +455,7 @@ class DataSerializer(WriteOnceMixin, serializers.ModelSerializer): def create(self, validated_data): files = self._pop_data(validated_data) + db_data = models.Data.objects.create(**validated_data) db_data.make_dirs() @@ -424,6 +478,8 @@ class DataSerializer(WriteOnceMixin, serializers.ModelSerializer): server_files = validated_data.pop('server_files') remote_files = validated_data.pop('remote_files') + validated_data.pop('job_file_mapping', None) # optional + for extra_key in { 'use_zip_chunks', 'use_cache', 'copy_data' }: validated_data.pop(extra_key) diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index 401e8274f..b6807aadf 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -8,6 +8,7 @@ import itertools import fnmatch import os import sys +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Union from rest_framework.serializers import ValidationError import rq import re @@ -60,6 +61,17 @@ def rq_handler(job, exc_type, exc_value, traceback): ############################# Internal implementation for server API +JobFileMapping = List[List[str]] + +class SegmentParams(NamedTuple): + start_frame: int + stop_frame: int + +class SegmentsParams(NamedTuple): + segments: Iterator[SegmentParams] + segment_size: int + overlap: int + def _copy_data_from_source(server_files, upload_dir, server_dir=None): job = rq.get_current_job() job.meta['status'] = 'Data are being copied from source..' @@ -79,37 +91,68 @@ def _copy_data_from_source(server_files, upload_dir, server_dir=None): os.makedirs(target_dir) shutil.copyfile(source_path, target_path) -def _get_task_segment_data(db_task, data_size): - segment_size = db_task.segment_size - segment_step = segment_size - if segment_size == 0 or segment_size > data_size: - segment_size = data_size +def _get_task_segment_data( + db_task: models.Task, + *, + data_size: Optional[int] = None, + job_file_mapping: Optional[JobFileMapping] = None, +) -> SegmentsParams: + if job_file_mapping is not None: + def _segments(): + # It is assumed here that files are already saved ordered in the task + # Here we just need to create segments by the job sizes + start_frame = 0 + for jf in job_file_mapping: + segment_size = len(jf) + stop_frame = start_frame + segment_size - 1 + yield SegmentParams(start_frame, stop_frame) + + start_frame = stop_frame + 1 + + segments = _segments() + segment_size = 0 + overlap = 0 + else: + # The segments have equal parameters + if data_size is None: + data_size = db_task.data.size + + segment_size = db_task.segment_size + segment_step = segment_size + if segment_size == 0 or segment_size > data_size: + segment_size = data_size + + # Segment step must be more than segment_size + overlap in single-segment tasks + # Otherwise a task contains an extra segment + segment_step = sys.maxsize - # Segment step must be more than segment_size + overlap in single-segment tasks - # Otherwise a task contains an extra segment - segment_step = sys.maxsize + overlap = 5 if db_task.mode == 'interpolation' else 0 + if db_task.overlap is not None: + overlap = min(db_task.overlap, segment_size // 2) - overlap = 5 if db_task.mode == 'interpolation' else 0 - if db_task.overlap is not None: - overlap = min(db_task.overlap, segment_size // 2) + segment_step -= overlap + + segments = ( + SegmentParams(start_frame, min(start_frame + segment_size - 1, data_size - 1)) + for start_frame in range(0, data_size, segment_step) + ) - segment_step -= overlap - return segment_step, segment_size, overlap + return SegmentsParams(segments, segment_size, overlap) -def _save_task_to_db(db_task, extractor): +def _save_task_to_db(db_task: models.Task, *, job_file_mapping: Optional[JobFileMapping] = None): job = rq.get_current_job() job.meta['status'] = 'Task is being saved in database' job.save_meta() - segment_step, segment_size, overlap = _get_task_segment_data(db_task, db_task.data.size) + segments, segment_size, overlap = _get_task_segment_data( + db_task=db_task, job_file_mapping=job_file_mapping + ) db_task.segment_size = segment_size db_task.overlap = overlap - for start_frame in range(0, db_task.data.size, segment_step): - stop_frame = min(start_frame + segment_size - 1, db_task.data.size - 1) - - slogger.glob.info("New segment for task #{}: start_frame = {}, \ - stop_frame = {}".format(db_task.id, start_frame, stop_frame)) + for segment_idx, (start_frame, stop_frame) in enumerate(segments): + slogger.glob.info("New segment for task #{}: idx = {}, start_frame = {}, \ + stop_frame = {}".format(db_task.id, segment_idx, start_frame, stop_frame)) db_segment = models.Segment() db_segment.task = db_task @@ -214,6 +257,41 @@ def _validate_data(counter, manifest_files=None): return counter, task_modes[0] +def _validate_job_file_mapping( + db_task: models.Task, data: Dict[str, Any] +) -> Optional[JobFileMapping]: + job_file_mapping = data.get('job_file_mapping', None) + + if job_file_mapping is None: + return None + elif not list(itertools.chain.from_iterable(job_file_mapping)): + raise ValidationError("job_file_mapping cannot be empty") + + if db_task.segment_size: + raise ValidationError("job_file_mapping cannot be used with segment_size") + + if (data.get('sorting_method', db_task.data.sorting_method) + != models.SortingMethod.LEXICOGRAPHICAL + ): + raise ValidationError("job_file_mapping cannot be used with sorting_method") + + if data.get('start_frame', db_task.data.start_frame): + raise ValidationError("job_file_mapping cannot be used with start_frame") + + if data.get('stop_frame', db_task.data.stop_frame): + raise ValidationError("job_file_mapping cannot be used with stop_frame") + + if data.get('frame_filter', db_task.data.frame_filter): + raise ValidationError("job_file_mapping cannot be used with frame_filter") + + if db_task.data.get_frame_step() != 1: + raise ValidationError("job_file_mapping cannot be used with frame step") + + if data.get('filename_pattern'): + raise ValidationError("job_file_mapping cannot be used with filename_pattern") + + return job_file_mapping + def _validate_manifest(manifests, root_dir, is_in_cloud, db_cloud_storage, data_storage_method): if manifests: if len(manifests) != 1: @@ -325,12 +403,20 @@ def _create_task_manifest_based_on_cloud_storage_manifest( manifest.create(sorted_content) @transaction.atomic -def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False): +def _create_thread( + db_task: Union[int, models.Task], + data: Dict[str, Any], + *, + isBackupRestore: bool = False, + isDatasetImport: bool = False, +) -> None: if isinstance(db_task, int): db_task = models.Task.objects.select_for_update().get(pk=db_task) slogger.glob.info("create task #{}".format(db_task.id)) + job_file_mapping = _validate_job_file_mapping(db_task, data) + db_data = db_task.data upload_dir = db_data.get_upload_dirname() if db_data.storage != models.StorageChoice.SHARE else settings.SHARE_ROOT is_data_in_cloud = db_data.storage == models.StorageChoice.CLOUD_STORAGE @@ -387,11 +473,17 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False): media = _count_files(data) media, task_mode = _validate_data(media, manifest_files) + if job_file_mapping is not None and task_mode != 'annotation': + raise ValidationError("job_file_mapping can't be used with sequence-based data like videos") + if data['server_files']: if db_data.storage == models.StorageChoice.LOCAL: _copy_data_from_source(data['server_files'], upload_dir, data.get('server_files_path')) elif is_data_in_cloud: - sorted_media = sort(media['image'], data['sorting_method']) + if job_file_mapping is not None: + sorted_media = list(itertools.chain.from_iterable(job_file_mapping)) + else: + sorted_media = sort(media['image'], data['sorting_method']) # Define task manifest content based on cloud storage manifest content and uploaded files _create_task_manifest_based_on_cloud_storage_manifest( @@ -486,24 +578,44 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False): extractor.filter(lambda x: not re.search(r'(^|{0})related_images{0}'.format(os.sep), x)) related_images = detect_related_images(extractor.absolute_source_paths, upload_dir) - if isBackupRestore and not isinstance(extractor, MEDIA_TYPES['video']['extractor']) and db_data.storage_method == models.StorageMethodChoice.CACHE and \ - db_data.sorting_method in {models.SortingMethod.RANDOM, models.SortingMethod.PREDEFINED} and validate_dimension.dimension != models.DimensionType.DIM_3D: - # we should sort media_files according to the manifest content sequence - # and we should do this in general after validation step for 3D data and after filtering from related_images - manifest = ImageManifestManager(db_data.get_manifest_path()) - manifest.set_index() + # Sort the files + if (isBackupRestore and ( + not isinstance(extractor, MEDIA_TYPES['video']['extractor']) + and db_data.storage_method == models.StorageMethodChoice.CACHE + and db_data.sorting_method in {models.SortingMethod.RANDOM, models.SortingMethod.PREDEFINED} + and validate_dimension.dimension != models.DimensionType.DIM_3D + ) or job_file_mapping + ): sorted_media_files = [] - for idx in range(len(extractor.absolute_source_paths)): - properties = manifest[idx] - image_name = properties.get('name', None) - image_extension = properties.get('extension', None) + if job_file_mapping: + sorted_media_files.extend(itertools.chain.from_iterable(job_file_mapping)) + else: + # we should sort media_files according to the manifest content sequence + # and we should do this in general after validation step for 3D data and after filtering from related_images + manifest = ImageManifestManager(db_data.get_manifest_path()) + manifest.set_index() + + for idx in range(len(extractor.absolute_source_paths)): + properties = manifest[idx] + image_name = properties.get('name', None) + image_extension = properties.get('extension', None) + + full_image_path = f"{image_name}{image_extension}" if image_name and image_extension else None + if full_image_path: + sorted_media_files.append(full_image_path) + + sorted_media_files = [os.path.join(upload_dir, fn) for fn in sorted_media_files] + + for file_path in sorted_media_files: + if not file_path in extractor: + raise ValidationError( + f"Can't find file '{os.path.basename(file_path)}' in the input files" + ) - full_image_path = os.path.join(upload_dir, f"{image_name}{image_extension}") if image_name and image_extension else None - if full_image_path and full_image_path in extractor: - sorted_media_files.append(full_image_path) media_files = sorted_media_files.copy() del sorted_media_files + data['sorting_method'] = models.SortingMethod.PREDEFINED extractor.reconcile( source_files=media_files, @@ -720,4 +832,4 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False): db_data.start_frame + (db_data.size - 1) * db_data.get_frame_step()) slogger.glob.info("Found frames {} for Data #{}".format(db_data.size, db_data.id)) - _save_task_to_db(db_task, extractor) + _save_task_to_db(db_task, job_file_mapping=job_file_mapping) diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 8685ba836..1bb8dd6bb 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -932,6 +932,9 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, self._object.save() data = {k: v for k, v in serializer.data.items()} + if 'job_file_mapping' in serializer.validated_data: + data['job_file_mapping'] = serializer.validated_data['job_file_mapping'] + data['use_zip_chunks'] = serializer.validated_data['use_zip_chunks'] data['use_cache'] = serializer.validated_data['use_cache'] data['copy_data'] = serializer.validated_data['copy_data'] diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py index 7b5bdc09b..853435a34 100644 --- a/tests/python/rest_api/test_tasks.py +++ b/tests/python/rest_api/test_tasks.py @@ -10,18 +10,22 @@ import subprocess from copy import deepcopy from functools import partial from http import HTTPStatus +from itertools import chain +from pathlib import Path from tempfile import TemporaryDirectory from time import sleep import pytest +from cvat_sdk import Client, Config from cvat_sdk.api_client import apis, models from cvat_sdk.core.helpers import get_paginated_collection +from cvat_sdk.core.proxies.tasks import ResourceType, Task from deepdiff import DeepDiff from PIL import Image import shared.utils.s3 as s3 from shared.fixtures.init import get_server_image_tag -from shared.utils.config import get_method, make_api_client, patch_method +from shared.utils.config import BASE_URL, USER_PASS, get_method, make_api_client, patch_method from shared.utils.helpers import generate_image_files from .utils import export_dataset @@ -431,6 +435,23 @@ class TestPostTaskData: (task, response) = api_client.tasks_api.create(spec, **kwargs) assert response.status == HTTPStatus.CREATED + if data.get("client_files") and "json" in content_type: + # Can't encode binary files in json + (_, response) = api_client.tasks_api.create_data( + task.id, + data_request=models.DataRequest( + client_files=data["client_files"], + image_quality=data["image_quality"], + ), + upload_multiple=True, + _content_type="multipart/form-data", + **kwargs, + ) + assert response.status == HTTPStatus.OK + + data = data.copy() + del data["client_files"] + (_, response) = api_client.tasks_api.create_data( task.id, data_request=deepcopy(data), _content_type=content_type, **kwargs ) @@ -833,6 +854,49 @@ class TestPostTaskData: status = self._test_cannot_create_task(self._USERNAME, task_spec, data_spec) assert "No media data found" in status.message + def test_can_specify_file_job_mapping(self): + task_spec = { + "name": f"test file-job mapping", + "labels": [{"name": "car"}], + } + + files = generate_image_files(7) + filenames = [osp.basename(f.name) for f in files] + expected_segments = [ + filenames[0:1], + filenames[1:5][::-1], # a reversed fragment + filenames[5:7], + ] + + data_spec = { + "image_quality": 75, + "client_files": files, + "job_file_mapping": expected_segments, + } + + task_id = self._test_create_task( + self._USERNAME, task_spec, data_spec, content_type="application/json" + ) + + with make_api_client(self._USERNAME) as api_client: + (task, _) = api_client.tasks_api.retrieve(id=task_id) + (task_meta, _) = api_client.tasks_api.retrieve_data_meta(id=task_id) + + assert [f.name for f in task_meta.frames] == list( + chain.from_iterable(expected_segments) + ) + + assert len(task.segments) == len(expected_segments) + + start_frame = 0 + for i, segment in enumerate(task.segments): + expected_size = len(expected_segments[i]) + stop_frame = start_frame + expected_size - 1 + assert segment.start_frame == start_frame + assert segment.stop_frame == stop_frame + + start_frame = stop_frame + 1 + @pytest.mark.usefixtures("restore_db_per_function") @pytest.mark.usefixtures("restore_cvat_data") @@ -952,3 +1016,91 @@ class TestGetTaskPreview: assert len(tasks) self._test_assigned_users_cannot_see_task_preview(tasks, users, is_task_staff) + + +class TestUnequalJobs: + def _make_client(self) -> Client: + return Client(BASE_URL, config=Config(status_check_period=0.01)) + + @pytest.fixture(autouse=True) + def setup(self, restore_db_per_function, tmp_path: Path, admin_user: str): + self.tmp_dir = tmp_path + + self.client = self._make_client() + self.user = admin_user + + with self.client: + self.client.login((self.user, USER_PASS)) + + @pytest.fixture + def fxt_task_with_unequal_jobs(self): + task_spec = { + "name": f"test file-job mapping", + "labels": [{"name": "car"}], + } + + files = generate_image_files(7) + filenames = [osp.basename(f.name) for f in files] + for file_data in files: + with open(self.tmp_dir / file_data.name, "wb") as f: + f.write(file_data.getvalue()) + + expected_segments = [ + filenames[0:1], + filenames[1:5][::-1], # a reversed fragment + filenames[5:7], + ] + + data_spec = { + "job_file_mapping": expected_segments, + } + + return self.client.tasks.create_from_data( + spec=task_spec, + resource_type=ResourceType.LOCAL, + resources=[self.tmp_dir / fn for fn in filenames], + data_params=data_spec, + ) + + def test_can_export(self, fxt_task_with_unequal_jobs: Task): + task = fxt_task_with_unequal_jobs + + filename = self.tmp_dir / f"task_{task.id}_coco.zip" + task.export_dataset("COCO 1.0", filename) + + assert filename.is_file() + assert filename.stat().st_size > 0 + + def test_can_import_annotations(self, fxt_task_with_unequal_jobs: Task): + task = fxt_task_with_unequal_jobs + + format_name = "COCO 1.0" + filename = self.tmp_dir / f"task_{task.id}_coco.zip" + task.export_dataset(format_name, filename) + + task.import_annotations(format_name, filename) + + def test_can_dump_backup(self, fxt_task_with_unequal_jobs: Task): + task = fxt_task_with_unequal_jobs + + filename = self.tmp_dir / f"task_{task.id}_backup.zip" + task.download_backup(filename) + + assert filename.is_file() + assert filename.stat().st_size > 0 + + def test_can_import_backup(self, fxt_task_with_unequal_jobs: Task): + task = fxt_task_with_unequal_jobs + + filename = self.tmp_dir / f"task_{task.id}_backup.zip" + task.download_backup(filename) + + restored_task = self.client.tasks.create_from_backup(filename) + + old_jobs = task.get_jobs() + new_jobs = restored_task.get_jobs() + assert len(old_jobs) == len(new_jobs) + + for old_job, new_job in zip(old_jobs, new_jobs): + assert old_job.start_frame == new_job.start_frame + assert old_job.stop_frame == new_job.stop_frame -- GitLab