未验证 提交 a3d6e047 编写于 作者: A Artem Zuikov 提交者: GitHub

Fix lambdas with multiple_joins_rewriter v2 (#11587)

上级 29bf9aa1
......@@ -4,6 +4,7 @@
#include <Interpreters/IdentifierSemantic.h>
#include <Interpreters/AsteriskSemantic.h>
#include <Interpreters/DatabaseAndTableWithAlias.h>
#include <Interpreters/RequiredSourceColumnsVisitor.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSubquery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
......@@ -378,10 +379,43 @@ using AppendSemanticVisitor = InDepthNodeVisitor<AppendSemanticMatcher, true>;
struct CollectColumnIdentifiersMatcher
{
using Data = std::vector<ASTIdentifier *>;
using Visitor = ConstInDepthNodeVisitor<CollectColumnIdentifiersMatcher, true>;
struct Data
{
std::vector<ASTIdentifier *> & identifiers;
std::vector<std::unordered_set<String>> ignored;
explicit Data(std::vector<ASTIdentifier *> & identifiers_)
: identifiers(identifiers_)
{}
void addIdentirier(const ASTIdentifier & ident)
{
for (const auto & aliases : ignored)
if (aliases.count(ident.name))
return;
identifiers.push_back(const_cast<ASTIdentifier *>(&ident));
}
void pushIgnored(const Names & names)
{
ignored.emplace_back(std::unordered_set<String>(names.begin(), names.end()));
}
void popIgnored()
{
ignored.pop_back();
}
};
static bool needChildVisit(const ASTPtr & node, const ASTPtr &)
{
/// "lambda" visit children itself.
if (const auto * f = node->as<ASTFunction>())
if (f->name == "lambda")
return false;
/// Do not go into subqueries. Do not collect table identifiers. Do not get identifier from 't.*'.
return !node->as<ASTSubquery>() &&
!node->as<ASTTablesInSelectQuery>() &&
......@@ -392,14 +426,28 @@ struct CollectColumnIdentifiersMatcher
{
if (auto * t = ast->as<ASTIdentifier>())
visit(*t, ast, data);
else if (auto * f = ast->as<ASTFunction>())
visit(*f, ast, data);
}
static void visit(const ASTIdentifier & ident, const ASTPtr &, Data & data)
{
data.push_back(const_cast<ASTIdentifier *>(&ident));
data.addIdentirier(ident);
}
static void visit(const ASTFunction & func, const ASTPtr &, Data & data)
{
if (func.name == "lambda")
{
data.pushIgnored(RequiredSourceColumnsMatcher::extractNamesFromLambda(func));
Visitor(data).visit(func.arguments->children[1]);
data.popIgnored();
}
}
};
using CollectColumnIdentifiersVisitor = ConstInDepthNodeVisitor<CollectColumnIdentifiersMatcher, true>;
using CollectColumnIdentifiersVisitor = CollectColumnIdentifiersMatcher::Visitor;
struct CheckAliasDependencyVisitorData
{
......@@ -709,7 +757,8 @@ void JoinToSubqueryTransformMatcher::visitV2(ASTSelectQuery & select, ASTPtr & a
/// Collect column identifiers
std::vector<ASTIdentifier *> identifiers;
CollectColumnIdentifiersVisitor(identifiers).visit(ast);
CollectColumnIdentifiersVisitor::Data data_identifiers(identifiers);
CollectColumnIdentifiersVisitor(data_identifiers).visit(ast);
std::vector<ASTIdentifier *> using_identifiers;
std::vector<std::vector<ASTPtr>> alias_pushdown(tables_count);
......@@ -725,7 +774,8 @@ void JoinToSubqueryTransformMatcher::visitV2(ASTSelectQuery & select, ASTPtr & a
if (join.on_expression)
{
std::vector<ASTIdentifier *> on_identifiers;
CollectColumnIdentifiersVisitor(on_identifiers).visit(join.on_expression);
CollectColumnIdentifiersVisitor::Data data_on_identifiers(on_identifiers);
CollectColumnIdentifiersVisitor(data_on_identifiers).visit(join.on_expression);
identifiers.insert(identifiers.end(), on_identifiers.begin(), on_identifiers.end());
/// Extract aliases used in ON section for pushdown. Exclude the last table.
......@@ -744,7 +794,10 @@ void JoinToSubqueryTransformMatcher::visitV2(ASTSelectQuery & select, ASTPtr & a
}
}
else if (join.using_expression_list)
CollectColumnIdentifiersVisitor(using_identifiers).visit(join.on_expression);
{
CollectColumnIdentifiersVisitor::Data data_using_identifiers(using_identifiers);
CollectColumnIdentifiersVisitor(data_using_identifiers).visit(join.using_expression_list);
}
}
}
......
......@@ -17,7 +17,7 @@ namespace ErrorCodes
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
static std::vector<String> extractNamesFromLambda(const ASTFunction & node)
std::vector<String> RequiredSourceColumnsMatcher::extractNamesFromLambda(const ASTFunction & node)
{
if (node.arguments->children.size() != 2)
throw Exception("lambda requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
......
......@@ -26,6 +26,8 @@ public:
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child);
static void visit(const ASTPtr & ast, Data & data);
static std::vector<String> extractNamesFromLambda(const ASTFunction & node);
private:
static void visit(const ASTIdentifier & node, const ASTPtr &, Data & data);
static void visit(const ASTFunction & node, const ASTPtr &, Data & data);
......
set multiple_joins_rewriter_version = 2;
select
arrayMap(x, y -> floor((y - x) / x, 3), l, r) diff_percent,
test, query
from (select [1] l) s1,
(select [2] r) s2,
(select 'test' test, 'query' query) any_query,
(select 1 ) check_single_query;
select
arrayMap(x -> floor(x, 4), original_medians_array.medians_by_version[1] as l) l_rounded,
arrayMap(x -> floor(x, 4), original_medians_array.medians_by_version[2] as r) r_rounded,
arrayMap(x, y -> floor((y - x) / x, 3), l, r) diff_percent,
test, query
from (select 1) rd,
(select [[1,2], [3,4]] medians_by_version) original_medians_array,
(select 'test' test, 'query' query) any_query,
(select 1 as A) check_single_query;
drop table if exists table;
create table table(query String, test String, run UInt32, metrics Array(UInt32), version UInt32) engine Memory;
select
arrayMap(x -> floor(x, 4), original_medians_array.medians_by_version[1] as l) l_rounded,
arrayMap(x -> floor(x, 4), original_medians_array.medians_by_version[2] as r) r_rounded,
arrayMap(x, y -> floor((y - x) / x, 3), l, r) diff_percent,
arrayMap(x, y -> floor(x / y, 3), threshold, l) threshold_percent,
test, query
from
(
select quantileExactForEach(0.999)(arrayMap(x, y -> abs(x - y), metrics_by_label[1], metrics_by_label[2]) as d) threshold
from
(
select virtual_run, groupArrayInsertAt(median_metrics, random_label) metrics_by_label
from
(
select medianExactForEach(metrics) median_metrics, virtual_run, random_label
from
(
select *, toUInt32(rowNumberInAllBlocks() % 2) random_label
from
(
select metrics, number virtual_run
from (select metrics, run, version from table) no_query, numbers(1, 100000) nn
order by virtual_run, rand()
) virtual_runs
) relabeled
group by virtual_run, random_label
) virtual_medians
group by virtual_run
) virtual_medians_array
) rd,
(
select groupArrayInsertAt(median_metrics, version) medians_by_version
from
(
select medianExactForEach(metrics) median_metrics, version
from table
group by version
) original_medians
) original_medians_array,
(
select any(test) test, any(query) query from table
) any_query,
(
select throwIf(uniq((test, query))) from table
) check_single_query;
drop table table;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册