提交 cec5e651 编写于 作者: T typhoonzero

fix ft job converge

上级 d18d75da
......@@ -110,43 +110,10 @@ void NewRemoteParameterUpdater::init(
// overwrite optimizerConfigV2 for per-parameter(layer) configs
for (int i = 0; i < parameterSize(); ++i) {
auto paramConfig = parameters_[i]->getConfig();
if (paramConfig.has_momentum() &&
trainerConfig_.learning_method() == "momentum") {
optimizerConfigV2.mutable_sgd()->set_momentum(paramConfig.momentum());
}
if (paramConfig.has_learning_rate()) {
switch (optimizerConfigV2.lr_policy()) {
case 0:
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
paramConfig.learning_rate());
break;
case 1:
optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
paramConfig.learning_rate());
break;
}
}
if (paramConfig.has_decay_rate()) {
switch (optimizerConfigV2.optimizer()) {
case 1: // SGD
optimizerConfigV2.mutable_sgd()->set_decay(
paramConfig.decay_rate());
break;
case 2: // Adadelta
optimizerConfigV2.mutable_adadelta()->set_decay(
paramConfig.decay_rate());
break;
case 3: // Adagrad
optimizerConfigV2.mutable_adagrad()->set_decay(
paramConfig.decay_rate());
break;
case 4: // Adam
optimizerConfigV2.mutable_adam()->set_decay(
paramConfig.decay_rate());
break;
}
}
// FIXME(typhoonzero): paramConfig always have default values,
// how to check if it's default?
// TODO: log output: optimizerConfigV2.DebugString();
LOG(INFO) << "trainerConfig_: " << trainerConfig_.DebugString();
// send param and config to pserver
std::string bytes = optimizerConfigV2.SerializeAsString();
const char *array = bytes.data();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册