未验证 提交 e07d0201 编写于 作者: A alexey-milovidov 提交者: GitHub

Merge pull request #11661 from ClickHouse/return-not-nullable-from-count-distinct-2

Return non-Nullable results from COUNT(DISTINCT), more complete.
......@@ -14,6 +14,7 @@ set (CLICKHOUSE_ODBC_BRIDGE_SOURCES
set (CLICKHOUSE_ODBC_BRIDGE_LINK
PRIVATE
clickhouse_parsers
clickhouse_aggregate_functions
daemon
dbms
Poco::Data
......
......@@ -36,7 +36,10 @@ public:
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{
return std::make_shared<AggregateFunctionArray>(nested_function, arguments);
}
......
......@@ -7,6 +7,12 @@
namespace DB
{
AggregateFunctionPtr AggregateFunctionCount::getOwnNullAdapter(
const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const
{
return std::make_shared<AggregateFunctionCountNotNullUnary>(types[0], params);
}
namespace
{
......@@ -22,7 +28,7 @@ AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, cons
void registerAggregateFunctionCount(AggregateFunctionFactory & factory)
{
factory.registerFunction("count", createAggregateFunctionCount, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("count", {createAggregateFunctionCount, {true}}, AggregateFunctionFactory::CaseInsensitive);
}
}
......@@ -68,16 +68,14 @@ public:
data(place).count = new_count;
}
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const override;
};
/// Simply count number of not-NULL values.
class AggregateFunctionCountNotNullUnary final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>
class AggregateFunctionCountNotNullUnary final
: public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>
{
public:
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params)
......
......@@ -29,18 +29,18 @@ namespace ErrorCodes
}
void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness)
void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness)
{
if (creator == nullptr)
if (creator_with_properties.creator == nullptr)
throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided "
" a null constructor", ErrorCodes::LOGICAL_ERROR);
if (!aggregate_functions.emplace(name, creator).second)
if (!aggregate_functions.emplace(name, creator_with_properties).second)
throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR);
if (case_sensitiveness == CaseInsensitive
&& !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second)
&& !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator_with_properties).second)
throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR);
}
......@@ -59,6 +59,7 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
const String & name,
const DataTypes & argument_types,
const Array & parameters,
AggregateFunctionProperties & out_properties,
int recursion_level) const
{
auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types);
......@@ -76,18 +77,15 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality);
Array nested_parameters = combinator->transformParameters(parameters);
AggregateFunctionPtr nested_function;
bool has_null_arguments = std::any_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(),
[](const auto & type) { return type->onlyNull(); });
/// A little hack - if we have NULL arguments, don't even create nested function.
/// Combinator will check if nested_function was created.
if (name == "count" || std::none_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(),
[](const auto & type) { return type->onlyNull(); }))
nested_function = getImpl(name, nested_types, nested_parameters, recursion_level);
return combinator->transformAggregateFunction(nested_function, type_without_low_cardinality, parameters);
AggregateFunctionPtr nested_function = getImpl(
name, nested_types, nested_parameters, out_properties, has_null_arguments, recursion_level);
return combinator->transformAggregateFunction(nested_function, out_properties, type_without_low_cardinality, parameters);
}
auto res = getImpl(name, type_without_low_cardinality, parameters, recursion_level);
auto res = getImpl(name, type_without_low_cardinality, parameters, out_properties, false, recursion_level);
if (!res)
throw Exception("Logical error: AggregateFunctionFactory returned nullptr", ErrorCodes::LOGICAL_ERROR);
return res;
......@@ -98,19 +96,35 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
const String & name_param,
const DataTypes & argument_types,
const Array & parameters,
AggregateFunctionProperties & out_properties,
bool has_null_arguments,
int recursion_level) const
{
String name = getAliasToOrName(name_param);
Value found;
/// Find by exact match.
if (auto it = aggregate_functions.find(name); it != aggregate_functions.end())
return it->second(name, argument_types, parameters);
{
found = it->second;
}
/// Find by case-insensitive name.
/// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names.
if (recursion_level == 0)
else if (recursion_level == 0)
{
if (auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name)); it != case_insensitive_aggregate_functions.end())
return it->second(name, argument_types, parameters);
if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end())
found = jt->second;
}
if (found.creator)
{
out_properties = found.properties;
/// The case when aggregate function should return NULL on NULL arguments. This case is handled in "get" method.
if (!out_properties.returns_default_when_only_null && has_null_arguments)
return nullptr;
return found.creator(name, argument_types, parameters);
}
/// Combinators of aggregate functions.
......@@ -126,9 +140,8 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
DataTypes nested_types = combinator->transformArguments(argument_types);
Array nested_parameters = combinator->transformParameters(parameters);
AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, recursion_level + 1);
return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, out_properties, recursion_level + 1);
return combinator->transformAggregateFunction(nested_function, out_properties, argument_types, parameters);
}
auto hints = this->getHints(name);
......@@ -140,10 +153,11 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
}
AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types, const Array & parameters) const
AggregateFunctionPtr AggregateFunctionFactory::tryGet(
const String & name, const DataTypes & argument_types, const Array & parameters, AggregateFunctionProperties & out_properties) const
{
return isAggregateFunctionName(name)
? get(name, argument_types, parameters)
? get(name, argument_types, parameters, out_properties)
: nullptr;
}
......
......@@ -26,34 +26,50 @@ using DataTypes = std::vector<DataTypePtr>;
*/
using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>;
struct AggregateFunctionWithProperties
{
AggregateFunctionCreator creator;
AggregateFunctionProperties properties;
AggregateFunctionWithProperties() = default;
AggregateFunctionWithProperties(const AggregateFunctionWithProperties &) = default;
template <typename Creator, std::enable_if_t<!std::is_same_v<Creator, AggregateFunctionWithProperties>> * = nullptr>
AggregateFunctionWithProperties(Creator creator_, AggregateFunctionProperties properties_ = {})
: creator(std::forward<Creator>(creator_)), properties(std::move(properties_))
{
}
};
/** Creates an aggregate function by name.
*/
class AggregateFunctionFactory final : private boost::noncopyable, public IFactoryWithAliases<AggregateFunctionCreator>
class AggregateFunctionFactory final : private boost::noncopyable, public IFactoryWithAliases<AggregateFunctionWithProperties>
{
public:
static AggregateFunctionFactory & instance();
/// Register a function by its name.
/// No locking, you must register all functions before usage of get.
void registerFunction(
const String & name,
Creator creator,
Value creator,
CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Throws an exception if not found.
AggregateFunctionPtr get(
const String & name,
const DataTypes & argument_types,
const Array & parameters = {},
const Array & parameters,
AggregateFunctionProperties & out_properties,
int recursion_level = 0) const;
/// Returns nullptr if not found.
AggregateFunctionPtr tryGet(
const String & name,
const DataTypes & argument_types,
const Array & parameters = {}) const;
const Array & parameters,
AggregateFunctionProperties & out_properties) const;
bool isAggregateFunctionName(const String & name, int recursion_level = 0) const;
......@@ -62,19 +78,21 @@ private:
const String & name,
const DataTypes & argument_types,
const Array & parameters,
AggregateFunctionProperties & out_properties,
bool has_null_arguments,
int recursion_level) const;
private:
using AggregateFunctions = std::unordered_map<String, Creator>;
using AggregateFunctions = std::unordered_map<String, Value>;
AggregateFunctions aggregate_functions;
/// Case insensitive aggregate functions will be additionally added here with lowercased name.
AggregateFunctions case_insensitive_aggregate_functions;
const AggregateFunctions & getCreatorMap() const override { return aggregate_functions; }
const AggregateFunctions & getMap() const override { return aggregate_functions; }
const AggregateFunctions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_aggregate_functions; }
const AggregateFunctions & getCaseInsensitiveMap() const override { return case_insensitive_aggregate_functions; }
String getFactoryName() const override { return "AggregateFunctionFactory"; }
......
......@@ -33,7 +33,10 @@ public:
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{
return std::make_shared<AggregateFunctionForEach>(nested_function, arguments);
}
......
......@@ -31,7 +31,10 @@ public:
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{
return std::make_shared<AggregateFunctionIf>(nested_function, arguments);
}
......
......@@ -34,7 +34,10 @@ public:
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{
const DataTypePtr & argument = arguments[0];
......
......@@ -25,7 +25,7 @@ public:
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>());
return argument_types.front();
}
void create(AggregateDataPtr) const override
......
......@@ -31,13 +31,11 @@ public:
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties & properties,
const DataTypes & arguments,
const Array & params) const override
{
/// Special case for 'count' function. It could be called with Nullable arguments
/// - that means - count number of calls, when all arguments are not NULL.
if (nested_function && nested_function->getName() == "count")
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0], params);
bool has_nullable_types = false;
bool has_null_types = false;
for (const auto & arg_type : arguments)
......@@ -58,15 +56,23 @@ public:
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (has_null_types)
return std::make_shared<AggregateFunctionNothing>(arguments, params);
{
/// Currently the only functions that returns not-NULL on all NULL arguments are count and uniq, and they returns UInt64.
if (properties.returns_default_when_only_null)
return std::make_shared<AggregateFunctionNothing>(DataTypes{
std::make_shared<DataTypeUInt64>()}, params);
else
return std::make_shared<AggregateFunctionNothing>(DataTypes{
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>())}, params);
}
assert(nested_function);
if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params))
return adapter;
bool return_type_is_nullable = !nested_function->returnDefaultWhenOnlyNull() && nested_function->getReturnType()->canBeInsideNullable();
bool serialize_flag = return_type_is_nullable || nested_function->returnDefaultWhenOnlyNull();
bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getReturnType()->canBeInsideNullable();
bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;
if (arguments.size() == 1)
{
......
......@@ -21,6 +21,7 @@ public:
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
......
......@@ -43,6 +43,7 @@ public:
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
......
......@@ -24,7 +24,10 @@ public:
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
return std::make_shared<AggregateFunctionState>(nested_function, arguments, params);
}
......
......@@ -123,13 +123,13 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory)
{
factory.registerFunction("uniq",
createAggregateFunctionUniq<AggregateFunctionUniqUniquesHashSetData, AggregateFunctionUniqUniquesHashSetDataForVariadic>);
{createAggregateFunctionUniq<AggregateFunctionUniqUniquesHashSetData, AggregateFunctionUniqUniquesHashSetDataForVariadic>, {true}});
factory.registerFunction("uniqHLL12",
createAggregateFunctionUniq<false, AggregateFunctionUniqHLL12Data, AggregateFunctionUniqHLL12DataForVariadic>);
{createAggregateFunctionUniq<false, AggregateFunctionUniqHLL12Data, AggregateFunctionUniqHLL12DataForVariadic>, {true}});
factory.registerFunction("uniqExact",
createAggregateFunctionUniq<true, AggregateFunctionUniqExactData, AggregateFunctionUniqExactData<String>>);
{createAggregateFunctionUniq<true, AggregateFunctionUniqExactData, AggregateFunctionUniqExactData<String>>, {true}});
}
}
......@@ -244,12 +244,6 @@ public:
{
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
}
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
};
......@@ -304,12 +298,6 @@ public:
{
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
}
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
};
}
......@@ -85,7 +85,7 @@ AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, c
void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory & factory)
{
factory.registerFunction("uniqUpTo", createAggregateFunctionUniqUpTo);
factory.registerFunction("uniqUpTo", {createAggregateFunctionUniqUpTo, {true}});
}
}
......@@ -166,17 +166,12 @@ public:
* nested_function is a smart pointer to this aggregate function itself.
* arguments and params are for nested_function.
*/
virtual AggregateFunctionPtr getOwnNullAdapter(const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const
virtual AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const
{
return nullptr;
}
/** When the function is wrapped with Null combinator,
* should we return Nullable type with NULL when no values were aggregated
* or we should return non-Nullable type with default value (example: count, countDistinct).
*/
virtual bool returnDefaultWhenOnlyNull() const { return false; }
const DataTypes & getArgumentTypes() const { return argument_types; }
const Array & getParameters() const { return parameters; }
......@@ -286,4 +281,15 @@ public:
};
/// Properties of aggregate function that are independent of argument types and parameters.
struct AggregateFunctionProperties
{
/** When the function is wrapped with Null combinator,
* should we return Nullable type with NULL when no values were aggregated
* or we should return non-Nullable type with default value (example: count, countDistinct).
*/
bool returns_default_when_only_null = false;
};
}
......@@ -59,6 +59,7 @@ public:
*/
virtual AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties & properties,
const DataTypes & arguments,
const Array & params) const = 0;
......
......@@ -381,6 +381,6 @@ if (ENABLE_TESTS AND USE_GTEST)
-Wno-gnu-zero-variadic-macro-arguments
)
target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_parsers dbms clickhouse_common_zookeeper string_utils)
target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_aggregate_functions clickhouse_parsers dbms clickhouse_common_zookeeper string_utils)
add_check(unit_tests_dbms)
endif ()
add_executable(test-connect test_connect.cpp)
target_link_libraries (test-connect PRIVATE dbms)
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#include <iostream>
#include <Poco/Net/StreamSocket.h>
#include <Common/Exception.h>
#include <IO/ReadHelpers.h>
/** In a loop it connects to the server and immediately breaks the connection.
* Using the SO_LINGER option, we ensure that the connection is terminated by sending a RST packet (not FIN).
* This behavior causes a bug in the TCPServer implementation in the Poco library.
*/
int main(int argc, char ** argv)
try
{
for (size_t i = 0, num_iters = argc >= 2 ? DB::parse<size_t>(argv[1]) : 1; i < num_iters; ++i)
{
std::cerr << ".";
Poco::Net::SocketAddress address("localhost", 9000);
int fd = socket(PF_INET, SOCK_STREAM, IPPROTO_IP);
if (fd < 0)
DB::throwFromErrno("Cannot create socket", 0);
linger linger_value;
linger_value.l_onoff = 1;
linger_value.l_linger = 0;
if (0 != setsockopt(fd, SOL_SOCKET, SO_LINGER, &linger_value, sizeof(linger_value)))
DB::throwFromErrno("Cannot set linger", 0);
try
{
int res = connect(fd, address.addr(), address.length());
if (res != 0 && errno != EINPROGRESS && errno != EWOULDBLOCK)
{
close(fd);
DB::throwFromErrno("Cannot connect", 0);
}
close(fd);
}
catch (const Poco::Exception & e)
{
std::cerr << e.displayText() << "\n";
}
}
std::cerr << "\n";
}
catch (const Poco::Exception & e)
{
std::cerr << e.displayText() << "\n";
}
......@@ -16,14 +16,14 @@ namespace ErrorCodes
}
/** If stored objects may have several names (aliases)
* this interface may be helpful
* template parameter is available as Creator
*/
template <typename CreatorFunc>
class IFactoryWithAliases : public IHints<2, IFactoryWithAliases<CreatorFunc>>
* this interface may be helpful
* template parameter is available as Value
*/
template <typename ValueType>
class IFactoryWithAliases : public IHints<2, IFactoryWithAliases<ValueType>>
{
protected:
using Creator = CreatorFunc;
using Value = ValueType;
String getAliasToOrName(const String & name) const
{
......@@ -43,13 +43,13 @@ public:
CaseInsensitive
};
/** Register additional name for creator
* real_name have to be already registered.
*/
/** Register additional name for value
* real_name have to be already registered.
*/
void registerAlias(const String & alias_name, const String & real_name, CaseSensitiveness case_sensitiveness = CaseSensitive)
{
const auto & creator_map = getCreatorMap();
const auto & case_insensitive_creator_map = getCaseInsensitiveCreatorMap();
const auto & creator_map = getMap();
const auto & case_insensitive_creator_map = getCaseInsensitiveMap();
const String factory_name = getFactoryName();
String real_dict_name;
......@@ -80,7 +80,7 @@ public:
{
std::vector<String> result;
auto getter = [](const auto & pair) { return pair.first; };
std::transform(getCreatorMap().begin(), getCreatorMap().end(), std::back_inserter(result), getter);
std::transform(getMap().begin(), getMap().end(), std::back_inserter(result), getter);
std::transform(aliases.begin(), aliases.end(), std::back_inserter(result), getter);
return result;
}
......@@ -88,7 +88,7 @@ public:
bool isCaseInsensitive(const String & name) const
{
String name_lowercase = Poco::toLower(name);
return getCaseInsensitiveCreatorMap().count(name_lowercase) || case_insensitive_aliases.count(name_lowercase);
return getCaseInsensitiveMap().count(name_lowercase) || case_insensitive_aliases.count(name_lowercase);
}
const String & aliasTo(const String & name) const
......@@ -109,11 +109,11 @@ public:
virtual ~IFactoryWithAliases() override {}
private:
using InnerMap = std::unordered_map<String, Creator>; // name -> creator
using InnerMap = std::unordered_map<String, Value>; // name -> creator
using AliasMap = std::unordered_map<String, String>; // alias -> original type
virtual const InnerMap & getCreatorMap() const = 0;
virtual const InnerMap & getCaseInsensitiveCreatorMap() const = 0;
virtual const InnerMap & getMap() const = 0;
virtual const InnerMap & getCaseInsensitiveMap() const = 0;
virtual String getFactoryName() const = 0;
/// Alias map to data_types from previous two maps
......
set(SRCS)
add_executable (finish_sorting_stream finish_sorting_stream.cpp ${SRCS})
target_link_libraries (finish_sorting_stream PRIVATE dbms)
target_link_libraries (finish_sorting_stream PRIVATE clickhouse_aggregate_functions dbms)
......@@ -392,7 +392,8 @@ static DataTypePtr create(const ASTPtr & arguments)
if (function_name.empty())
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row);
AggregateFunctionProperties properties;
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row, properties);
return std::make_shared<DataTypeAggregateFunction>(function, argument_types, params_row);
}
......
......@@ -110,7 +110,8 @@ static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & argum
if (function_name.empty())
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row);
AggregateFunctionProperties properties;
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row, properties);
// check function
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions))
......
......@@ -80,7 +80,7 @@ DataTypePtr DataTypeFactory::get(const String & family_name_param, const ASTPtr
}
void DataTypeFactory::registerDataType(const String & family_name, Creator creator, CaseSensitiveness case_sensitiveness)
void DataTypeFactory::registerDataType(const String & family_name, Value creator, CaseSensitiveness case_sensitiveness)
{
if (creator == nullptr)
throw Exception("DataTypeFactory: the data type family " + family_name + " has been provided "
......@@ -136,7 +136,7 @@ void DataTypeFactory::registerSimpleDataTypeCustom(const String &name, SimpleCre
}, case_sensitiveness);
}
const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const
const DataTypeFactory::Value & DataTypeFactory::findCreatorByName(const String & family_name) const
{
{
DataTypesDictionary::const_iterator it = data_types.find(family_name);
......
......@@ -23,7 +23,7 @@ class DataTypeFactory final : private boost::noncopyable, public IFactoryWithAli
{
private:
using SimpleCreator = std::function<DataTypePtr()>;
using DataTypesDictionary = std::unordered_map<String, Creator>;
using DataTypesDictionary = std::unordered_map<String, Value>;
using CreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>(const ASTPtr & parameters)>;
using SimpleCreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>()>;
......@@ -35,7 +35,7 @@ public:
DataTypePtr get(const ASTPtr & ast) const;
/// Register a type family by its name.
void registerDataType(const String & family_name, Creator creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
void registerDataType(const String & family_name, Value creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Register a simple data type, that have no parameters.
void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
......@@ -47,7 +47,7 @@ public:
void registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
private:
const Creator& findCreatorByName(const String & family_name) const;
const Value & findCreatorByName(const String & family_name) const;
private:
DataTypesDictionary data_types;
......@@ -57,9 +57,9 @@ private:
DataTypeFactory();
const DataTypesDictionary & getCreatorMap() const override { return data_types; }
const DataTypesDictionary & getMap() const override { return data_types; }
const DataTypesDictionary & getCaseInsensitiveCreatorMap() const override { return case_insensitive_data_types; }
const DataTypesDictionary & getCaseInsensitiveMap() const override { return case_insensitive_data_types; }
String getFactoryName() const override { return "DataTypeFactory"; }
};
......
set(SRCS )
add_executable (tab_separated_streams tab_separated_streams.cpp ${SRCS})
target_link_libraries (tab_separated_streams PRIVATE dbms)
target_link_libraries (tab_separated_streams PRIVATE clickhouse_aggregate_functions dbms)
......@@ -20,7 +20,7 @@ namespace ErrorCodes
void FunctionFactory::registerFunction(const
std::string & name,
Creator creator,
Value creator,
CaseSensitiveness case_sensitiveness)
{
if (!functions.emplace(name, creator).second)
......
......@@ -53,7 +53,7 @@ public:
FunctionOverloadResolverImplPtr tryGetImpl(const std::string & name, const Context & context) const;
private:
using Functions = std::unordered_map<std::string, Creator>;
using Functions = std::unordered_map<std::string, Value>;
Functions functions;
Functions case_insensitive_functions;
......@@ -64,9 +64,9 @@ private:
return std::make_unique<DefaultOverloadResolver>(Function::create(context));
}
const Functions & getCreatorMap() const override { return functions; }
const Functions & getMap() const override { return functions; }
const Functions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_functions; }
const Functions & getCaseInsensitiveMap() const override { return case_insensitive_functions; }
String getFactoryName() const override { return "FunctionFactory"; }
......@@ -74,7 +74,7 @@ private:
/// No locking, you must register all functions before usage of get.
void registerFunction(
const std::string & name,
Creator creator,
Value creator,
CaseSensitiveness case_sensitiveness = CaseSensitive);
};
......
......@@ -113,8 +113,9 @@ public:
auto nested_type = array_type->getNestedType();
DataTypes argument_types = {nested_type};
Array params_row;
AggregateFunctionPtr bitmap_function
= AggregateFunctionFactory::instance().get(AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row);
AggregateFunctionProperties properties;
AggregateFunctionPtr bitmap_function = AggregateFunctionFactory::instance().get(
AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row, properties);
return std::make_shared<DataTypeAggregateFunction>(bitmap_function, argument_types, params_row);
}
......@@ -156,8 +157,9 @@ private:
// output data
Array params_row;
AggregateFunctionPtr bitmap_function
= AggregateFunctionFactory::instance().get(AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row);
AggregateFunctionProperties properties;
AggregateFunctionPtr bitmap_function = AggregateFunctionFactory::instance().get(
AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row, properties);
auto col_to = ColumnAggregateFunction::create(bitmap_function);
col_to->reserve(offsets.size());
......
......@@ -97,7 +97,8 @@ DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName());
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row);
AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
}
return aggregate_function->getReturnType();
......
......@@ -115,7 +115,8 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName());
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row);
AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
}
return std::make_shared<DataTypeArray>(aggregate_function->getReturnType());
......
......@@ -420,8 +420,9 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ExpressionActionsPtr & action
aggregate.argument_names[i] = name;
}
AggregateFunctionProperties properties;
aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters) : Array();
aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters);
aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters, properties);
aggregate_descriptions.push_back(aggregate);
}
......
......@@ -34,11 +34,11 @@ target_include_directories (two_level_hash_map SYSTEM BEFORE PRIVATE ${SPARSEHAS
target_link_libraries (two_level_hash_map PRIVATE dbms)
add_executable (in_join_subqueries_preprocessor in_join_subqueries_preprocessor.cpp)
target_link_libraries (in_join_subqueries_preprocessor PRIVATE dbms clickhouse_parsers)
target_link_libraries (in_join_subqueries_preprocessor PRIVATE clickhouse_aggregate_functions dbms clickhouse_parsers)
add_check(in_join_subqueries_preprocessor)
add_executable (users users.cpp)
target_link_libraries (users PRIVATE dbms clickhouse_common_config)
target_link_libraries (users PRIVATE clickhouse_aggregate_functions dbms clickhouse_common_config)
if (OS_LINUX)
add_executable (internal_iotop internal_iotop.cpp)
......
......@@ -103,9 +103,10 @@ int main(int argc, char ** argv)
std::vector<Key> data(n);
Value value;
AggregateFunctionPtr func_count = factory.get("count", data_types_empty);
AggregateFunctionPtr func_avg = factory.get("avg", data_types_uint64);
AggregateFunctionPtr func_uniq = factory.get("uniq", data_types_uint64);
AggregateFunctionProperties properties;
AggregateFunctionPtr func_count = factory.get("count", data_types_empty, {}, properties);
AggregateFunctionPtr func_avg = factory.get("avg", data_types_uint64, {}, properties);
AggregateFunctionPtr func_uniq = factory.get("uniq", data_types_uint64, {}, properties);
#define INIT \
{ \
......
......@@ -47,7 +47,8 @@ struct SummingSortedAlgorithm::AggregateDescription
void init(const char * function_name, const DataTypes & argument_types)
{
function = AggregateFunctionFactory::instance().get(function_name, argument_types);
AggregateFunctionProperties properties;
function = AggregateFunctionFactory::instance().get(function_name, argument_types, {}, properties);
add_function = function->getAddressOfAddFunction();
state.reset(function->sizeOfData(), function->alignOfData());
}
......
......@@ -117,8 +117,9 @@ static void appendGraphitePattern(
aggregate_function_name, params_row, "GraphiteMergeTree storage initialization");
/// TODO Not only Float64
pattern.function = AggregateFunctionFactory::instance().get(aggregate_function_name, {std::make_shared<DataTypeFloat64>()},
params_row);
AggregateFunctionProperties properties;
pattern.function = AggregateFunctionFactory::instance().get(
aggregate_function_name, {std::make_shared<DataTypeFloat64>()}, params_row, properties);
}
else if (startsWith(key, "retention"))
{
......
......@@ -15,7 +15,7 @@ namespace ErrorCodes
}
void TableFunctionFactory::registerFunction(const std::string & name, Creator creator, CaseSensitiveness case_sensitiveness)
void TableFunctionFactory::registerFunction(const std::string & name, Value creator, CaseSensitiveness case_sensitiveness)
{
if (!table_functions.emplace(name, creator).second)
throw Exception("TableFunctionFactory: the table function name '" + name + "' is not unique",
......
......@@ -24,12 +24,11 @@ using TableFunctionCreator = std::function<TableFunctionPtr()>;
class TableFunctionFactory final: private boost::noncopyable, public IFactoryWithAliases<TableFunctionCreator>
{
public:
static TableFunctionFactory & instance();
/// Register a function by its name.
/// No locking, you must register all functions before usage of get.
void registerFunction(const std::string & name, Creator creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
void registerFunction(const std::string & name, Value creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
template <typename Function>
void registerFunction(CaseSensitiveness case_sensitiveness = CaseSensitive)
......@@ -50,11 +49,11 @@ public:
bool isTableFunctionName(const std::string & name) const;
private:
using TableFunctions = std::unordered_map<std::string, Creator>;
using TableFunctions = std::unordered_map<std::string, Value>;
const TableFunctions & getCreatorMap() const override { return table_functions; }
const TableFunctions & getMap() const override { return table_functions; }
const TableFunctions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_table_functions; }
const TableFunctions & getCaseInsensitiveMap() const override { return case_insensitive_table_functions; }
String getFactoryName() const override { return "TableFunctionFactory"; }
......
......@@ -6,7 +6,14 @@ SELECT uniq(number >= 5 ? number : NULL) FROM numbers(10);
SELECT uniqExact(number >= 5 ? number : NULL) FROM numbers(10);
SELECT count(DISTINCT number >= 5 ? number : NULL) FROM numbers(10);
SELECT '---';
SELECT count(NULL);
-- These two returns NULL for now, but we want to change them to return 0.
SELECT uniq(NULL);
SELECT count(DISTINCT NULL);
SELECT '---';
SELECT avg(NULL);
SELECT sum(NULL);
SELECT corr(NULL, NULL);
SELECT corr(1, NULL);
SELECT corr(NULL, 1);
add_executable (convert-month-partitioned-parts main.cpp)
target_link_libraries(convert-month-partitioned-parts PRIVATE dbms clickhouse_parsers boost::program_options)
target_link_libraries(convert-month-partitioned-parts PRIVATE clickhouse_aggregate_functions dbms clickhouse_parsers boost::program_options)
add_executable (zookeeper-adjust-block-numbers-to-parts main.cpp ${SRCS})
target_compile_options(zookeeper-adjust-block-numbers-to-parts PRIVATE -Wno-format)
target_link_libraries (zookeeper-adjust-block-numbers-to-parts PRIVATE dbms clickhouse_common_zookeeper boost::program_options)
target_link_libraries (zookeeper-adjust-block-numbers-to-parts PRIVATE clickhouse_aggregate_functions dbms clickhouse_common_zookeeper boost::program_options)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册