From 9e910222da57da97efa21e6462544b8da53a1d26 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Mon, 13 Jan 2020 00:00:55 +0300 Subject: [PATCH] Add interfaces to check access rights. --- dbms/src/Access/AccessControlManager.cpp | 7 + dbms/src/Access/AccessControlManager.h | 6 + dbms/src/Access/AccessFlags.h | 330 +++++++++++ dbms/src/Access/AccessRights.cpp | 724 +++++++++++++++++++++++ dbms/src/Access/AccessRights.h | 137 +++++ dbms/src/Access/AccessRightsContext.cpp | 257 ++++++++ dbms/src/Access/AccessRightsContext.h | 82 +++ dbms/src/Access/AccessRightsElement.cpp | 85 +++ dbms/src/Access/AccessRightsElement.h | 100 ++++ dbms/src/Access/AccessType.h | 86 +++ dbms/src/Common/ErrorCodes.cpp | 4 +- dbms/src/Interpreters/Context.cpp | 57 +- dbms/src/Interpreters/Context.h | 29 +- dbms/src/Interpreters/Users.cpp | 10 + dbms/src/Interpreters/Users.h | 7 +- 15 files changed, 1903 insertions(+), 18 deletions(-) create mode 100644 dbms/src/Access/AccessFlags.h create mode 100644 dbms/src/Access/AccessRights.cpp create mode 100644 dbms/src/Access/AccessRights.h create mode 100644 dbms/src/Access/AccessRightsContext.cpp create mode 100644 dbms/src/Access/AccessRightsContext.h create mode 100644 dbms/src/Access/AccessRightsElement.cpp create mode 100644 dbms/src/Access/AccessRightsElement.h create mode 100644 dbms/src/Access/AccessType.h diff --git a/dbms/src/Access/AccessControlManager.cpp b/dbms/src/Access/AccessControlManager.cpp index 249dc54fb0..b38c715d24 100644 --- a/dbms/src/Access/AccessControlManager.cpp +++ b/dbms/src/Access/AccessControlManager.cpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace DB @@ -40,6 +41,12 @@ void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguratio } +std::shared_ptr AccessControlManager::getAccessRightsContext(const ClientInfo & client_info, const AccessRights & granted_to_user, const Settings & settings, const String & current_database) +{ + return std::make_shared(client_info, granted_to_user, settings, current_database); +} + + std::shared_ptr AccessControlManager::createQuotaContext( const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) { diff --git a/dbms/src/Access/AccessControlManager.h b/dbms/src/Access/AccessControlManager.h index 9658dc7161..ddb4653451 100644 --- a/dbms/src/Access/AccessControlManager.h +++ b/dbms/src/Access/AccessControlManager.h @@ -24,6 +24,10 @@ class QuotaContextFactory; struct QuotaUsageInfo; class RowPolicyContext; class RowPolicyContextFactory; +class AccessRights; +class AccessRightsContext; +class ClientInfo; +struct Settings; /// Manages access control entities. @@ -35,6 +39,8 @@ public: void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config); + std::shared_ptr getAccessRightsContext(const ClientInfo & client_info, const AccessRights & granted_to_user, const Settings & settings, const String & current_database); + std::shared_ptr createQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key); diff --git a/dbms/src/Access/AccessFlags.h b/dbms/src/Access/AccessFlags.h new file mode 100644 index 0000000000..f8d3e840cb --- /dev/null +++ b/dbms/src/Access/AccessFlags.h @@ -0,0 +1,330 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +/// Represents a combination of access types which can be granted globally, on databases, tables, columns, etc. +/// For example "SELECT, CREATE USER" is an access type. +class AccessFlags +{ +public: + AccessFlags(AccessType type); + + /// The same as AccessFlags(AccessType::NONE). + AccessFlags() = default; + + /// Constructs from a string like "SELECT". + AccessFlags(const std::string_view & keyword); + + /// Constructs from a list of strings like "SELECT, UPDATE, INSERT". + AccessFlags(const std::vector & keywords); + AccessFlags(const Strings & keywords); + + AccessFlags(const AccessFlags & src) = default; + AccessFlags(AccessFlags && src) = default; + AccessFlags & operator =(const AccessFlags & src) = default; + AccessFlags & operator =(AccessFlags && src) = default; + + /// Returns the access type which contains two specified access types. + AccessFlags & operator |=(const AccessFlags & other) { flags |= other.flags; return *this; } + friend AccessFlags operator |(const AccessFlags & left, const AccessFlags & right) { return AccessFlags(left) |= right; } + + /// Returns the access type which contains the common part of two access types. + AccessFlags & operator &=(const AccessFlags & other) { flags &= other.flags; return *this; } + friend AccessFlags operator &(const AccessFlags & left, const AccessFlags & right) { return AccessFlags(left) &= right; } + + /// Returns the access type which contains only the part of the first access type which is not the part of the second access type. + /// (lhs - rhs) is the same as (lhs & ~rhs). + AccessFlags & operator -=(const AccessFlags & other) { flags &= ~other.flags; return *this; } + friend AccessFlags operator -(const AccessFlags & left, const AccessFlags & right) { return AccessFlags(left) -= right; } + + AccessFlags operator ~() const { AccessFlags res; res.flags = ~flags; return res; } + + bool isEmpty() const { return flags.none(); } + explicit operator bool() const { return !isEmpty(); } + bool contains(const AccessFlags & other) const { return (flags & other.flags) == other.flags; } + + friend bool operator ==(const AccessFlags & left, const AccessFlags & right) { return left.flags == right.flags; } + friend bool operator !=(const AccessFlags & left, const AccessFlags & right) { return !(left == right); } + + void clear() { flags.reset(); } + + /// Returns a comma-separated list of keywords, like "SELECT, CREATE USER, UPDATE". + String toString() const; + + /// Returns a list of keywords. + std::vector toKeywords() const; + + /// Returns the access types which could be granted on the database level. + /// For example, SELECT can be granted on the database level, but CREATE_USER cannot. + static AccessFlags databaseLevel(); + + /// Returns the access types which could be granted on the table/dictionary level. + static AccessFlags tableLevel(); + + /// Returns the access types which could be granted on the column/attribute level. + static AccessFlags columnLevel(); + +private: + static constexpr size_t NUM_FLAGS = 64; + using Flags = std::bitset; + Flags flags; + + AccessFlags(const Flags & flags_) : flags(flags_) {} + + template + class Impl; +}; + + +namespace ErrorCodes +{ + extern const int UNKNOWN_ACCESS_TYPE; +} + +template +class AccessFlags::Impl +{ +public: + static const Impl & instance() + { + static const Impl res; + return res; + } + + Flags accessTypeToFlags(AccessType type) const + { + return access_type_to_flags_mapping[static_cast(type)]; + } + + Flags keywordToFlags(const std::string_view & keyword) const + { + auto it = keyword_to_flags_map.find(keyword); + if (it == keyword_to_flags_map.end()) + { + String uppercased_keyword{keyword}; + boost::to_upper(uppercased_keyword); + it = keyword_to_flags_map.find(uppercased_keyword); + if (it == keyword_to_flags_map.end()) + throw Exception("Unknown access type: " + String(keyword), ErrorCodes::UNKNOWN_ACCESS_TYPE); + } + return it->second; + } + + Flags keywordsToFlags(const std::vector & keywords) const + { + Flags res; + for (const auto & keyword : keywords) + res |= keywordToFlags(keyword); + return res; + } + + Flags keywordsToFlags(const Strings & keywords) const + { + Flags res; + for (const auto & keyword : keywords) + res |= keywordToFlags(keyword); + return res; + } + + std::vector flagsToKeywords(const Flags & flags_) const + { + std::vector keywords; + flagsToKeywordsRec(flags_, keywords, *flags_to_keyword_tree); + + if (keywords.empty()) + keywords.push_back("USAGE"); + + return keywords; + } + + String flagsToString(const Flags & flags_) const + { + String str; + for (const auto & keyword : flagsToKeywords(flags_)) + { + if (!str.empty()) + str += ", "; + str += keyword; + } + return str; + } + + const Flags & getDatabaseLevelFlags() const { return all_grantable_on_level[DATABASE_LEVEL]; } + const Flags & getTableLevelFlags() const { return all_grantable_on_level[TABLE_LEVEL]; } + const Flags & getColumnLevelFlags() const { return all_grantable_on_level[COLUMN_LEVEL]; } + +private: + enum Level + { + UNKNOWN_LEVEL = -1, + GLOBAL_LEVEL = 0, + DATABASE_LEVEL = 1, + TABLE_LEVEL = 2, + VIEW_LEVEL = 2, + DICTIONARY_LEVEL = 2, + COLUMN_LEVEL = 3, + }; + + struct Node; + using NodePtr = std::unique_ptr; + using Nodes = std::vector; + + template + static Nodes nodes(Args&& ... args) + { + Nodes res; + ext::push_back(res, std::move(args)...); + return res; + } + + struct Node + { + std::string_view keyword; + std::vector aliases; + Flags flags; + Level level = UNKNOWN_LEVEL; + Nodes children; + + Node(std::string_view keyword_, size_t flag_, Level level_) + : keyword(keyword_), level(level_) + { + flags.set(flag_); + } + + Node(std::string_view keyword_, Nodes children_) + : keyword(keyword_), children(std::move(children_)) + { + for (const auto & child : children) + flags |= child->flags; + } + + template + Node(std::string_view keyword_, NodePtr first_child, Args &&... other_children) + : Node(keyword_, nodes(std::move(first_child), std::move(other_children)...)) {} + }; + + static void flagsToKeywordsRec(const Flags & flags_, std::vector & keywords, const Node & start_node) + { + Flags matching_flags = (flags_ & start_node.flags); + if (matching_flags.any()) + { + if (matching_flags == start_node.flags) + { + keywords.push_back(start_node.keyword); + } + else + { + for (const auto & child : start_node.children) + flagsToKeywordsRec(flags_, keywords, *child); + } + } + } + + static void makeFlagsToKeywordTree(NodePtr & flags_to_keyword_tree_) + { + size_t next_flag = 0; + Nodes all; + + auto show = std::make_unique("SHOW", next_flag++, COLUMN_LEVEL); + auto exists = std::make_unique("EXISTS", next_flag++, COLUMN_LEVEL); + ext::push_back(all, std::move(show), std::move(exists)); + + auto select = std::make_unique("SELECT", next_flag++, COLUMN_LEVEL); + auto insert = std::make_unique("INSERT", next_flag++, COLUMN_LEVEL); + auto update = std::make_unique("UPDATE", next_flag++, COLUMN_LEVEL); + auto delet = std::make_unique("DELETE", next_flag++, TABLE_LEVEL); + ext::push_back(all, std::move(select), std::move(insert), std::move(update), std::move(delet)); + + flags_to_keyword_tree_ = std::make_unique("ALL", std::move(all)); + flags_to_keyword_tree_->aliases.push_back("ALL PRIVILEGES"); + } + + void makeKeywordToFlagsMap(std::unordered_map & keyword_to_flags_map_, Node * start_node = nullptr) + { + if (!start_node) + { + start_node = flags_to_keyword_tree.get(); + keyword_to_flags_map_["USAGE"] = {}; + keyword_to_flags_map_["NONE"] = {}; + keyword_to_flags_map_["NO PRIVILEGES"] = {}; + } + start_node->aliases.emplace_back(start_node->keyword); + for (auto & alias : start_node->aliases) + { + boost::to_upper(alias); + keyword_to_flags_map_[alias] = start_node->flags; + } + for (auto & child : start_node->children) + makeKeywordToFlagsMap(keyword_to_flags_map_, child.get()); + } + + void makeAccessTypeToFlagsMapping(std::vector & access_type_to_flags_mapping_) + { + access_type_to_flags_mapping_.resize(MAX_ACCESS_TYPE); + for (auto access_type : ext::range_with_static_cast(0, MAX_ACCESS_TYPE)) + { + auto str = toKeyword(access_type); + auto it = keyword_to_flags_map.find(str); + if (it == keyword_to_flags_map.end()) + { + String uppercased{str}; + boost::to_upper(uppercased); + it = keyword_to_flags_map.find(uppercased); + } + access_type_to_flags_mapping_[static_cast(access_type)] = it->second; + } + } + + void collectAllGrantableOnLevel(std::vector & all_grantable_on_level_, const Node * start_node = nullptr) + { + if (!start_node) + { + start_node = flags_to_keyword_tree.get(); + all_grantable_on_level.resize(COLUMN_LEVEL + 1); + } + for (int i = 0; i <= start_node->level; ++i) + all_grantable_on_level_[i] |= start_node->flags; + for (const auto & child : start_node->children) + collectAllGrantableOnLevel(all_grantable_on_level_, child.get()); + } + + Impl() + { + makeFlagsToKeywordTree(flags_to_keyword_tree); + makeKeywordToFlagsMap(keyword_to_flags_map); + makeAccessTypeToFlagsMapping(access_type_to_flags_mapping); + collectAllGrantableOnLevel(all_grantable_on_level); + } + + std::unique_ptr flags_to_keyword_tree; + std::unordered_map keyword_to_flags_map; + std::vector access_type_to_flags_mapping; + std::vector all_grantable_on_level; +}; + + +inline AccessFlags::AccessFlags(AccessType type) : flags(Impl<>::instance().accessTypeToFlags(type)) {} +inline AccessFlags::AccessFlags(const std::string_view & keyword) : flags(Impl<>::instance().keywordToFlags(keyword)) {} +inline AccessFlags::AccessFlags(const std::vector & keywords) : flags(Impl<>::instance().keywordsToFlags(keywords)) {} +inline AccessFlags::AccessFlags(const Strings & keywords) : flags(Impl<>::instance().keywordsToFlags(keywords)) {} +inline String AccessFlags::toString() const { return Impl<>::instance().flagsToString(flags); } +inline std::vector AccessFlags::toKeywords() const { return Impl<>::instance().flagsToKeywords(flags); } +inline AccessFlags AccessFlags::databaseLevel() { return Impl<>::instance().getDatabaseLevelFlags(); } +inline AccessFlags AccessFlags::tableLevel() { return Impl<>::instance().getTableLevelFlags(); } +inline AccessFlags AccessFlags::columnLevel() { return Impl<>::instance().getColumnLevelFlags(); } + +inline AccessFlags operator |(AccessType left, AccessType right) { return AccessFlags(left) | right; } +inline AccessFlags operator &(AccessType left, AccessType right) { return AccessFlags(left) & right; } +inline AccessFlags operator -(AccessType left, AccessType right) { return AccessFlags(left) - right; } +inline AccessFlags operator ~(AccessType x) { return ~AccessFlags(x); } + +} diff --git a/dbms/src/Access/AccessRights.cpp b/dbms/src/Access/AccessRights.cpp new file mode 100644 index 0000000000..ceb64b2b41 --- /dev/null +++ b/dbms/src/Access/AccessRights.cpp @@ -0,0 +1,724 @@ +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int INVALID_GRANT; + extern const int LOGICAL_ERROR; +} + + +namespace +{ + enum Level + { + GLOBAL_LEVEL, + DATABASE_LEVEL, + TABLE_LEVEL, + COLUMN_LEVEL, + }; + + enum RevokeMode + { + NORMAL_REVOKE_MODE, /// for AccessRights::revoke() + PARTIAL_REVOKE_MODE, /// for AccessRights::partialRevoke() + FULL_REVOKE_MODE, /// for AccessRights::fullRevoke() + }; + + struct Helper + { + static const Helper & instance() + { + static const Helper res; + return res; + } + + const AccessFlags database_level_flags = AccessFlags::databaseLevel(); + const AccessFlags table_level_flags = AccessFlags::tableLevel(); + const AccessFlags column_level_flags = AccessFlags::columnLevel(); + const AccessFlags show_flag = AccessType::SHOW; + const AccessFlags exists_flag = AccessType::EXISTS; + }; +} + + +struct AccessRights::Node +{ +public: + std::shared_ptr node_name; + Level level = GLOBAL_LEVEL; + AccessFlags explicit_grants; + AccessFlags partial_revokes; + AccessFlags inherited_access; /// the access inherited from the parent node + AccessFlags raw_access; /// raw_access = (inherited_access - partial_revokes) | explicit_grants + AccessFlags access; /// access = raw_access | implicit_access + AccessFlags min_access; /// min_access = access & child[0].access & ... | child[N-1].access + AccessFlags max_access; /// max_access = access | child[0].access | ... | child[N-1].access + std::unique_ptr> children; + + Node() = default; + Node(const Node & src) { *this = src; } + + Node & operator =(const Node & src) + { + node_name = src.node_name; + level = src.level; + inherited_access = src.inherited_access; + explicit_grants = src.explicit_grants; + partial_revokes = src.partial_revokes; + access = src.access; + min_access = src.min_access; + max_access = src.max_access; + if (src.children) + children = std::make_unique>(*src.children); + else + children = nullptr; + return *this; + } + + void grant(AccessFlags access_to_grant, const Helper & helper) + { + if (!access_to_grant) + return; + + if (level == GLOBAL_LEVEL) + { + /// Everything can be granted on the global level. + } + else if (level == DATABASE_LEVEL) + { + AccessFlags grantable = access_to_grant & helper.database_level_flags; + if (!grantable) + throw Exception(access_to_grant.toString() + " cannot be granted on the database level", ErrorCodes::INVALID_GRANT); + access_to_grant = grantable; + } + else if (level == TABLE_LEVEL) + { + AccessFlags grantable = access_to_grant & helper.table_level_flags; + if (!grantable) + throw Exception(access_to_grant.toString() + " cannot be granted on the table level", ErrorCodes::INVALID_GRANT); + access_to_grant = grantable; + } + else if (level == COLUMN_LEVEL) + { + AccessFlags grantable = access_to_grant & helper.column_level_flags; + if (!grantable) + throw Exception(access_to_grant.toString() + " cannot be granted on the column level", ErrorCodes::INVALID_GRANT); + access_to_grant = grantable; + } + + explicit_grants |= access_to_grant - partial_revokes; + partial_revokes -= access_to_grant; + calculateAllAccessRec(helper); + } + + template + void grant(const AccessFlags & access_to_grant, const Helper & helper, const std::string_view & name, const Args &... subnames) + { + auto & child = getChild(name); + child.grant(access_to_grant, helper, subnames...); + eraseChildIfEmpty(child); + calculateImplicitAccess(helper); + calculateMinAndMaxAccess(); + } + + template + void grant(const AccessFlags & access_to_grant, const Helper & helper, const std::vector & names) + { + for (const auto & name : names) + { + auto & child = getChild(name); + child.grant(access_to_grant, helper); + eraseChildIfEmpty(child); + } + calculateImplicitAccess(helper); + calculateMinAndMaxAccess(); + } + + template + void revoke(const AccessFlags & access_to_revoke, const Helper & helper) + { + if constexpr (mode == NORMAL_REVOKE_MODE) + { + explicit_grants -= access_to_revoke; + } + else if constexpr (mode == PARTIAL_REVOKE_MODE) + { + partial_revokes |= access_to_revoke - explicit_grants; + explicit_grants -= access_to_revoke; + } + else /// mode == FULL_REVOKE_MODE + { + fullRevokeRec(access_to_revoke); + } + calculateAllAccessRec(helper); + } + + template + void revoke(const AccessFlags & access_to_revoke, const Helper & helper, const std::string_view & name, const Args &... subnames) + { + Node * child; + if (mode == NORMAL_REVOKE_MODE) + { + if (!(child = tryGetChild(name))) + return; + } + else + child = &getChild(name); + + child->revoke(access_to_revoke, helper, subnames...); + eraseChildIfEmpty(*child); + calculateImplicitAccess(helper); + calculateMinAndMaxAccess(); + } + + template + void revoke(const AccessFlags & access_to_revoke, const Helper & helper, const std::vector & names) + { + Node * child; + for (const auto & name : names) + { + if (mode == NORMAL_REVOKE_MODE) + { + if (!(child = tryGetChild(name))) + continue; + } + else + child = &getChild(name); + + child->revoke(access_to_revoke, helper); + eraseChildIfEmpty(*child); + } + calculateImplicitAccess(helper); + calculateMinAndMaxAccess(); + } + + bool isGranted(const AccessFlags & flags) const + { + return min_access.contains(flags); + } + + template + bool isGranted(AccessFlags flags, const std::string_view & name, const Args &... subnames) const + { + if (min_access.contains(flags)) + return true; + if (!max_access.contains(flags)) + return false; + + const Node * child = tryGetChild(name); + if (child) + return child->isGranted(flags, subnames...); + else + return access.contains(flags); + } + + template + bool isGranted(AccessFlags flags, const std::vector & names) const + { + if (min_access.contains(flags)) + return true; + if (!max_access.contains(flags)) + return false; + + for (const auto & name : names) + { + const Node * child = tryGetChild(name); + if (child) + { + if (!child->isGranted(flags, name)) + return false; + } + else + { + if (!access.contains(flags)) + return false; + } + } + return true; + } + + friend bool operator ==(const Node & left, const Node & right) + { + if ((left.explicit_grants != right.explicit_grants) || (left.partial_revokes != right.partial_revokes)) + return false; + + if (!left.children) + return !right.children; + + if (!right.children) + return false; + return *left.children == *right.children; + } + + friend bool operator!=(const Node & left, const Node & right) { return !(left == right); } + + bool isEmpty() const + { + return !explicit_grants && !partial_revokes && !children; + } + + void merge(const Node & other, const Helper & helper) + { + mergeRawAccessRec(other); + calculateGrantsAndPartialRevokesRec(); + calculateAllAccessRec(helper); + } + +private: + Node * tryGetChild(const std::string_view & name) + { + if (!children) + return nullptr; + auto it = children->find(name); + if (it == children->end()) + return nullptr; + return &it->second; + } + + const Node * tryGetChild(const std::string_view & name) const + { + if (!children) + return nullptr; + auto it = children->find(name); + if (it == children->end()) + return nullptr; + return &it->second; + } + + Node & getChild(const std::string_view & name) + { + auto * child = tryGetChild(name); + if (child) + return *child; + if (!children) + children = std::make_unique>(); + auto new_child_name = std::make_shared(name); + Node & new_child = (*children)[*new_child_name]; + new_child.node_name = std::move(new_child_name); + new_child.level = static_cast(level + 1); + new_child.inherited_access = raw_access; + new_child.raw_access = raw_access; + return new_child; + } + + void eraseChildIfEmpty(Node & child) + { + if (!child.isEmpty()) + return; + auto it = children->find(*child.node_name); + children->erase(it); + if (children->empty()) + children = nullptr; + } + + void calculateImplicitAccess(const Helper & helper) + { + access = raw_access; + if (access & helper.database_level_flags) + access |= helper.show_flag | helper.exists_flag; + else if ((level >= DATABASE_LEVEL) && children) + access |= helper.exists_flag; + } + + void calculateMinAndMaxAccess() + { + min_access = access; + max_access = access; + if (children) + { + for (const auto & child : *children | boost::adaptors::map_values) + { + min_access &= child.min_access; + max_access |= child.max_access; + } + } + } + + void calculateAllAccessRec(const Helper & helper) + { + partial_revokes &= inherited_access; + raw_access = (inherited_access - partial_revokes) | explicit_grants; + + /// Traverse tree. + if (children) + { + for (auto it = children->begin(); it != children->end();) + { + auto & child = it->second; + child.inherited_access = raw_access; + child.calculateAllAccessRec(helper); + if (child.isEmpty()) + it = children->erase(it); + else + ++it; + } + if (children->empty()) + children = nullptr; + } + + calculateImplicitAccess(helper); + calculateMinAndMaxAccess(); + } + + void fullRevokeRec(const AccessFlags & access_to_revoke) + { + explicit_grants -= access_to_revoke; + partial_revokes |= access_to_revoke; + if (children) + { + for (auto & child : *children | boost::adaptors::map_values) + child.fullRevokeRec(access_to_revoke); + } + } + + void mergeRawAccessRec(const Node & rhs) + { + if (rhs.children) + { + for (const auto & [rhs_childname, rhs_child] : *rhs.children) + getChild(rhs_childname).mergeRawAccessRec(rhs_child); + } + raw_access |= rhs.raw_access; + if (children) + { + for (auto & [lhs_childname, lhs_child] : *children) + { + lhs_child.inherited_access = raw_access; + if (!rhs.tryGetChild(lhs_childname)) + lhs_child.raw_access |= rhs.raw_access; + } + } + } + + void calculateGrantsAndPartialRevokesRec() + { + explicit_grants = raw_access - inherited_access; + partial_revokes = inherited_access - raw_access; + if (children) + { + for (auto & child : *children | boost::adaptors::map_values) + child.calculateGrantsAndPartialRevokesRec(); + } + } +}; + + +AccessRights::AccessRights() = default; +AccessRights::~AccessRights() = default; +AccessRights::AccessRights(AccessRights && src) = default; +AccessRights & AccessRights::operator =(AccessRights && src) = default; + + +AccessRights::AccessRights(const AccessRights & src) +{ + *this = src; +} + + +AccessRights & AccessRights::operator =(const AccessRights & src) +{ + if (src.root) + root = std::make_unique(*src.root); + else + root = nullptr; + return *this; +} + + +AccessRights::AccessRights(const AccessFlags & access) +{ + grant(access); +} + + +bool AccessRights::isEmpty() const +{ + return !root; +} + + +void AccessRights::clear() +{ + root = nullptr; +} + + +template +void AccessRights::grantImpl(const AccessFlags & access, const Args &... args) +{ + if (!root) + root = std::make_unique(); + root->grant(access, Helper::instance(), args...); + if (root->isEmpty()) + root = nullptr; +} + +void AccessRights::grantImpl(const AccessRightsElement & element, std::string_view current_database) +{ + if (element.any_database) + { + grantImpl(element.access_flags); + } + else if (element.any_table) + { + if (element.database.empty()) + grantImpl(element.access_flags, current_database); + else + grantImpl(element.access_flags, element.database); + } + else if (element.any_column) + { + if (element.database.empty()) + grantImpl(element.access_flags, current_database, element.table); + else + grantImpl(element.access_flags, element.database, element.table); + } + else + { + if (element.database.empty()) + grantImpl(element.access_flags, current_database, element.table, element.columns); + else + grantImpl(element.access_flags, element.database, element.table, element.columns); + } +} + +void AccessRights::grantImpl(const AccessRightsElements & elements, std::string_view current_database) +{ + for (const auto & element : elements) + grantImpl(element, current_database); +} + +void AccessRights::grant(const AccessFlags & access) { grantImpl(access); } +void AccessRights::grant(const AccessFlags & access, const std::string_view & database) { grantImpl(access, database); } +void AccessRights::grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table) { grantImpl(access, database, table); } +void AccessRights::grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) { grantImpl(access, database, table, column); } +void AccessRights::grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) { grantImpl(access, database, table, columns); } +void AccessRights::grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) { grantImpl(access, database, table, columns); } +void AccessRights::grant(const AccessRightsElement & element, std::string_view current_database) { grantImpl(element, current_database); } +void AccessRights::grant(const AccessRightsElements & elements, std::string_view current_database) { grantImpl(elements, current_database); } + +template +void AccessRights::revokeImpl(const AccessFlags & access, const Args &... args) +{ + if (!root) + return; + root->revoke(access, Helper::instance(), args...); + if (root->isEmpty()) + root = nullptr; +} + +template +void AccessRights::revokeImpl(const AccessRightsElement & element, std::string_view current_database) +{ + if (element.any_database) + { + revokeImpl(element.access_flags); + } + else if (element.any_table) + { + if (element.database.empty()) + revokeImpl(element.access_flags, current_database); + else + revokeImpl(element.access_flags, element.database); + } + else if (element.any_column) + { + if (element.database.empty()) + revokeImpl(element.access_flags, current_database, element.table); + else + revokeImpl(element.access_flags, element.database, element.table); + } + else + { + if (element.database.empty()) + revokeImpl(element.access_flags, current_database, element.table, element.columns); + else + revokeImpl(element.access_flags, element.database, element.table, element.columns); + } +} + +template +void AccessRights::revokeImpl(const AccessRightsElements & elements, std::string_view current_database) +{ + for (const auto & element : elements) + revokeImpl(element, current_database); +} + +void AccessRights::revoke(const AccessFlags & access) { revokeImpl(access); } +void AccessRights::revoke(const AccessFlags & access, const std::string_view & database) { revokeImpl(access, database); } +void AccessRights::revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table) { revokeImpl(access, database, table); } +void AccessRights::revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(access, database, table, column); } +void AccessRights::revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(access, database, table, columns); } +void AccessRights::revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(access, database, table, columns); } +void AccessRights::revoke(const AccessRightsElement & element, std::string_view current_database) { revokeImpl(element, current_database); } +void AccessRights::revoke(const AccessRightsElements & elements, std::string_view current_database) { revokeImpl(elements, current_database); } + +void AccessRights::partialRevoke(const AccessFlags & access) { revokeImpl(access); } +void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database) { revokeImpl(access, database); } +void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table) { revokeImpl(access, database, table); } +void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(access, database, table, column); } +void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(access, database, table, columns); } +void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(access, database, table, columns); } +void AccessRights::partialRevoke(const AccessRightsElement & element, std::string_view current_database) { revokeImpl(element, current_database); } +void AccessRights::partialRevoke(const AccessRightsElements & elements, std::string_view current_database) { revokeImpl(elements, current_database); } + +void AccessRights::fullRevoke(const AccessFlags & access) { revokeImpl(access); } +void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database) { revokeImpl(access, database); } +void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table) { revokeImpl(access, database, table); } +void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(access, database, table, column); } +void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(access, database, table, columns); } +void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(access, database, table, columns); } +void AccessRights::fullRevoke(const AccessRightsElement & element, std::string_view current_database) { revokeImpl(element, current_database); } +void AccessRights::fullRevoke(const AccessRightsElements & elements, std::string_view current_database) { revokeImpl(elements, current_database); } + + +AccessRights::Elements AccessRights::getElements() const +{ + if (!root) + return {}; + Elements res; + if (root->explicit_grants) + res.grants.push_back({root->explicit_grants}); + if (root->children) + { + for (const auto & [db_name, db_node] : *root->children) + { + if (db_node.partial_revokes) + res.partial_revokes.push_back({db_node.partial_revokes, db_name}); + if (db_node.explicit_grants) + res.grants.push_back({db_node.explicit_grants, db_name}); + if (db_node.children) + { + for (const auto & [table_name, table_node] : *db_node.children) + { + if (table_node.partial_revokes) + res.partial_revokes.push_back({table_node.partial_revokes, db_name, table_name}); + if (table_node.explicit_grants) + res.grants.push_back({table_node.explicit_grants, db_name, table_name}); + if (table_node.children) + { + for (const auto & [column_name, column_node] : *table_node.children) + { + if (column_node.partial_revokes) + res.partial_revokes.push_back({column_node.partial_revokes, db_name, table_name, column_name}); + if (column_node.explicit_grants) + res.grants.push_back({column_node.explicit_grants, db_name, table_name, column_name}); + } + } + } + } + } + } + return res; +} + + +String AccessRights::toString() const +{ + auto elements = getElements(); + String res; + if (!elements.grants.empty()) + { + res += "GRANT "; + res += elements.grants.toString(); + } + if (!elements.partial_revokes.empty()) + { + if (!res.empty()) + res += ", "; + res += "REVOKE "; + res += elements.partial_revokes.toString(); + } + if (res.empty()) + res = "GRANT USAGE ON *.*"; + return res; +} + + +template +bool AccessRights::isGrantedImpl(const AccessFlags & access, const Args &... args) const +{ + if (!root) + return access.isEmpty(); + return root->isGranted(access, args...); +} + +bool AccessRights::isGrantedImpl(const AccessRightsElement & element, std::string_view current_database) const +{ + if (element.any_database) + { + return isGrantedImpl(element.access_flags); + } + else if (element.any_table) + { + if (element.database.empty()) + return isGrantedImpl(element.access_flags, current_database); + else + return isGrantedImpl(element.access_flags, element.database); + } + else if (element.any_column) + { + if (element.database.empty()) + return isGrantedImpl(element.access_flags, current_database, element.table); + else + return isGrantedImpl(element.access_flags, element.database, element.table); + } + else + { + if (element.database.empty()) + return isGrantedImpl(element.access_flags, current_database, element.table, element.columns); + else + return isGrantedImpl(element.access_flags, element.database, element.table, element.columns); + } +} + +bool AccessRights::isGrantedImpl(const AccessRightsElements & elements, std::string_view current_database) const +{ + for (const auto & element : elements) + if (!isGrantedImpl(element, current_database)) + return false; + return true; +} + +bool AccessRights::isGranted(const AccessFlags & access) const { return isGrantedImpl(access); } +bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database) const { return isGrantedImpl(access, database); } +bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return isGrantedImpl(access, database, table); } +bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return isGrantedImpl(access, database, table, column); } +bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return isGrantedImpl(access, database, table, columns); } +bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return isGrantedImpl(access, database, table, columns); } +bool AccessRights::isGranted(const AccessRightsElement & element, std::string_view current_database) const { return isGrantedImpl(element, current_database); } +bool AccessRights::isGranted(const AccessRightsElements & elements, std::string_view current_database) const { return isGrantedImpl(elements, current_database); } + + +bool operator ==(const AccessRights & left, const AccessRights & right) +{ + if (!left.root) + return !right.root; + if (!right.root) + return false; + return *left.root == *right.root; +} + + +void AccessRights::merge(const AccessRights & other) +{ + if (!root) + { + *this = other; + return; + } + if (other.root) + { + root->merge(*other.root, Helper::instance()); + if (root->isEmpty()) + root = nullptr; + } +} + +} diff --git a/dbms/src/Access/AccessRights.h b/dbms/src/Access/AccessRights.h new file mode 100644 index 0000000000..4eefcb0d6d --- /dev/null +++ b/dbms/src/Access/AccessRights.h @@ -0,0 +1,137 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB +{ +/// Represents a set of access types granted on databases, tables, columns, etc. +/// For example, "GRANT SELECT, UPDATE ON db.*, GRANT INSERT ON db2.mytbl2" are access rights. +class AccessRights +{ +public: + AccessRights(); + AccessRights(const AccessFlags & access); + ~AccessRights(); + AccessRights(const AccessRights & src); + AccessRights & operator =(const AccessRights & src); + AccessRights(AccessRights && src); + AccessRights & operator =(AccessRights && src); + + bool isEmpty() const; + + /// Revokes everything. It's the same as fullRevoke(AccessType::ALL). + void clear(); + + /// Grants access on a specified database/table/column. + /// Does nothing if the specified access has been already granted. + void grant(const AccessFlags & access); + void grant(const AccessFlags & access, const std::string_view & database); + void grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table); + void grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column); + void grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns); + void grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns); + void grant(const AccessRightsElement & access, std::string_view current_database = {}); + void grant(const AccessRightsElements & access, std::string_view current_database = {}); + + /// Revokes a specified access granted earlier on a specified database/table/column. + /// Does nothing if the specified access is not granted. + /// If the specified access is granted but on upper level (e.g. database for table, table for columns) + /// or lower level, the function also does nothing. + /// This function implements the standard SQL REVOKE behaviour. + void revoke(const AccessFlags & access); + void revoke(const AccessFlags & access, const std::string_view & database); + void revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table); + void revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column); + void revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns); + void revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns); + void revoke(const AccessRightsElement & access, std::string_view current_database = {}); + void revoke(const AccessRightsElements & access, std::string_view current_database = {}); + + /// Revokes a specified access granted earlier on a specified database/table/column or on lower levels. + /// The function also restricts access if it's granted on upper level. + /// For example, an access could be granted on a database and then revoked on a table in this database. + /// This function implements the MySQL REVOKE behaviour with partial_revokes is ON. + void partialRevoke(const AccessFlags & access); + void partialRevoke(const AccessFlags & access, const std::string_view & database); + void partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table); + void partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column); + void partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns); + void partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns); + void partialRevoke(const AccessRightsElement & access, std::string_view current_database = {}); + void partialRevoke(const AccessRightsElements & access, std::string_view current_database = {}); + + /// Revokes a specified access granted earlier on a specified database/table/column or on lower levels. + /// The function also restricts access if it's granted on upper level. + /// For example, fullRevoke(AccessType::ALL) revokes all grants at all, just like clear(); + /// fullRevoke(AccessType::SELECT, db) means it's not allowed to execute SELECT in that database anymore (from any table). + void fullRevoke(const AccessFlags & access); + void fullRevoke(const AccessFlags & access, const std::string_view & database); + void fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table); + void fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column); + void fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns); + void fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns); + void fullRevoke(const AccessRightsElement & access, std::string_view current_database = {}); + void fullRevoke(const AccessRightsElements & access, std::string_view current_database = {}); + + /// Returns the information about all the access granted. + struct Elements + { + AccessRightsElements grants; + AccessRightsElements partial_revokes; + }; + Elements getElements() const; + + /// Returns the information about all the access granted as a string. + String toString() const; + + /// Whether a specified access granted. + bool isGranted(const AccessFlags & access) const; + bool isGranted(const AccessFlags & access, const std::string_view & database) const; + bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; + bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + bool isGranted(const AccessRightsElement & access, std::string_view current_database = {}) const; + bool isGranted(const AccessRightsElements & access, std::string_view current_database = {}) const; + + friend bool operator ==(const AccessRights & left, const AccessRights & right); + friend bool operator !=(const AccessRights & left, const AccessRights & right) { return !(left == right); } + + /// Merges two sets of access rights together. + /// It's used to combine access rights from multiple roles. + void merge(const AccessRights & other); + +private: + template + void grantImpl(const AccessFlags & access, const Args &... args); + + void grantImpl(const AccessRightsElement & access, std::string_view current_database); + void grantImpl(const AccessRightsElements & access, std::string_view current_database); + + template + void revokeImpl(const AccessFlags & access, const Args &... args); + + template + void revokeImpl(const AccessRightsElement & access, std::string_view current_database); + + template + void revokeImpl(const AccessRightsElements & access, std::string_view current_database); + + template + bool isGrantedImpl(const AccessFlags & access, const Args &... args) const; + + bool isGrantedImpl(const AccessRightsElement & access, std::string_view current_database) const; + bool isGrantedImpl(const AccessRightsElements & access, std::string_view current_database) const; + + template + AccessFlags getAccessImpl(const Args &... args) const; + + struct Node; + std::unique_ptr root; +}; + +} diff --git a/dbms/src/Access/AccessRightsContext.cpp b/dbms/src/Access/AccessRightsContext.cpp new file mode 100644 index 0000000000..3de8efce4e --- /dev/null +++ b/dbms/src/Access/AccessRightsContext.cpp @@ -0,0 +1,257 @@ +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ACCESS_DENIED; + extern const int READONLY; + extern const int QUERY_IS_PROHIBITED; + extern const int FUNCTION_NOT_ALLOWED; +} + + +namespace +{ + enum CheckAccessRightsMode + { + RETURN_FALSE_IF_ACCESS_DENIED, + LOG_WARNING_IF_ACCESS_DENIED, + THROW_IF_ACCESS_DENIED, + }; + + + String formatSkippedMessage() + { + return ""; + } + + String formatSkippedMessage(const std::string_view & database) + { + return ". Skipped database " + backQuoteIfNeed(database); + } + + String formatSkippedMessage(const std::string_view & database, const std::string_view & table) + { + String str = ". Skipped table "; + if (!database.empty()) + str += backQuoteIfNeed(database) + "."; + str += backQuoteIfNeed(table); + return str; + } + + String formatSkippedMessage(const std::string_view & database, const std::string_view & table, const std::string_view & column) + { + String str = ". Skipped column " + backQuoteIfNeed(column) + " ON "; + if (!database.empty()) + str += backQuoteIfNeed(database) + "."; + str += backQuoteIfNeed(table); + return str; + } + + template + String formatSkippedMessage(const std::string_view & database, const std::string_view & table, const std::vector & columns) + { + if (columns.size() == 1) + return formatSkippedMessage(database, table, columns[0]); + + String str = ". Skipped columns "; + bool need_comma = false; + for (const auto & column : columns) + { + if (std::exchange(need_comma, true)) + str += ", "; + str += backQuoteIfNeed(column); + } + str += " ON "; + if (!database.empty()) + str += backQuoteIfNeed(database) + "."; + str += backQuoteIfNeed(table); + return str; + } +} + + +AccessRightsContext::AccessRightsContext() +{ + result_access_cache[0].emplace().grant(AccessType::ALL); +} + + +AccessRightsContext::AccessRightsContext(const ClientInfo & client_info_, const AccessRights & granted_to_user_, const Settings & settings, const String & current_database_) + : user_name(client_info_.current_user) + , granted_to_user(granted_to_user_) + , readonly(settings.readonly) + , allow_ddl(settings.allow_ddl) + , allow_introspection(settings.allow_introspection_functions) + , current_database(current_database_) + , interface(client_info_.interface) + , http_method(client_info_.http_method) + , trace_log(&Poco::Logger::get("AccessRightsContext (" + user_name + ")")) +{ +} + + +template +bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const +{ + std::lock_guard lock{mutex}; + const auto & result_access = calculateResultAccess(); + bool is_granted = result_access.isGranted(access, args...); + + if (trace_log) + LOG_TRACE(trace_log, "Access " << (is_granted ? "granted" : "denied") << ": " << (AccessRightsElement{access, args...}.toString())); + + if (is_granted) + return true; + + if constexpr (mode == RETURN_FALSE_IF_ACCESS_DENIED) + return false; + + if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED) + { + if (!log_) + return false; + } + + auto show_error = [&](const String & msg, [[maybe_unused]] int error_code) + { + if constexpr (mode == THROW_IF_ACCESS_DENIED) + throw Exception(msg, error_code); + else if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED) + LOG_WARNING(log_, msg + formatSkippedMessage(args...)); + }; + + if (readonly && calculateResultAccess(false, allow_ddl, allow_introspection).isGranted(access, args...)) + { + if (interface == ClientInfo::Interface::HTTP && http_method == ClientInfo::HTTPMethod::GET) + show_error( + "Cannot execute query in readonly mode. " + "For queries over HTTP, method GET implies readonly. You should use method POST for modifying queries", + ErrorCodes::READONLY); + else + show_error("Cannot execute query in readonly mode", ErrorCodes::READONLY); + } + else if (!allow_ddl && calculateResultAccess(readonly, true, allow_introspection).isGranted(access, args...)) + { + show_error("Cannot execute query. DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED); + } + else if (!allow_introspection && calculateResultAccess(readonly, allow_ddl, true).isGranted(access, args...)) + { + show_error("Introspection functions are disabled, because setting 'allow_introspection_functions' is set to 0", ErrorCodes::FUNCTION_NOT_ALLOWED); + } + else + { + show_error( + user_name + ": Not enough privileges. To perform this operation you should have grant " + + AccessRightsElement{access, args...}.toString(), + ErrorCodes::ACCESS_DENIED); + } + + return false; +} + +template +bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElement & element) const +{ + if (element.any_database) + { + return checkImpl(log_, element.access_flags); + } + else if (element.any_table) + { + if (element.database.empty()) + return checkImpl(log_, element.access_flags, current_database); + else + return checkImpl(log_, element.access_flags, element.database); + } + else if (element.any_column) + { + if (element.database.empty()) + return checkImpl(log_, element.access_flags, current_database, element.table); + else + return checkImpl(log_, element.access_flags, element.database, element.table); + } + else + { + if (element.database.empty()) + return checkImpl(log_, element.access_flags, current_database, element.table, element.columns); + else + return checkImpl(log_, element.access_flags, element.database, element.table, element.columns); + } +} + + +template +bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElements & elements) const +{ + for (const auto & element : elements) + if (!checkImpl(log_, element)) + return false; + return true; +} + + +void AccessRightsContext::check(const AccessFlags & access) const { checkImpl(nullptr, access); } +void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database) const { checkImpl(nullptr, access, database); } +void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl(nullptr, access, database, table); } +void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl(nullptr, access, database, table, column); } +void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { checkImpl(nullptr, access, database, table, columns); } +void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl(nullptr, access, database, table, columns); } +void AccessRightsContext::check(const AccessRightsElement & access) const { checkImpl(nullptr, access); } +void AccessRightsContext::check(const AccessRightsElements & access) const { checkImpl(nullptr, access); } + +bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkImpl(nullptr, access); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkImpl(nullptr, access, database); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl(nullptr, access, database, table); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl(nullptr, access, database, table, column); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkImpl(nullptr, access, database, table, columns); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl(nullptr, access, database, table, columns); } +bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkImpl(nullptr, access); } +bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkImpl(nullptr, access); } + +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkImpl(log_, access); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkImpl(log_, access, database); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl(log_, access, database, table); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl(log_, access, database, table, column); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkImpl(log_, access, database, table, columns); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl(log_, access, database, table, columns); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkImpl(log_, access); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkImpl(log_, access); } + + +const AccessRights & AccessRightsContext::calculateResultAccess() const +{ + if (result_access_cache[0]) + return *result_access_cache[0]; + return calculateResultAccess(readonly, allow_ddl, allow_introspection); +} + + +const AccessRights & AccessRightsContext::calculateResultAccess(UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const +{ + size_t cache_index = static_cast(readonly_ != readonly) + + static_cast(allow_ddl_ != allow_ddl) * 2 + + + static_cast(allow_introspection_ != allow_introspection) * 3; + assert(cache_index < std::size(result_access_cache)); + auto & cached_result = result_access_cache[cache_index]; + + if (cached_result) + return *cached_result; + auto & result = cached_result.emplace(); + + result = granted_to_user; + + /// TODO + + return result; +} + +} diff --git a/dbms/src/Access/AccessRightsContext.h b/dbms/src/Access/AccessRightsContext.h new file mode 100644 index 0000000000..3d84e88198 --- /dev/null +++ b/dbms/src/Access/AccessRightsContext.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include +#include +#include + + +namespace Poco { class Logger; } + +namespace DB +{ +class Exception; +struct Settings; + + +class AccessRightsContext +{ +public: + /// Default constructor creates access rights' context which allows everything. + AccessRightsContext(); + + AccessRightsContext(const ClientInfo & client_info_, const AccessRights & granted_to_user, const Settings & settings, const String & current_database_); + + /// Checks if a specified access granted, and throws an exception if not. + /// Empty database means the current database. + void check(const AccessFlags & access) const; + void check(const AccessFlags & access, const std::string_view & database) const; + void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; + void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + void check(const AccessRightsElement & access) const; + void check(const AccessRightsElements & access) const; + + /// Checks if a specified access granted. + bool isGranted(const AccessFlags & access) const; + bool isGranted(const AccessFlags & access, const std::string_view & database) const; + bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; + bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + bool isGranted(const AccessRightsElement & access) const; + bool isGranted(const AccessRightsElements & access) const; + + /// Checks if a specified access granted, and logs a warning if not. + bool isGranted(Poco::Logger * log_, const AccessFlags & access) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + bool isGranted(Poco::Logger * log_, const AccessRightsElement & access) const; + bool isGranted(Poco::Logger * log_, const AccessRightsElements & access) const; + +private: + template + bool checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const; + + template + bool checkImpl(Poco::Logger * log_, const AccessRightsElement & access) const; + + template + bool checkImpl(Poco::Logger * log_, const AccessRightsElements & access) const; + + const AccessRights & calculateResultAccess() const; + const AccessRights & calculateResultAccess(UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const; + + const String user_name; + const AccessRights granted_to_user; + const UInt64 readonly = 0; + const bool allow_ddl = true; + const bool allow_introspection = true; + const String current_database; + const ClientInfo::Interface interface = ClientInfo::Interface::TCP; + const ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; + Poco::Logger * const trace_log = nullptr; + mutable std::optional result_access_cache[4]; + mutable std::mutex mutex; +}; + +} diff --git a/dbms/src/Access/AccessRightsElement.cpp b/dbms/src/Access/AccessRightsElement.cpp new file mode 100644 index 0000000000..160c265224 --- /dev/null +++ b/dbms/src/Access/AccessRightsElement.cpp @@ -0,0 +1,85 @@ +#include +#include + + +namespace DB +{ +void AccessRightsElement::setDatabase(const String & new_database) +{ + database = new_database; + any_database = false; +} + + +void AccessRightsElement::replaceEmptyDatabase(const String & new_database) +{ + if (isEmptyDatabase()) + setDatabase(new_database); +} + + +bool AccessRightsElement::isEmptyDatabase() const +{ + return !any_database && database.empty(); +} + + +String AccessRightsElement::toString() const +{ + String columns_in_parentheses; + if (!any_column) + { + for (const auto & column : columns) + { + columns_in_parentheses += columns_in_parentheses.empty() ? "(" : ", "; + columns_in_parentheses += backQuoteIfNeed(column); + } + columns_in_parentheses += ")"; + } + + String msg; + for (const std::string_view & keyword : access_flags.toKeywords()) + { + if (!msg.empty()) + msg += ", "; + msg += String{keyword} + columns_in_parentheses; + } + + msg += " ON "; + + if (any_database) + msg += "*."; + else if (!database.empty()) + msg += backQuoteIfNeed(database) + "."; + + if (any_table) + msg += "*"; + else + msg += backQuoteIfNeed(table); + return msg; +} + + +void AccessRightsElements::replaceEmptyDatabase(const String & new_database) +{ + for (auto & element : *this) + element.replaceEmptyDatabase(new_database); +} + + +String AccessRightsElements::toString() const +{ + String res; + bool need_comma = false; + for (auto & element : *this) + { + if (std::exchange(need_comma, true)) + res += ", "; + res += element.toString(); + } + + if (res.empty()) + res = "USAGE ON *.*"; + return res; +} +} diff --git a/dbms/src/Access/AccessRightsElement.h b/dbms/src/Access/AccessRightsElement.h new file mode 100644 index 0000000000..3894b6f515 --- /dev/null +++ b/dbms/src/Access/AccessRightsElement.h @@ -0,0 +1,100 @@ +#pragma once + +#include + + +namespace DB +{ +/// An element of access rights which can be represented by single line +/// GRANT ... ON ... +struct AccessRightsElement +{ + AccessFlags access_flags; + String database; + String table; + Strings columns; + bool any_database = true; + bool any_table = true; + bool any_column = true; + + AccessRightsElement() = default; + AccessRightsElement(const AccessRightsElement &) = default; + AccessRightsElement & operator=(const AccessRightsElement &) = default; + AccessRightsElement(AccessRightsElement &&) = default; + AccessRightsElement & operator=(AccessRightsElement &&) = default; + + AccessRightsElement(AccessFlags access_flags_) : access_flags(access_flags_) {} + + AccessRightsElement(AccessFlags access_flags_, const std::string_view & database_) + : access_flags(access_flags_), database(database_), any_database(false) + { + } + + AccessRightsElement(AccessFlags access_flags_, const std::string_view & database_, const std::string_view & table_) + : access_flags(access_flags_), database(database_), table(table_), any_database(false), any_table(false) + { + } + + AccessRightsElement( + AccessFlags access_flags_, const std::string_view & database_, const std::string_view & table_, const std::string_view & column_) + : access_flags(access_flags_) + , database(database_) + , table(table_) + , columns({String{column_}}) + , any_database(false) + , any_table(false) + , any_column(false) + { + } + + AccessRightsElement( + AccessFlags access_flags_, + const std::string_view & database_, + const std::string_view & table_, + const std::vector & columns_) + : access_flags(access_flags_), database(database_), table(table_), any_database(false), any_table(false), any_column(false) + { + columns.resize(columns_.size()); + for (size_t i = 0; i != columns_.size(); ++i) + columns[i] = String{columns_[i]}; + } + + AccessRightsElement( + AccessFlags access_flags_, const std::string_view & database_, const std::string_view & table_, const Strings & columns_) + : access_flags(access_flags_) + , database(database_) + , table(table_) + , columns(columns_) + , any_database(false) + , any_table(false) + , any_column(false) + { + } + + /// Sets the database. + void setDatabase(const String & new_database); + + /// If the database is empty, replaces it with `new_database`. Otherwise does nothing. + void replaceEmptyDatabase(const String & new_database); + + bool isEmptyDatabase() const; + + /// Returns a human-readable representation like "SELECT, UPDATE(x, y) ON db.table". + /// The returned string isn't prefixed with the "GRANT" keyword. + String toString() const; +}; + + +/// Multiple elements of access rights. +class AccessRightsElements : public std::vector +{ +public: + /// Replaces the empty database with `new_database`. + void replaceEmptyDatabase(const String & new_database); + + /// Returns a human-readable representation like "SELECT, UPDATE(x, y) ON db.table". + /// The returned string isn't prefixed with the "GRANT" keyword. + String toString() const; +}; + +} diff --git a/dbms/src/Access/AccessType.h b/dbms/src/Access/AccessType.h new file mode 100644 index 0000000000..930549f5bc --- /dev/null +++ b/dbms/src/Access/AccessType.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB +{ +/// Represents an access type which can be granted on databases, tables, columns, etc. +enum class AccessType +{ + NONE, /// no access + ALL, /// full access + + SHOW, /// allows to execute SHOW TABLES, SHOW CREATE TABLE, SHOW DATABASES and so on + /// (granted implicitly with any other grant) + + EXISTS, /// allows to execute EXISTS, USE, i.e. to check existence + /// (granted implicitly on the database level with any other grant on the database and lower levels, + /// e.g. "GRANT SELECT(x) ON db.table" also grants EXISTS on db.*) + + SELECT, + INSERT, + UPDATE, /// allows to execute ALTER UPDATE + DELETE, /// allows to execute ALTER DELETE +}; + +constexpr size_t MAX_ACCESS_TYPE = static_cast(AccessType::DELETE) + 1; + +std::string_view toString(AccessType type); + + +namespace impl +{ + template + class AccessTypeToKeywordConverter + { + public: + static const AccessTypeToKeywordConverter & instance() + { + static const AccessTypeToKeywordConverter res; + return res; + } + + std::string_view convert(AccessType type) const + { + return access_type_to_keyword_mapping[static_cast(type)]; + } + + private: + void addToMapping(AccessType type, const std::string_view & str) + { + String str2{str}; + boost::replace_all(str2, "_", " "); + if (islower(str2[0])) + str2 += "()"; + access_type_to_keyword_mapping[static_cast(type)] = str2; + } + + AccessTypeToKeywordConverter() + { +#define ACCESS_TYPE_TO_KEYWORD_CASE(type) \ + addToMapping(AccessType::type, #type) + + ACCESS_TYPE_TO_KEYWORD_CASE(NONE); + ACCESS_TYPE_TO_KEYWORD_CASE(ALL); + ACCESS_TYPE_TO_KEYWORD_CASE(SHOW); + ACCESS_TYPE_TO_KEYWORD_CASE(EXISTS); + + ACCESS_TYPE_TO_KEYWORD_CASE(SELECT); + ACCESS_TYPE_TO_KEYWORD_CASE(INSERT); + ACCESS_TYPE_TO_KEYWORD_CASE(UPDATE); + ACCESS_TYPE_TO_KEYWORD_CASE(DELETE); + +#undef ACCESS_TYPE_TO_KEYWORD_CASE + } + + std::array access_type_to_keyword_mapping; + }; +} + +inline std::string_view toKeyword(AccessType type) { return impl::AccessTypeToKeywordConverter<>::instance().convert(type); } + +} diff --git a/dbms/src/Common/ErrorCodes.cpp b/dbms/src/Common/ErrorCodes.cpp index fe5bca9f55..2c3ae67d96 100644 --- a/dbms/src/Common/ErrorCodes.cpp +++ b/dbms/src/Common/ErrorCodes.cpp @@ -469,7 +469,7 @@ namespace ErrorCodes extern const int ACCESS_ENTITY_FOUND_DUPLICATES = 494; extern const int ACCESS_ENTITY_STORAGE_READONLY = 495; extern const int QUOTA_REQUIRES_CLIENT_KEY = 496; - extern const int NOT_ENOUGH_PRIVILEGES = 497; + extern const int ACCESS_DENIED = 497; extern const int LIMIT_BY_WITH_TIES_IS_NOT_SUPPORTED = 498; extern const int S3_ERROR = 499; extern const int CANNOT_CREATE_DATABASE = 501; @@ -479,6 +479,8 @@ namespace ErrorCodes extern const int CANNOT_DELETE_DIRECTORY = 505; extern const int UNEXPECTED_ERROR_CODE = 506; extern const int UNABLE_TO_SKIP_UNUSED_SHARDS = 507; + extern const int UNKNOWN_ACCESS_TYPE = 508; + extern const int INVALID_GRANT = 509; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index 04d01a24cc..2a936ee356 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -95,8 +96,7 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; extern const int SCALAR_ALREADY_EXISTS; extern const int UNKNOWN_SCALAR; - extern const int NOT_ENOUGH_PRIVILEGES; - extern const int UNKNOWN_POLICY; + extern const int ACCESS_DENIED; } @@ -337,6 +337,7 @@ Context Context::createGlobal() Context res; res.quota = std::make_shared(); res.row_policy = std::make_shared(); + res.access_rights = std::make_shared(); res.shared = std::make_shared(); return res; } @@ -643,20 +644,36 @@ const AccessControlManager & Context::getAccessControlManager() const return shared->access_control_manager; } +template +void Context::checkAccessImpl(const Args &... args) const +{ + getAccessRights()->check(args...); +} + +void Context::checkAccess(const AccessFlags & access) const { return checkAccessImpl(access); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(access, database); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(access, database, table); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(access, database, table, column); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(access, database, table, columns); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(access, database, table, columns); } +void Context::checkAccess(const AccessRightsElement & access) const { return checkAccessImpl(access); } +void Context::checkAccess(const AccessRightsElements & access) const { return checkAccessImpl(access); } + void Context::checkQuotaManagementIsAllowed() { if (!is_quota_management_allowed) throw Exception( - "User " + client_info.current_user + " doesn't have enough privileges to manage quotas", ErrorCodes::NOT_ENOUGH_PRIVILEGES); + "User " + client_info.current_user + " doesn't have enough privileges to manage quotas", ErrorCodes::ACCESS_DENIED); } void Context::checkRowPolicyManagementIsAllowed() { if (!is_row_policy_management_allowed) throw Exception( - "User " + client_info.current_user + " doesn't have enough privileges to manage row policies", ErrorCodes::NOT_ENOUGH_PRIVILEGES); + "User " + client_info.current_user + " doesn't have enough privileges to manage row policies", ErrorCodes::ACCESS_DENIED); } + void Context::setUsersConfig(const ConfigurationPtr & config) { auto lock = getLock(); @@ -674,8 +691,6 @@ ConfigurationPtr Context::getUsersConfig() void Context::calculateUserSettings() { auto lock = getLock(); - - auto user = getUser(client_info.current_user); String profile = user->profile; /// 1) Set default settings (hardcoded values) @@ -697,8 +712,15 @@ void Context::calculateUserSettings() is_quota_management_allowed = user->is_quota_management_allowed; row_policy = getAccessControlManager().getRowPolicyContext(client_info.current_user); is_row_policy_management_allowed = user->is_row_policy_management_allowed; + calculateAccessRights(); } +void Context::calculateAccessRights() +{ + auto lock = getLock(); + if (user) + std::atomic_store(&access_rights, getAccessControlManager().getAccessRightsContext(client_info, user->access, settings, current_database)); +} void Context::setProfile(const String & profile) { @@ -710,7 +732,7 @@ void Context::setProfile(const String & profile) settings_constraints = std::move(new_constraints); } -std::shared_ptr Context::getUser(const String & user_name) +std::shared_ptr Context::getUser(const String & user_name) const { return shared->users_manager->getUser(user_name); } @@ -719,7 +741,7 @@ void Context::setUser(const String & name, const String & password, const Poco:: { auto lock = getLock(); - auto user_props = shared->users_manager->authorizeAndGetUser(name, password, address.host()); + user = shared->users_manager->authorizeAndGetUser(name, password, address.host()); client_info.current_user = name; client_info.current_address = address; @@ -1143,7 +1165,15 @@ Settings Context::getSettings() const void Context::setSettings(const Settings & settings_) { + auto lock = getLock(); + bool old_readonly = settings.readonly; + bool old_allow_ddl = settings.allow_ddl; + bool old_allow_introspection_functions = settings.allow_introspection_functions; + settings = settings_; + + if ((settings.readonly != old_readonly) || (settings.allow_ddl != old_allow_ddl) || (settings.allow_introspection_functions != old_allow_introspection_functions)) + calculateAccessRights(); } @@ -1156,6 +1186,9 @@ void Context::setSetting(const String & name, const String & value) return; } settings.set(name, value); + + if (name == "readonly" || name == "allow_ddl" || name == "allow_introspection_functions") + calculateAccessRights(); } @@ -1168,6 +1201,9 @@ void Context::setSetting(const String & name, const Field & value) return; } settings.set(name, value); + + if (name == "readonly" || name == "allow_ddl" || name == "allow_introspection_functions") + calculateAccessRights(); } @@ -1222,6 +1258,7 @@ void Context::setCurrentDatabase(const String & name) auto lock = getLock(); assertDatabaseExists(name); current_database = name; + calculateAccessRights(); } @@ -1589,9 +1626,9 @@ std::pair Context::getInterserverIOAddress() const return { shared->interserver_io_host, shared->interserver_io_port }; } -void Context::setInterserverCredentials(const String & user, const String & password) +void Context::setInterserverCredentials(const String & user_, const String & password) { - shared->interserver_io_user = user; + shared->interserver_io_user = user_; shared->interserver_io_password = password; } diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index dcce1a4772..c33b102eca 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -43,6 +43,7 @@ namespace DB struct ContextShared; class Context; +class AccessRightsContext; class QuotaContext; class RowPolicyContext; class EmbeddedDictionaries; @@ -145,6 +146,8 @@ private: InputInitializer input_initializer_callback; InputBlocksReader input_blocks_reader; + std::shared_ptr user; + std::shared_ptr access_rights; std::shared_ptr quota; /// Current quota. By default - empty quota, that have no limits. bool is_quota_management_allowed = false; /// Whether the current user is allowed to manage quotas via SQL commands. std::shared_ptr row_policy; @@ -219,6 +222,19 @@ public: AccessControlManager & getAccessControlManager(); const AccessControlManager & getAccessControlManager() const; + std::shared_ptr getAccessRights() const { return std::atomic_load(&access_rights); } + + /// Checks access rights. + /// Empty database means the current database. + void checkAccess(const AccessFlags & access) const; + void checkAccess(const AccessFlags & access, const std::string_view & database) const; + void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; + void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + void checkAccess(const AccessRightsElement & access) const; + void checkAccess(const AccessRightsElements & access) const; + std::shared_ptr getQuota() const { return quota; } void checkQuotaManagementIsAllowed(); std::shared_ptr getRowPolicy() const { return row_policy; } @@ -233,12 +249,10 @@ public: /// Must be called before getClientInfo. void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key); + std::shared_ptr getUser() const { return user; } /// Used by MySQL Secure Password Authentication plugin. - std::shared_ptr getUser(const String & user_name); - - /// Compute and set actual user settings, client_info.current_user should be set - void calculateUserSettings(); + std::shared_ptr getUser(const String & user_name) const; /// We have to copy external tables inside executeQuery() to track limits. Therefore, set callback for it. Must set once. void setExternalTablesInitializer(ExternalTablesInitializer && initializer); @@ -589,12 +603,19 @@ public: MySQLWireContext mysql; private: + /// Compute and set actual user settings, client_info.current_user should be set + void calculateUserSettings(); + void calculateAccessRights(); + /** Check if the current client has access to the specified database. * If access is denied, throw an exception. * NOTE: This method should always be called when the `shared->mutex` mutex is acquired. */ void checkDatabaseAccessRightsImpl(const std::string & database_name) const; + template + void checkAccessImpl(const Args &... args) const; + void setProfile(const String & profile); EmbeddedDictionaries & getEmbeddedDictionariesImpl(bool throw_on_error) const; diff --git a/dbms/src/Interpreters/Users.cpp b/dbms/src/Interpreters/Users.cpp index 2b48ce56e3..28b1190455 100644 --- a/dbms/src/Interpreters/Users.cpp +++ b/dbms/src/Interpreters/Users.cpp @@ -103,6 +103,16 @@ User::User(const String & name_, const String & config_elem, const Poco::Util::A } } + access.grant(AccessType::ALL); /// By default all databases are accessible. + + if (databases) + { + access.fullRevoke(AccessFlags::databaseLevel()); + for (const String & database : *databases) + access.grant(AccessFlags::databaseLevel(), database); + access.grant(AccessFlags::databaseLevel(), "system"); /// Anyone has access to the "system" database. + } + if (config.has(config_elem + ".allow_quota_management")) is_quota_management_allowed = config.getBool(config_elem + ".allow_quota_management"); if (config.has(config_elem + ".allow_row_policy_management")) diff --git a/dbms/src/Interpreters/Users.h b/dbms/src/Interpreters/Users.h index f151770cef..5f8d65e3a8 100644 --- a/dbms/src/Interpreters/Users.h +++ b/dbms/src/Interpreters/Users.h @@ -4,10 +4,9 @@ #include #include #include - -#include +#include +#include #include -#include namespace Poco @@ -45,6 +44,8 @@ struct User bool is_quota_management_allowed = false; bool is_row_policy_management_allowed = false; + AccessRights access; + User(const String & name_, const String & config_elem, const Poco::Util::AbstractConfiguration & config); }; -- GitLab