multi_trainer.cc 11.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
16

17
#include "paddle/fluid/framework/device_worker_factory.h"
18
#include "paddle/fluid/framework/threadpool.h"
19
#include "paddle/fluid/framework/trainer.h"
Z
zhaocaibei123 已提交
20
#include "paddle/fluid/platform/lodtensor_printer.h"
T
tangwei12 已提交
21
#if defined PADDLE_WITH_PSCORE
22
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
T
tangwei12 已提交
23
#endif
24 25 26 27

namespace paddle {
namespace framework {

28 29
extern Barrier g_barrier;

D
dongdaxiang 已提交
30
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
31
                              Dataset* dataset) {
32
  thread_num_ = trainer_desc.thread_num();
33 34
  SetDataset(dataset);

H
hutuxian 已提交
35
  ParseDumpConfig(trainer_desc);
36 37 38
  mpi_rank_ = trainer_desc.mpi_rank();
  mpi_size_ = trainer_desc.mpi_size();
  dump_file_num_ = trainer_desc.dump_file_num();
39 40 41 42 43
  for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
       i++) {
    need_merge_var_names_.push_back(
        trainer_desc.downpour_param().stat_var_names(i));
  }
T
Thunderbrook 已提交
44 45 46 47 48 49 50
#ifdef PADDLE_WITH_HETERPS
  for (int i = 0; i < thread_num_; ++i) {
    int num = trainer_desc.worker_places(i);
    platform::CUDAPlace place = platform::CUDAPlace(num);
    places_.push_back(place);
  }
#endif
L
lxsbupt 已提交
51
  user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
52
  // get filelist from trainer_desc here
J
jiaqi 已提交
53
  const std::vector<paddle::framework::DataFeed*> readers =
D
dongdaxiang 已提交
54
      dataset->GetReaders();
55
  VLOG(3) << "readers num: " << readers.size();
56 57 58 59
  // change thread num to readers num
  thread_num_ = readers.size();
  VLOG(3) << "worker thread num: " << thread_num_;
  workers_.resize(thread_num_);
60

T
tangwei12 已提交
61
#if defined PADDLE_WITH_PSCORE
62
  if (trainer_desc.thread_barrier()) {
T
tangwei12 已提交
63
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset(
64 65 66
        thread_num_);
  }
#endif
67
  g_barrier.reset(thread_num_);
68 69 70
  for (int i = 0; i < thread_num_; ++i) {
    workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
        trainer_desc.device_worker_name());
H
hutuxian 已提交
71 72 73 74 75
    workers_[i]->SetNeedDumpField(need_dump_field_);
    workers_[i]->SetNeedDumpParam(need_dump_param_);
    workers_[i]->SetDumpFieldVector(dump_fields_);
    workers_[i]->SetDumpParamVector(dump_param_);
    workers_[i]->InitRandomDumpConfig(trainer_desc);
D
dongdaxiang 已提交
76
    workers_[i]->Initialize(trainer_desc);
77
    workers_[i]->SetDeviceIndex(i);
D
dongdaxiang 已提交
78
    workers_[i]->SetDataFeed(readers[i]);
79
    workers_[i]->SetThreadNum(thread_num_);
80
  }
D
dongdaxiang 已提交
81 82

  // set debug here
83
  SetDebug(trainer_desc.debug());
84 85
}

H
hutuxian 已提交
86
std::string MultiTrainer::GetDumpPath(int tid) {
87
  if (!user_define_dump_filename_.empty()) {
88 89 90 91
    return string::format_string("%s/part-%s-%05d",
                                 dump_fields_path_.c_str(),
                                 user_define_dump_filename_.c_str(),
                                 tid);
Y
yaoxuefeng 已提交
92
  }
93 94
  return string::format_string(
      "%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid);
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
}

void MultiTrainer::InitDumpEnv() {
  queue_ = paddle::framework::MakeChannel<std::string>();
  for (int i = 0; i < thread_num_; ++i) {
    workers_[i]->SetChannelWriter(queue_.get());
  }
  dump_thread_num_ = 1;
  if (dump_file_num_ > mpi_size_) {
    dump_thread_num_ = dump_file_num_ / mpi_size_;
    if (dump_file_num_ % mpi_size_ > mpi_rank_) {
      dump_thread_num_ += 1;
    }
  }
  for (int i = 0; i < dump_thread_num_; i++) {
110
    dump_thread_.push_back(std::thread([this, i] { DumpWork(i); }));
111 112 113
  }
}

114 115 116 117
// call only after all resources are set in current trainer
void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
                                  const platform::Place& place) {
  for (int i = 0; i < thread_num_; ++i) {
T
Thunderbrook 已提交
118 119 120
#ifdef PADDLE_WITH_HETERPS
    workers_[i]->SetPlace(places_[i]);
    workers_[i]->SetReaderPlace(places_[i]);
121 122
    workers_[i]->SetDeviceContext(
        platform::DeviceContextPool::Instance().Get(places_[i]));
T
Thunderbrook 已提交
123
#else
124
    workers_[i]->SetPlace(place);
125
    workers_[i]->SetReaderPlace(place);
T
Thunderbrook 已提交
126
#endif
127 128 129
    workers_[i]->SetRootScope(root_scope_);
    workers_[i]->CreateDeviceResource(main_program);  // Program
    workers_[i]->BindingDataFeedMemory();
T
Thunderbrook 已提交
130
    workers_[i]->CacheProgram(main_program);
131
  }
T
Thunderbrook 已提交
132 133 134 135 136 137 138 139 140 141 142 143
#ifdef PADDLE_WITH_HETERPS
  for (int num = 0; num < thread_num_; ++num) {
    auto place = places_[num];
    Scope* scope = workers_[num]->GetThreadScope();
    auto& block = main_program.Block(0);
    for (auto& var : block.AllVars()) {
      if (var->Persistable()) {
        auto name = var->Name();
        Variable* root_var = root_scope_->FindVar(name);
        if (!root_var) {
          continue;
        }
144
        if (root_var->IsType<phi::SelectedRows>()) {
T
Thunderbrook 已提交
145 146
          continue;
        }
147 148
        phi::DenseTensor* root_tensor =
            root_var->GetMutable<phi::DenseTensor>();
T
Thunderbrook 已提交
149 150
        auto* ptr = scope->Var(name);
        InitializeVariable(ptr, proto::VarType::LOD_TENSOR);
151
        phi::DenseTensor* thread_tensor = ptr->GetMutable<phi::DenseTensor>();
T
Thunderbrook 已提交
152 153 154 155 156
        TensorCopy(*root_tensor, place, thread_tensor);
      }
    }
  }
#endif
D
danleifeng 已提交
157 158 159
  for (auto& var : main_program.Block(0).AllVars()) {
    if (var->Persistable()) {
      auto it = std::find(need_merge_var_names_.begin(),
160 161
                          need_merge_var_names_.end(),
                          var->Name());
D
danleifeng 已提交
162 163 164 165 166 167 168
      if (it == need_merge_var_names_.end() &&
          var->GetType() != proto::VarType::SELECTED_ROWS) {
        VLOG(2) << "train param: " << var->Name();
        trainable_param_.push_back(var->Name());
      }
    }
  }
169 170
}

171
void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
X
xujiaqi01 已提交
172
  if (need_dump_field_ || need_dump_param_) {
173 174
    InitDumpEnv();
  }
Z
zhaocaibei123 已提交
175 176 177 178 179 180 181

#ifdef PADDLE_WITH_PSCORE
  // pull dense param first
  auto communicator = paddle::distributed::Communicator::GetInstance();
  // for unittest which call train_from_dataset but does not call
  // fleet.init_worker() first
  if (communicator == nullptr) {
182
    VLOG(1) << "MultiTrainer::InitOtherEnv Communicator is null!";
Z
zhaocaibei123 已提交
183 184 185 186 187 188
  } else {
    auto& recv_ctx = communicator->GetRecvCtxMap();
    communicator->PullDense(recv_ctx);
    VLOG(3) << "init other env done.";
  }
#endif
189 190
}

191 192 193
Scope* MultiTrainer::GetWorkerScope(int thread_id) {
  return workers_[thread_id]->GetThreadScope();
}
194 195 196 197 198 199 200 201 202 203 204 205 206
inline std::vector<std::shared_ptr<paddle::framework::ThreadPool>>&
GetThreadPool(int thread_num) {
  static std::vector<std::shared_ptr<paddle::framework::ThreadPool>>
      thread_pools;
  if (!thread_pools.empty()) {
    return thread_pools;
  }
  thread_pools.resize(thread_num);
  for (int i = 0; i < thread_num; ++i) {
    thread_pools[i].reset(new paddle::framework::ThreadPool(1));
  }
  return thread_pools;
}
207
void MultiTrainer::Run() {
208
  VLOG(3) << "Going to run";
209 210 211 212
  auto pool = GetThreadPool(thread_num_);
  std::vector<std::future<void>> wait_futures;
  CHECK_EQ(static_cast<int>(pool.size()), thread_num_);
  for (int i = 0; i < thread_num_; ++i) {
213
    if (!debug_) {
214 215
      wait_futures.emplace_back(
          pool[i]->Run([this, i]() { workers_[i]->TrainFiles(); }));
216
    } else {
217 218
      wait_futures.emplace_back(
          pool[i]->Run([this, i]() { workers_[i]->TrainFilesWithProfiler(); }));
219
    }
220
  }
221 222
  for (auto& th : wait_futures) {
    th.get();
223 224 225
  }
}

T
Thunderbrook 已提交
226 227
#ifdef PADDLE_WITH_HETERPS
void MultiTrainer::MergeDenseParam() {
D
danleifeng 已提交
228
#ifdef PADDLE_WITH_PSCORE
T
Thunderbrook 已提交
229
  auto communicator = paddle::distributed::Communicator::GetInstance();
D
danleifeng 已提交
230 231 232 233
  auto thread_scope = workers_[0]->GetThreadScope();
  if (communicator == nullptr) {
    for (auto& name : trainable_param_) {
      VLOG(2) << "merge var " << name << " to root scope";
T
Thunderbrook 已提交
234
      Variable* root_var = root_scope_->FindVar(name);
235
      phi::DenseTensor* root_tensor = root_var->GetMutable<phi::DenseTensor>();
T
Thunderbrook 已提交
236
      Variable* var = thread_scope->FindVar(name);
237
      phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
D
danleifeng 已提交
238 239 240 241 242 243 244 245 246
      TensorCopySync((*tensor), root_tensor->place(), root_tensor);
    }
  } else {
    auto& recv_ctx = communicator->GetRecvCtxMap();
    for (auto& iter : recv_ctx) {
      auto& varnames = iter.second;
      for (auto& name : varnames) {
        VLOG(2) << "merge var " << name << " to root scope";
        Variable* root_var = root_scope_->FindVar(name);
247 248
        phi::DenseTensor* root_tensor =
            root_var->GetMutable<phi::DenseTensor>();
D
danleifeng 已提交
249
        Variable* var = thread_scope->FindVar(name);
250
        phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
D
danleifeng 已提交
251 252
        TensorCopySync((*tensor), root_tensor->place(), root_tensor);
      }
T
Thunderbrook 已提交
253 254
    }
  }
T
Thunderbrook 已提交
255
#endif
T
Thunderbrook 已提交
256 257 258 259
}
#endif

template <typename T>
260 261 262
void MultiTrainer::MergeToRootScope(phi::DenseTensor* root_tensor,
                                    phi::DenseTensor* tensor) {
  phi::DenseTensor tmp_root;
T
Thunderbrook 已提交
263 264
  TensorCopy(*root_tensor, platform::CPUPlace(), &tmp_root);
  T* tmp_root_data = tmp_root.data<T>();
265
  phi::DenseTensor tmp_tensor;
T
Thunderbrook 已提交
266 267 268 269 270 271 272 273
  TensorCopy(*tensor, platform::CPUPlace(), &tmp_tensor);
  T* data = tmp_tensor.data<T>();
  for (int i = 0; i < tmp_tensor.numel(); i++) {
    tmp_root_data[i] += data[i];
  }
  TensorCopy(tmp_root, platform::CPUPlace(), root_tensor);
}

274
void MultiTrainer::Finalize() {
X
xujiaqi01 已提交
275
  if (need_dump_field_ || need_dump_param_) {
276 277
    FinalizeDumpEnv();
  }
T
Thunderbrook 已提交
278 279 280 281 282
  for (size_t i = 0; i < need_merge_var_names_.size(); i++) {
    Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]);
    if (root_var == nullptr) {
      continue;
    }
283
    phi::DenseTensor* root_tensor = root_var->GetMutable<phi::DenseTensor>();
T
Thunderbrook 已提交
284

W
wangguanqun 已提交
285
    for (int j = 1; j < thread_num_; j++) {
T
Thunderbrook 已提交
286 287 288 289 290 291
      Scope* cur_thread_scope = workers_[j]->GetThreadScope();
      Variable* thread_var =
          cur_thread_scope->FindVar(need_merge_var_names_[i]);
      if (thread_var == nullptr) {
        continue;
      }
292 293
      phi::DenseTensor* thread_tensor =
          thread_var->GetMutable<phi::DenseTensor>();
T
Thunderbrook 已提交
294 295
#define MergeCallback(cpp_type, proto_type)                                    \
  do {                                                                         \
296 297 298
    if (framework::TransToProtoVarType(root_tensor->dtype()) == proto_type) {  \
      if (framework::TransToProtoVarType(thread_tensor->dtype()) !=            \
          proto_type) {                                                        \
T
Thunderbrook 已提交
299 300
        VLOG(0) << "Error: thread id=" << j << ", need_merge_var_names_[" << i \
                << "] " << need_merge_var_names_[i]                            \
301 302
                << ", root tensor type=" << root_tensor->dtype()               \
                << ", thread tensor type=" << thread_tensor->dtype();          \
T
Thunderbrook 已提交
303 304 305 306 307 308 309 310
        exit(-1);                                                              \
      }                                                                        \
      MergeToRootScope<cpp_type>(root_tensor, thread_tensor);                  \
    }                                                                          \
  } while (0)
      _ForEachDataType_(MergeCallback);
    }
  }
W
wangguanqun 已提交
311
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
312 313
  MergeDenseParam();
#endif
Z
zhaocaibei123 已提交
314 315 316 317 318

#if defined PADDLE_WITH_PSCORE
  auto communicator = paddle::distributed::Communicator::GetInstance();
  // for unittest which does not call fleet.init_worker() first
  if (communicator == nullptr) {
319
    VLOG(1) << "MultiTrainer::Finalize communicator is null!";
Z
zhaocaibei123 已提交
320
  } else {
321 322 323 324
    if (communicator->_worker_ptr != nullptr) {
      communicator->_worker_ptr->Flush();
      VLOG(1) << "MultiTrainer::Finalize ps client flush done";
    } else {
325
      VLOG(1) << "communicator->_worker_ptr is null";
326
    }
Z
zhaocaibei123 已提交
327 328
  }
#endif
329 330
  root_scope_->DropKids();
}
D
Dong Daxiang 已提交
331

332 333
}  // end namespace framework
}  // end namespace paddle