提交 5e62a082 编写于 作者: D Danila Kutenin

produce hints for typo functions and types

上级 f2ded6a0
...@@ -128,7 +128,11 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl( ...@@ -128,7 +128,11 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
return combinator->transformAggregateFunction(nested_function, argument_types, parameters); return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
} }
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); auto hints = this->getHints(name);
if (!hints.empty())
throw Exception("Unknown aggregate function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
else
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
} }
......
#pragma once #pragma once
#include <Common/Exception.h> #include <Common/Exception.h>
#include <Common/NamePrompter.h>
#include <Core/Types.h> #include <Core/Types.h>
#include <Poco/String.h> #include <Poco/String.h>
...@@ -105,6 +106,12 @@ public: ...@@ -105,6 +106,12 @@ public:
return aliases.count(name) || case_insensitive_aliases.count(name); return aliases.count(name) || case_insensitive_aliases.count(name);
} }
std::vector<String> getHints(const String & name) const
{
static const auto registeredNames = getAllRegisteredNames();
return prompter.getHints(name, registeredNames);
}
virtual ~IFactoryWithAliases() {} virtual ~IFactoryWithAliases() {}
private: private:
...@@ -120,6 +127,12 @@ private: ...@@ -120,6 +127,12 @@ private:
/// Case insensitive aliases /// Case insensitive aliases
AliasMap case_insensitive_aliases; AliasMap case_insensitive_aliases;
/**
* prompter for names, if a person makes a typo for some function or type, it
* helps to find best possible match (in particular, edit distance is one or two symbols)
*/
NamePrompter</*MistakeFactor=*/2, /*MaxNumHints=*/2> prompter;
}; };
} }
#pragma once
#include <Core/Types.h>
#include <cctype>
#include <algorithm>
#include <queue>
#include <utility>
#include <iostream>
namespace DB
{
template <size_t MistakeFactor, size_t MaxNumHints>
class NamePrompter
{
public:
using DistanceIndex = std::pair<size_t, size_t>;
using DistanceIndexQueue = std::priority_queue<DistanceIndex>;
static std::vector<String> getHints(const String & name, const std::vector<String> & prompting_strings)
{
DistanceIndexQueue queue;
for (size_t i = 0; i < prompting_strings.size(); ++i)
appendToQueue(i, name, queue, prompting_strings);
return release(queue, prompting_strings);
}
private:
static size_t LevenshteinDistance(const String & lhs, const String & rhs)
{
size_t n = lhs.size();
size_t m = rhs.size();
std::vector<std::vector<size_t>> d(n + 1, std::vector<size_t>(m + 1));
for (size_t i = 1; i <= n; ++i)
d[i][0] = i;
for (size_t i = 1; i <= m; ++i)
d[0][i] = i;
for (size_t j = 1; j <= m; ++j)
{
for (size_t i = 1; i <= n; ++i)
{
if (std::tolower(lhs[i - 1]) == std::tolower(rhs[j - 1]))
{
d[i][j] = d[i - 1][j - 1];
}
else
{
size_t dist1 = d[i - 1][j] + 1;
size_t dist2 = d[i][j - 1] + 1;
size_t dist3 = d[i - 1][j - 1] + 1;
d[i][j] = std::min(dist1, std::min(dist2, dist3));
}
}
}
return d[n][m];
}
static void appendToQueue(size_t ind, const String & name, DistanceIndexQueue & queue, const std::vector<String> & prompting_strings)
{
std::cout << prompting_strings[ind] << std::endl;
if (prompting_strings[ind].size() <= name.size() + MistakeFactor && prompting_strings[ind].size() + MistakeFactor >= name.size())
{
size_t distance = LevenshteinDistance(prompting_strings[ind], name);
if (distance <= MistakeFactor) {
queue.emplace(distance, ind);
if (queue.size() > MaxNumHints)
queue.pop();
}
}
}
static std::vector<String> release(DistanceIndexQueue & queue, const std::vector<String> & prompting_strings)
{
std::vector<String> ans;
ans.reserve(queue.size());
while (!queue.empty())
{
auto top = queue.top();
queue.pop();
ans.push_back(prompting_strings[top.second]);
}
std::reverse(ans.begin(), ans.end());
return ans;
}
};
}
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <Poco/String.h> #include <Poco/String.h>
#include <Common/StringUtils/StringUtils.h> #include <Common/StringUtils/StringUtils.h>
#include <IO/WriteHelpers.h>
namespace DB namespace DB
{ {
...@@ -87,7 +87,11 @@ DataTypePtr DataTypeFactory::get(const String & family_name_param, const ASTPtr ...@@ -87,7 +87,11 @@ DataTypePtr DataTypeFactory::get(const String & family_name_param, const ASTPtr
return it->second(parameters); return it->second(parameters);
} }
throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE); auto hints = this->getHints(family_name);
if (!hints.empty())
throw Exception("Unknown data type family: " + family_name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_TYPE);
else
throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE);
} }
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <Poco/String.h> #include <Poco/String.h>
#include <IO/WriteHelpers.h>
namespace DB namespace DB
{ {
...@@ -43,7 +45,13 @@ FunctionBuilderPtr FunctionFactory::get( ...@@ -43,7 +45,13 @@ FunctionBuilderPtr FunctionFactory::get(
{ {
auto res = tryGet(name, context); auto res = tryGet(name, context);
if (!res) if (!res)
throw Exception("Unknown function " + name, ErrorCodes::UNKNOWN_FUNCTION); {
auto hints = this->getHints(name);
if (!hints.empty())
throw Exception("Unknown function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_FUNCTION);
else
throw Exception("Unknown function " + name, ErrorCodes::UNKNOWN_FUNCTION);
}
return res; return res;
} }
......
...@@ -357,7 +357,18 @@ void ActionsVisitor::visit(const ASTPtr & ast) ...@@ -357,7 +357,18 @@ void ActionsVisitor::visit(const ASTPtr & ast)
? context.getQueryContext() ? context.getQueryContext()
: context; : context;
const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(node->name, function_context); FunctionBuilderPtr function_builder;
try
{
function_builder = FunctionFactory::instance().get(node->name, function_context);
}
catch (DB::Exception & e)
{
auto hints = AggregateFunctionFactory::instance().getHints(node->name);
if (!hints.empty())
e.addMessage("Or unknown aggregate function " + node->name + ". Maybe you meant: " + toString(hints));
e.rethrow();
}
Names argument_names; Names argument_names;
DataTypes argument_types; DataTypes argument_types;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册