未验证 提交 cacb9973 编写于 作者: M Maxim Zhiltsov 提交者: GitHub

Add job access checks for model invocations (#5392)

Fixes #4996
- Added job access checks for model launches in the interactive mode
上级 192fd726
......@@ -90,6 +90,7 @@ non-ascii paths while adding files from "Connected file share" (issue #4428)
- Fixed FBRS serverless function runtime error on images with alpha channel (<https://github.com/opencv/cvat/pull/5384>)
- Attaching manifest with custom name (<https://github.com/opencv/cvat/pull/5377>)
- Uploading non-zip annotaion files (<https://github.com/opencv/cvat/pull/5386>)
- A permission problem with interactive model launches for workers in orgs (<https://github.com/opencv/cvat/issues/4996>)
- Fix chart not being upgradable (<https://github.com/opencv/cvat/pull/5371>)
- Broken helm chart - if using custom release name (<https://github.com/opencv/cvat/pull/5403>)
- Missing source tag in project annotations (<https://github.com/opencv/cvat/pull/5408>)
......
......@@ -369,7 +369,8 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
try {
// run server request
this.setState({ fetching: true });
const response = await core.lambda.call(jobInstance.taskId, interactor, data);
const response = await core.lambda.call(jobInstance.taskId, interactor,
{ ...data, job: jobInstance.id });
// approximation with cv.approxPolyDP
const approximated = await this.approximateResponsePoints(response.points);
......@@ -740,6 +741,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
const response = await core.lambda.call(jobInstance.taskId, tracker, {
frame: frame - 1,
shapes: trackableObjects.shapes,
job: jobInstance.id,
});
const { states: serverlessStates } = response;
......@@ -787,6 +789,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
frame,
shapes: trackableObjects.shapes,
states: trackableObjects.states,
job: jobInstance.id,
});
response.shapes = response.shapes.map(trackedRectangleMapper);
......@@ -1161,7 +1164,9 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
runInference={async (model: Model, body: DetectorRequestBody) => {
try {
this.setState({ mode: 'detection', fetching: true });
const result = await core.lambda.call(jobInstance.taskId, model, { ...body, frame });
const result = await core.lambda.call(jobInstance.taskId, model, {
...body, frame, job: jobInstance.id,
});
const states = result.map(
(data: any): any => {
const jobLabel = (jobInstance.labels as Label[])
......
......@@ -446,6 +446,9 @@ class Segment(models.Model):
start_frame = models.IntegerField()
stop_frame = models.IntegerField()
def contains_frame(self, idx: int) -> bool:
return self.start_frame <= idx and idx <= self.stop_frame
class Meta:
default_permissions = ()
......@@ -472,6 +475,11 @@ class Job(models.Model):
project = self.segment.task.project
return project.id if project else None
@extend_schema_field(OpenApiTypes.INT)
def get_task_id(self):
task = self.segment.task
return task.id if task else None
def get_organization_id(self):
return self.segment.task.organization
......
......@@ -365,12 +365,15 @@ class LambdaPermission(OpenPolicyAgentPermission):
def create(cls, request, view, obj):
permissions = []
if view.basename == 'function' or view.basename == 'request':
for scope in cls.get_scopes(request, view, obj):
scopes = cls.get_scopes(request, view, obj)
for scope in scopes:
self = cls.create_base_perm(request, view, scope, obj)
permissions.append(self)
task_id = request.data.get('task')
if task_id:
if job_id := request.data.get('job'):
perm = JobPermission.create_scope_view_data(request, job_id)
permissions.append(perm)
elif task_id := request.data.get('task'):
perm = TaskPermission.create_scope_view_data(request, task_id)
permissions.append(perm)
......@@ -879,6 +882,14 @@ class JobPermission(OpenPolicyAgentPermission):
return permissions
@classmethod
def create_scope_view_data(cls, request, job_id):
try:
obj = Job.objects.get(id=job_id)
except Job.DoesNotExist as ex:
raise ValidationError(str(ex))
return cls(**cls.unpack_context(request), obj=obj, scope='view:data')
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.url = settings.IAM_OPA_DATA_URL + '/jobs/allow'
......
......@@ -6,6 +6,7 @@
import json
from collections import OrderedDict
from io import BytesIO
from typing import Dict, Optional
from unittest import mock, skip
import os
......@@ -71,7 +72,7 @@ class ForceLogin:
if self.user:
self.client.logout()
class LambdaTestCase(APITestCase):
class _LambdaTestCaseBase(APITestCase):
def setUp(self):
self.client = APIClient()
......@@ -83,11 +84,6 @@ class LambdaTestCase(APITestCase):
self.addCleanup(invoke_patcher.stop)
invoke_patcher.start()
images_main_task = self._generate_task_images(3)
images_assigneed_to_user_task = self._generate_task_images(3)
self.main_task = self._create_task(tasks["main"], images_main_task)
self.assigneed_to_user_task = self._create_task(tasks["assigneed_to_user"], images_assigneed_to_user_task)
def __get_data_from_lambda_manager_http(self, **kwargs):
url = kwargs["url"]
if url == "/api/functions":
......@@ -143,24 +139,28 @@ class LambdaTestCase(APITestCase):
user_admin = User.objects.create_superuser(username="admin", email="",
password="admin")
user_admin.groups.add(group_admin)
user_dummy = User.objects.create_user(username="user", password="user")
user_dummy = User.objects.create_user(username="user", password="user",
email="user@example.com")
user_dummy.groups.add(group_user)
cls.admin = user_admin
cls.user = user_dummy
def _create_task(self, data, image_data):
with ForceLogin(self.admin, self.client):
response = self.client.post('/api/tasks', data=data, format="json")
def _create_task(self, data, image_data, *, owner=None, org_id=None):
with ForceLogin(owner or self.admin, self.client):
response = self.client.post('/api/tasks', data=data, format="json",
QUERY_STRING=f'org_id={org_id}' if org_id is not None else None)
assert response.status_code == status.HTTP_201_CREATED, response.status_code
tid = response.data["id"]
response = self.client.post("/api/tasks/%s/data" % tid,
data=image_data)
data=image_data,
QUERY_STRING=f'org_id={org_id}' if org_id is not None else None)
assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code
response = self.client.get("/api/tasks/%s" % tid)
response = self.client.get("/api/tasks/%s" % tid,
QUERY_STRING=f'org_id={org_id}' if org_id is not None else None)
task = response.data
return task
......@@ -180,26 +180,37 @@ class LambdaTestCase(APITestCase):
cls._create_db_users()
def _get_request(self, path, user):
def _get_request(self, path, user, *, org_id=None):
with ForceLogin(user, self.client):
response = self.client.get(path)
response = self.client.get(path,
QUERY_STRING=f'org_id={org_id}' if org_id is not None else '')
return response
def _delete_request(self, path, user):
def _delete_request(self, path, user, *, org_id=None):
with ForceLogin(user, self.client):
response = self.client.delete(path)
response = self.client.delete(path,
QUERY_STRING=f'org_id={org_id}' if org_id is not None else '')
return response
def _post_request(self, path, user, data):
def _post_request(self, path, user, data, *, org_id=None):
data = json.dumps(data)
with ForceLogin(user, self.client):
response = self.client.post(path, data=data, content_type='application/json')
response = self.client.post(path, data=data, content_type='application/json',
QUERY_STRING=f'org_id={org_id}' if org_id is not None else '')
return response
def __check_expected_keys_in_response_function(self, data):
def _patch_request(self, path, user, data, *, org_id=None):
data = json.dumps(data)
with ForceLogin(user, self.client):
response = self.client.patch(path, data=data, content_type='application/json',
QUERY_STRING=f'org_id={org_id}' if org_id is not None else '')
return response
def _check_expected_keys_in_response_function(self, data):
kind = data["kind"]
if kind == "interactor":
for key in expected_keys_in_response_function_interactor:
......@@ -212,16 +223,27 @@ class LambdaTestCase(APITestCase):
self.assertIn(key, data)
class LambdaTestCases(_LambdaTestCaseBase):
def setUp(self):
super().setUp()
images_main_task = self._generate_task_images(3)
images_assigneed_to_user_task = self._generate_task_images(3)
self.main_task = self._create_task(tasks["main"], images_main_task)
self.assigneed_to_user_task = self._create_task(
tasks["assigneed_to_user"], images_assigneed_to_user_task
)
def test_api_v2_lambda_functions_list(self):
response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
for data in response.data:
self.__check_expected_keys_in_response_function(data)
self._check_expected_keys_in_response_function(data)
response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.user)
self.assertEqual(response.status_code, status.HTTP_200_OK)
for data in response.data:
self.__check_expected_keys_in_response_function(data)
self._check_expected_keys_in_response_function(data)
response = self._get_request(LAMBDA_FUNCTIONS_PATH, None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
......@@ -257,11 +279,11 @@ class LambdaTestCase(APITestCase):
response = self._get_request(path, self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.__check_expected_keys_in_response_function(response.data)
self._check_expected_keys_in_response_function(response.data)
response = self._get_request(path, self.user)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.__check_expected_keys_in_response_function(response.data)
self._check_expected_keys_in_response_function(response.data)
response = self._get_request(path, None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
......@@ -966,3 +988,151 @@ class LambdaTestCase(APITestCase):
response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_error}", self.admin, data)
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
class Issue4996_Cases(_LambdaTestCaseBase):
# Check regressions for https://github.com/opencv/cvat/issues/4996#issuecomment-1266123032
# We need to check that job assignee can call functions in the assigned jobs
# This requires to pass the job id in the call request.
def _create_org(self, *, owner: int, members: Dict[int, str] = None) -> dict:
org = self._post_request('/api/organizations', user=owner, data={
"slug": "testorg",
"name": "test Org",
})
assert org.status_code == status.HTTP_201_CREATED
org = org.json()
for uid, role in members.items():
user = self._get_request('/api/users/self', user=uid)
assert user.status_code == status.HTTP_200_OK
user = user.json()
invitation = self._post_request('/api/invitations', user=owner, data={
'email': user['email'],
'role': role,
}, org_id=org['id'])
assert invitation.status_code == status.HTTP_201_CREATED
return org
def _set_task_assignee(self, task: int, assignee: Optional[int], *,
org_id: Optional[int] = None):
response = self._patch_request(f'/api/tasks/{task}', user=self.admin, data={
'assignee_id': assignee,
}, org_id=org_id)
assert response.status_code == status.HTTP_200_OK
def _set_job_assignee(self, job: int, assignee: Optional[int], *,
org_id: Optional[int] = None):
response = self._patch_request(f'/api/jobs/{job}', user=self.admin, data={
'assignee': assignee,
}, org_id=org_id)
assert response.status_code == status.HTTP_200_OK
def setUp(self):
self.org = self._create_org(owner=self.admin, members={self.user: 'worker'})
task = self._create_task(data={
'name': 'test_task',
'labels': [{'name': 'cat'}],
'segment_size': 2
},
image_data=self._generate_task_images(6),
owner=self.admin,
org_id=self.org['id'],
)
self.task = task
jobs = self._get_request(f"/api/tasks/{self.task['id']}/jobs", self.admin,
org_id=self.org['id'])
assert jobs.status_code == status.HTTP_200_OK
self.job = jobs.json()[1]
self.common_data = {
"task": self.task['id'],
"frame": 0,
"cleanup": True,
"mapping": {
"car": { "name": "car" },
},
}
self.function_name = f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}"
return super().setUp()
def _get_valid_job_params(self):
return {
"job": self.job['id'],
"frame": 2
}
def _get_invalid_job_params(self):
return {
"job": self.job['id'],
"frame": 0
}
def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_task_request(self):
data = self.common_data.copy()
with self.subTest(job=None, assignee=None):
response = self._post_request(self.function_name, self.user, data,
org_id=self.org['id'])
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_job_request(self):
data = self.common_data.copy()
data.update(self._get_valid_job_params())
with self.subTest(job='defined', assignee=None):
response = self._post_request(self.function_name, self.user, data,
org_id=self.org['id'])
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_can_call_function_for_job_worker_in_org__allow_task_assigned_worker_with_task_request(self):
self._set_task_assignee(self.task['id'], self.user.id, org_id=self.org['id'])
data = self.common_data.copy()
with self.subTest(job=None, assignee='task'):
response = self._post_request(self.function_name, self.user, data,
org_id=self.org['id'])
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_can_call_function_for_job_worker_in_org__deny_job_assigned_worker_with_task_request(self):
self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id'])
data = self.common_data.copy()
with self.subTest(job=None, assignee='job'):
response = self._post_request(self.function_name, self.user, data,
org_id=self.org['id'])
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_can_call_function_for_job_worker_in_org__allow_job_assigned_worker_with_job_request(self):
self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id'])
data = self.common_data.copy()
data.update(self._get_valid_job_params())
with self.subTest(job='defined', assignee='job'):
response = self._post_request(self.function_name, self.user, data,
org_id=self.org['id'])
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_can_check_job_boundaries_in_function_call__fail_for_frame_outside_job(self):
self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id'])
data = self.common_data.copy()
data.update(self._get_invalid_job_params())
with self.subTest(job='defined', frame='outside'):
response = self._post_request(self.function_name, self.user, data,
org_id=self.org['id'])
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_can_check_job_boundaries_in_function_call__ok_for_frame_inside_job(self):
self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id'])
data = self.common_data.copy()
data.update(self._get_valid_job_params())
with self.subTest(job='defined', frame='inside'):
response = self._post_request(self.function_name, self.user, data,
org_id=self.org['id'])
self.assertEqual(response.status_code, status.HTTP_200_OK)
......@@ -8,6 +8,8 @@ import json
from functools import wraps
from enum import Enum
from copy import deepcopy
import textwrap
from typing import Any, Dict, Optional
import django_rq
import requests
......@@ -16,16 +18,17 @@ import os
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from rest_framework import status, viewsets
from rest_framework import status, viewsets, serializers
from rest_framework.response import Response
import cvat.apps.dataset_manager as dm
from cvat.apps.engine.frame_provider import FrameProvider
from cvat.apps.engine.models import Task as TaskModel
from cvat.apps.engine.models import Job, Task
from cvat.apps.engine.serializers import LabeledDataSerializer
from cvat.apps.engine.models import ShapeType, SourceType
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiResponse, OpenApiParameter
from drf_spectacular.utils import (extend_schema, extend_schema_view,
OpenApiResponse, OpenApiParameter, inline_serializer)
from drf_spectacular.types import OpenApiTypes
class LambdaType(Enum):
......@@ -175,8 +178,13 @@ class LambdaFunction:
return response
def invoke(self, db_task, data):
def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = None):
try:
if db_job is not None and db_job.get_task_id() != db_task.id:
raise ValidationError("Job task id does not match task id",
code=status.HTTP_400_BAD_REQUEST
)
payload = {}
data = {k: v for k,v in data.items() if v is not None}
threshold = data.get("threshold")
......@@ -225,6 +233,16 @@ class LambdaFunction:
if mapped_attr in task_attr_names:
supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] })
# Check job frame boundaries
for key, desc in (
('frame', 'frame'),
('frame0', 'start frame'),
('frame1', 'end frame'),
):
if key in data and db_job and not db_job.segment.contains_frame(data[key]):
raise ValidationError(f"The {desc} is outside the job range",
code=status.HTTP_400_BAD_REQUEST)
if self.kind == LambdaType.DETECTOR:
payload.update({
"image": self._get_image(db_task, data["frame"], quality)
......@@ -647,7 +665,7 @@ class LambdaJob:
@staticmethod
def __call__(function, task, quality, cleanup, **kwargs):
# TODO: need logging
db_task = TaskModel.objects.get(pk=task)
db_task = Task.objects.get(pk=task)
if cleanup:
dm.task.delete_task_data(db_task.id)
db_labels = (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all()
......@@ -685,7 +703,7 @@ def return_response(success_code=status.HTTP_200_OK):
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
data = str(err)
except ValidationError as err:
status_code = err.code
status_code = err.code or status.HTTP_400_BAD_REQUEST
data = err.message
except ObjectDoesNotExist as err:
status_code = status.HTTP_400_BAD_REQUEST
......@@ -725,12 +743,34 @@ class FunctionViewSet(viewsets.ViewSet):
gateway = LambdaGateway()
return gateway.get(func_id).to_dict()
@extend_schema(description=textwrap.dedent("""\
Allows to execute a function for immediate computation.
Intended for short-lived executions, useful for interactive calls.
When executed for interactive annotation, the job id must be specified
in the 'job' input field. The task id is not required in this case,
but if it is specified, it must match the job task id.
"""),
request=inline_serializer("OnlineFunctionCall", fields={
"job": serializers.IntegerField(required=False),
"task": serializers.IntegerField(required=False),
}),
responses=OpenApiResponse(description="Returns function invocation results")
)
@return_response()
def call(self, request, func_id):
self.check_object_permissions(request, func_id)
try:
task_id = request.data['task']
db_task = TaskModel.objects.get(pk=task_id)
job_id = request.data.get('job')
job = None
if job_id is not None:
job = Job.objects.get(id=job_id)
task_id = job.get_task_id()
else:
task_id = request.data['task']
db_task = Task.objects.get(pk=task_id)
except (KeyError, ObjectDoesNotExist) as err:
raise ValidationError(
'`{}` lambda function was run '.format(func_id) +
......@@ -740,7 +780,7 @@ class FunctionViewSet(viewsets.ViewSet):
gateway = LambdaGateway()
lambda_func = gateway.get(func_id)
return lambda_func.invoke(db_task, request.data)
return lambda_func.invoke(db_task, request.data, db_job=job)
@extend_schema(tags=['lambda'])
@extend_schema_view(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册