未验证 提交 c808471d 编写于 作者: M Maxim Zhiltsov 提交者: GitHub

Improve SDK UX with task creation (#5502)

Extracted from https://github.com/opencv/cvat/pull/5083

- Added a default arg for task data uploading
- Added an option to wait for the data processing in task data uploading
- Moved data splitting by requests for TUS closer to the point of use
上级 1d00e515
......@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[2.4.0] - Unreleased
### Added
- \[SDK\] An arg to wait for data processing in the task data uploading function
(<https://github.com/opencv/cvat/pull/5502>)
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>, <https://github.com/opencv/cvat/pull/5525>)
- \[SDK\] Configuration setting to change the dataset cache directory
(<https://github.com/opencv/cvat/pull/5535>)
......@@ -17,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The Docker Compose files now use the Compose Specification version
of the format. This version is supported by Docker Compose 1.27.0+
(<https://github.com/opencv/cvat/pull/5524>).
- \[SDK\] The `resource_type` args now have the default value of `local` in task creation functions.
The corresponding arguments are keyword-only now.
(<https://github.com/opencv/cvat/pull/5502>)
### Deprecated
- TDB
......
......@@ -38,9 +38,9 @@ class CLI:
self,
name: str,
labels: List[Dict[str, str]],
resource_type: ResourceType,
resources: Sequence[str],
*,
resource_type: ResourceType = ResourceType.LOCAL,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
status_check_period: int = 2,
......
......@@ -65,11 +65,13 @@ class Task(
def upload_data(
self,
resource_type: ResourceType,
resources: Sequence[StrPath],
*,
resource_type: ResourceType = ResourceType.LOCAL,
pbar: Optional[ProgressReporter] = None,
params: Optional[Dict[str, Any]] = None,
wait_for_completion: bool = True,
status_check_period: Optional[int] = None,
) -> None:
"""
Add local, remote, or shared files to an existing task.
......@@ -121,6 +123,37 @@ class Task(
url, list(map(Path, resources)), pbar=pbar, **data
)
if wait_for_completion:
if status_check_period is None:
status_check_period = self._client.config.status_check_period
self._client.logger.info("Awaiting for task %s creation...", self.id)
while True:
sleep(status_check_period)
(status, response) = self.api.retrieve_status(self.id)
self._client.logger.info(
"Task %s creation status: %s (message=%s)",
self.id,
status.state.value,
status.message,
)
if (
status.state.value
== models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]
):
break
elif (
status.state.value
== models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]
):
raise exceptions.ApiException(
status=status.state.value, reason=status.message, http_resp=response
)
self.fetch()
def import_annotations(
self,
format_name: str,
......@@ -296,9 +329,9 @@ class TasksRepo(
def create_from_data(
self,
spec: models.ITaskWriteRequest,
resource_type: ResourceType,
resources: Sequence[str],
*,
resource_type: ResourceType = ResourceType.LOCAL,
data_params: Optional[Dict[str, Any]] = None,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
......@@ -313,9 +346,6 @@ class TasksRepo(
Returns: id of the created task
"""
if status_check_period is None:
status_check_period = self._client.config.status_check_period
if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
raise exceptions.ApiValueError(
"Can't set labels to a task inside a project. "
......@@ -326,27 +356,14 @@ class TasksRepo(
task = self.create(spec=spec)
self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
self._client.logger.info("Awaiting for task %s creation...", task.id)
status: models.RqStatus = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
sleep(status_check_period)
(status, response) = self.api.retrieve_status(task.id)
self._client.logger.info(
"Task %s creation status=%s, message=%s",
task.id,
status.state.value,
status.message,
)
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
raise exceptions.ApiException(
status=status.state.value, reason=status.message, http_resp=response
)
status = status.state.value
task.upload_data(
resource_type=resource_type,
resources=resources,
pbar=pbar,
params=data_params,
wait_for_completion=True,
status_check_period=status_check_period,
)
if annotation_path:
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
......
......@@ -5,7 +5,6 @@
from __future__ import annotations
import os
from contextlib import ExitStack, closing
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
......@@ -206,40 +205,6 @@ class Uploader:
positive_statuses=positive_statuses,
)
def _split_files_by_requests(
self, filenames: List[Path]
) -> Tuple[List[Tuple[List[Path], int]], List[Path], int]:
bulk_files: Dict[str, int] = {}
separate_files: Dict[str, int] = {}
# sort by size
for filename in filenames:
filename = filename.resolve()
file_size = filename.stat().st_size
if MAX_REQUEST_SIZE < file_size:
separate_files[filename] = file_size
else:
bulk_files[filename] = file_size
total_size = sum(bulk_files.values()) + sum(separate_files.values())
# group small files by requests
bulk_file_groups: List[Tuple[List[str], int]] = []
current_group_size: int = 0
current_group: List[str] = []
for filename, file_size in bulk_files.items():
if MAX_REQUEST_SIZE < current_group_size + file_size:
bulk_file_groups.append((current_group, current_group_size))
current_group_size = 0
current_group = []
current_group.append(filename)
current_group_size += file_size
if current_group:
bulk_file_groups.append((current_group, current_group_size))
return bulk_file_groups, separate_files, total_size
@staticmethod
def _make_tus_uploader(api_client: ApiClient, url: str, **kwargs):
# Add headers required by CVAT server
......@@ -353,6 +318,10 @@ class DatasetUploader(Uploader):
class DataUploader(Uploader):
def __init__(self, client: Client, *, max_request_size: int = MAX_REQUEST_SIZE):
super().__init__(client)
self.max_request_size = max_request_size
def upload_files(
self,
url: str,
......@@ -369,22 +338,21 @@ class DataUploader(Uploader):
self._tus_start_upload(url)
for group, group_size in bulk_file_groups:
with ExitStack() as es:
files = {}
for i, filename in enumerate(group):
files[f"client_files[{i}]"] = (
os.fspath(filename),
es.enter_context(closing(open(filename, "rb"))).read(),
)
response = self._client.api_client.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
**self._client.api_client.get_common_headers(),
},
files = {}
for i, filename in enumerate(group):
files[f"client_files[{i}]"] = (
os.fspath(filename),
filename.read_bytes(),
)
response = self._client.api_client.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
**self._client.api_client.get_common_headers(),
},
)
expect_status(200, response)
if pbar is not None:
......@@ -401,3 +369,38 @@ class DataUploader(Uploader):
)
self._tus_finish_upload(url, fields=kwargs)
def _split_files_by_requests(
self, filenames: List[Path]
) -> Tuple[List[Tuple[List[Path], int]], List[Path], int]:
bulk_files: Dict[str, int] = {}
separate_files: Dict[str, int] = {}
max_request_size = self.max_request_size
# sort by size
for filename in filenames:
filename = filename.resolve()
file_size = filename.stat().st_size
if max_request_size < file_size:
separate_files[filename] = file_size
else:
bulk_files[filename] = file_size
total_size = sum(bulk_files.values()) + sum(separate_files.values())
# group small files by requests
bulk_file_groups: List[Tuple[List[str], int]] = []
current_group_size: int = 0
current_group: List[str] = []
for filename, file_size in bulk_files.items():
if max_request_size < current_group_size + file_size:
bulk_file_groups.append((current_group, current_group_size))
current_group_size = 0
current_group = []
current_group.append(filename)
current_group_size += file_size
if current_group:
bulk_file_groups.append((current_group, current_group_size))
return bulk_file_groups, separate_files, total_size
......@@ -70,8 +70,8 @@ class TestTaskVisionDataset:
models.PatchedLabelRequest(name="car"),
],
),
ResourceType.LOCAL,
list(map(os.fspath, image_paths)),
resource_type=ResourceType.LOCAL,
resources=list(map(os.fspath, image_paths)),
data_params={"chunk_size": 3},
)
......@@ -274,8 +274,8 @@ class TestProjectVisionDataset:
project_id=self.project.id,
subset=subset,
),
ResourceType.LOCAL,
image_paths,
resource_type=ResourceType.LOCAL,
resources=image_paths,
data_params={"image_quality": 70},
)
for subset, image_paths in zip(subsets, image_paths_per_task)
......
......@@ -58,7 +58,6 @@ class TestTaskUsecases:
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
},
resource_type=ResourceType.LOCAL,
resources=[fxt_image_file],
data_params={"image_quality": 80},
)
......@@ -202,6 +201,38 @@ class TestTaskUsecases:
assert response_json["format"] == "CVAT for images 1.1"
assert response_json["lfs"] is False
def test_can_upload_data_to_empty_task(self):
pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out)
task = self.client.tasks.create(
{
"name": f"test task",
"labels": [{"name": "car"}],
}
)
data_params = {
"image_quality": 75,
}
task_files = generate_image_files(7)
for i, f in enumerate(task_files):
fname = self.tmp_path / f.name
fname.write_bytes(f.getvalue())
task_files[i] = fname
task.upload_data(
resources=task_files,
resource_type=ResourceType.LOCAL,
params=data_params,
pbar=pbar,
)
assert task.size == 7
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert self.stdout.getvalue() == ""
def test_can_retrieve_task(self, fxt_new_task: Task):
task_id = fxt_new_task.id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册