提交 1e8f04b5 编写于 作者: V Vitaly Baranov

Add to class AllowedClientHost support for like pattern and for removing.

上级 d9073f27
#include <Access/AllowedClientHosts.h>
#include <Common/Exception.h>
#include <common/SimpleCache.h>
#include <Common/StringUtils/StringUtils.h>
#include <IO/ReadHelpers.h>
#include <Functions/likePatternToRegexp.h>
#include <Poco/Net/SocketAddress.h>
#include <Poco/RegularExpression.h>
#include <common/logger_useful.h>
#include <ext/scope_guard.h>
#include <boost/range/algorithm/find.hpp>
#include <boost/range/algorithm/find_first_of.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/algorithm/string/replace.hpp>
#include <ifaddrs.h>
......@@ -27,20 +24,6 @@ namespace
using IPSubnet = AllowedClientHosts::IPSubnet;
const IPSubnet ALL_ADDRESSES{IPAddress{IPAddress::IPv6}, IPAddress{IPAddress::IPv6}};
const IPAddress & getIPV6Loopback()
{
static const IPAddress ip("::1");
return ip;
}
bool isIPV4LoopbackMappedToIPV6(const IPAddress & ip)
{
static const IPAddress prefix("::ffff:127.0.0.0");
/// 104 == 128 - 24, we have to reset the lowest 24 bits of 128 before comparing with `prefix`
/// (IPv4 loopback means any IP from 127.0.0.0 to 127.255.255.255).
return (ip & IPAddress(104, IPAddress::IPv6)) == prefix;
}
/// Converts an address to IPv6.
/// The loopback address "127.0.0.1" (or any "127.x.y.z") is converted to "::1".
IPAddress toIPv6(const IPAddress & ip)
......@@ -52,35 +35,18 @@ namespace
v6 = IPAddress("::ffff:" + ip.toString());
// ::ffff:127.XX.XX.XX -> ::1
if (isIPV4LoopbackMappedToIPV6(v6))
v6 = getIPV6Loopback();
if ((v6 & IPAddress(104, IPAddress::IPv6)) == IPAddress("::ffff:127.0.0.0"))
v6 = IPAddress{"::1"};
return v6;
}
/// Converts a subnet to IPv6.
IPSubnet toIPv6(const IPSubnet & subnet)
{
IPSubnet v6;
if (subnet.prefix.family() == IPAddress::IPv6)
v6.prefix = subnet.prefix;
else
v6.prefix = IPAddress("::ffff:" + subnet.prefix.toString());
if (subnet.mask.family() == IPAddress::IPv6)
v6.mask = subnet.mask;
else
v6.mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + subnet.mask.toString());
v6.prefix = v6.prefix & v6.mask;
// ::ffff:127.XX.XX.XX -> ::1
if (isIPV4LoopbackMappedToIPV6(v6.prefix))
v6 = {getIPV6Loopback(), IPAddress(128, IPAddress::IPv6)};
return v6;
return IPSubnet(toIPv6(subnet.getPrefix()), subnet.getMask());
}
/// Helper function for isAddressOfHost().
bool isAddressOfHostImpl(const IPAddress & address, const String & host)
{
......@@ -150,7 +116,7 @@ namespace
int err = getifaddrs(&ifa_begin);
if (err)
return {getIPV6Loopback()};
return {IPAddress{"::1"}};
for (const ifaddrs * ifa = ifa_begin; ifa; ifa = ifa->ifa_next)
{
......@@ -203,200 +169,131 @@ namespace
static SimpleCache<decltype(getHostByAddressImpl), &getHostByAddressImpl> cache;
return cache(address);
}
}
String AllowedClientHosts::IPSubnet::toString() const
{
unsigned int prefix_length = mask.prefixLength();
if (IPAddress{prefix_length, mask.family()} == mask)
return prefix.toString() + "/" + std::to_string(prefix_length);
return prefix.toString() + "/" + mask.toString();
}
AllowedClientHosts::AllowedClientHosts()
{
}
AllowedClientHosts::AllowedClientHosts(AllAddressesTag)
{
addAllAddresses();
}
AllowedClientHosts::~AllowedClientHosts() = default;
AllowedClientHosts::AllowedClientHosts(const AllowedClientHosts & src)
{
*this = src;
}
AllowedClientHosts & AllowedClientHosts::operator =(const AllowedClientHosts & src)
{
addresses = src.addresses;
localhost = src.localhost;
subnets = src.subnets;
host_names = src.host_names;
host_regexps = src.host_regexps;
compiled_host_regexps.clear();
return *this;
}
AllowedClientHosts::AllowedClientHosts(AllowedClientHosts && src) = default;
AllowedClientHosts & AllowedClientHosts::operator =(AllowedClientHosts && src) = default;
void AllowedClientHosts::clear()
{
addresses.clear();
localhost = false;
subnets.clear();
host_names.clear();
host_regexps.clear();
compiled_host_regexps.clear();
}
bool AllowedClientHosts::empty() const
{
return addresses.empty() && subnets.empty() && host_names.empty() && host_regexps.empty();
}
void AllowedClientHosts::addAddress(const IPAddress & address)
{
IPAddress addr_v6 = toIPv6(address);
if (boost::range::find(addresses, addr_v6) != addresses.end())
return;
addresses.push_back(addr_v6);
if (addr_v6.isLoopback())
localhost = true;
}
void AllowedClientHosts::addAddress(const String & address)
{
addAddress(IPAddress{address});
}
void AllowedClientHosts::addSubnet(const IPSubnet & subnet)
{
IPSubnet subnet_v6 = toIPv6(subnet);
if (subnet_v6.mask == IPAddress(128, IPAddress::IPv6))
void parseLikePatternIfIPSubnet(const String & pattern, IPSubnet & subnet, IPAddress::Family address_family)
{
addAddress(subnet_v6.prefix);
return;
}
if (boost::range::find(subnets, subnet_v6) == subnets.end())
subnets.push_back(subnet_v6);
}
void AllowedClientHosts::addSubnet(const IPAddress & prefix, const IPAddress & mask)
{
addSubnet(IPSubnet{prefix, mask});
}
size_t slash = pattern.find('/');
if (slash != String::npos)
{
/// IP subnet, e.g. "192.168.0.0/16" or "192.168.0.0/255.255.0.0".
subnet = IPSubnet{pattern};
return;
}
void AllowedClientHosts::addSubnet(const IPAddress & prefix, size_t num_prefix_bits)
{
addSubnet(prefix, IPAddress(num_prefix_bits, prefix.family()));
}
bool has_wildcard = (pattern.find_first_of("%_") != String::npos);
if (has_wildcard)
{
/// IP subnet specified with one of the wildcard characters, e.g. "192.168.%.%".
String wildcard_replaced_with_zero_bits = pattern;
String wildcard_replaced_with_one_bits = pattern;
if (address_family == IPAddress::IPv6)
{
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "_", "0");
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "%", "0000");
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "_", "f");
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "%", "ffff");
}
else if (address_family == IPAddress::IPv4)
{
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "%", "0");
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "%", "255");
}
IPAddress prefix{wildcard_replaced_with_zero_bits};
IPAddress mask = ~(prefix ^ IPAddress{wildcard_replaced_with_one_bits});
subnet = IPSubnet{prefix, mask};
return;
}
void AllowedClientHosts::addSubnet(const String & subnet)
{
size_t slash = subnet.find('/');
if (slash == String::npos)
{
addAddress(subnet);
return;
/// Exact IP address.
subnet = IPSubnet{pattern};
}
IPAddress prefix{String{subnet, 0, slash}};
String mask(subnet, slash + 1, subnet.length() - slash - 1);
if (std::all_of(mask.begin(), mask.end(), isNumericASCII))
addSubnet(prefix, parseFromString<UInt8>(mask));
else
addSubnet(prefix, IPAddress{mask});
}
void AllowedClientHosts::addHostName(const String & host_name)
{
if (boost::range::find(host_names, host_name) != host_names.end())
return;
host_names.push_back(host_name);
if (boost::iequals(host_name, "localhost"))
localhost = true;
}
/// Extracts a subnet, a host name or a host name regular expession from a like pattern.
void parseLikePattern(
const String & pattern, std::optional<IPSubnet> & subnet, std::optional<String> & name, std::optional<String> & name_regexp)
{
/// If `host` starts with digits and a dot then it's an IP pattern, otherwise it's a hostname pattern.
size_t first_not_digit = pattern.find_first_not_of("0123456789");
if ((first_not_digit != String::npos) && (first_not_digit != 0) && (pattern[first_not_digit] == '.'))
{
parseLikePatternIfIPSubnet(pattern, subnet.emplace(), IPAddress::IPv4);
return;
}
void AllowedClientHosts::addHostRegexp(const String & host_regexp)
{
if (boost::range::find(host_regexps, host_regexp) == host_regexps.end())
host_regexps.push_back(host_regexp);
}
size_t first_not_hex = pattern.find_first_not_of("0123456789ABCDEFabcdef");
if (((first_not_hex == 4) && pattern[first_not_hex] == ':') || pattern.starts_with("::"))
{
parseLikePatternIfIPSubnet(pattern, subnet.emplace(), IPAddress::IPv6);
return;
}
bool has_wildcard = (pattern.find_first_of("%_") != String::npos);
if (has_wildcard)
{
name_regexp = likePatternToRegexp(pattern);
return;
}
void AllowedClientHosts::addAllAddresses()
{
clear();
addSubnet(ALL_ADDRESSES);
name = pattern;
}
}
bool AllowedClientHosts::containsAllAddresses() const
bool AllowedClientHosts::contains(const IPAddress & client_address) const
{
return (boost::range::find(subnets, ALL_ADDRESSES) != subnets.end())
|| (boost::range::find(host_regexps, ".*") != host_regexps.end())
|| (boost::range::find(host_regexps, "$") != host_regexps.end());
}
if (any_host)
return true;
IPAddress client_v6 = toIPv6(client_address);
void AllowedClientHosts::checkContains(const IPAddress & address, const String & user_name) const
{
if (!contains(address))
std::optional<bool> is_client_local_value;
auto is_client_local = [&]
{
if (user_name.empty())
throw Exception("It's not allowed to connect from address " + address.toString(), ErrorCodes::IP_ADDRESS_NOT_ALLOWED);
else
throw Exception("User " + user_name + " is not allowed to connect from address " + address.toString(), ErrorCodes::IP_ADDRESS_NOT_ALLOWED);
}
}
if (is_client_local_value)
return *is_client_local_value;
is_client_local_value = isAddressOfLocalhost(client_v6);
return *is_client_local_value;
};
bool AllowedClientHosts::contains(const IPAddress & address) const
{
/// Check `ip_addresses`.
IPAddress addr_v6 = toIPv6(address);
if (boost::range::find(addresses, addr_v6) != addresses.end())
if (local_host && is_client_local())
return true;
if (localhost && isAddressOfLocalhost(addr_v6))
return true;
/// Check `addresses`.
auto check_address = [&](const IPAddress & address_)
{
IPAddress address_v6 = toIPv6(address_);
if (address_v6.isLoopback())
return is_client_local();
return address_v6 == client_v6;
};
for (const auto & address : addresses)
if (check_address(address))
return true;
/// Check `subnets`.
auto check_subnet = [&](const IPSubnet & subnet_)
{
IPSubnet subnet_v6 = toIPv6(subnet_);
if (subnet_v6.isMaskAllBitsOne())
return check_address(subnet_v6.getPrefix());
return (client_v6 & subnet_v6.getMask()) == subnet_v6.getPrefix();
};
/// Check `ip_subnets`.
for (const auto & subnet : subnets)
if ((addr_v6 & subnet.mask) == subnet.prefix)
if (check_subnet(subnet))
return true;
/// Check `hosts`.
for (const String & host_name : host_names)
/// Check `names`.
auto check_name = [&](const String & name_)
{
if (boost::iequals(name_, "localhost"))
return is_client_local();
try
{
if (isAddressOfHost(addr_v6, host_name))
return true;
return isAddressOfHost(client_v6, name_);
}
catch (const Exception & e)
{
......@@ -405,55 +302,82 @@ bool AllowedClientHosts::contains(const IPAddress & address) const
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText()
"Failed to check if the allowed client hosts contain address " << client_address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
return false;
}
}
};
/// Check `host_regexps`.
try
for (const String & name : names)
if (check_name(name))
return true;
/// Check `name_regexps`.
std::optional<String> resolved_host;
auto check_name_regexp = [&](const String & name_regexp_)
{
String resolved_host = getHostByAddress(addr_v6);
if (!resolved_host.empty())
try
{
compileRegexps();
for (const auto & compiled_regexp : compiled_host_regexps)
{
Poco::RegularExpression::Match match;
if (compiled_regexp && compiled_regexp->match(resolved_host, match))
return true;
}
if (boost::iequals(name_regexp_, "localhost"))
return is_client_local();
if (!resolved_host)
resolved_host = getHostByAddress(client_v6);
if (resolved_host->empty())
return false;
Poco::RegularExpression re(name_regexp_);
Poco::RegularExpression::Match match;
return re.match(*resolved_host, match) != 0;
}
}
catch (const Exception & e)
catch (const Exception & e)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << client_address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
return false;
}
};
for (const String & name_regexp : name_regexps)
if (check_name_regexp(name_regexp))
return true;
auto check_like_pattern = [&](const String & pattern)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
}
std::optional<IPSubnet> subnet;
std::optional<String> name;
std::optional<String> name_regexp;
parseLikePattern(pattern, subnet, name, name_regexp);
if (subnet)
return check_subnet(*subnet);
else if (name)
return check_name(*name);
else if (name_regexp)
return check_name_regexp(*name_regexp);
else
return false;
};
for (const String & like_pattern : like_patterns)
if (check_like_pattern(like_pattern))
return true;
return false;
}
void AllowedClientHosts::compileRegexps() const
void AllowedClientHosts::checkContains(const IPAddress & address, const String & user_name) const
{
if (compiled_host_regexps.size() == host_regexps.size())
return;
size_t old_size = compiled_host_regexps.size();
compiled_host_regexps.reserve(host_regexps.size());
for (size_t i = old_size; i != host_regexps.size(); ++i)
compiled_host_regexps.emplace_back(std::make_unique<Poco::RegularExpression>(host_regexps[i]));
if (!contains(address))
{
if (user_name.empty())
throw Exception("It's not allowed to connect from address " + address.toString(), ErrorCodes::IP_ADDRESS_NOT_ALLOWED);
else
throw Exception("User " + user_name + " is not allowed to connect from address " + address.toString(), ErrorCodes::IP_ADDRESS_NOT_ALLOWED);
}
}
bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs)
{
return (lhs.addresses == rhs.addresses) && (lhs.subnets == rhs.subnets) && (lhs.host_names == rhs.host_names)
&& (lhs.host_regexps == rhs.host_regexps);
}
}
......@@ -4,12 +4,9 @@
#include <Poco/Net/IPAddress.h>
#include <memory>
#include <vector>
namespace Poco
{
class RegularExpression;
}
#include <boost/range/algorithm/find.hpp>
#include <boost/range/algorithm_ext/erase.hpp>
#include <boost/algorithm/string/predicate.hpp>
namespace DB
......@@ -20,69 +17,100 @@ class AllowedClientHosts
public:
using IPAddress = Poco::Net::IPAddress;
struct IPSubnet
class IPSubnet
{
IPAddress prefix;
IPAddress mask;
public:
IPSubnet() {}
IPSubnet(const IPAddress & prefix_, const IPAddress & mask_) { set(prefix_, mask_); }
IPSubnet(const IPAddress & prefix_, size_t num_prefix_bits) { set(prefix_, num_prefix_bits); }
explicit IPSubnet(const IPAddress & address) { set(address); }
explicit IPSubnet(const String & str);
const IPAddress & getPrefix() const { return prefix; }
const IPAddress & getMask() const { return mask; }
bool isMaskAllBitsOne() const;
String toString() const;
friend bool operator ==(const IPSubnet & lhs, const IPSubnet & rhs) { return (lhs.prefix == rhs.prefix) && (lhs.mask == rhs.mask); }
friend bool operator !=(const IPSubnet & lhs, const IPSubnet & rhs) { return !(lhs == rhs); }
private:
void set(const IPAddress & prefix_, const IPAddress & mask_);
void set(const IPAddress & prefix_, size_t num_prefix_bits);
void set(const IPAddress & address);
IPAddress prefix;
IPAddress mask;
};
struct AllAddressesTag {};
struct AnyHostTag {};
AllowedClientHosts();
explicit AllowedClientHosts(AllAddressesTag);
~AllowedClientHosts();
AllowedClientHosts() {}
explicit AllowedClientHosts(AnyHostTag) { addAnyHost(); }
~AllowedClientHosts() {}
AllowedClientHosts(const AllowedClientHosts & src);
AllowedClientHosts & operator =(const AllowedClientHosts & src);
AllowedClientHosts(AllowedClientHosts && src);
AllowedClientHosts & operator =(AllowedClientHosts && src);
AllowedClientHosts(const AllowedClientHosts & src) = default;
AllowedClientHosts & operator =(const AllowedClientHosts & src) = default;
AllowedClientHosts(AllowedClientHosts && src) = default;
AllowedClientHosts & operator =(AllowedClientHosts && src) = default;
/// Removes all contained addresses. This will disallow all addresses.
/// Removes all contained addresses. This will disallow all hosts.
void clear();
bool empty() const;
/// Allows exact IP address.
/// For example, 213.180.204.3 or 2a02:6b8::3
void addAddress(const IPAddress & address);
void addAddress(const String & address);
void addAddress(const String & address) { addAddress(IPAddress(address)); }
void removeAddress(const IPAddress & address);
void removeAddress(const String & address) { removeAddress(IPAddress{address}); }
const std::vector<IPAddress> & getAddresses() const { return addresses; }
/// Allows an IP subnet.
/// For example, 312.234.1.1/255.255.255.0 or 2a02:6b8::3/64
void addSubnet(const IPSubnet & subnet);
void addSubnet(const String & subnet);
void addSubnet(const String & subnet) { addSubnet(IPSubnet{subnet}); }
void addSubnet(const IPAddress & prefix, const IPAddress & mask) { addSubnet({prefix, mask}); }
void addSubnet(const IPAddress & prefix, size_t num_prefix_bits) { addSubnet({prefix, num_prefix_bits}); }
void removeSubnet(const IPSubnet & subnet);
void removeSubnet(const String & subnet) { removeSubnet(IPSubnet{subnet}); }
void removeSubnet(const IPAddress & prefix, const IPAddress & mask) { removeSubnet({prefix, mask}); }
void removeSubnet(const IPAddress & prefix, size_t num_prefix_bits) { removeSubnet({prefix, num_prefix_bits}); }
const std::vector<IPSubnet> & getSubnets() const { return subnets; }
/// Allows an IP subnet.
/// For example, 312.234.1.1/255.255.255.0 or 2a02:6b8::3/FFFF:FFFF:FFFF:FFFF::
void addSubnet(const IPAddress & prefix, const IPAddress & mask);
/// Allows an exact host name. The `contains()` function will check that the provided address equals to one of that host's addresses.
void addName(const String & name);
void removeName(const String & name);
const std::vector<String> & getNames() const { return names; }
/// Allows an IP subnet.
/// For example, 10.0.0.1/8 or 2a02:6b8::3/64
void addSubnet(const IPAddress & prefix, size_t num_prefix_bits);
/// Allows the host names matching a regular expression.
void addNameRegexp(const String & name_regexp);
void removeNameRegexp(const String & name_regexp);
const std::vector<String> & getNameRegexps() const { return name_regexps; }
/// Allows all addresses.
void addAllAddresses();
/// Allows IP addresses or host names using LIKE pattern.
/// This pattern can contain % and _ wildcard characters.
/// For example, addLikePattern("@") will allow all addresses.
void addLikePattern(const String & pattern);
void removeLikePattern(const String & like_pattern);
const std::vector<String> & getLikePatterns() const { return like_patterns; }
/// Allows an exact host. The `contains()` function will check that the provided address equals to one of that host's addresses.
void addHostName(const String & host_name);
/// Allows local host.
void addLocalHost();
void removeLocalHost();
bool containsLocalHost() const { return local_host;}
/// Allows a regular expression for the host.
void addHostRegexp(const String & host_regexp);
/// Allows any host.
void addAnyHost();
bool containsAnyHost() const { return any_host;}
const std::vector<IPAddress> & getAddresses() const { return addresses; }
const std::vector<IPSubnet> & getSubnets() const { return subnets; }
const std::vector<String> & getHostNames() const { return host_names; }
const std::vector<String> & getHostRegexps() const { return host_regexps; }
void add(const AllowedClientHosts & other);
void remove(const AllowedClientHosts & other);
/// Checks if the provided address is in the list. Returns false if not.
bool contains(const IPAddress & address) const;
/// Checks if any address is allowed.
bool containsAllAddresses() const;
/// Checks if the provided address is in the list. Throws an exception if not.
/// `username` is only used for generating an error message if the address isn't in the list.
void checkContains(const IPAddress & address, const String & user_name = String()) const;
......@@ -91,13 +119,269 @@ public:
friend bool operator !=(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs) { return !(lhs == rhs); }
private:
void compileRegexps() const;
std::vector<IPAddress> addresses;
bool localhost = false;
std::vector<IPSubnet> subnets;
std::vector<String> host_names;
std::vector<String> host_regexps;
mutable std::vector<std::unique_ptr<Poco::RegularExpression>> compiled_host_regexps;
Strings names;
Strings name_regexps;
Strings like_patterns;
bool any_host = false;
bool local_host = false;
};
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & prefix_, const IPAddress & mask_)
{
prefix = prefix_;
mask = mask_;
if (prefix.family() != mask.family())
{
if (prefix.family() == IPAddress::IPv4)
prefix = IPAddress("::ffff:" + prefix.toString());
if (mask.family() == IPAddress::IPv4)
mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + mask.toString());
}
prefix = prefix & mask;
if (prefix.family() == IPAddress::IPv4)
{
if ((prefix & IPAddress{8, IPAddress::IPv4}) == IPAddress{"127.0.0.0"})
{
// 127.XX.XX.XX -> 127.0.0.1
prefix = IPAddress{"127.0.0.1"};
mask = IPAddress{32, IPAddress::IPv4};
}
}
else
{
if ((prefix & IPAddress{104, IPAddress::IPv6}) == IPAddress{"::ffff:127.0.0.0"})
{
// ::ffff:127.XX.XX.XX -> ::1
prefix = IPAddress{"::1"};
mask = IPAddress{128, IPAddress::IPv6};
}
}
}
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & prefix_, size_t num_prefix_bits)
{
set(prefix_, IPAddress(num_prefix_bits, prefix_.family()));
}
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & address)
{
set(address, address.length() * 8);
}
inline AllowedClientHosts::IPSubnet::IPSubnet(const String & str)
{
size_t slash = str.find('/');
if (slash == String::npos)
{
set(IPAddress(str));
return;
}
IPAddress new_prefix{String{str, 0, slash}};
String mask_str(str, slash + 1, str.length() - slash - 1);
bool only_digits = (mask_str.find_first_not_of("0123456789") == std::string::npos);
if (only_digits)
set(new_prefix, std::stoul(mask_str));
else
set(new_prefix, IPAddress{mask_str});
}
inline String AllowedClientHosts::IPSubnet::toString() const
{
unsigned int prefix_length = mask.prefixLength();
if (isMaskAllBitsOne())
return prefix.toString();
else if (IPAddress{prefix_length, mask.family()} == mask)
return prefix.toString() + "/" + std::to_string(prefix_length);
else
return prefix.toString() + "/" + mask.toString();
}
inline bool AllowedClientHosts::IPSubnet::isMaskAllBitsOne() const
{
return mask == IPAddress(mask.length() * 8, mask.family());
}
inline void AllowedClientHosts::clear()
{
addresses = {};
subnets = {};
names = {};
name_regexps = {};
like_patterns = {};
any_host = false;
local_host = false;
}
inline bool AllowedClientHosts::empty() const
{
return !any_host && !local_host && addresses.empty() && subnets.empty() && names.empty() && name_regexps.empty() && like_patterns.empty();
}
inline void AllowedClientHosts::addAddress(const IPAddress & address)
{
if (address.isLoopback())
local_host = true;
else if (boost::range::find(addresses, address) == addresses.end())
addresses.push_back(address);
}
inline void AllowedClientHosts::removeAddress(const IPAddress & address)
{
if (address.isLoopback())
local_host = false;
else
boost::range::remove_erase(addresses, address);
}
inline void AllowedClientHosts::addSubnet(const IPSubnet & subnet)
{
if (subnet.getMask().isWildcard())
any_host = true;
else if (subnet.isMaskAllBitsOne())
addAddress(subnet.getPrefix());
else if (boost::range::find(subnets, subnet) == subnets.end())
subnets.push_back(subnet);
}
inline void AllowedClientHosts::removeSubnet(const IPSubnet & subnet)
{
if (subnet.getMask().isWildcard())
any_host = false;
else if (subnet.isMaskAllBitsOne())
removeAddress(subnet.getPrefix());
else
boost::range::remove_erase(subnets, subnet);
}
inline void AllowedClientHosts::addName(const String & name)
{
if (boost::iequals(name, "localhost"))
local_host = true;
else if (boost::range::find(names, name) == names.end())
names.push_back(name);
}
inline void AllowedClientHosts::removeName(const String & name)
{
if (boost::iequals(name, "localhost"))
local_host = false;
else
boost::range::remove_erase(names, name);
}
inline void AllowedClientHosts::addNameRegexp(const String & name_regexp)
{
if (boost::iequals(name_regexp, "localhost"))
local_host = true;
else if (name_regexp == ".*")
any_host = true;
else if (boost::range::find(name_regexps, name_regexp) == name_regexps.end())
name_regexps.push_back(name_regexp);
}
inline void AllowedClientHosts::removeNameRegexp(const String & name_regexp)
{
if (boost::iequals(name_regexp, "localhost"))
local_host = false;
else if (name_regexp == ".*")
any_host = false;
else
boost::range::remove_erase(name_regexps, name_regexp);
}
inline void AllowedClientHosts::addLikePattern(const String & pattern)
{
if (boost::iequals(pattern, "localhost") || (pattern == "127.0.0.1") || (pattern == "::1"))
local_host = true;
else if ((pattern == "@") || (pattern == "0.0.0.0/0") || (pattern == "::/0"))
any_host = true;
else if (boost::range::find(like_patterns, pattern) == name_regexps.end())
like_patterns.push_back(pattern);
}
inline void AllowedClientHosts::removeLikePattern(const String & pattern)
{
if (boost::iequals(pattern, "localhost") || (pattern == "127.0.0.1") || (pattern == "::1"))
local_host = false;
else if ((pattern == "@") || (pattern == "0.0.0.0/0") || (pattern == "::/0"))
any_host = false;
else
boost::range::remove_erase(like_patterns, pattern);
}
inline void AllowedClientHosts::addLocalHost()
{
local_host = true;
}
inline void AllowedClientHosts::removeLocalHost()
{
local_host = false;
}
inline void AllowedClientHosts::addAnyHost()
{
clear();
any_host = true;
}
inline void AllowedClientHosts::add(const AllowedClientHosts & other)
{
if (other.containsAnyHost())
{
addAnyHost();
return;
}
if (other.containsLocalHost())
addLocalHost();
for (const IPAddress & address : other.getAddresses())
addAddress(address);
for (const IPSubnet & subnet : other.getSubnets())
addSubnet(subnet);
for (const String & name : other.getNames())
addName(name);
for (const String & name_regexp : other.getNameRegexps())
addNameRegexp(name_regexp);
for (const String & like_pattern : other.getLikePatterns())
addLikePattern(like_pattern);
}
inline void AllowedClientHosts::remove(const AllowedClientHosts & other)
{
if (other.containsAnyHost())
{
clear();
return;
}
if (other.containsLocalHost())
removeLocalHost();
for (const IPAddress & address : other.getAddresses())
removeAddress(address);
for (const IPSubnet & subnet : other.getSubnets())
removeSubnet(subnet);
for (const String & name : other.getNames())
removeName(name);
for (const String & name_regexp : other.getNameRegexps())
removeNameRegexp(name_regexp);
for (const String & like_pattern : other.getLikePatterns())
removeLikePattern(like_pattern);
}
inline bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs)
{
return (lhs.any_host == rhs.any_host) && (lhs.local_host == rhs.local_host) && (lhs.addresses == rhs.addresses)
&& (lhs.subnets == rhs.subnets) && (lhs.names == rhs.names) && (lhs.name_regexps == rhs.name_regexps)
&& (lhs.like_patterns == rhs.like_patterns);
}
}
......@@ -14,7 +14,7 @@ namespace DB
struct User : public IAccessEntity
{
Authentication authentication;
AllowedClientHosts allowed_client_hosts;
AllowedClientHosts allowed_client_hosts{AllowedClientHosts::AnyHostTag{}};
AccessRights access;
String profile;
......
......@@ -90,15 +90,16 @@ namespace
{
Poco::Util::AbstractConfiguration::Keys keys;
config.keys(networks_config, keys);
user->allowed_client_hosts.clear();
for (const String & key : keys)
{
String value = config.getString(networks_config + "." + key);
if (key.starts_with("ip"))
user->allowed_client_hosts.addSubnet(value);
else if (key.starts_with("host_regexp"))
user->allowed_client_hosts.addHostRegexp(value);
user->allowed_client_hosts.addNameRegexp(value);
else if (key.starts_with("host"))
user->allowed_client_hosts.addHostName(value);
user->allowed_client_hosts.addName(value);
else
throw Exception("Unknown address pattern type: " + key, ErrorCodes::UNKNOWN_ADDRESS_PATTERN_TYPE);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册