NewRemoteParameterUpdater.cpp 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "NewRemoteParameterUpdater.h"
#include "Trainer.h"
#include "paddle/utils/Stat.h"

DECLARE_int32(trainer_id);
DECLARE_string(save_dir);

namespace paddle {
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
    const OptimizationConfig &config, const std::string pserverSpec)
W
wuyi05 已提交
25 26
    : trainerConfig_(config),
      parameterClient_(-1),
27 28 29
      newParameters_(nullptr),
      newGradients_(nullptr),
      pserverSpec_(pserverSpec) {}
30

31 32 33 34 35 36 37 38 39 40 41
NewRemoteParameterUpdater::NewRemoteParameterUpdater(
    const OptimizationConfig &config,
    const std::string pserverSpec,
    const bool useEtcd)
    : trainerConfig_(config),
      parameterClient_(-1),
      newParameters_(nullptr),
      newGradients_(nullptr),
      pserverSpec_(pserverSpec),
      useEtcd_(useEtcd) {}

42 43 44 45 46
void NewRemoteParameterUpdater::init(
    const std::vector<ParameterPtr> &parameters) {
  ParameterUpdater::init(parameters);

  // create parameter server client.
47
  if (useEtcd_) {
H
Helin Wang 已提交
48 49
    parameterClient_ =
        paddle_new_etcd_pserver_client((char *)pserverSpec_.c_str());
50 51 52 53
  } else {
    parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
                                                 FLAGS_trainer_id == 0);
  }
54 55

  // init new parameter and gradient.
Q
qiaolongfei 已提交
56 57
  newParameters_ = initNewParameter(PARAMETER_VALUE);
  newGradients_ = initNewParameter(PARAMETER_GRADIENT);
58 59 60 61 62 63

  // init parameter, one trainer will get the opportunity to int parameter and
  // send them to parameter server. Others will get the initialized parameter
  // from parameter server
  if (paddle_begin_init_params(parameterClient_)) {
    LOG(INFO) << "paddle_begin_init_params start";
64 65
    // NOTE: convert V1 OptimizatioinConfig proto to V2 OptimizerConfig.
    // This makes golang pserver compatible with handy V1 demos.
H
Helin Wang 已提交
66
    // TODO(wuyi): Refine or remove these ugly converting lines
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    OptimizerConfig optimizerConfigV2;
    if (trainerConfig_.learning_method() == "momentum") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
    } else if (trainerConfig_.learning_method() == "adagrad") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
      optimizerConfigV2.mutable_adagrad()->set_epsilon(
          trainerConfig_.ada_epsilon());
    } else if (trainerConfig_.learning_method() == "adadelta") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
      optimizerConfigV2.mutable_adadelta()->set_epsilon(
          trainerConfig_.ada_epsilon());
      optimizerConfigV2.mutable_adadelta()->set_rho(trainerConfig_.ada_rou());
    } else if (trainerConfig_.learning_method() == "adam") {
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adam);
      optimizerConfigV2.mutable_adam()->set_beta_1(trainerConfig_.adam_beta1());
      optimizerConfigV2.mutable_adam()->set_beta_2(trainerConfig_.adam_beta2());
      optimizerConfigV2.mutable_adam()->set_epsilon(
          trainerConfig_.adam_epsilon());
    } else {
      LOG(ERROR) << "got unsupported v1 optimizer config: "
                 << trainerConfig_.learning_method();
      optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
    }

    if (trainerConfig_.learning_rate_schedule() == "constant") {
      optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
      optimizerConfigV2.mutable_const_lr()->set_learning_rate(
          trainerConfig_.learning_rate());
    } else if (trainerConfig_.learning_rate_schedule() == "linear") {
      optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Linear);
      optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
          trainerConfig_.learning_rate());
      optimizerConfigV2.mutable_linear_lr()->set_lr_decay_a(
          trainerConfig_.learning_rate_decay_a());
      optimizerConfigV2.mutable_linear_lr()->set_lr_decay_b(
          trainerConfig_.learning_rate_decay_b());
    } else {
      LOG(ERROR) << "got unsupported v1 learning_rate_schedule config: "
                 << trainerConfig_.learning_rate_schedule() << ", set to const";
      optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
107 108
      optimizerConfigV2.mutable_const_lr()->set_learning_rate(
          trainerConfig_.learning_rate());
109 110 111
    }

    // overwrite optimizerConfigV2 for per-parameter(layer) configs
112
    for (int i = 0; i < parameterSize(); ++i) {
T
typhoonzero 已提交
113 114
      // FIXME(typhoonzero): paramConfig always have default values,
      // how to check if it's default?
H
helinwang 已提交
115
      // TODO(typhoonzero): log output: optimizerConfigV2.DebugString();
T
typhoonzero 已提交
116
      LOG(INFO) << "trainerConfig_: " << trainerConfig_.DebugString();
117 118
      // send param and config to pserver
      std::string bytes = optimizerConfigV2.SerializeAsString();
Q
qiaolongfei 已提交
119 120 121 122
      const char *array = bytes.data();
      int size = (int)bytes.size();
      paddle_init_param(
          parameterClient_, *newParameters_[i], (void *)array, size);
123 124 125 126
    }
    paddle_finish_init_params(parameterClient_);
    LOG(INFO) << "paddle_begin_init_params done";
  } else {
Q
qiaolongfei 已提交
127
    paddle_get_params(parameterClient_, newParameters_, parameterSize());
128 129 130 131 132 133 134 135 136
  }

  LOG(INFO) << "NewRemoteParameterUpdater initialized";
}

void NewRemoteParameterUpdater::updateImpl(Parameter *para) {}

void NewRemoteParameterUpdater::finishBatch(real cost) {
  // send gradient to parameter server.
137
  paddle_send_grads(parameterClient_, newGradients_, parameterSize());
138
  // get the updated parameter from parameterClient.
Q
qiaolongfei 已提交
139
  paddle_get_params(parameterClient_, newParameters_, parameterSize());
140 141 142 143 144 145 146 147 148 149

  // clear gradient after update parameter.
  for (auto &para : parameters_) {
    para->getBuf(PARAMETER_GRADIENT)->zeroMem();
  }
}

void NewRemoteParameterUpdater::startPass() {}

bool NewRemoteParameterUpdater::finishPass() { return true; }
W
wuyi05 已提交
150
}  // namespace paddle