提交 2000cafe 编写于 作者: 武毅 提交者: GitHub

Merge pull request #5132 from typhoonzero/fix_ft_job_converge

fix ft job converge
...@@ -110,43 +110,10 @@ void NewRemoteParameterUpdater::init( ...@@ -110,43 +110,10 @@ void NewRemoteParameterUpdater::init(
// overwrite optimizerConfigV2 for per-parameter(layer) configs // overwrite optimizerConfigV2 for per-parameter(layer) configs
for (int i = 0; i < parameterSize(); ++i) { for (int i = 0; i < parameterSize(); ++i) {
auto paramConfig = parameters_[i]->getConfig(); // FIXME(typhoonzero): paramConfig always have default values,
if (paramConfig.has_momentum() && // how to check if it's default?
trainerConfig_.learning_method() == "momentum") { // TODO(typhoonzero): log output: optimizerConfigV2.DebugString();
optimizerConfigV2.mutable_sgd()->set_momentum(paramConfig.momentum()); LOG(INFO) << "trainerConfig_: " << trainerConfig_.DebugString();
}
if (paramConfig.has_learning_rate()) {
switch (optimizerConfigV2.lr_policy()) {
case 0:
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
paramConfig.learning_rate());
break;
case 1:
optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
paramConfig.learning_rate());
break;
}
}
if (paramConfig.has_decay_rate()) {
switch (optimizerConfigV2.optimizer()) {
case 1: // SGD
optimizerConfigV2.mutable_sgd()->set_decay(
paramConfig.decay_rate());
break;
case 2: // Adadelta
optimizerConfigV2.mutable_adadelta()->set_decay(
paramConfig.decay_rate());
break;
case 3: // Adagrad
optimizerConfigV2.mutable_adagrad()->set_decay(
paramConfig.decay_rate());
break;
case 4: // Adam
optimizerConfigV2.mutable_adam()->set_decay(
paramConfig.decay_rate());
break;
}
}
// send param and config to pserver // send param and config to pserver
std::string bytes = optimizerConfigV2.SerializeAsString(); std::string bytes = optimizerConfigV2.SerializeAsString();
const char *array = bytes.data(); const char *array = bytes.data();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册