NewRemoteParameterUpdater.cpp 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "NewRemoteParameterUpdater.h"
#include "Trainer.h"
#include "paddle/utils/Stat.h"

DECLARE_int32(trainer_id);
DECLARE_string(save_dir);

namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
    const OptimizationConfig &config, const std::string pserverSpec)
W
wuyi05 已提交
25 26
    : trainerConfig_(config),
      parameterClient_(-1),
27 28 29
      newParameters_(nullptr),
      newGradients_(nullptr),
      pserverSpec_(pserverSpec) {}
30

31 32 33 34 35 36 37 38 39 40 41
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
    const OptimizationConfig &config,
    const std::string pserverSpec,
    const bool useEtcd)
    : trainerConfig_(config),
      parameterClient_(-1),
      newParameters_(nullptr),
      newGradients_(nullptr),
      pserverSpec_(pserverSpec),
      useEtcd_(useEtcd) {}

42 43 44 45 46
void NewRemoteParameterUpdater::init(
    const std::vector<ParameterPtr> &parameters) {
  ParameterUpdater::init(parameters);

  // create parameter server client.
47
  if (useEtcd_) {
H
Helin Wang 已提交
48 49
    parameterClient_ =
        paddle_new_etcd_pserver_client((char *)pserverSpec_.c_str());
50 51 52 53
  } else {
    parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
                                                 FLAGS_trainer_id == 0);
  }
54 55

  // init new parameter and gradient.
Q
qiaolongfei 已提交
56 57
  newParameters_ = initNewParameter(PARAMETER_VALUE);
  newGradients_ = initNewParameter(PARAMETER_GRADIENT);
58 59 60 61 62 63

  // init parameter, one trainer will get the opportunity to int parameter and
  // send them to parameter server. Others will get the initialized parameter
  // from parameter server
  if (paddle_begin_init_params(parameterClient_)) {
    LOG(INFO) << "paddle_begin_init_params start";
64 65
    // NOTE: convert V1 OptimizatioinConfig proto to V2 OptimizerConfig.
    // This makes golang pserver compatible with handy V1 demos.
H
Helin Wang 已提交
66
    // TODO(wuyi): Refine or remove these ugly converting lines
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    OptimizerConfig optimizerConfigV2;
    if (trainerConfig_.learning_method() == "momentum") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
    } else if (trainerConfig_.learning_method() == "adagrad") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
      optimizerConfigV2.mutable_adagrad()->set_epsilon(
          trainerConfig_.ada_epsilon());
    } else if (trainerConfig_.learning_method() == "adadelta") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
      optimizerConfigV2.mutable_adadelta()->set_epsilon(
          trainerConfig_.ada_epsilon());
      optimizerConfigV2.mutable_adadelta()->set_rho(trainerConfig_.ada_rou());
    } else if (trainerConfig_.learning_method() == "adam") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adam);
      optimizerConfigV2.mutable_adam()->set_beta_1(trainerConfig_.adam_beta1());
      optimizerConfigV2.mutable_adam()->set_beta_2(trainerConfig_.adam_beta2());
      optimizerConfigV2.mutable_adam()->set_epsilon(
          trainerConfig_.adam_epsilon());
    } else {
      LOG(ERROR) << "got unsupported v1 optimizer config: "
                 << trainerConfig_.learning_method();
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
    }

    if (trainerConfig_.learning_rate_schedule() == "constant") {
      optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
      optimizerConfigV2.mutable_const_lr()->set_learning_rate(
          trainerConfig_.learning_rate());
    } else if (trainerConfig_.learning_rate_schedule() == "linear") {
      optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Linear);
      optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
          trainerConfig_.learning_rate());
      optimizerConfigV2.mutable_linear_lr()->set_lr_decay_a(
          trainerConfig_.learning_rate_decay_a());
      optimizerConfigV2.mutable_linear_lr()->set_lr_decay_b(
          trainerConfig_.learning_rate_decay_b());
    } else {
      LOG(ERROR) << "got unsupported v1 learning_rate_schedule config: "
                 << trainerConfig_.learning_rate_schedule() << ", set to const";
      optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
107 108
      optimizerConfigV2.mutable_const_lr()->set_learning_rate(
          trainerConfig_.learning_rate());
109 110 111
    }

    // overwrite optimizerConfigV2 for per-parameter(layer) configs
112
    for (int i = 0; i < parameterSize(); ++i) {
Q
qiaolongfei 已提交
113
      auto paramConfig = parameters_[i]->getConfig();
114 115 116 117
      if (paramConfig.has_momentum() &&
          trainerConfig_.learning_method() == "momentum") {
        optimizerConfigV2.mutable_sgd()->set_momentum(paramConfig.momentum());
      }
武毅 已提交
118
      if (paramConfig.has_learning_rate()) {
119 120 121 122 123 124 125 126 127 128
        switch (optimizerConfigV2.lr_policy()) {
          case 0:
            optimizerConfigV2.mutable_const_lr()->set_learning_rate(
                paramConfig.learning_rate());
            break;
          case 1:
            optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
                paramConfig.learning_rate());
            break;
        }
武毅 已提交
129
      }
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
      if (paramConfig.has_decay_rate()) {
        switch (optimizerConfigV2.optimizer()) {
          case 1:  // SGD
            optimizerConfigV2.mutable_sgd()->set_decay(
                paramConfig.decay_rate());
            break;
          case 2:  // Adadelta
            optimizerConfigV2.mutable_adadelta()->set_decay(
                paramConfig.decay_rate());
            break;
          case 3:  // Adagrad
            optimizerConfigV2.mutable_adagrad()->set_decay(
                paramConfig.decay_rate());
            break;
          case 4:  // Adam
            optimizerConfigV2.mutable_adam()->set_decay(
                paramConfig.decay_rate());
            break;
        }
W
wuyi05 已提交
149
      }
150 151
      // send param and config to pserver
      std::string bytes = optimizerConfigV2.SerializeAsString();
Q
qiaolongfei 已提交
152 153 154 155
      const char *array = bytes.data();
      int size = (int)bytes.size();
      paddle_init_param(
          parameterClient_, *newParameters_[i], (void *)array, size);
156 157 158 159
    }
    paddle_finish_init_params(parameterClient_);
    LOG(INFO) << "paddle_begin_init_params done";
  } else {
Q
qiaolongfei 已提交
160
    paddle_get_params(parameterClient_, newParameters_, parameterSize());
161 162 163 164 165 166 167 168 169
  }

  LOG(INFO) << "NewRemoteParameterUpdater initialized";
}

void NewRemoteParameterUpdater::updateImpl(Parameter *para) {}

void NewRemoteParameterUpdater::finishBatch(real cost) {
  // send gradient to parameter server.
170
  paddle_send_grads(parameterClient_, newGradients_, parameterSize());
171
  // get the updated parameter from parameterClient.
Q
qiaolongfei 已提交
172
  paddle_get_params(parameterClient_, newParameters_, parameterSize());
173 174 175 176 177 178 179 180 181 182

  // clear gradient after update parameter.
  for (auto &para : parameters_) {
    para->getBuf(PARAMETER_GRADIENT)->zeroMem();
  }
}

void NewRemoteParameterUpdater::startPass() {}

bool NewRemoteParameterUpdater::finishPass() { return true; }
W
wuyi05 已提交
183
}  // namespace paddle