未验证 提交 957d2326 编写于 作者: I Ivan 提交者: GitHub

Minimal implementation of row-level security CLICKHOUSE-4315 (#4792)

For detailed description see the related PR
上级 5588618c
......@@ -74,6 +74,26 @@
<!-- Quota for user. -->
<quota>default</quota>
<!-- For testing the table filters -->
<databases>
<test>
<!-- Simple expression filter -->
<filtered_table1>
<filter>a = 1</filter>
</filtered_table1>
<!-- Complex expression filter -->
<filtered_table2>
<filter>a + b &lt; 1 or c - d &gt; 5</filter>
</filtered_table2>
<!-- Filter with ALIAS column -->
<filtered_table3>
<filter>c = 1</filter>
</filtered_table3>
</test>
</databases>
</default>
<!-- Example of user with readonly access. -->
......
#include <Interpreters/ProcessList.h>
#include <DataStreams/BlockIO.h>
namespace DB
{
BlockIO::~BlockIO() = default;
BlockIO::BlockIO() = default;
BlockIO::BlockIO(const BlockIO &) = default;
}
......@@ -11,15 +11,19 @@ class ProcessListEntry;
struct BlockIO
{
BlockIO() = default;
BlockIO(const BlockIO &) = default;
~BlockIO() = default;
BlockOutputStreamPtr out;
BlockInputStreamPtr in;
/** process_list_entry should be destroyed after in and after out,
* since in and out contain pointer to objects inside process_list_entry (query-level MemoryTracker for example),
* which could be used before destroying of in and out.
*/
std::shared_ptr<ProcessListEntry> process_list_entry;
BlockInputStreamPtr in;
BlockOutputStreamPtr out;
/// Callbacks for query logging could be set here.
std::function<void(IBlockInputStream *, IBlockOutputStream *)> finish_callback;
std::function<void()> exception_callback;
......@@ -37,17 +41,11 @@ struct BlockIO
exception_callback();
}
/// We provide the correct order of destruction.
void reset()
BlockIO & operator= (const BlockIO & rhs)
{
out.reset();
in.reset();
process_list_entry.reset();
}
BlockIO & operator= (const BlockIO & rhs)
{
reset();
process_list_entry = rhs.process_list_entry;
in = rhs.in;
......@@ -58,10 +56,6 @@ struct BlockIO
return *this;
}
~BlockIO();
BlockIO();
BlockIO(const BlockIO &);
};
}
......@@ -26,7 +26,7 @@
#include <Core/Settings.h>
#include <Interpreters/ExpressionJIT.h>
#include <Interpreters/RuntimeComponentsFactory.h>
#include <Interpreters/ISecurityManager.h>
#include <Interpreters/IUsersManager.h>
#include <Interpreters/Quota.h>
#include <Interpreters/EmbeddedDictionaries.h>
#include <Interpreters/ExternalDictionaries.h>
......@@ -129,7 +129,7 @@ struct ContextShared
mutable std::optional<ExternalModels> external_models;
String default_profile_name; /// Default profile name used for default values.
String system_profile_name; /// Profile used by system processes
std::unique_ptr<ISecurityManager> security_manager; /// Known users.
std::unique_ptr<IUsersManager> users_manager; /// Known users.
Quotas quotas; /// Known quotas for resource use.
mutable UncompressedCachePtr uncompressed_cache; /// The cache of decompressed blocks.
mutable MarkCachePtr mark_cache; /// Cache of marks in compressed files.
......@@ -291,7 +291,7 @@ struct ContextShared
private:
void initialize()
{
security_manager = runtime_components_factory->createSecurityManager();
users_manager = runtime_components_factory->createUsersManager();
}
};
......@@ -571,7 +571,7 @@ void Context::setUsersConfig(const ConfigurationPtr & config)
{
auto lock = getLock();
shared->users_config = config;
shared->security_manager->loadFromConfig(*shared->users_config);
shared->users_manager->loadFromConfig(*shared->users_config);
shared->quotas.loadFromConfig(*shared->users_config);
}
......@@ -581,11 +581,39 @@ ConfigurationPtr Context::getUsersConfig()
return shared->users_config;
}
bool Context::hasUserProperty(const String & database, const String & table, const String & name) const
{
auto lock = getLock();
// No user - no properties.
if (client_info.current_user.empty())
return false;
const auto & props = shared->users_manager->getUser(client_info.current_user)->table_props;
auto db = props.find(database);
if (db == props.end())
return false;
auto table_props = db->second.find(table);
if (table_props == db->second.end())
return false;
return !!table_props->second.count(name);
}
const String & Context::getUserProperty(const String & database, const String & table, const String & name) const
{
auto lock = getLock();
const auto & props = shared->users_manager->getUser(client_info.current_user)->table_props;
return props.at(database).at(table).at(name);
}
void Context::calculateUserSettings()
{
auto lock = getLock();
String profile = shared->security_manager->getUser(client_info.current_user)->profile;
String profile = shared->users_manager->getUser(client_info.current_user)->profile;
/// 1) Set default settings (hardcoded values)
/// NOTE: we ignore global_context settings (from which it is usually copied)
......@@ -606,7 +634,7 @@ void Context::setUser(const String & name, const String & password, const Poco::
{
auto lock = getLock();
auto user_props = shared->security_manager->authorizeAndGetUser(name, password, address.host());
auto user_props = shared->users_manager->authorizeAndGetUser(name, password, address.host());
client_info.current_user = name;
client_info.current_address = address;
......@@ -644,7 +672,7 @@ bool Context::hasDatabaseAccessRights(const String & database_name) const
{
auto lock = getLock();
return client_info.current_user.empty() || (database_name == "system") ||
shared->security_manager->hasAccessToDatabase(client_info.current_user, database_name);
shared->users_manager->hasAccessToDatabase(client_info.current_user, database_name);
}
void Context::checkDatabaseAccessRightsImpl(const std::string & database_name) const
......@@ -655,7 +683,7 @@ void Context::checkDatabaseAccessRightsImpl(const std::string & database_name) c
/// All users have access to the database system.
return;
}
if (!shared->security_manager->hasAccessToDatabase(client_info.current_user, database_name))
if (!shared->users_manager->hasAccessToDatabase(client_info.current_user, database_name))
throw Exception("Access denied to database " + database_name + " for user " + client_info.current_user , ErrorCodes::DATABASE_ACCESS_DENIED);
}
......
......@@ -188,6 +188,10 @@ public:
void setUsersConfig(const ConfigurationPtr & config);
ConfigurationPtr getUsersConfig();
// User property is a key-value pair from the configuration entry: users.<username>.databases.<db_name>.<table_name>.<key_name>
bool hasUserProperty(const String & database, const String & table, const String & name) const;
const String & getUserProperty(const String & database, const String & table, const String & name) const;
/// Must be called before getClientInfo.
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key);
/// Compute and set actual user settings, client_info.current_user should be set
......
......@@ -2,7 +2,7 @@
#include <Dictionaries/Embedded/IGeoDictionariesLoader.h>
#include <Interpreters/IExternalLoaderConfigRepository.h>
#include <Interpreters/ISecurityManager.h>
#include <Interpreters/IUsersManager.h>
#include <memory>
......@@ -16,7 +16,9 @@ namespace DB
class IRuntimeComponentsFactory
{
public:
virtual std::unique_ptr<ISecurityManager> createSecurityManager() = 0;
virtual ~IRuntimeComponentsFactory() = default;
virtual std::unique_ptr<IUsersManager> createUsersManager() = 0;
virtual std::unique_ptr<IGeoDictionariesLoader> createGeoDictionariesLoader() = 0;
......@@ -24,8 +26,6 @@ public:
virtual std::unique_ptr<IExternalLoaderConfigRepository> createExternalDictionariesConfigRepository() = 0;
virtual std::unique_ptr<IExternalLoaderConfigRepository> createExternalModelsConfigRepository() = 0;
virtual ~IRuntimeComponentsFactory() {}
};
}
......@@ -5,16 +5,18 @@
namespace DB
{
/** Duties of security manager:
/** Duties of users manager:
* 1) Authenticate users
* 2) Provide user settings (profile, quota, ACLs)
* 3) Grant access to databases
*/
class ISecurityManager
class IUsersManager
{
public:
using UserPtr = std::shared_ptr<const User>;
virtual ~IUsersManager() = default;
virtual void loadFromConfig(const Poco::Util::AbstractConfiguration & config) = 0;
/// Find user and make authorize checks
......@@ -28,8 +30,6 @@ public:
/// Check if the user has access to the database.
virtual bool hasAccessToDatabase(const String & user_name, const String & database_name) const = 0;
virtual ~ISecurityManager() {}
};
}
......@@ -23,12 +23,13 @@
#include <DataStreams/ConvertColumnLowCardinalityToFullBlockInputStream.h>
#include <DataStreams/ConvertingBlockInputStream.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTOrderByElement.h>
#include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/ParserSelectQuery.h>
#include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
......@@ -75,6 +76,60 @@ namespace ErrorCodes
extern const int INVALID_LIMIT_EXPRESSION;
}
namespace
{
/// Assumes `storage` is set and the table filter is not empty.
String generateFilterActions(ExpressionActionsPtr & actions, const StoragePtr & storage, const Context & context, const Names & prerequisite_columns = {})
{
const auto & db_name = storage->getDatabaseName();
const auto & table_name = storage->getTableName();
const auto & filter_str = context.getUserProperty(db_name, table_name, "filter");
/// TODO: implement some AST builders for this kind of stuff
ASTPtr query_ast = std::make_shared<ASTSelectQuery>();
auto * select_ast = query_ast->as<ASTSelectQuery>();
auto expr_list = std::make_shared<ASTExpressionList>();
select_ast->children.push_back(expr_list);
select_ast->select_expression_list = select_ast->children.back();
auto parseExpression = [] (const String & expr)
{
ParserExpression expr_parser;
return parseQuery(expr_parser, expr, 0);
};
// The first column is our filter expression.
expr_list->children.push_back(parseExpression(filter_str));
/// Keep columns that are required after the filter actions.
for (const auto & column_str : prerequisite_columns)
expr_list->children.push_back(parseExpression(column_str));
auto tables = std::make_shared<ASTTablesInSelectQuery>();
auto tables_elem = std::make_shared<ASTTablesInSelectQueryElement>();
auto table_expr = std::make_shared<ASTTableExpression>();
select_ast->children.push_back(tables);
select_ast->tables = select_ast->children.back();
tables->children.push_back(tables_elem);
tables_elem->table_expression = table_expr;
tables_elem->children.push_back(table_expr);
table_expr->database_and_table_name = createTableIdentifier(db_name, table_name);
table_expr->children.push_back(table_expr->database_and_table_name);
/// Using separate expression analyzer to prevent any possible alias injection
auto syntax_result = SyntaxAnalyzer(context).analyze(query_ast, storage->getColumns().getAllPhysical());
ExpressionAnalyzer analyzer(query_ast, syntax_result, context);
ExpressionActionsChain new_chain(context);
analyzer.appendSelect(new_chain, false);
actions = new_chain.getLastActions();
return expr_list->children.at(0)->getColumnName();
}
} // namespace
InterpreterSelectQuery::InterpreterSelectQuery(
const ASTPtr & query_ptr_,
const Context & context_,
......@@ -302,7 +357,8 @@ BlockInputStreams InterpreterSelectQuery::executeWithMultipleStreams()
return pipeline.streams;
}
InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpressions(QueryProcessingStage::Enum from_stage, bool dry_run)
InterpreterSelectQuery::AnalysisResult
InterpreterSelectQuery::analyzeExpressions(QueryProcessingStage::Enum from_stage, bool dry_run, const FilterInfoPtr & filter_info)
{
AnalysisResult res;
......@@ -318,6 +374,7 @@ InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpression
* throw out unnecessary columns based on the entire query. In unnecessary parts of the query, we will not execute subqueries.
*/
bool has_filter = false;
bool has_prewhere = false;
bool has_where = false;
size_t where_step_num;
......@@ -350,10 +407,15 @@ InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpression
res.columns_to_remove_after_prewhere = std::move(columns_to_remove);
}
else if (has_filter)
{
/// Can't have prewhere and filter set simultaneously
res.filter_info->do_remove_column = chain.steps.at(0).can_remove_required_output.at(0);
}
if (has_where)
res.remove_where_filter = chain.steps.at(where_step_num).can_remove_required_output.at(0);
has_prewhere = has_where = false;
has_filter = has_prewhere = has_where = false;
chain.clear();
};
......@@ -378,6 +440,26 @@ InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpression
columns_for_final.begin(), columns_for_final.end());
}
if (storage && context.hasUserProperty(storage->getDatabaseName(), storage->getTableName(), "filter"))
{
has_filter = true;
/// XXX: aggregated copy-paste from ExpressionAnalyzer::appendSmth()
if (chain.steps.empty())
{
chain.steps.emplace_back(std::make_shared<ExpressionActions>(source_columns, context));
}
ExpressionActionsChain::Step & step = chain.steps.back();
// FIXME: assert(filter_info);
res.filter_info = filter_info;
step.actions = filter_info->actions;
step.required_output.push_back(res.filter_info->column_name);
step.can_remove_required_output = {true};
chain.addStep();
}
if (query_analyzer->appendPrewhere(chain, !res.first_stage, additional_required_columns_after_prewhere))
{
has_prewhere = true;
......@@ -445,6 +527,8 @@ InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpression
}
/// Before executing WHERE and HAVING, remove the extra columns from the block (mostly the aggregation keys).
if (res.filter_info)
res.filter_info->actions->prependProjectInput();
if (res.has_where)
res.before_where->prependProjectInput();
if (res.has_having)
......@@ -491,7 +575,8 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt
QueryProcessingStage::Enum from_stage = QueryProcessingStage::FetchColumns;
/// PREWHERE optimization
if (storage)
/// Turn off, if the table filter is applied.
if (storage && !context.hasUserProperty(storage->getDatabaseName(), storage->getTableName(), "filter"))
{
if (!dry_run)
from_stage = storage->getQueryProcessingStage(context);
......@@ -517,11 +602,23 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt
}
AnalysisResult expressions;
FilterInfoPtr filter_info;
/// We need proper `source_header` for `NullBlockInputStream` in dry-run.
if (storage && context.hasUserProperty(storage->getDatabaseName(), storage->getTableName(), "filter"))
{
filter_info = std::make_shared<FilterInfo>();
filter_info->column_name = generateFilterActions(filter_info->actions, storage, context, required_columns);
source_header = storage->getSampleBlockForColumns(filter_info->actions->getRequiredColumns());
}
if (dry_run)
{
pipeline.streams.emplace_back(std::make_shared<NullBlockInputStream>(source_header));
expressions = analyzeExpressions(QueryProcessingStage::FetchColumns, true);
expressions = analyzeExpressions(QueryProcessingStage::FetchColumns, true, filter_info);
if (storage && expressions.filter_info && expressions.prewhere_info)
throw Exception("PREWHERE is not supported if the table is filtered by row-level security expression", ErrorCodes::ILLEGAL_PREWHERE);
if (expressions.prewhere_info)
pipeline.streams.back() = std::make_shared<FilterBlockInputStream>(
......@@ -533,12 +630,15 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt
if (prepared_input)
pipeline.streams.push_back(prepared_input);
expressions = analyzeExpressions(from_stage, false);
expressions = analyzeExpressions(from_stage, false, filter_info);
if (from_stage == QueryProcessingStage::WithMergeableState &&
options.to_stage == QueryProcessingStage::WithMergeableState)
throw Exception("Distributed on Distributed is not supported", ErrorCodes::NOT_IMPLEMENTED);
if (storage && expressions.filter_info && expressions.prewhere_info)
throw Exception("PREWHERE is not supported if the table is filtered by row-level security expression", ErrorCodes::ILLEGAL_PREWHERE);
/** Read the data from Storage. from_stage - to what stage the request was completed in Storage. */
executeFetchColumns(from_stage, pipeline, expressions.prewhere_info, expressions.columns_to_remove_after_prewhere);
......@@ -563,6 +663,18 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt
if (expressions.first_stage)
{
if (expressions.filter_info)
{
pipeline.transform([&](auto & stream)
{
stream = std::make_shared<FilterBlockInputStream>(
stream,
expressions.filter_info->actions,
expressions.filter_info->column_name,
expressions.filter_info->do_remove_column);
});
}
if (expressions.hasJoin())
{
const auto & join = query.join()->table_join->as<ASTTableJoin &>();
......@@ -788,11 +900,26 @@ void InterpreterSelectQuery::executeFetchColumns(
/// Actions to calculate ALIAS if required.
ExpressionActionsPtr alias_actions;
/// Are ALIAS columns required for query execution?
auto alias_columns_required = false;
if (storage)
{
/// Append columns from the table filter to required
if (context.hasUserProperty(storage->getDatabaseName(), storage->getTableName(), "filter"))
{
auto initial_required_columns = required_columns;
ExpressionActionsPtr actions;
generateFilterActions(actions, storage, context, initial_required_columns);
auto required_columns_from_filter = actions->getRequiredColumns();
for (const auto & column : required_columns_from_filter)
{
if (required_columns.end() == std::find(required_columns.begin(), required_columns.end(), column))
required_columns.push_back(column);
}
}
/// Detect, if ALIAS columns are required for query execution
auto alias_columns_required = false;
const ColumnsDescription & storage_columns = storage->getColumns();
for (const auto & column_name : required_columns)
{
......@@ -804,25 +931,33 @@ void InterpreterSelectQuery::executeFetchColumns(
}
}
/// There are multiple sources of required columns:
/// - raw required columns,
/// - columns deduced from ALIAS columns,
/// - raw required columns from PREWHERE,
/// - columns deduced from ALIAS columns from PREWHERE.
/// PREWHERE is a special case, since we need to resolve it and pass directly to `IStorage::read()`
/// before any other executions.
if (alias_columns_required)
{
/// Columns required for prewhere actions.
NameSet required_prewhere_columns;
/// Columns required for prewhere actions which are aliases in storage.
NameSet required_prewhere_aliases;
Block prewhere_actions_result;
NameSet required_columns_from_prewhere; /// Set of all (including ALIAS) required columns for PREWHERE
NameSet required_aliases_from_prewhere; /// Set of ALIAS required columns for PREWHERE
if (prewhere_info)
{
/// Get some columns directly from PREWHERE expression actions
auto prewhere_required_columns = prewhere_info->prewhere_actions->getRequiredColumns();
required_prewhere_columns.insert(prewhere_required_columns.begin(), prewhere_required_columns.end());
prewhere_actions_result = prewhere_info->prewhere_actions->getSampleBlock();
required_columns_from_prewhere.insert(prewhere_required_columns.begin(), prewhere_required_columns.end());
}
/// We will create an expression to return all the requested columns, with the calculation of the required ALIAS columns.
ASTPtr required_columns_expr_list = std::make_shared<ASTExpressionList>();
/// Separate expression for columns used in prewhere.
ASTPtr required_prewhere_columns_expr_list = std::make_shared<ASTExpressionList>();
/// Expression, that contains all raw required columns
ASTPtr required_columns_all_expr = std::make_shared<ASTExpressionList>();
/// Expression, that contains raw required columns for PREWHERE
ASTPtr required_columns_from_prewhere_expr = std::make_shared<ASTExpressionList>();
/// Sort out already known required columns between expressions,
/// also populate `required_aliases_from_prewhere`.
for (const auto & column : required_columns)
{
ASTPtr column_expr;
......@@ -833,36 +968,47 @@ void InterpreterSelectQuery::executeFetchColumns(
else
column_expr = std::make_shared<ASTIdentifier>(column);
if (required_prewhere_columns.count(column))
if (required_columns_from_prewhere.count(column))
{
required_prewhere_columns_expr_list->children.emplace_back(std::move(column_expr));
required_columns_from_prewhere_expr->children.emplace_back(std::move(column_expr));
if (is_alias)
required_prewhere_aliases.insert(column);
required_aliases_from_prewhere.insert(column);
}
else
required_columns_expr_list->children.emplace_back(std::move(column_expr));
required_columns_all_expr->children.emplace_back(std::move(column_expr));
}
/// Columns which we will get after prewhere execution.
NamesAndTypesList additional_source_columns;
/// Add columns which will be added by prewhere (otherwise we will remove them in project action).
NameSet columns_to_remove(columns_to_remove_after_prewhere.begin(), columns_to_remove_after_prewhere.end());
for (const auto & column : prewhere_actions_result)
/// Columns, which we will get after prewhere and filter executions.
NamesAndTypesList required_columns_after_prewhere;
NameSet required_columns_after_prewhere_set;
/// Collect required columns from prewhere expression actions.
if (prewhere_info)
{
if (prewhere_info->remove_prewhere_column && column.name == prewhere_info->prewhere_column_name)
continue;
NameSet columns_to_remove(columns_to_remove_after_prewhere.begin(), columns_to_remove_after_prewhere.end());
Block prewhere_actions_result = prewhere_info->prewhere_actions->getSampleBlock();
/// Populate required columns with the columns, added by PREWHERE actions and not removed afterwards.
/// XXX: looks hacky that we already know which columns after PREWHERE we won't need for sure.
for (const auto & column : prewhere_actions_result)
{
if (prewhere_info->remove_prewhere_column && column.name == prewhere_info->prewhere_column_name)
continue;
if (columns_to_remove.count(column.name))
continue;
if (columns_to_remove.count(column.name))
continue;
required_columns_expr_list->children.emplace_back(std::make_shared<ASTIdentifier>(column.name));
additional_source_columns.emplace_back(column.name, column.type);
required_columns_all_expr->children.emplace_back(std::make_shared<ASTIdentifier>(column.name));
required_columns_after_prewhere.emplace_back(column.name, column.type);
}
required_columns_after_prewhere_set
= ext::map<NameSet>(required_columns_after_prewhere, [](const auto & it) { return it.name; });
}
auto additional_source_columns_set = ext::map<NameSet>(additional_source_columns, [] (const auto & it) { return it.name; });
auto syntax_result = SyntaxAnalyzer(context).analyze(required_columns_expr_list, additional_source_columns, {}, storage);
alias_actions = ExpressionAnalyzer(required_columns_expr_list, syntax_result, context).getActions(true);
auto syntax_result = SyntaxAnalyzer(context).analyze(required_columns_all_expr, required_columns_after_prewhere, {}, storage);
alias_actions = ExpressionAnalyzer(required_columns_all_expr, syntax_result, context).getActions(true);
/// The set of required columns could be added as a result of adding an action to calculate ALIAS.
required_columns = alias_actions->getRequiredColumns();
......@@ -874,17 +1020,10 @@ void InterpreterSelectQuery::executeFetchColumns(
prewhere_info->remove_prewhere_column = false;
/// Remove columns which will be added by prewhere.
size_t next_req_column_pos = 0;
for (size_t i = 0; i < required_columns.size(); ++i)
required_columns.erase(std::remove_if(required_columns.begin(), required_columns.end(), [&](const String & name)
{
if (!additional_source_columns_set.count(required_columns[i]))
{
if (next_req_column_pos < i)
std::swap(required_columns[i], required_columns[next_req_column_pos]);
++next_req_column_pos;
}
}
required_columns.resize(next_req_column_pos);
return !!required_columns_after_prewhere_set.count(name);
}), required_columns.end());
if (prewhere_info)
{
......@@ -898,21 +1037,22 @@ void InterpreterSelectQuery::executeFetchColumns(
}
prewhere_info->prewhere_actions = std::move(new_actions);
auto analyzed_result = SyntaxAnalyzer(context).analyze(required_prewhere_columns_expr_list, storage->getColumns().getAllPhysical());
prewhere_info->alias_actions =
ExpressionAnalyzer(required_prewhere_columns_expr_list, analyzed_result, context)
.getActions(true, false);
auto analyzed_result
= SyntaxAnalyzer(context).analyze(required_columns_from_prewhere_expr, storage->getColumns().getAllPhysical());
prewhere_info->alias_actions
= ExpressionAnalyzer(required_columns_from_prewhere_expr, analyzed_result, context).getActions(true, false);
/// Add columns required by alias actions.
auto required_aliased_columns = prewhere_info->alias_actions->getRequiredColumns();
for (auto & column : required_aliased_columns)
/// Add (physical?) columns required by alias actions.
auto required_columns_from_alias = prewhere_info->alias_actions->getRequiredColumns();
Block prewhere_actions_result = prewhere_info->prewhere_actions->getSampleBlock();
for (auto & column : required_columns_from_alias)
if (!prewhere_actions_result.has(column))
if (required_columns.end() == std::find(required_columns.begin(), required_columns.end(), column))
required_columns.push_back(column);
/// Add columns required by prewhere actions.
for (const auto & column : required_prewhere_columns)
if (required_prewhere_aliases.count(column) == 0)
/// Add physical columns required by prewhere actions.
for (const auto & column : required_columns_from_prewhere)
if (required_aliases_from_prewhere.count(column) == 0)
if (required_columns.end() == std::find(required_columns.begin(), required_columns.end(), column))
required_columns.push_back(column);
}
......@@ -1013,12 +1153,17 @@ void InterpreterSelectQuery::executeFetchColumns(
if (pipeline.streams.empty())
{
pipeline.streams.emplace_back(std::make_shared<NullBlockInputStream>(storage->getSampleBlockForColumns(required_columns)));
pipeline.streams = {std::make_shared<NullBlockInputStream>(storage->getSampleBlockForColumns(required_columns))};
if (query_info.prewhere_info)
pipeline.streams.back() = std::make_shared<FilterBlockInputStream>(
pipeline.streams.back(), prewhere_info->prewhere_actions,
prewhere_info->prewhere_column_name, prewhere_info->remove_prewhere_column);
pipeline.transform([&](auto & stream)
{
stream = std::make_shared<FilterBlockInputStream>(
stream,
prewhere_info->prewhere_actions,
prewhere_info->prewhere_column_name,
prewhere_info->remove_prewhere_column);
});
}
pipeline.transform([&](auto & stream)
......@@ -1434,6 +1579,7 @@ void InterpreterSelectQuery::executeLimitBy(Pipeline & pipeline)
}
// TODO: move to anonymous namespace
bool hasWithTotalsInAnySubqueryInFromClause(const ASTSelectQuery & query)
{
if (query.group_by_with_totals)
......
......@@ -104,13 +104,13 @@ private:
BlockInputStreamPtr & firstStream() { return streams.at(0); }
template <typename Transform>
void transform(Transform && transform)
void transform(Transform && transformation)
{
for (auto & stream : streams)
transform(stream);
transformation(stream);
if (stream_with_non_joined_data)
transform(stream_with_non_joined_data);
transformation(stream_with_non_joined_data);
}
bool hasMoreThanOneStream() const
......@@ -154,9 +154,10 @@ private:
SubqueriesForSets subqueries_for_sets;
PrewhereInfoPtr prewhere_info;
FilterInfoPtr filter_info;
};
AnalysisResult analyzeExpressions(QueryProcessingStage::Enum from_stage, bool dry_run);
AnalysisResult analyzeExpressions(QueryProcessingStage::Enum from_stage, bool dry_run, const FilterInfoPtr & filter_info);
/** From which table to read. With JOIN, the "left" table is returned.
......
#pragma once
#include <Dictionaries/Embedded/GeoDictionariesLoader.h>
#include <Interpreters/IRuntimeComponentsFactory.h>
#include <Interpreters/ExternalLoaderConfigRepository.h>
#include <Interpreters/SecurityManager.h>
#include <Interpreters/IRuntimeComponentsFactory.h>
#include <Interpreters/UsersManager.h>
namespace DB
{
......@@ -14,9 +14,9 @@ namespace DB
class RuntimeComponentsFactory : public IRuntimeComponentsFactory
{
public:
std::unique_ptr<ISecurityManager> createSecurityManager() override
std::unique_ptr<IUsersManager> createUsersManager() override
{
return std::make_unique<SecurityManager>();
return std::make_unique<UsersManager>();
}
std::unique_ptr<IGeoDictionariesLoader> createGeoDictionariesLoader() override
......
......@@ -315,6 +315,34 @@ User::User(const String & name_, const String & config_elem, const Poco::Util::A
databases.insert(database_name);
}
}
/// Read properties per "database.table"
/// Only tables are expected to have properties, so that all the keys inside "database" are table names.
const auto config_databases = config_elem + ".databases";
if (config.has(config_databases))
{
Poco::Util::AbstractConfiguration::Keys database_names;
config.keys(config_databases, database_names);
/// Read tables within databases
for (const auto & database : database_names)
{
const auto config_database = config_databases + "." + database;
Poco::Util::AbstractConfiguration::Keys table_names;
config.keys(config_database, table_names);
/// Read table properties
for (const auto & table : table_names)
{
const auto config_filter = config_database + "." + table + ".filter";
if (config.has(config_filter))
{
const auto filter_query = config.getString(config_filter);
table_props[database][table]["filter"] = filter_query;
}
}
}
}
}
......
......@@ -2,9 +2,10 @@
#include <Core/Types.h>
#include <vector>
#include <unordered_set>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace Poco
......@@ -65,6 +66,12 @@ struct User
using DatabaseSet = std::unordered_set<std::string>;
DatabaseSet databases;
/// Table properties.
using PropertyMap = std::unordered_map<std::string /* name */, std::string /* value */>;
using TableMap = std::unordered_map<std::string /* table */, PropertyMap /* properties */>;
using DatabaseMap = std::unordered_map<std::string /* database */, TableMap /* tables */>;
DatabaseMap table_props;
User(const String & name_, const String & config_elem, const Poco::Util::AbstractConfiguration & config);
};
......
#include "SecurityManager.h"
#include <Interpreters/UsersManager.h>
#include <Poco/Net/IPAddress.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Poco/String.h>
......@@ -28,9 +29,9 @@ namespace ErrorCodes
extern const int SUPPORT_IS_DISABLED;
}
using UserPtr = SecurityManager::UserPtr;
using UserPtr = UsersManager::UserPtr;
void SecurityManager::loadFromConfig(const Poco::Util::AbstractConfiguration & config)
void UsersManager::loadFromConfig(const Poco::Util::AbstractConfiguration & config)
{
Container new_users;
......@@ -46,7 +47,7 @@ void SecurityManager::loadFromConfig(const Poco::Util::AbstractConfiguration & c
users = std::move(new_users);
}
UserPtr SecurityManager::authorizeAndGetUser(
UserPtr UsersManager::authorizeAndGetUser(
const String & user_name,
const String & password,
const Poco::Net::IPAddress & address) const
......@@ -100,7 +101,7 @@ UserPtr SecurityManager::authorizeAndGetUser(
return it->second;
}
UserPtr SecurityManager::getUser(const String & user_name) const
UserPtr UsersManager::getUser(const String & user_name) const
{
auto it = users.find(user_name);
......@@ -110,7 +111,7 @@ UserPtr SecurityManager::getUser(const String & user_name) const
return it->second;
}
bool SecurityManager::hasAccessToDatabase(const std::string & user_name, const std::string & database_name) const
bool UsersManager::hasAccessToDatabase(const std::string & user_name, const std::string & database_name) const
{
auto it = users.find(user_name);
......
#pragma once
#include <Interpreters/ISecurityManager.h>
#include <Interpreters/IUsersManager.h>
#include <map>
namespace DB
{
/** Default implementation of security manager used by native server application.
/** Default implementation of users manager used by native server application.
* Manages fixed set of users listed in 'Users' configuration file.
*/
class SecurityManager : public ISecurityManager
class UsersManager : public IUsersManager
{
private:
using Container = std::map<String, UserPtr>;
Container users;
public:
void loadFromConfig(const Poco::Util::AbstractConfiguration & config) override;
......@@ -27,6 +23,10 @@ public:
UserPtr getUser(const String & user_name) const override;
bool hasAccessToDatabase(const String & user_name, const String & database_name) const override;
private:
using Container = std::map<String, UserPtr>;
Container users;
};
}
#include <Common/Config/ConfigProcessor.h>
#include <Interpreters/SecurityManager.h>
#include <Interpreters/UsersManager.h>
#include <boost/filesystem.hpp>
#include <vector>
#include <string>
......@@ -197,11 +197,11 @@ void runOneTest(const TestDescriptor & test_descriptor)
throw std::runtime_error(os.str());
}
DB::SecurityManager security_manager;
DB::UsersManager users_manager;
try
{
security_manager.loadFromConfig(*config);
users_manager.loadFromConfig(*config);
}
catch (const Poco::Exception & ex)
{
......@@ -216,7 +216,7 @@ void runOneTest(const TestDescriptor & test_descriptor)
try
{
res = security_manager.hasAccessToDatabase(entry.user_name, entry.database_name);
res = users_manager.hasAccessToDatabase(entry.user_name, entry.database_name);
}
catch (const Poco::Exception &)
{
......
......@@ -61,9 +61,9 @@ public:
/// The main name of the table type (for example, StorageMergeTree).
virtual std::string getName() const = 0;
/** The name of the table.
*/
/// The name of the table.
virtual std::string getTableName() const = 0;
virtual std::string getDatabaseName() const { return {}; } // FIXME: should be abstract method.
/** Returns true if the storage receives data from a remote server or servers. */
virtual bool isRemote() const { return false; }
......
......@@ -27,7 +27,7 @@ friend class KafkaBlockOutputStream;
public:
std::string getName() const override { return "Kafka"; }
std::string getTableName() const override { return table_name; }
std::string getDatabaseName() const { return database_name; }
std::string getDatabaseName() const override { return database_name; }
void startup() override;
void shutdown() override;
......
......@@ -25,7 +25,16 @@ struct PrewhereInfo
: prewhere_actions(std::move(prewhere_actions_)), prewhere_column_name(std::move(prewhere_column_name_)) {}
};
/// Helper struct to store all the information about the filter expression.
struct FilterInfo
{
ExpressionActionsPtr actions;
String column_name;
bool do_remove_column = false;
};
using PrewhereInfoPtr = std::shared_ptr<PrewhereInfo>;
using FilterInfoPtr = std::shared_ptr<FilterInfo>;
struct SyntaxAnalyzerResult;
using SyntaxAnalyzerResultPtr = std::shared_ptr<const SyntaxAnalyzerResult>;
......
......@@ -27,12 +27,9 @@ public:
void shutdown() override;
~StorageMergeTree() override;
std::string getName() const override
{
return data.merging_params.getModeName() + "MergeTree";
}
std::string getName() const override { return data.merging_params.getModeName() + "MergeTree"; }
std::string getTableName() const override { return table_name; }
std::string getDatabaseName() const override { return database_name; }
bool supportsSampling() const override { return data.supportsSampling(); }
bool supportsPrewhere() const override { return data.supportsPrewhere(); }
......
......@@ -79,12 +79,10 @@ public:
void shutdown() override;
~StorageReplicatedMergeTree() override;
std::string getName() const override
{
return "Replicated" + data.merging_params.getModeName() + "MergeTree";
}
std::string getName() const override { return "Replicated" + data.merging_params.getModeName() + "MergeTree"; }
std::string getTableName() const override { return table_name; }
std::string getDatabaseName() const override { return database_name; }
bool supportsSampling() const override { return data.supportsSampling(); }
bool supportsFinal() const override { return data.supportsFinal(); }
bool supportsPrewhere() const override { return data.supportsPrewhere(); }
......
-- PREWHERE should fail
1 0
1 1
0 0 0 0
0 0 6 0
0 1
1 0
1
1
0
1
1
1
1
1
0
1
1
0
1
1
1
1
0
1
1
1
1
1
1 0 1 1
1 1 1 1
1 1 1 0
DROP TABLE IF EXISTS filtered_table1;
DROP TABLE IF EXISTS filtered_table2;
DROP TABLE IF EXISTS filtered_table3;
-- Filter: a = 1, values: (1, 0), (1, 1)
CREATE TABLE test.filtered_table1 (a UInt8, b UInt8) ENGINE MergeTree ORDER BY a;
INSERT INTO test.filtered_table1 values (0, 0), (0, 1), (1, 0), (1, 1);
-- Filter: a + b < 1 or c - d > 5, values: (0, 0, 0, 0), (0, 0, 6, 0)
CREATE TABLE test.filtered_table2 (a UInt8, b UInt8, c UInt8, d UInt8) ENGINE MergeTree ORDER BY a;
INSERT INTO test.filtered_table2 values (0, 0, 0, 0), (1, 2, 3, 4), (4, 3, 2, 1), (0, 0, 6, 0);
-- Filter: c = 1, values: (0, 1), (1, 0)
CREATE TABLE test.filtered_table3 (a UInt8, b UInt8, c UInt16 ALIAS a + b) ENGINE MergeTree ORDER BY a;
INSERT INTO test.filtered_table3 values (0, 0), (0, 1), (1, 0), (1, 1);
SELECT '-- PREWHERE should fail';
SELECT * FROM test.filtered_table1 PREWHERE 1; -- { serverError 182 }
SELECT * FROM test.filtered_table2 PREWHERE 1; -- { serverError 182 }
SELECT * FROM test.filtered_table3 PREWHERE 1; -- { serverError 182 }
SELECT * FROM test.filtered_table1;
SELECT * FROM test.filtered_table2;
SELECT * FROM test.filtered_table3;
SELECT a FROM test.filtered_table1;
SELECT b FROM test.filtered_table1;
SELECT a FROM test.filtered_table1 WHERE a = 1;
SELECT a = 1 FROM test.filtered_table1;
SELECT a FROM test.filtered_table3;
SELECT b FROM test.filtered_table3;
SELECT c FROM test.filtered_table3;
SELECT a + b FROM test.filtered_table3;
SELECT a FROM test.filtered_table3 WHERE c = 1;
SELECT c = 1 FROM test.filtered_table3;
SELECT a + b = 1 FROM test.filtered_table3;
SELECT * FROM test.filtered_table1 as t1 ANY LEFT JOIN test.filtered_table1 as t2 ON t1.a = t2.b;
SELECT * FROM test.filtered_table1 as t2 ANY RIGHT JOIN test.filtered_table1 as t1 ON t2.b = t1.a;
DROP TABLE test.filtered_table1;
DROP TABLE test.filtered_table2;
DROP TABLE test.filtered_table3;
DROP TABLE IF EXISTS test.test;
CREATE TABLE test.test (a UInt8, b UInt8, c UInt16 ALIAS a + b) ENGINE = MergeTree ORDER BY a;
SELECT b FROM test.test PREWHERE c = 1;
DROP TABLE test;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册