提交 39af2559 编写于 作者: 武毅 提交者: GitHub

Fix new optimizer lr (#3074)

* default learning rate, temperary fix

* update
上级 eff17a68
...@@ -38,7 +38,7 @@ def main(): ...@@ -38,7 +38,7 @@ def main():
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# create optimizer of new remote updater to pserver # create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3)
print "etcd endoint: ", etcd_endpoint print "etcd endoint: ", etcd_endpoint
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
......
...@@ -76,7 +76,11 @@ void NewRemoteParameterUpdater::init( ...@@ -76,7 +76,11 @@ void NewRemoteParameterUpdater::init(
sgdConfigV2->set_decay(paramConfig.decay_rate()); sgdConfigV2->set_decay(paramConfig.decay_rate());
optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const); optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
auto constlr = optimizeConfigV2.mutable_const_lr(); auto constlr = optimizeConfigV2.mutable_const_lr();
constlr->set_learning_rate(paramConfig.learning_rate()); if (paramConfig.has_learning_rate()) {
constlr->set_learning_rate(paramConfig.learning_rate());
} else {
constlr->set_learning_rate(trainerConfig_.learning_rate());
}
if (trainerConfig_.algorithm() == "sgd") { if (trainerConfig_.algorithm() == "sgd") {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD); optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
// FIXME: config all algorithms // FIXME: config all algorithms
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册