ParameterOptimizer.h 6.7 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "LearningRateScheduler.h"
#include "Parameter.h"

namespace paddle {

/**
 * Some member functions are set to const for two reasons:
 *
 * 1. For sparse update thread safe: update(), traverse callback(const this)
 *    may be called many times, each time one row, and these function
 *    can be called parallelly by multi worker, to speed up large block.
 *
 * 2. For predicate functions, needSpecialTraversal(), startCatchUpWith()
 *    may be called many times, should be no state change between calls.
 */
class ParameterOptimizer {
public:
34 35
  typedef std::function<void(
      const VectorPtr vecs[], const ParameterConfig& config, size_t sparseId)>
Z
zhangjinchao01 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
      TraverseCallback;

public:
  explicit ParameterOptimizer(const OptimizationConfig& optConfig)
      : applyDecay_(true),
        optConfig_(optConfig),
        parameterTypes_{PARAMETER_VALUE, PARAMETER_GRADIENT},
        learningRate_(optConfig.learning_rate()),
        learningRateScheduler_(LearningRateScheduler::create(optConfig)),
        pass_(0),
        firstTime_(true) {}

  real calcLearningRate(int64_t numSamplesProcessed, int64_t pass) {
    return learningRateScheduler_->calcLearningRate(numSamplesProcessed, pass);
  }

  virtual ~ParameterOptimizer() {}

  /**
   * For sparse update, optimizer can maintain numRows of timer(t0).
   * Some sparse optimizer depends on parameter config in functions
   * such as startBatch(). Optimizer can get it here. But notice that,
   * not all callers can pass config here, so the optimizer should check
   * config passed in is not null ptr.
   */
  virtual void init(size_t numRows, const ParameterConfig* config) {}

  virtual void startPass() {}
  virtual void finishPass() { ++pass_; }

  /// called by Trainer before forward() of a batch.
  virtual void startBatch(int64_t numSamplesProcessed) {
    (void)numSamplesProcessed;
  }

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
  /**
   * following hooks useful for sparse update,
   * because the traversal in block costs.
   * called by Trainer after update and before finishBatch
   * e.g. Trainer call like this:
   *
   * @code
   * startBatch();
   * if (dense) {
   *   update(blockVec);
   * } else {//sparse
   *   for (row : rows_in_block) {update(rowVec)}
   * }
   * auto callback = needSpecialTraversal();
   * if (callback) {
   *   // do traverse, maybe multi-thread
   *   if (dense) {
   *     callback();
   *   } else {//sparse
   *     for (row : all_rows_in_block) {callback();}
   *   }
   * }
   * finishBatch();
   * @endcode
   *
   * @return callback if need traverse,
   *         else return nullptr.
   *         It should be no state change.
   */
Z
zhangjinchao01 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113
  virtual TraverseCallback needSpecialTraversal(
      const ParameterConfig& config) const {
    return nullptr;
  }

  /// called by Trainer after backward() of a batch
  virtual void finishBatch() {}

  /**
   * between startBatch() and finishBatch(), update() will be called
   * by the trainer multiple times, each time for updating one Parameter
   * with its gradient in PARAMETER_GRADIENT. sparseId is row id,
   * when sparseId set, update is sparse, each time one row.
   */
114 115
  virtual void update(const VectorPtr vecs[],
                      const ParameterConfig& config,
Z
zhangjinchao01 已提交
116 117
                      size_t sparseId = -1LU) const = 0;

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
  /**
   * following hooks catch up with current time for sparse update,
   * In the beginning, call startCatchUpWith() and check return.
   * In the end, call finishCatchUpWith() to finish state.
   * callback do the actual works, can call many times for sparse data.
   * e.g. Trainer call like this:
   *
   * @code
   * auto callback = startCatchUpWith();
   * if (callback) {
   *   // do catch up with, maybe multi-thread
   *   if (dense) {
   *     callback();
   *   } else {//sparse
   *     for (row : rows_in_block) {callback();}
   *   }
   *   // finish catch up with, main thread
   *   finishCatchUpWith();
   * }
   * @endcode
   *
   * @return callback if need catch up with,
   *         else return nullptr.
   *         It should be no state change.
   */
Z
zhangjinchao01 已提交
143 144 145
  virtual TraverseCallback startCatchUpWith() const { return nullptr; }
  virtual void finishCatchUpWith() {}

146 147 148 149 150 151 152 153 154 155
  /**
   * following two hooks used by averager,
   * apply to final parameter value (PARAMETER_VALUE or PARAMETER_APPLY).
   *
   * restore() will restore orginal value if it apply to PARAMETER_VALUE.
   * Caller must ensure it's catched up with current time before apply.
   *
   * Use returned callback same way as callback returned by
   * ParameterOptimizer::needSpecialTraversal()
   */
Z
zhangjinchao01 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
  virtual TraverseCallback apply() { return nullptr; }
  virtual TraverseCallback restore() { return nullptr; }

  /// return the parameter types used by this updater
  const std::vector<ParameterType>& getParameterTypes() const {
    return parameterTypes_;
  }

  void addParameterType(ParameterType type) {
    for (auto t : parameterTypes_) {
      if (t == type) return;
    }
    parameterTypes_.push_back(type);
  }
  real getLearningRate() const { return learningRate_; }

  virtual void setNoDecay() { applyDecay_ = false; }

  static ParameterOptimizer* create(const OptimizationConfig& optConfig,
                                    bool inPserver = false);

protected:
  typedef std::vector<ParameterOptimizer::TraverseCallback> TraverseCallbackVec;

  static TraverseCallback composeCallbacks(
      const TraverseCallbackVec& callbacks) {
    if (callbacks.size() > 1LU) {
183 184
      return [callbacks](const VectorPtr vecs[],
                         const ParameterConfig& config,
Z
zhangjinchao01 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
                         size_t sparseId) {
        for (auto callback : callbacks) {
          callback(vecs, config, sparseId);
        }
      };
    }
    return (callbacks.size() == 1LU) ? callbacks[0] : nullptr;
  }

  bool applyDecay_;
  const OptimizationConfig& optConfig_;
  std::vector<ParameterType> parameterTypes_;

  /**
   * global learning rate, init value is opt_config.learning_rate,
   * sparse regularizer get this value per batch, after StartBatch() called
   * so, if lr change in StartBatch, please assign to learningRate_
   */
  real learningRate_;
  std::unique_ptr<LearningRateScheduler> learningRateScheduler_;
  int64_t pass_;  // current training pass (starting from 0)
  bool firstTime_;
};

}  // namespace paddle