diff --git a/dbms/src/Core/ErrorCodes.cpp b/dbms/src/Core/ErrorCodes.cpp index 40e76c82ca1dc48e2ddf15a643c1ac0b78ca4777..ffeda42047b5815d4987d1800e7c7a646ecb703d 100644 --- a/dbms/src/Core/ErrorCodes.cpp +++ b/dbms/src/Core/ErrorCodes.cpp @@ -384,6 +384,8 @@ namespace ErrorCodes extern const int UNKNOWN_STATUS_OF_DISTRIBUTED_DDL_TASK = 379; extern const int CANNOT_KILL = 380; extern const int HTTP_LENGTH_REQUIRED = 381; + extern const int CANNOT_LOAD_CATBOOST_MODEL = 382; + extern const int CANNOT_APPLY_CATBOOST_MODEL = 383; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Dictionaries/CatBoostModel.cpp b/dbms/src/Dictionaries/CatBoostModel.cpp index deb2a0e7af86142040f94261544976fb86b8c0b6..b79cdd956610ff36f8675f4d5869cc4c74d31a40 100644 --- a/dbms/src/Dictionaries/CatBoostModel.cpp +++ b/dbms/src/Dictionaries/CatBoostModel.cpp @@ -1,10 +1,26 @@ #include +#include #include #include +#include +#include +#include +#include +#include +#include +#include namespace DB { +namespace ErrorCodes +{ +extern const int LOGICAL_ERROR; +extern const int BAD_ARGUMENTS; +extern const int CANNOT_LOAD_CATBOOST_MODEL; +extern const int CANNOT_APPLY_CATBOOST_MODEL; +} + namespace { @@ -39,10 +55,322 @@ struct CatBoostWrapperApi int (* GetIntegerCatFeatureHash)(long long val); }; -class CatBoostWrapperHolder : public CatBoostWrapperApiProvider + +class CatBoostModelHolder { +private: + CatBoostWrapperApi::ModelCalcerHandle * handle; + CatBoostWrapperApi * api; public: - CatBoostWrapperHolder(const std::string & lib_path) : lib(lib_path), lib_path(lib_path) { initApi(); } + explicit CatBoostModelHolder(CatBoostWrapperApi * api) : api(api) { handle = api->ModelCalcerCreate(); } + ~CatBoostModelHolder() { api->ModelCalcerDelete(handle); } + + CatBoostWrapperApi::ModelCalcerHandle * get() { return handle; } + explicit operator CatBoostWrapperApi::ModelCalcerHandle * () { return handle; } +}; + + +class CatBoostModelImpl : public ICatBoostModel +{ +public: + CatBoostModelImpl(CatBoostWrapperApi * api, const std::string & model_path) : api(api) + { + auto handle_ = std::make_unique(api); + if (!handle_) + { + std::string msg = "Cannot create CatBoost model: "; + throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL); + } + if (!api->LoadFullModelFromFile(handle_.get(), model_path.c_str())) + { + std::string msg = "Cannot load CatBoost model: "; + throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL); + } + handle = std::move(handle_); + } + + ColumnPtr calc(const Columns & columns, size_t float_features_count, size_t cat_features_count) + { + if (columns.empty()) + throw Exception("Got empty columns list for CatBoost model.", ErrorCodes::BAD_ARGUMENTS); + + if (columns.size() != float_features_count + cat_features_count) + { + std::string msg; + { + WriteBufferFromString buffer(msg); + buffer << "Number of columns is different with number of features: "; + buffer << columns.size() << " vs " << float_features_count << " + " << cat_features_count; + } + throw Exception(msg, ErrorCodes::BAD_ARGUMENTS); + } + + for (size_t i = 0; i < float_features_count; ++i) + { + if (!columns[i]->isNumeric()) + { + std::string msg; + { + WriteBufferFromString buffer(msg); + buffer << "Column " << i << "should be numeric to make float feature."; + } + throw Exception(msg, ErrorCodes::BAD_ARGUMENTS); + } + } + + bool cat_features_are_strings = true; + for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i) + { + const auto & column = columns[i]; + if (column->isNumeric()) + cat_features_are_strings = false; + else if (!(typeid_cast(column.get()) + || typeid_cast(column.get()))) + { + std::string msg; + { + WriteBufferFromString buffer(msg); + buffer << "Column " << i << "should be numeric or string."; + } + throw Exception(msg, ErrorCodes::BAD_ARGUMENTS); + } + } + + return calcImpl(columns, float_features_count, cat_features_count, cat_features_are_strings); + } + +private: + std::unique_ptr handle; + CatBoostWrapperApi * api; + + /// Buffer should be allocated with features_count * column->size() elements. + /// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count] + template + void placeColumnAsNumber(const ColumnPtr & column, T * buffer, size_t features_count) + { + size_t size = column->size(); + FieldVisitorConvertToNumber visitor; + for (size_t i = 0; i < size; ++i) + { + /// TODO: Replace with column visitor. + Field field; + column->get(i, field); + *buffer = applyVisitor(visitor, field); + buffer += features_count; + } + } + + /// Buffer should be allocated with features_count * column->size() elements. + /// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count] + void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count) + { + size_t size = column.size(); + for (size_t i = 0; i < size; ++i) + { + *buffer = const_cast(column.getDataAtWithTerminatingZero(i).data); + buffer += features_count; + } + } + + /// Buffer should be allocated with features_count * column->size() elements. + /// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count] + /// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero). + PODArray placeFixedStringColumn(const ColumnFixedString & column, const char ** buffer, size_t features_count) + { + size_t size = column.size(); + size_t str_size = column.getN(); + PODArray data(size * (str_size + 1)); + char * data_ptr = data.data(); + + for (size_t i = 0; i < size; ++i) + { + auto ref = column.getDataAt(i); + memcpy(data_ptr, ref.data, ref.size); + data_ptr[ref.size] = 0; + *buffer = data_ptr; + data_ptr += ref.size + 1; + buffer += features_count; + } + + return data; + } + + /// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values. + template + ColumnPtr placeNumericColumns(const Columns & columns, size_t offset, size_t size, const T** buffer) + { + if (size == 0) + return nullptr; + size_t column_size = columns[offset]->size(); + auto data_column = std::make_shared>(size * column_size); + T* data = data_column->getData().data(); + for (size_t i = offset; i < offset + size; ++i) + { + const auto & column = columns[i]; + if (column->isNumeric()) + placeColumnAsNumber(column, data + i, size); + } + + for (size_t i = 0; i < column_size; ++i) + { + *buffer = data; + ++buffer; + data += size; + } + + return data_column; + } + + /// Place columns into buffer, returns data which was used for fixed string columns. + /// Buffer should contains column->size() values, each value contains size strings. + std::vector> placeStringColumns( + const Columns & columns, size_t offset, size_t size, const char *** buffer) + { + if (size == 0) + return {}; + size_t column_size = columns[offset]->size(); + + std::vector> data; + for (size_t i = offset; i < offset + size; ++i) + { + const auto & column = columns[i]; + if (auto column_string = typeid_cast(column.get())) + placeStringColumn(*column_string, buffer[i], size); + else if (auto column_fixed_string = typeid_cast(column.get())) + data.push_back(placeFixedStringColumn(*column_fixed_string, buffer[i], size)); + else + throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR); + } + + return data; + } + + /// Calc hash for string cat feature at ps positions. + template + void calcStringHashes(const Column * column, size_t features_count, size_t ps, const int ** buffer) + { + size_t column_size = column->size(); + for (size_t j = 0; j < column_size; ++j) + { + auto ref = column->getDataAt(j); + const_cast(*buffer)[ps] = api->GetStringCatFeatureHash(ref.data, ref.size); + buffer += features_count; + } + } + + /// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values. + void calcIntHashes(size_t column_size, size_t features_count, size_t ps, const int ** buffer) + { + for (size_t j = 0; j < column_size; ++j) + { + const_cast(*buffer)[ps] = api->GetIntegerCatFeatureHash((*buffer)[ps]); + buffer += features_count; + } + } + + void calcHashes(const Columns & columns, size_t offset, size_t size, const int ** buffer) + { + if (size == 0) + return; + size_t column_size = columns[offset]->size(); + + std::vector> data; + for (size_t i = offset; i < offset + size; ++i) + { + const auto & column = columns[i]; + auto buffer_ptr = buffer; + if (auto column_string = typeid_cast(column.get())) + calcStringHashes(column_string, size, column_size, buffer); + else if (auto column_fixed_string = typeid_cast(column.get())) + calcStringHashes(column_fixed_string, size, column_size, buffer); + else + calcIntHashes(column_size, size, column_size, buffer); + } + } + + void fillCatFeaturesBuffer(const char *** cat_features, const char ** buffer, + size_t column_size, size_t cat_features_count) + { + for (size_t i = 0; i < column_size; ++i) + { + *cat_features = buffer; + ++cat_features; + buffer += cat_features_count; + } + } + + ColumnPtr calcImpl(const Columns & columns, size_t float_features_count, size_t cat_features_count, + bool cat_features_are_strings) + { + // size_t size = columns.size(); + size_t column_size = columns.front()->size(); + + PODArray float_features(column_size); + auto float_features_buf = float_features.data(); + auto float_features_col = placeNumericColumns(columns, 0, float_features_count, float_features_buf); + + auto result= std::make_shared(column_size); + auto result_buf = result->getData().data(); + + std::string error_msg = "Error occurred while applying CatBoost model: "; + + if (cat_features_count == 0) + { + if (!api->CalcModelPredictionFlat(handle.get(), column_size, + float_features_buf, float_features_count, + result_buf, column_size)) + { + + throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL); + } + return result; + } + + + if (cat_features_are_strings) + { + PODArray cat_features_holder(cat_features_count * column_size); + PODArray cat_features(column_size); + auto cat_features_buf = cat_features.data(); + + fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size, cat_features_count); + auto fixed_strings_data = placeStringColumns(columns, float_features_count, + cat_features_count, cat_features_buf); + + if (!api->CalcModelPrediction(handle.get(), column_size, + float_features_buf, float_features_count, + cat_features_buf, cat_features_count, + result_buf, column_size)) + { + throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL); + } + } + else + { + PODArray cat_features(column_size); + auto cat_features_buf = cat_features.data(); + auto cat_features_col = placeNumericColumns(columns, float_features_count, + cat_features_count, cat_features_buf); + calcHashes(columns, float_features_count, cat_features_count, cat_features_buf); + if (!api->CalcModelPredictionWithHashedCatFeatures( + handle.get(), column_size, + float_features_buf, float_features_count, + cat_features_buf, cat_features_count, + result_buf, column_size)) + { + throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL); + } + } + + return result; + } +}; + + +class CatBoostLibHolder: public CatBoostWrapperApiProvider +{ +public: + explicit CatBoostLibHolder(const std::string & lib_path) : lib(lib_path), lib_path(lib_path) { initApi(); } const CatBoostWrapperApi & getApi() const override { return api; } const std::string & getCurrentPath() const { return lib_path; } @@ -62,7 +390,7 @@ private: } }; -void CatBoostWrapperHolder::initApi() +void CatBoostLibHolder::initApi() { load(api.ModelCalcerCreate, "ModelCalcerCreate"); load(api.ModelCalcerDelete, "ModelCalcerDelete"); @@ -75,9 +403,9 @@ void CatBoostWrapperHolder::initApi() load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash"); } -std::shared_ptr getCatBoostWrapperHolder(const std::string & lib_path) +std::shared_ptr getCatBoostWrapperHolder(const std::string & lib_path) { - static std::weak_ptr ptr; + static std::weak_ptr ptr; static std::mutex mutex; std::lock_guard lock(mutex); @@ -85,7 +413,7 @@ std::shared_ptr getCatBoostWrapperHolder(const std::strin if (!result || result->getCurrentPath() != lib_path) { - result = std::make_shared(lib_path); + result = std::make_shared(lib_path); /// This assignment is not atomic, which prevents from creating lock only inside 'if'. ptr = result; } @@ -103,13 +431,14 @@ CatBoostModel::CatBoostModel(const Poco::Util::AbstractConfiguration & config, } CatBoostModel::CatBoostModel(const std::string & name, const std::string & model_path, const std::string & lib_path, - const ExternalLoadableLifetime & lifetime) - : name(name), model_path(model_path), lifetime(lifetime) + const ExternalLoadableLifetime & lifetime, + size_t float_features_count, size_t cat_features_count) + : name(name), model_path(model_path), lib_path(lib_path), lifetime(lifetime), + float_features_count(float_features_count), cat_features_count(cat_features_count) { try { - api_provider = getCatBoostWrapperHolder(lib_path); - api = &api_provider->getApi(); + init(lib_path); } catch (...) { @@ -117,6 +446,13 @@ CatBoostModel::CatBoostModel(const std::string & name, const std::string & model } } +void CatBoostModel::init(const std::string & lib_path) +{ + api_provider = getCatBoostWrapperHolder(lib_path); + api = &api_provider->getApi(); + model = std::make_unique(api, model_path); +} + const ExternalLoadableLifetime & CatBoostModel::getLifetime() const { return lifetime; @@ -129,22 +465,24 @@ bool CatBoostModel::isModified() const std::unique_ptr CatBoostModel::cloneObject() const { - return nullptr; + return std::make_unique(name, model_path, lib_path, lifetime, float_features_count, cat_features_count); } size_t CatBoostModel::getFloatFeaturesCount() const { - return 0; + return float_features_count; } size_t CatBoostModel::getCatFeaturesCount() const { - return 0; + return cat_features_count; } -void CatBoostModel::apply(const Columns & floatColumns, const Columns & catColumns, ColumnFloat64 & result) +ColumnPtr CatBoostModel::apply(const Columns & columns) { - + if (!model) + throw Exception("CatBoost model was not loaded.", ErrorCodes::LOGICAL_ERROR); + return model->calc(columns, float_features_count, cat_features_count); } } diff --git a/dbms/src/Dictionaries/CatBoostModel.h b/dbms/src/Dictionaries/CatBoostModel.h index adacffab05b1b5bb63cfd48aa404030565bccc63..11e7c8c71319f8f3a45a371b4bfde7cdd60efb83 100644 --- a/dbms/src/Dictionaries/CatBoostModel.h +++ b/dbms/src/Dictionaries/CatBoostModel.h @@ -15,6 +15,12 @@ public: virtual const CatBoostWrapperApi & getApi() const = 0; }; +class ICatBoostModel +{ +public: + virtual ~ICatBoostModel() = default; + virtual ColumnPtr calc(const Columns & columns, size_t float_features_count, size_t cat_features_count) = 0; +}; class CatBoostModel : public IExternalLoadable { @@ -37,18 +43,27 @@ public: size_t getFloatFeaturesCount() const; size_t getCatFeaturesCount() const; - void apply(const Columns & floatColumns, const Columns & catColumns, ColumnFloat64 & result); + ColumnPtr apply(const Columns & columns); private: std::string name; std::string model_path; + std::string lib_path; ExternalLoadableLifetime lifetime; std::exception_ptr creation_exception; std::shared_ptr api_provider; const CatBoostWrapperApi * api; + std::unique_ptr model; + + size_t float_features_count; + size_t cat_features_count; + CatBoostModel(const std::string & name, const std::string & model_path, - const std::string & lib_path, const ExternalLoadableLifetime & lifetime); + const std::string & lib_path, const ExternalLoadableLifetime & lifetime, + size_t float_features_count, size_t cat_features_count); + + void init(const std::string & lib_path); }; }