未验证 提交 00809c78 编写于 作者: A alexey-milovidov 提交者: GitHub

Merge pull request #5623 from yandex/Quid37-lin_ref_perf

Merging PR #5505
......@@ -49,8 +49,8 @@ namespace
auto l2_reg_coef = Float64(0.1);
UInt32 batch_size = 15;
std::string weights_updater_name = "\'SGD\'";
std::shared_ptr<IGradientComputer> gradient_computer;
std::string weights_updater_name = "SGD";
std::unique_ptr<IGradientComputer> gradient_computer;
if (!parameters.empty())
{
......@@ -66,20 +66,19 @@ namespace
}
if (parameters.size() > 3)
{
weights_updater_name = applyVisitor(FieldVisitorToString(), parameters[3]);
if (weights_updater_name != "\'SGD\'" && weights_updater_name != "\'Momentum\'" && weights_updater_name != "\'Nesterov\'")
{
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
weights_updater_name = parameters[3].safeGet<String>();
if (weights_updater_name != "SGD" && weights_updater_name != "Momentum" && weights_updater_name != "Nesterov")
throw Exception("Invalid parameter for weights updater. The only supported are 'SGD', 'Momentum' and 'Nesterov'",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
if (std::is_same<Method, FuncLinearRegression>::value)
{
gradient_computer = std::make_shared<LinearRegression>();
gradient_computer = std::make_unique<LinearRegression>();
}
else if (std::is_same<Method, FuncLogisticRegression>::value)
{
gradient_computer = std::make_shared<LogisticRegression>();
gradient_computer = std::make_unique<LogisticRegression>();
}
else
{
......@@ -88,7 +87,7 @@ namespace
return std::make_shared<Method>(
argument_types.size() - 1,
gradient_computer,
std::move(gradient_computer),
weights_updater_name,
learning_rate,
l2_reg_coef,
......
......@@ -37,8 +37,7 @@ public:
Float64 l2_reg_coef,
Float64 target,
const IColumn ** columns,
size_t row_num)
= 0;
size_t row_num) = 0;
virtual void predict(
ColumnVector<Float64>::Container & container,
......@@ -201,9 +200,8 @@ private:
};
/**
* LinearModelData is a class which manages current state of learning
*/
/** LinearModelData is a class which manages current state of learning
*/
class LinearModelData
{
public:
......@@ -249,9 +247,8 @@ private:
std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater;
/**
* The function is called when we want to flush current batch and update our weights
*/
/** The function is called when we want to flush current batch and update our weights
*/
void update_state();
};
......@@ -268,7 +265,7 @@ public:
explicit AggregateFunctionMLMethod(
UInt32 param_num,
std::shared_ptr<IGradientComputer> gradient_computer,
std::unique_ptr<IGradientComputer> gradient_computer,
std::string weights_updater_name,
Float64 learning_rate,
Float64 l2_reg_coef,
......@@ -300,19 +297,15 @@ public:
void create(AggregateDataPtr place) const override
{
std::shared_ptr<IWeightsUpdater> new_weights_updater;
if (weights_updater_name == "\'SGD\'")
{
if (weights_updater_name == "SGD")
new_weights_updater = std::make_shared<StochasticGradientDescent>();
} else if (weights_updater_name == "\'Momentum\'")
{
else if (weights_updater_name == "Momentum")
new_weights_updater = std::make_shared<Momentum>();
} else if (weights_updater_name == "\'Nesterov\'")
{
else if (weights_updater_name == "Nesterov")
new_weights_updater = std::make_shared<Nesterov>();
} else
{
else
throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR);
}
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册