Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
疯人忠
Cvat
提交
53697eca
C
Cvat
项目概览
疯人忠
/
Cvat
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
Cvat
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
53697eca
编写于
8月 26, 2022
作者:
M
Maxim Zhiltsov
提交者:
GitHub
8月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
SDK layer 2 - cover RC1 usecases (#4813)
上级
b60d3b48
变更
53
展开全部
显示空白变更内容
内联
并排
Showing
53 changed file
with
3196 addition
and
731 deletion
+3196
-731
.bandit
.bandit
+1
-0
.github/workflows/bandit.yml
.github/workflows/bandit.yml
+1
-1
CHANGELOG.md
CHANGELOG.md
+3
-2
cvat-cli/src/cvat_cli/cli.py
cvat-cli/src/cvat_cli/cli.py
+10
-11
cvat-cli/src/cvat_cli/parser.py
cvat-cli/src/cvat_cli/parser.py
+1
-1
cvat-sdk/.gitignore
cvat-sdk/.gitignore
+1
-1
cvat-sdk/cvat_sdk/core/client.py
cvat-sdk/cvat_sdk/core/client.py
+80
-137
cvat-sdk/cvat_sdk/core/downloading.py
cvat-sdk/cvat_sdk/core/downloading.py
+45
-6
cvat-sdk/cvat_sdk/core/git.py
cvat-sdk/cvat_sdk/core/git.py
+2
-3
cvat-sdk/cvat_sdk/core/helpers.py
cvat-sdk/cvat_sdk/core/helpers.py
+19
-3
cvat-sdk/cvat_sdk/core/proxies/__init__.py
cvat-sdk/cvat_sdk/core/proxies/__init__.py
+0
-0
cvat-sdk/cvat_sdk/core/proxies/annotations.py
cvat-sdk/cvat_sdk/core/proxies/annotations.py
+66
-0
cvat-sdk/cvat_sdk/core/proxies/issues.py
cvat-sdk/cvat_sdk/core/proxies/issues.py
+60
-0
cvat-sdk/cvat_sdk/core/proxies/jobs.py
cvat-sdk/cvat_sdk/core/proxies/jobs.py
+166
-0
cvat-sdk/cvat_sdk/core/proxies/model_proxy.py
cvat-sdk/cvat_sdk/core/proxies/model_proxy.py
+213
-0
cvat-sdk/cvat_sdk/core/proxies/projects.py
cvat-sdk/cvat_sdk/core/proxies/projects.py
+187
-0
cvat-sdk/cvat_sdk/core/proxies/tasks.py
cvat-sdk/cvat_sdk/core/proxies/tasks.py
+388
-0
cvat-sdk/cvat_sdk/core/proxies/users.py
cvat-sdk/cvat_sdk/core/proxies/users.py
+35
-0
cvat-sdk/cvat_sdk/core/types.py
cvat-sdk/cvat_sdk/core/types.py
+0
-18
cvat-sdk/cvat_sdk/core/uploading.py
cvat-sdk/cvat_sdk/core/uploading.py
+141
-61
cvat-sdk/cvat_sdk/core/utils.py
cvat-sdk/cvat_sdk/core/utils.py
+0
-8
cvat-sdk/gen/postprocess.py
cvat-sdk/gen/postprocess.py
+1
-1
cvat-sdk/gen/templates/openapi-generator/api_client.mustache
cvat-sdk/gen/templates/openapi-generator/api_client.mustache
+3
-0
cvat-sdk/gen/templates/openapi-generator/model.mustache
cvat-sdk/gen/templates/openapi-generator/model.mustache
+1
-0
cvat-sdk/gen/templates/openapi-generator/model_templates/model_normal.mustache
...s/openapi-generator/model_templates/model_normal.mustache
+1
-1
cvat-sdk/gen/templates/openapi-generator/model_templates/model_simple.mustache
...s/openapi-generator/model_templates/model_simple.mustache
+1
-1
cvat-sdk/gen/templates/openapi-generator/model_utils.mustache
...-sdk/gen/templates/openapi-generator/model_utils.mustache
+5
-0
cvat-sdk/gen/templates/requirements/base.txt
cvat-sdk/gen/templates/requirements/base.txt
+1
-0
cvat-ui/package.json
cvat-ui/package.json
+1
-1
cvat-ui/src/components/tasks-page/tasks-page.tsx
cvat-ui/src/components/tasks-page/tasks-page.tsx
+1
-1
cvat/apps/dataset_repo/views.py
cvat/apps/dataset_repo/views.py
+1
-3
cvat/apps/engine/filters.py
cvat/apps/engine/filters.py
+1
-1
cvat/apps/engine/mixins.py
cvat/apps/engine/mixins.py
+15
-1
cvat/apps/engine/schema.py
cvat/apps/engine/schema.py
+96
-16
cvat/apps/engine/serializers.py
cvat/apps/engine/serializers.py
+31
-16
cvat/apps/engine/tests/test_rest_api.py
cvat/apps/engine/tests/test_rest_api.py
+69
-35
cvat/apps/engine/views.py
cvat/apps/engine/views.py
+139
-87
cvat/utils/version.py
cvat/utils/version.py
+3
-4
tests/python/cli/test_cli.py
tests/python/cli/test_cli.py
+18
-22
tests/python/rest_api/test_auth.py
tests/python/rest_api/test_auth.py
+138
-0
tests/python/rest_api/test_issues.py
tests/python/rest_api/test_issues.py
+257
-69
tests/python/rest_api/test_jobs.py
tests/python/rest_api/test_jobs.py
+103
-72
tests/python/rest_api/test_projects.py
tests/python/rest_api/test_projects.py
+13
-27
tests/python/rest_api/test_tasks.py
tests/python/rest_api/test_tasks.py
+19
-30
tests/python/rest_api/utils.py
tests/python/rest_api/utils.py
+26
-0
tests/python/sdk/fixtures.py
tests/python/sdk/fixtures.py
+25
-0
tests/python/sdk/test_issues_comments.py
tests/python/sdk/test_issues_comments.py
+236
-0
tests/python/sdk/test_jobs.py
tests/python/sdk/test_jobs.py
+280
-0
tests/python/sdk/test_tasks.py
tests/python/sdk/test_tasks.py
+184
-61
tests/python/sdk/test_users.py
tests/python/sdk/test_users.py
+73
-0
tests/python/sdk/util.py
tests/python/sdk/util.py
+28
-28
tests/python/shared/fixtures/data.py
tests/python/shared/fixtures/data.py
+4
-0
tests/python/shared/utils/config.py
tests/python/shared/utils/config.py
+3
-2
未找到文件。
.bandit
浏览文件 @
53697eca
...
...
@@ -6,3 +6,4 @@
# B406 : import_xml_sax
# B410 : import_lxml
skips: B101,B102,B320,B404,B406,B410
exclude: **/tests/**,tests
.github/workflows/bandit.yml
浏览文件 @
53697eca
...
...
@@ -33,7 +33,7 @@ jobs:
echo "Bandit version: "$(bandit --version | head -1)
echo "The files will be checked: "$(echo $CHANGED_FILES)
bandit
$CHANGED_FILES --exclude '**/tests/**' -a file --ini ./.bandit -f html -o ./bandit_report/bandit_checks.html
bandit
-a file --ini .bandit -f html -o ./bandit_report/bandit_checks.html $CHANGED_FILES
deactivate
else
echo "No files with the \"py\" extension found"
...
...
CHANGELOG.md
浏览文件 @
53697eca
...
...
@@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
Possibility to display tags on frame
-
Support source and target storages (server part)
-
Tests for import/export annotation, dataset, backup from/to cloud storage
-
Added Python SDK package (
`cvat-sdk`
)
-
Added Python SDK package (
`cvat-sdk`
)
(
<https://github.com/opencv/cvat/pull/4813>
)
-
Previews for jobs
-
Documentation for LDAP authentication (
<https://github.com/cvat-ai/cvat/pull/39>
)
-
OpenCV.js caching and autoload (
<https://github.com/cvat-ai/cvat/pull/30>
)
...
...
@@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
-
Bumped nuclio version to 1.8.14
-
Simplified running REST API tests. Extended CI-nightly workflow
-
REST API tests are partially moved to Python SDK (
`users`
,
`projects`
,
`tasks`
)
-
REST API tests are partially moved to Python SDK (
`users`
,
`projects`
,
`tasks`
,
`issues`
)
-
cvat-ui: Improve UI/UX on label, create task and create project forms (
<https://github.com/cvat-ai/cvat/pull/7>
)
-
Removed link to OpenVINO documentation (
<https://github.com/cvat-ai/cvat/pull/35>
)
-
Clarified meaning of chunking for videos
...
...
@@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
Image search in cloud storage (
<https://github.com/cvat-ai/cvat/pull/8>
)
-
Reset password functionality (
<https://github.com/cvat-ai/cvat/pull/52>
)
-
Creating task with cloud storage data (
<https://github.com/cvat-ai/cvat/pull/116>
)
-
Show empty tasks (
<https://github.com/cvat-ai/cvat/pull/100>
)
### Security
-
TDB
...
...
cvat-cli/src/cvat_cli/cli.py
浏览文件 @
53697eca
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
...
...
@@ -11,7 +10,7 @@ from typing import Dict, List, Sequence, Tuple
import
tqdm
from
cvat_sdk
import
Client
,
models
from
cvat_sdk.core.helpers
import
TqdmProgressReporter
from
cvat_sdk.core.
type
s
import
ResourceType
from
cvat_sdk.core.
proxies.task
s
import
ResourceType
class
CLI
:
...
...
@@ -26,7 +25,7 @@ class CLI:
def
tasks_list
(
self
,
*
,
use_json_output
:
bool
=
False
,
**
kwargs
):
"""List all tasks in either basic or JSON format."""
results
=
self
.
client
.
list_tasks
(
return_json
=
use_json_output
,
**
kwargs
)
results
=
self
.
client
.
tasks
.
list
(
return_json
=
use_json_output
,
**
kwargs
)
if
use_json_output
:
print
(
json
.
dumps
(
json
.
loads
(
results
),
indent
=
2
))
else
:
...
...
@@ -50,7 +49,7 @@ class CLI:
"""
Create a new task with the given name and labels JSON and add the files to it.
"""
task
=
self
.
client
.
create_task
(
task
=
self
.
client
.
tasks
.
create_from_data
(
spec
=
models
.
TaskWriteRequest
(
name
=
name
,
labels
=
labels
,
**
kwargs
),
resource_type
=
resource_type
,
resources
=
resources
,
...
...
@@ -66,7 +65,7 @@ class CLI:
def
tasks_delete
(
self
,
task_ids
:
Sequence
[
int
])
->
None
:
"""Delete a list of tasks, ignoring those which don't exist."""
self
.
client
.
delete_task
s
(
task_ids
=
task_ids
)
self
.
client
.
tasks
.
remove_by_id
s
(
task_ids
=
task_ids
)
def
tasks_frames
(
self
,
...
...
@@ -80,11 +79,11 @@ class CLI:
Download the requested frame numbers for a task and save images as
task_<ID>_frame_<FRAME>.jpg.
"""
self
.
client
.
retrieve_task
(
task
_id
=
task_id
).
download_frames
(
self
.
client
.
tasks
.
retrieve
(
obj
_id
=
task_id
).
download_frames
(
frame_ids
=
frame_ids
,
outdir
=
outdir
,
quality
=
quality
,
filename_pattern
=
"task_{task_id}
_frame_{frame_id:06d}{frame_ext}"
,
filename_pattern
=
f
"task_
{
task_id
}
"
+
"
_frame_{frame_id:06d}{frame_ext}"
,
)
def
tasks_dump
(
...
...
@@ -99,7 +98,7 @@ class CLI:
"""
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
self
.
client
.
retrieve_task
(
task
_id
=
task_id
).
export_dataset
(
self
.
client
.
tasks
.
retrieve
(
obj
_id
=
task_id
).
export_dataset
(
format_name
=
fileformat
,
filename
=
filename
,
pbar
=
self
.
_make_pbar
(),
...
...
@@ -112,7 +111,7 @@ class CLI:
)
->
None
:
"""Upload annotations for a task in the specified format
(e.g. 'YOLO ZIP 1.0')."""
self
.
client
.
retrieve_task
(
task
_id
=
task_id
).
import_annotations
(
self
.
client
.
tasks
.
retrieve
(
obj
_id
=
task_id
).
import_annotations
(
format_name
=
fileformat
,
filename
=
filename
,
status_check_period
=
status_check_period
,
...
...
@@ -121,13 +120,13 @@ class CLI:
def
tasks_export
(
self
,
task_id
:
str
,
filename
:
str
,
*
,
status_check_period
:
int
=
2
)
->
None
:
"""Download a task backup"""
self
.
client
.
retrieve_task
(
task
_id
=
task_id
).
download_backup
(
self
.
client
.
tasks
.
retrieve
(
obj
_id
=
task_id
).
download_backup
(
filename
=
filename
,
status_check_period
=
status_check_period
,
pbar
=
self
.
_make_pbar
()
)
def
tasks_import
(
self
,
filename
:
str
,
*
,
status_check_period
:
int
=
2
)
->
None
:
"""Import a task from a backup file"""
self
.
client
.
create_task
_from_backup
(
self
.
client
.
tasks
.
create
_from_backup
(
filename
=
filename
,
status_check_period
=
status_check_period
,
pbar
=
self
.
_make_pbar
()
)
...
...
cvat-cli/src/cvat_cli/parser.py
浏览文件 @
53697eca
...
...
@@ -10,7 +10,7 @@ import logging
import
os
from
distutils.util
import
strtobool
from
cvat_sdk.core.
type
s
import
ResourceType
from
cvat_sdk.core.
proxies.task
s
import
ResourceType
from
.version
import
VERSION
...
...
cvat-sdk/.gitignore
浏览文件 @
53697eca
cvat-sdk/cvat_sdk/core/client.py
浏览文件 @
53697eca
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
...
...
@@ -6,23 +5,22 @@
from
__future__
import
annotations
import
json
import
logging
import
os.path
as
osp
import
urllib.parse
from
time
import
sleep
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
,
Tuple
import
attrs
import
urllib3
from
cvat_sdk.api_client
import
ApiClient
,
ApiException
,
ApiValueError
,
Configuration
,
models
from
cvat_sdk.core.
git
import
create_git_repo
from
cvat_sdk.core.
helpers
import
get_paginated_collection
from
cvat_sdk.core.pro
gress
import
ProgressReporter
from
cvat_sdk.core.
tasks
import
TaskProxy
from
cvat_sdk.core.
types
import
ResourceType
from
cvat_sdk.core.
uploading
import
Uploader
from
cvat_sdk.core.
utils
import
assert_status
from
cvat_sdk.api_client
import
ApiClient
,
Configuration
,
models
from
cvat_sdk.core.
helpers
import
expect_status
from
cvat_sdk.core.
proxies.issues
import
CommentsRepo
,
IssuesRepo
from
cvat_sdk.core.pro
xies.jobs
import
JobsRepo
from
cvat_sdk.core.
proxies.model_proxy
import
Repo
from
cvat_sdk.core.
proxies.projects
import
ProjectsRepo
from
cvat_sdk.core.
proxies.tasks
import
TasksRepo
from
cvat_sdk.core.
proxies.users
import
UsersRepo
@
attrs
.
define
...
...
@@ -43,11 +41,13 @@ class Client:
):
# TODO: use requests instead of urllib3 in ApiClient
# TODO: try to autodetect schema
self
.
_api_map
=
_
CVAT_API_V2
(
url
)
self
.
api_map
=
CVAT_API_V2
(
url
)
self
.
api
=
ApiClient
(
Configuration
(
host
=
url
))
self
.
logger
=
logger
or
logging
.
getLogger
(
__name__
)
self
.
config
=
config
or
Config
()
self
.
_repos
:
Dict
[
str
,
Repo
]
=
{}
def
__enter__
(
self
):
self
.
api
.
__enter__
()
return
self
...
...
@@ -67,150 +67,93 @@ class Client:
assert
"csrftoken"
in
self
.
api
.
cookies
self
.
api
.
set_default_header
(
"Authorization"
,
"Token "
+
auth
.
key
)
def
create_task
(
self
,
spec
:
models
.
ITaskWriteRequest
,
resource_type
:
ResourceType
,
resources
:
Sequence
[
str
],
*
,
data_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
annotation_path
:
str
=
""
,
annotation_format
:
str
=
"CVAT XML 1.1"
,
status_check_period
:
int
=
None
,
dataset_repository_url
:
str
=
""
,
use_lfs
:
bool
=
False
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
)
->
TaskProxy
:
"""
Create a new task with the given name and labels JSON and
add the files to it.
def
_has_credentials
(
self
):
return
(
(
"sessionid"
in
self
.
api
.
cookies
)
or
(
"csrftoken"
in
self
.
api
.
cookies
)
or
(
self
.
api
.
get_common_headers
().
get
(
"Authorization"
,
""
))
)
Returns: id of the created task
"""
def
logout
(
self
):
if
self
.
_has_credentials
():
self
.
api
.
auth_api
.
create_logout
()
def
wait_for_completion
(
self
:
Client
,
url
:
str
,
*
,
success_status
:
int
,
status_check_period
:
Optional
[
int
]
=
None
,
query_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
post_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
method
:
str
=
"POST"
,
positive_statuses
:
Optional
[
Sequence
[
int
]]
=
None
,
)
->
urllib3
.
HTTPResponse
:
if
status_check_period
is
None
:
status_check_period
=
self
.
config
.
status_check_period
if
getattr
(
spec
,
"project_id"
,
None
)
and
getattr
(
spec
,
"labels"
,
None
):
raise
ApiValueError
(
"Can't set labels to a task inside a project. "
"Tasks inside a project use project's labels."
,
[
"labels"
],
)
(
task
,
_
)
=
self
.
api
.
tasks_api
.
create
(
spec
)
self
.
logger
.
info
(
"Created task ID: %s NAME: %s"
,
task
.
id
,
task
.
name
)
task
=
TaskProxy
(
self
,
task
)
task
.
upload_data
(
resource_type
,
resources
,
pbar
=
pbar
,
params
=
data_params
)
positive_statuses
=
set
(
positive_statuses
)
|
{
success_status
}
self
.
logger
.
info
(
"Awaiting for task %s creation..."
,
task
.
id
)
status
=
None
while
status
!=
models
.
RqStatusStateEnum
.
allowed_values
[(
"value"
,)][
"FINISHED"
]:
while
True
:
sleep
(
status_check_period
)
(
status
,
_
)
=
self
.
api
.
tasks_api
.
retrieve_status
(
task
.
id
)
self
.
logger
.
info
(
"Task %s creation status=%s, message=%s"
,
task
.
id
,
status
.
state
.
value
,
status
.
message
,
)
if
status
.
state
.
value
==
models
.
RqStatusStateEnum
.
allowed_values
[(
"value"
,)][
"FAILED"
]:
raise
ApiException
(
status
=
status
.
state
.
value
,
reason
=
status
.
message
)
status
=
status
.
state
.
value
if
annotation_path
:
task
.
import_annotations
(
annotation_format
,
annotation_path
,
pbar
=
pbar
)
if
dataset_repository_url
:
create_git_repo
(
self
,
task_id
=
task
.
id
,
repo_url
=
dataset_repository_url
,
status_check_period
=
status_check_period
,
use_lfs
=
use_lfs
,
)
task
.
fetch
()
return
task
def
list_tasks
(
self
,
*
,
return_json
:
bool
=
False
,
**
kwargs
)
->
Union
[
List
[
TaskProxy
],
List
[
Dict
[
str
,
Any
]]]:
"""List all tasks in either basic or JSON format."""
results
=
get_paginated_collection
(
endpoint
=
self
.
api
.
tasks_api
.
list_endpoint
,
return_json
=
return_json
,
**
kwargs
response
=
self
.
api
.
rest_client
.
request
(
method
=
method
,
url
=
url
,
headers
=
self
.
api
.
get_common_headers
(),
query_params
=
query_params
,
post_params
=
post_params
,
)
if
return_json
:
return
json
.
dumps
(
results
)
return
[
TaskProxy
(
self
,
v
)
for
v
in
results
]
def
retrieve_task
(
self
,
task_id
:
int
)
->
TaskProxy
:
(
task
,
_
)
=
self
.
api
.
tasks_api
.
retrieve
(
task_id
)
return
TaskProxy
(
self
,
task
)
self
.
logger
.
debug
(
"STATUS %s"
,
response
.
status
)
expect_status
(
positive_statuses
,
response
)
if
response
.
status
==
success_status
:
break
def
delete_tasks
(
self
,
task_ids
:
Sequence
[
int
]):
"""
Delete a list of tasks, ignoring those which don't exist.
"""
return
response
for
task_id
in
task_ids
:
(
_
,
response
)
=
self
.
api
.
tasks_api
.
destroy
(
task_id
,
_check_status
=
False
)
if
200
<=
response
.
status
<=
299
:
self
.
logger
.
info
(
f
"Task ID
{
task_id
}
deleted"
)
elif
response
.
status
==
404
:
self
.
logger
.
info
(
f
"Task ID
{
task_id
}
not found"
)
else
:
self
.
logger
.
warning
(
f
"Failed to delete task ID
{
task_id
}
: "
f
"
{
response
.
msg
}
(status
{
response
.
status
}
)"
)
def
_get_repo
(
self
,
key
:
str
)
->
Repo
:
_repo_map
=
{
"tasks"
:
TasksRepo
,
"projects"
:
ProjectsRepo
,
"jobs"
:
JobsRepo
,
"users"
:
UsersRepo
,
"issues"
:
IssuesRepo
,
"comments"
:
CommentsRepo
,
}
def
create_task_from_backup
(
self
,
filename
:
str
,
*
,
status_check_period
:
int
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
)
->
TaskProxy
:
"""
Import a task from a backup file
"""
if
status_check_period
is
None
:
status_check_period
=
self
.
config
.
status_check_period
repo
=
self
.
_repos
.
get
(
key
,
None
)
if
repo
is
None
:
repo
=
_repo_map
[
key
](
self
)
self
.
_repos
[
key
]
=
repo
return
repo
params
=
{
"filename"
:
osp
.
basename
(
filename
)}
url
=
self
.
_api_map
.
make_endpoint_url
(
self
.
api
.
tasks_api
.
create_backup_endpoint
.
path
)
uploader
=
Uploader
(
self
)
response
=
uploader
.
upload_file
(
url
,
filename
,
meta
=
params
,
query_params
=
params
,
pbar
=
pbar
,
logger
=
self
.
logger
.
debug
)
@
property
def
tasks
(
self
)
->
TasksRepo
:
return
self
.
_get_repo
(
"tasks"
)
rq_id
=
json
.
loads
(
response
.
data
)[
"rq_id"
]
@
property
def
projects
(
self
)
->
ProjectsRepo
:
return
self
.
_get_repo
(
"projects"
)
# check task status
while
True
:
sleep
(
status_check_period
)
@
property
def
jobs
(
self
)
->
JobsRepo
:
return
self
.
_get_repo
(
"jobs"
)
response
=
self
.
api
.
rest_client
.
POST
(
url
,
post_params
=
{
"rq_id"
:
rq_id
},
headers
=
self
.
api
.
get_common_headers
()
)
if
response
.
status
==
201
:
break
assert_status
(
202
,
response
)
@
property
def
users
(
self
)
->
UsersRepo
:
return
self
.
_get_repo
(
"users"
)
task_id
=
json
.
loads
(
response
.
data
)[
"id"
]
self
.
logger
.
info
(
f
"Task has been imported sucessfully. Task ID:
{
task_id
}
"
)
@
property
def
issues
(
self
)
->
IssuesRepo
:
return
self
.
_get_repo
(
"issues"
)
return
self
.
retrieve_task
(
task_id
)
@
property
def
comments
(
self
)
->
CommentsRepo
:
return
self
.
_get_repo
(
"comments"
)
class
_
CVAT_API_V2
:
class
CVAT_API_V2
:
"""Build parameterized API URLs"""
def
__init__
(
self
,
host
,
https
=
False
):
...
...
cvat-sdk/cvat_sdk/core/downloading.py
浏览文件 @
53697eca
...
...
@@ -8,8 +8,9 @@ from __future__ import annotations
import
os
import
os.path
as
osp
from
contextlib
import
closing
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
cvat_sdk.api_client.api_client
import
Endpoint
from
cvat_sdk.core.progress
import
ProgressReporter
if
TYPE_CHECKING
:
...
...
@@ -17,8 +18,12 @@ if TYPE_CHECKING:
class
Downloader
:
"""
Implements common downloading protocols
"""
def
__init__
(
self
,
client
:
Client
):
self
.
client
=
client
self
.
_
client
=
client
def
download_file
(
self
,
...
...
@@ -29,8 +34,7 @@ class Downloader:
pbar
:
Optional
[
ProgressReporter
]
=
None
,
)
->
None
:
"""
Downloads the file from url into a temporary file, then renames it
to the requested name.
Downloads the file from url into a temporary file, then renames it to the requested name.
"""
CHUNK_SIZE
=
10
*
2
**
20
...
...
@@ -41,10 +45,10 @@ class Downloader:
if
osp
.
exists
(
tmp_path
):
raise
FileExistsError
(
f
"Can't write temporary file '
{
tmp_path
}
' - file exists"
)
response
=
self
.
client
.
api
.
rest_client
.
GET
(
response
=
self
.
_
client
.
api
.
rest_client
.
GET
(
url
,
_request_timeout
=
timeout
,
headers
=
self
.
client
.
api
.
get_common_headers
(),
headers
=
self
.
_
client
.
api
.
get_common_headers
(),
_parse_response
=
False
,
)
with
closing
(
response
):
...
...
@@ -72,3 +76,38 @@ class Downloader:
except
:
os
.
unlink
(
tmp_path
)
raise
def
prepare_and_download_file_from_endpoint
(
self
,
endpoint
:
Endpoint
,
filename
:
str
,
*
,
url_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
query_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
status_check_period
:
Optional
[
int
]
=
None
,
):
client
=
self
.
_client
if
status_check_period
is
None
:
status_check_period
=
client
.
config
.
status_check_period
client
.
logger
.
info
(
"Waiting for the server to prepare the file..."
)
url
=
client
.
api_map
.
make_endpoint_url
(
endpoint
.
path
,
kwsub
=
url_params
,
query_params
=
query_params
)
client
.
wait_for_completion
(
url
,
method
=
"GET"
,
positive_statuses
=
[
202
],
success_status
=
201
,
status_check_period
=
status_check_period
,
)
query_params
=
dict
(
query_params
or
{})
query_params
[
"action"
]
=
"download"
url
=
client
.
api_map
.
make_endpoint_url
(
endpoint
.
path
,
kwsub
=
url_params
,
query_params
=
query_params
)
downloader
=
Downloader
(
client
)
downloader
.
download_file
(
url
,
output_path
=
filename
,
pbar
=
pbar
)
cvat-sdk/cvat_sdk/core/git.py
浏览文件 @
53697eca
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
...
...
@@ -27,7 +26,7 @@ def create_git_repo(
common_headers
=
client
.
api
.
get_common_headers
()
response
=
client
.
api
.
rest_client
.
POST
(
client
.
_
api_map
.
git_create
(
task_id
),
client
.
api_map
.
git_create
(
task_id
),
post_params
=
{
"path"
:
repo_url
,
"lfs"
:
use_lfs
,
"tid"
:
task_id
},
headers
=
common_headers
,
)
...
...
@@ -36,7 +35,7 @@ def create_git_repo(
client
.
logger
.
info
(
f
"Create RQ ID:
{
rq_id
}
"
)
client
.
logger
.
debug
(
"Awaiting a dataset repository to be created for the task %s..."
,
task_id
)
check_url
=
client
.
_
api_map
.
git_check
(
rq_id
)
check_url
=
client
.
api_map
.
git_check
(
rq_id
)
status
=
None
while
status
!=
"finished"
:
sleep
(
status_check_period
)
...
...
cvat-sdk/cvat_sdk/core/helpers.py
浏览文件 @
53697eca
...
...
@@ -6,13 +6,14 @@ from __future__ import annotations
import
io
import
json
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Union
import
tqdm
import
urllib3
from
cvat_sdk
import
exceptions
from
cvat_sdk.api_client.api_client
import
Endpoint
from
cvat_sdk.core.progress
import
ProgressReporter
from
cvat_sdk.core.utils
import
assert_status
def
get_paginated_collection
(
...
...
@@ -26,7 +27,7 @@ def get_paginated_collection(
page
=
1
while
True
:
(
page_contents
,
response
)
=
endpoint
.
call_with_http_info
(
**
kwargs
,
page
=
page
)
asser
t_status
(
200
,
response
)
expec
t_status
(
200
,
response
)
if
return_json
:
results
.
extend
(
json
.
loads
(
response
.
data
).
get
(
"results"
,
[]))
...
...
@@ -86,3 +87,18 @@ class StreamWithProgress:
def
tell
(
self
):
return
self
.
stream
.
tell
()
def
expect_status
(
codes
:
Union
[
int
,
Iterable
[
int
]],
response
:
urllib3
.
HTTPResponse
)
->
None
:
if
not
hasattr
(
codes
,
"__iter__"
):
codes
=
[
codes
]
if
response
.
status
in
codes
:
return
if
300
<=
response
.
status
<=
500
:
raise
exceptions
.
ApiException
(
response
.
status
,
reason
=
response
.
msg
,
http_resp
=
response
)
else
:
raise
exceptions
.
ApiException
(
response
.
status
,
reason
=
"Unexpected status code received"
,
http_resp
=
response
)
cvat-sdk/cvat_sdk/core/proxies/__init__.py
0 → 100644
浏览文件 @
53697eca
cvat-sdk/cvat_sdk/core/proxies/annotations.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from
abc
import
ABC
from
enum
import
Enum
from
typing
import
Optional
,
Sequence
from
cvat_sdk
import
models
from
cvat_sdk.core.proxies.model_proxy
import
_EntityT
class
AnnotationUpdateAction
(
Enum
):
CREATE
=
"create"
UPDATE
=
"update"
DELETE
=
"delete"
class
AnnotationCrudMixin
(
ABC
):
# TODO: refactor
@
property
def
_put_annotations_data_param
(
self
)
->
str
:
...
def
get_annotations
(
self
:
_EntityT
)
->
models
.
ILabeledData
:
(
annotations
,
_
)
=
self
.
api
.
retrieve_annotations
(
getattr
(
self
,
self
.
_model_id_field
))
return
annotations
def
set_annotations
(
self
:
_EntityT
,
data
:
models
.
ILabeledDataRequest
):
self
.
api
.
update_annotations
(
getattr
(
self
,
self
.
_model_id_field
),
**
{
self
.
_put_annotations_data_param
:
data
}
)
def
update_annotations
(
self
:
_EntityT
,
data
:
models
.
IPatchedLabeledDataRequest
,
*
,
action
:
AnnotationUpdateAction
=
AnnotationUpdateAction
.
UPDATE
,
):
self
.
api
.
partial_update_annotations
(
action
=
action
.
value
,
id
=
getattr
(
self
,
self
.
_model_id_field
),
patched_labeled_data_request
=
data
,
)
def
remove_annotations
(
self
:
_EntityT
,
*
,
ids
:
Optional
[
Sequence
[
int
]]
=
None
):
if
ids
:
anns
=
self
.
get_annotations
()
if
not
isinstance
(
ids
,
set
):
ids
=
set
(
ids
)
anns_to_remove
=
models
.
PatchedLabeledDataRequest
(
tags
=
[
models
.
LabeledImageRequest
(
**
a
.
to_dict
())
for
a
in
anns
.
tags
if
a
.
id
in
ids
],
tracks
=
[
models
.
LabeledTrackRequest
(
**
a
.
to_dict
())
for
a
in
anns
.
tracks
if
a
.
id
in
ids
],
shapes
=
[
models
.
LabeledShapeRequest
(
**
a
.
to_dict
())
for
a
in
anns
.
shapes
if
a
.
id
in
ids
],
)
self
.
update_annotations
(
anns_to_remove
,
action
=
AnnotationUpdateAction
.
DELETE
)
else
:
self
.
api
.
destroy_annotations
(
getattr
(
self
,
self
.
_model_id_field
))
cvat-sdk/cvat_sdk/core/proxies/issues.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from
__future__
import
annotations
from
cvat_sdk.api_client
import
apis
,
models
from
cvat_sdk.core.proxies.model_proxy
import
(
ModelCreateMixin
,
ModelDeleteMixin
,
ModelListMixin
,
ModelRetrieveMixin
,
ModelUpdateMixin
,
build_model_bases
,
)
_CommentEntityBase
,
_CommentRepoBase
=
build_model_bases
(
models
.
CommentRead
,
apis
.
CommentsApi
,
api_member_name
=
"comments_api"
)
class
Comment
(
models
.
ICommentRead
,
_CommentEntityBase
,
ModelUpdateMixin
[
models
.
IPatchedCommentWriteRequest
],
ModelDeleteMixin
,
):
_model_partial_update_arg
=
"patched_comment_write_request"
class
CommentsRepo
(
_CommentRepoBase
,
ModelListMixin
[
Comment
],
ModelCreateMixin
[
Comment
,
models
.
ICommentWriteRequest
],
ModelRetrieveMixin
[
Comment
],
):
_entity_type
=
Comment
_IssueEntityBase
,
_IssueRepoBase
=
build_model_bases
(
models
.
IssueRead
,
apis
.
IssuesApi
,
api_member_name
=
"issues_api"
)
class
Issue
(
models
.
IIssueRead
,
_IssueEntityBase
,
ModelUpdateMixin
[
models
.
IPatchedIssueWriteRequest
],
ModelDeleteMixin
,
):
_model_partial_update_arg
=
"patched_issue_write_request"
class
IssuesRepo
(
_IssueRepoBase
,
ModelListMixin
[
Issue
],
ModelCreateMixin
[
Issue
,
models
.
IIssueWriteRequest
],
ModelRetrieveMixin
[
Issue
],
):
_entity_type
=
Issue
cvat-sdk/cvat_sdk/core/proxies/jobs.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from
__future__
import
annotations
import
io
import
mimetypes
import
os
import
os.path
as
osp
from
typing
import
List
,
Optional
,
Sequence
from
PIL
import
Image
from
cvat_sdk.api_client
import
apis
,
models
from
cvat_sdk.core.downloading
import
Downloader
from
cvat_sdk.core.helpers
import
get_paginated_collection
from
cvat_sdk.core.progress
import
ProgressReporter
from
cvat_sdk.core.proxies.annotations
import
AnnotationCrudMixin
from
cvat_sdk.core.proxies.issues
import
Issue
from
cvat_sdk.core.proxies.model_proxy
import
(
ModelListMixin
,
ModelRetrieveMixin
,
ModelUpdateMixin
,
build_model_bases
,
)
from
cvat_sdk.core.uploading
import
AnnotationUploader
_JobEntityBase
,
_JobRepoBase
=
build_model_bases
(
models
.
JobRead
,
apis
.
JobsApi
,
api_member_name
=
"jobs_api"
)
class
Job
(
models
.
IJobRead
,
_JobEntityBase
,
ModelUpdateMixin
[
models
.
IPatchedJobWriteRequest
],
AnnotationCrudMixin
,
):
_model_partial_update_arg
=
"patched_job_write_request"
_put_annotations_data_param
=
"job_annotations_update_request"
def
import_annotations
(
self
,
format_name
:
str
,
filename
:
str
,
*
,
status_check_period
:
Optional
[
int
]
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
):
"""
Upload annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
"""
AnnotationUploader
(
self
.
_client
).
upload_file_and_wait
(
self
.
api
.
create_annotations_endpoint
,
filename
,
format_name
,
url_params
=
{
"id"
:
self
.
id
},
pbar
=
pbar
,
status_check_period
=
status_check_period
,
)
self
.
_client
.
logger
.
info
(
f
"Annotation file '
{
filename
}
' for job #
{
self
.
id
}
uploaded"
)
def
export_dataset
(
self
,
format_name
:
str
,
filename
:
str
,
*
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
status_check_period
:
Optional
[
int
]
=
None
,
include_images
:
bool
=
True
,
)
->
None
:
"""
Download annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
"""
if
include_images
:
endpoint
=
self
.
api
.
retrieve_dataset_endpoint
else
:
endpoint
=
self
.
api
.
retrieve_annotations_endpoint
Downloader
(
self
.
_client
).
prepare_and_download_file_from_endpoint
(
endpoint
=
endpoint
,
filename
=
filename
,
url_params
=
{
"id"
:
self
.
id
},
query_params
=
{
"format"
:
format_name
},
pbar
=
pbar
,
status_check_period
=
status_check_period
,
)
self
.
_client
.
logger
.
info
(
f
"Dataset for job
{
self
.
id
}
has been downloaded to
{
filename
}
"
)
def
get_frame
(
self
,
frame_id
:
int
,
*
,
quality
:
Optional
[
str
]
=
None
,
)
->
io
.
RawIOBase
:
(
_
,
response
)
=
self
.
api
.
retrieve_data
(
self
.
id
,
number
=
frame_id
,
quality
=
quality
,
type
=
"frame"
)
return
io
.
BytesIO
(
response
.
data
)
def
get_preview
(
self
,
)
->
io
.
RawIOBase
:
(
_
,
response
)
=
self
.
api
.
retrieve_data
(
self
.
id
,
type
=
"preview"
)
return
io
.
BytesIO
(
response
.
data
)
def
download_frames
(
self
,
frame_ids
:
Sequence
[
int
],
*
,
outdir
:
str
=
""
,
quality
:
str
=
"original"
,
filename_pattern
:
str
=
"frame_{frame_id:06d}{frame_ext}"
,
)
->
Optional
[
List
[
Image
.
Image
]]:
"""
Download the requested frame numbers for a job and save images as outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
os
.
makedirs
(
outdir
,
exist_ok
=
True
)
for
frame_id
in
frame_ids
:
frame_bytes
=
self
.
get_frame
(
frame_id
,
quality
=
quality
)
im
=
Image
.
open
(
frame_bytes
)
mime_type
=
im
.
get_format_mimetype
()
or
"image/jpg"
im_ext
=
mimetypes
.
guess_extension
(
mime_type
)
# FIXME It is better to use meta information from the server
# to determine the extension
# replace '.jpe' or '.jpeg' with a more used '.jpg'
if
im_ext
in
(
".jpe"
,
".jpeg"
,
None
):
im_ext
=
".jpg"
outfile
=
filename_pattern
.
format
(
frame_id
=
frame_id
,
frame_ext
=
im_ext
)
im
.
save
(
osp
.
join
(
outdir
,
outfile
))
def
get_meta
(
self
)
->
models
.
IDataMetaRead
:
(
meta
,
_
)
=
self
.
api
.
retrieve_data_meta
(
self
.
id
)
return
meta
def
get_frames_info
(
self
)
->
List
[
models
.
IFrameMeta
]:
return
self
.
get_meta
().
frames
def
remove_frames_by_ids
(
self
,
ids
:
Sequence
[
int
])
->
None
:
self
.
_client
.
api
.
tasks_api
.
jobs_partial_update_data_meta
(
self
.
id
,
patched_data_meta_write_request
=
models
.
PatchedDataMetaWriteRequest
(
deleted_frames
=
ids
),
)
def
get_issues
(
self
)
->
List
[
Issue
]:
return
[
Issue
(
self
.
_client
,
m
)
for
m
in
self
.
api
.
list_issues
(
id
=
self
.
id
)[
0
]]
def
get_commits
(
self
)
->
List
[
models
.
IJobCommit
]:
return
get_paginated_collection
(
self
.
api
.
list_commits_endpoint
,
id
=
self
.
id
)
class
JobsRepo
(
_JobRepoBase
,
ModelListMixin
[
Job
],
ModelRetrieveMixin
[
Job
],
):
_entity_type
=
Job
cvat-sdk/cvat_sdk/core/proxies/model_proxy.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from
__future__
import
annotations
import
json
from
abc
import
ABC
from
copy
import
deepcopy
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Literal
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
,
)
from
typing_extensions
import
Self
from
cvat_sdk.api_client.model_utils
import
IModelData
,
ModelNormal
,
to_json
from
cvat_sdk.core.helpers
import
get_paginated_collection
if
TYPE_CHECKING
:
from
cvat_sdk.core.client
import
Client
IModel
=
TypeVar
(
"IModel"
,
bound
=
IModelData
)
ModelType
=
TypeVar
(
"ModelType"
,
bound
=
ModelNormal
)
ApiType
=
TypeVar
(
"ApiType"
)
class
ModelProxy
(
ABC
,
Generic
[
ModelType
,
ApiType
]):
_client
:
Client
@
property
def
_api_member_name
(
self
)
->
str
:
...
def
__init__
(
self
,
client
:
Client
)
->
None
:
self
.
__dict__
[
"_client"
]
=
client
@
classmethod
def
get_api
(
cls
,
client
:
Client
)
->
ApiType
:
return
getattr
(
client
.
api
,
cls
.
_api_member_name
)
@
property
def
api
(
self
)
->
ApiType
:
return
self
.
get_api
(
self
.
_client
)
class
Entity
(
ModelProxy
[
ModelType
,
ApiType
]):
"""
Represents a single object. Implements related operations and provides access to data members.
"""
_model
:
ModelType
def
__init__
(
self
,
client
:
Client
,
model
:
ModelType
)
->
None
:
super
().
__init__
(
client
)
self
.
__dict__
[
"_model"
]
=
model
@
property
def
_model_id_field
(
self
)
->
str
:
return
"id"
def
__getattr__
(
self
,
__name
:
str
)
->
Any
:
# NOTE: be aware of potential problems with throwing AttributeError from @property
# in derived classes!
# https://medium.com/@ceshine/python-debugging-pitfall-mixed-use-of-property-and-getattr-f89e0ede13f1
return
self
.
_model
[
__name
]
def
__str__
(
self
)
->
str
:
return
str
(
self
.
_model
)
def
__repr__
(
self
)
->
str
:
return
f
"<
{
self
.
__class__
.
__name__
}
: id=
{
getattr
(
self
,
self
.
_model_id_field
)
}
>"
class
Repo
(
ModelProxy
[
ModelType
,
ApiType
]):
"""
Represents a collection of corresponding Entity objects.
Implements group and management operations for entities.
"""
_entity_type
:
Type
[
Entity
[
ModelType
,
ApiType
]]
### Utilities
def
build_model_bases
(
mt
:
Type
[
ModelType
],
at
:
Type
[
ApiType
],
*
,
api_member_name
:
Optional
[
str
]
=
None
)
->
Tuple
[
Type
[
Entity
[
ModelType
,
ApiType
]],
Type
[
Repo
[
ModelType
,
ApiType
]]]:
"""
Helps to remove code duplication in declarations of derived classes
"""
class
_EntityBase
(
Entity
[
ModelType
,
ApiType
]):
if
api_member_name
:
_api_member_name
=
api_member_name
class
_RepoBase
(
Repo
[
ModelType
,
ApiType
]):
if
api_member_name
:
_api_member_name
=
api_member_name
return
_EntityBase
,
_RepoBase
### CRUD mixins
_EntityT
=
TypeVar
(
"_EntityT"
,
bound
=
Entity
)
#### Repo mixins
class
ModelCreateMixin
(
Generic
[
_EntityT
,
IModel
]):
def
create
(
self
:
Repo
,
spec
:
Union
[
Dict
[
str
,
Any
],
IModel
])
->
_EntityT
:
"""
Creates a new object on the server and returns corresponding local object
"""
(
model
,
_
)
=
self
.
api
.
create
(
spec
)
return
self
.
_entity_type
(
self
.
_client
,
model
)
class
ModelRetrieveMixin
(
Generic
[
_EntityT
]):
def
retrieve
(
self
:
Repo
,
obj_id
:
int
)
->
_EntityT
:
"""
Retrieves an object from server by ID
"""
(
model
,
_
)
=
self
.
api
.
retrieve
(
id
=
obj_id
)
return
self
.
_entity_type
(
self
.
_client
,
model
)
class
ModelListMixin
(
Generic
[
_EntityT
]):
@
overload
def
list
(
self
:
Repo
,
*
,
return_json
:
Literal
[
False
]
=
False
)
->
List
[
_EntityT
]:
...
@
overload
def
list
(
self
:
Repo
,
*
,
return_json
:
Literal
[
True
]
=
False
)
->
List
[
Any
]:
...
def
list
(
self
:
Repo
,
*
,
return_json
:
bool
=
False
)
->
List
[
Union
[
_EntityT
,
Any
]]:
"""
Retrieves all objects from the server and returns them in basic or JSON format.
"""
results
=
get_paginated_collection
(
endpoint
=
self
.
api
.
list_endpoint
,
return_json
=
return_json
)
if
return_json
:
return
json
.
dumps
(
results
)
return
[
self
.
_entity_type
(
self
.
_client
,
model
)
for
model
in
results
]
#### Entity mixins
class
ModelUpdateMixin
(
ABC
,
Generic
[
IModel
]):
@
property
def
_model_partial_update_arg
(
self
:
Entity
)
->
str
:
...
def
_export_update_fields
(
self
:
Entity
,
overrides
:
Optional
[
Union
[
Dict
[
str
,
Any
],
IModel
]]
=
None
)
->
Dict
[
str
,
Any
]:
# TODO: support field conversion and assignment updating
# fields = to_json(self._model)
if
isinstance
(
overrides
,
ModelNormal
):
overrides
=
to_json
(
overrides
)
fields
=
deepcopy
(
overrides
)
return
fields
def
fetch
(
self
:
Entity
)
->
Self
:
"""
Updates current object from the server
"""
# TODO: implement revision checking
(
self
.
_model
,
_
)
=
self
.
api
.
retrieve
(
id
=
getattr
(
self
,
self
.
_model_id_field
))
return
self
def
update
(
self
:
Entity
,
values
:
Union
[
Dict
[
str
,
Any
],
IModel
])
->
Self
:
"""
Commits local model changes to the server
"""
# TODO: implement revision checking
self
.
api
.
partial_update
(
id
=
getattr
(
self
,
self
.
_model_id_field
),
**
{
self
.
_model_partial_update_arg
:
self
.
_export_update_fields
(
values
)},
)
# TODO: use the response model, once input and output models are same
return
self
.
fetch
()
class
ModelDeleteMixin
:
def
remove
(
self
:
Entity
)
->
None
:
"""
Removes current object on the server
"""
self
.
api
.
destroy
(
id
=
getattr
(
self
,
self
.
_model_id_field
))
cvat-sdk/cvat_sdk/core/proxies/projects.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from
__future__
import
annotations
import
json
import
os.path
as
osp
from
typing
import
Optional
from
cvat_sdk.api_client
import
apis
,
models
from
cvat_sdk.core.downloading
import
Downloader
from
cvat_sdk.core.progress
import
ProgressReporter
from
cvat_sdk.core.proxies.model_proxy
import
(
ModelCreateMixin
,
ModelDeleteMixin
,
ModelListMixin
,
ModelRetrieveMixin
,
ModelUpdateMixin
,
build_model_bases
,
)
from
cvat_sdk.core.uploading
import
DatasetUploader
,
Uploader
_ProjectEntityBase
,
_ProjectRepoBase
=
build_model_bases
(
models
.
ProjectRead
,
apis
.
ProjectsApi
,
api_member_name
=
"projects_api"
)
class
Project
(
_ProjectEntityBase
,
models
.
IProjectRead
,
ModelUpdateMixin
[
models
.
IPatchedProjectWriteRequest
]
):
_model_partial_update_arg
=
"patched_project_write_request"
def
import_dataset
(
self
,
format_name
:
str
,
filename
:
str
,
*
,
status_check_period
:
Optional
[
int
]
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
):
"""
Import dataset for a project in the specified format (e.g. 'YOLO ZIP 1.0').
"""
DatasetUploader
(
self
.
_client
).
upload_file_and_wait
(
self
.
api
.
create_dataset_endpoint
,
filename
,
format_name
,
url_params
=
{
"id"
:
self
.
id
},
pbar
=
pbar
,
status_check_period
=
status_check_period
,
)
self
.
_client
.
logger
.
info
(
f
"Annotation file '
{
filename
}
' for project #
{
self
.
id
}
uploaded"
)
def
export_dataset
(
self
,
format_name
:
str
,
filename
:
str
,
*
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
status_check_period
:
Optional
[
int
]
=
None
,
include_images
:
bool
=
True
,
)
->
None
:
"""
Download annotations for a project in the specified format (e.g. 'YOLO ZIP 1.0').
"""
if
include_images
:
endpoint
=
self
.
api
.
retrieve_dataset_endpoint
else
:
endpoint
=
self
.
api
.
retrieve_annotations_endpoint
Downloader
(
self
.
_client
).
prepare_and_download_file_from_endpoint
(
endpoint
=
endpoint
,
filename
=
filename
,
url_params
=
{
"id"
:
self
.
id
},
query_params
=
{
"format"
:
format_name
},
pbar
=
pbar
,
status_check_period
=
status_check_period
,
)
self
.
_client
.
logger
.
info
(
f
"Dataset for project
{
self
.
id
}
has been downloaded to
{
filename
}
"
)
def
download_backup
(
self
,
filename
:
str
,
*
,
status_check_period
:
int
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
)
->
None
:
"""
Download a project backup
"""
Downloader
(
self
.
_client
).
prepare_and_download_file_from_endpoint
(
self
.
api
.
retrieve_backup_endpoint
,
filename
=
filename
,
pbar
=
pbar
,
status_check_period
=
status_check_period
,
url_params
=
{
"id"
:
self
.
id
},
)
self
.
_client
.
logger
.
info
(
f
"Backup for project
{
self
.
id
}
has been downloaded to
{
filename
}
"
)
def
get_annotations
(
self
)
->
models
.
ILabeledData
:
(
annotations
,
_
)
=
self
.
api
.
retrieve_annotations
(
self
.
id
)
return
annotations
class
ProjectsRepo
(
_ProjectRepoBase
,
ModelCreateMixin
[
Project
,
models
.
IProjectWriteRequest
],
ModelListMixin
[
Project
],
ModelRetrieveMixin
[
Project
],
ModelDeleteMixin
,
):
_entity_type
=
Project
def
create_from_dataset
(
self
,
spec
:
models
.
IProjectWriteRequest
,
*
,
dataset_path
:
str
=
""
,
dataset_format
:
str
=
"CVAT XML 1.1"
,
status_check_period
:
int
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
)
->
Project
:
"""
Create a new project with the given name and labels JSON and
add the files to it.
Returns: id of the created project
"""
project
=
self
.
create
(
spec
=
spec
)
self
.
_client
.
logger
.
info
(
"Created project ID: %s NAME: %s"
,
project
.
id
,
project
.
name
)
if
dataset_path
:
project
.
import_dataset
(
format_name
=
dataset_format
,
filename
=
dataset_path
,
pbar
=
pbar
,
status_check_period
=
status_check_period
,
)
project
.
fetch
()
return
project
def
create_from_backup
(
self
,
filename
:
str
,
*
,
status_check_period
:
int
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
)
->
Project
:
"""
Import a project from a backup file
"""
if
status_check_period
is
None
:
status_check_period
=
self
.
config
.
status_check_period
params
=
{
"filename"
:
osp
.
basename
(
filename
)}
url
=
self
.
api_map
.
make_endpoint_url
(
self
.
api
.
create_backup_endpoint
.
path
)
uploader
=
Uploader
(
self
)
response
=
uploader
.
upload_file
(
url
,
filename
,
meta
=
params
,
query_params
=
params
,
pbar
=
pbar
,
logger
=
self
.
_client
.
logger
.
debug
,
)
rq_id
=
json
.
loads
(
response
.
data
)[
"rq_id"
]
response
=
self
.
_client
.
wait_for_completion
(
url
,
success_status
=
201
,
positive_statuses
=
[
202
],
post_params
=
{
"rq_id"
:
rq_id
},
status_check_period
=
status_check_period
,
)
project_id
=
json
.
loads
(
response
.
data
)[
"id"
]
self
.
_client
.
logger
.
info
(
f
"Project has been imported sucessfully. Project ID:
{
project_id
}
"
)
return
self
.
retrieve
(
project_id
)
cvat-sdk/cvat_sdk/core/tasks.py
→
cvat-sdk/cvat_sdk/core/
proxies/
tasks.py
浏览文件 @
53697eca
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
...
...
@@ -6,71 +5,60 @@
from
__future__
import
annotations
import
io
import
json
import
mimetypes
import
os
import
os.path
as
osp
from
abc
import
ABC
,
abstractmethod
from
io
import
BytesIO
from
enum
import
Enum
from
time
import
sleep
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Sequence
from
PIL
import
Image
from
cvat_sdk
import
models
from
cvat_sdk.
api_client.model_utils
import
OpenApiModel
from
cvat_sdk
.api_client
import
apis
,
exceptions
,
models
from
cvat_sdk.
core
import
git
from
cvat_sdk.core.downloading
import
Downloader
from
cvat_sdk.core.progress
import
ProgressReporter
from
cvat_sdk.core.types
import
ResourceType
from
cvat_sdk.core.uploading
import
Uploader
from
cvat_sdk.core.proxies.annotations
import
AnnotationCrudMixin
from
cvat_sdk.core.proxies.jobs
import
Job
from
cvat_sdk.core.proxies.model_proxy
import
(
ModelCreateMixin
,
ModelDeleteMixin
,
ModelListMixin
,
ModelRetrieveMixin
,
ModelUpdateMixin
,
build_model_bases
,
)
from
cvat_sdk.core.uploading
import
AnnotationUploader
,
DataUploader
,
Uploader
from
cvat_sdk.core.utils
import
filter_dict
if
TYPE_CHECKING
:
from
cvat_sdk.core.client
import
Client
class
ResourceType
(
Enum
):
LOCAL
=
0
SHARE
=
1
REMOTE
=
2
class
ModelProxy
(
ABC
):
_client
:
Client
_model
:
OpenApiModel
def
__str__
(
self
):
return
self
.
name
.
lower
()
def
__init__
(
self
,
client
:
Client
,
model
:
OpenApiModel
)
->
None
:
self
.
__dict__
[
"_client"
]
=
client
self
.
__dict__
[
"_model"
]
=
model
def
__repr__
(
self
):
return
str
(
self
)
def
__getattr__
(
self
,
__name
:
str
)
->
Any
:
return
self
.
_model
[
__name
]
def
__setattr__
(
self
,
__name
:
str
,
__value
:
Any
)
->
None
:
if
__name
in
self
.
__dict__
:
self
.
__dict__
[
__name
]
=
__value
else
:
self
.
_model
[
__name
]
=
__value
@
abstractmethod
def
fetch
(
self
,
force
:
bool
=
False
):
"""Fetches model data from the server"""
...
@
abstractmethod
def
commit
(
self
,
force
:
bool
=
False
):
"""Commits local changes to the server"""
...
def
sync
(
self
):
"""Pulls server state and commits local model changes"""
raise
NotImplementedError
@
abstractmethod
def
update
(
self
,
**
kwargs
):
"""Updates multiple fields at once"""
...
_TaskEntityBase
,
_TaskRepoBase
=
build_model_bases
(
models
.
TaskRead
,
apis
.
TasksApi
,
api_member_name
=
"tasks_api"
)
class
TaskProxy
(
ModelProxy
,
models
.
ITaskRead
):
def
__init__
(
self
,
client
:
Client
,
task
:
models
.
TaskRead
):
ModelProxy
.
__init__
(
self
,
client
=
client
,
model
=
task
)
def
remove
(
self
):
self
.
_client
.
api
.
tasks_api
.
destroy
(
self
.
id
)
class
Task
(
_TaskEntityBase
,
models
.
ITaskRead
,
ModelUpdateMixin
[
models
.
IPatchedTaskWriteRequest
],
ModelDeleteMixin
,
AnnotationCrudMixin
,
):
_model_partial_update_arg
=
"patched_task_write_request"
_put_annotations_data_param
=
"task_annotations_update_request"
def
upload_data
(
self
,
...
...
@@ -83,9 +71,6 @@ class TaskProxy(ModelProxy, models.ITaskRead):
"""
Add local, remote, or shared files to an existing task.
"""
client
=
self
.
_client
task_id
=
self
.
id
params
=
params
or
{}
data
=
{}
...
...
@@ -116,73 +101,58 @@ class TaskProxy(ModelProxy, models.ITaskRead):
data
[
"frame_filter"
]
=
f
"step=
{
params
.
get
(
'frame_step'
)
}
"
if
resource_type
in
[
ResourceType
.
REMOTE
,
ResourceType
.
SHARE
]:
client
.
api
.
tasks_
api
.
create_data
(
task_
id
,
self
.
api
.
create_data
(
self
.
id
,
data_request
=
models
.
DataRequest
(
**
data
),
_content_type
=
"multipart/form-data"
,
)
elif
resource_type
==
ResourceType
.
LOCAL
:
url
=
client
.
_
api_map
.
make_endpoint_url
(
client
.
api
.
tasks_api
.
create_data_endpoint
.
path
,
kwsub
=
{
"id"
:
task_
id
}
url
=
self
.
_client
.
api_map
.
make_endpoint_url
(
self
.
api
.
create_data_endpoint
.
path
,
kwsub
=
{
"id"
:
self
.
id
}
)
uploader
=
Uploader
(
client
)
uploader
.
upload_files
(
url
,
resources
,
pbar
=
pbar
,
**
data
)
DataUploader
(
self
.
_client
).
upload_files
(
url
,
resources
,
pbar
=
pbar
,
**
data
)
def
import_annotations
(
self
,
format_name
:
str
,
filename
:
str
,
*
,
status_check_period
:
int
=
None
,
status_check_period
:
Optional
[
int
]
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
):
"""
Upload annotations for a task in the specified format
(e.g. 'YOLO ZIP 1.0').
Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
client
=
self
.
_client
if
status_check_period
is
None
:
status_check_period
=
client
.
config
.
status_check_period
task_id
=
self
.
id
url
=
client
.
_api_map
.
make_endpoint_url
(
client
.
api
.
tasks_api
.
create_annotations_endpoint
.
path
,
kwsub
=
{
"id"
:
task_id
},
AnnotationUploader
(
self
.
_client
).
upload_file_and_wait
(
self
.
api
.
create_annotations_endpoint
,
filename
,
format_name
,
url_params
=
{
"id"
:
self
.
id
},
pbar
=
pbar
,
status_check_period
=
status_check_period
,
)
params
=
{
"format"
:
format_name
,
"filename"
:
osp
.
basename
(
filename
)}
uploader
=
Uploader
(
client
)
uploader
.
upload_file
(
url
,
filename
,
pbar
=
pbar
,
query_params
=
params
,
meta
=
{
"filename"
:
params
[
"filename"
]}
)
while
True
:
response
=
client
.
api
.
rest_client
.
POST
(
url
,
headers
=
client
.
api
.
get_common_headers
(),
query_params
=
params
)
if
response
.
status
==
201
:
break
sleep
(
status_check_period
)
self
.
_client
.
logger
.
info
(
f
"Annotation file '
{
filename
}
' for task #
{
self
.
id
}
uploaded"
)
client
.
logger
.
info
(
f
"Upload job for Task ID
{
task_id
}
with annotation file
{
filename
}
finished"
)
def
retrieve_frame
(
def
get_frame
(
self
,
frame_id
:
int
,
*
,
quality
:
Optional
[
str
]
=
None
,
)
->
io
.
RawIOBase
:
client
=
self
.
_client
task_id
=
self
.
id
params
=
{}
if
quality
:
params
[
"quality"
]
=
quality
(
_
,
response
)
=
self
.
api
.
retrieve_data
(
self
.
id
,
number
=
frame_id
,
**
params
,
type
=
"frame"
)
return
io
.
BytesIO
(
response
.
data
)
(
_
,
response
)
=
client
.
api
.
tasks_api
.
retrieve_data
(
task_id
,
frame_id
,
quality
,
type
=
"frame"
)
return
BytesIO
(
response
.
data
)
def
get_preview
(
self
,
)
->
io
.
RawIOBase
:
(
_
,
response
)
=
self
.
api
.
retrieve_data
(
self
.
id
,
type
=
"preview"
)
return
io
.
BytesIO
(
response
.
data
)
def
download_frames
(
self
,
...
...
@@ -190,19 +160,16 @@ class TaskProxy(ModelProxy, models.ITaskRead):
*
,
outdir
:
str
=
""
,
quality
:
str
=
"original"
,
filename_pattern
:
str
=
"
task_{task_id}_
frame_{frame_id:06d}{frame_ext}"
,
filename_pattern
:
str
=
"frame_{frame_id:06d}{frame_ext}"
,
)
->
Optional
[
List
[
Image
.
Image
]]:
"""
Download the requested frame numbers for a task and save images as
outdir/filename_pattern
Download the requested frame numbers for a task and save images as outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
task_id
=
self
.
id
os
.
makedirs
(
outdir
,
exist_ok
=
True
)
for
frame_id
in
frame_ids
:
frame_bytes
=
self
.
retrieve
_frame
(
frame_id
,
quality
=
quality
)
frame_bytes
=
self
.
get
_frame
(
frame_id
,
quality
=
quality
)
im
=
Image
.
open
(
frame_bytes
)
mime_type
=
im
.
get_format_mimetype
()
or
"image/jpg"
...
...
@@ -214,7 +181,7 @@ class TaskProxy(ModelProxy, models.ITaskRead):
if
im_ext
in
(
".jpe"
,
".jpeg"
,
None
):
im_ext
=
".jpg"
outfile
=
filename_pattern
.
format
(
task_id
=
task_id
,
frame_id
=
frame_id
,
frame_ext
=
im_ext
)
outfile
=
filename_pattern
.
format
(
frame_id
=
frame_id
,
frame_ext
=
im_ext
)
im
.
save
(
osp
.
join
(
outdir
,
outfile
))
def
export_dataset
(
...
...
@@ -223,40 +190,27 @@ class TaskProxy(ModelProxy, models.ITaskRead):
filename
:
str
,
*
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
status_check_period
:
int
=
None
,
status_check_period
:
Optional
[
int
]
=
None
,
include_images
:
bool
=
True
,
)
->
None
:
"""
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
client
=
self
.
_client
if
status_check_period
is
None
:
status_check_period
=
client
.
config
.
status_check_period
task_id
=
self
.
id
params
=
{
"filename"
:
self
.
name
,
"format"
:
format_name
}
if
include_images
:
endpoint
=
client
.
api
.
tasks_
api
.
retrieve_dataset_endpoint
endpoint
=
self
.
api
.
retrieve_dataset_endpoint
else
:
endpoint
=
client
.
api
.
tasks_api
.
retrieve_annotations_endpoint
client
.
logger
.
info
(
"Waiting for the server to prepare the file..."
)
while
True
:
(
_
,
response
)
=
endpoint
.
call_with_http_info
(
id
=
task_id
,
**
params
)
client
.
logger
.
debug
(
"STATUS {}"
.
format
(
response
.
status
))
if
response
.
status
==
201
:
break
sleep
(
status_check_period
)
params
[
"action"
]
=
"download"
url
=
client
.
_api_map
.
make_endpoint_url
(
endpoint
.
path
,
kwsub
=
{
"id"
:
task_id
},
query_params
=
params
endpoint
=
self
.
api
.
retrieve_annotations_endpoint
Downloader
(
self
.
_client
).
prepare_and_download_file_from_endpoint
(
endpoint
=
endpoint
,
filename
=
filename
,
url_params
=
{
"id"
:
self
.
id
},
query_params
=
{
"format"
:
format_name
},
pbar
=
pbar
,
status_check_period
=
status_check_period
,
)
downloader
=
Downloader
(
client
)
downloader
.
download_file
(
url
,
output_path
=
filename
,
pbar
=
pbar
)
client
.
logger
.
info
(
f
"Dataset has been export
ed to
{
filename
}
"
)
self
.
_client
.
logger
.
info
(
f
"Dataset for task
{
self
.
id
}
has been download
ed to
{
filename
}
"
)
def
download_backup
(
self
,
...
...
@@ -264,45 +218,171 @@ class TaskProxy(ModelProxy, models.ITaskRead):
*
,
status_check_period
:
int
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
):
)
->
None
:
"""
Download a task backup
"""
client
=
self
.
_client
Downloader
(
self
.
_client
).
prepare_and_download_file_from_endpoint
(
self
.
api
.
retrieve_backup_endpoint
,
filename
=
filename
,
pbar
=
pbar
,
status_check_period
=
status_check_period
,
url_params
=
{
"id"
:
self
.
id
},
)
self
.
_client
.
logger
.
info
(
f
"Backup for task
{
self
.
id
}
has been downloaded to
{
filename
}
"
)
def
get_jobs
(
self
)
->
List
[
Job
]:
return
[
Job
(
self
.
_client
,
m
)
for
m
in
self
.
api
.
list_jobs
(
id
=
self
.
id
)[
0
]]
def
get_meta
(
self
)
->
models
.
IDataMetaRead
:
(
meta
,
_
)
=
self
.
api
.
retrieve_data_meta
(
self
.
id
)
return
meta
def
get_frames_info
(
self
)
->
List
[
models
.
IFrameMeta
]:
return
self
.
get_meta
().
frames
def
remove_frames_by_ids
(
self
,
ids
:
Sequence
[
int
])
->
None
:
self
.
api
.
partial_update_data_meta
(
self
.
id
,
patched_data_meta_write_request
=
models
.
PatchedDataMetaWriteRequest
(
deleted_frames
=
ids
),
)
class
TasksRepo
(
_TaskRepoBase
,
ModelCreateMixin
[
Task
,
models
.
ITaskWriteRequest
],
ModelRetrieveMixin
[
Task
],
ModelListMixin
[
Task
],
ModelDeleteMixin
,
):
_entity_type
=
Task
def
create_from_data
(
self
,
spec
:
models
.
ITaskWriteRequest
,
resource_type
:
ResourceType
,
resources
:
Sequence
[
str
],
*
,
data_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
annotation_path
:
str
=
""
,
annotation_format
:
str
=
"CVAT XML 1.1"
,
status_check_period
:
int
=
None
,
dataset_repository_url
:
str
=
""
,
use_lfs
:
bool
=
False
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
)
->
Task
:
"""
Create a new task with the given name and labels JSON and
add the files to it.
Returns: id of the created task
"""
if
status_check_period
is
None
:
status_check_period
=
client
.
config
.
status_check_period
status_check_period
=
self
.
_client
.
config
.
status_check_period
if
getattr
(
spec
,
"project_id"
,
None
)
and
getattr
(
spec
,
"labels"
,
None
):
raise
exceptions
.
ApiValueError
(
"Can't set labels to a task inside a project. "
"Tasks inside a project use project's labels."
,
[
"labels"
],
)
task_id
=
self
.
id
task
=
self
.
create
(
spec
=
spec
)
self
.
_client
.
logger
.
info
(
"Created task ID: %s NAME: %s"
,
task
.
id
,
task
.
name
)
endpoint
=
client
.
api
.
tasks_api
.
retrieve_backup_endpoint
client
.
logger
.
info
(
"Waiting for the server to prepare the file..."
)
while
True
:
(
_
,
response
)
=
endpoint
.
call_with_http_info
(
id
=
task_id
)
client
.
logger
.
debug
(
"STATUS {}"
.
format
(
response
.
status
))
if
response
.
status
==
201
:
break
task
.
upload_data
(
resource_type
,
resources
,
pbar
=
pbar
,
params
=
data_params
)
self
.
_client
.
logger
.
info
(
"Awaiting for task %s creation..."
,
task
.
id
)
status
:
models
.
RqStatus
=
None
while
status
!=
models
.
RqStatusStateEnum
.
allowed_values
[(
"value"
,)][
"FINISHED"
]:
sleep
(
status_check_period
)
(
status
,
response
)
=
self
.
api
.
retrieve_status
(
task
.
id
)
self
.
_client
.
logger
.
info
(
"Task %s creation status=%s, message=%s"
,
task
.
id
,
status
.
state
.
value
,
status
.
message
,
)
if
status
.
state
.
value
==
models
.
RqStatusStateEnum
.
allowed_values
[(
"value"
,)][
"FAILED"
]:
raise
exceptions
.
ApiException
(
status
=
status
.
state
.
value
,
reason
=
status
.
message
,
http_resp
=
response
)
status
=
status
.
state
.
value
url
=
client
.
_api_map
.
make_endpoint_url
(
endpoint
.
path
,
kwsub
=
{
"id"
:
task_id
},
query_params
=
{
"action"
:
"download"
}
if
annotation_path
:
task
.
import_annotations
(
annotation_format
,
annotation_path
,
pbar
=
pbar
)
if
dataset_repository_url
:
git
.
create_git_repo
(
self
,
task_id
=
task
.
id
,
repo_url
=
dataset_repository_url
,
status_check_period
=
status_check_period
,
use_lfs
=
use_lfs
,
)
downloader
=
Downloader
(
client
)
downloader
.
download_file
(
url
,
output_path
=
filename
,
pbar
=
pbar
)
client
.
logger
.
info
(
f
"Task
{
task_id
}
has been exported sucessfully to
{
osp
.
abspath
(
filename
)
}
"
task
.
fetch
()
return
task
def
remove_by_ids
(
self
,
task_ids
:
Sequence
[
int
])
->
None
:
"""
Delete a list of tasks, ignoring those which don't exist.
"""
for
task_id
in
task_ids
:
(
_
,
response
)
=
self
.
api
.
destroy
(
task_id
,
_check_status
=
False
)
if
200
<=
response
.
status
<=
299
:
self
.
_client
.
logger
.
info
(
f
"Task ID
{
task_id
}
deleted"
)
elif
response
.
status
==
404
:
self
.
_client
.
logger
.
info
(
f
"Task ID
{
task_id
}
not found"
)
else
:
self
.
_client
.
logger
.
warning
(
f
"Failed to delete task ID
{
task_id
}
: "
f
"
{
response
.
msg
}
(status
{
response
.
status
}
)"
)
def
fetch
(
self
,
force
:
bool
=
False
):
# TODO: implement revision checking
model
,
_
=
self
.
_client
.
api
.
tasks_api
.
retrieve
(
self
.
id
)
self
.
_model
=
model
def
create_from_backup
(
self
,
filename
:
str
,
*
,
status_check_period
:
int
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
)
->
Task
:
"""
Import a task from a backup file
"""
if
status_check_period
is
None
:
status_check_period
=
self
.
_client
.
config
.
status_check_period
params
=
{
"filename"
:
osp
.
basename
(
filename
)}
url
=
self
.
_client
.
api_map
.
make_endpoint_url
(
self
.
api
.
create_backup_endpoint
.
path
)
uploader
=
Uploader
(
self
.
_client
)
response
=
uploader
.
upload_file
(
url
,
filename
,
meta
=
params
,
query_params
=
params
,
pbar
=
pbar
,
logger
=
self
.
_client
.
logger
.
debug
,
)
def
commit
(
self
,
force
:
bool
=
False
):
return
super
().
commit
(
force
)
rq_id
=
json
.
loads
(
response
.
data
)[
"rq_id"
]
response
=
self
.
_client
.
wait_for_completion
(
url
,
success_status
=
201
,
positive_statuses
=
[
202
],
post_params
=
{
"rq_id"
:
rq_id
},
status_check_period
=
status_check_period
,
)
def
update
(
self
,
**
kwargs
):
return
super
().
update
(
**
kwargs
)
task_id
=
json
.
loads
(
response
.
data
)[
"id"
]
self
.
_client
.
logger
.
info
(
f
"Task has been imported sucessfully. Task ID:
{
task_id
}
"
)
def
__str__
(
self
)
->
str
:
return
str
(
self
.
_model
)
return
self
.
retrieve
(
task_id
)
cvat-sdk/cvat_sdk/core/proxies/users.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from
__future__
import
annotations
from
cvat_sdk.api_client
import
apis
,
models
from
cvat_sdk.core.proxies.model_proxy
import
(
ModelDeleteMixin
,
ModelListMixin
,
ModelRetrieveMixin
,
ModelUpdateMixin
,
build_model_bases
,
)
_UserEntityBase
,
_UserRepoBase
=
build_model_bases
(
models
.
User
,
apis
.
UsersApi
,
api_member_name
=
"users_api"
)
class
User
(
models
.
IUser
,
_UserEntityBase
,
ModelUpdateMixin
[
models
.
IPatchedUserRequest
],
ModelDeleteMixin
):
_model_partial_update_arg
=
"patched_user_request"
class
UsersRepo
(
_UserRepoBase
,
ModelListMixin
[
User
],
ModelRetrieveMixin
[
User
],
):
_entity_type
=
User
def
retrieve_current_user
(
self
)
->
User
:
return
User
(
self
.
_client
,
self
.
api
.
retrieve_self
()[
0
])
cvat-sdk/cvat_sdk/core/types.py
已删除
100644 → 0
浏览文件 @
b60d3b48
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from
enum
import
Enum
class
ResourceType
(
Enum
):
LOCAL
=
0
SHARE
=
1
REMOTE
=
2
def
__str__
(
self
):
return
self
.
name
.
lower
()
def
__repr__
(
self
):
return
str
(
self
)
cvat-sdk/cvat_sdk/core/uploading.py
浏览文件 @
53697eca
...
...
@@ -7,16 +7,15 @@ from __future__ import annotations
import
os
import
os.path
as
osp
from
contextlib
import
ExitStack
,
closing
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
import
requests
import
urllib3
from
cvat_sdk.api_client
import
ApiClie
nt
from
cvat_sdk.api_client
.api_client
import
ApiClient
,
Endpoi
nt
from
cvat_sdk.api_client.rest
import
RESTClientObject
from
cvat_sdk.core.helpers
import
StreamWithProgress
from
cvat_sdk.core.helpers
import
StreamWithProgress
,
expect_status
from
cvat_sdk.core.progress
import
ProgressReporter
from
cvat_sdk.core.utils
import
assert_status
if
TYPE_CHECKING
:
from
cvat_sdk.core.client
import
Client
...
...
@@ -25,57 +24,12 @@ MAX_REQUEST_SIZE = 100 * 2**20
class
Uploader
:
def
__init__
(
self
,
client
:
Client
):
self
.
client
=
client
def
upload_files
(
self
,
url
:
str
,
resources
:
List
[
str
],
*
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
**
kwargs
,
):
bulk_file_groups
,
separate_files
,
total_size
=
self
.
_split_files_by_requests
(
resources
)
if
pbar
is
not
None
:
pbar
.
start
(
total_size
,
desc
=
"Uploading data"
)
self
.
_tus_start_upload
(
url
)
for
group
,
group_size
in
bulk_file_groups
:
with
ExitStack
()
as
es
:
files
=
{}
for
i
,
filename
in
enumerate
(
group
):
files
[
f
"client_files[
{
i
}
]"
]
=
(
filename
,
es
.
enter_context
(
closing
(
open
(
filename
,
"rb"
))).
read
(),
)
response
=
self
.
client
.
api
.
rest_client
.
POST
(
url
,
post_params
=
dict
(
**
kwargs
,
**
files
),
headers
=
{
"Content-Type"
:
"multipart/form-data"
,
"Upload-Multiple"
:
""
,
**
self
.
client
.
api
.
get_common_headers
(),
},
)
assert_status
(
200
,
response
)
if
pbar
is
not
None
:
pbar
.
advance
(
group_size
)
for
filename
in
separate_files
:
# TODO: check if basename produces invalid paths here, can lead to overwriting
self
.
_upload_file_data_with_tus
(
url
,
filename
,
meta
=
{
"filename"
:
osp
.
basename
(
filename
)},
pbar
=
pbar
,
logger
=
self
.
client
.
logger
.
debug
,
)
"""
Implements common uploading protocols
"""
self
.
_tus_finish_upload
(
url
,
fields
=
kwargs
)
def
__init__
(
self
,
client
:
Client
):
self
.
_client
=
client
def
upload_file
(
self
,
...
...
@@ -121,6 +75,27 @@ class Uploader:
)
return
self
.
_tus_finish_upload
(
url
,
query_params
=
query_params
,
fields
=
fields
)
def
_wait_for_completion
(
self
,
url
:
str
,
*
,
success_status
:
int
,
status_check_period
:
Optional
[
int
]
=
None
,
query_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
post_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
method
:
str
=
"POST"
,
positive_statuses
:
Optional
[
Sequence
[
int
]]
=
None
,
)
->
urllib3
.
HTTPResponse
:
return
self
.
_client
.
wait_for_completion
(
url
,
success_status
=
success_status
,
status_check_period
=
status_check_period
,
query_params
=
query_params
,
post_params
=
post_params
,
method
=
method
,
positive_statuses
=
positive_statuses
,
)
def
_split_files_by_requests
(
self
,
filenames
:
List
[
str
]
)
->
Tuple
[
List
[
Tuple
[
List
[
str
],
int
]],
List
[
str
],
int
]:
...
...
@@ -268,7 +243,7 @@ class Uploader:
input_file
=
StreamWithProgress
(
input_file
,
pbar
,
length
=
file_size
)
tus_uploader
=
self
.
_make_tus_uploader
(
self
.
client
.
api
,
self
.
_
client
.
api
,
url
=
url
.
rstrip
(
"/"
)
+
"/"
,
metadata
=
meta
,
file_stream
=
input_file
,
...
...
@@ -278,26 +253,131 @@ class Uploader:
tus_uploader
.
upload
()
def
_tus_start_upload
(
self
,
url
,
*
,
query_params
=
None
):
response
=
self
.
client
.
api
.
rest_client
.
POST
(
response
=
self
.
_
client
.
api
.
rest_client
.
POST
(
url
,
query_params
=
query_params
,
headers
=
{
"Upload-Start"
:
""
,
**
self
.
client
.
api
.
get_common_headers
(),
**
self
.
_
client
.
api
.
get_common_headers
(),
},
)
asser
t_status
(
202
,
response
)
expec
t_status
(
202
,
response
)
return
response
def
_tus_finish_upload
(
self
,
url
,
*
,
query_params
=
None
,
fields
=
None
):
response
=
self
.
client
.
api
.
rest_client
.
POST
(
response
=
self
.
_
client
.
api
.
rest_client
.
POST
(
url
,
headers
=
{
"Upload-Finish"
:
""
,
**
self
.
client
.
api
.
get_common_headers
(),
**
self
.
_
client
.
api
.
get_common_headers
(),
},
query_params
=
query_params
,
post_params
=
fields
,
)
asser
t_status
(
202
,
response
)
expec
t_status
(
202
,
response
)
return
response
class
AnnotationUploader
(
Uploader
):
def
upload_file_and_wait
(
self
,
endpoint
:
Endpoint
,
filename
:
str
,
format_name
:
str
,
*
,
url_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
status_check_period
:
Optional
[
int
]
=
None
,
):
url
=
self
.
_client
.
api_map
.
make_endpoint_url
(
endpoint
.
path
,
kwsub
=
url_params
)
params
=
{
"format"
:
format_name
,
"filename"
:
osp
.
basename
(
filename
)}
self
.
upload_file
(
url
,
filename
,
pbar
=
pbar
,
query_params
=
params
,
meta
=
{
"filename"
:
params
[
"filename"
]}
)
self
.
_wait_for_completion
(
url
,
success_status
=
201
,
positive_statuses
=
[
202
],
status_check_period
=
status_check_period
,
query_params
=
params
,
method
=
"POST"
,
)
class
DatasetUploader
(
Uploader
):
def
upload_file_and_wait
(
self
,
endpoint
:
Endpoint
,
filename
:
str
,
format_name
:
str
,
*
,
url_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
status_check_period
:
Optional
[
int
]
=
None
,
):
url
=
self
.
_client
.
api_map
.
make_endpoint_url
(
endpoint
.
path
,
kwsub
=
url_params
)
params
=
{
"format"
:
format_name
,
"filename"
:
osp
.
basename
(
filename
)}
self
.
upload_file
(
url
,
filename
,
pbar
=
pbar
,
query_params
=
params
,
meta
=
{
"filename"
:
params
[
"filename"
]}
)
self
.
_wait_for_completion
(
url
,
success_status
=
201
,
positive_statuses
=
[
202
],
status_check_period
=
status_check_period
,
query_params
=
params
,
method
=
"GET"
,
)
class
DataUploader
(
Uploader
):
def
upload_files
(
self
,
url
:
str
,
resources
:
List
[
str
],
*
,
pbar
:
Optional
[
ProgressReporter
]
=
None
,
**
kwargs
,
):
bulk_file_groups
,
separate_files
,
total_size
=
self
.
_split_files_by_requests
(
resources
)
if
pbar
is
not
None
:
pbar
.
start
(
total_size
,
desc
=
"Uploading data"
)
self
.
_tus_start_upload
(
url
)
for
group
,
group_size
in
bulk_file_groups
:
with
ExitStack
()
as
es
:
files
=
{}
for
i
,
filename
in
enumerate
(
group
):
files
[
f
"client_files[
{
i
}
]"
]
=
(
filename
,
es
.
enter_context
(
closing
(
open
(
filename
,
"rb"
))).
read
(),
)
response
=
self
.
_client
.
api
.
rest_client
.
POST
(
url
,
post_params
=
dict
(
**
kwargs
,
**
files
),
headers
=
{
"Content-Type"
:
"multipart/form-data"
,
"Upload-Multiple"
:
""
,
**
self
.
_client
.
api
.
get_common_headers
(),
},
)
expect_status
(
200
,
response
)
if
pbar
is
not
None
:
pbar
.
advance
(
group_size
)
for
filename
in
separate_files
:
# TODO: check if basename produces invalid paths here, can lead to overwriting
self
.
_upload_file_data_with_tus
(
url
,
filename
,
meta
=
{
"filename"
:
osp
.
basename
(
filename
)},
pbar
=
pbar
,
logger
=
self
.
_client
.
logger
.
debug
,
)
self
.
_tus_finish_upload
(
url
,
fields
=
kwargs
)
cvat-sdk/cvat_sdk/core/utils.py
浏览文件 @
53697eca
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
...
...
@@ -7,13 +6,6 @@ from __future__ import annotations
from
typing
import
Any
,
Dict
,
Sequence
import
urllib3
def
assert_status
(
code
:
int
,
response
:
urllib3
.
HTTPResponse
)
->
None
:
if
response
.
status
!=
code
:
raise
Exception
(
f
"Unexpected status code received
{
response
.
status
}
"
)
def
filter_dict
(
d
:
Dict
[
str
,
Any
],
*
,
keep
:
Sequence
[
str
]
=
None
,
drop
:
Sequence
[
str
]
=
None
...
...
cvat-sdk/gen/postprocess.py
浏览文件 @
53697eca
...
...
@@ -48,7 +48,7 @@ class Processor:
tokenized_path
=
tokenized_path
[
2
:]
prefix
=
tokenized_path
[
0
]
+
"_"
if
new_name
.
startswith
(
prefix
):
if
new_name
.
startswith
(
prefix
)
and
tokenized_path
[
0
]
in
operation
[
"tags"
]
:
new_name
=
new_name
[
len
(
prefix
)
:]
return
new_name
...
...
cvat-sdk/gen/templates/openapi-generator/api_client.mustache
浏览文件 @
53697eca
...
...
@@ -345,6 +345,9 @@ class ApiClient(object):
"""
if response_schema == (file_type,):
# TODO: response schema can be "oneOf" with a file option,
# this implementation does not cover this.
# handle file downloading
# save response body into a tmp file and return the instance
content_disposition = response.getheader("Content-Disposition")
...
...
cvat-sdk/gen/templates/openapi-generator/model.mustache
浏览文件 @
53697eca
...
...
@@ -9,6 +9,7 @@ import sys # noqa: F401
from
{{
packageName
}}
.model_utils import ( # noqa: F401
ApiTypeError,
IModelData,
ModelComposed,
ModelNormal,
ModelSimple,
...
...
cvat-sdk/gen/templates/openapi-generator/model_templates/model_normal.mustache
浏览文件 @
53697eca
class I
{{
classname
}}
:
class I
{{
classname
}}
(IModelData)
:
"""
NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
...
...
cvat-sdk/gen/templates/openapi-generator/model_templates/model_simple.mustache
浏览文件 @
53697eca
class I
{{
classname
}}
:
class I
{{
classname
}}
(IModelData)
:
"""
NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech
...
...
cvat-sdk/gen/templates/openapi-generator/model_utils.mustache
浏览文件 @
53697eca
...
...
@@ -113,6 +113,11 @@ def composed_model_input_classes(cls):
return []
class IModelData:
"""
The base class for model data. Declares model fields and their types for better introspection
"""
class OpenApiModel(object):
"""The base class for all OpenAPIModels"""
...
...
cvat-sdk/gen/templates/requirements/base.txt
浏览文件 @
53697eca
...
...
@@ -3,3 +3,4 @@
attrs >= 21.4.0
tqdm >= 4.64.0
tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code
typing_extensions >= 4.2.0
cvat-ui/package.json
浏览文件 @
53697eca
{
"name"
:
"cvat-ui"
,
"version"
:
"1.41.
0
"
,
"version"
:
"1.41.
1
"
,
"description"
:
"CVAT single-page application"
,
"main"
:
"src/index.tsx"
,
"scripts"
:
{
...
...
cvat-ui/src/components/tasks-page/tasks-page.tsx
浏览文件 @
53697eca
...
...
@@ -70,7 +70,7 @@ function TasksPageComponent(props: Props): JSX.Element {
<
Button
type
=
'link'
onClick
=
{
():
void
=>
{
dispatch
(
hideEmptyTasks
(
tru
e
));
dispatch
(
hideEmptyTasks
(
fals
e
));
message
.
destroy
();
}
}
>
...
...
cvat/apps/dataset_repo/views.py
浏览文件 @
53697eca
...
...
@@ -109,10 +109,8 @@ def update_git_repo(request, tid):
status
=
http
.
HTTPStatus
.
OK
,
)
except
Exception
as
ex
:
try
:
with
contextlib
.
suppress
(
Exception
)
:
slogger
.
task
[
tid
].
error
(
"error occurred during changing repository request"
,
exc_info
=
True
)
except
Exception
:
pass
return
HttpResponseBadRequest
(
str
(
ex
))
...
...
cvat/apps/engine/filters.py
浏览文件 @
53697eca
...
...
@@ -15,7 +15,7 @@ from rest_framework.exceptions import ValidationError
class
SearchFilter
(
filters
.
SearchFilter
):
def
get_search_fields
(
self
,
view
,
request
):
search_fields
=
getattr
(
view
,
'search_fields'
,
[])
search_fields
=
getattr
(
view
,
'search_fields'
)
or
[]
lookup_fields
=
{
field
:
field
for
field
in
search_fields
}
view_lookup_fields
=
getattr
(
view
,
'lookup_fields'
,
{})
keys_to_update
=
set
(
search_fields
)
&
set
(
view_lookup_fields
.
keys
())
...
...
cvat/apps/engine/mixins.py
浏览文件 @
53697eca
...
...
@@ -9,7 +9,7 @@ import uuid
from
django.conf
import
settings
from
django.core.cache
import
cache
from
distutils.util
import
strtobool
from
rest_framework
import
status
from
rest_framework
import
status
,
mixins
from
rest_framework.response
import
Response
from
cvat.apps.engine.models
import
Location
...
...
@@ -315,3 +315,17 @@ class SerializeMixin:
file_name
=
request
.
query_params
.
get
(
"filename"
,
""
)
return
import_func
(
request
,
filename
=
file_name
)
return
self
.
upload_data
(
request
)
class
PartialUpdateModelMixin
:
"""
Update fields of a model instance.
Almost the same as UpdateModelMixin, but has no public PUT / update() method.
"""
def
perform_update
(
self
,
serializer
):
mixins
.
UpdateModelMixin
.
perform_update
(
self
,
serializer
=
serializer
)
def
partial_update
(
self
,
request
,
*
args
,
**
kwargs
):
kwargs
[
'partial'
]
=
True
return
mixins
.
UpdateModelMixin
.
update
(
self
,
request
=
request
,
*
args
,
**
kwargs
)
cvat/apps/engine/schema.py
浏览文件 @
53697eca
...
...
@@ -2,12 +2,26 @@
#
# SPDX-License-Identifier: MIT
from
typing
import
Type
from
rest_framework
import
serializers
from
drf_spectacular.extensions
import
OpenApiSerializerExtension
from
drf_spectacular.plumbing
import
force_instance
from
drf_spectacular.plumbing
import
force_instance
,
build_basic_type
from
drf_spectacular.types
import
OpenApiTypes
from
drf_spectacular.serializers
import
PolymorphicProxySerializerExtension
def
_copy_serializer
(
instance
:
serializers
.
Serializer
,
*
,
_new_type
:
Type
[
serializers
.
Serializer
]
=
None
,
**
kwargs
)
->
serializers
.
Serializer
:
_new_type
=
_new_type
or
type
(
instance
)
instance_kwargs
=
instance
.
_kwargs
instance_kwargs
[
'partial'
]
=
instance
.
partial
# this can be set separately
instance_kwargs
.
update
(
kwargs
)
return
_new_type
(
*
instance
.
_args
,
**
instance
.
_kwargs
)
class
DataSerializerExtension
(
OpenApiSerializerExtension
):
# *FileSerializer mimics a FileField
# but it is mapped as an object with a file field, which
...
...
@@ -23,40 +37,106 @@ class DataSerializerExtension(OpenApiSerializerExtension):
target_class
=
'cvat.apps.engine.serializers.DataSerializer'
def
map_serializer
(
self
,
auto_schema
,
direction
):
assert
is
instance
(
self
.
target_class
,
type
)
assert
is
subclass
(
self
.
target_class
,
serializers
.
ModelSerializer
)
instance
=
force_instance
(
self
.
target_class
)
instance
=
self
.
target
assert
isinstance
(
instance
,
serializers
.
ModelSerializer
)
def
_get_field
(
instance
,
source_name
,
field_name
):
def
_get_field
(
instance
:
serializers
.
ModelSerializer
,
source_name
:
str
,
field_name
:
str
)
->
serializers
.
ModelField
:
child_instance
=
force_instance
(
instance
.
fields
[
source_name
].
child
)
assert
isinstance
(
child_instance
,
serializers
.
ModelSerializer
)
child_fields
=
child_instance
.
fields
assert
child_fields
.
keys
()
==
{
'file'
}
# protect
from
changes
assert
child_fields
.
keys
()
==
{
'file'
}
# protect
ion from implementation
changes
return
child_fields
[
field_name
]
def
_sanitize_field
(
field
)
:
def
_sanitize_field
(
field
:
serializers
.
ModelField
)
->
serializers
.
ModelField
:
field
.
source
=
None
field
.
source_attrs
=
[]
return
field
def
_make_field
(
source_name
,
field_name
)
:
def
_make_field
(
source_name
:
str
,
field_name
:
str
)
->
serializers
.
ModelField
:
return
_sanitize_field
(
_get_field
(
instance
,
source_name
,
field_name
))
class
_Override
(
self
.
target_class
):
# pylint: disable=inherit-non-class
client_files
=
serializers
.
ListField
(
child
=
_make_field
(
'client_files'
,
'file'
),
default
=
[])
server_files
=
serializers
.
ListField
(
child
=
_make_field
(
'server_files'
,
'file'
),
default
=
[])
remote_files
=
serializers
.
ListField
(
child
=
_make_field
(
'remote_files'
,
'file'
),
default
=
[])
client_files
=
serializers
.
ListField
(
child
=
_make_field
(
'client_files'
,
'file'
),
default
=
[])
server_files
=
serializers
.
ListField
(
child
=
_make_field
(
'server_files'
,
'file'
),
default
=
[])
remote_files
=
serializers
.
ListField
(
child
=
_make_field
(
'remote_files'
,
'file'
),
default
=
[])
return
auto_schema
.
_map_serializer
(
_copy_serializer
(
instance
,
_new_type
=
_Override
,
context
=
{
'view'
:
auto_schema
.
view
}),
direction
,
bypass_extensions
=
False
)
class
WriteOnceSerializerExtension
(
OpenApiSerializerExtension
):
"""
Enables support for cvat.apps.engine.serializers.WriteOnceMixin in drf-spectacular.
Doesn't block other extensions on the target serializer.
"""
return
auto_schema
.
_map_serializer
(
_Override
(),
direction
,
bypass_extensions
=
False
)
match_subclasses
=
True
target_class
=
'cvat.apps.engine.serializers.WriteOnceMixin'
_PROCESSED_INDICATOR_NAME
=
'write_once_serializer_extension_processed'
class
CustomProxySerializerExtension
(
PolymorphicProxySerializerExtension
):
"""
Allows to patch PolymorphicProxySerializer-based schema.
@
classmethod
def
_matches
(
cls
,
target
)
->
bool
:
if
super
().
_matches
(
target
):
# protect from recursive invocations
assert
isinstance
(
target
,
serializers
.
Serializer
)
processed
=
target
.
context
.
get
(
cls
.
_PROCESSED_INDICATOR_NAME
,
False
)
return
not
processed
return
False
Override "target_component" in children classes.
def
map_serializer
(
self
,
auto_schema
,
direction
):
return
auto_schema
.
_map_serializer
(
_copy_serializer
(
self
.
target
,
context
=
{
'view'
:
auto_schema
.
view
,
self
.
_PROCESSED_INDICATOR_NAME
:
True
}),
direction
,
bypass_extensions
=
False
)
class
OpenApiTypeProxySerializerExtension
(
PolymorphicProxySerializerExtension
):
"""
Provides support for OpenApiTypes in the PolymorphicProxySerializer list
"""
priority
=
0
# restore normal priority
def
_process_serializer
(
self
,
auto_schema
,
serializer
,
direction
):
if
isinstance
(
serializer
,
OpenApiTypes
):
schema
=
build_basic_type
(
serializer
)
return
(
None
,
schema
)
else
:
return
super
().
_process_serializer
(
auto_schema
=
auto_schema
,
serializer
=
serializer
,
direction
=
direction
)
def
map_serializer
(
self
,
auto_schema
,
direction
):
""" custom handling for @extend_schema's injection of PolymorphicProxySerializer """
result
=
super
().
map_serializer
(
auto_schema
=
auto_schema
,
direction
=
direction
)
if
isinstance
(
self
.
target
.
serializers
,
dict
):
required
=
OpenApiTypes
.
NONE
not
in
self
.
target
.
serializers
.
values
()
else
:
required
=
OpenApiTypes
.
NONE
not
in
self
.
target
.
serializers
if
not
required
:
result
[
'nullable'
]
=
True
return
result
class
ComponentProxySerializerExtension
(
OpenApiTypeProxySerializerExtension
):
"""
Allows to patch PolymorphicProxySerializer-based component schema.
Override the "target_component" field in children classes.
"""
priority
=
1
# higher than in the parent class
target_component
:
str
=
''
@
classmethod
...
...
@@ -69,7 +149,7 @@ class CustomProxySerializerExtension(PolymorphicProxySerializerExtension):
return
target
.
component_name
==
cls
.
target_component
class
AnyOfProxySerializerExtension
(
C
ustom
ProxySerializerExtension
):
class
AnyOfProxySerializerExtension
(
C
omponent
ProxySerializerExtension
):
"""
Replaces oneOf with anyOf in the generated schema. Useful when
no disciminator field is available, and the options are
...
...
cvat/apps/engine/serializers.py
浏览文件 @
53697eca
...
...
@@ -198,7 +198,9 @@ class JobReadSerializer(serializers.ModelSerializer):
class
JobWriteSerializer
(
serializers
.
ModelSerializer
):
assignee
=
serializers
.
IntegerField
(
allow_null
=
True
,
required
=
False
)
def
to_representation
(
self
,
instance
):
# FIXME: deal with resquest/response separation
serializer
=
JobReadSerializer
(
instance
,
context
=
self
.
context
)
return
serializer
.
data
...
...
@@ -307,8 +309,8 @@ class RqStatusSerializer(serializers.Serializer):
progress
=
serializers
.
FloatField
(
max_value
=
100
,
default
=
0
)
class
WriteOnceMixin
:
"""
Adds support for write once fields to serializers.
"""
Adds support for write once fields to serializers.
To use it, specify a list of fields as `write_once_fields` on the
serializer's Meta:
...
...
@@ -329,12 +331,15 @@ class WriteOnceMixin:
# We're only interested in PATCH/PUT.
if
'update'
in
getattr
(
self
.
context
.
get
(
'view'
),
'action'
,
''
):
return
self
.
_set_write_once_fields
(
extra_kwargs
)
extra_kwargs
=
self
.
_set_write_once_fields
(
extra_kwargs
)
return
extra_kwargs
def
_set_write_once_fields
(
self
,
extra_kwargs
):
"""Set all fields in `Meta.write_once_fields` to read_only."""
"""
Set all fields in `Meta.write_once_fields` to read_only.
"""
write_once_fields
=
getattr
(
self
.
Meta
,
'write_once_fields'
,
None
)
if
not
write_once_fields
:
return
extra_kwargs
...
...
@@ -352,7 +357,7 @@ class WriteOnceMixin:
return
extra_kwargs
class
DataSerializer
(
serializers
.
ModelSerializer
):
class
DataSerializer
(
WriteOnceMixin
,
serializers
.
ModelSerializer
):
image_quality
=
serializers
.
IntegerField
(
min_value
=
0
,
max_value
=
100
)
use_zip_chunks
=
serializers
.
BooleanField
(
default
=
False
)
client_files
=
ClientFileSerializer
(
many
=
True
,
default
=
[])
...
...
@@ -876,16 +881,16 @@ class AnnotationSerializer(serializers.Serializer):
id
=
serializers
.
IntegerField
(
default
=
None
,
allow_null
=
True
)
frame
=
serializers
.
IntegerField
(
min_value
=
0
)
label_id
=
serializers
.
IntegerField
(
min_value
=
0
)
group
=
serializers
.
IntegerField
(
min_value
=
0
,
allow_null
=
True
)
source
=
serializers
.
CharField
(
default
=
'manual'
)
group
=
serializers
.
IntegerField
(
min_value
=
0
,
allow_null
=
True
,
default
=
None
)
source
=
serializers
.
CharField
(
default
=
'manual'
)
class
LabeledImageSerializer
(
AnnotationSerializer
):
attributes
=
AttributeValSerializer
(
many
=
True
,
source
=
"labeledimageattributeval_set"
)
source
=
"labeledimageattributeval_set"
,
default
=
[]
)
class
ShapeSerializer
(
serializers
.
Serializer
):
type
=
serializers
.
ChoiceField
(
choices
=
models
.
ShapeType
.
choices
())
occluded
=
serializers
.
BooleanField
()
occluded
=
serializers
.
BooleanField
(
default
=
False
)
outside
=
serializers
.
BooleanField
(
default
=
False
,
required
=
False
)
z_order
=
serializers
.
IntegerField
(
default
=
0
)
rotation
=
serializers
.
FloatField
(
default
=
0
,
min_value
=
0
,
max_value
=
360
)
...
...
@@ -896,7 +901,7 @@ class ShapeSerializer(serializers.Serializer):
class
SubLabeledShapeSerializer
(
ShapeSerializer
,
AnnotationSerializer
):
attributes
=
AttributeValSerializer
(
many
=
True
,
source
=
"labeledshapeattributeval_set"
)
source
=
"labeledshapeattributeval_set"
,
default
=
[]
)
class
LabeledShapeSerializer
(
SubLabeledShapeSerializer
):
elements
=
SubLabeledShapeSerializer
(
many
=
True
,
required
=
False
)
...
...
@@ -905,22 +910,22 @@ class TrackedShapeSerializer(ShapeSerializer):
id
=
serializers
.
IntegerField
(
default
=
None
,
allow_null
=
True
)
frame
=
serializers
.
IntegerField
(
min_value
=
0
)
attributes
=
AttributeValSerializer
(
many
=
True
,
source
=
"trackedshapeattributeval_set"
)
source
=
"trackedshapeattributeval_set"
,
default
=
[]
)
class
SubLabeledTrackSerializer
(
AnnotationSerializer
):
shapes
=
TrackedShapeSerializer
(
many
=
True
,
allow_empty
=
True
,
source
=
"trackedshape_set"
)
attributes
=
AttributeValSerializer
(
many
=
True
,
source
=
"labeledtrackattributeval_set"
)
source
=
"labeledtrackattributeval_set"
,
default
=
[]
)
class
LabeledTrackSerializer
(
SubLabeledTrackSerializer
):
elements
=
SubLabeledTrackSerializer
(
many
=
True
,
required
=
False
)
class
LabeledDataSerializer
(
serializers
.
Serializer
):
version
=
serializers
.
IntegerField
(
)
tags
=
LabeledImageSerializer
(
many
=
True
)
shapes
=
LabeledShapeSerializer
(
many
=
True
)
tracks
=
LabeledTrackSerializer
(
many
=
True
)
version
=
serializers
.
IntegerField
(
default
=
0
)
# TODO: remove
tags
=
LabeledImageSerializer
(
many
=
True
,
default
=
[]
)
shapes
=
LabeledShapeSerializer
(
many
=
True
,
default
=
[]
)
tracks
=
LabeledTrackSerializer
(
many
=
True
,
default
=
[]
)
class
FileInfoSerializer
(
serializers
.
Serializer
):
name
=
serializers
.
CharField
(
max_length
=
1024
)
...
...
@@ -991,6 +996,10 @@ class IssueReadSerializer(serializers.ModelSerializer):
fields
=
(
'id'
,
'frame'
,
'position'
,
'job'
,
'owner'
,
'assignee'
,
'created_date'
,
'updated_date'
,
'comments'
,
'resolved'
)
read_only_fields
=
fields
extra_kwargs
=
{
'created_date'
:
{
'allow_null'
:
True
},
'updated_date'
:
{
'allow_null'
:
True
},
}
class
IssueWriteSerializer
(
WriteOnceMixin
,
serializers
.
ModelSerializer
):
...
...
@@ -1010,6 +1019,12 @@ class IssueWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
message
=
message
,
owner
=
db_issue
.
owner
)
return
db_issue
def
update
(
self
,
instance
,
validated_data
):
message
=
validated_data
.
pop
(
'message'
,
None
)
if
message
:
raise
NotImplementedError
(
'Check https://github.com/cvat-ai/cvat/issues/122'
)
return
super
().
update
(
instance
,
validated_data
)
class
Meta
:
model
=
models
.
Issue
fields
=
(
'id'
,
'frame'
,
'position'
,
'job'
,
'owner'
,
'assignee'
,
...
...
cvat/apps/engine/tests/test_rest_api.py
浏览文件 @
53697eca
...
...
@@ -313,7 +313,7 @@ class JobGetAPITestCase(APITestCase):
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_401_UNAUTHORIZED
)
class
JobUpdateAPITestCase
(
APITestCase
):
class
Job
Partial
UpdateAPITestCase
(
APITestCase
):
def
setUp
(
self
):
self
.
client
=
APIClient
()
self
.
task
=
create_dummy_db_tasks
(
self
)[
0
]
...
...
@@ -327,7 +327,7 @@ class JobUpdateAPITestCase(APITestCase):
def
_run_api_v2_jobs_id
(
self
,
jid
,
user
,
data
):
with
ForceLogin
(
user
,
self
.
client
):
response
=
self
.
client
.
p
ut
(
'/api/jobs/{}'
.
format
(
jid
),
data
=
data
,
format
=
'json'
)
response
=
self
.
client
.
p
atch
(
'/api/jobs/{}'
.
format
(
jid
),
data
=
data
,
format
=
'json'
)
return
response
...
...
@@ -382,22 +382,43 @@ class JobUpdateAPITestCase(APITestCase):
response
=
self
.
_run_api_v2_jobs_id
(
self
.
job
.
id
+
10
,
None
,
data
)
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_401_UNAUTHORIZED
)
class
JobPartialUpdateAPITestCase
(
JobUpdateAPITestCase
):
def
test_api_v2_jobs_id_annotator_partial
(
self
):
data
=
{
"stage"
:
StageChoice
.
ANNOTATION
}
response
=
self
.
_run_api_v2_jobs_id
(
self
.
job
.
id
,
self
.
annotator
,
data
)
self
.
assertEquals
(
response
.
status_code
,
status
.
HTTP_403_FORBIDDEN
,
response
)
def
test_api_v2_jobs_id_admin_partial
(
self
):
data
=
{
"assignee_id"
:
self
.
user
.
id
}
response
=
self
.
_run_api_v2_jobs_id
(
self
.
job
.
id
,
self
.
owner
,
data
)
self
.
_check_request
(
response
,
data
)
class
JobUpdateAPITestCase
(
APITestCase
):
def
setUp
(
self
):
self
.
client
=
APIClient
()
self
.
task
=
create_dummy_db_tasks
(
self
)[
0
]
self
.
job
=
Job
.
objects
.
filter
(
segment__task_id
=
self
.
task
.
id
).
first
()
self
.
job
.
assignee
=
self
.
annotator
self
.
job
.
save
()
@
classmethod
def
setUpTestData
(
cls
):
create_db_users
(
cls
)
def
_run_api_v2_jobs_id
(
self
,
jid
,
user
,
data
):
with
ForceLogin
(
user
,
self
.
client
):
response
=
self
.
client
.
p
atch
(
'/api/jobs/{}'
.
format
(
jid
),
data
=
data
,
format
=
'json'
)
response
=
self
.
client
.
p
ut
(
'/api/jobs/{}'
.
format
(
jid
),
data
=
data
,
format
=
'json'
)
return
response
def
test_api_v2_jobs_id_annotator
_partial
(
self
):
def
test_api_v2_jobs_id_annotator
(
self
):
data
=
{
"stage"
:
StageChoice
.
ANNOTATION
}
response
=
self
.
_run_api_v2_jobs_id
(
self
.
job
.
id
,
self
.
annotator
,
data
)
self
.
assertEquals
(
response
.
status_code
,
status
.
HTTP_40
3_FORBIDDEN
,
response
)
self
.
assertEquals
(
response
.
status_code
,
status
.
HTTP_40
5_METHOD_NOT_ALLOWED
,
response
)
def
test_api_v2_jobs_id_admin
_partial
(
self
):
def
test_api_v2_jobs_id_admin
(
self
):
data
=
{
"assignee_id"
:
self
.
user
.
id
}
response
=
self
.
_run_api_v2_jobs_id
(
self
.
job
.
id
,
self
.
owner
,
data
)
self
.
_check_request
(
response
,
data
)
self
.
assertEquals
(
response
.
status_code
,
status
.
HTTP_405_METHOD_NOT_ALLOWED
,
response
)
class
JobDataMetaPartialUpdateAPITestCase
(
APITestCase
):
def
setUp
(
self
):
...
...
@@ -1987,7 +2008,6 @@ class TaskDeleteAPITestCase(APITestCase):
self
.
assertFalse
(
os
.
path
.
exists
(
task_dir
))
class
TaskUpdateAPITestCase
(
APITestCase
):
def
setUp
(
self
):
self
.
client
=
APIClient
()
...
...
@@ -2003,6 +2023,39 @@ class TaskUpdateAPITestCase(APITestCase):
return
response
def
_check_api_v2_tasks_id
(
self
,
user
,
data
):
for
db_task
in
self
.
tasks
:
response
=
self
.
_run_api_v2_tasks_id
(
db_task
.
id
,
user
,
data
)
if
user
is
None
:
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_401_UNAUTHORIZED
)
else
:
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_405_METHOD_NOT_ALLOWED
)
def
test_api_v2_tasks_id_admin
(
self
):
data
=
{
"name"
:
"new name for the task"
}
self
.
_check_api_v2_tasks_id
(
self
.
admin
,
data
)
def
test_api_v2_tasks_id_user
(
self
):
data
=
{
"name"
:
"new name for the task"
}
self
.
_check_api_v2_tasks_id
(
self
.
user
,
data
)
def
test_api_v2_tasks_id_somebody
(
self
):
data
=
{
"name"
:
"new name for the task"
}
self
.
_check_api_v2_tasks_id
(
self
.
somebody
,
data
)
def
test_api_v2_tasks_id_no_auth
(
self
):
data
=
{
"name"
:
"new name for the task"
}
self
.
_check_api_v2_tasks_id
(
None
,
data
)
class
TaskPartialUpdateAPITestCase
(
APITestCase
):
def
setUp
(
self
):
self
.
client
=
APIClient
()
@
classmethod
def
setUpTestData
(
cls
):
create_db_users
(
cls
)
cls
.
tasks
=
create_dummy_db_tasks
(
cls
)
def
_check_response
(
self
,
response
,
db_task
,
data
):
self
.
assertEqual
(
response
.
status_code
,
status
.
HTTP_200_OK
)
name
=
data
.
get
(
"name"
,
db_task
.
name
)
...
...
@@ -2034,6 +2087,13 @@ class TaskUpdateAPITestCase(APITestCase):
[
label
[
"name"
]
for
label
in
response
.
data
[
"labels"
]]
)
def
_run_api_v2_tasks_id
(
self
,
tid
,
user
,
data
):
with
ForceLogin
(
user
,
self
.
client
):
response
=
self
.
client
.
patch
(
'/api/tasks/{}'
.
format
(
tid
),
data
=
data
,
format
=
"json"
)
return
response
def
_check_api_v2_tasks_id
(
self
,
user
,
data
):
for
db_task
in
self
.
tasks
:
response
=
self
.
_run_api_v2_tasks_id
(
db_task
.
id
,
user
,
data
)
...
...
@@ -2077,32 +2137,6 @@ class TaskUpdateAPITestCase(APITestCase):
}
self
.
_check_api_v2_tasks_id
(
self
.
user
,
data
)
def
test_api_v2_tasks_id_somebody
(
self
):
data
=
{
"name"
:
"new name for the task"
,
"labels"
:
[{
"name"
:
"test"
,
}]
}
self
.
_check_api_v2_tasks_id
(
self
.
somebody
,
data
)
def
test_api_v2_tasks_id_no_auth
(
self
):
data
=
{
"name"
:
"new name for the task"
,
"labels"
:
[{
"name"
:
"test"
,
}]
}
self
.
_check_api_v2_tasks_id
(
None
,
data
)
class
TaskPartialUpdateAPITestCase
(
TaskUpdateAPITestCase
):
def
_run_api_v2_tasks_id
(
self
,
tid
,
user
,
data
):
with
ForceLogin
(
user
,
self
.
client
):
response
=
self
.
client
.
patch
(
'/api/tasks/{}'
.
format
(
tid
),
data
=
data
,
format
=
"json"
)
return
response
def
test_api_v2_tasks_id_admin_partial
(
self
):
data
=
{
"name"
:
"new name for the task #2"
,
...
...
cvat/apps/engine/views.py
浏览文件 @
53697eca
此差异已折叠。
点击以展开。
cvat/utils/version.py
浏览文件 @
53697eca
...
...
@@ -45,10 +45,10 @@ def get_git_changeset():
so it's sufficient for generating the development version numbers.
"""
repo_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
git_log
=
subprocess
.
Popen
(
'git log --pretty=format:%ct --quiet -1 HEAD'
,
git_log
=
subprocess
.
Popen
(
# nosec: B603, B607
[
'git'
,
'log'
,
'--pretty=format:%ct'
,
'--quiet'
,
'-1'
,
'HEAD'
]
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
True
,
cwd
=
repo_dir
,
universal_newlines
=
True
,
cwd
=
repo_dir
,
universal_newlines
=
True
,
)
timestamp
=
git_log
.
communicate
()[
0
]
try
:
...
...
@@ -56,4 +56,3 @@ def get_git_changeset():
except
ValueError
:
return
None
return
timestamp
.
strftime
(
'%Y%m%d%H%M%S'
)
tests/python/cli/test_cli.py
浏览文件 @
53697eca
...
...
@@ -8,9 +8,9 @@ import os
from
pathlib
import
Path
import
pytest
from
cvat_sdk
import
exceptions
,
make_client
from
cvat_sdk.
core.tasks
import
TaskProxy
from
cvat_sdk.core.
types
import
ResourceType
from
cvat_sdk
import
make_client
from
cvat_sdk.
api_client
import
exceptions
from
cvat_sdk.core.
proxies.tasks
import
ResourceType
,
Task
from
PIL
import
Image
from
sdk.util
import
generate_coco_json
...
...
@@ -41,8 +41,6 @@ class TestCLI:
yield
self
.
tmp_path
=
None
@
pytest
.
fixture
def
fxt_image_file
(
self
):
img_path
=
self
.
tmp_path
/
"img_0.png"
...
...
@@ -61,7 +59,7 @@ class TestCLI:
yield
ann_filename
@
pytest
.
fixture
def
fxt_backup_file
(
self
,
fxt_new_task
:
Task
Proxy
,
fxt_coco_file
:
str
):
def
fxt_backup_file
(
self
,
fxt_new_task
:
Task
,
fxt_coco_file
:
str
):
backup_path
=
self
.
tmp_path
/
"backup.zip"
fxt_new_task
.
import_annotations
(
"COCO 1.0"
,
filename
=
fxt_coco_file
)
...
...
@@ -73,7 +71,7 @@ class TestCLI:
def
fxt_new_task
(
self
):
files
=
generate_images
(
str
(
self
.
tmp_path
),
5
)
task
=
self
.
client
.
create_task
(
task
=
self
.
client
.
tasks
.
create_from_data
(
spec
=
{
"name"
:
"test_task"
,
"labels"
:
[{
"name"
:
"car"
},
{
"name"
:
"person"
}],
...
...
@@ -114,30 +112,28 @@ class TestCLI:
)
task_id
=
int
(
stdout
.
split
()[
-
1
])
assert
self
.
client
.
retrieve_task
(
task_id
).
size
==
5
assert
self
.
client
.
tasks
.
retrieve
(
task_id
).
size
==
5
def
test_can_list_tasks_in_simple_format
(
self
,
fxt_new_task
:
Task
Proxy
):
def
test_can_list_tasks_in_simple_format
(
self
,
fxt_new_task
:
Task
):
output
=
self
.
run_cli
(
"ls"
)
results
=
output
.
split
(
"
\n
"
)
assert
any
(
str
(
fxt_new_task
.
id
)
in
r
for
r
in
results
)
def
test_can_list_tasks_in_json_format
(
self
,
fxt_new_task
:
Task
Proxy
):
def
test_can_list_tasks_in_json_format
(
self
,
fxt_new_task
:
Task
):
output
=
self
.
run_cli
(
"ls"
,
"--json"
)
results
=
json
.
loads
(
output
)
assert
any
(
r
[
"id"
]
==
fxt_new_task
.
id
for
r
in
results
)
def
test_can_delete_task
(
self
,
fxt_new_task
:
Task
Proxy
):
def
test_can_delete_task
(
self
,
fxt_new_task
:
Task
):
self
.
run_cli
(
"delete"
,
str
(
fxt_new_task
.
id
))
with
pytest
.
raises
(
exceptions
.
ApiException
)
as
capture
:
with
pytest
.
raises
(
exceptions
.
NotFoundException
)
:
fxt_new_task
.
fetch
()
assert
capture
.
value
.
status
==
404
def
test_can_download_task_annotations
(
self
,
fxt_new_task
:
TaskProxy
):
filename
:
Path
=
self
.
tmp_path
/
"task_{fxt_new_task.id}-cvat.zip"
def
test_can_download_task_annotations
(
self
,
fxt_new_task
:
Task
):
filename
=
self
.
tmp_path
/
"task_{fxt_new_task.id}-cvat.zip"
self
.
run_cli
(
"dump"
,
str
(
fxt_new_task
.
id
),
...
...
@@ -152,8 +148,8 @@ class TestCLI:
assert
0
<
filename
.
stat
().
st_size
def
test_can_download_task_backup
(
self
,
fxt_new_task
:
Task
Proxy
):
filename
:
Path
=
self
.
tmp_path
/
"task_{fxt_new_task.id}-cvat.zip"
def
test_can_download_task_backup
(
self
,
fxt_new_task
:
Task
):
filename
=
self
.
tmp_path
/
"task_{fxt_new_task.id}-cvat.zip"
self
.
run_cli
(
"export"
,
str
(
fxt_new_task
.
id
),
...
...
@@ -165,7 +161,7 @@ class TestCLI:
assert
0
<
filename
.
stat
().
st_size
@
pytest
.
mark
.
parametrize
(
"quality"
,
(
"compressed"
,
"original"
))
def
test_can_download_task_frames
(
self
,
fxt_new_task
:
Task
Proxy
,
quality
:
str
):
def
test_can_download_task_frames
(
self
,
fxt_new_task
:
Task
,
quality
:
str
):
out_dir
=
str
(
self
.
tmp_path
/
"downloads"
)
self
.
run_cli
(
"frames"
,
...
...
@@ -182,13 +178,13 @@ class TestCLI:
"task_{}_frame_{:06d}.jpg"
.
format
(
fxt_new_task
.
id
,
i
)
for
i
in
range
(
2
)
}
def
test_can_upload_annotations
(
self
,
fxt_new_task
:
Task
Proxy
,
fxt_coco_file
:
Path
):
def
test_can_upload_annotations
(
self
,
fxt_new_task
:
Task
,
fxt_coco_file
:
Path
):
self
.
run_cli
(
"upload"
,
str
(
fxt_new_task
.
id
),
str
(
fxt_coco_file
),
"--format"
,
"COCO 1.0"
)
def
test_can_create_from_backup
(
self
,
fxt_new_task
:
Task
Proxy
,
fxt_backup_file
:
Path
):
def
test_can_create_from_backup
(
self
,
fxt_new_task
:
Task
,
fxt_backup_file
:
Path
):
stdout
=
self
.
run_cli
(
"import"
,
str
(
fxt_backup_file
))
task_id
=
int
(
stdout
.
split
()[
-
1
])
assert
task_id
assert
task_id
!=
fxt_new_task
.
id
assert
self
.
client
.
retrieve_task
(
task_id
).
size
==
fxt_new_task
.
size
assert
self
.
client
.
tasks
.
retrieve
(
task_id
).
size
==
fxt_new_task
.
size
tests/python/rest_api/test_auth.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import
json
from
http
import
HTTPStatus
import
pytest
from
cvat_sdk.api_client
import
ApiClient
,
Configuration
,
models
from
shared.utils.config
import
BASE_URL
,
USER_PASS
,
make_api_client
@
pytest
.
mark
.
usefixtures
(
"dontchangedb"
)
class
TestBasicAuth
:
def
test_can_do_basic_auth
(
self
,
admin_user
:
str
):
username
=
admin_user
config
=
Configuration
(
host
=
BASE_URL
,
username
=
username
,
password
=
USER_PASS
)
with
ApiClient
(
config
)
as
client
:
(
user
,
response
)
=
client
.
users_api
.
retrieve_self
()
assert
response
.
status
==
HTTPStatus
.
OK
assert
user
.
username
==
username
@
pytest
.
mark
.
usefixtures
(
"changedb"
)
class
TestTokenAuth
:
@
staticmethod
def
login
(
client
:
ApiClient
,
username
:
str
)
->
models
.
Token
:
(
auth
,
_
)
=
client
.
auth_api
.
create_login
(
models
.
LoginRequest
(
username
=
username
,
password
=
USER_PASS
)
)
client
.
set_default_header
(
"Authorization"
,
"Token "
+
auth
.
key
)
return
auth
@
classmethod
def
make_client
(
cls
,
username
:
str
)
->
ApiClient
:
with
ApiClient
(
Configuration
(
host
=
BASE_URL
))
as
client
:
cls
.
login
(
client
,
username
)
return
client
def
test_can_do_token_auth_and_manage_cookies
(
self
,
admin_user
:
str
):
username
=
admin_user
with
ApiClient
(
Configuration
(
host
=
BASE_URL
))
as
client
:
auth
=
self
.
login
(
client
,
username
=
username
)
assert
"sessionid"
in
client
.
cookies
assert
"csrftoken"
in
client
.
cookies
assert
auth
.
key
(
user
,
response
)
=
client
.
users_api
.
retrieve_self
()
assert
response
.
status
==
HTTPStatus
.
OK
assert
user
.
username
==
username
def
test_can_do_logout
(
self
,
admin_user
:
str
):
username
=
admin_user
with
self
.
make_client
(
username
)
as
client
:
(
_
,
response
)
=
client
.
auth_api
.
create_logout
()
assert
response
.
status
==
HTTPStatus
.
OK
(
_
,
response
)
=
client
.
users_api
.
retrieve_self
(
_parse_response
=
False
,
_check_status
=
False
)
assert
response
.
status
==
HTTPStatus
.
UNAUTHORIZED
@
pytest
.
mark
.
usefixtures
(
"changedb"
)
class
TestCredentialsManagement
:
def
test_can_register
(
self
):
username
=
"newuser"
email
=
"123@456.com"
with
ApiClient
(
Configuration
(
host
=
BASE_URL
))
as
client
:
(
user
,
response
)
=
client
.
auth_api
.
create_register
(
models
.
RestrictedRegisterRequest
(
username
=
username
,
password1
=
USER_PASS
,
password2
=
USER_PASS
,
email
=
email
)
)
assert
response
.
status
==
HTTPStatus
.
CREATED
assert
user
.
username
==
username
with
make_api_client
(
username
)
as
client
:
(
user
,
response
)
=
client
.
users_api
.
retrieve_self
()
assert
response
.
status
==
HTTPStatus
.
OK
assert
user
.
username
==
username
assert
user
.
email
==
email
def
test_can_change_password
(
self
,
admin_user
:
str
):
username
=
admin_user
new_pass
=
"5w4knrqaW#$@gewa"
with
make_api_client
(
username
)
as
client
:
(
info
,
response
)
=
client
.
auth_api
.
create_password_change
(
models
.
PasswordChangeRequest
(
old_password
=
USER_PASS
,
new_password1
=
new_pass
,
new_password2
=
new_pass
)
)
assert
response
.
status
==
HTTPStatus
.
OK
assert
info
.
detail
==
"New password has been saved."
(
_
,
response
)
=
client
.
users_api
.
retrieve_self
(
_parse_response
=
False
,
_check_status
=
False
)
assert
response
.
status
==
HTTPStatus
.
UNAUTHORIZED
client
.
configuration
.
password
=
new_pass
(
user
,
response
)
=
client
.
users_api
.
retrieve_self
()
assert
response
.
status
==
HTTPStatus
.
OK
assert
user
.
username
==
username
def
test_can_report_weak_password
(
self
,
admin_user
:
str
):
username
=
admin_user
new_pass
=
"pass"
with
make_api_client
(
username
)
as
client
:
(
_
,
response
)
=
client
.
auth_api
.
create_password_change
(
models
.
PasswordChangeRequest
(
old_password
=
USER_PASS
,
new_password1
=
new_pass
,
new_password2
=
new_pass
),
_parse_response
=
False
,
_check_status
=
False
,
)
assert
response
.
status
==
HTTPStatus
.
BAD_REQUEST
assert
json
.
loads
(
response
.
data
)
==
{
"new_password2"
:
[
"This password is too short. It must contain at least 8 characters."
,
"This password is too common."
,
]
}
def
test_can_report_mismatching_passwords
(
self
,
admin_user
:
str
):
username
=
admin_user
with
make_api_client
(
username
)
as
client
:
(
_
,
response
)
=
client
.
auth_api
.
create_password_change
(
models
.
PasswordChangeRequest
(
old_password
=
USER_PASS
,
new_password1
=
"3j4tb13/T$#"
,
new_password2
=
"q#@$n34g5"
),
_parse_response
=
False
,
_check_status
=
False
,
)
assert
response
.
status
==
HTTPStatus
.
BAD_REQUEST
assert
json
.
loads
(
response
.
data
)
==
{
"new_password2"
:
[
"The two password fields didn’t match."
]
}
tests/python/rest_api/test_issues.py
浏览文件 @
53697eca
...
...
@@ -3,48 +3,78 @@
#
# SPDX-License-Identifier: MIT
import
pytest
import
json
from
copy
import
deepcopy
from
http
import
HTTPStatus
import
pytest
from
cvat_sdk
import
models
from
deepdiff
import
DeepDiff
from
copy
import
deepcopy
from
shared.utils.config
import
post_method
,
patch_method
from
cvat_sdk.api_client
import
exceptions
@
pytest
.
mark
.
usefixtures
(
'changedb'
)
from
shared.utils.config
import
make_api_client
@
pytest
.
mark
.
usefixtures
(
"changedb"
)
class
TestPostIssues
:
def
_test_check_response
(
self
,
user
,
data
,
is_allow
,
**
kwargs
):
response
=
post_method
(
user
,
'issues'
,
data
,
**
kwargs
)
with
make_api_client
(
user
)
as
client
:
(
_
,
response
)
=
client
.
issues_api
.
create
(
models
.
IssueWriteRequest
(
**
data
),
**
kwargs
,
_parse_response
=
False
,
_check_status
=
False
,
)
if
is_allow
:
assert
response
.
status_code
==
HTTPStatus
.
CREATED
assert
user
==
response
.
json
()[
'owner'
][
'username'
]
assert
data
[
'message'
]
==
response
.
json
()[
'comments'
][
0
][
'message'
]
assert
DeepDiff
(
data
,
response
.
json
(),
exclude_regex_paths
=
r
"root\['created_date|updated_date|comments|id|owner|message'\]"
)
==
{}
assert
response
.
status
==
HTTPStatus
.
CREATED
response_json
=
json
.
loads
(
response
.
data
)
assert
user
==
response_json
[
"owner"
][
"username"
]
assert
data
[
"message"
]
==
response_json
[
"comments"
][
0
][
"message"
]
assert
(
DeepDiff
(
data
,
response_json
,
exclude_regex_paths
=
r
"root\['created_date|updated_date|comments|id|owner|message'\]"
,
)
==
{}
)
else
:
assert
response
.
status_code
==
HTTPStatus
.
FORBIDDEN
@
pytest
.
mark
.
parametrize
(
'org'
,
[
''
])
@
pytest
.
mark
.
parametrize
(
'privilege, job_staff, is_allow'
,
[
(
'admin'
,
True
,
True
),
(
'admin'
,
False
,
True
),
(
'business'
,
True
,
True
),
(
'business'
,
False
,
False
),
(
'worker'
,
True
,
True
),
(
'worker'
,
False
,
False
),
(
'user'
,
True
,
True
),
(
'user'
,
False
,
False
)
])
def
test_user_create_issue
(
self
,
org
,
privilege
,
job_staff
,
is_allow
,
find_job_staff_user
,
find_users
,
jobs_by_org
):
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
@
pytest
.
mark
.
parametrize
(
"org"
,
[
""
])
@
pytest
.
mark
.
parametrize
(
"privilege, job_staff, is_allow"
,
[
(
"admin"
,
True
,
True
),
(
"admin"
,
False
,
True
),
(
"business"
,
True
,
True
),
(
"business"
,
False
,
False
),
(
"worker"
,
True
,
True
),
(
"worker"
,
False
,
False
),
(
"user"
,
True
,
True
),
(
"user"
,
False
,
False
),
],
)
def
test_user_create_issue
(
self
,
org
,
privilege
,
job_staff
,
is_allow
,
find_job_staff_user
,
find_users
,
jobs_by_org
):
users
=
find_users
(
privilege
=
privilege
)
jobs
=
jobs_by_org
[
org
]
username
,
jid
=
find_job_staff_user
(
jobs
,
users
,
job_staff
)
job
,
=
filter
(
lambda
job
:
job
[
'id'
]
==
jid
,
jobs
)
(
job
,)
=
filter
(
lambda
job
:
job
[
"id"
]
==
jid
,
jobs
)
data
=
{
"assignee"
:
None
,
"comments"
:
[],
"job"
:
jid
,
"frame"
:
job
[
'start_frame'
],
"frame"
:
job
[
"start_frame"
],
"position"
:
[
0.
,
0.
,
1.
,
1.
,
0.0
,
0.0
,
1.0
,
1.0
,
],
"resolved"
:
False
,
"message"
:
"lorem ipsum"
,
...
...
@@ -52,16 +82,23 @@ class TestPostIssues:
self
.
_test_check_response
(
username
,
data
,
is_allow
)
@
pytest
.
mark
.
parametrize
(
'org'
,
[
2
])
@
pytest
.
mark
.
parametrize
(
'role, job_staff, is_allow'
,
[
(
'maintainer'
,
False
,
True
),
(
'owner'
,
False
,
True
),
(
'supervisor'
,
False
,
False
),
(
'worker'
,
False
,
False
),
(
'maintainer'
,
True
,
True
),
(
'owner'
,
True
,
True
),
(
'supervisor'
,
True
,
True
),
(
'worker'
,
True
,
True
)
])
def
test_member_create_issue
(
self
,
org
,
role
,
job_staff
,
is_allow
,
find_job_staff_user
,
find_users
,
jobs_by_org
,
jobs
):
@
pytest
.
mark
.
parametrize
(
"org"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"role, job_staff, is_allow"
,
[
(
"maintainer"
,
False
,
True
),
(
"owner"
,
False
,
True
),
(
"supervisor"
,
False
,
False
),
(
"worker"
,
False
,
False
),
(
"maintainer"
,
True
,
True
),
(
"owner"
,
True
,
True
),
(
"supervisor"
,
True
,
True
),
(
"worker"
,
True
,
True
),
],
)
def
test_member_create_issue
(
self
,
org
,
role
,
job_staff
,
is_allow
,
find_job_staff_user
,
find_users
,
jobs_by_org
,
jobs
):
users
=
find_users
(
role
=
role
,
org
=
org
)
username
,
jid
=
find_job_staff_user
(
jobs_by_org
[
org
],
users
,
job_staff
)
job
=
jobs
[
jid
]
...
...
@@ -70,50 +107,85 @@ class TestPostIssues:
"assignee"
:
None
,
"comments"
:
[],
"job"
:
jid
,
"frame"
:
job
[
'start_frame'
],
"frame"
:
job
[
"start_frame"
],
"position"
:
[
0.
,
0.
,
1.
,
1.
,
0.0
,
0.0
,
1.0
,
1.0
,
],
"resolved"
:
False
,
"message"
:
"lorem ipsum"
,
}
self
.
_test_check_response
(
username
,
data
,
is_allow
,
org_id
=
org
)
@
pytest
.
mark
.
usefixtures
(
'changedb'
)
@
pytest
.
mark
.
usefixtures
(
"changedb"
)
class
TestPatchIssues
:
def
_test_check_response
(
self
,
user
,
issue_id
,
data
,
is_allow
,
**
kwargs
):
response
=
patch_method
(
user
,
f
'issues/
{
issue_id
}
'
,
data
,
action
=
'update'
,
**
kwargs
)
with
make_api_client
(
user
)
as
client
:
(
_
,
response
)
=
client
.
issues_api
.
partial_update
(
issue_id
,
patched_issue_write_request
=
models
.
PatchedIssueWriteRequest
(
**
data
),
**
kwargs
,
_parse_response
=
False
,
_check_status
=
False
,
)
if
is_allow
:
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
DeepDiff
(
data
,
response
.
json
(),
exclude_regex_paths
=
r
"root\['created_date|updated_date|comments|id|owner'\]"
)
==
{}
assert
response
.
status
==
HTTPStatus
.
OK
assert
(
DeepDiff
(
data
,
json
.
loads
(
response
.
data
),
exclude_regex_paths
=
r
"root\['created_date|updated_date|comments|id|owner'\]"
,
)
==
{}
)
else
:
assert
response
.
status
_code
==
HTTPStatus
.
FORBIDDEN
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
@
pytest
.
fixture
(
scope
=
'class'
)
@
pytest
.
fixture
(
scope
=
"class"
)
def
request_data
(
self
,
issues
):
def
get_data
(
issue_id
):
data
=
deepcopy
(
issues
[
issue_id
])
data
[
'resolved'
]
=
not
data
[
'resolved'
]
data
.
pop
(
'comments'
)
data
.
pop
(
'updated_date'
)
data
.
pop
(
'id'
)
data
.
pop
(
'owner'
)
data
[
"resolved"
]
=
not
data
[
"resolved"
]
data
.
pop
(
"comments"
)
data
.
pop
(
"updated_date"
)
data
.
pop
(
"id"
)
data
.
pop
(
"owner"
)
return
data
return
get_data
@
pytest
.
mark
.
parametrize
(
'org'
,
[
''
])
@
pytest
.
mark
.
parametrize
(
'privilege, issue_staff, issue_admin, is_allow'
,
[
(
'admin'
,
True
,
None
,
True
),
(
'admin'
,
False
,
None
,
True
),
(
'business'
,
True
,
None
,
True
),
(
'business'
,
False
,
None
,
False
),
(
'user'
,
True
,
None
,
True
),
(
'user'
,
False
,
None
,
False
),
(
'worker'
,
False
,
True
,
True
),
(
'worker'
,
True
,
False
,
False
),
(
'worker'
,
False
,
False
,
False
)
])
def
test_user_update_issue
(
self
,
org
,
privilege
,
issue_staff
,
issue_admin
,
is_allow
,
find_issue_staff_user
,
find_users
,
issues_by_org
,
request_data
):
@
pytest
.
mark
.
parametrize
(
"org"
,
[
""
])
@
pytest
.
mark
.
parametrize
(
"privilege, issue_staff, issue_admin, is_allow"
,
[
(
"admin"
,
True
,
None
,
True
),
(
"admin"
,
False
,
None
,
True
),
(
"business"
,
True
,
None
,
True
),
(
"business"
,
False
,
None
,
False
),
(
"user"
,
True
,
None
,
True
),
(
"user"
,
False
,
None
,
False
),
(
"worker"
,
False
,
True
,
True
),
(
"worker"
,
True
,
False
,
False
),
(
"worker"
,
False
,
False
,
False
),
],
)
def
test_user_update_issue
(
self
,
org
,
privilege
,
issue_staff
,
issue_admin
,
is_allow
,
find_issue_staff_user
,
find_users
,
issues_by_org
,
request_data
,
):
users
=
find_users
(
privilege
=
privilege
)
issues
=
issues_by_org
[
org
]
username
,
issue_id
=
find_issue_staff_user
(
issues
,
users
,
issue_staff
,
issue_admin
)
...
...
@@ -121,19 +193,135 @@ class TestPatchIssues:
data
=
request_data
(
issue_id
)
self
.
_test_check_response
(
username
,
issue_id
,
data
,
is_allow
)
@
pytest
.
mark
.
parametrize
(
'org'
,
[
2
])
@
pytest
.
mark
.
parametrize
(
'role, issue_staff, issue_admin, is_allow'
,
[
(
'maintainer'
,
True
,
None
,
True
),
(
'maintainer'
,
False
,
None
,
True
),
(
'supervisor'
,
True
,
None
,
True
),
(
'supervisor'
,
False
,
None
,
False
),
(
'owner'
,
True
,
None
,
True
),
(
'owner'
,
False
,
None
,
True
),
(
'worker'
,
False
,
True
,
True
),
(
'worker'
,
True
,
False
,
False
),
(
'worker'
,
False
,
False
,
False
)
])
def
test_member_update_issue
(
self
,
org
,
role
,
issue_staff
,
issue_admin
,
is_allow
,
find_issue_staff_user
,
find_users
,
issues_by_org
,
request_data
):
@
pytest
.
mark
.
parametrize
(
"org"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"role, issue_staff, issue_admin, is_allow"
,
[
(
"maintainer"
,
True
,
None
,
True
),
(
"maintainer"
,
False
,
None
,
True
),
(
"supervisor"
,
True
,
None
,
True
),
(
"supervisor"
,
False
,
None
,
False
),
(
"owner"
,
True
,
None
,
True
),
(
"owner"
,
False
,
None
,
True
),
(
"worker"
,
False
,
True
,
True
),
(
"worker"
,
True
,
False
,
False
),
(
"worker"
,
False
,
False
,
False
),
],
)
def
test_member_update_issue
(
self
,
org
,
role
,
issue_staff
,
issue_admin
,
is_allow
,
find_issue_staff_user
,
find_users
,
issues_by_org
,
request_data
,
):
users
=
find_users
(
role
=
role
,
org
=
org
)
issues
=
issues_by_org
[
org
]
username
,
issue_id
=
find_issue_staff_user
(
issues
,
users
,
issue_staff
,
issue_admin
)
data
=
request_data
(
issue_id
)
self
.
_test_check_response
(
username
,
issue_id
,
data
,
is_allow
,
org_id
=
org
)
@
pytest
.
mark
.
xfail
(
raises
=
exceptions
.
ServiceException
,
reason
=
"server bug, https://github.com/cvat-ai/cvat/issues/122"
)
def
test_cant_update_message
(
self
,
admin_user
:
str
,
issues_by_org
):
org
=
2
issue_id
=
issues_by_org
[
org
][
0
][
'id'
]
with
make_api_client
(
admin_user
)
as
client
:
client
.
issues_api
.
partial_update
(
issue_id
,
patched_issue_write_request
=
models
.
PatchedIssueWriteRequest
(
message
=
"foo"
),
org_id
=
org
,
)
@
pytest
.
mark
.
usefixtures
(
"changedb"
)
class
TestDeleteIssues
:
def
_test_check_response
(
self
,
user
,
issue_id
,
expect_success
,
**
kwargs
):
with
make_api_client
(
user
)
as
client
:
(
_
,
response
)
=
client
.
issues_api
.
destroy
(
issue_id
,
**
kwargs
,
_parse_response
=
False
,
_check_status
=
False
,
)
if
expect_success
:
assert
response
.
status
==
HTTPStatus
.
NO_CONTENT
(
_
,
response
)
=
client
.
issues_api
.
retrieve
(
issue_id
,
_parse_response
=
False
,
_check_status
=
False
)
assert
response
.
status
==
HTTPStatus
.
NOT_FOUND
else
:
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
@
pytest
.
mark
.
parametrize
(
"org"
,
[
""
])
@
pytest
.
mark
.
parametrize
(
"privilege, issue_staff, issue_admin, expect_success"
,
[
(
"admin"
,
True
,
None
,
True
),
(
"admin"
,
False
,
None
,
True
),
(
"business"
,
True
,
None
,
True
),
(
"business"
,
False
,
None
,
False
),
(
"user"
,
True
,
None
,
True
),
(
"user"
,
False
,
None
,
False
),
(
"worker"
,
False
,
True
,
True
),
(
"worker"
,
True
,
False
,
False
),
(
"worker"
,
False
,
False
,
False
),
],
)
def
test_user_delete_issue
(
self
,
org
,
privilege
,
issue_staff
,
issue_admin
,
expect_success
,
find_issue_staff_user
,
find_users
,
issues_by_org
,
):
users
=
find_users
(
privilege
=
privilege
)
issues
=
issues_by_org
[
org
]
username
,
issue_id
=
find_issue_staff_user
(
issues
,
users
,
issue_staff
,
issue_admin
)
self
.
_test_check_response
(
username
,
issue_id
,
expect_success
)
@
pytest
.
mark
.
parametrize
(
"org"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"role, issue_staff, issue_admin, expect_success"
,
[
(
"maintainer"
,
True
,
None
,
True
),
(
"maintainer"
,
False
,
None
,
True
),
(
"supervisor"
,
True
,
None
,
True
),
(
"supervisor"
,
False
,
None
,
False
),
(
"owner"
,
True
,
None
,
True
),
(
"owner"
,
False
,
None
,
True
),
(
"worker"
,
False
,
True
,
True
),
(
"worker"
,
True
,
False
,
False
),
(
"worker"
,
False
,
False
,
False
),
],
)
def
test_org_member_delete_issue
(
self
,
org
,
role
,
issue_staff
,
issue_admin
,
expect_success
,
find_issue_staff_user
,
find_users
,
issues_by_org
,
):
users
=
find_users
(
role
=
role
,
org
=
org
)
issues
=
issues_by_org
[
org
]
username
,
issue_id
=
find_issue_staff_user
(
issues
,
users
,
issue_staff
,
issue_admin
)
self
.
_test_check_response
(
username
,
issue_id
,
expect_success
,
org_id
=
org
)
tests/python/rest_api/test_jobs.py
浏览文件 @
53697eca
...
...
@@ -4,10 +4,14 @@
# SPDX-License-Identifier: MIT
from
http
import
HTTPStatus
import
json
from
typing
import
List
from
cvat_sdk.core.helpers
import
get_paginated_collection
from
deepdiff
import
DeepDiff
import
pytest
from
copy
import
deepcopy
from
shared.utils.config
import
get_method
,
patch_method
from
shared.utils.config
import
make_api_client
from
.utils
import
export_dataset
def
get_job_staff
(
job
,
tasks
,
projects
):
job_staff
=
[]
...
...
@@ -42,15 +46,17 @@ def filter_jobs(jobs, tasks, org):
@
pytest
.
mark
.
usefixtures
(
'dontchangedb'
)
class
TestGetJobs
:
def
_test_get_job_200
(
self
,
user
,
jid
,
data
,
**
kwargs
):
response
=
get_method
(
user
,
f
'jobs/
{
jid
}
'
,
**
kwargs
)
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
DeepDiff
(
data
,
response
.
json
(
),
exclude_paths
=
"root['updated_date']"
,
with
make_api_client
(
user
)
as
client
:
(
_
,
response
)
=
client
.
jobs_api
.
retrieve
(
jid
,
**
kwargs
)
assert
response
.
status
==
HTTPStatus
.
OK
assert
DeepDiff
(
data
,
json
.
loads
(
response
.
data
),
exclude_paths
=
"root['updated_date']"
,
ignore_order
=
True
)
==
{}
def
_test_get_job_403
(
self
,
user
,
jid
,
**
kwargs
):
response
=
get_method
(
user
,
f
'jobs/
{
jid
}
'
,
**
kwargs
)
assert
response
.
status_code
==
HTTPStatus
.
FORBIDDEN
with
make_api_client
(
user
)
as
client
:
(
_
,
response
)
=
client
.
jobs_api
.
retrieve
(
jid
,
**
kwargs
,
_check_status
=
False
,
_parse_response
=
False
)
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
@
pytest
.
mark
.
parametrize
(
'org'
,
[
None
,
''
,
1
,
2
])
def
test_admin_get_job
(
self
,
jobs
,
tasks
,
org
):
...
...
@@ -82,15 +88,17 @@ class TestGetJobs:
@
pytest
.
mark
.
usefixtures
(
'dontchangedb'
)
class
TestListJobs
:
def
_test_list_jobs_200
(
self
,
user
,
data
,
**
kwargs
):
response
=
get_method
(
user
,
'jobs'
,
**
kwargs
,
page_size
=
'all'
)
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
DeepDiff
(
data
,
response
.
json
()[
'results'
]
,
exclude_paths
=
"root['updated_date']"
,
with
make_api_client
(
user
)
as
client
:
results
=
get_paginated_collection
(
client
.
jobs_api
.
list_endpoint
,
return_json
=
True
,
**
kwargs
)
assert
DeepDiff
(
data
,
results
,
exclude_paths
=
"root['updated_date']"
,
ignore_order
=
True
)
==
{}
def
_test_list_jobs_403
(
self
,
user
,
**
kwargs
):
response
=
get_method
(
user
,
'jobs'
,
**
kwargs
)
assert
response
.
status_code
==
HTTPStatus
.
FORBIDDEN
with
make_api_client
(
user
)
as
client
:
(
_
,
response
)
=
client
.
jobs_api
.
list
(
**
kwargs
,
_check_status
=
False
,
_parse_response
=
False
)
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
@
pytest
.
mark
.
parametrize
(
'org'
,
[
None
,
''
,
1
,
2
])
def
test_admin_list_jobs
(
self
,
jobs
,
tasks
,
org
):
...
...
@@ -119,52 +127,54 @@ class TestListJobs:
@
pytest
.
mark
.
usefixtures
(
'dontchangedb'
)
class
TestGetAnnotations
:
def
_test_get_job_annotations_200
(
self
,
user
,
jid
,
data
,
**
kwargs
):
response
=
get_method
(
user
,
f
'jobs/
{
jid
}
/annotations'
,
**
kwargs
)
with
make_api_client
(
user
)
as
client
:
(
_
,
response
)
=
client
.
jobs_api
.
retrieve_annotations
(
jid
,
**
kwargs
)
assert
response
.
status
==
HTTPStatus
.
OK
response_data
=
response
.
json
(
)
response_data
=
json
.
loads
(
response
.
data
)
response_data
[
'shapes'
]
=
sorted
(
response_data
[
'shapes'
],
key
=
lambda
a
:
a
[
'id'
])
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
DeepDiff
(
data
,
response_data
,
exclude_regex_paths
=
r
"root\['version|updated_date'\]"
)
==
{}
def
_test_get_job_annotations_403
(
self
,
user
,
jid
,
**
kwargs
):
response
=
get_method
(
user
,
f
'jobs/
{
jid
}
/annotations'
,
**
kwargs
)
assert
response
.
status_code
==
HTTPStatus
.
FORBIDDEN
with
make_api_client
(
user
)
as
client
:
(
_
,
response
)
=
client
.
jobs_api
.
retrieve_annotations
(
jid
,
**
kwargs
,
_check_status
=
False
,
_parse_response
=
False
)
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
@
pytest
.
mark
.
parametrize
(
'org'
,
[
''
])
@
pytest
.
mark
.
parametrize
(
'groups, job_staff,
is_allow
'
,
[
@
pytest
.
mark
.
parametrize
(
'groups, job_staff,
expect_success
'
,
[
([
'admin'
],
True
,
True
),
([
'admin'
],
False
,
True
),
([
'business'
],
True
,
True
),
([
'business'
],
False
,
False
),
([
'worker'
],
True
,
True
),
([
'worker'
],
False
,
False
),
([
'user'
],
True
,
True
),
([
'user'
],
False
,
False
)
])
def
test_user_get_job_annotations
(
self
,
org
,
groups
,
job_staff
,
is_allow
,
users
,
jobs
,
tasks
,
annotations
,
find_job_staff_user
):
expect_success
,
users
,
jobs
,
tasks
,
annotations
,
find_job_staff_user
):
users
=
[
u
for
u
in
users
if
u
[
'groups'
]
==
groups
]
jobs
,
kwargs
=
filter_jobs
(
jobs
,
tasks
,
org
)
username
,
job_id
=
find_job_staff_user
(
jobs
,
users
,
job_staff
)
if
is_allow
:
if
expect_success
:
self
.
_test_get_job_annotations_200
(
username
,
job_id
,
annotations
[
'job'
][
str
(
job_id
)],
**
kwargs
)
else
:
self
.
_test_get_job_annotations_403
(
username
,
job_id
,
**
kwargs
)
@
pytest
.
mark
.
parametrize
(
'org'
,
[
2
])
@
pytest
.
mark
.
parametrize
(
'role, job_staff,
is_allow
'
,
[
@
pytest
.
mark
.
parametrize
(
'role, job_staff,
expect_success
'
,
[
(
'owner'
,
True
,
True
),
(
'owner'
,
False
,
True
),
(
'maintainer'
,
True
,
True
),
(
'maintainer'
,
False
,
True
),
(
'supervisor'
,
True
,
True
),
(
'supervisor'
,
False
,
False
),
(
'worker'
,
True
,
True
),
(
'worker'
,
False
,
False
),
])
def
test_member_get_job_annotations
(
self
,
org
,
role
,
job_staff
,
is_allow
,
def
test_member_get_job_annotations
(
self
,
org
,
role
,
job_staff
,
expect_success
,
jobs
,
tasks
,
find_job_staff_user
,
annotations
,
find_users
):
users
=
find_users
(
org
=
org
,
role
=
role
)
jobs
,
kwargs
=
filter_jobs
(
jobs
,
tasks
,
org
)
username
,
jid
=
find_job_staff_user
(
jobs
,
users
,
job_staff
)
if
is_allow
:
if
expect_success
:
data
=
annotations
[
'job'
][
str
(
jid
)]
data
[
'shapes'
]
=
sorted
(
data
[
'shapes'
],
key
=
lambda
a
:
a
[
'id'
])
self
.
_test_get_job_annotations_200
(
username
,
jid
,
data
,
**
kwargs
)
...
...
@@ -172,17 +182,17 @@ class TestGetAnnotations:
self
.
_test_get_job_annotations_403
(
username
,
jid
,
**
kwargs
)
@
pytest
.
mark
.
parametrize
(
'org'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'privilege,
is_allow
'
,
[
@
pytest
.
mark
.
parametrize
(
'privilege,
expect_success
'
,
[
(
'admin'
,
True
),
(
'business'
,
False
),
(
'worker'
,
False
),
(
'user'
,
False
)
])
def
test_non_member_get_job_annotations
(
self
,
org
,
privilege
,
is_allow
,
def
test_non_member_get_job_annotations
(
self
,
org
,
privilege
,
expect_success
,
jobs
,
tasks
,
find_job_staff_user
,
annotations
,
find_users
):
users
=
find_users
(
privilege
=
privilege
,
exclude_org
=
org
)
jobs
,
kwargs
=
filter_jobs
(
jobs
,
tasks
,
org
)
username
,
job_id
=
find_job_staff_user
(
jobs
,
users
,
False
)
kwargs
=
{
'org_id'
:
org
}
if
is_allow
:
if
expect_success
:
self
.
_test_get_job_annotations_200
(
username
,
job_id
,
annotations
[
'job'
][
str
(
job_id
)],
**
kwargs
)
else
:
...
...
@@ -190,15 +200,25 @@ class TestGetAnnotations:
@
pytest
.
mark
.
usefixtures
(
'changedb'
)
class
TestPatchJobAnnotations
:
_ORG
=
2
def
_check_respone
(
self
,
username
,
jid
,
expect_success
,
data
=
None
,
org
=
None
):
kwargs
=
{}
if
org
is
not
None
:
if
isinstance
(
org
,
str
):
kwargs
[
'org'
]
=
org
else
:
kwargs
[
'org_id'
]
=
org
with
make_api_client
(
username
)
as
client
:
(
_
,
response
)
=
client
.
jobs_api
.
partial_update_annotations
(
id
=
jid
,
patched_labeled_data_request
=
deepcopy
(
data
),
action
=
'update'
,
**
kwargs
,
_parse_response
=
expect_success
,
_check_status
=
expect_success
)
def
_test_check_respone
(
self
,
is_allow
,
response
,
data
=
None
):
if
is_allow
:
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
DeepDiff
(
data
,
response
.
json
(),
if
expect_success
:
assert
response
.
status
==
HTTPStatus
.
OK
assert
DeepDiff
(
data
,
json
.
loads
(
response
.
data
),
exclude_regex_paths
=
r
"root\['version|updated_date'\]"
)
==
{}
else
:
assert
response
.
status_code
==
HTTPStatus
.
FORBIDDEN
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
@
pytest
.
fixture
(
scope
=
'class'
)
def
request_data
(
self
,
annotations
):
...
...
@@ -210,13 +230,13 @@ class TestPatchJobAnnotations:
return
get_data
@
pytest
.
mark
.
parametrize
(
'org'
,
[
2
])
@
pytest
.
mark
.
parametrize
(
'role, job_staff,
is_allow
'
,
[
@
pytest
.
mark
.
parametrize
(
'role, job_staff,
expect_success
'
,
[
(
'maintainer'
,
False
,
True
),
(
'owner'
,
False
,
True
),
(
'supervisor'
,
False
,
False
),
(
'worker'
,
False
,
False
),
(
'maintainer'
,
True
,
True
),
(
'owner'
,
True
,
True
),
(
'supervisor'
,
True
,
True
),
(
'worker'
,
True
,
True
)
])
def
test_member_update_job_annotations
(
self
,
org
,
role
,
job_staff
,
is_allow
,
def
test_member_update_job_annotations
(
self
,
org
,
role
,
job_staff
,
expect_success
,
find_job_staff_user
,
find_users
,
request_data
,
jobs_by_org
,
filter_jobs_with_shapes
):
users
=
find_users
(
role
=
role
,
org
=
org
)
jobs
=
jobs_by_org
[
org
]
...
...
@@ -224,17 +244,13 @@ class TestPatchJobAnnotations:
username
,
jid
=
find_job_staff_user
(
filtered_jobs
,
users
,
job_staff
)
data
=
request_data
(
jid
)
response
=
patch_method
(
username
,
f
'jobs/
{
jid
}
/annotations'
,
data
,
org_id
=
org
,
action
=
'update'
)
self
.
_test_check_respone
(
is_allow
,
response
,
data
)
self
.
_check_respone
(
username
,
jid
,
expect_success
,
data
,
org
=
org
)
@
pytest
.
mark
.
parametrize
(
'org'
,
[
2
])
@
pytest
.
mark
.
parametrize
(
'privilege,
is_allow
'
,
[
@
pytest
.
mark
.
parametrize
(
'privilege,
expect_success
'
,
[
(
'admin'
,
True
),
(
'business'
,
False
),
(
'worker'
,
False
),
(
'user'
,
False
)
])
def
test_non_member_update_job_annotations
(
self
,
org
,
privilege
,
is_allow
,
def
test_non_member_update_job_annotations
(
self
,
org
,
privilege
,
expect_success
,
find_job_staff_user
,
find_users
,
request_data
,
jobs_by_org
,
filter_jobs_with_shapes
):
users
=
find_users
(
privilege
=
privilege
,
exclude_org
=
org
)
jobs
=
jobs_by_org
[
org
]
...
...
@@ -242,19 +258,16 @@ class TestPatchJobAnnotations:
username
,
jid
=
find_job_staff_user
(
filtered_jobs
,
users
,
False
)
data
=
request_data
(
jid
)
response
=
patch_method
(
username
,
f
'jobs/
{
jid
}
/annotations'
,
data
,
org_id
=
org
,
action
=
'update'
)
self
.
_test_check_respone
(
is_allow
,
response
,
data
)
self
.
_check_respone
(
username
,
jid
,
expect_success
,
data
,
org
=
org
)
@
pytest
.
mark
.
parametrize
(
'org'
,
[
''
])
@
pytest
.
mark
.
parametrize
(
'privilege, job_staff,
is_allow
'
,
[
@
pytest
.
mark
.
parametrize
(
'privilege, job_staff,
expect_success
'
,
[
(
'admin'
,
True
,
True
),
(
'admin'
,
False
,
True
),
(
'business'
,
True
,
True
),
(
'business'
,
False
,
False
),
(
'worker'
,
True
,
True
),
(
'worker'
,
False
,
False
),
(
'user'
,
True
,
True
),
(
'user'
,
False
,
False
)
])
def
test_user_update_job_annotations
(
self
,
org
,
privilege
,
job_staff
,
is_allow
,
def
test_user_update_job_annotations
(
self
,
org
,
privilege
,
job_staff
,
expect_success
,
find_job_staff_user
,
find_users
,
request_data
,
jobs_by_org
,
filter_jobs_with_shapes
):
users
=
find_users
(
privilege
=
privilege
)
jobs
=
jobs_by_org
[
org
]
...
...
@@ -262,15 +275,10 @@ class TestPatchJobAnnotations:
username
,
jid
=
find_job_staff_user
(
filtered_jobs
,
users
,
job_staff
)
data
=
request_data
(
jid
)
response
=
patch_method
(
username
,
f
'jobs/
{
jid
}
/annotations'
,
data
,
org_id
=
org
,
action
=
'update'
)
self
.
_test_check_respone
(
is_allow
,
response
,
data
)
self
.
_check_respone
(
username
,
jid
,
expect_success
,
data
,
org
=
org
)
@
pytest
.
mark
.
usefixtures
(
'changedb'
)
class
TestPatchJob
:
_ORG
=
2
@
pytest
.
fixture
(
scope
=
'class'
)
def
find_task_staff_user
(
self
,
is_task_staff
):
def
find
(
jobs
,
users
,
is_staff
):
...
...
@@ -300,24 +308,47 @@ class TestPatchJob:
return
find_new_assignee
@
pytest
.
mark
.
parametrize
(
'org'
,
[
2
])
@
pytest
.
mark
.
parametrize
(
'role, task_staff,
is_allow
'
,
[
@
pytest
.
mark
.
parametrize
(
'role, task_staff,
expect_success
'
,
[
(
'maintainer'
,
False
,
True
),
(
'owner'
,
False
,
True
),
(
'supervisor'
,
False
,
False
),
(
'worker'
,
False
,
False
),
(
'maintainer'
,
True
,
True
),
(
'owner'
,
True
,
True
),
(
'supervisor'
,
True
,
True
),
(
'worker'
,
True
,
True
)
])
def
test_member_update_job_assignee
(
self
,
org
,
role
,
task_staff
,
is_allow
,
def
test_member_update_job_assignee
(
self
,
org
,
role
,
task_staff
,
expect_success
,
find_task_staff_user
,
find_users
,
jobs_by_org
,
new_assignee
,
expected_data
):
users
,
jobs
=
find_users
(
role
=
role
,
org
=
org
),
jobs_by_org
[
org
]
user
,
jid
=
find_task_staff_user
(
jobs
,
users
,
task_staff
)
assignee
=
new_assignee
(
jid
,
user
[
'id'
])
response
=
patch_method
(
user
[
'username'
],
f
'jobs/
{
jid
}
'
,
{
'assignee'
:
assignee
},
org_id
=
self
.
_ORG
)
if
is_allow
:
assert
response
.
status_code
==
HTTPStatus
.
OK
assert
DeepDiff
(
expected_data
(
jid
,
assignee
),
response
.
json
(),
with
make_api_client
(
user
[
'username'
])
as
client
:
(
_
,
response
)
=
client
.
jobs_api
.
partial_update
(
id
=
jid
,
patched_job_write_request
=
{
'assignee'
:
assignee
},
org_id
=
org
,
_parse_response
=
expect_success
,
_check_status
=
expect_success
)
if
expect_success
:
assert
response
.
status
==
HTTPStatus
.
OK
assert
DeepDiff
(
expected_data
(
jid
,
assignee
),
json
.
loads
(
response
.
data
),
exclude_paths
=
"root['updated_date']"
,
ignore_order
=
True
)
==
{}
else
:
assert
response
.
status_code
==
HTTPStatus
.
FORBIDDEN
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
@
pytest
.
mark
.
usefixtures
(
'dontchangedb'
)
class
TestJobDataset
:
def
_export_dataset
(
self
,
username
,
jid
,
**
kwargs
):
with
make_api_client
(
username
)
as
api_client
:
return
export_dataset
(
api_client
.
jobs_api
.
retrieve_dataset_endpoint
,
id
=
jid
,
**
kwargs
)
def
_export_annotations
(
self
,
username
,
jid
,
**
kwargs
):
with
make_api_client
(
username
)
as
api_client
:
return
export_dataset
(
api_client
.
jobs_api
.
retrieve_annotations_endpoint
,
id
=
jid
,
**
kwargs
)
def
test_can_export_dataset
(
self
,
admin_user
:
str
,
jobs_with_shapes
:
List
):
job
=
jobs_with_shapes
[
0
]
response
=
self
.
_export_dataset
(
admin_user
,
job
[
'id'
],
format
=
'CVAT for images 1.1'
)
assert
response
.
data
def
test_can_export_annotations
(
self
,
admin_user
:
str
,
jobs_with_shapes
:
List
):
job
=
jobs_with_shapes
[
0
]
response
=
self
.
_export_annotations
(
admin_user
,
job
[
'id'
],
format
=
'CVAT for images 1.1'
)
assert
response
.
data
tests/python/rest_api/test_projects.py
浏览文件 @
53697eca
...
...
@@ -13,9 +13,8 @@ import pytest
from
copy
import
deepcopy
from
deepdiff
import
DeepDiff
from
cvat_sdk.models
import
DatasetFileRequest
,
ProjectWriteRequest
from
shared.utils.config
import
get_method
,
patch_method
,
make_api_client
from
.utils
import
export_dataset
@
pytest
.
mark
.
usefixtures
(
'dontchangedb'
)
...
...
@@ -229,12 +228,12 @@ class TestGetProjectBackup:
class
TestPostProjects
:
def
_test_create_project_201
(
self
,
user
,
spec
,
**
kwargs
):
with
make_api_client
(
user
)
as
api_client
:
(
_
,
response
)
=
api_client
.
projects_api
.
create
(
ProjectWriteRequest
(
**
spec
)
,
**
kwargs
)
(
_
,
response
)
=
api_client
.
projects_api
.
create
(
spec
,
**
kwargs
)
assert
response
.
status
==
HTTPStatus
.
CREATED
def
_test_create_project_403
(
self
,
user
,
spec
,
**
kwargs
):
with
make_api_client
(
user
)
as
api_client
:
(
_
,
response
)
=
api_client
.
projects_api
.
create
(
ProjectWriteRequest
(
**
spec
)
,
**
kwargs
,
(
_
,
response
)
=
api_client
.
projects_api
.
create
(
spec
,
**
kwargs
,
_parse_response
=
False
,
_check_status
=
False
)
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
...
...
@@ -316,43 +315,30 @@ class TestPostProjects:
self
.
_test_create_project_201
(
user
[
'username'
],
spec
,
org_id
=
user
[
'org'
])
@
pytest
.
mark
.
usefixtures
(
"changedb"
)
@
pytest
.
mark
.
usefixtures
(
"restore_cvat_data"
)
class
TestImportExportDatasetProject
:
def
_test_export_project
(
self
,
username
,
p
roject_
id
,
format_name
):
def
_test_export_project
(
self
,
username
,
pid
,
format_name
):
with
make_api_client
(
username
)
as
api_client
:
while
True
:
(
_
,
response
)
=
api_client
.
projects_api
.
retrieve_dataset
(
id
=
project_id
,
format
=
format_name
)
if
response
.
status
==
HTTPStatus
.
CREATED
:
break
(
_
,
response
)
=
api_client
.
projects_api
.
retrieve_dataset
(
id
=
project_id
,
format
=
format_name
,
action
=
'download'
)
assert
response
.
status
==
HTTPStatus
.
OK
return
response
return
export_dataset
(
api_client
.
projects_api
.
retrieve_dataset_endpoint
,
id
=
pid
,
format
=
format_name
)
def
_test_import_project
(
self
,
username
,
project_id
,
format_name
,
data
):
with
make_api_client
(
username
)
as
api_client
:
(
_
,
response
)
=
api_client
.
projects_api
.
create_dataset
(
id
=
project_id
,
format
=
format_name
,
dataset_
file_request
=
DatasetFileRequest
(
**
data
),
format
=
format_name
,
dataset_
write_request
=
deepcopy
(
data
),
_content_type
=
"multipart/form-data"
)
assert
response
.
status
==
HTTPStatus
.
ACCEPTED
while
True
:
# TODO: Request schema doesn't describe this capability.
# It's better be refactored to a separate endpoint to get request status
response
=
get_method
(
username
,
f
'projects/
{
project_id
}
/dataset'
,
# TODO: It's better be refactored to a separate endpoint to get request status
(
_
,
response
)
=
api_client
.
projects_api
.
retrieve_dataset
(
project_id
,
action
=
'import_status'
)
response
.
raise_for_status
()
if
response
.
status_code
==
HTTPStatus
.
CREATED
:
if
response
.
status
==
HTTPStatus
.
CREATED
:
break
def
test_can_import_dataset_in_org
(
self
):
username
=
'admin1'
def
test_can_import_dataset_in_org
(
self
,
admin_user
):
project_id
=
4
response
=
self
.
_test_export_project
(
username
,
project_id
,
'CVAT for images 1.1'
)
response
=
self
.
_test_export_project
(
admin_user
,
project_id
,
'CVAT for images 1.1'
)
tmp_file
=
io
.
BytesIO
(
response
.
data
)
tmp_file
.
name
=
'dataset.zip'
...
...
@@ -361,7 +347,7 @@ class TestImportExportDatasetProject:
'dataset_file'
:
tmp_file
,
}
self
.
_test_import_project
(
username
,
project_id
,
'CVAT 1.1'
,
import_data
)
self
.
_test_import_project
(
admin_user
,
project_id
,
'CVAT 1.1'
,
import_data
)
@
pytest
.
mark
.
usefixtures
(
'changedb'
)
class
TestPatchProjectLabel
:
...
...
tests/python/rest_api/test_tasks.py
浏览文件 @
53697eca
...
...
@@ -7,14 +7,15 @@ import json
from
copy
import
deepcopy
from
http
import
HTTPStatus
from
time
import
sleep
from
cvat_sdk.api_client
.apis
import
TasksApi
from
cvat_sdk.
api_client
import
models
from
cvat_sdk.api_client
import
models
,
apis
from
cvat_sdk.
core.helpers
import
get_paginated_collection
import
pytest
from
deepdiff
import
DeepDiff
from
shared.utils.config
import
make_api_client
from
shared.utils.helpers
import
generate_image_files
from
.utils
import
export_dataset
def
get_cloud_storage_content
(
username
,
cloud_storage_id
,
manifest
):
with
make_api_client
(
username
)
as
api_client
:
...
...
@@ -27,12 +28,9 @@ def get_cloud_storage_content(username, cloud_storage_id, manifest):
class
TestGetTasks
:
def
_test_task_list_200
(
self
,
user
,
project_id
,
data
,
exclude_paths
=
''
,
**
kwargs
):
with
make_api_client
(
user
)
as
api_client
:
(
_
,
response
)
=
api_client
.
projects_api
.
list_tasks
(
project_id
,
**
kwargs
,
_parse_response
=
False
)
assert
response
.
status
==
HTTPStatus
.
OK
response_data
=
json
.
loads
(
response
.
data
)
assert
DeepDiff
(
data
,
response_data
[
'results'
],
ignore_order
=
True
,
exclude_paths
=
exclude_paths
)
==
{}
results
=
get_paginated_collection
(
api_client
.
projects_api
.
list_tasks_endpoint
,
return_json
=
True
,
id
=
project_id
,
**
kwargs
)
assert
DeepDiff
(
data
,
results
,
ignore_order
=
True
,
exclude_paths
=
exclude_paths
)
==
{}
def
_test_task_list_403
(
self
,
user
,
project_id
,
**
kwargs
):
with
make_api_client
(
user
)
as
api_client
:
...
...
@@ -60,7 +58,7 @@ class TestGetTasks:
for
user
in
staff_users
:
with
make_api_client
(
user
[
'username'
])
as
api_client
:
(
_
,
response
)
=
api_client
.
tasks_api
.
list
(
**
kwargs
,
_parse_response
=
False
)
(
_
,
response
)
=
api_client
.
tasks_api
.
list
(
**
kwargs
)
assert
response
.
status
==
HTTPStatus
.
OK
response_data
=
json
.
loads
(
response
.
data
)
...
...
@@ -113,12 +111,12 @@ class TestGetTasks:
class
TestPostTasks
:
def
_test_create_task_201
(
self
,
user
,
spec
,
**
kwargs
):
with
make_api_client
(
user
)
as
api_client
:
(
_
,
response
)
=
api_client
.
tasks_api
.
create
(
models
.
TaskWriteRequest
(
**
spec
)
,
**
kwargs
)
(
_
,
response
)
=
api_client
.
tasks_api
.
create
(
spec
,
**
kwargs
)
assert
response
.
status
==
HTTPStatus
.
CREATED
def
_test_create_task_403
(
self
,
user
,
spec
,
**
kwargs
):
with
make_api_client
(
user
)
as
api_client
:
(
_
,
response
)
=
api_client
.
tasks_api
.
create
(
models
.
TaskWriteRequest
(
**
spec
)
,
**
kwargs
,
(
_
,
response
)
=
api_client
.
tasks_api
.
create
(
spec
,
**
kwargs
,
_parse_response
=
False
,
_check_status
=
False
)
assert
response
.
status
==
HTTPStatus
.
FORBIDDEN
...
...
@@ -210,10 +208,9 @@ class TestPatchTaskAnnotations:
data
=
request_data
(
tid
)
with
make_api_client
(
username
)
as
api_client
:
patched_data
=
models
.
PatchedTaskWriteRequest
(
**
deepcopy
(
data
))
(
_
,
response
)
=
api_client
.
tasks_api
.
partial_update_annotations
(
id
=
tid
,
action
=
'update'
,
org
=
org
,
patched_
task_write_request
=
patched_data
,
patched_
labeled_data_request
=
deepcopy
(
data
)
,
_parse_response
=
False
,
_check_status
=
False
)
self
.
_test_check_response
(
is_allow
,
response
,
data
)
...
...
@@ -233,30 +230,23 @@ class TestPatchTaskAnnotations:
data
=
request_data
(
tid
)
with
make_api_client
(
username
)
as
api_client
:
patched_data
=
models
.
PatchedTaskWriteRequest
(
**
deepcopy
(
data
))
(
_
,
response
)
=
api_client
.
tasks_api
.
partial_update_annotations
(
id
=
tid
,
org_id
=
org
,
action
=
'update'
,
patched_
task_write_request
=
patched_data
,
patched_
labeled_data_request
=
deepcopy
(
data
)
,
_parse_response
=
False
,
_check_status
=
False
)
self
.
_test_check_response
(
is_allow
,
response
,
data
)
@
pytest
.
mark
.
usefixtures
(
'dontchangedb'
)
class
TestGetTaskDataset
:
def
_test_export_
project
(
self
,
username
,
tid
,
**
kwargs
):
def
_test_export_
task
(
self
,
username
,
tid
,
**
kwargs
):
with
make_api_client
(
username
)
as
api_client
:
(
_
,
response
)
=
api_client
.
tasks_api
.
retrieve_dataset
(
id
=
tid
,
**
kwargs
)
assert
response
.
status
==
HTTPStatus
.
ACCEPTED
(
_
,
response
)
=
api_client
.
tasks_api
.
retrieve_dataset
(
id
=
tid
,
**
kwargs
)
assert
response
.
status
==
HTTPStatus
.
CREATED
(
_
,
response
)
=
api_client
.
tasks_api
.
retrieve_dataset
(
id
=
tid
,
**
kwargs
,
action
=
'download'
)
assert
response
.
status
==
HTTPStatus
.
OK
return
export_dataset
(
api_client
.
tasks_api
.
retrieve_dataset_endpoint
,
id
=
tid
,
**
kwargs
)
def
test_
admin_can_export_task_dataset
(
self
,
tasks_with_shapes
):
def
test_
can_export_task_dataset
(
self
,
admin_user
,
tasks_with_shapes
):
task
=
tasks_with_shapes
[
0
]
self
.
_test_export_project
(
'admin1'
,
task
[
'id'
],
format
=
'CVAT for images 1.1'
)
response
=
self
.
_test_export_task
(
admin_user
,
task
[
'id'
],
format
=
'CVAT for images 1.1'
)
assert
response
.
data
@
pytest
.
mark
.
usefixtures
(
"changedb"
)
@
pytest
.
mark
.
usefixtures
(
"restore_cvat_data"
)
...
...
@@ -264,7 +254,7 @@ class TestPostTaskData:
_USERNAME
=
'admin1'
@
staticmethod
def
_wait_until_task_is_created
(
api
:
TasksApi
,
task_id
:
int
)
->
models
.
RqStatus
:
def
_wait_until_task_is_created
(
api
:
apis
.
TasksApi
,
task_id
:
int
)
->
models
.
RqStatus
:
for
_
in
range
(
100
):
(
status
,
_
)
=
api
.
retrieve_status
(
task_id
)
if
status
.
state
.
value
in
[
'Finished'
,
'Failed'
]:
...
...
@@ -274,11 +264,10 @@ class TestPostTaskData:
def
_test_create_task
(
self
,
username
,
spec
,
data
,
content_type
,
**
kwargs
):
with
make_api_client
(
username
)
as
api_client
:
(
task
,
response
)
=
api_client
.
tasks_api
.
create
(
models
.
TaskWriteRequest
(
**
spec
)
,
**
kwargs
)
(
task
,
response
)
=
api_client
.
tasks_api
.
create
(
spec
,
**
kwargs
)
assert
response
.
status
==
HTTPStatus
.
CREATED
task_data
=
models
.
DataRequest
(
**
data
)
(
_
,
response
)
=
api_client
.
tasks_api
.
create_data
(
task
.
id
,
task_data
,
(
_
,
response
)
=
api_client
.
tasks_api
.
create_data
(
task
.
id
,
data_request
=
deepcopy
(
data
),
_content_type
=
content_type
,
**
kwargs
)
assert
response
.
status
==
HTTPStatus
.
ACCEPTED
...
...
tests/python/rest_api/utils.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from
http
import
HTTPStatus
from
time
import
sleep
from
cvat_sdk.api_client.api_client
import
Endpoint
from
urllib3
import
HTTPResponse
def
export_dataset
(
endpoint
:
Endpoint
,
*
,
max_retries
:
int
=
20
,
interval
:
float
=
0.1
,
**
kwargs
)
->
HTTPResponse
:
for
_
in
range
(
max_retries
):
(
_
,
response
)
=
endpoint
.
call_with_http_info
(
**
kwargs
,
_parse_response
=
False
)
if
response
.
status
==
HTTPStatus
.
CREATED
:
break
assert
response
.
status
==
HTTPStatus
.
ACCEPTED
sleep
(
interval
)
assert
response
.
status
==
HTTPStatus
.
CREATED
(
_
,
response
)
=
endpoint
.
call_with_http_info
(
**
kwargs
,
action
=
"download"
,
_parse_response
=
False
)
assert
response
.
status
==
HTTPStatus
.
OK
return
response
tests/python/sdk/fixtures.py
浏览文件 @
53697eca
...
...
@@ -2,10 +2,16 @@
#
# SPDX-License-Identifier: MIT
from
pathlib
import
Path
import
pytest
from
cvat_sdk
import
Client
from
PIL
import
Image
from
shared.utils.config
import
BASE_URL
from
shared.utils.helpers
import
generate_image_file
from
.util
import
generate_coco_json
@
pytest
.
fixture
...
...
@@ -20,3 +26,22 @@ def fxt_client(fxt_logger):
with
client
:
yield
client
@
pytest
.
fixture
def
fxt_image_file
(
tmp_path
:
Path
):
img_path
=
tmp_path
/
"img.png"
with
img_path
.
open
(
"wb"
)
as
f
:
f
.
write
(
generate_image_file
(
filename
=
str
(
img_path
),
size
=
(
5
,
10
)).
getvalue
())
return
img_path
@
pytest
.
fixture
def
fxt_coco_file
(
tmp_path
:
Path
,
fxt_image_file
:
Path
):
img_filename
=
fxt_image_file
img_size
=
Image
.
open
(
img_filename
).
size
ann_filename
=
tmp_path
/
"coco.json"
generate_coco_json
(
ann_filename
,
img_info
=
(
img_filename
,
*
img_size
))
yield
ann_filename
tests/python/sdk/test_issues_comments.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import
io
from
logging
import
Logger
from
pathlib
import
Path
from
typing
import
Tuple
import
pytest
from
cvat_sdk
import
Client
from
cvat_sdk.api_client
import
exceptions
,
models
from
cvat_sdk.core.proxies.tasks
import
ResourceType
,
Task
from
shared.utils.config
import
USER_PASS
class
TestIssuesUsecases
:
@
pytest
.
fixture
(
autouse
=
True
)
def
setup
(
self
,
changedb
,
# force fixture call order to allow DB setup
tmp_path
:
Path
,
fxt_logger
:
Tuple
[
Logger
,
io
.
StringIO
],
fxt_client
:
Client
,
fxt_stdout
:
io
.
StringIO
,
admin_user
:
str
,
):
self
.
tmp_path
=
tmp_path
_
,
self
.
logger_stream
=
fxt_logger
self
.
client
=
fxt_client
self
.
stdout
=
fxt_stdout
self
.
user
=
admin_user
self
.
client
.
login
((
self
.
user
,
USER_PASS
))
yield
@
pytest
.
fixture
def
fxt_new_task
(
self
,
fxt_image_file
:
Path
):
task
=
self
.
client
.
tasks
.
create_from_data
(
spec
=
{
"name"
:
"test_task"
,
"labels"
:
[{
"name"
:
"car"
},
{
"name"
:
"person"
}],
},
resource_type
=
ResourceType
.
LOCAL
,
resources
=
[
str
(
fxt_image_file
)],
data_params
=
{
"image_quality"
:
80
},
)
return
task
def
test_can_retrieve_issue
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
)
)
retrieved_issue
=
self
.
client
.
issues
.
retrieve
(
issue
.
id
)
assert
issue
.
id
==
retrieved_issue
.
id
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_list_issues
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
assignee
=
self
.
client
.
users
.
list
()[
0
].
id
,
)
)
issues
=
self
.
client
.
issues
.
list
()
assert
any
(
issue
.
id
==
j
.
id
for
j
in
issues
)
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_list_comments
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
)
)
comment
=
self
.
client
.
comments
.
create
(
models
.
CommentWriteRequest
(
issue
.
id
,
message
=
"hi!"
))
issue
.
fetch
()
comment_ids
=
{
c
.
id
for
c
in
issue
.
comments
}
assert
len
(
comment_ids
)
==
2
assert
comment
.
id
in
comment_ids
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_modify_issue
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
)
)
issue
.
update
(
models
.
PatchedIssueWriteRequest
(
resolved
=
True
))
retrieved_issue
=
self
.
client
.
issues
.
retrieve
(
issue
.
id
)
assert
retrieved_issue
.
resolved
is
True
assert
issue
.
resolved
==
retrieved_issue
.
resolved
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_remove_issue
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
)
)
issue
.
remove
()
with
pytest
.
raises
(
exceptions
.
NotFoundException
):
issue
.
fetch
()
with
pytest
.
raises
(
exceptions
.
NotFoundException
):
self
.
client
.
comments
.
retrieve
(
issue
.
comments
[
0
].
id
)
assert
self
.
stdout
.
getvalue
()
==
""
class
TestCommentsUsecases
:
@
pytest
.
fixture
(
autouse
=
True
)
def
setup
(
self
,
changedb
,
# force fixture call order to allow DB setup
tmp_path
:
Path
,
fxt_logger
:
Tuple
[
Logger
,
io
.
StringIO
],
fxt_client
:
Client
,
fxt_stdout
:
io
.
StringIO
,
admin_user
:
str
,
):
self
.
tmp_path
=
tmp_path
_
,
self
.
logger_stream
=
fxt_logger
self
.
client
=
fxt_client
self
.
stdout
=
fxt_stdout
self
.
user
=
admin_user
self
.
client
.
login
((
self
.
user
,
USER_PASS
))
yield
@
pytest
.
fixture
def
fxt_new_task
(
self
,
fxt_image_file
:
Path
):
task
=
self
.
client
.
tasks
.
create_from_data
(
spec
=
{
"name"
:
"test_task"
,
"labels"
:
[{
"name"
:
"car"
},
{
"name"
:
"person"
}],
},
resource_type
=
ResourceType
.
LOCAL
,
resources
=
[
str
(
fxt_image_file
)],
data_params
=
{
"image_quality"
:
80
},
)
return
task
def
test_can_retrieve_comment
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
)
)
comment
=
self
.
client
.
comments
.
create
(
models
.
CommentWriteRequest
(
issue
.
id
,
message
=
"hi!"
))
retrieved_comment
=
self
.
client
.
comments
.
retrieve
(
comment
.
id
)
assert
comment
.
id
==
retrieved_comment
.
id
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_list_comments
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
)
)
comment
=
self
.
client
.
comments
.
create
(
models
.
CommentWriteRequest
(
issue
.
id
,
message
=
"hi!"
))
comments
=
self
.
client
.
comments
.
list
()
assert
any
(
comment
.
id
==
c
.
id
for
c
in
comments
)
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_modify_comment
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
)
)
comment
=
self
.
client
.
comments
.
create
(
models
.
CommentWriteRequest
(
issue
.
id
,
message
=
"hi!"
))
comment
.
update
(
models
.
PatchedCommentWriteRequest
(
message
=
"bar"
))
retrieved_comment
=
self
.
client
.
comments
.
retrieve
(
comment
.
id
)
assert
retrieved_comment
.
message
==
"bar"
assert
comment
.
message
==
retrieved_comment
.
message
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_remove_comment
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
)
)
comment
=
self
.
client
.
comments
.
create
(
models
.
CommentWriteRequest
(
issue
.
id
,
message
=
"hi!"
))
comment
.
remove
()
with
pytest
.
raises
(
exceptions
.
NotFoundException
):
comment
.
fetch
()
assert
self
.
stdout
.
getvalue
()
==
""
tests/python/sdk/test_jobs.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import
io
import
os.path
as
osp
from
logging
import
Logger
from
pathlib
import
Path
from
typing
import
Tuple
import
pytest
from
cvat_sdk
import
Client
from
cvat_sdk.api_client
import
models
from
cvat_sdk.core.proxies.tasks
import
ResourceType
,
Task
from
PIL
import
Image
from
shared.utils.config
import
USER_PASS
from
.util
import
make_pbar
class
TestJobUsecases
:
@
pytest
.
fixture
(
autouse
=
True
)
def
setup
(
self
,
changedb
,
# force fixture call order to allow DB setup
tmp_path
:
Path
,
fxt_logger
:
Tuple
[
Logger
,
io
.
StringIO
],
fxt_client
:
Client
,
fxt_stdout
:
io
.
StringIO
,
admin_user
:
str
,
):
self
.
tmp_path
=
tmp_path
_
,
self
.
logger_stream
=
fxt_logger
self
.
client
=
fxt_client
self
.
stdout
=
fxt_stdout
self
.
user
=
admin_user
self
.
client
.
login
((
self
.
user
,
USER_PASS
))
yield
@
pytest
.
fixture
def
fxt_new_task
(
self
,
fxt_image_file
:
Path
):
task
=
self
.
client
.
tasks
.
create_from_data
(
spec
=
{
"name"
:
"test_task"
,
"labels"
:
[{
"name"
:
"car"
},
{
"name"
:
"person"
}],
},
resource_type
=
ResourceType
.
LOCAL
,
resources
=
[
str
(
fxt_image_file
)],
data_params
=
{
"image_quality"
:
80
},
)
return
task
@
pytest
.
fixture
def
fxt_task_with_shapes
(
self
,
fxt_new_task
:
Task
):
fxt_new_task
.
set_annotations
(
models
.
LabeledDataRequest
(
shapes
=
[
models
.
LabeledShapeRequest
(
frame
=
0
,
label_id
=
fxt_new_task
.
labels
[
0
].
id
,
type
=
"rectangle"
,
points
=
[
1
,
1
,
2
,
2
],
),
],
)
)
return
fxt_new_task
def
test_can_retrieve_job
(
self
,
fxt_new_task
:
Task
):
job_id
=
fxt_new_task
.
get_jobs
()[
0
].
id
job
=
self
.
client
.
jobs
.
retrieve
(
job_id
)
assert
job
.
id
==
job_id
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_list_jobs
(
self
,
fxt_new_task
:
Task
):
task_job_ids
=
set
(
j
.
id
for
j
in
fxt_new_task
.
get_jobs
())
jobs
=
self
.
client
.
jobs
.
list
()
assert
len
(
task_job_ids
)
!=
0
assert
task_job_ids
.
issubset
(
j
.
id
for
j
in
jobs
)
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_update_job_field_directly
(
self
,
fxt_new_task
:
Task
):
job
=
self
.
client
.
jobs
.
list
()[
0
]
assert
not
job
.
assignee
new_assignee
=
self
.
client
.
users
.
list
()[
0
]
job
.
update
({
"assignee"
:
new_assignee
.
id
})
updated_job
=
self
.
client
.
jobs
.
retrieve
(
job
.
id
)
assert
updated_job
.
assignee
.
id
==
new_assignee
.
id
assert
self
.
stdout
.
getvalue
()
==
""
@
pytest
.
mark
.
parametrize
(
"include_images"
,
(
True
,
False
))
def
test_can_download_dataset
(
self
,
fxt_new_task
:
Task
,
include_images
:
bool
):
pbar_out
=
io
.
StringIO
()
pbar
=
make_pbar
(
file
=
pbar_out
)
task_id
=
fxt_new_task
.
id
path
=
str
(
self
.
tmp_path
/
f
"task_
{
task_id
}
-cvat.zip"
)
job
=
self
.
client
.
jobs
.
retrieve
(
task_id
)
job
.
export_dataset
(
format_name
=
"CVAT for images 1.1"
,
filename
=
path
,
pbar
=
pbar
,
include_images
=
include_images
,
)
assert
"100%"
in
pbar_out
.
getvalue
().
strip
(
"
\r
"
).
split
(
"
\r
"
)[
-
1
]
assert
osp
.
isfile
(
path
)
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_download_preview
(
self
,
fxt_new_task
:
Task
):
frame_encoded
=
fxt_new_task
.
get_jobs
()[
0
].
get_preview
()
assert
Image
.
open
(
frame_encoded
).
size
!=
0
assert
self
.
stdout
.
getvalue
()
==
""
@
pytest
.
mark
.
parametrize
(
"quality"
,
(
"compressed"
,
"original"
))
def
test_can_download_frame
(
self
,
fxt_new_task
:
Task
,
quality
:
str
):
frame_encoded
=
fxt_new_task
.
get_jobs
()[
0
].
get_frame
(
0
,
quality
=
quality
)
assert
Image
.
open
(
frame_encoded
).
size
!=
0
assert
self
.
stdout
.
getvalue
()
==
""
@
pytest
.
mark
.
parametrize
(
"quality"
,
(
"compressed"
,
"original"
))
def
test_can_download_frames
(
self
,
fxt_new_task
:
Task
,
quality
:
str
):
fxt_new_task
.
get_jobs
()[
0
].
download_frames
(
[
0
],
quality
=
quality
,
outdir
=
str
(
self
.
tmp_path
),
filename_pattern
=
"frame-{frame_id}{frame_ext}"
,
)
assert
osp
.
isfile
(
self
.
tmp_path
/
"frame-0.jpg"
)
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_upload_annotations
(
self
,
fxt_new_task
:
Task
,
fxt_coco_file
:
Path
):
pbar_out
=
io
.
StringIO
()
pbar
=
make_pbar
(
file
=
pbar_out
)
fxt_new_task
.
get_jobs
()[
0
].
import_annotations
(
format_name
=
"COCO 1.0"
,
filename
=
str
(
fxt_coco_file
),
pbar
=
pbar
)
assert
"uploaded"
in
self
.
logger_stream
.
getvalue
()
assert
"100%"
in
pbar_out
.
getvalue
().
strip
(
"
\r
"
).
split
(
"
\r
"
)[
-
1
]
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_get_meta
(
self
,
fxt_new_task
:
Task
):
meta
=
fxt_new_task
.
get_jobs
()[
0
].
get_meta
()
assert
meta
.
image_quality
==
80
assert
meta
.
size
==
1
assert
len
(
meta
.
frames
)
==
meta
.
size
assert
meta
.
frames
[
0
].
name
==
"img.png"
assert
meta
.
frames
[
0
].
width
==
5
assert
meta
.
frames
[
0
].
height
==
10
assert
not
meta
.
deleted_frames
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_remove_frames
(
self
,
fxt_new_task
:
Task
):
fxt_new_task
.
get_jobs
()[
0
].
remove_frames_by_ids
([
0
])
meta
=
fxt_new_task
.
get_jobs
()[
0
].
get_meta
()
assert
meta
.
deleted_frames
==
[
0
]
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_get_issues
(
self
,
fxt_new_task
:
Task
):
issue
=
self
.
client
.
issues
.
create
(
models
.
IssueWriteRequest
(
frame
=
0
,
position
=
[
2.0
,
4.0
],
job
=
fxt_new_task
.
get_jobs
()[
0
].
id
,
message
=
"hello"
,
)
)
job_issue_ids
=
set
(
j
.
id
for
j
in
fxt_new_task
.
get_jobs
()[
0
].
get_issues
())
assert
{
issue
.
id
}
==
job_issue_ids
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_get_annotations
(
self
,
fxt_task_with_shapes
:
Task
):
anns
=
fxt_task_with_shapes
.
get_jobs
()[
0
].
get_annotations
()
assert
len
(
anns
.
shapes
)
==
1
assert
anns
.
shapes
[
0
].
type
.
value
==
"rectangle"
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_set_annotations
(
self
,
fxt_new_task
:
Task
):
fxt_new_task
.
get_jobs
()[
0
].
set_annotations
(
models
.
LabeledDataRequest
(
tags
=
[
models
.
LabeledImageRequest
(
frame
=
0
,
label_id
=
fxt_new_task
.
labels
[
0
].
id
)],
)
)
anns
=
fxt_new_task
.
get_jobs
()[
0
].
get_annotations
()
assert
len
(
anns
.
tags
)
==
1
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_clear_annotations
(
self
,
fxt_task_with_shapes
:
Task
):
fxt_task_with_shapes
.
get_jobs
()[
0
].
remove_annotations
()
anns
=
fxt_task_with_shapes
.
get_jobs
()[
0
].
get_annotations
()
assert
len
(
anns
.
tags
)
==
0
assert
len
(
anns
.
tracks
)
==
0
assert
len
(
anns
.
shapes
)
==
0
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_remove_annotations
(
self
,
fxt_new_task
:
Task
):
fxt_new_task
.
get_jobs
()[
0
].
set_annotations
(
models
.
LabeledDataRequest
(
shapes
=
[
models
.
LabeledShapeRequest
(
frame
=
0
,
label_id
=
fxt_new_task
.
labels
[
0
].
id
,
type
=
"rectangle"
,
points
=
[
1
,
1
,
2
,
2
],
),
models
.
LabeledShapeRequest
(
frame
=
0
,
label_id
=
fxt_new_task
.
labels
[
0
].
id
,
type
=
"rectangle"
,
points
=
[
2
,
2
,
3
,
3
],
),
],
)
)
anns
=
fxt_new_task
.
get_jobs
()[
0
].
get_annotations
()
fxt_new_task
.
get_jobs
()[
0
].
remove_annotations
(
ids
=
[
anns
.
shapes
[
0
].
id
])
anns
=
fxt_new_task
.
get_jobs
()[
0
].
get_annotations
()
assert
len
(
anns
.
tags
)
==
0
assert
len
(
anns
.
tracks
)
==
0
assert
len
(
anns
.
shapes
)
==
1
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_update_annotations
(
self
,
fxt_task_with_shapes
:
Task
):
fxt_task_with_shapes
.
get_jobs
()[
0
].
update_annotations
(
models
.
PatchedLabeledDataRequest
(
shapes
=
[
models
.
LabeledShapeRequest
(
frame
=
0
,
label_id
=
fxt_task_with_shapes
.
labels
[
0
].
id
,
type
=
"rectangle"
,
points
=
[
0
,
1
,
2
,
3
],
),
],
tracks
=
[
models
.
LabeledTrackRequest
(
frame
=
0
,
label_id
=
fxt_task_with_shapes
.
labels
[
0
].
id
,
shapes
=
[
models
.
TrackedShapeRequest
(
frame
=
0
,
type
=
"polygon"
,
points
=
[
3
,
2
,
2
,
3
,
3
,
4
]
),
],
)
],
tags
=
[
models
.
LabeledImageRequest
(
frame
=
0
,
label_id
=
fxt_task_with_shapes
.
labels
[
0
].
id
)
],
)
)
anns
=
fxt_task_with_shapes
.
get_jobs
()[
0
].
get_annotations
()
assert
len
(
anns
.
shapes
)
==
2
assert
len
(
anns
.
tracks
)
==
1
assert
len
(
anns
.
tags
)
==
1
assert
self
.
stdout
.
getvalue
()
==
""
tests/python/sdk/test_tasks.py
浏览文件 @
53697eca
此差异已折叠。
点击以展开。
tests/python/sdk/test_users.py
0 → 100644
浏览文件 @
53697eca
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import
io
from
logging
import
Logger
from
pathlib
import
Path
from
typing
import
Tuple
import
pytest
from
cvat_sdk
import
Client
,
models
from
cvat_sdk.api_client
import
exceptions
from
shared.utils.config
import
USER_PASS
class
TestUserUsecases
:
@
pytest
.
fixture
(
autouse
=
True
)
def
setup
(
self
,
changedb
,
# force fixture call order to allow DB setup
tmp_path
:
Path
,
fxt_logger
:
Tuple
[
Logger
,
io
.
StringIO
],
fxt_client
:
Client
,
fxt_stdout
:
io
.
StringIO
,
admin_user
:
str
,
):
self
.
tmp_path
=
tmp_path
_
,
self
.
logger_stream
=
fxt_logger
self
.
client
=
fxt_client
self
.
stdout
=
fxt_stdout
self
.
user
=
admin_user
self
.
client
.
login
((
self
.
user
,
USER_PASS
))
yield
def
test_can_retrieve_user
(
self
):
me
=
self
.
client
.
users
.
retrieve_current_user
()
user
=
self
.
client
.
users
.
retrieve
(
me
.
id
)
assert
user
.
id
==
me
.
id
assert
user
.
username
==
self
.
user
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_list_users
(
self
):
users
=
self
.
client
.
users
.
list
()
assert
self
.
user
in
set
(
u
.
username
for
u
in
users
)
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_update_user
(
self
):
user
=
self
.
client
.
users
.
retrieve_current_user
()
user
.
update
(
models
.
PatchedUserRequest
(
first_name
=
"foo"
,
last_name
=
"bar"
))
retrieved_user
=
self
.
client
.
users
.
retrieve
(
user
.
id
)
assert
retrieved_user
.
first_name
==
"foo"
assert
retrieved_user
.
last_name
==
"bar"
assert
user
.
first_name
==
retrieved_user
.
first_name
assert
user
.
last_name
==
retrieved_user
.
last_name
assert
self
.
stdout
.
getvalue
()
==
""
def
test_can_remove_user
(
self
):
users
=
self
.
client
.
users
.
list
()
removed_user
=
next
(
u
for
u
in
users
if
u
.
username
!=
self
.
user
)
removed_user
.
remove
()
with
pytest
.
raises
(
exceptions
.
NotFoundException
):
removed_user
.
fetch
()
assert
self
.
stdout
.
getvalue
()
==
""
tests/python/sdk/util.py
浏览文件 @
53697eca
tests/python/shared/fixtures/data.py
浏览文件 @
53697eca
...
...
@@ -279,6 +279,10 @@ def filter_tasks_with_shapes(annotations):
return
list
(
filter
(
lambda
t
:
annotations
[
'task'
][
str
(
t
[
'id'
])][
'shapes'
],
tasks
))
return
find
@
pytest
.
fixture
(
scope
=
'session'
)
def
jobs_with_shapes
(
jobs
,
filter_jobs_with_shapes
):
return
filter_jobs_with_shapes
(
jobs
)
@
pytest
.
fixture
(
scope
=
'session'
)
def
tasks_with_shapes
(
tasks
,
filter_tasks_with_shapes
):
return
filter_tasks_with_shapes
(
tasks
)
...
...
tests/python/shared/utils/config.py
浏览文件 @
53697eca
...
...
@@ -48,5 +48,6 @@ def post_files_method(username, endpoint, data, files, **kwargs):
def
server_get
(
username
,
endpoint
,
**
kwargs
):
return
requests
.
get
(
get_server_url
(
endpoint
,
**
kwargs
),
auth
=
(
username
,
USER_PASS
))
def
make_api_client
(
user
:
str
)
->
ApiClient
:
return
ApiClient
(
configuration
=
Configuration
(
host
=
BASE_URL
,
username
=
user
,
password
=
USER_PASS
))
def
make_api_client
(
user
:
str
,
*
,
password
:
str
=
None
)
->
ApiClient
:
return
ApiClient
(
configuration
=
Configuration
(
host
=
BASE_URL
,
username
=
user
,
password
=
password
or
USER_PASS
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录