未验证 提交 2584b963 编写于 作者: K Kirill Sizov 提交者: GitHub

Fix comments for PR 5892 (#6206)

### Motivation and context
Applied comments:
- [x] Aligning ORGANIZATION_OPEN_API_PARAMETERS
https://github.com/opencv/cvat/pull/5892#discussion_r1207319660
- [x] Moving ORGANIZATION_OPEN_API_PARAMETERS
https://github.com/opencv/cvat/pull/5892#discussion_r1207320699
~~- [ ] Moving CustomerAutoSchema
https://github.com/opencv/cvat/pull/5892#discussion_r1207326857~~ this
uncritical comment that cannot be done easily, [see
answer](https://github.com/opencv/cvat/pull/5892#discussion_r1209015397)
- [x] Raise error if cannot get `organization_id` for objects
https://github.com/opencv/cvat/pull/5892#discussion_r1207365213
- [x] Multiply fields for `iam_organization_field`
https://github.com/opencv/cvat/pull/5892#issuecomment-1566841192Co-authored-by: NMaxim Zhiltsov <zhiltsov.max35@gmail.com>
上级 21503b3d
......@@ -546,9 +546,9 @@ class Label(models.Model):
@property
def organization_id(self):
if self.project is not None:
return self.project.organization.id
return self.project.organization_id
if self.task is not None:
return self.task.organization.id
return self.task.organization_id
return None
class Meta:
......
......@@ -4,12 +4,12 @@
import textwrap
from typing import Type
from rest_framework import serializers
from drf_spectacular.utils import OpenApiParameter
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import force_instance, build_basic_type
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.plumbing import build_basic_type, force_instance
from drf_spectacular.serializers import PolymorphicProxySerializerExtension
from drf_spectacular.types import OpenApiTypes
from rest_framework import serializers
def _copy_serializer(
......@@ -229,27 +229,5 @@ class CloudStorageReadSerializerExtension(_CloudStorageSerializerExtension):
class CloudStorageWriteSerializerExtension(_CloudStorageSerializerExtension):
target_class = 'cvat.apps.engine.serializers.CloudStorageWriteSerializer'
ORGANIZATION_OPEN_API_PARAMETERS = [
OpenApiParameter(
name='org',
type=str,
required=False,
location=OpenApiParameter.QUERY,
description="Organization unique slug",
),
OpenApiParameter(
name='org_id',
type=int,
required=False,
location=OpenApiParameter.QUERY,
description="Organization identifier",
),
OpenApiParameter(
name='X-Organization',
type=str,
required=False,
location=OpenApiParameter.HEADER
),
]
__all__ = [] # No public symbols here
......@@ -64,7 +64,6 @@ from cvat.apps.engine.serializers import (
CloudStorageReadSerializer, DatasetFileSerializer,
ProjectFileSerializer, TaskFileSerializer, CloudStorageContentSerializer)
from cvat.apps.engine.view_utils import get_cloud_storage_for_import_or_export
from cvat.apps.engine.schema import ORGANIZATION_OPEN_API_PARAMETERS
from utils.dataset_manifest import ImageManifestManager
from cvat.apps.engine.utils import (
......@@ -79,6 +78,7 @@ from .log import slogger
from cvat.apps.iam.permissions import (CloudStoragePermission,
CommentPermission, IssuePermission, JobPermission, LabelPermission, ProjectPermission,
TaskPermission, UserPermission)
from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
from cvat.apps.engine.cache import MediaCache
from cvat.apps.events.handlers import handle_annotations_patch
from cvat.apps.engine.view_utils import tus_chunk_action
......@@ -1824,10 +1824,7 @@ class LabelViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
'project__organization'
).all()
# NOTE: This filter works incorrectly for this view
# it requires task__organization OR project__organization check.
# Thus, we rely on permission-based filtering
iam_organization_field = None
iam_organization_field = ('task__organization', 'project__organization')
search_fields = ('name', 'parent')
filter_fields = list(search_fields) + ['id', 'type', 'color', 'parent_id']
......
......@@ -11,9 +11,9 @@ from rest_framework.renderers import JSONRenderer
from cvat.apps.iam.permissions import EventsPermission
from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
from cvat.apps.events.serializers import ClientEventsSerializer
from cvat.apps.engine.log import vlogger
from cvat.apps.engine.schema import ORGANIZATION_OPEN_API_PARAMETERS
from .export import export
class EventsViewSet(viewsets.ViewSet):
......
......@@ -3,13 +3,65 @@
# SPDX-License-Identifier: MIT
from rest_framework.filters import BaseFilterBackend
from django.db.models import Q
from collections.abc import Iterable
from drf_spectacular.utils import OpenApiParameter
ORGANIZATION_OPEN_API_PARAMETERS = [
OpenApiParameter(
name='org',
type=str,
required=False,
location=OpenApiParameter.QUERY,
description="Organization unique slug",
),
OpenApiParameter(
name='org_id',
type=int,
required=False,
location=OpenApiParameter.QUERY,
description="Organization identifier",
),
OpenApiParameter(
name='X-Organization',
type=str,
required=False,
location=OpenApiParameter.HEADER,
description="Organization unique slug",
),
]
class OrganizationFilterBackend(BaseFilterBackend):
organization_slug = 'org'
organization_slug_description = 'Organization unique slug'
organization_id = 'org_id'
organization_id_description = 'Organization identifier'
organization_slug_header = 'X-Organization'
def _parameter_is_provided(self, request):
for parameter in ORGANIZATION_OPEN_API_PARAMETERS:
if parameter.location == 'header' and parameter.name in request.headers:
return True
elif parameter.location == 'query' and parameter.name in request.query_params:
return True
return False
def _construct_filter_query(self, organization_fields, org_id):
if isinstance(organization_fields, str):
return Q(**{organization_fields: org_id})
if isinstance(organization_fields, Iterable):
# we select all db records where AT LEAST ONE organization field is equal org_id
operation = Q.OR
if org_id is None:
# but to get all non-org objects we need select db records where ALL organization fields are None
operation = Q.AND
filter_query = Q()
for org_field in organization_fields:
filter_query.add(Q(**{org_field: org_id}), operation)
return filter_query
return Q()
def filter_queryset(self, request, queryset, view):
# Filter works only for "list" requests and allows to return
......@@ -24,16 +76,14 @@ class OrganizationFilterBackend(BaseFilterBackend):
if org:
visibility = {'organization': org.id}
elif not org and (
self.organization_slug in request.query_params
or self.organization_id in request.query_params
or self.organization_slug_header in request.headers
):
elif not org and self._parameter_is_provided(request):
visibility = {'organization': None}
if visibility:
visibility[view.iam_organization_field] = visibility.pop('organization')
return queryset.filter(**visibility).distinct()
org_id = visibility.pop("organization")
query = self._construct_filter_query(view.iam_organization_field, org_id)
return queryset.filter(query).distinct()
return queryset
......@@ -41,23 +91,20 @@ class OrganizationFilterBackend(BaseFilterBackend):
if not view.iam_organization_field or view.detail:
return []
return [
{
'name': self.organization_slug,
'in': 'query',
'description': self.organization_slug_description,
'schema': {'type': 'string'},
},
{
'name': self.organization_id,
'in': 'query',
'description': self.organization_id_description,
'schema': {'type': 'integer'},
},
{
'name': self.organization_slug_header,
'in': 'header',
'description': self.organization_slug_description,
'schema': {'type': 'string'},
},
]
parameters = []
for parameter in ORGANIZATION_OPEN_API_PARAMETERS:
parameter_type = None
if parameter.type == int:
parameter_type = 'integer'
elif parameter.type == str:
parameter_type = 'string'
parameters.append({
'name': parameter.name,
'in': parameter.location,
'description': parameter.description,
'schema': {'type': parameter_type}
})
return parameters
......@@ -14,11 +14,12 @@ from typing import Any, Dict, List, Optional, Sequence, Union, cast
from attrs import define, field
from django.conf import settings
from django.db.models import Q
from rest_framework.exceptions import ValidationError, PermissionDenied
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.permissions import BasePermission
from cvat.apps.engine.models import (CloudStorage, Issue, Job, Label, Project,
Task)
from cvat.apps.organizations.models import Membership, Organization
from cvat.apps.engine.models import CloudStorage, Label, Project, Task, Job, Issue
from cvat.apps.webhooks.models import WebhookTypeChoice
from cvat.utils.http import make_requests_session
......@@ -56,14 +57,24 @@ def get_organization(request, obj):
return obj
if obj:
if organization_id := getattr(obj, "organization_id", None):
try:
return Organization.objects.get(id=organization_id)
except Organization.DoesNotExist:
try:
organization_id = getattr(obj, 'organization_id')
except AttributeError as exc:
# Skip initialization of organization for those objects that don't related with organization
view = request.parser_context.get('view')
if view and view.basename in ('user', 'function', 'request'):
return None
return None
return request.iam_context["organization"]
raise exc
try:
return Organization.objects.get(id=organization_id)
except Organization.DoesNotExist:
return None
return request.iam_context['organization']
def get_membership(request, organization):
if organization is None:
......@@ -80,12 +91,13 @@ def get_iam_context(request, obj):
membership = get_membership(request, organization)
if organization and not request.user.is_superuser and membership is None:
raise PermissionDenied({"message": "You should be an active member in the organization"})
raise PermissionDenied({'message': 'You should be an active member in the organization'})
return {
'user_id': request.user.id,
'group_name': request.iam_context['privilege'],
'org_id': getattr(organization, 'id', None),
'org_slug': getattr(organization, 'slug', None),
'org_owner_id': getattr(organization.owner, 'id', None)
if organization else None,
'org_role': getattr(membership, 'role', None),
......
......@@ -5,9 +5,10 @@
import re
import textwrap
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.authentication import SessionScheme, TokenScheme
from drf_spectacular.extensions import OpenApiAuthenticationExtension
from drf_spectacular.authentication import TokenScheme, SessionScheme
from drf_spectacular.openapi import AutoSchema
class SignatureAuthenticationScheme(OpenApiAuthenticationExtension):
......@@ -95,3 +96,4 @@ class CustomAutoSchema(AutoSchema):
tokenized_path.append('formatted')
return '_'.join([tokenized_path[0]] + [action] + tokenized_path[1:])
......@@ -32,9 +32,9 @@ import cvat.apps.dataset_manager as dm
from cvat.apps.engine.frame_provider import FrameProvider
from cvat.apps.engine.models import Job, ShapeType, SourceType, Task
from cvat.apps.engine.serializers import LabeledDataSerializer
from cvat.apps.engine.schema import ORGANIZATION_OPEN_API_PARAMETERS
from cvat.utils.http import make_requests_session
from cvat.apps.iam.permissions import LambdaPermission
from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
class LambdaType(Enum):
......
......@@ -9,10 +9,10 @@ from django.utils.crypto import get_random_string
from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_view
from cvat.apps.engine.mixins import PartialUpdateModelMixin
from cvat.apps.engine.schema import ORGANIZATION_OPEN_API_PARAMETERS
from cvat.apps.iam.permissions import (
InvitationPermission, MembershipPermission, OrganizationPermission)
from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
from .models import Invitation, Membership, Organization
from .serializers import (
......
......@@ -10,9 +10,9 @@ from rest_framework.decorators import action
from rest_framework.permissions import SAFE_METHODS
from rest_framework.response import Response
from cvat.apps.engine.schema import ORGANIZATION_OPEN_API_PARAMETERS
from cvat.apps.engine.view_utils import list_action, make_paginated_response
from cvat.apps.iam.permissions import WebhookPermission
from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
from .event_type import AllEvents, OrganizationEvents, ProjectEvents
from .models import Webhook, WebhookDelivery, WebhookTypeChoice
......
......@@ -338,6 +338,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: org
schema:
......@@ -796,6 +797,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: org
schema:
......@@ -991,6 +993,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: org
schema:
......@@ -1101,6 +1104,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: org
schema:
......@@ -1315,6 +1319,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: org
schema:
......@@ -2052,6 +2057,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- name: color
in: query
description: A simple equality filter for the color field
......@@ -2337,6 +2343,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: org
schema:
......@@ -2807,6 +2814,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: org
schema:
......@@ -3221,6 +3229,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: cloud_storage_id
schema:
......@@ -3640,6 +3649,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: org
schema:
......@@ -4337,6 +4347,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: cloud_storage_id
schema:
......@@ -4673,6 +4684,7 @@ paths:
name: X-Organization
schema:
type: string
description: Organization unique slug
- in: query
name: org
schema:
......
......@@ -16,7 +16,7 @@ from cvat_sdk.api_client.model.file_info import FileInfo
from deepdiff import DeepDiff
from PIL import Image
from shared.utils.config import make_api_client
from shared.utils.config import get_method, make_api_client
from .utils import CollectionSimpleFilterTestBase
......@@ -575,3 +575,22 @@ class TestGetCloudStorageContent:
break
assert expected_content == current_content
@pytest.mark.usefixtures("restore_db_per_class")
class TestListCloudStorages:
def _test_can_see_cloud_storages(self, user, data, **kwargs):
response = get_method(user, "cloudstorages", **kwargs)
assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json()["results"]) == {}
def test_admin_can_see_all_cloud_storages(self, cloud_storages):
self._test_can_see_cloud_storages("admin2", cloud_storages.raw, page_size="all")
@pytest.mark.parametrize("field_value, query_value", [(2, 2), (None, "")])
def test_can_filter_by_org_id(self, field_value, query_value, cloud_storages):
cloud_storages = filter(lambda i: i["organization"] == field_value, cloud_storages)
self._test_can_see_cloud_storages(
"admin2", list(cloud_storages), page_size="all", org_id=query_value
)
......@@ -7,8 +7,9 @@ from http import HTTPStatus
import pytest
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
from deepdiff import DeepDiff
from shared.utils.config import post_method
from shared.utils.config import get_method, post_method
from .utils import CollectionSimpleFilterTestBase
......@@ -120,3 +121,22 @@ class TestInvitationsListFilters(CollectionSimpleFilterTestBase):
)
def test_can_use_simple_filter_for_object_list(self, field):
return super().test_can_use_simple_filter_for_object_list(field)
@pytest.mark.usefixtures("restore_db_per_class")
class TestListInvitations:
def _test_can_see_invitations(self, user, data, **kwargs):
response = get_method(user, "invitations", **kwargs)
assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json()["results"]) == {}
def test_admin_can_see_all_invitations(self, invitations):
self._test_can_see_invitations("admin2", invitations.raw, page_size="all")
@pytest.mark.parametrize("field_value, query_value", [(1, 1), (None, "")])
def test_can_filter_by_org_id(self, field_value, query_value, invitations):
invitations = filter(lambda i: i["organization"] == field_value, invitations)
self._test_can_see_invitations(
"admin2", list(invitations), page_size="all", org_id=query_value
)
......@@ -13,7 +13,7 @@ from cvat_sdk import models
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
from deepdiff import DeepDiff
from shared.utils.config import make_api_client
from shared.utils.config import get_method, make_api_client
from .utils import CollectionSimpleFilterTestBase
......@@ -394,3 +394,20 @@ class TestCommentsListFilters(CollectionSimpleFilterTestBase):
)
def test_can_use_simple_filter_for_object_list(self, field):
return super().test_can_use_simple_filter_for_object_list(field)
@pytest.mark.usefixtures("restore_db_per_class")
class TestListIssues:
def _test_can_see_issues(self, user, data, **kwargs):
response = get_method(user, "issues", **kwargs)
assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json()["results"]) == {}
def test_admin_can_see_all_issues(self, issues):
self._test_can_see_issues("admin2", issues.raw, page_size="all")
@pytest.mark.parametrize("field_value, query_value", [(1, 1), (None, "")])
def test_can_filter_by_org_id(self, field_value, query_value, issues, jobs):
issues = filter(lambda i: jobs[i["job"]]["organization"] == field_value, issues)
self._test_can_see_issues("admin2", list(issues), page_size="all", org_id=query_value)
......@@ -30,6 +30,13 @@ class TestGetMemberships:
def test_admin_can_see_all_memberships(self, memberships):
self._test_can_see_memberships("admin2", memberships.raw, page_size="all")
@pytest.mark.parametrize("field_value, query_value", [(1, 1), (None, "")])
def test_can_filter_by_org_id(self, field_value, query_value, memberships):
memberships = filter(lambda m: m["organization"] == field_value, memberships)
self._test_can_see_memberships(
"admin2", list(memberships), page_size="all", org_id=query_value
)
def test_non_admin_can_see_only_self_memberships(self, memberships):
non_admins = ["business1", "user1", "dummy1", "worker2"]
for username in non_admins:
......
......@@ -213,9 +213,7 @@ class TestGetProjectBackup:
and is_org_member(user["id"], project["organization"])
)
self._test_cannot_get_project_backup(
user["username"], project["id"], org_id=project["organization"]
)
self._test_cannot_get_project_backup(user["username"], project["id"])
# Org worker that in [project:owner, project:assignee] can get project backup.
def test_org_worker_can_get_project_backup(
......@@ -231,9 +229,7 @@ class TestGetProjectBackup:
and is_org_member(user["id"], project["organization"])
)
self._test_can_get_project_backup(
user["username"], project["id"], org_id=project["organization"]
)
self._test_can_get_project_backup(user["username"], project["id"])
# Org supervisor that in [project:owner, project:assignee] can get project backup.
def test_org_supervisor_can_get_project_backup(
......@@ -249,9 +245,7 @@ class TestGetProjectBackup:
and is_org_member(user["id"], project["organization"])
)
self._test_can_get_project_backup(
user["username"], project["id"], org_id=project["organization"]
)
self._test_can_get_project_backup(user["username"], project["id"])
# Org supervisor that not in [project:owner, project:assignee] cannot get project backup.
def test_org_supervisor_cannot_get_project_backup(
......@@ -267,9 +261,7 @@ class TestGetProjectBackup:
and is_org_member(user["id"], project["organization"])
)
self._test_cannot_get_project_backup(
user["username"], project["id"], org_id=project["organization"]
)
self._test_cannot_get_project_backup(user["username"], project["id"])
# Org maintainer that not in [project:owner, project:assignee] can get project backup.
def test_org_maintainer_can_get_project_backup(
......@@ -285,9 +277,7 @@ class TestGetProjectBackup:
and is_org_member(user["id"], project["organization"])
)
self._test_can_get_project_backup(
user["username"], project["id"], org_id=project["organization"]
)
self._test_can_get_project_backup(user["username"], project["id"])
# Org owner that not in [project:owner, project:assignee] can get project backup.
def test_org_owner_can_get_project_backup(
......@@ -303,9 +293,7 @@ class TestGetProjectBackup:
and is_org_member(user["id"], project["organization"])
)
self._test_can_get_project_backup(
user["username"], project["id"], org_id=project["organization"]
)
self._test_can_get_project_backup(user["username"], project["id"])
@pytest.mark.usefixtures("restore_db_per_function")
......@@ -794,7 +782,6 @@ class TestPatchProjectLabel:
user["username"],
f'projects/{project["id"]}',
{"labels": [new_label]},
org_id=project["organization"],
)
assert response.status_code == HTTPStatus.OK
assert response.json()["labels"]["count"] == project["labels"]["count"] + 1
......@@ -823,7 +810,6 @@ class TestPatchProjectLabel:
user["username"],
f'projects/{project["id"]}',
{"labels": [new_label]},
org_id=project["organization"],
)
assert response.status_code == HTTPStatus.FORBIDDEN
......@@ -848,7 +834,6 @@ class TestPatchProjectLabel:
user["username"],
f'projects/{project["id"]}',
{"labels": [new_label]},
org_id=project["organization"],
)
assert response.status_code == HTTPStatus.OK
assert response.json()["labels"]["count"] == project["labels"]["count"] + 1
......
......@@ -1317,7 +1317,6 @@ class TestPatchTaskLabel:
user["username"],
f'tasks/{task["id"]}',
{"labels": [new_label]},
org_id=task["organization"],
)
assert response.status_code == HTTPStatus.OK
assert response.json()["labels"]["count"] == task["labels"]["count"] + 1
......@@ -1346,7 +1345,6 @@ class TestPatchTaskLabel:
user["username"],
f'tasks/{task["id"]}',
{"labels": [new_label]},
org_id=task["organization"],
)
assert response.status_code == HTTPStatus.FORBIDDEN
......@@ -1371,7 +1369,6 @@ class TestPatchTaskLabel:
user["username"],
f'tasks/{task["id"]}',
{"labels": [new_label]},
org_id=task["organization"],
)
assert response.status_code == HTTPStatus.OK
assert response.json()["labels"]["count"] == task["labels"]["count"] + 1
......
......@@ -773,6 +773,14 @@ class TestGetListWebhooks:
assert response.status_code == HTTPStatus.OK
assert DeepDiff(expected_response, response.json()["results"], ignore_order=True) == {}
@pytest.mark.parametrize("field_value, query_value", [(1, 1), (None, "")])
def test_can_filter_by_org_id(self, field_value, query_value, webhooks):
webhooks = filter(lambda w: w["organization"] == field_value, webhooks)
response = get_method("admin2", f"webhooks", org_id=query_value)
assert response.status_code == HTTPStatus.OK
assert DeepDiff(list(webhooks), response.json()["results"], ignore_order=True) == {}
@pytest.mark.usefixtures("restore_db_per_function")
class TestPatchWebhooks:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册