Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
2dot5
ClickHouse
提交
ff088b4a
C
ClickHouse
项目概览
2dot5
/
ClickHouse
通知
3
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
ClickHouse
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ff088b4a
编写于
10月 09, 2017
作者:
N
Nikolai Kochetov
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modified CatBoostModel [#CLICKHOUSE-3305]
上级
e817de7e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
372 addition
and
17 deletion
+372
-17
dbms/src/Core/ErrorCodes.cpp
dbms/src/Core/ErrorCodes.cpp
+2
-0
dbms/src/Dictionaries/CatBoostModel.cpp
dbms/src/Dictionaries/CatBoostModel.cpp
+353
-15
dbms/src/Dictionaries/CatBoostModel.h
dbms/src/Dictionaries/CatBoostModel.h
+17
-2
未找到文件。
dbms/src/Core/ErrorCodes.cpp
浏览文件 @
ff088b4a
...
...
@@ -384,6 +384,8 @@ namespace ErrorCodes
extern
const
int
UNKNOWN_STATUS_OF_DISTRIBUTED_DDL_TASK
=
379
;
extern
const
int
CANNOT_KILL
=
380
;
extern
const
int
HTTP_LENGTH_REQUIRED
=
381
;
extern
const
int
CANNOT_LOAD_CATBOOST_MODEL
=
382
;
extern
const
int
CANNOT_APPLY_CATBOOST_MODEL
=
383
;
extern
const
int
KEEPER_EXCEPTION
=
999
;
extern
const
int
POCO_EXCEPTION
=
1000
;
...
...
dbms/src/Dictionaries/CatBoostModel.cpp
浏览文件 @
ff088b4a
#include <Dictionaries/CatBoostModel.h>
#include <Core/FieldVisitors.h>
#include <boost/dll/import.hpp>
#include <mutex>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnVector.h>
#include <Common/typeid_cast.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Common/PODArray.h>
namespace
DB
{
namespace
ErrorCodes
{
extern
const
int
LOGICAL_ERROR
;
extern
const
int
BAD_ARGUMENTS
;
extern
const
int
CANNOT_LOAD_CATBOOST_MODEL
;
extern
const
int
CANNOT_APPLY_CATBOOST_MODEL
;
}
namespace
{
...
...
@@ -39,10 +55,322 @@ struct CatBoostWrapperApi
int
(
*
GetIntegerCatFeatureHash
)(
long
long
val
);
};
class
CatBoostWrapperHolder
:
public
CatBoostWrapperApiProvider
class
CatBoostModelHolder
{
private:
CatBoostWrapperApi
::
ModelCalcerHandle
*
handle
;
CatBoostWrapperApi
*
api
;
public:
CatBoostWrapperHolder
(
const
std
::
string
&
lib_path
)
:
lib
(
lib_path
),
lib_path
(
lib_path
)
{
initApi
();
}
explicit
CatBoostModelHolder
(
CatBoostWrapperApi
*
api
)
:
api
(
api
)
{
handle
=
api
->
ModelCalcerCreate
();
}
~
CatBoostModelHolder
()
{
api
->
ModelCalcerDelete
(
handle
);
}
CatBoostWrapperApi
::
ModelCalcerHandle
*
get
()
{
return
handle
;
}
explicit
operator
CatBoostWrapperApi
::
ModelCalcerHandle
*
()
{
return
handle
;
}
};
class
CatBoostModelImpl
:
public
ICatBoostModel
{
public:
CatBoostModelImpl
(
CatBoostWrapperApi
*
api
,
const
std
::
string
&
model_path
)
:
api
(
api
)
{
auto
handle_
=
std
::
make_unique
<
CatBoostModelHolder
>
(
api
);
if
(
!
handle_
)
{
std
::
string
msg
=
"Cannot create CatBoost model: "
;
throw
Exception
(
msg
+
api
->
GetErrorString
(),
ErrorCodes
::
CANNOT_LOAD_CATBOOST_MODEL
);
}
if
(
!
api
->
LoadFullModelFromFile
(
handle_
.
get
(),
model_path
.
c_str
()))
{
std
::
string
msg
=
"Cannot load CatBoost model: "
;
throw
Exception
(
msg
+
api
->
GetErrorString
(),
ErrorCodes
::
CANNOT_LOAD_CATBOOST_MODEL
);
}
handle
=
std
::
move
(
handle_
);
}
ColumnPtr
calc
(
const
Columns
&
columns
,
size_t
float_features_count
,
size_t
cat_features_count
)
{
if
(
columns
.
empty
())
throw
Exception
(
"Got empty columns list for CatBoost model."
,
ErrorCodes
::
BAD_ARGUMENTS
);
if
(
columns
.
size
()
!=
float_features_count
+
cat_features_count
)
{
std
::
string
msg
;
{
WriteBufferFromString
buffer
(
msg
);
buffer
<<
"Number of columns is different with number of features: "
;
buffer
<<
columns
.
size
()
<<
" vs "
<<
float_features_count
<<
" + "
<<
cat_features_count
;
}
throw
Exception
(
msg
,
ErrorCodes
::
BAD_ARGUMENTS
);
}
for
(
size_t
i
=
0
;
i
<
float_features_count
;
++
i
)
{
if
(
!
columns
[
i
]
->
isNumeric
())
{
std
::
string
msg
;
{
WriteBufferFromString
buffer
(
msg
);
buffer
<<
"Column "
<<
i
<<
"should be numeric to make float feature."
;
}
throw
Exception
(
msg
,
ErrorCodes
::
BAD_ARGUMENTS
);
}
}
bool
cat_features_are_strings
=
true
;
for
(
size_t
i
=
float_features_count
;
i
<
float_features_count
+
cat_features_count
;
++
i
)
{
const
auto
&
column
=
columns
[
i
];
if
(
column
->
isNumeric
())
cat_features_are_strings
=
false
;
else
if
(
!
(
typeid_cast
<
const
ColumnString
*>
(
column
.
get
())
||
typeid_cast
<
const
ColumnFixedString
*>
(
column
.
get
())))
{
std
::
string
msg
;
{
WriteBufferFromString
buffer
(
msg
);
buffer
<<
"Column "
<<
i
<<
"should be numeric or string."
;
}
throw
Exception
(
msg
,
ErrorCodes
::
BAD_ARGUMENTS
);
}
}
return
calcImpl
(
columns
,
float_features_count
,
cat_features_count
,
cat_features_are_strings
);
}
private:
std
::
unique_ptr
<
CatBoostModelHolder
>
handle
;
CatBoostWrapperApi
*
api
;
/// Buffer should be allocated with features_count * column->size() elements.
/// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
template
<
typename
T
>
void
placeColumnAsNumber
(
const
ColumnPtr
&
column
,
T
*
buffer
,
size_t
features_count
)
{
size_t
size
=
column
->
size
();
FieldVisitorConvertToNumber
<
T
>
visitor
;
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
/// TODO: Replace with column visitor.
Field
field
;
column
->
get
(
i
,
field
);
*
buffer
=
applyVisitor
(
visitor
,
field
);
buffer
+=
features_count
;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
void
placeStringColumn
(
const
ColumnString
&
column
,
const
char
**
buffer
,
size_t
features_count
)
{
size_t
size
=
column
.
size
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
*
buffer
=
const_cast
<
char
*>
(
column
.
getDataAtWithTerminatingZero
(
i
).
data
);
buffer
+=
features_count
;
}
}
/// Buffer should be allocated with features_count * column->size() elements.
/// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
/// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
PODArray
<
char
>
placeFixedStringColumn
(
const
ColumnFixedString
&
column
,
const
char
**
buffer
,
size_t
features_count
)
{
size_t
size
=
column
.
size
();
size_t
str_size
=
column
.
getN
();
PODArray
<
char
>
data
(
size
*
(
str_size
+
1
));
char
*
data_ptr
=
data
.
data
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
auto
ref
=
column
.
getDataAt
(
i
);
memcpy
(
data_ptr
,
ref
.
data
,
ref
.
size
);
data_ptr
[
ref
.
size
]
=
0
;
*
buffer
=
data_ptr
;
data_ptr
+=
ref
.
size
+
1
;
buffer
+=
features_count
;
}
return
data
;
}
/// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
template
<
typename
T
>
ColumnPtr
placeNumericColumns
(
const
Columns
&
columns
,
size_t
offset
,
size_t
size
,
const
T
**
buffer
)
{
if
(
size
==
0
)
return
nullptr
;
size_t
column_size
=
columns
[
offset
]
->
size
();
auto
data_column
=
std
::
make_shared
<
typename
ColumnVector
<
T
>>
(
size
*
column_size
);
T
*
data
=
data_column
->
getData
().
data
();
for
(
size_t
i
=
offset
;
i
<
offset
+
size
;
++
i
)
{
const
auto
&
column
=
columns
[
i
];
if
(
column
->
isNumeric
())
placeColumnAsNumber
(
column
,
data
+
i
,
size
);
}
for
(
size_t
i
=
0
;
i
<
column_size
;
++
i
)
{
*
buffer
=
data
;
++
buffer
;
data
+=
size
;
}
return
data_column
;
}
/// Place columns into buffer, returns data which was used for fixed string columns.
/// Buffer should contains column->size() values, each value contains size strings.
std
::
vector
<
PODArray
<
char
>>
placeStringColumns
(
const
Columns
&
columns
,
size_t
offset
,
size_t
size
,
const
char
***
buffer
)
{
if
(
size
==
0
)
return
{};
size_t
column_size
=
columns
[
offset
]
->
size
();
std
::
vector
<
PODArray
<
char
>>
data
;
for
(
size_t
i
=
offset
;
i
<
offset
+
size
;
++
i
)
{
const
auto
&
column
=
columns
[
i
];
if
(
auto
column_string
=
typeid_cast
<
const
ColumnString
*>
(
column
.
get
()))
placeStringColumn
(
*
column_string
,
buffer
[
i
],
size
);
else
if
(
auto
column_fixed_string
=
typeid_cast
<
const
ColumnFixedString
*>
(
column
.
get
()))
data
.
push_back
(
placeFixedStringColumn
(
*
column_fixed_string
,
buffer
[
i
],
size
));
else
throw
Exception
(
"Cannot place string column."
,
ErrorCodes
::
LOGICAL_ERROR
);
}
return
data
;
}
/// Calc hash for string cat feature at ps positions.
template
<
typename
Column
>
void
calcStringHashes
(
const
Column
*
column
,
size_t
features_count
,
size_t
ps
,
const
int
**
buffer
)
{
size_t
column_size
=
column
->
size
();
for
(
size_t
j
=
0
;
j
<
column_size
;
++
j
)
{
auto
ref
=
column
->
getDataAt
(
j
);
const_cast
<
int
*>
(
*
buffer
)[
ps
]
=
api
->
GetStringCatFeatureHash
(
ref
.
data
,
ref
.
size
);
buffer
+=
features_count
;
}
}
/// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
void
calcIntHashes
(
size_t
column_size
,
size_t
features_count
,
size_t
ps
,
const
int
**
buffer
)
{
for
(
size_t
j
=
0
;
j
<
column_size
;
++
j
)
{
const_cast
<
int
*>
(
*
buffer
)[
ps
]
=
api
->
GetIntegerCatFeatureHash
((
*
buffer
)[
ps
]);
buffer
+=
features_count
;
}
}
void
calcHashes
(
const
Columns
&
columns
,
size_t
offset
,
size_t
size
,
const
int
**
buffer
)
{
if
(
size
==
0
)
return
;
size_t
column_size
=
columns
[
offset
]
->
size
();
std
::
vector
<
PODArray
<
char
>>
data
;
for
(
size_t
i
=
offset
;
i
<
offset
+
size
;
++
i
)
{
const
auto
&
column
=
columns
[
i
];
auto
buffer_ptr
=
buffer
;
if
(
auto
column_string
=
typeid_cast
<
const
ColumnString
*>
(
column
.
get
()))
calcStringHashes
(
column_string
,
size
,
column_size
,
buffer
);
else
if
(
auto
column_fixed_string
=
typeid_cast
<
const
ColumnFixedString
*>
(
column
.
get
()))
calcStringHashes
(
column_fixed_string
,
size
,
column_size
,
buffer
);
else
calcIntHashes
(
column_size
,
size
,
column_size
,
buffer
);
}
}
void
fillCatFeaturesBuffer
(
const
char
***
cat_features
,
const
char
**
buffer
,
size_t
column_size
,
size_t
cat_features_count
)
{
for
(
size_t
i
=
0
;
i
<
column_size
;
++
i
)
{
*
cat_features
=
buffer
;
++
cat_features
;
buffer
+=
cat_features_count
;
}
}
ColumnPtr
calcImpl
(
const
Columns
&
columns
,
size_t
float_features_count
,
size_t
cat_features_count
,
bool
cat_features_are_strings
)
{
// size_t size = columns.size();
size_t
column_size
=
columns
.
front
()
->
size
();
PODArray
<
const
float
*>
float_features
(
column_size
);
auto
float_features_buf
=
float_features
.
data
();
auto
float_features_col
=
placeNumericColumns
<
float
>
(
columns
,
0
,
float_features_count
,
float_features_buf
);
auto
result
=
std
::
make_shared
<
ColumnFloat64
>
(
column_size
);
auto
result_buf
=
result
->
getData
().
data
();
std
::
string
error_msg
=
"Error occurred while applying CatBoost model: "
;
if
(
cat_features_count
==
0
)
{
if
(
!
api
->
CalcModelPredictionFlat
(
handle
.
get
(),
column_size
,
float_features_buf
,
float_features_count
,
result_buf
,
column_size
))
{
throw
Exception
(
error_msg
+
api
->
GetErrorString
(),
ErrorCodes
::
CANNOT_APPLY_CATBOOST_MODEL
);
}
return
result
;
}
if
(
cat_features_are_strings
)
{
PODArray
<
const
char
*>
cat_features_holder
(
cat_features_count
*
column_size
);
PODArray
<
const
char
**>
cat_features
(
column_size
);
auto
cat_features_buf
=
cat_features
.
data
();
fillCatFeaturesBuffer
(
cat_features_buf
,
cat_features_holder
.
data
(),
column_size
,
cat_features_count
);
auto
fixed_strings_data
=
placeStringColumns
(
columns
,
float_features_count
,
cat_features_count
,
cat_features_buf
);
if
(
!
api
->
CalcModelPrediction
(
handle
.
get
(),
column_size
,
float_features_buf
,
float_features_count
,
cat_features_buf
,
cat_features_count
,
result_buf
,
column_size
))
{
throw
Exception
(
error_msg
+
api
->
GetErrorString
(),
ErrorCodes
::
CANNOT_APPLY_CATBOOST_MODEL
);
}
}
else
{
PODArray
<
const
int
*>
cat_features
(
column_size
);
auto
cat_features_buf
=
cat_features
.
data
();
auto
cat_features_col
=
placeNumericColumns
<
int
>
(
columns
,
float_features_count
,
cat_features_count
,
cat_features_buf
);
calcHashes
(
columns
,
float_features_count
,
cat_features_count
,
cat_features_buf
);
if
(
!
api
->
CalcModelPredictionWithHashedCatFeatures
(
handle
.
get
(),
column_size
,
float_features_buf
,
float_features_count
,
cat_features_buf
,
cat_features_count
,
result_buf
,
column_size
))
{
throw
Exception
(
error_msg
+
api
->
GetErrorString
(),
ErrorCodes
::
CANNOT_APPLY_CATBOOST_MODEL
);
}
}
return
result
;
}
};
class
CatBoostLibHolder
:
public
CatBoostWrapperApiProvider
{
public:
explicit
CatBoostLibHolder
(
const
std
::
string
&
lib_path
)
:
lib
(
lib_path
),
lib_path
(
lib_path
)
{
initApi
();
}
const
CatBoostWrapperApi
&
getApi
()
const
override
{
return
api
;
}
const
std
::
string
&
getCurrentPath
()
const
{
return
lib_path
;
}
...
...
@@ -62,7 +390,7 @@ private:
}
};
void
CatBoost
Wrapper
Holder
::
initApi
()
void
CatBoost
Lib
Holder
::
initApi
()
{
load
(
api
.
ModelCalcerCreate
,
"ModelCalcerCreate"
);
load
(
api
.
ModelCalcerDelete
,
"ModelCalcerDelete"
);
...
...
@@ -75,9 +403,9 @@ void CatBoostWrapperHolder::initApi()
load
(
api
.
GetIntegerCatFeatureHash
,
"GetIntegerCatFeatureHash"
);
}
std
::
shared_ptr
<
CatBoost
Wrapper
Holder
>
getCatBoostWrapperHolder
(
const
std
::
string
&
lib_path
)
std
::
shared_ptr
<
CatBoost
Lib
Holder
>
getCatBoostWrapperHolder
(
const
std
::
string
&
lib_path
)
{
static
std
::
weak_ptr
<
CatBoost
Wrapper
Holder
>
ptr
;
static
std
::
weak_ptr
<
CatBoost
Lib
Holder
>
ptr
;
static
std
::
mutex
mutex
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
...
...
@@ -85,7 +413,7 @@ std::shared_ptr<CatBoostWrapperHolder> getCatBoostWrapperHolder(const std::strin
if
(
!
result
||
result
->
getCurrentPath
()
!=
lib_path
)
{
result
=
std
::
make_shared
<
CatBoost
Wrapper
Holder
>
(
lib_path
);
result
=
std
::
make_shared
<
CatBoost
Lib
Holder
>
(
lib_path
);
/// This assignment is not atomic, which prevents from creating lock only inside 'if'.
ptr
=
result
;
}
...
...
@@ -103,13 +431,14 @@ CatBoostModel::CatBoostModel(const Poco::Util::AbstractConfiguration & config,
}
CatBoostModel
::
CatBoostModel
(
const
std
::
string
&
name
,
const
std
::
string
&
model_path
,
const
std
::
string
&
lib_path
,
const
ExternalLoadableLifetime
&
lifetime
)
:
name
(
name
),
model_path
(
model_path
),
lifetime
(
lifetime
)
const
ExternalLoadableLifetime
&
lifetime
,
size_t
float_features_count
,
size_t
cat_features_count
)
:
name
(
name
),
model_path
(
model_path
),
lib_path
(
lib_path
),
lifetime
(
lifetime
),
float_features_count
(
float_features_count
),
cat_features_count
(
cat_features_count
)
{
try
{
api_provider
=
getCatBoostWrapperHolder
(
lib_path
);
api
=
&
api_provider
->
getApi
();
init
(
lib_path
);
}
catch
(...)
{
...
...
@@ -117,6 +446,13 @@ CatBoostModel::CatBoostModel(const std::string & name, const std::string & model
}
}
void
CatBoostModel
::
init
(
const
std
::
string
&
lib_path
)
{
api_provider
=
getCatBoostWrapperHolder
(
lib_path
);
api
=
&
api_provider
->
getApi
();
model
=
std
::
make_unique
<
CatBoostModelImpl
>
(
api
,
model_path
);
}
const
ExternalLoadableLifetime
&
CatBoostModel
::
getLifetime
()
const
{
return
lifetime
;
...
...
@@ -129,22 +465,24 @@ bool CatBoostModel::isModified() const
std
::
unique_ptr
<
IExternalLoadable
>
CatBoostModel
::
cloneObject
()
const
{
return
nullptr
;
return
std
::
make_unique
<
CatBoostModel
>
(
name
,
model_path
,
lib_path
,
lifetime
,
float_features_count
,
cat_features_count
)
;
}
size_t
CatBoostModel
::
getFloatFeaturesCount
()
const
{
return
0
;
return
float_features_count
;
}
size_t
CatBoostModel
::
getCatFeaturesCount
()
const
{
return
0
;
return
cat_features_count
;
}
void
CatBoostModel
::
apply
(
const
Columns
&
floatColumns
,
const
Columns
&
catColumns
,
ColumnFloat64
&
result
)
ColumnPtr
CatBoostModel
::
apply
(
const
Columns
&
columns
)
{
if
(
!
model
)
throw
Exception
(
"CatBoost model was not loaded."
,
ErrorCodes
::
LOGICAL_ERROR
);
return
model
->
calc
(
columns
,
float_features_count
,
cat_features_count
);
}
}
dbms/src/Dictionaries/CatBoostModel.h
浏览文件 @
ff088b4a
...
...
@@ -15,6 +15,12 @@ public:
virtual
const
CatBoostWrapperApi
&
getApi
()
const
=
0
;
};
class
ICatBoostModel
{
public:
virtual
~
ICatBoostModel
()
=
default
;
virtual
ColumnPtr
calc
(
const
Columns
&
columns
,
size_t
float_features_count
,
size_t
cat_features_count
)
=
0
;
};
class
CatBoostModel
:
public
IExternalLoadable
{
...
...
@@ -37,18 +43,27 @@ public:
size_t
getFloatFeaturesCount
()
const
;
size_t
getCatFeaturesCount
()
const
;
void
apply
(
const
Columns
&
floatColumns
,
const
Columns
&
catColumns
,
ColumnFloat64
&
result
);
ColumnPtr
apply
(
const
Columns
&
columns
);
private:
std
::
string
name
;
std
::
string
model_path
;
std
::
string
lib_path
;
ExternalLoadableLifetime
lifetime
;
std
::
exception_ptr
creation_exception
;
std
::
shared_ptr
<
CatBoostWrapperApiProvider
>
api_provider
;
const
CatBoostWrapperApi
*
api
;
std
::
unique_ptr
<
ICatBoostModel
>
model
;
size_t
float_features_count
;
size_t
cat_features_count
;
CatBoostModel
(
const
std
::
string
&
name
,
const
std
::
string
&
model_path
,
const
std
::
string
&
lib_path
,
const
ExternalLoadableLifetime
&
lifetime
);
const
std
::
string
&
lib_path
,
const
ExternalLoadableLifetime
&
lifetime
,
size_t
float_features_count
,
size_t
cat_features_count
);
void
init
(
const
std
::
string
&
lib_path
);
};
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录