NewRemoteParameterUpdater.cpp 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "NewRemoteParameterUpdater.h"
#include "Trainer.h"
#include "paddle/utils/Stat.h"

DECLARE_int32(trainer_id);
DECLARE_string(save_dir);

namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
    const OptimizationConfig &config, const std::string pserverSpec)
W
wuyi05 已提交
25 26
    : trainerConfig_(config),
      parameterClient_(-1),
27 28 29
      newParameters_(nullptr),
      newGradients_(nullptr),
      pserverSpec_(pserverSpec) {}
30

31 32 33 34 35 36 37 38 39 40 41
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
    const OptimizationConfig &config,
    const std::string pserverSpec,
    const bool useEtcd)
    : trainerConfig_(config),
      parameterClient_(-1),
      newParameters_(nullptr),
      newGradients_(nullptr),
      pserverSpec_(pserverSpec),
      useEtcd_(useEtcd) {}

42 43 44 45 46 47 48 49 50 51
void NewRemoteParameterUpdater::init(
    const std::vector<ParameterPtr> &parameters) {
  ParameterUpdater::init(parameters);

  for (auto &para : parameters_) {
    para->getBuf(PARAMETER_VALUE)->zeroMem();
    para->getBuf(PARAMETER_GRADIENT)->zeroMem();
  }

  // create parameter server client.
52
  if (useEtcd_) {
H
Helin Wang 已提交
53 54
    parameterClient_ =
        paddle_new_etcd_pserver_client((char *)pserverSpec_.c_str());
55 56 57 58
  } else {
    parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
                                                 FLAGS_trainer_id == 0);
  }
59 60

  // init new parameter and gradient.
Q
qiaolongfei 已提交
61 62
  newParameters_ = initNewParameter(PARAMETER_VALUE);
  newGradients_ = initNewParameter(PARAMETER_GRADIENT);
63 64 65 66 67 68

  // init parameter, one trainer will get the opportunity to int parameter and
  // send them to parameter server. Others will get the initialized parameter
  // from parameter server
  if (paddle_begin_init_params(parameterClient_)) {
    LOG(INFO) << "paddle_begin_init_params start";
69 70
    // NOTE: convert V1 OptimizatioinConfig proto to V2 OptimizerConfig.
    // This makes golang pserver compatible with handy V1 demos.
H
Helin Wang 已提交
71
    // TODO(wuyi): Refine or remove these ugly converting lines
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    OptimizerConfig optimizerConfigV2;
    if (trainerConfig_.learning_method() == "momentum") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
    } else if (trainerConfig_.learning_method() == "adagrad") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
      optimizerConfigV2.mutable_adagrad()->set_epsilon(
          trainerConfig_.ada_epsilon());
    } else if (trainerConfig_.learning_method() == "adadelta") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
      optimizerConfigV2.mutable_adadelta()->set_epsilon(
          trainerConfig_.ada_epsilon());
      optimizerConfigV2.mutable_adadelta()->set_rho(trainerConfig_.ada_rou());
    } else if (trainerConfig_.learning_method() == "adam") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adam);
      optimizerConfigV2.mutable_adam()->set_beta_1(trainerConfig_.adam_beta1());
      optimizerConfigV2.mutable_adam()->set_beta_2(trainerConfig_.adam_beta2());
      optimizerConfigV2.mutable_adam()->set_epsilon(
          trainerConfig_.adam_epsilon());
    } else {
      LOG(ERROR) << "got unsupported v1 optimizer config: "
                 << trainerConfig_.learning_method();
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
    }

    if (trainerConfig_.learning_rate_schedule() == "constant") {
      optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
      optimizerConfigV2.mutable_const_lr()->set_learning_rate(
          trainerConfig_.learning_rate());
    } else if (trainerConfig_.learning_rate_schedule() == "linear") {
      optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Linear);
      optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
          trainerConfig_.learning_rate());
      optimizerConfigV2.mutable_linear_lr()->set_lr_decay_a(
          trainerConfig_.learning_rate_decay_a());
      optimizerConfigV2.mutable_linear_lr()->set_lr_decay_b(
          trainerConfig_.learning_rate_decay_b());
    } else {
      LOG(ERROR) << "got unsupported v1 learning_rate_schedule config: "
                 << trainerConfig_.learning_rate_schedule() << ", set to const";
      optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
    }

    // overwrite optimizerConfigV2 for per-parameter(layer) configs
115
    for (int i = 0; i < parameterSize(); ++i) {
Q
qiaolongfei 已提交
116
      auto paramConfig = parameters_[i]->getConfig();
117 118 119 120
      if (paramConfig.has_momentum() &&
          trainerConfig_.learning_method() == "momentum") {
        optimizerConfigV2.mutable_sgd()->set_momentum(paramConfig.momentum());
      }
武毅 已提交
121
      if (paramConfig.has_learning_rate()) {
122 123 124 125 126 127 128 129 130 131
        switch (optimizerConfigV2.lr_policy()) {
          case 0:
            optimizerConfigV2.mutable_const_lr()->set_learning_rate(
                paramConfig.learning_rate());
            break;
          case 1:
            optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
                paramConfig.learning_rate());
            break;
        }
武毅 已提交
132
      }
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
      if (paramConfig.has_decay_rate()) {
        switch (optimizerConfigV2.optimizer()) {
          case 1:  // SGD
            optimizerConfigV2.mutable_sgd()->set_decay(
                paramConfig.decay_rate());
            break;
          case 2:  // Adadelta
            optimizerConfigV2.mutable_adadelta()->set_decay(
                paramConfig.decay_rate());
            break;
          case 3:  // Adagrad
            optimizerConfigV2.mutable_adagrad()->set_decay(
                paramConfig.decay_rate());
            break;
          case 4:  // Adam
            optimizerConfigV2.mutable_adam()->set_decay(
                paramConfig.decay_rate());
            break;
        }
W
wuyi05 已提交
152
      }
153 154
      // send param and config to pserver
      std::string bytes = optimizerConfigV2.SerializeAsString();
Q
qiaolongfei 已提交
155 156 157 158
      const char *array = bytes.data();
      int size = (int)bytes.size();
      paddle_init_param(
          parameterClient_, *newParameters_[i], (void *)array, size);
159 160 161 162
    }
    paddle_finish_init_params(parameterClient_);
    LOG(INFO) << "paddle_begin_init_params done";
  } else {
Q
qiaolongfei 已提交
163
    paddle_get_params(parameterClient_, newParameters_, parameterSize());
164 165 166 167 168 169 170 171 172
  }

  LOG(INFO) << "NewRemoteParameterUpdater initialized";
}

void NewRemoteParameterUpdater::updateImpl(Parameter *para) {}

void NewRemoteParameterUpdater::finishBatch(real cost) {
  // send gradient to parameter server.
173
  paddle_send_grads(parameterClient_, newGradients_, parameterSize());
174
  // get the updated parameter from parameterClient.
Q
qiaolongfei 已提交
175
  paddle_get_params(parameterClient_, newParameters_, parameterSize());
176 177 178 179 180 181 182 183 184 185

  // clear gradient after update parameter.
  for (auto &para : parameters_) {
    para->getBuf(PARAMETER_GRADIENT)->zeroMem();
  }
}

void NewRemoteParameterUpdater::startPass() {}

bool NewRemoteParameterUpdater::finishPass() { return true; }
W
wuyi05 已提交
186
}  // namespace paddle