From 39af25595935876614c8ea938510b302ac8b4547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Thu, 27 Jul 2017 09:25:11 +0800 Subject: [PATCH] Fix new optimizer lr (#3074) * default learning rate, temperary fix * update --- go/pserver/client/c/test/test_train.py | 2 +- paddle/trainer/NewRemoteParameterUpdater.cpp | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/go/pserver/client/c/test/test_train.py b/go/pserver/client/c/test/test_train.py index e9264592b4..17082cf892 100644 --- a/go/pserver/client/c/test/test_train.py +++ b/go/pserver/client/c/test/test_train.py @@ -38,7 +38,7 @@ def main(): parameters = paddle.parameters.create(cost) # create optimizer of new remote updater to pserver - optimizer = paddle.optimizer.Momentum(momentum=0) + optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3) print "etcd endoint: ", etcd_endpoint trainer = paddle.trainer.SGD(cost=cost, diff --git a/paddle/trainer/NewRemoteParameterUpdater.cpp b/paddle/trainer/NewRemoteParameterUpdater.cpp index a830ceba57..e1558e3fdf 100644 --- a/paddle/trainer/NewRemoteParameterUpdater.cpp +++ b/paddle/trainer/NewRemoteParameterUpdater.cpp @@ -76,7 +76,11 @@ void NewRemoteParameterUpdater::init( sgdConfigV2->set_decay(paramConfig.decay_rate()); optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const); auto constlr = optimizeConfigV2.mutable_const_lr(); - constlr->set_learning_rate(paramConfig.learning_rate()); + if (paramConfig.has_learning_rate()) { + constlr->set_learning_rate(paramConfig.learning_rate()); + } else { + constlr->set_learning_rate(trainerConfig_.learning_rate()); + } if (trainerConfig_.algorithm() == "sgd") { optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD); // FIXME: config all algorithms -- GitLab