未验证 提交 4ca476ec 编写于 作者: A alexey-milovidov 提交者: GitHub

Merge pull request #5516 from yandex/fix-regression-models

Fix regression models
......@@ -45,7 +45,7 @@ namespace
/// Such default parameters were picked because they did good on some tests,
/// though it still requires to fit parameters to achieve better result
auto learning_rate = Float64(0.00001);
auto learning_rate = Float64(0.01);
auto l2_reg_coef = Float64(0.1);
UInt32 batch_size = 15;
......@@ -134,9 +134,14 @@ void LinearModelData::update_state()
}
void LinearModelData::predict(
ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const
ColumnVector<Float64>::Container & container,
Block & block,
size_t offset,
size_t limit,
const ColumnNumbers & arguments,
const Context & context) const
{
gradient_computer->predict(container, block, arguments, weights, bias, context);
gradient_computer->predict(container, block, offset, limit, arguments, weights, bias, context);
}
void LinearModelData::returnWeights(IColumn & to) const
......@@ -345,42 +350,38 @@ void IWeightsUpdater::add_to_batch(
void LogisticRegression::predict(
ColumnVector<Float64>::Container & container,
Block & block,
size_t offset,
size_t limit,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
const Context & context) const
const Context & /*context*/) const
{
size_t rows_num = block.rows();
std::vector<Float64> results(rows_num, bias);
if (offset > rows_num || offset + limit > rows_num)
throw Exception("Invalid offset and limit for LogisticRegression::predict. "
"Block has " + toString(rows_num) + " rows, but offset is " + toString(offset) +
" and limit is " + toString(limit), ErrorCodes::LOGICAL_ERROR);
std::vector<Float64> results(limit, bias);
for (size_t i = 1; i < arguments.size(); ++i)
{
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNativeNumber(cur_col.type))
{
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
}
/// If column type is already Float64 then castColumn simply returns it
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
if (!features_column)
{
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
}
auto & features_column = cur_col.column;
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
}
for (size_t row_num = 0; row_num < limit; ++row_num)
results[row_num] += weights[i - 1] * features_column->getFloat64(offset + row_num);
}
container.reserve(rows_num);
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
container.reserve(container.size() + limit);
for (size_t row_num = 0; row_num < limit; ++row_num)
container.emplace_back(1 / (1 + exp(-results[row_num])));
}
}
void LogisticRegression::compute(
......@@ -413,10 +414,12 @@ void LogisticRegression::compute(
void LinearRegression::predict(
ColumnVector<Float64>::Container & container,
Block & block,
size_t offset,
size_t limit,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
const Context & context) const
const Context & /*context*/) const
{
if (weights.size() + 1 != arguments.size())
{
......@@ -424,36 +427,33 @@ void LinearRegression::predict(
}
size_t rows_num = block.rows();
std::vector<Float64> results(rows_num, bias);
if (offset > rows_num || offset + limit > rows_num)
throw Exception("Invalid offset and limit for LogisticRegression::predict. "
"Block has " + toString(rows_num) + " rows, but offset is " + toString(offset) +
" and limit is " + toString(limit), ErrorCodes::LOGICAL_ERROR);
std::vector<Float64> results(limit, bias);
for (size_t i = 1; i < arguments.size(); ++i)
{
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNativeNumber(cur_col.type))
{
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
}
/// If column type is already Float64 then castColumn simply returns it
auto features_col_ptr = castColumn(cur_col, std::make_shared<DataTypeFloat64>(), context);
auto features_column = typeid_cast<const ColumnFloat64 *>(features_col_ptr.get());
auto features_column = cur_col.column;
if (!features_column)
{
throw Exception("Unexpectedly cannot dynamically cast features column " + std::to_string(i), ErrorCodes::LOGICAL_ERROR);
}
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
results[row_num] += weights[i - 1] * features_column->getElement(row_num);
}
for (size_t row_num = 0; row_num < limit; ++row_num)
results[row_num] += weights[i - 1] * features_column->getFloat64(row_num + offset);
}
container.reserve(rows_num);
for (size_t row_num = 0; row_num != rows_num; ++row_num)
{
container.reserve(container.size() + limit);
for (size_t row_num = 0; row_num < limit; ++row_num)
container.emplace_back(results[row_num]);
}
}
void LinearRegression::compute(
......
......@@ -3,6 +3,7 @@
#include <Columns/ColumnVector.h>
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnsNumber.h>
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeArray.h>
......@@ -42,6 +43,8 @@ public:
virtual void predict(
ColumnVector<Float64>::Container & container,
Block & block,
size_t offset,
size_t limit,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
......@@ -67,6 +70,8 @@ public:
void predict(
ColumnVector<Float64>::Container & container,
Block & block,
size_t offset,
size_t limit,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
......@@ -92,6 +97,8 @@ public:
void predict(
ColumnVector<Float64>::Container & container,
Block & block,
size_t offset,
size_t limit,
const ColumnNumbers & arguments,
const std::vector<Float64> & weights,
Float64 bias,
......@@ -218,8 +225,13 @@ public:
void read(ReadBuffer & buf);
void
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
void predict(
ColumnVector<Float64>::Container & container,
Block & block,
size_t offset,
size_t limit,
const ColumnNumbers & arguments,
const Context & context) const;
void returnWeights(IColumn & to) const;
private:
......@@ -228,11 +240,11 @@ private:
Float64 learning_rate;
Float64 l2_reg_coef;
UInt32 batch_capacity;
UInt64 batch_capacity;
UInt32 iter_num = 0;
UInt64 iter_num = 0;
std::vector<Float64> gradient_batch;
UInt32 batch_size;
UInt64 batch_size;
std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater;
......@@ -316,7 +328,13 @@ public:
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { this->data(place).read(buf); }
void predictValues(
ConstAggregateDataPtr place, IColumn & to, Block & block, const ColumnNumbers & arguments, const Context & context) const override
ConstAggregateDataPtr place,
IColumn & to,
Block & block,
size_t offset,
size_t limit,
const ColumnNumbers & arguments,
const Context & context) const override
{
if (arguments.size() != param_num + 1)
throw Exception(
......@@ -325,17 +343,12 @@ public:
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
/// This cast might be correct because column type is based on getReturnTypeToPredict.
ColumnVector<Float64> * column;
try
{
column = &dynamic_cast<ColumnVector<Float64> &>(to);
} catch (const std::bad_cast &)
{
auto * column = typeid_cast<ColumnFloat64 *>(&to);
if (!column)
throw Exception("Cast of column of predictions is incorrect. getReturnTypeToPredict must return same value as it is casted to",
ErrorCodes::BAD_CAST);
}
this->data(place).predict(column->getData(), block, arguments, context);
this->data(place).predict(column->getData(), block, offset, limit, arguments, context);
}
/** This function is called if aggregate function without State modifier is selected in a query.
......
......@@ -100,9 +100,16 @@ public:
/// Inserts results into a column.
virtual void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const = 0;
/// This function is used for machine learning methods
virtual void predictValues(ConstAggregateDataPtr /* place */, IColumn & /*to*/,
Block & /*block*/, const ColumnNumbers & /*arguments*/, const Context & /*context*/) const
/// Used for machine learning methods. Predict result from trained model.
/// Will insert result into `to` column for rows in range [offset, offset + limit).
virtual void predictValues(
ConstAggregateDataPtr /* place */,
IColumn & /*to*/,
Block & /*block*/,
size_t /*offset*/,
size_t /*limit*/,
const ColumnNumbers & /*arguments*/,
const Context & /*context*/) const
{
throw Exception("Method predictValues is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
......
......@@ -92,13 +92,21 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
auto ML_function = func.get();
if (ML_function)
{
size_t row_num = 0;
for (auto val : data)
if (data.size() == 1)
{
ML_function->predictValues(val, *res, block, arguments, context);
++row_num;
/// Case for const column. Predict using single model.
ML_function->predictValues(data[0], *res, block, 0, block.rows(), arguments, context);
}
else
{
/// Case for non-constant column. Use different aggregate function for each row.
size_t row_num = 0;
for (auto val : data)
{
ML_function->predictValues(val, *res, block, row_num, 1, arguments, context);
++row_num;
}
}
}
else
{
......
......@@ -53,10 +53,10 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!arguments.size())
if (arguments.empty())
throw Exception("Function " + getName() + " requires at least one argument", ErrorCodes::BAD_ARGUMENTS);
const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
const auto * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
if (!type)
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
......@@ -66,19 +66,21 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
if (!arguments.size())
if (arguments.empty())
throw Exception("Function " + getName() + " requires at least one argument", ErrorCodes::BAD_ARGUMENTS);
const ColumnConst * column_with_states
= typeid_cast<const ColumnConst *>(&*block.getByPosition(arguments[0]).column);
const auto * model = block.getByPosition(arguments[0]).column.get();
if (const auto * column_with_states = typeid_cast<const ColumnConst *>(model))
model = column_with_states->getDataColumnPtr().get();
if (!column_with_states)
const auto * agg_function = typeid_cast<const ColumnAggregateFunction *>(model);
if (!agg_function)
throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
+ " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
block.getByPosition(result).column =
typeid_cast<const ColumnAggregateFunction *>(&*column_with_states->getDataColumnPtr())->predictValues(block, arguments, context);
block.getByPosition(result).column = agg_function->predictValues(block, arguments, context);
}
const Context & context;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册