tensor_table.h 7.9 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17 18
#include <algorithm>
#include <condition_variable>  // NOLINT
T
tangwei12 已提交
19
#include <memory>
20 21
#include <mutex>  // NOLINT
#include <set>
T
tangwei12 已提交
22 23 24
#include <string>
#include <unordered_map>
#include <vector>
25

26
#include "paddle/fluid/distributed/common/utils.h"
27
#include "paddle/fluid/distributed/ps/table/table.h"
T
tangwei12 已提交
28 29 30 31
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/device_context.h"

32 33 34 35 36 37 38 39
namespace paddle {
namespace framework {
class Executor;
class Scope;
struct ExecutorPrepareContext;
}  // namespace framework
}  // namespace paddle

40 41
DECLARE_double(eager_delete_tensor_gb);

T
tangwei12 已提交
42 43 44
namespace paddle {
namespace distributed {

45 46 47
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"

T
tangwei12 已提交
48 49
class TensorTable : public Table {
 public:
50
  TensorTable() {}
T
tangwei12 已提交
51 52
  virtual ~TensorTable() {}

53 54
  int32_t Pull(TableContext &context) override { return 0; }
  int32_t Push(TableContext &context) override { return 0; }
T
tangwei12 已提交
55

Z
zhaocaibei123 已提交
56
  int32_t Shrink(const std::string &param) override { return 0; }
57

58
  void *GetShard(size_t shard_idx) override { return 0; }
T
tangwei12 已提交
59

60
  int32_t InitializeShard() override { return 0; }
T
tangwei12 已提交
61

62
  int32_t Flush() override { return 0; }
63

64
  int32_t Load(const std::string &path, const std::string &param) override {
T
tangwei12 已提交
65
    return 0;
66
  }
67
  int32_t Save(const std::string &path, const std::string &param) override {
68 69 70
    return 0;
  }

71
  void Clear() override {}
T
tangwei12 已提交
72

Z
zhaocaibei123 已提交
73
  int32_t Initialize() override { return 0; }
74

Z
zhaocaibei123 已提交
75
  int32_t SetProgramEnv(
76
      framework::Scope *scope, platform::Place place,
77 78 79 80 81 82 83
      const std::vector<framework::ProgramDesc> *sub_program) override {
    scope_ = scope;
    place_ = place;
    executor_ = new framework::Executor(place_);
    sub_program_ = sub_program;
    return 0;
  }
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

 protected:
  framework::Executor *executor_;
  framework::Scope *scope_;
  platform::Place place_ = platform::CPUPlace();
  const std::vector<framework::ProgramDesc> *sub_program_;
  paddle::distributed::TensorAccessorParameter program_config_;
  std::shared_ptr<framework::ExecutorPrepareContext> exec_context_ = nullptr;
};

class DenseTensorTable : public TensorTable {
 public:
  DenseTensorTable() {}
  virtual ~DenseTensorTable() {}

Z
zhaocaibei123 已提交
99
  int32_t Shrink(const std::string &param) override { return 0; }
T
tangwei12 已提交
100

101
  void *GetShard(size_t shard_idx) override { return 0; }
T
tangwei12 已提交
102

103
  int32_t InitializeShard() override { return 0; }
T
tangwei12 已提交
104

105
  int32_t Flush() override { return 0; }
T
tangwei12 已提交
106

107
  void Clear() override {}
108 109

  // Todo: Support program Load & Save
110
  int32_t Load(const std::string &path, const std::string &param) override {
T
tangwei12 已提交
111 112
    return 0;
  }
113
  int32_t Save(const std::string &path, const std::string &param) override {
T
tangwei12 已提交
114 115 116
    return 0;
  }

117 118
  /*----------------------------------------------------------------------*/

Z
zhaocaibei123 已提交
119
  int32_t Initialize() override { return 0; }
120 121

 protected:
Z
zhaocaibei123 已提交
122 123
  virtual int32_t _RunProgram(const float *values, size_t num,
                              const uint32_t trainer_id) {
124 125
    return 0;
  }
T
tangwei12 已提交
126

127 128 129 130
  int startup_program_id_ = -1;
  int main_program_id_ = -1;
  std::string feed_var_name_ = "";
  std::string fetch_var_name_ = "";
T
tangwei12 已提交
131 132
};

133
class GlobalStepTable : public DenseTensorTable {
T
tangwei12 已提交
134
 public:
135 136
  GlobalStepTable() {}
  virtual ~GlobalStepTable() {}
T
tangwei12 已提交
137

Z
zhaocaibei123 已提交
138
  int32_t Shrink(const std::string &param) override { return 0; }
T
tangwei12 已提交
139

140
  void *GetShard(size_t shard_idx) override { return 0; }
141

142
  int32_t InitializeShard() override { return 0; }
T
tangwei12 已提交
143

144
  int32_t Flush() override { return 0; }
T
tangwei12 已提交
145

146
  void Clear() override {}
147

148
  int32_t Load(const std::string &path, const std::string &param) override {
T
tangwei12 已提交
149 150
    return 0;
  }
151
  int32_t Save(const std::string &path, const std::string &param) override {
T
tangwei12 已提交
152 153 154
    return 0;
  }

155
  /*----------------------------------------------------------------------*/
T
tangwei12 已提交
156

Z
zhaocaibei123 已提交
157
  int32_t Initialize() override {
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
    auto _program_config = _config.tensor();
    auto trainers_ = _config.common().trainer_num();
    FLAGS_eager_delete_tensor_gb = -1;
    // Get Config
    if (_program_config.has_startup_program_id()) {
      startup_program_id_ = _program_config.startup_program_id();
    }
    if (_program_config.has_main_program_id()) {
      main_program_id_ = _program_config.main_program_id();
    }
    if (_program_config.has_feed_var_name()) {
      feed_var_name_ = _program_config.feed_var_name();
    }
    if (_program_config.has_fetch_var_name()) {
      fetch_var_name_ = _program_config.fetch_var_name();
    }

    // Run startup program
    if (startup_program_id_ != -1) {
      std::map<std::string, const framework::LoDTensor *> fake_feed;
      std::map<std::string, framework::FetchType *> fake_fetch;
      auto startup_program_desc = sub_program_->at(startup_program_id_);
      auto ctx = executor_->Prepare(startup_program_desc, 0);
      executor_->RunPreparedContext(ctx.get(), scope_, false);
    }

    if (main_program_id_ != -1) {
      // Run main porgram, if program is used for learning decay
      auto main_program_desc = sub_program_->at(main_program_id_);
      auto main_ctx = executor_->Prepare(main_program_desc, 0);
      exec_context_ = std::move(main_ctx);
      executor_->RunPreparedContext(exec_context_.get(), scope_, false);
      // init decay_counters
      decay_counters_.reserve(trainers_);
      for (int32_t i = 0; i < trainers_; ++i) {
        decay_counters_[i] = 0;
      }
    }
196
    return 0;
197
  }
198

199
  //  int32_t PushDense(const float *values, size_t num) override { return 0; }
200

201 202
  virtual int32_t Push(TableContext context) {
    return _RunProgram(context.push_context.push_steps, context.trainer_id);
203
  }
204

Z
zhaocaibei123 已提交
205 206
  int32_t SetTableMap(std::unordered_map<uint32_t, std::shared_ptr<Table>>
                          *table_map) override {
207 208 209 210 211 212 213 214 215 216
    auto *lr_var = scope_->FindVar(fetch_var_name_);
    auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
    auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace());
    VLOG(3) << "GlobalStepTable::set_table_map set global lr: " << *lr_value;

    for (auto iter = table_map->begin(); iter != table_map->end(); iter++) {
      auto table_id = iter->first;
      if (table_id == _config.table_id()) {
        continue;
      }
Z
zhaocaibei123 已提交
217
      iter->second->SetGlobalLR(lr_value);
218 219 220
    }
    return 0;
  }
221 222

 private:
Z
zhaocaibei123 已提交
223 224
  virtual int32_t _RunProgram(const int64_t *values,
                              const uint32_t trainer_id) {
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
    FLAGS_eager_delete_tensor_gb = -1;
    auto counter = decay_counters_.at(trainer_id);
    counter += int(values[0]);
    decay_counters_.at(trainer_id) = counter;

    auto *global_step_var = scope_->FindVar(feed_var_name_);
    auto *tensor = global_step_var->GetMutable<framework::LoDTensor>();
    auto *value = tensor->mutable_data<int64_t>(platform::CPUPlace());

    auto global_counter = 0;
    for (auto &trainer_counter : decay_counters_) {
      global_counter += trainer_counter.second;
    }

    // Todo: hard code for increment op
    value[0] = global_counter - 1;
    VLOG(3) << "GlobalStepTable::_run_program global_counter " << value[0];

    executor_->RunPreparedContext(exec_context_.get(), scope_, false, false);
    auto *lr_var = scope_->FindVar(fetch_var_name_);
    auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
    auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace());
    VLOG(3) << "GlobalStepTable::LR value: " << lr_value[0];
    return 0;
  }
250 251 252 253

 private:
  std::unordered_map<int, int64_t> decay_counters_;
  int32_t trainers_;
T
tangwei12 已提交
254
};
255

T
tangwei12 已提交
256 257
}  // namespace distributed
}  // namespace paddle