提交 1a0ff422 编写于 作者: baltery's avatar baltery

[Update] 优化树结构

上级 6d96b5db
......@@ -148,6 +148,7 @@ class AssetUserTestConnectiveApi(generics.RetrieveAPIView):
Test asset users connective
"""
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.TaskIDSerializer
def get_asset_users(self):
username = self.request.GET.get('username')
......
......@@ -26,6 +26,7 @@ from ..hands import IsOrgAdmin
from ..models import Node
from ..tasks import update_assets_hardware_info_util, test_asset_connectivity_util
from .. import serializers
from ..utils import NodeUtil
logger = get_logger(__file__)
......@@ -79,12 +80,10 @@ class NodeListAsTreeApi(generics.ListAPIView):
serializer_class = TreeNodeSerializer
def get_queryset(self):
queryset = [node.as_tree_node() for node in Node.objects.all()]
return queryset
def filter_queryset(self, queryset):
if self.request.query_params.get('refresh', '0') == '1':
queryset = self.refresh_nodes(queryset)
queryset = Node.objects.all()
util = NodeUtil()
nodes = util.get_nodes_by_queryset(queryset)
queryset = [node.as_tree_node() for node in nodes]
return queryset
@staticmethod
......@@ -114,15 +113,11 @@ class NodeChildrenAsTreeApi(generics.ListAPIView):
def get_queryset(self):
node_key = self.request.query_params.get('key')
if node_key:
self.node = Node.objects.get(key=node_key)
queryset = self.node.get_children(with_self=False)
else:
self.is_root = True
self.node = Node.root()
queryset = list(self.node.get_children(with_self=True))
nodes_invalid = Node.objects.exclude(key__startswith=self.node.key)
queryset.extend(list(nodes_invalid))
util = NodeUtil()
if not node_key:
node_key = Node.root().key
self.node = util.get_node_by_key(node_key)
queryset = self.node.get_children(with_self=True)
queryset = [node.as_tree_node() for node in queryset]
queryset = sorted(queryset)
return queryset
......
......@@ -46,12 +46,6 @@ class AssetQuerySet(models.QuerySet):
return self.active()
class AssetManager(OrgManager):
def get_queryset(self):
queryset = super().get_queryset().prefetch_related("nodes", "protocols")
return queryset
class Protocol(models.Model):
PROTOCOL_SSH = 'ssh'
PROTOCOL_RDP = 'rdp'
......@@ -131,7 +125,7 @@ class Asset(OrgModelMixin):
date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created'))
comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment'))
objects = AssetManager.from_queryset(AssetQuerySet)()
objects = OrgManager.from_queryset(AssetQuerySet)()
def __str__(self):
return '{0.hostname}({0.ip})'.format(self)
......@@ -300,15 +294,20 @@ class Asset(OrgModelMixin):
@classmethod
def generate_fake(cls, count=100):
from random import seed, choice
import forgery_py
from django.db import IntegrityError
from .node import Node
from orgs.utils import get_current_org
from orgs.models import Organization
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
nodes = list(Node.objects.all())
seed()
for i in range(count):
ip = [str(i) for i in random.sample(range(255), 4)]
asset = cls(ip='.'.join(ip),
hostname=forgery_py.internet.user_name(True),
hostname='.'.join(ip),
admin_user=choice(AdminUser.objects.all()),
created_by='Fake')
try:
......
# -*- coding: utf-8 -*-
#
import uuid
import re
from django.db import models, transaction
from django.db.models import Q
......@@ -15,54 +16,185 @@ from orgs.models import Organization
__all__ = ['Node']
class Node(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1'
value = models.CharField(max_length=128, verbose_name=_("Value"))
child_mark = models.IntegerField(default=0)
date_create = models.DateTimeField(auto_now_add=True)
class FamilyMixin:
_parents = None
_children = None
_all_children = None
is_node = True
_assets_amount = None
_full_value_cache_key = '_NODE_VALUE_{}'
_assets_amount_cache_key = '_NODE_ASSETS_AMOUNT_{}'
class Meta:
verbose_name = _("Node")
ordering = ['key']
@property
def children(self):
if self._children:
return self._children
pattern = r'^{0}:[0-9]+$'.format(self.key)
return Node.objects.filter(key__regex=pattern)
def __str__(self):
return self.full_value
@children.setter
def children(self, value):
self._children = value
def __eq__(self, other):
if not other:
return False
return self.id == other.id
@property
def all_children(self):
if self._all_children:
return self._all_children
pattern = r'^{0}:'.format(self.key)
return Node.objects.filter(
key__regex=pattern
)
def __gt__(self, other):
if self.is_root() and not other.is_root():
return True
elif not self.is_root() and other.is_root():
return False
self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')]
self_parent_key = self_key[:-1]
other_parent_key = other_key[:-1]
def get_children(self, with_self=False):
children = list(self.children)
if with_self:
children.append(self)
return children
if self_parent_key == other_parent_key:
return self.name > other.name
if len(self_parent_key) < len(other_parent_key):
return True
elif len(self_parent_key) > len(other_parent_key):
return False
return self_key > other_key
def get_all_children(self, with_self=False):
children = self.all_children
if with_self:
children = list(children)
children.append(self)
return children
def __lt__(self, other):
return not self.__gt__(other)
@property
def parents(self):
if self._parents:
return self._parents
ancestor_keys = self.get_ancestor_keys()
ancestor = Node.objects.filter(
key__in=ancestor_keys
).order_by('key')
return ancestor
@parents.setter
def parents(self, value):
self._parents = value
def get_ancestor(self, with_self=False):
parents = self.parents
if with_self:
parents = list(parents)
parents.append(self)
return parents
@property
def name(self):
return self.value
def parent(self):
if self._parents:
return self._parents[0]
if self.is_root():
return self
try:
parent = Node.objects.get(key=self.parent_key)
return parent
except Node.DoesNotExist:
return Node.root()
@parent.setter
def parent(self, parent):
if not self.is_node:
self.key = parent.key + ':fake'
return
children = self.get_all_children()
old_key = self.key
with transaction.atomic():
self.key = parent.get_next_child_key()
for child in children:
child.key = child.key.replace(old_key, self.key, 1)
child.save()
self.save()
def get_sibling(self, with_self=False):
key = ':'.join(self.key.split(':')[:-1])
pattern = r'^{}:[0-9]+$'.format(key)
sibling = Node.objects.filter(
key__regex=pattern.format(self.key)
)
if not with_self:
sibling = sibling.exclude(key=self.key)
return sibling
def get_family(self):
ancestor = self.get_ancestor()
children = self.get_all_children()
return [*tuple(ancestor), self, *tuple(children)]
def get_ancestor_keys(self, with_self=False):
parent_keys = []
key_list = self.key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
def is_children(self, other):
pattern = re.compile(r'^{0}:[0-9]+$'.format(self.key))
return pattern.match(other.key)
def is_parent(self, other):
pattern = re.compile(r'^{0}:[0-9]+$'.format(other.key))
return pattern.match(self.key)
@property
def parent_key(self):
parent_key = ":".join(self.key.split(":")[:-1])
return parent_key
@property
def parents_keys(self, with_self=False):
keys = []
key_list = self.key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
keys.append(':'.join(key_list))
key_list.pop()
return keys
class FullValueMixin:
_full_value_cache_key = '_NODE_VALUE_{}'
_full_value = ''
key = ''
@property
def full_value(self):
if self._full_value:
return self._full_value
key = self._full_value_cache_key.format(self.key)
cached = cache.get(key)
if cached:
return cached
if self.is_root():
return self.value
parent_full_value = self.parent.full_value
value = parent_full_value + ' / ' + self.value
self.full_value = value
return value
@full_value.setter
def full_value(self, value):
self._full_value = value
key = self._full_value_cache_key.format(self.key)
cache.set(key, value, 3600*24)
def expire_full_value(self):
key = self._full_value_cache_key.format(self.key)
cache.delete_pattern(key+'*')
@classmethod
def expire_nodes_full_value(cls, nodes=None):
key = cls._full_value_cache_key.format('*')
cache.delete_pattern(key+'*')
from ..utils import NodeUtil
util = NodeUtil()
util.set_full_value()
class AssetsAmountMixin:
_assets_amount_cache_key = '_NODE_ASSETS_AMOUNT_{}'
_assets_amount = None
key = ''
@property
def assets_amount(self):
......@@ -77,53 +209,77 @@ class Node(OrgModelMixin):
if cached is not None:
return cached
assets_amount = self.get_all_assets().count()
cache.set(cache_key, assets_amount, 3600)
self.assets_amount = assets_amount
return assets_amount
@assets_amount.setter
def assets_amount(self, value):
self._assets_amount = value
cache_key = self._assets_amount_cache_key.format(self.key)
cache.set(cache_key, value, 3600 * 24)
def expire_assets_amount(self):
ancestor_keys = self.get_ancestor_keys(with_self=True)
cache_keys = [self._assets_amount_cache_key.format(k) for k in ancestor_keys]
cache_keys = [self._assets_amount_cache_key.format(k) for k in
ancestor_keys]
cache.delete_many(cache_keys)
@classmethod
def expire_nodes_assets_amount(cls, nodes=None):
if nodes:
for node in nodes:
node.expire_assets_amount()
return
from ..utils import NodeUtil
key = cls._assets_amount_cache_key.format('*')
cache.delete_pattern(key)
util = NodeUtil(with_assets_amount=True)
util.set_assets_amount()
@property
def full_value(self):
key = self._full_value_cache_key.format(self.key)
cached = cache.get(key)
if cached:
return cached
if self.is_root():
return self.value
parent_full_value = self.parent.full_value
value = parent_full_value + ' / ' + self.value
key = self._full_value_cache_key.format(self.key)
cache.set(key, value, 3600)
return value
def expire_full_value(self):
key = self._full_value_cache_key.format(self.key)
cache.delete_pattern(key+'*')
class Node(OrgModelMixin, FamilyMixin, FullValueMixin, AssetsAmountMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1'
value = models.CharField(max_length=128, verbose_name=_("Value"))
child_mark = models.IntegerField(default=0)
date_create = models.DateTimeField(auto_now_add=True)
@classmethod
def expire_nodes_full_value(cls, nodes=None):
if nodes:
for node in nodes:
node.expire_full_value()
return
key = cls._full_value_cache_key.format('*')
cache.delete_pattern(key+'*')
is_node = True
_parents = None
class Meta:
verbose_name = _("Node")
ordering = ['key']
def __str__(self):
return self.full_value
def __eq__(self, other):
if not other:
return False
return self.id == other.id
def __gt__(self, other):
# if self.is_root() and not other.is_root():
# return False
# elif not self.is_root() and other.is_root():
# return True
self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')]
self_parent_key = self_key[:-1]
other_parent_key = other_key[:-1]
if self_parent_key and other_parent_key and \
self_parent_key == other_parent_key:
return self.value > other.value
# if len(self_parent_key) < len(other_parent_key):
# return True
# elif len(self_parent_key) > len(other_parent_key):
# return False
return self_key > other_key
def __lt__(self, other):
return not self.__gt__(other)
@property
def name(self):
return self.value
@property
def level(self):
......@@ -152,33 +308,6 @@ class Node(OrgModelMixin):
child = self.__class__.objects.create(id=_id, key=child_key, value=value)
return child
def get_children(self, with_self=False):
pattern = r'^{0}$|^{0}:[0-9]+$' if with_self else r'^{0}:[0-9]+$'
return self.__class__.objects.filter(
key__regex=pattern.format(self.key)
)
def get_all_children(self, with_self=False):
pattern = r'^{0}$|^{0}:' if with_self else r'^{0}:'
return self.__class__.objects.filter(
key__regex=pattern.format(self.key)
)
def get_sibling(self, with_self=False):
key = ':'.join(self.key.split(':')[:-1])
pattern = r'^{}:[0-9]+$'.format(key)
sibling = self.__class__.objects.filter(
key__regex=pattern.format(self.key)
)
if not with_self:
sibling = sibling.exclude(key=self.key)
return sibling
def get_family(self):
ancestor = self.get_ancestor()
children = self.get_all_children()
return [*tuple(ancestor), self, *tuple(children)]
def get_assets(self):
from .asset import Asset
if self.is_default_node():
......@@ -214,52 +343,6 @@ class Node(OrgModelMixin):
else:
return False
@property
def parent_key(self):
parent_key = ":".join(self.key.split(":")[:-1])
return parent_key
@property
def parent(self):
if self.is_root():
return self
try:
parent = self.__class__.objects.get(key=self.parent_key)
return parent
except Node.DoesNotExist:
return self.__class__.root()
@parent.setter
def parent(self, parent):
if not self.is_node:
self.key = parent.key + ':fake'
return
children = self.get_all_children()
old_key = self.key
with transaction.atomic():
self.key = parent.get_next_child_key()
for child in children:
child.key = child.key.replace(old_key, self.key, 1)
child.save()
self.save()
def get_ancestor_keys(self, with_self=False):
parent_keys = []
key_list = self.key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
def get_ancestor(self, with_self=False):
ancestor_keys = self.get_ancestor_keys(with_self=with_self)
ancestor = self.__class__.objects.filter(
key__in=ancestor_keys
).order_by('key')
return ancestor
@classmethod
def create_root_node(cls):
# 如果使用current_org 在set_current_org时会死循环
......@@ -310,9 +393,19 @@ class Node(OrgModelMixin):
tree_node = TreeNode(**data)
return tree_node
@classmethod
def get_queryset(cls):
from ..utils import NodeUtil
util = NodeUtil()
return util.nodes
@classmethod
def generate_fake(cls, count=100):
import random
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
for i in range(count):
node = random.choice(cls.objects.all())
node.create_child('Node {}'.format(i))
# ~*~ coding: utf-8 ~*~
#
from django.utils.translation import ugettext_lazy as _
from django.core.cache import cache
from django.utils import timezone
from django.db.models import Prefetch
from common.utils import get_object_or_none
from .models import SystemUser, Label
from common.utils import get_object_or_none, get_logger
from common.struct import Stack
from .models import SystemUser, Label, Node, Asset
def get_assets_by_id_list(id_list):
return Asset.objects.filter(id__in=id_list).filter(is_active=True)
def get_system_users_by_id_list(id_list):
return SystemUser.objects.filter(id__in=id_list)
logger = get_logger(__file__)
def get_system_user_by_name(name):
......@@ -47,4 +41,154 @@ class LabelFilter:
return queryset
class NodeUtil:
def __init__(self, with_assets_amount=False, debug=False):
self.stack = Stack()
self._nodes = {}
self.with_assets_amount = with_assets_amount
self._debug = debug
self.init()
@staticmethod
def sorted_by(node):
return [int(i) for i in node.key.split(':')]
def get_all_nodes(self):
all_nodes = Node.objects.all()
if self.with_assets_amount:
all_nodes = all_nodes.prefetch_related(
Prefetch('assets', queryset=Asset.objects.all().only('id'))
)
for node in all_nodes:
node._assets = set(node.assets.all())
all_nodes = sorted(all_nodes, key=self.sorted_by)
guarder = Node(key='', value='Guarder')
guarder._assets = []
all_nodes.append(guarder)
return all_nodes
def push_to_stack(self, node):
# 入栈之前检查
# 如果栈是空的,证明是一颗树的根部
if self.stack.is_empty():
node._full_value = node.value
node._parents = []
else:
# 如果不是根节点,
# 该节点的祖先应该是父节点的祖先加上父节点
# 该节点的名字是父节点的名字+自己的名字
node._parents = [self.stack.top] + self.stack.top._parents
node._full_value = ' / '.join(
[self.stack.top._full_value, node.value]
)
node._children = []
node._all_children = []
self.debug("入栈: {}".format(node.key))
self.stack.push(node)
# 出栈
def pop_from_stack(self):
_node = self.stack.pop()
self.debug("出栈: {} 栈顶: {}".format(_node.key, self.stack.top.key if self.stack.top else None))
self._nodes[_node.key] = _node
if not self.stack.top:
return
if self.with_assets_amount:
self.stack.top._assets.update(_node._assets)
_node._assets_amount = len(_node._assets)
delattr(_node, '_assets')
self.stack.top._children.append(_node)
self.stack.top._all_children.extend([_node] + _node._children)
def init(self):
all_nodes = self.get_all_nodes()
for node in all_nodes:
self.debug("准备: {} 栈顶: {}".format(node.key, self.stack.top.key if self.stack.top else None))
# 入栈之前检查,该节点是不是栈顶节点的子节点
# 如果不是,则栈顶出栈
while self.stack.top and not self.stack.top.is_children(node):
self.pop_from_stack()
self.push_to_stack(node)
# 出栈最后一个
self.debug("剩余: {}".format(', '.join([n.key for n in self.stack])))
def get_nodes_by_queryset(self, queryset):
nodes = []
for n in queryset:
node = self._nodes.get(n.key)
if not node:
continue
nodes.append(nodes)
return [self]
def get_node_by_key(self, key):
return self._nodes.get(key)
def debug(self, msg):
self._debug and logger.debug(msg)
def set_assets_amount(self):
for node in self._nodes.values():
node.assets_amount = node._assets_amount
def set_full_value(self):
for node in self._nodes.values():
node.full_value = node._full_value
@property
def nodes(self):
return list(self._nodes.values())
# 使用给定节点生成一颗树
# 找到他们的祖先节点
# 可选找到他们的子孙节点
def get_family(self, nodes, with_children=False):
tree_nodes = set()
for n in nodes:
node = self.get_node_by_key(n.key)
if not node:
continue
tree_nodes.update(node._parents)
tree_nodes.add(node)
if with_children:
tree_nodes.update(node._children)
for n in tree_nodes:
delattr(n, '_children')
delattr(n, '_parents')
return list(tree_nodes)
def test_node_tree():
tree = NodeUtil()
for node in tree._nodes.values():
print("Check {}".format(node.key))
children_wanted = node.get_all_children().count()
children = len(node._children)
if children != children_wanted:
print("{} children not equal: {} != {}".format(node.key, children, children_wanted))
assets_amount_wanted = node.get_all_assets().count()
if node._assets_amount != assets_amount_wanted:
print("{} assets amount not equal: {} != {}".format(
node.key, node._assets_amount, assets_amount_wanted)
)
full_value_wanted = node.full_value
if node._full_value != full_value_wanted:
print("{} full value not equal: {} != {}".format(
node.key, node._full_value, full_value_wanted)
)
parents_wanted = node.get_ancestor().count()
parents = len(node._parents)
if parents != parents_wanted:
print("{} parents count not equal: {} != {}".format(
node.key, parents, parents_wanted)
)
# -*- coding: utf-8 -*-
#
class Stack(list):
def is_empty(self):
return len(self) == 0
@property
def top(self):
if self.is_empty():
return None
return self[-1]
@property
def bottom(self):
if self.is_empty():
return None
return self[0]
def size(self):
return len(self)
def push(self, item):
self.append(item)
......@@ -7,7 +7,7 @@ from django.conf.urls.static import static
from django.conf.urls.i18n import i18n_patterns
from django.views.i18n import JavaScriptCatalog
from .views import IndexView, LunaView, I18NView
from .views import IndexView, LunaView, I18NView, HealthCheckView
from .swagger import get_swagger_view
api_v1 = [
......@@ -63,6 +63,7 @@ urlpatterns = [
path('', IndexView.as_view(), name='index'),
path('', include(api_v2_patterns)),
path('', include(api_v1_patterns)),
path('api/health/', HealthCheckView.as_view(), name="health"),
path('luna/', LunaView.as_view(), name='luna-view'),
path('i18n/<str:lang>/', I18NView.as_view(), name='i18n-switch'),
path('settings/', include('settings.urls.view_urls', namespace='settings')),
......
import datetime
import re
import time
from django.http import HttpResponse, HttpResponseRedirect
from django.conf import settings
......@@ -9,6 +10,7 @@ from django.utils.translation import ugettext_lazy as _
from django.db.models import Count
from django.shortcuts import redirect
from rest_framework.response import Response
from rest_framework.views import APIView
from django.views.decorators.csrf import csrf_exempt
from django.http import HttpResponse
from django.utils.encoding import iri_to_uri
......@@ -222,3 +224,10 @@ def redirect_format_api(request, *args, **kwargs):
return HttpResponseTemporaryRedirect(_path)
else:
return Response({"msg": "Redirect url failed: {}".format(_path)}, status=404)
class HealthCheckView(APIView):
permission_classes = ()
def get(self, request):
return Response({"status": 1, "time": int(time.time())})
......@@ -2,7 +2,6 @@
#
from .ansible.inventory import BaseInventory
from assets.utils import get_assets_by_id_list, get_system_user_by_id
from common.utils import get_logger
......
# -*- coding: utf-8 -*-
#
import traceback
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.shortcuts import redirect, get_object_or_404
......@@ -33,8 +33,8 @@ class OrgManager(models.Manager):
def get_queryset(self):
queryset = super(OrgManager, self).get_queryset()
kwargs = {}
_current_org = get_current_org()
_current_org = get_current_org()
if _current_org is None:
kwargs['id'] = None
elif _current_org.is_real():
......@@ -42,12 +42,17 @@ class OrgManager(models.Manager):
elif _current_org.is_default():
queryset = queryset.filter(org_id="")
# lines = traceback.format_stack()
# print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>")
# for line in lines[-10:-5]:
# print(line)
# print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
queryset = queryset.filter(**kwargs)
return queryset
def all(self):
_current_org = get_current_org()
if _current_org is None:
if not current_org:
msg = 'You can `objects.set_current_org(org).all()` then run it'
return self
else:
......
......@@ -258,7 +258,9 @@ class UserGrantedNodesWithAssetsAsTreeApi(UserPermissionCacheMixin, ListAPIView)
util.filter_permissions(
system_users=self.system_user_id
)
print("111111111111")
nodes = util.get_nodes_with_assets()
print("22222222222222")
for node, assets in nodes.items():
data = parse_node_to_tree_node(node)
queryset.append(data)
......
......@@ -7,7 +7,7 @@ from django.utils.translation import ugettext_lazy as _
from orgs.mixins import OrgModelForm
from orgs.utils import current_org
from perms.models import AssetPermission
from assets.models import Asset
from assets.models import Asset, Node
__all__ = [
'AssetPermissionForm',
......
......@@ -4,6 +4,7 @@ import uuid
from collections import defaultdict
import json
from hashlib import md5
import time
from django.utils import timezone
from django.db.models import Q
......@@ -17,6 +18,7 @@ from common.tree import TreeNode
from .. import const
from ..models import AssetPermission, Action
from ..hands import Node
from assets.utils import NodeUtil
logger = get_logger(__file__)
......@@ -35,9 +37,8 @@ class GenerateTree:
"asset_instance": set("system_user")
}
"""
self.__all_nodes = list(Node.objects.all())
self.node_util = NodeUtil()
self.nodes = defaultdict(dict)
self.direct_nodes = []
self._root_node = None
self._ungroup_node = None
......@@ -48,10 +49,8 @@ class GenerateTree:
all_nodes = self.nodes.keys()
# 如果没有授权节点,就放到默认的根节点下
if not all_nodes:
root_node = Node.root()
self.add_node(root_node)
else:
root_node = max(all_nodes)
return None
root_node = min(all_nodes)
self._root_node = root_node
return root_node
......@@ -60,7 +59,10 @@ class GenerateTree:
if self._ungroup_node:
return self._ungroup_node
node_id = const.UNGROUPED_NODE_ID
node_key = self.root_node.get_next_child_key()
if self.root_node:
node_key = self.root_node.get_next_child_key()
else:
node_key = '0:0'
node_value = _("Default")
node = Node(id=node_id, key=node_key, value=node_value)
self.add_node(node)
......@@ -69,11 +71,11 @@ class GenerateTree:
def add_asset(self, asset, system_users):
nodes = asset.nodes.all()
in_nodes = set(self.direct_nodes) & set(nodes)
for node in in_nodes:
self.nodes[node][asset].update(system_users)
if not in_nodes:
self.nodes[self.ungrouped_node][asset].update(system_users)
for node in nodes:
if node in self.nodes:
self.nodes[node][asset].update(system_users)
else:
self.nodes[self.ungrouped_node][asset].update(system_users)
def get_nodes(self):
for node in self.nodes:
......@@ -84,26 +86,14 @@ class GenerateTree:
node.assets_amount = len(assets)
return self.nodes
# 添加节点时,追溯到根节点
def add_node(self, node):
if node in self.nodes:
return
else:
self.nodes[node] = defaultdict(set)
if node.is_root():
return
for n in self.__all_nodes:
if n.key == node.parent_key:
self.add_node(n)
break
self.nodes[node] = defaultdict(set)
# 添加树节点
def add_nodes(self, nodes):
for node in nodes:
need_nodes = self.node_util.get_family(nodes, with_children=True)
for node in need_nodes:
self.add_node(node)
self.add_nodes(node.get_all_children(with_self=False))
# 如果是直接授权的节点,则放到direct_nodes中
self.direct_nodes.append(node)
def get_user_permissions(user, include_group=True):
......@@ -140,35 +130,28 @@ def get_system_user_permissions(system_user):
)
class AssetPermissionUtil:
get_permissions_map = {
"User": get_user_permissions,
"UserGroup": get_user_group_permissions,
"Asset": get_asset_permissions,
"Node": get_node_permissions,
"SystemUser": get_system_user_permissions,
}
def timeit(func):
def wrapper(*args, **kwargs):
logger.debug("Start call: {}".format(func.__name__))
now = time.time()
result = func(*args, **kwargs)
using = time.time() - now
logger.debug("Call {} end, using: {:.2}".format(func.__name__, using))
return result
return wrapper
class AssetGranted:
def __init__(self):
self.system_users = {}
class AssetPermissionCacheMixin:
CACHE_KEY_PREFIX = '_ASSET_PERM_CACHE_'
CACHE_META_KEY_PREFIX = '_ASSET_PERM_META_KEY_'
CACHE_TIME = settings.ASSETS_PERM_CACHE_TIME
CACHE_POLICY_MAP = (('0', 'never'), ('1', 'using'), ('2', 'refresh'))
def __init__(self, obj, cache_policy='0'):
self.object = obj
self.obj_id = str(obj.id)
self._permissions = None
self._permissions_id = None # 标记_permission的唯一值
self._assets = None
self._filter_id = 'None' # 当通过filter更改 permission是标记
self.cache_policy = cache_policy
self.tree = GenerateTree()
self.change_org_if_need()
@staticmethod
def change_org_if_need():
set_to_root_org()
@classmethod
def is_not_using_cache(cls, cache_policy):
return cls.CACHE_TIME == 0 or cache_policy in cls.CACHE_POLICY_MAP[0]
......@@ -190,94 +173,7 @@ class AssetPermissionUtil:
def _is_refresh_cache(self):
return self.is_refresh_cache(self.cache_policy)
@property
def permissions(self):
if self._permissions:
return self._permissions
object_cls = self.object.__class__.__name__
func = self.get_permissions_map[object_cls]
permissions = func(self.object)
self._permissions = permissions
return permissions
def filter_permissions(self, **filters):
filters_json = json.dumps(filters, sort_keys=True)
self._permissions = self.permissions.filter(**filters)
self._filter_id = md5(filters_json.encode()).hexdigest()
@staticmethod
def _structured_system_user(system_users, actions):
"""
结构化系统用户
:param system_users:
:param actions:
:return: {system_user1: {'actions': set(), }, }
"""
_attr = {'actions': set(actions)}
_system_users = {system_user: _attr for system_user in system_users}
return _system_users
def get_nodes_direct(self):
"""
返回用户/组授权规则直接关联的节点
:return: {asset1: {system_user1: {'actions': set()},}}
"""
nodes = defaultdict(dict)
permissions = self.permissions.prefetch_related('nodes', 'system_users')
for perm in permissions:
actions = perm.actions.all()
self.tree.add_nodes(perm.nodes.all())
for node in perm.nodes.all():
system_users = perm.system_users.all()
system_users = self._structured_system_user(system_users, actions)
nodes[node].update(system_users)
return nodes
def get_assets_direct(self):
"""
返回用户授权规则直接关联的资产
:return: {asset1: {system_user1: {'actions': set()},}}
"""
assets = defaultdict(dict)
permissions = self.permissions.prefetch_related('assets', 'system_users')
for perm in permissions:
actions = perm.actions.all()
for asset in perm.assets.all().valid().prefetch_related('nodes'):
system_users = perm.system_users.filter(protocol__in=asset.protocols_name)
system_users = self._structured_system_user(system_users, actions)
assets[asset].update(system_users)
return assets
def get_assets_without_cache(self):
"""
:return: {asset1: set(system_user1,)}
"""
if self._assets:
return self._assets
assets = self.get_assets_direct()
nodes = self.get_nodes_direct()
for node, system_users in nodes.items():
_assets = node.get_all_assets().valid().prefetch_related('nodes')
for asset in _assets:
for system_user, attr_dict in system_users.items():
if not asset.has_protocol(system_user.protocol):
continue
if system_user in assets[asset]:
actions = assets[asset][system_user]['actions']
attr_dict['actions'].update(actions)
system_users.update({system_user: attr_dict})
assets[asset].update(system_users)
__assets = defaultdict(set)
for asset, system_users in assets.items():
for system_user, attr_dict in system_users.items():
setattr(system_user, 'actions', attr_dict['actions'])
__assets[asset] = set(system_users.keys())
self._assets = __assets
return self._assets
@timeit
def get_cache_key(self, resource):
cache_key = self.CACHE_KEY_PREFIX + '{obj_id}_{filter_id}_{resource}'
return cache_key.format(
......@@ -301,27 +197,6 @@ class AssetPermissionUtil:
cached = cache.get(self.asset_key)
return cached
def get_assets(self):
if self._is_not_using_cache():
return self.get_assets_from_cache()
elif self._is_refresh_cache():
self.expire_cache()
return self.get_assets_from_cache()
else:
self.expire_cache()
return self.get_assets_without_cache()
def get_nodes_with_assets_without_cache(self):
"""
返回节点并且包含资产
{"node": {"assets": set("system_user")}}
:return:
"""
assets = self.get_assets_without_cache()
for asset, system_users in assets.items():
self.tree.add_asset(asset, system_users)
return self.tree.get_nodes()
def get_nodes_with_assets_from_cache(self):
cached = cache.get(self.node_key)
if not cached:
......@@ -338,13 +213,6 @@ class AssetPermissionUtil:
else:
return self.get_nodes_with_assets_without_cache()
def get_system_user_without_cache(self):
system_users = set()
permissions = self.permissions.prefetch_related('system_users')
for perm in permissions:
system_users.update(perm.system_users.all())
return system_users
def get_system_user_from_cache(self):
cached = cache.get(self.system_key)
if not cached:
......@@ -418,6 +286,152 @@ class AssetPermissionUtil:
cache.delete_pattern(key)
class AssetPermissionUtil(AssetPermissionCacheMixin):
get_permissions_map = {
"User": get_user_permissions,
"UserGroup": get_user_group_permissions,
"Asset": get_asset_permissions,
"Node": get_node_permissions,
"SystemUser": get_system_user_permissions,
}
def __init__(self, obj, cache_policy='0'):
self.object = obj
self.obj_id = str(obj.id)
self._permissions = None
self._permissions_id = None # 标记_permission的唯一值
self._assets = None
self._filter_id = 'None' # 当通过filter更改 permission是标记
self.cache_policy = cache_policy
self.tree = GenerateTree()
self.change_org_if_need()
self.nodes = None
@staticmethod
def change_org_if_need():
set_to_root_org()
@property
def permissions(self):
if self._permissions:
return self._permissions
object_cls = self.object.__class__.__name__
func = self.get_permissions_map[object_cls]
permissions = func(self.object)
self._permissions = permissions
return permissions
@timeit
def filter_permissions(self, **filters):
filters_json = json.dumps(filters, sort_keys=True)
self._permissions = self.permissions.filter(**filters)
self._filter_id = md5(filters_json.encode()).hexdigest()
@staticmethod
@timeit
def _structured_system_user(system_users, actions):
"""
结构化系统用户
:param system_users:
:param actions:
:return: {system_user1: {'actions': set(), }, }
"""
_attr = {'actions': set(actions)}
_system_users = {system_user: _attr for system_user in system_users}
return _system_users
@timeit
def get_nodes_direct(self):
"""
返回用户/组授权规则直接关联的节点
:return: {asset1: {system_user1: {'actions': set()},}}
"""
nodes = defaultdict(dict)
permissions = self.permissions.prefetch_related('nodes', 'system_users', 'actions')
for perm in permissions:
actions = perm.actions.all()
for node in perm.nodes.all():
system_users = perm.system_users.all()
system_users = self._structured_system_user(system_users, actions)
nodes[node].update(system_users)
self.tree.add_nodes(nodes.keys())
# 替换成优化过的node
nodes = {self.tree.node_util.get_node_by_key(k.key): v for k, v in nodes.items()}
return nodes
@timeit
def get_assets_direct(self):
"""
返回用户授权规则直接关联的资产
:return: {asset1: {system_user1: {'actions': set()},}}
"""
assets = defaultdict(dict)
permissions = self.permissions.prefetch_related('assets', 'system_users')
for perm in permissions:
actions = perm.actions.all()
for asset in perm.assets.all().valid().prefetch_related('nodes'):
system_users = perm.system_users.filter(protocol__in=asset.protocols_name)
system_users = self._structured_system_user(system_users, actions)
assets[asset].update(system_users)
return assets
@timeit
def get_assets_without_cache(self):
"""
:return: {asset1: set(system_user1,)}
"""
if self._assets:
return self._assets
assets = self.get_assets_direct()
nodes = self.get_nodes_direct()
# for node, system_users in nodes.items():
# print(9999, node)
# _assets = node.get_all_valid_assets()
# print(".......... end .......")
# for asset in _assets:
# print(">>asset")
# for system_user, attr_dict in system_users.items():
# print(">>>system user")
# if not asset.has_protocol(system_user.protocol):
# continue
# if system_user in assets[asset]:
# actions = assets[asset][system_user]['actions']
# attr_dict['actions'].update(actions)
# system_users.update({system_user: attr_dict})
# print("<<<system user")
# print("<<<asset")
# assets[asset].update(system_users)
# print(">>>>>>")
#
__assets = defaultdict(set)
for asset, system_users in assets.items():
for system_user, attr_dict in system_users.items():
setattr(system_user, 'actions', attr_dict['actions'])
__assets[asset] = set(system_users.keys())
self._assets = __assets
return self._assets
@timeit
def get_nodes_with_assets_without_cache(self):
"""
返回节点并且包含资产
{"node": {"assets": set("system_user")}}
:return:
"""
assets = self.get_assets_without_cache()
for asset, system_users in assets.items():
self.tree.add_asset(asset, system_users)
return self.tree.get_nodes()
def get_system_user_without_cache(self):
system_users = set()
permissions = self.permissions.prefetch_related('system_users')
for perm in permissions:
system_users.update(perm.system_users.all())
return system_users
def is_obj_attr_has(obj, val, attrs=("hostname", "ip", "comment")):
if not attrs:
vals = [val for val in obj.__dict__.values() if isinstance(val, (str, int))]
......
......@@ -242,22 +242,3 @@ class CommandStorageDeleteAPI(APIView):
storage_name = str(request.data.get('name'))
Setting.delete_storage('TERMINAL_COMMAND_STORAGE', storage_name)
return Response({"msg": _('Delete succeed')}, status=200)
class DjangoSettingsAPI(APIView):
def get(self, request):
if not settings.DEBUG:
return Response("Not in debug mode")
data = {}
for i in [settings, getattr(settings, '_wrapped')]:
if not i:
continue
for k, v in i.__dict__.items():
if k and k.isupper():
try:
json.dumps(v)
data[k] = v
except (json.JSONDecodeError, TypeError):
data[k] = str(v)
return Response(data)
\ No newline at end of file
......@@ -15,5 +15,4 @@ urlpatterns = [
path('terminal/replay-storage/delete/', api.ReplayStorageDeleteAPI.as_view(), name='replay-storage-delete'),
path('terminal/command-storage/create/', api.CommandStorageCreateAPI.as_view(), name='command-storage-create'),
path('terminal/command-storage/delete/', api.CommandStorageDeleteAPI.as_view(), name='command-storage-delete'),
path('django-settings/', api.DjangoSettingsAPI.as_view(), name='django-settings'),
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册