未验证 提交 192fd726 编写于 作者: R Roman Donchenko 提交者: GitHub

Fix creation of tasks with Git repositories via the SDK (#5409)

Fixes #4365
上级 8b13a2c4
......@@ -93,6 +93,8 @@ non-ascii paths while adding files from "Connected file share" (issue #4428)
- Fix chart not being upgradable (<https://github.com/opencv/cvat/pull/5371>)
- Broken helm chart - if using custom release name (<https://github.com/opencv/cvat/pull/5403>)
- Missing source tag in project annotations (<https://github.com/opencv/cvat/pull/5408>)
- Creating a task with a Git repository via the SDK
(<https://github.com/opencv/cvat/issues/4365>)
### Security
- TDB
......
......@@ -291,6 +291,9 @@ class CVAT_API_V2:
def git_check(self, rq_id: int) -> str:
return self.git + f"check/{rq_id}"
def git_get(self, task_id: int) -> str:
return self.git + f"get/{task_id}"
def make_endpoint_url(
self,
path: str,
......
......@@ -30,7 +30,7 @@ def create_git_repo(
post_params={"path": repo_url, "lfs": use_lfs, "tid": task_id},
headers=common_headers,
)
response_json = json.loads(response)
response_json = json.loads(response.data)
rq_id = response_json["rq_id"]
client.logger.info(f"Create RQ ID: {rq_id}")
......
......@@ -339,7 +339,7 @@ class TasksRepo(
if dataset_repository_url:
git.create_git_repo(
self,
self._client,
task_id=task.id,
repo_url=dataset_repository_url,
status_check_period=status_check_period,
......
......@@ -201,6 +201,9 @@ export async function changeRepo(taskId: number, type: string, value: any): Prom
core.server
.request(`${baseURL}/git/repository/${taskId}`, {
method: 'PATCH',
headers: {
'Content-type': 'application/json',
},
data: JSON.stringify({
type,
value,
......
......@@ -3,10 +3,16 @@
# SPDX-License-Identifier: MIT
import http.client
from django.http import HttpResponseBadRequest, JsonResponse, HttpResponse
from django.http import HttpResponseBadRequest, HttpResponse
from rules.contrib.views import permission_required, objectgetter
from cvat.apps.iam.decorators import login_required
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.decorators import api_view, permission_classes
from drf_spectacular.utils import extend_schema
from cvat.apps.engine.log import slogger
from cvat.apps.engine import models
from cvat.apps.dataset_repo.models import GitData
......@@ -14,9 +20,21 @@ import contextlib
import cvat.apps.dataset_repo.dataset_repo as CVATGit
import django_rq
import json
@login_required
def _legacy_api_view(allowed_method_names=None):
# Currently, the views in this file use the legacy permission-checking
# approach, so this decorator disables the default DRF permission classes.
# TODO: migrate to DRF permissions, make the views compatible with drf-spectacular,
# and remove this decorator.
def decorator(view):
view = permission_classes([IsAuthenticated])(view)
view = api_view(allowed_method_names)(view)
view = extend_schema(exclude=True)(view)
return view
return decorator
@_legacy_api_view()
def check_process(request, rq_id):
try:
queue = django_rq.get_queue('default')
......@@ -24,40 +42,40 @@ def check_process(request, rq_id):
if rq_job is not None:
if rq_job.is_queued or rq_job.is_started:
return JsonResponse({"status": rq_job.get_status()})
return Response({"status": rq_job.get_status()})
elif rq_job.is_finished:
return JsonResponse({"status": rq_job.get_status()})
return Response({"status": rq_job.get_status()})
else:
return JsonResponse({"status": rq_job.get_status(), "stderr": rq_job.exc_info})
return Response({"status": rq_job.get_status(), "stderr": rq_job.exc_info})
else:
return JsonResponse({"status": "unknown"})
return Response({"status": "unknown"})
except Exception as ex:
slogger.glob.error("error occurred during checking repository request with rq id {}".format(rq_id), exc_info=True)
return HttpResponseBadRequest(str(ex))
@login_required
@_legacy_api_view(['POST'])
@permission_required(perm=['engine.task.create'],
fn=objectgetter(models.Task, 'tid'), raise_exception=True)
def create(request, tid):
def create(request: Request, tid):
try:
slogger.task[tid].info("create repository request")
body = json.loads(request.body.decode('utf-8'))
body = request.data
path = body["path"]
export_format = body["format"]
export_format = body.get("format")
lfs = body["lfs"]
rq_id = "git.create.{}".format(tid)
queue = django_rq.get_queue("default")
queue.enqueue_call(func = CVATGit.initial_create, args = (tid, path, export_format, lfs, request.user), job_id = rq_id)
return JsonResponse({ "rq_id": rq_id })
return Response({ "rq_id": rq_id })
except Exception as ex:
slogger.glob.error("error occurred during initial cloning repository request with rq id {}".format(rq_id), exc_info=True)
return HttpResponseBadRequest(str(ex))
@login_required
def push_repository(request, tid):
@_legacy_api_view()
def push_repository(request: Request, tid):
try:
slogger.task[tid].info("push repository request")
......@@ -65,7 +83,7 @@ def push_repository(request, tid):
queue = django_rq.get_queue('default')
queue.enqueue_call(func = CVATGit.push, args = (tid, request.user, request.scheme, request.get_host()), job_id = rq_id)
return JsonResponse({ "rq_id": rq_id })
return Response({ "rq_id": rq_id })
except Exception as ex:
with contextlib.suppress(Exception):
slogger.task[tid].error("error occurred during pushing repository request",
......@@ -74,11 +92,11 @@ def push_repository(request, tid):
return HttpResponseBadRequest(str(ex))
@login_required
def get_repository(request, tid):
@_legacy_api_view()
def get_repository(request: Request, tid):
try:
slogger.task[tid].info("get repository request")
return JsonResponse(CVATGit.get(tid, request.user))
return Response(CVATGit.get(tid, request.user))
except Exception as ex:
with contextlib.suppress(Exception):
slogger.task[tid].error("error occurred during getting repository info request",
......@@ -86,12 +104,12 @@ def get_repository(request, tid):
return HttpResponseBadRequest(str(ex))
@login_required
@_legacy_api_view(['PATCH'])
@permission_required(perm=['engine.task.access'],
fn=objectgetter(models.Task, 'tid'), raise_exception=True)
def update_git_repo(request, tid):
def update_git_repo(request: Request, tid):
try:
body = json.loads(request.body.decode('utf-8'))
body = request.data
req_type = body["type"]
value = body["value"]
git_data_obj = GitData.objects.filter(task_id=tid)[0]
......@@ -114,7 +132,7 @@ def update_git_repo(request, tid):
return HttpResponseBadRequest(str(ex))
@login_required
@_legacy_api_view()
def get_meta_info(request):
try:
db_git_records = GitData.objects.all()
......@@ -122,7 +140,7 @@ def get_meta_info(request):
for db_git in db_git_records:
response[db_git.task_id] = db_git.status
return JsonResponse(response, safe = False)
return Response(response)
except Exception as ex:
slogger.glob.exception("error occurred during get meta request", exc_info = True)
return HttpResponseBadRequest(str(ex))
# Copyright (C) 2018-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
from functools import wraps
from django.views.generic import RedirectView
from django.contrib.auth import REDIRECT_FIELD_NAME
from django.http import JsonResponse
from django.conf import settings
from .authentication import TokenAuthenticationEx
def login_required(function=None, redirect_field_name=REDIRECT_FIELD_NAME,
login_url=None, redirect_methods=('GET')):
def decorator(view_func):
@wraps(view_func)
def _wrapped_view(request, *args, **kwargs):
if request.user.is_authenticated:
return view_func(request, *args, **kwargs)
else:
tokenAuth = TokenAuthenticationEx()
auth = tokenAuth.authenticate(request)
if auth is not None:
return view_func(request, *args, **kwargs)
login_url = '{}/login'.format(settings.UI_URL)
if request.method not in redirect_methods:
return JsonResponse({'login_page_url': login_url}, status=403)
return RedirectView.as_view(
url=login_url,
permanent=True,
query_string=True
)(request)
return _wrapped_view
return decorator(function) if function else decorator
......@@ -53,10 +53,10 @@ except ImportError:
def generate_ssh_keys():
keys_dir = '{}/keys'.format(os.getcwd())
ssh_dir = '{}/.ssh'.format(os.getenv('HOME'))
pidfile = os.path.join(ssh_dir, 'ssh.pid')
pidfile = os.path.join(keys_dir, 'ssh.pid')
def add_ssh_keys():
IGNORE_FILES = ('README.md', 'ssh.pid')
IGNORE_FILES = ('README.md',)
keys_to_add = [entry.name for entry in os.scandir(ssh_dir) if entry.name not in IGNORE_FILES]
keys_to_add = ' '.join(os.path.join(ssh_dir, f) for f in keys_to_add)
subprocess.run(['ssh-add {}'.format(keys_to_add)], # nosec
......
......@@ -15,3 +15,17 @@ services:
cvat:
aliases:
- webhooks
git_server:
image: alpine/git
restart: always
depends_on:
- cvat_server
entrypoint: /mnt/scripts/entrypoint.sh
volumes:
- ./tests/git_server/:/mnt/scripts:ro
- cvat_keys:/mnt/keys:ro
networks:
cvat:
aliases:
- gitserver
#!/bin/sh
set -e
mkdir -p ~/repos/repo.git
git -C ~/repos/repo.git init --bare
mkdir -p ~/.ssh
# Authorize CVAT's client keys
cat /mnt/keys/*.pub > ~/.ssh/authorized_keys
ssh-keygen -A
exec /usr/sbin/sshd -D
......@@ -3,6 +3,7 @@
# SPDX-License-Identifier: MIT
import io
import json
import os.path as osp
import zipfile
from logging import Logger
......@@ -169,6 +170,39 @@ class TestTaskUsecases:
assert capture.match("No media data found")
assert self.stdout.getvalue() == ""
def test_can_create_task_with_git_repo(self, fxt_image_file: Path):
pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out)
task_spec = {
"name": f"task with Git repo",
"labels": [{"name": "car"}],
}
repository_url = "root@gitserver:repos/repo.git [annotations/annot.zip]"
task = self.client.tasks.create_from_data(
spec=task_spec,
resource_type=ResourceType.LOCAL,
resources=[str(fxt_image_file)],
pbar=pbar,
dataset_repository_url=repository_url,
)
assert task.size == 1
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert self.stdout.getvalue() == ""
git_get_response = self.client.api_client.rest_client.GET(
self.client.api_map.git_get(task.id),
headers=self.client.api_client.get_common_headers(),
)
response_json = json.loads(git_get_response.data)
assert response_json["url"]["value"] == repository_url
assert response_json["format"] == "CVAT for images 1.1"
assert response_json["lfs"] is False
def test_can_retrieve_task(self, fxt_new_task: Task):
task_id = fxt_new_task.id
......
......@@ -32,7 +32,7 @@ DC_FILES = [
"docker-compose.dev.yml",
"tests/docker-compose.file_share.yml",
"tests/docker-compose.minio.yml",
"tests/docker-compose.webhook.yml",
"tests/docker-compose.test_servers.yml",
)
] + CONTAINER_NAME_FILES
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册