提交 9be3cbb9 编写于 作者: X xinwen 提交者: baltery

perf: 优化用户详情页授权列表加载速度&添加可重入锁

上级 e599bca9
......@@ -4,9 +4,7 @@ from assets.api import FilterAssetByNodeMixin
from rest_framework.viewsets import ModelViewSet
from rest_framework.generics import RetrieveAPIView
from django.shortcuts import get_object_or_404
from django.utils.decorators import method_decorator
from assets.locks import NodeTreeUpdateLock
from common.utils import get_logger, get_object_or_none
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsSuperUser
from orgs.mixins.api import OrgBulkModelViewSet
......
......@@ -2,7 +2,7 @@ from typing import List
from common.utils.common import timeit
from assets.models import Node, Asset
from assets.pagination import AssetLimitOffsetPagination
from assets.pagination import NodeAssetTreePagination
from common.utils import lazyproperty
from assets.utils import get_node, is_query_node_all_assets
......@@ -81,7 +81,7 @@ class SerializeToTreeNodeMixin:
class FilterAssetByNodeMixin:
pagination_class = AssetLimitOffsetPagination
pagination_class = NodeAssetTreePagination
@lazyproperty
def is_query_node_all_assets(self):
......
......@@ -8,7 +8,6 @@ from rest_framework.response import Response
from rest_framework.decorators import action
from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404, Http404
from django.utils.decorators import method_decorator
from django.db.models.signals import m2m_changed
from common.const.http import POST
......@@ -25,10 +24,10 @@ from ..models import Node
from ..tasks import (
update_node_assets_hardware_info_manual,
test_node_assets_connectivity_manual,
check_node_assets_amount_task
)
from .. import serializers
from .mixin import SerializeToTreeNodeMixin
from assets.locks import NodeTreeUpdateLock
logger = get_logger(__file__)
......@@ -54,6 +53,11 @@ class NodeViewSet(OrgModelViewSet):
serializer.validated_data["key"] = child_key
serializer.save()
@action(methods=[POST], detail=False, url_path='check_assets_amount_task')
def check_assets_amount_task(self, request):
task = check_node_assets_amount_task.delay(current_org.id)
return Response(data={'task': task.id})
def perform_update(self, serializer):
node = self.get_object()
if node.is_org_root() and node.value != serializer.validated_data['value']:
......
......@@ -15,7 +15,6 @@ class NodeTreeUpdateLock(DistributedLock):
)
return name
def __init__(self, blocking=True):
def __init__(self):
name = self.get_name()
super().__init__(name=name, blocking=blocking,
release_lock_on_transaction_commit=True)
super().__init__(name=name, release_on_transaction_commit=True, reentrant=True)
# Generated by Django 3.1 on 2021-02-04 09:49
# Generated by Django 3.1 on 2021-02-08 10:02
from django.db import migrations
......@@ -10,8 +10,8 @@ class Migration(migrations.Migration):
]
operations = [
migrations.RemoveField(
model_name='node',
name='assets_amount',
migrations.AlterModelOptions(
name='asset',
options={'ordering': ['hostname'], 'verbose_name': 'Asset'},
),
]
......@@ -353,4 +353,4 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
class Meta:
unique_together = [('org_id', 'hostname')]
verbose_name = _("Asset")
ordering = ["hostname", "ip"]
ordering = ["hostname", ]
......@@ -425,11 +425,6 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
@property
def assets_amount(self):
assets_id = self.get_all_assets_id()
return len(assets_id)
def get_all_assets_id(self):
assets_id = self.get_all_assets_id_by_node_key(org_id=self.org_id, node_key=self.key)
return set(assets_id)
......@@ -550,6 +545,7 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
date_create = models.DateTimeField(auto_now_add=True)
parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"),
db_index=True, default='')
assets_amount = models.IntegerField(default=0)
objects = OrgManager.from_queryset(NodeQuerySet)()
is_node = True
......
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request
from common.utils import get_logger
from assets.models import Node
logger = get_logger(__name__)
class AssetPaginationBase(LimitOffsetPagination):
def init_attrs(self, queryset, request: Request, view=None):
self._request = request
self._view = view
self._user = request.user
def paginate_queryset(self, queryset, request: Request, view=None):
self.init_attrs(queryset, request, view)
return super().paginate_queryset(queryset, request, view=None)
class AssetLimitOffsetPagination(LimitOffsetPagination):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
def get_count(self, queryset):
"""
1. 如果查询节点下的所有资产,那 count 使用 Node.assets_amount
2. 如果有其他过滤条件使用 super
3. 如果只查询该节点下的资产使用 super
"""
exclude_query_params = {
self.limit_query_param,
self.offset_query_param,
'node', 'all', 'show_current_asset',
'node_id', 'display', 'draw', 'fields_size',
'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw',
'order', 'node', 'node_id', 'fields_size',
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}')
return super().get_count(queryset)
node_assets_count = self.get_count_from_nodes(queryset)
if node_assets_count is None:
return super().get_count(queryset)
return node_assets_count
def get_count_from_nodes(self, queryset):
raise NotImplementedError
class NodeAssetTreePagination(AssetPaginationBase):
def get_count_from_nodes(self, queryset):
is_query_all = self._view.is_query_node_all_assets
if is_query_all:
node = self._view.node
if not node:
node = Node.org_root()
logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> {self._request.get_full_path()}')
return node.assets_amount
return super().get_count(queryset)
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
return super().paginate_queryset(queryset, request, view=None)
return None
from .common import *
from .maintain_nodes_tree import *
from .node_assets_amount import *
from .node_assets_mapping import *
# -*- coding: utf-8 -*-
#
from operator import add, sub
from django.db.models import Q, F
from django.dispatch import receiver
from django.db.models.signals import (
m2m_changed
)
from orgs.utils import ensure_in_real_or_default_org
from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR
from common.utils import get_logger
from assets.models import Asset, Node, compute_parent_key
from assets.locks import NodeTreeUpdateLock
logger = get_logger(__file__)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
return
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator)
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator)
class NodeAssetsAmountUtils:
@classmethod
def _remove_ancestor_keys(cls, ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
@classmethod
def _is_asset_exists_in_node(cls, asset_pk, node_key):
exists = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).filter(id=asset_pk).exists()
return exists
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_nodes_asset_amount(cls, node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时,更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = cls._is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
cls._remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = cls._is_asset_exists_in_node(asset_pk, parent_key)
if exists:
cls._remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时,更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
from celery import shared_task
from django.utils.translation import gettext_lazy as _
from orgs.models import Organization
from orgs.utils import tmp_to_org
from ops.celery.decorator import register_as_period_task
from assets.utils import check_node_assets_amount
from common.utils.lock import AcquireFailed
from common.utils import get_logger
logger = get_logger(__file__)
@shared_task
def check_node_assets_amount_task(orgid=None):
if orgid is None:
orgs = [*Organization.objects.all(), Organization.default()]
else:
orgs = [Organization.get_instance(orgid)]
for org in orgs:
try:
with tmp_to_org(org):
check_node_assets_amount()
except AcquireFailed:
logger.error(_('The task of self-checking is already running and cannot be started repeatedly'))
@register_as_period_task(crontab='0 2 * * *')
@shared_task
def check_node_assets_amount_period_task():
check_node_assets_amount_task()
......@@ -5,12 +5,45 @@ from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none,
from common.http import is_true
from common.struct import Stack
from common.db.models import output_as_string
from orgs.utils import ensure_in_real_or_default_org, current_org
from .models import Node
from .locks import NodeTreeUpdateLock
from .models import Node, Asset
logger = get_logger(__file__)
@NodeTreeUpdateLock()
@ensure_in_real_or_default_org
def check_node_assets_amount():
logger.info(f'Check node assets amount {current_org}')
nodes = list(Node.objects.all().only('id', 'key', 'assets_amount'))
nodeid_assetid_pairs = list(Asset.nodes.through.objects.all().values_list('node_id', 'asset_id'))
nodekey_assetids_mapper = defaultdict(set)
nodeid_nodekey_mapper = {}
for node in nodes:
nodeid_nodekey_mapper[node.id] = node.key
for nodeid, assetid in nodeid_assetid_pairs:
if nodeid not in nodeid_nodekey_mapper:
continue
nodekey = nodeid_nodekey_mapper[nodeid]
nodekey_assetids_mapper[nodekey].add(assetid)
util = NodeAssetsUtil(nodes, nodekey_assetids_mapper)
util.generate()
to_updates = []
for node in nodes:
assets_amount = util.get_assets_amount(node.key)
if node.assets_amount != assets_amount:
logger.error(f'Node[{node.key}] assets amount error {node.assets_amount} != {assets_amount}')
node.assets_amount = assets_amount
to_updates.append(node)
Node.objects.bulk_update(to_updates, fields=('assets_amount',))
def is_query_node_all_assets(request):
request = request
query_all_arg = request.query_params.get('all', 'true')
......@@ -104,5 +137,3 @@ class NodeAssetsUtil:
util = cls(nodes, mapping)
util.generate()
return util
......@@ -8,6 +8,7 @@ from django.db import transaction
from common.utils import get_logger
from common.utils.inspect import copy_function_args
from apps.jumpserver.const import CONFIG
from common.local import thread_local
logger = get_logger(__file__)
......@@ -16,24 +17,28 @@ class AcquireFailed(RuntimeError):
pass
class LockHasTimeOut(RuntimeError):
pass
class DistributedLock(RedisLock):
def __init__(self, name, blocking=True, expire=None, release_lock_on_transaction_commit=False,
release_raise_exc=False, auto_renewal_seconds=60*2):
def __init__(self, name, *, expire=None, release_on_transaction_commit=False,
reentrant=False, release_raise_exc=False, auto_renewal_seconds=60):
"""
使用 redis 构造的分布式锁
:param name:
锁的名字,要全局唯一
:param blocking:
该参数只在锁作为装饰器或者 `with` 时有效。
:param expire:
锁的过期时间
:param release_lock_on_transaction_commit:
:param release_on_transaction_commit:
是否在当前事务结束后再释放锁
:param release_raise_exc:
释放锁时,如果没有持有锁是否抛异常或静默
:param auto_renewal_seconds:
当持有一个无限期锁的时候,刷新锁的时间,具体参考 `redis_lock.Lock#auto_renewal`
:param reentrant:
是否可重入
"""
self.kwargs_copy = copy_function_args(self.__init__, locals())
redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
......@@ -45,28 +50,20 @@ class DistributedLock(RedisLock):
auto_renewal = False
super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal)
self._blocking = blocking
self._release_lock_on_transaction_commit = release_lock_on_transaction_commit
self._release_on_transaction_commit = release_on_transaction_commit
self._release_raise_exc = release_raise_exc
self._reentrant = reentrant
self._acquired_reentrant_lock = False
self._thread_id = threading.current_thread().ident
def __enter__(self):
thread_id = threading.current_thread().ident
logger.debug(f'Attempt to acquire global lock: thread {thread_id} lock {self._name}')
acquired = self.acquire(blocking=self._blocking)
if self._blocking and not acquired:
logger.debug(f'Not acquired lock, but blocking=True, thread {thread_id} lock {self._name}')
raise EnvironmentError("Lock wasn't acquired, but blocking=True")
acquired = self.acquire(blocking=True)
if not acquired:
logger.debug(f'Not acquired the lock, thread {thread_id} lock {self._name}')
raise AcquireFailed
logger.debug(f'Acquire lock success, thread {thread_id} lock {self._name}')
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
if self._release_lock_on_transaction_commit:
transaction.on_commit(self.release)
else:
self.release()
self.release()
def __call__(self, func):
@wraps(func)
......@@ -82,9 +79,105 @@ class DistributedLock(RedisLock):
return True
return False
def release(self):
def locked_by_current_thread(self):
if self.locked():
owner_id = self.get_owner_id()
local_owner_id = getattr(thread_local, self.name, None)
if local_owner_id and owner_id == local_owner_id:
return True
return False
def acquire(self, blocking=True, timeout=None):
if self._reentrant:
if self.locked_by_current_thread():
self._acquired_reentrant_lock = True
logger.debug(
f'I[{self.id}] reentry lock[{self.name}] in thread[{self._thread_id}].')
return True
logger.debug(f'I[{self.id}] attempt acquire reentrant-lock[{self.name}].')
acquired = super().acquire(blocking=blocking, timeout=timeout)
if acquired:
logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] now.')
setattr(thread_local, self.name, self.id)
else:
logger.debug(f'I[{self.id}] acquired reentrant-lock[{self.name}] failed.')
return acquired
else:
logger.debug(f'I[{self.id}] attempt acquire lock[{self.name}].')
acquired = super().acquire(blocking=blocking, timeout=timeout)
logger.debug(f'I[{self.id}] acquired lock[{self.name}] {acquired}.')
return acquired
@property
def name(self):
return self._name
def _raise_exc_with_log(self, msg, *, exc_cls=NotAcquired):
e = exc_cls(msg)
logger.error(msg)
self._raise_exc(e)
def _raise_exc(self, e):
if self._release_raise_exc:
raise e
def _release_on_reentrant_locked_by_brother(self):
if self._acquired_reentrant_lock:
self._acquired_reentrant_lock = False
logger.debug(f'I[{self.id}] released reentrant-lock[{self.name}] owner[{self.get_owner_id()}] in thread[{self._thread_id}]')
return
else:
self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired by me[{self.id}].')
def _release_on_reentrant_locked_by_me(self):
logger.debug(f'I[{self.id}] release reentrant-lock[{self.name}] in thread[{self._thread_id}]')
id = getattr(thread_local, self.name, None)
if id != self.id:
raise PermissionError(f'Reentrant-lock[{self.name}] is not locked by me[{self.id}], owner[{id}]')
try:
# 这里要保证先删除 thread_local 的标记,
delattr(thread_local, self.name)
except AttributeError:
pass
finally:
try:
# 这里处理的是边界情况,
# 判断锁是我的 -> 锁超时 -> 释放锁报错
# 此时的报错应该被静默
self._release_redis_lock()
except NotAcquired:
pass
def _release_redis_lock(self):
# 最底层 api
super().release()
def _release(self):
try:
super().release()
except AcquireFailed as e:
if self._release_raise_exc:
raise e
self._release_redis_lock()
except NotAcquired as e:
logger.error(f'I[{self.id}] release lock[{self.name}] failed {e}')
self._raise_exc(e)
def release(self):
_release = self._release
# 处理可重入锁
if self._reentrant:
if self.locked_by_current_thread():
if self.locked_by_me():
_release = self._release_on_reentrant_locked_by_me
else:
_release = self._release_on_reentrant_locked_by_brother
else:
self._raise_exc_with_log(f'Reentrant-lock[{self.name}] is not acquired in current-thread[{self._thread_id}]')
# 处理是否在事务提交时才释放锁
if self._release_on_transaction_commit:
logger.debug(f'I[{self.id}] release lock[{self.name}] on transaction commit ...')
transaction.on_commit(_release)
else:
_release()
......@@ -186,6 +186,10 @@ def org_aware_func(org_arg_name):
current_org = LocalProxy(get_current_org)
def ensure_in_real_or_default_org():
if not current_org or current_org.is_root():
raise ValueError('You must in a real or default org!')
def ensure_in_real_or_default_org(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not current_org or current_org.is_root():
raise ValueError('You must in a real or default org!')
return func(*args, **kwargs)
return wrapper
......@@ -26,12 +26,6 @@ class AssetPermissionViewSet(BasePermissionViewSet):
'node_id', 'node', 'asset_id', 'hostname', 'ip'
]
def get_queryset(self):
queryset = super().get_queryset().prefetch_related(
"nodes", "assets", "users", "user_groups", "system_users"
)
return queryset
def filter_node(self, queryset):
node_id = self.request.query_params.get('node_id')
node_name = self.request.query_params.get('node')
......
......@@ -14,7 +14,6 @@ from .mixin import RoleUserMixin, RoleAdminMixin
from perms.utils.asset.user_permission import (
UserGrantedTreeBuildUtils, get_user_all_asset_perm_ids,
UserGrantedNodesQueryUtils, UserGrantedAssetsQueryUtils,
QuerySetStage,
)
from perms.models import AssetPermission, PermNode
from assets.models import Asset
......@@ -44,10 +43,10 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
def add_favorite_resource(self, data: list, nodes_query_utils, assets_query_utils):
favorite_node = nodes_query_utils.get_favorite_node()
qs_state = QuerySetStage().annotate(
favorite_assets = assets_query_utils.get_favorite_assets()
favorite_assets = favorite_assets.annotate(
parent_key=Value(favorite_node.key, output_field=CharField())
).prefetch_related('platform')
favorite_assets = assets_query_utils.get_favorite_assets(qs_stage=qs_state, only=())
data.extend(self.serialize_nodes([favorite_node], with_asset_amount=True))
data.extend(self.serialize_assets(favorite_assets))
......@@ -59,13 +58,11 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
data.extend(self.serialize_nodes(nodes, with_asset_amount=True))
def add_assets(self, data: list, assets_query_utils: UserGrantedAssetsQueryUtils):
qs_stage = QuerySetStage().annotate(parent_key=F('nodes__key')).prefetch_related('platform')
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
all_assets = assets_query_utils.get_direct_granted_nodes_assets(qs_stage=qs_stage)
all_assets = assets_query_utils.get_direct_granted_nodes_assets()
else:
all_assets = assets_query_utils.get_all_granted_assets(qs_stage=qs_stage)
all_assets = assets_query_utils.get_all_granted_assets()
all_assets = all_assets.annotate(parent_key=F('nodes__key')).prefetch_related('platform')
data.extend(self.serialize_assets(all_assets))
@tmp_to_root_org()
......@@ -144,8 +141,6 @@ class GrantedNodeChildrenWithAssetsAsTreeApiMixin(SerializeToTreeNodeMixin,
assets = assets_query_utils.get_node_assets(key)
assets = assets.prefetch_related('platform')
user = self.user
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, key)
return Response(data=[*tree_nodes, *tree_assets])
......
......@@ -45,7 +45,7 @@ class BasePermissionViewSet(OrgBulkModelViewSet):
if not self.is_query_all():
queryset = queryset.filter(users=user)
return queryset
groups = user.groups.all()
groups = list(user.groups.all().values_list('id', flat=True))
queryset = queryset.filter(
Q(users=user) | Q(user_groups__in=groups)
).distinct()
......
# Generated by Django 3.1 on 2021-02-04 09:49
# Generated by Django 3.1 on 2021-02-08 07:15
import assets.models.node
from django.conf import settings
......@@ -9,8 +9,8 @@ import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('assets', '0066_remove_node_assets_amount'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('assets', '0065_auto_20210121_1549'),
('perms', '0017_auto_20210104_0435'),
]
......
from rest_framework.pagination import LimitOffsetPagination
from django.conf import settings
from rest_framework.request import Request
from django.db.models import Sum
from assets.pagination import AssetPaginationBase
from perms.models import UserAssetGrantedTreeNodeRelation
from common.utils import get_logger
logger = get_logger(__name__)
class GrantedAssetPaginationBase(LimitOffsetPagination):
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
self._user = request.user
return super().paginate_queryset(queryset, request, view=None)
def get_count(self, queryset):
exclude_query_params = {
self.limit_query_param,
self.offset_query_param,
'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw',
'order',
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}')
return super().get_count(queryset)
return self.get_count_from_nodes(queryset)
def get_count_from_nodes(self, queryset):
raise NotImplementedError
class GrantedAssetPaginationBase(AssetPaginationBase):
def init_attrs(self, queryset, request: Request, view=None):
super().init_attrs(queryset, request, view)
self._user = view.user
class NodeGrantedAssetPagination(GrantedAssetPaginationBase):
......@@ -42,11 +22,13 @@ class NodeGrantedAssetPagination(GrantedAssetPaginationBase):
return node.assets_amount
else:
logger.warn(f'Not hit node.assets_amount[{node}] because {self._view} not has `pagination_node` -> {self._request.get_full_path()}')
return super().get_count(queryset)
return None
class AllGrantedAssetPagination(GrantedAssetPaginationBase):
def get_count_from_nodes(self, queryset):
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return None
assets_amount = sum(UserAssetGrantedTreeNodeRelation.objects.filter(
user=self._user, node_parent_key=''
).values_list('node_assets_amount', flat=True))
......
......@@ -3,9 +3,12 @@
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from django.db.models import Prefetch
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from perms.models import AssetPermission, Action
from assets.models import Asset, Node, SystemUser
from users.models import User, UserGroup
__all__ = [
'AssetPermissionSerializer',
......@@ -68,5 +71,11 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('users', 'user_groups', 'assets', 'nodes', 'system_users')
queryset = queryset.prefetch_related(
Prefetch('system_users', queryset=SystemUser.objects.only('id')),
Prefetch('user_groups', queryset=UserGroup.objects.only('id')),
Prefetch('users', queryset=User.objects.only('id')),
Prefetch('assets', queryset=Asset.objects.only('id')),
Prefetch('nodes', queryset=Node.objects.only('id'))
)
return queryset
......@@ -115,8 +115,8 @@ class UnionQuerySet(QuerySet):
def __getitem__(self, item):
return self.__execute()[item]
def __next__(self):
return next(self.__execute())
def __iter__(self):
return iter(self.__execute())
@classmethod
def test_it(cls):
......@@ -299,12 +299,12 @@ class UserGrantedTreeRefreshController:
cls.remove_builed_orgs_from_users(orgs_id, users_id)
@classmethod
@ensure_in_real_or_default_org
def add_need_refresh_on_nodes_assets_relate_change(cls, node_ids, asset_ids):
"""
1,计算与这些资产有关的授权
2,计算与这些节点以及祖先节点有关的授权
"""
ensure_in_real_or_default_org()
node_ids = set(node_ids)
ancestor_node_keys = set()
......@@ -340,8 +340,8 @@ class UserGrantedTreeRefreshController:
cls.add_need_refresh_by_asset_perm_ids(perm_ids)
@classmethod
@ensure_in_real_or_default_org
def add_need_refresh_by_asset_perm_ids(cls, asset_perm_ids):
ensure_in_real_or_default_org()
group_ids = AssetPermission.user_groups.through.objects.filter(
assetpermission_id__in=asset_perm_ids
......@@ -429,8 +429,8 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
return asset_ids
@timeit
@ensure_in_real_or_default_org
def rebuild_user_granted_tree(self):
ensure_in_real_or_default_org()
logger.info(f'Rebuild user:{self.user} tree in org:{current_org}')
user = self.user
......@@ -618,13 +618,13 @@ class UserGrantedTreeBuildUtils(UserGrantedUtilsBase):
class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
def get_favorite_assets(self, only=('id', )) -> QuerySet:
def get_favorite_assets(self) -> QuerySet:
favorite_asset_ids = FavoriteAsset.objects.filter(
user=self.user
).values_list('asset_id', flat=True)
favorite_asset_ids = list(favorite_asset_ids)
assets = self.get_all_granted_assets()
assets = assets.filter(id__in=favorite_asset_ids).only(*only)
assets = assets.filter(id__in=favorite_asset_ids)
return assets
def get_ungroup_assets(self) -> AssetQuerySet:
......@@ -670,7 +670,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
granted_status = node.get_granted_status(self.user)
if granted_status == NodeFrom.granted:
assets = Asset.objects.order_by().filter(nodes_id=node.id)
assets = Asset.objects.order_by().filter(nodes__id=node.id)
return assets
elif granted_status == NodeFrom.asset:
return self._get_indirect_granted_node_assets(node.id)
......@@ -678,7 +678,7 @@ class UserGrantedAssetsQueryUtils(UserGrantedUtilsBase):
return Asset.objects.none()
def _get_indirect_granted_node_assets(self, id) -> AssetQuerySet:
assets = Asset.objects.order_by().filter(nodes_id=id) & self.get_direct_granted_assets()
assets = Asset.objects.order_by().filter(nodes__id=id).distinct() & self.get_direct_granted_assets()
return assets
def _get_indirect_granted_node_all_assets(self, node) -> QuerySet:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册