未验证 提交 31f05782 编写于 作者: M Maxim Zhiltsov 提交者: GitHub

Add a way to create task with custom jobs (#5536)

This PR adds an option to specify file to job mapping explicitly during
task creation. This option is incompatible with most other job-related
parameters like `sorting_method` and `frame_step`.

- Added a new task creation parameter (`job_file_mapping`) to set a
custom file to job mapping during task creation
上级 b00bc653
......@@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/opencv/cvat/pull/5535>)
- \[SDK\] Class to represent a project as a PyTorch dataset
(<https://github.com/opencv/cvat/pull/5523>)
- Support for custom file to job splits in tasks (server API & SDK only)
(<https://github.com/opencv/cvat/pull/5536>)
- \[SDK\] A PyTorch adapter setting to disable cache updates
(<https://github.com/opencv/cvat/pull/5549>)
- YOLO v7 serverless feature added using ONNX backend (<https://github.com/opencv/cvat/pull/5552>)
......
......@@ -92,6 +92,7 @@ class Task(
"stop_frame",
"use_cache",
"use_zip_chunks",
"job_file_mapping",
"filename_pattern",
"cloud_storage_id",
],
......
......@@ -346,7 +346,7 @@ class DataUploader(Uploader):
)
response = self._client.api_client.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
post_params={"image_quality": kwargs["image_quality"], **files},
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
......
......@@ -9,6 +9,7 @@ from enum import Enum
import re
import shutil
import tempfile
from typing import Any, Dict, Iterable
import uuid
from zipfile import ZipFile
from datetime import datetime
......@@ -37,7 +38,7 @@ from cvat.apps.engine.utils import av_scan_paths, process_failed_job, configure_
from cvat.apps.engine.models import (
StorageChoice, StorageMethodChoice, DataChoice, Task, Project, Location,
CloudStorage as CloudStorageModel)
from cvat.apps.engine.task import _create_thread
from cvat.apps.engine.task import JobFileMapping, _create_thread
from cvat.apps.dataset_manager.views import TASK_CACHE_TTL, PROJECT_CACHE_TTL, get_export_cache_dir, clear_export_cache, log_exception
from cvat.apps.dataset_manager.bindings import CvatImportError
from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance
......@@ -133,6 +134,8 @@ class _TaskBackupBase(_BackupBase):
'storage',
'sorting_method',
'deleted_frames',
'custom_segments',
'job_file_mapping',
}
self._prepare_meta(allowed_fields, data)
......@@ -331,6 +334,9 @@ class TaskExporter(_ExporterBase, _TaskBackupBase):
segment = segment_serailizer.data
segment.update(job_data)
if self._db_task.segment_size == 0:
segment.update(serialize_custom_file_mapping(db_segment))
return segment
def serialize_jobs():
......@@ -338,12 +344,28 @@ class TaskExporter(_ExporterBase, _TaskBackupBase):
db_segments.sort(key=lambda i: i.job_set.first().id)
return (serialize_segment(s) for s in db_segments)
def serialize_custom_file_mapping(db_segment: models.Segment):
if self._db_task.mode == 'annotation':
files: Iterable[models.Image] = self._db_data.images.all().order_by('frame')
segment_files = files[db_segment.start_frame : db_segment.stop_frame + 1]
return {'files': list(frame.path for frame in segment_files)}
else:
assert False, (
"Backups with custom file mapping are not supported"
" in the 'interpolation' task mode"
)
def serialize_data():
data_serializer = DataSerializer(self._db_data)
data = data_serializer.data
data['chunk_type'] = data.pop('compressed_chunk_type')
# There are no deleted frames in DataSerializer so we need to pick it
data['deleted_frames'] = self._db_data.deleted_frames
if self._db_task.segment_size == 0:
data['custom_segments'] = True
return self._prepare_data_meta(data)
task = serialize_task()
......@@ -491,6 +513,20 @@ class TaskImporter(_ImporterBase, _TaskBackupBase):
return segment_size, overlap
@staticmethod
def _parse_custom_segments(*, jobs: Dict[str, Any]) -> JobFileMapping:
segments = []
for i, segment in enumerate(jobs):
segment_size = segment['stop_frame'] - segment['start_frame'] + 1
segment_files = segment['files']
if len(segment_files) != segment_size:
raise ValidationError(f"segment {i}: segment files do not match segment size")
segments.append(segment_files)
return segments
def _import_task(self):
def _write_data(zip_object):
data_path = self._db_task.data.get_upload_dirname()
......@@ -519,10 +555,23 @@ class TaskImporter(_ImporterBase, _TaskBackupBase):
jobs = self._manifest.pop('jobs')
self._prepare_task_meta(self._manifest)
self._manifest['segment_size'], self._manifest['overlap'] = self._calculate_segment_size(jobs)
self._manifest['owner_id'] = self._user_id
self._manifest['project_id'] = self._project_id
if custom_segments := data.pop('custom_segments', False):
job_file_mapping = self._parse_custom_segments(jobs=jobs)
data['job_file_mapping'] = job_file_mapping
for d in [self._manifest, data]:
for k in [
'segment_size', 'overlap', 'start_frame', 'stop_frame',
'sorting_method', 'frame_filter', 'filename_pattern'
]:
d.pop(k, None)
else:
self._manifest['segment_size'], self._manifest['overlap'] = \
self._calculate_segment_size(jobs)
self._db_task = models.Task.objects.create(**self._manifest, organization_id=self._org_id)
task_path = self._db_task.get_dirname()
if os.path.isdir(task_path):
......@@ -550,7 +599,10 @@ class TaskImporter(_ImporterBase, _TaskBackupBase):
data['use_zip_chunks'] = data.pop('chunk_type') == DataChoice.IMAGESET
data = data_serializer.data
data['client_files'] = uploaded_files
_create_thread(self._db_task.pk, data.copy(), True)
if custom_segments:
data['job_file_mapping'] = job_file_mapping
_create_thread(self._db_task.pk, data.copy(), isBackupRestore=True)
db_data.start_frame = data['start_frame']
db_data.stop_frame = data['stop_frame']
db_data.frame_filter = data['frame_filter']
......
......@@ -169,7 +169,7 @@ class ImageListReader(IMediaReader):
if not source_path:
raise Exception('No image found')
if stop is None:
if not stop:
stop = len(source_path)
else:
stop = min(len(source_path), stop + 1)
......
......@@ -344,6 +344,7 @@ class Task(models.Model):
updated_date = models.DateTimeField(auto_now=True)
overlap = models.PositiveIntegerField(null=True)
# Zero means that there are no limits (default)
# Note that the files can be split into jobs in a custom way in this case
segment_size = models.PositiveIntegerField(default=0)
status = models.CharField(max_length=32, choices=StatusChoice.choices(),
default=StatusChoice.ANNOTATION)
......
......@@ -8,6 +8,7 @@ import re
import shutil
from tempfile import NamedTemporaryFile
import textwrap
from typing import OrderedDict
from rest_framework import serializers, exceptions
......@@ -362,6 +363,41 @@ class WriteOnceMixin:
return extra_kwargs
class JobFiles(serializers.ListField):
"""
Read JobFileMapping docs for more info.
"""
def __init__(self, *args, **kwargs):
kwargs.setdefault('child', serializers.CharField(allow_blank=False, max_length=1024))
kwargs.setdefault('allow_empty', False)
super().__init__(*args, **kwargs)
class JobFileMapping(serializers.ListField):
"""
Represents a file-to-job mapping. Useful to specify a custom job
configuration during task creation. This option is not compatible with
most other job split-related options.
Example:
[
["file1.jpg", "file2.jpg"], # job #1 files
["file3.png"], # job #2 files
["file4.jpg", "file5.png", "file6.bmp"], # job #3 files
]
Files in the jobs must not overlap and repeat.
"""
def __init__(self, *args, **kwargs):
kwargs.setdefault('child', JobFiles())
kwargs.setdefault('allow_empty', False)
kwargs.setdefault('help_text', textwrap.dedent(__class__.__doc__))
super().__init__(*args, **kwargs)
class DataSerializer(WriteOnceMixin, serializers.ModelSerializer):
image_quality = serializers.IntegerField(min_value=0, max_value=100)
use_zip_chunks = serializers.BooleanField(default=False)
......@@ -372,12 +408,14 @@ class DataSerializer(WriteOnceMixin, serializers.ModelSerializer):
copy_data = serializers.BooleanField(default=False)
cloud_storage_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
filename_pattern = serializers.CharField(allow_null=True, required=False)
job_file_mapping = JobFileMapping(required=False, write_only=True)
class Meta:
model = models.Data
fields = ('chunk_size', 'size', 'image_quality', 'start_frame', 'stop_frame', 'frame_filter',
'compressed_chunk_type', 'original_chunk_type', 'client_files', 'server_files', 'remote_files', 'use_zip_chunks',
'cloud_storage_id', 'use_cache', 'copy_data', 'storage_method', 'storage', 'sorting_method', 'filename_pattern')
'cloud_storage_id', 'use_cache', 'copy_data', 'storage_method', 'storage', 'sorting_method', 'filename_pattern',
'job_file_mapping')
# pylint: disable=no-self-use
def validate_frame_filter(self, value):
......@@ -392,6 +430,21 @@ class DataSerializer(WriteOnceMixin, serializers.ModelSerializer):
raise serializers.ValidationError('Chunk size must be a positive integer')
return value
def validate_job_file_mapping(self, value):
existing_files = set()
for job_files in value:
for filename in job_files:
if filename in existing_files:
raise serializers.ValidationError(
f"The same file '{filename}' cannot be used multiple "
"times in the job file mapping"
)
existing_files.add(filename)
return value
# pylint: disable=no-self-use
def validate(self, attrs):
if 'start_frame' in attrs and 'stop_frame' in attrs \
......@@ -402,6 +455,7 @@ class DataSerializer(WriteOnceMixin, serializers.ModelSerializer):
def create(self, validated_data):
files = self._pop_data(validated_data)
db_data = models.Data.objects.create(**validated_data)
db_data.make_dirs()
......@@ -424,6 +478,8 @@ class DataSerializer(WriteOnceMixin, serializers.ModelSerializer):
server_files = validated_data.pop('server_files')
remote_files = validated_data.pop('remote_files')
validated_data.pop('job_file_mapping', None) # optional
for extra_key in { 'use_zip_chunks', 'use_cache', 'copy_data' }:
validated_data.pop(extra_key)
......
......@@ -8,6 +8,7 @@ import itertools
import fnmatch
import os
import sys
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Union
from rest_framework.serializers import ValidationError
import rq
import re
......@@ -60,6 +61,17 @@ def rq_handler(job, exc_type, exc_value, traceback):
############################# Internal implementation for server API
JobFileMapping = List[List[str]]
class SegmentParams(NamedTuple):
start_frame: int
stop_frame: int
class SegmentsParams(NamedTuple):
segments: Iterator[SegmentParams]
segment_size: int
overlap: int
def _copy_data_from_source(server_files, upload_dir, server_dir=None):
job = rq.get_current_job()
job.meta['status'] = 'Data are being copied from source..'
......@@ -79,37 +91,68 @@ def _copy_data_from_source(server_files, upload_dir, server_dir=None):
os.makedirs(target_dir)
shutil.copyfile(source_path, target_path)
def _get_task_segment_data(db_task, data_size):
segment_size = db_task.segment_size
segment_step = segment_size
if segment_size == 0 or segment_size > data_size:
segment_size = data_size
def _get_task_segment_data(
db_task: models.Task,
*,
data_size: Optional[int] = None,
job_file_mapping: Optional[JobFileMapping] = None,
) -> SegmentsParams:
if job_file_mapping is not None:
def _segments():
# It is assumed here that files are already saved ordered in the task
# Here we just need to create segments by the job sizes
start_frame = 0
for jf in job_file_mapping:
segment_size = len(jf)
stop_frame = start_frame + segment_size - 1
yield SegmentParams(start_frame, stop_frame)
start_frame = stop_frame + 1
segments = _segments()
segment_size = 0
overlap = 0
else:
# The segments have equal parameters
if data_size is None:
data_size = db_task.data.size
segment_size = db_task.segment_size
segment_step = segment_size
if segment_size == 0 or segment_size > data_size:
segment_size = data_size
# Segment step must be more than segment_size + overlap in single-segment tasks
# Otherwise a task contains an extra segment
segment_step = sys.maxsize
# Segment step must be more than segment_size + overlap in single-segment tasks
# Otherwise a task contains an extra segment
segment_step = sys.maxsize
overlap = 5 if db_task.mode == 'interpolation' else 0
if db_task.overlap is not None:
overlap = min(db_task.overlap, segment_size // 2)
overlap = 5 if db_task.mode == 'interpolation' else 0
if db_task.overlap is not None:
overlap = min(db_task.overlap, segment_size // 2)
segment_step -= overlap
segments = (
SegmentParams(start_frame, min(start_frame + segment_size - 1, data_size - 1))
for start_frame in range(0, data_size, segment_step)
)
segment_step -= overlap
return segment_step, segment_size, overlap
return SegmentsParams(segments, segment_size, overlap)
def _save_task_to_db(db_task, extractor):
def _save_task_to_db(db_task: models.Task, *, job_file_mapping: Optional[JobFileMapping] = None):
job = rq.get_current_job()
job.meta['status'] = 'Task is being saved in database'
job.save_meta()
segment_step, segment_size, overlap = _get_task_segment_data(db_task, db_task.data.size)
segments, segment_size, overlap = _get_task_segment_data(
db_task=db_task, job_file_mapping=job_file_mapping
)
db_task.segment_size = segment_size
db_task.overlap = overlap
for start_frame in range(0, db_task.data.size, segment_step):
stop_frame = min(start_frame + segment_size - 1, db_task.data.size - 1)
slogger.glob.info("New segment for task #{}: start_frame = {}, \
stop_frame = {}".format(db_task.id, start_frame, stop_frame))
for segment_idx, (start_frame, stop_frame) in enumerate(segments):
slogger.glob.info("New segment for task #{}: idx = {}, start_frame = {}, \
stop_frame = {}".format(db_task.id, segment_idx, start_frame, stop_frame))
db_segment = models.Segment()
db_segment.task = db_task
......@@ -214,6 +257,41 @@ def _validate_data(counter, manifest_files=None):
return counter, task_modes[0]
def _validate_job_file_mapping(
db_task: models.Task, data: Dict[str, Any]
) -> Optional[JobFileMapping]:
job_file_mapping = data.get('job_file_mapping', None)
if job_file_mapping is None:
return None
elif not list(itertools.chain.from_iterable(job_file_mapping)):
raise ValidationError("job_file_mapping cannot be empty")
if db_task.segment_size:
raise ValidationError("job_file_mapping cannot be used with segment_size")
if (data.get('sorting_method', db_task.data.sorting_method)
!= models.SortingMethod.LEXICOGRAPHICAL
):
raise ValidationError("job_file_mapping cannot be used with sorting_method")
if data.get('start_frame', db_task.data.start_frame):
raise ValidationError("job_file_mapping cannot be used with start_frame")
if data.get('stop_frame', db_task.data.stop_frame):
raise ValidationError("job_file_mapping cannot be used with stop_frame")
if data.get('frame_filter', db_task.data.frame_filter):
raise ValidationError("job_file_mapping cannot be used with frame_filter")
if db_task.data.get_frame_step() != 1:
raise ValidationError("job_file_mapping cannot be used with frame step")
if data.get('filename_pattern'):
raise ValidationError("job_file_mapping cannot be used with filename_pattern")
return job_file_mapping
def _validate_manifest(manifests, root_dir, is_in_cloud, db_cloud_storage, data_storage_method):
if manifests:
if len(manifests) != 1:
......@@ -325,12 +403,20 @@ def _create_task_manifest_based_on_cloud_storage_manifest(
manifest.create(sorted_content)
@transaction.atomic
def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False):
def _create_thread(
db_task: Union[int, models.Task],
data: Dict[str, Any],
*,
isBackupRestore: bool = False,
isDatasetImport: bool = False,
) -> None:
if isinstance(db_task, int):
db_task = models.Task.objects.select_for_update().get(pk=db_task)
slogger.glob.info("create task #{}".format(db_task.id))
job_file_mapping = _validate_job_file_mapping(db_task, data)
db_data = db_task.data
upload_dir = db_data.get_upload_dirname() if db_data.storage != models.StorageChoice.SHARE else settings.SHARE_ROOT
is_data_in_cloud = db_data.storage == models.StorageChoice.CLOUD_STORAGE
......@@ -387,11 +473,17 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False):
media = _count_files(data)
media, task_mode = _validate_data(media, manifest_files)
if job_file_mapping is not None and task_mode != 'annotation':
raise ValidationError("job_file_mapping can't be used with sequence-based data like videos")
if data['server_files']:
if db_data.storage == models.StorageChoice.LOCAL:
_copy_data_from_source(data['server_files'], upload_dir, data.get('server_files_path'))
elif is_data_in_cloud:
sorted_media = sort(media['image'], data['sorting_method'])
if job_file_mapping is not None:
sorted_media = list(itertools.chain.from_iterable(job_file_mapping))
else:
sorted_media = sort(media['image'], data['sorting_method'])
# Define task manifest content based on cloud storage manifest content and uploaded files
_create_task_manifest_based_on_cloud_storage_manifest(
......@@ -486,24 +578,44 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False):
extractor.filter(lambda x: not re.search(r'(^|{0})related_images{0}'.format(os.sep), x))
related_images = detect_related_images(extractor.absolute_source_paths, upload_dir)
if isBackupRestore and not isinstance(extractor, MEDIA_TYPES['video']['extractor']) and db_data.storage_method == models.StorageMethodChoice.CACHE and \
db_data.sorting_method in {models.SortingMethod.RANDOM, models.SortingMethod.PREDEFINED} and validate_dimension.dimension != models.DimensionType.DIM_3D:
# we should sort media_files according to the manifest content sequence
# and we should do this in general after validation step for 3D data and after filtering from related_images
manifest = ImageManifestManager(db_data.get_manifest_path())
manifest.set_index()
# Sort the files
if (isBackupRestore and (
not isinstance(extractor, MEDIA_TYPES['video']['extractor'])
and db_data.storage_method == models.StorageMethodChoice.CACHE
and db_data.sorting_method in {models.SortingMethod.RANDOM, models.SortingMethod.PREDEFINED}
and validate_dimension.dimension != models.DimensionType.DIM_3D
) or job_file_mapping
):
sorted_media_files = []
for idx in range(len(extractor.absolute_source_paths)):
properties = manifest[idx]
image_name = properties.get('name', None)
image_extension = properties.get('extension', None)
if job_file_mapping:
sorted_media_files.extend(itertools.chain.from_iterable(job_file_mapping))
else:
# we should sort media_files according to the manifest content sequence
# and we should do this in general after validation step for 3D data and after filtering from related_images
manifest = ImageManifestManager(db_data.get_manifest_path())
manifest.set_index()
for idx in range(len(extractor.absolute_source_paths)):
properties = manifest[idx]
image_name = properties.get('name', None)
image_extension = properties.get('extension', None)
full_image_path = f"{image_name}{image_extension}" if image_name and image_extension else None
if full_image_path:
sorted_media_files.append(full_image_path)
sorted_media_files = [os.path.join(upload_dir, fn) for fn in sorted_media_files]
for file_path in sorted_media_files:
if not file_path in extractor:
raise ValidationError(
f"Can't find file '{os.path.basename(file_path)}' in the input files"
)
full_image_path = os.path.join(upload_dir, f"{image_name}{image_extension}") if image_name and image_extension else None
if full_image_path and full_image_path in extractor:
sorted_media_files.append(full_image_path)
media_files = sorted_media_files.copy()
del sorted_media_files
data['sorting_method'] = models.SortingMethod.PREDEFINED
extractor.reconcile(
source_files=media_files,
......@@ -720,4 +832,4 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False):
db_data.start_frame + (db_data.size - 1) * db_data.get_frame_step())
slogger.glob.info("Found frames {} for Data #{}".format(db_data.size, db_data.id))
_save_task_to_db(db_task, extractor)
_save_task_to_db(db_task, job_file_mapping=job_file_mapping)
......@@ -932,6 +932,9 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
self._object.save()
data = {k: v for k, v in serializer.data.items()}
if 'job_file_mapping' in serializer.validated_data:
data['job_file_mapping'] = serializer.validated_data['job_file_mapping']
data['use_zip_chunks'] = serializer.validated_data['use_zip_chunks']
data['use_cache'] = serializer.validated_data['use_cache']
data['copy_data'] = serializer.validated_data['copy_data']
......
......@@ -10,18 +10,22 @@ import subprocess
from copy import deepcopy
from functools import partial
from http import HTTPStatus
from itertools import chain
from pathlib import Path
from tempfile import TemporaryDirectory
from time import sleep
import pytest
from cvat_sdk import Client, Config
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from deepdiff import DeepDiff
from PIL import Image
import shared.utils.s3 as s3
from shared.fixtures.init import get_server_image_tag
from shared.utils.config import get_method, make_api_client, patch_method
from shared.utils.config import BASE_URL, USER_PASS, get_method, make_api_client, patch_method
from shared.utils.helpers import generate_image_files
from .utils import export_dataset
......@@ -431,6 +435,23 @@ class TestPostTaskData:
(task, response) = api_client.tasks_api.create(spec, **kwargs)
assert response.status == HTTPStatus.CREATED
if data.get("client_files") and "json" in content_type:
# Can't encode binary files in json
(_, response) = api_client.tasks_api.create_data(
task.id,
data_request=models.DataRequest(
client_files=data["client_files"],
image_quality=data["image_quality"],
),
upload_multiple=True,
_content_type="multipart/form-data",
**kwargs,
)
assert response.status == HTTPStatus.OK
data = data.copy()
del data["client_files"]
(_, response) = api_client.tasks_api.create_data(
task.id, data_request=deepcopy(data), _content_type=content_type, **kwargs
)
......@@ -833,6 +854,49 @@ class TestPostTaskData:
status = self._test_cannot_create_task(self._USERNAME, task_spec, data_spec)
assert "No media data found" in status.message
def test_can_specify_file_job_mapping(self):
task_spec = {
"name": f"test file-job mapping",
"labels": [{"name": "car"}],
}
files = generate_image_files(7)
filenames = [osp.basename(f.name) for f in files]
expected_segments = [
filenames[0:1],
filenames[1:5][::-1], # a reversed fragment
filenames[5:7],
]
data_spec = {
"image_quality": 75,
"client_files": files,
"job_file_mapping": expected_segments,
}
task_id = self._test_create_task(
self._USERNAME, task_spec, data_spec, content_type="application/json"
)
with make_api_client(self._USERNAME) as api_client:
(task, _) = api_client.tasks_api.retrieve(id=task_id)
(task_meta, _) = api_client.tasks_api.retrieve_data_meta(id=task_id)
assert [f.name for f in task_meta.frames] == list(
chain.from_iterable(expected_segments)
)
assert len(task.segments) == len(expected_segments)
start_frame = 0
for i, segment in enumerate(task.segments):
expected_size = len(expected_segments[i])
stop_frame = start_frame + expected_size - 1
assert segment.start_frame == start_frame
assert segment.stop_frame == stop_frame
start_frame = stop_frame + 1
@pytest.mark.usefixtures("restore_db_per_function")
@pytest.mark.usefixtures("restore_cvat_data")
......@@ -952,3 +1016,91 @@ class TestGetTaskPreview:
assert len(tasks)
self._test_assigned_users_cannot_see_task_preview(tasks, users, is_task_staff)
class TestUnequalJobs:
def _make_client(self) -> Client:
return Client(BASE_URL, config=Config(status_check_period=0.01))
@pytest.fixture(autouse=True)
def setup(self, restore_db_per_function, tmp_path: Path, admin_user: str):
self.tmp_dir = tmp_path
self.client = self._make_client()
self.user = admin_user
with self.client:
self.client.login((self.user, USER_PASS))
@pytest.fixture
def fxt_task_with_unequal_jobs(self):
task_spec = {
"name": f"test file-job mapping",
"labels": [{"name": "car"}],
}
files = generate_image_files(7)
filenames = [osp.basename(f.name) for f in files]
for file_data in files:
with open(self.tmp_dir / file_data.name, "wb") as f:
f.write(file_data.getvalue())
expected_segments = [
filenames[0:1],
filenames[1:5][::-1], # a reversed fragment
filenames[5:7],
]
data_spec = {
"job_file_mapping": expected_segments,
}
return self.client.tasks.create_from_data(
spec=task_spec,
resource_type=ResourceType.LOCAL,
resources=[self.tmp_dir / fn for fn in filenames],
data_params=data_spec,
)
def test_can_export(self, fxt_task_with_unequal_jobs: Task):
task = fxt_task_with_unequal_jobs
filename = self.tmp_dir / f"task_{task.id}_coco.zip"
task.export_dataset("COCO 1.0", filename)
assert filename.is_file()
assert filename.stat().st_size > 0
def test_can_import_annotations(self, fxt_task_with_unequal_jobs: Task):
task = fxt_task_with_unequal_jobs
format_name = "COCO 1.0"
filename = self.tmp_dir / f"task_{task.id}_coco.zip"
task.export_dataset(format_name, filename)
task.import_annotations(format_name, filename)
def test_can_dump_backup(self, fxt_task_with_unequal_jobs: Task):
task = fxt_task_with_unequal_jobs
filename = self.tmp_dir / f"task_{task.id}_backup.zip"
task.download_backup(filename)
assert filename.is_file()
assert filename.stat().st_size > 0
def test_can_import_backup(self, fxt_task_with_unequal_jobs: Task):
task = fxt_task_with_unequal_jobs
filename = self.tmp_dir / f"task_{task.id}_backup.zip"
task.download_backup(filename)
restored_task = self.client.tasks.create_from_backup(filename)
old_jobs = task.get_jobs()
new_jobs = restored_task.get_jobs()
assert len(old_jobs) == len(new_jobs)
for old_job, new_job in zip(old_jobs, new_jobs):
assert old_job.start_frame == new_job.start_frame
assert old_job.stop_frame == new_job.stop_frame
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册