未验证 提交 838f1daf 编写于 作者: M Maxim Zhiltsov 提交者: GitHub

Allow typed simple filters (#5760)

上级 a491c56e
......@@ -57,7 +57,7 @@ class Issue(
return [
Comment(self._client, m)
for m in get_paginated_collection(
self._client.api_client.comments_api.list_endpoint, issue_id=str(self.id)
self._client.api_client.comments_api.list_endpoint, issue_id=self.id
)
]
......
......@@ -153,7 +153,7 @@ class Job(
def get_labels(self) -> List[models.ILabel]:
return get_paginated_collection(
self._client.api_client.labels_api.list_endpoint, job_id=str(self.id)
self._client.api_client.labels_api.list_endpoint, job_id=self.id
)
def get_frames_info(self) -> List[models.IFrameMeta]:
......@@ -169,13 +169,10 @@ class Job(
return [
Issue(self._client, m)
for m in get_paginated_collection(
self._client.api_client.issues_api.list_endpoint, job_id=str(self.id)
self._client.api_client.issues_api.list_endpoint, job_id=self.id
)
]
def get_commits(self) -> List[models.IJobCommit]:
return get_paginated_collection(self.api.list_commits_endpoint, id=self.id)
class JobsRepo(
_JobRepoBase,
......
......@@ -128,13 +128,13 @@ class Project(
return [
Task(self._client, m)
for m in get_paginated_collection(
self._client.api_client.tasks_api.list_endpoint, project_id=str(self.id)
self._client.api_client.tasks_api.list_endpoint, project_id=self.id
)
]
def get_labels(self) -> List[models.ILabel]:
return get_paginated_collection(
self._client.api_client.labels_api.list_endpoint, project_id=str(self.id)
self._client.api_client.labels_api.list_endpoint, project_id=self.id
)
def get_preview(
......
......@@ -306,7 +306,7 @@ class Task(
return [
Job(self._client, model=m)
for m in get_paginated_collection(
self._client.api_client.jobs_api.list_endpoint, task_id=str(self.id)
self._client.api_client.jobs_api.list_endpoint, task_id=self.id
)
]
......@@ -316,7 +316,7 @@ class Task(
def get_labels(self) -> List[models.ILabel]:
return get_paginated_collection(
self._client.api_client.labels_api.list_endpoint, task_id=str(self.id)
self._client.api_client.labels_api.list_endpoint, task_id=self.id
)
def get_frames_info(self) -> List[models.IFrameMeta]:
......
......@@ -9,6 +9,7 @@ import operator
import json
from django_filters import FilterSet
from django_filters import filters as djf
from django_filters.filterset import BaseFilterSet
from django_filters.rest_framework import DjangoFilterBackend
from django.db.models import Q
......@@ -283,7 +284,7 @@ class SimpleFilter(DjangoFilterBackend):
def get_filterset_class(self, view, queryset=None):
lookup_fields = self.get_lookup_fields(view)
if not lookup_fields or not queryset:
if not lookup_fields or queryset is None:
return None
MetaBase = getattr(self.filterset_base, 'Meta', object)
......@@ -306,33 +307,37 @@ class SimpleFilter(DjangoFilterBackend):
return get_lookup_fields(view, fields=simple_filters)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
lookup_fields = self.get_lookup_fields(view)
return [
coreapi.Field(
name=field_name,
location='query',
schema={
'type': 'string',
}
) for field_name in lookup_fields
]
def get_schema_operation_parameters(self, view):
lookup_fields = self.get_lookup_fields(view)
queryset = view.get_queryset()
filterset_class = self.get_filterset_class(view, queryset)
if not filterset_class:
return []
parameters = []
for field_name in lookup_fields:
parameters.append({
for field_name, filter_ in filterset_class.base_filters.items():
if isinstance(filter_, djf.BooleanFilter):
parameter_schema = { 'type': 'boolean' }
elif isinstance(filter_, (djf.NumberFilter, djf.ModelChoiceFilter)):
parameter_schema = { 'type': 'integer' }
elif isinstance(filter_, (djf.CharFilter, djf.ChoiceFilter)):
# Choices use their labels as filter values
parameter_schema = { 'type': 'string' }
else:
raise Exception("Filter field '{}' type '{}' is not supported".format(
'.'.join([view.basename, view.action, field_name]),
filter_
))
parameter = {
'name': field_name,
'in': 'query',
'description': force_str(self.filter_desc.format_map({'field_name': field_name})),
'schema': {
'type': 'string',
},
})
'description': force_str(self.filter_desc.format_map({
'field_name': filter_.label if filter_.label is not None else field_name
})),
'schema': parameter_schema,
}
if filter_.extra and 'choices' in filter_.extra:
parameter['schema']['enum'] = [c[0] for c in filter_.extra['choices']]
parameters.append(parameter)
return parameters
......@@ -243,6 +243,10 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
if self.action == 'list':
perm = ProjectPermission.create_scope_list(self.request)
queryset = perm.filter(queryset)
......@@ -708,6 +712,10 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
if self.action == 'list':
perm = TaskPermission.create_scope_list(self.request)
queryset = perm.filter(queryset)
......@@ -1302,6 +1310,9 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
if self.action == 'list':
perm = JobPermission.create_scope_list(self.request)
queryset = perm.filter(queryset)
......@@ -1718,6 +1729,10 @@ class IssueViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
if self.action == 'list':
perm = IssuePermission.create_scope_list(self.request)
queryset = perm.filter(queryset)
......@@ -1786,6 +1801,10 @@ class CommentViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
if self.action == 'list':
perm = CommentPermission.create_scope_list(self.request)
queryset = perm.filter(queryset)
......@@ -1813,9 +1832,12 @@ class CommentViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
summary='Method returns a paginated list of labels',
parameters=[
# These filters are implemented differently from others
OpenApiParameter('job_id', description='A simple equality filter for job id'),
OpenApiParameter('task_id', description='A simple equality filter for task id'),
OpenApiParameter('project_id', description='A simple equality filter for project id'),
OpenApiParameter('job_id', type=OpenApiTypes.INT,
description='A simple equality filter for job id'),
OpenApiParameter('task_id', type=OpenApiTypes.INT,
description='A simple equality filter for task id'),
OpenApiParameter('project_id', type=OpenApiTypes.INT,
description='A simple equality filter for project id'),
],
responses={
'200': LabelSerializer(many=True),
......@@ -1866,6 +1888,9 @@ class LabelViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
serializer_class = LabelSerializer
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return super().get_queryset()
if self.action == 'list':
job_id = self.request.GET.get('job_id', None)
task_id = self.request.GET.get('task_id', None)
......@@ -1986,6 +2011,10 @@ class UserViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
if self.action == 'list':
perm = UserPermission.create_scope_list(self.request)
queryset = perm.filter(queryset)
......@@ -2077,6 +2106,10 @@ class CloudStorageViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
if self.action == 'list':
perm = CloudStoragePermission.create_scope_list(self.request)
queryset = perm.filter(queryset)
......
......@@ -68,6 +68,10 @@ class OrganizationViewSet(viewsets.GenericViewSet,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
permission = OrganizationPermission.create_scope_list(self.request)
return permission.filter(queryset)
......@@ -131,6 +135,10 @@ class MembershipViewSet(mixins.RetrieveModelMixin, DestroyModelMixin,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
permission = MembershipPermission.create_scope_list(self.request)
return permission.filter(queryset)
......@@ -190,6 +198,10 @@ class InvitationViewSet(viewsets.GenericViewSet,
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
permission = InvitationPermission.create_scope_list(self.request)
return permission.filter(queryset)
......
......@@ -93,6 +93,10 @@ class WebhookViewSet(viewsets.ModelViewSet):
def get_queryset(self):
queryset = super().get_queryset()
if getattr(self, 'swagger_fake_view', False):
return queryset
if self.action == "list":
perm = WebhookPermission.create_scope_list(self.request)
queryset = perm.filter(queryset)
......
......@@ -36,7 +36,7 @@ class TestPostIssues:
assert user == response_json["owner"]["username"]
with make_api_client(user) as client:
(comments, _) = client.comments_api.list(issue_id=str(response_json["id"]))
(comments, _) = client.comments_api.list(issue_id=response_json["id"])
assert data["message"] == comments.results[0].message
assert (
......
......@@ -268,7 +268,7 @@ class TestLabelsListFilters(CollectionSimpleFilterTestBase):
else:
assert False
kwargs[key] = str(v)
kwargs[key] = v
with pytest.raises(exceptions.ApiException) as capture:
self._retrieve_collection(**kwargs)
......@@ -307,12 +307,12 @@ class TestLabelsListFilters(CollectionSimpleFilterTestBase):
dst_obj = next(
t for t in self.task_samples if t.get(f"{src}_id") == src_with_labels["id"]
)
kwargs["task_id"] = str(dst_obj["id"])
kwargs["task_id"] = dst_obj["id"]
elif dst == "job":
dst_obj = next(
j for j in self.job_samples if j.get(f"{src}_id") == src_with_labels["id"]
)
kwargs["job_id"] = str(dst_obj["id"])
kwargs["job_id"] = dst_obj["id"]
else:
assert False
......@@ -428,7 +428,7 @@ class TestListLabels(_TestLabelsPermissionsBase):
kwargs = {
"org_id": org_id,
f"{source_type}_id": str(source["id"]),
f"{source_type}_id": source["id"],
}
if staff:
......@@ -480,7 +480,7 @@ class TestListLabels(_TestLabelsPermissionsBase):
kwargs = {
"org_id": org_id,
f"{source_type}_id": str(source["id"]),
f"{source_type}_id": source["id"],
}
self._test_list_ok(admin_user, source_labels, **kwargs)
......
......@@ -657,7 +657,7 @@ class TestPatchProjectLabel:
kwargs.setdefault("return_json", True)
with make_api_client(user) as api_client:
return get_paginated_collection(
api_client.labels_api.list_endpoint, project_id=str(pid), **kwargs
api_client.labels_api.list_endpoint, project_id=pid, **kwargs
)
def test_can_delete_label(self, projects, labels, admin_user):
......
......@@ -55,7 +55,7 @@ class TestGetTasks:
results = get_paginated_collection(
api_client.tasks_api.list_endpoint,
return_json=True,
project_id=str(project_id),
project_id=project_id,
**kwargs,
)
assert DeepDiff(data, results, ignore_order=True, exclude_paths=exclude_paths) == {}
......@@ -919,7 +919,7 @@ class TestPostTaskData:
with make_api_client(self._USERNAME) as api_client:
jobs: List[models.JobRead] = get_paginated_collection(
api_client.jobs_api.list_endpoint, task_id=str(task_id), sort="id"
api_client.jobs_api.list_endpoint, task_id=task_id, sort="id"
)
(task_meta, _) = api_client.tasks_api.retrieve_data_meta(id=task_id)
......@@ -967,7 +967,7 @@ class TestPatchTaskLabel:
kwargs.setdefault("return_json", True)
with make_api_client(user) as api_client:
return get_paginated_collection(
api_client.labels_api.list_endpoint, task_id=str(pid), **kwargs
api_client.labels_api.list_endpoint, task_id=pid, **kwargs
)
def test_can_delete_label(self, tasks, labels, admin_user):
......
......@@ -104,7 +104,7 @@ class CollectionSimpleFilterTestBase(metaclass=ABCMeta):
def test_can_use_simple_filter_for_object_list(self, field):
value, gt_objects = self._get_field_samples(field)
received_items = self._retrieve_collection(**{field: str(value)})
received_items = self._retrieve_collection(**{field: value})
self._compare_results(gt_objects, received_items)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册