multi_trainer.cc 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
16

17 18
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
Z
zhaocaibei123 已提交
19
#include "paddle/fluid/platform/lodtensor_printer.h"
T
tangwei12 已提交
20

T
tangwei12 已提交
21
#if defined PADDLE_WITH_PSCORE
22
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
T
tangwei12 已提交
23
#endif
24 25 26 27

namespace paddle {
namespace framework {

D
dongdaxiang 已提交
28
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
29
                              Dataset* dataset) {
30
  thread_num_ = trainer_desc.thread_num();
31 32
  SetDataset(dataset);

H
hutuxian 已提交
33
  ParseDumpConfig(trainer_desc);
34 35 36 37
  mpi_rank_ = trainer_desc.mpi_rank();
  mpi_size_ = trainer_desc.mpi_size();
  dump_file_num_ = trainer_desc.dump_file_num();

38 39 40 41 42
  for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
       i++) {
    need_merge_var_names_.push_back(
        trainer_desc.downpour_param().stat_var_names(i));
  }
T
Thunderbrook 已提交
43 44 45 46 47 48 49
#ifdef PADDLE_WITH_HETERPS
  for (int i = 0; i < thread_num_; ++i) {
    int num = trainer_desc.worker_places(i);
    platform::CUDAPlace place = platform::CUDAPlace(num);
    places_.push_back(place);
  }
#endif
50
  // get filelist from trainer_desc here
J
jiaqi 已提交
51
  const std::vector<paddle::framework::DataFeed*> readers =
D
dongdaxiang 已提交
52
      dataset->GetReaders();
53
  VLOG(3) << "readers num: " << readers.size();
54 55 56 57
  // change thread num to readers num
  thread_num_ = readers.size();
  VLOG(3) << "worker thread num: " << thread_num_;
  workers_.resize(thread_num_);
58

T
tangwei12 已提交
59
#if defined PADDLE_WITH_PSCORE
60
  if (trainer_desc.thread_barrier()) {
T
tangwei12 已提交
61
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset(
62 63 64 65
        thread_num_);
  }
#endif

66 67 68
  for (int i = 0; i < thread_num_; ++i) {
    workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
        trainer_desc.device_worker_name());
H
hutuxian 已提交
69 70 71 72 73
    workers_[i]->SetNeedDumpField(need_dump_field_);
    workers_[i]->SetNeedDumpParam(need_dump_param_);
    workers_[i]->SetDumpFieldVector(dump_fields_);
    workers_[i]->SetDumpParamVector(dump_param_);
    workers_[i]->InitRandomDumpConfig(trainer_desc);
D
dongdaxiang 已提交
74
    workers_[i]->Initialize(trainer_desc);
75
    workers_[i]->SetDeviceIndex(i);
D
dongdaxiang 已提交
76
    workers_[i]->SetDataFeed(readers[i]);
77
  }
D
dongdaxiang 已提交
78 79

  // set debug here
80
  SetDebug(trainer_desc.debug());
81 82
}

H
hutuxian 已提交
83
std::string MultiTrainer::GetDumpPath(int tid) {
Y
yaoxuefeng 已提交
84 85 86 87
  if (user_define_dump_filename_ != "") {
    return string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(),
                                 user_define_dump_filename_.c_str(), tid);
  }
H
hutuxian 已提交
88 89
  return string::format_string("%s/part-%03d-%05d", dump_fields_path_.c_str(),
                               mpi_rank_, tid);
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
}

void MultiTrainer::InitDumpEnv() {
  queue_ = paddle::framework::MakeChannel<std::string>();
  for (int i = 0; i < thread_num_; ++i) {
    workers_[i]->SetChannelWriter(queue_.get());
  }
  dump_thread_num_ = 1;
  if (dump_file_num_ > mpi_size_) {
    dump_thread_num_ = dump_file_num_ / mpi_size_;
    if (dump_file_num_ % mpi_size_ > mpi_rank_) {
      dump_thread_num_ += 1;
    }
  }
  for (int i = 0; i < dump_thread_num_; i++) {
    dump_thread_.push_back(
H
hutuxian 已提交
106
        std::thread(std::bind(&TrainerBase::DumpWork, this, i)));
107 108 109
  }
}

110 111 112 113
// call only after all resources are set in current trainer
void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
                                  const platform::Place& place) {
  for (int i = 0; i < thread_num_; ++i) {
T
Thunderbrook 已提交
114 115 116
#ifdef PADDLE_WITH_HETERPS
    workers_[i]->SetPlace(places_[i]);
    workers_[i]->SetReaderPlace(places_[i]);
117 118
    workers_[i]->SetDeviceContext(
        platform::DeviceContextPool::Instance().Get(places_[i]));
T
Thunderbrook 已提交
119
#else
120
    workers_[i]->SetPlace(place);
121
    workers_[i]->SetReaderPlace(place);
T
Thunderbrook 已提交
122
#endif
123 124 125
    workers_[i]->SetRootScope(root_scope_);
    workers_[i]->CreateDeviceResource(main_program);  // Program
    workers_[i]->BindingDataFeedMemory();
T
Thunderbrook 已提交
126
    workers_[i]->CacheProgram(main_program);
127
  }
T
Thunderbrook 已提交
128 129 130 131 132 133 134 135 136 137 138 139
#ifdef PADDLE_WITH_HETERPS
  for (int num = 0; num < thread_num_; ++num) {
    auto place = places_[num];
    Scope* scope = workers_[num]->GetThreadScope();
    auto& block = main_program.Block(0);
    for (auto& var : block.AllVars()) {
      if (var->Persistable()) {
        auto name = var->Name();
        Variable* root_var = root_scope_->FindVar(name);
        if (!root_var) {
          continue;
        }
140
        if (root_var->IsType<pten::SelectedRows>()) {
T
Thunderbrook 已提交
141 142 143 144 145 146 147 148 149 150 151
          continue;
        }
        LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
        auto* ptr = scope->Var(name);
        InitializeVariable(ptr, proto::VarType::LOD_TENSOR);
        LoDTensor* thread_tensor = ptr->GetMutable<LoDTensor>();
        TensorCopy(*root_tensor, place, thread_tensor);
      }
    }
  }
#endif
152 153
}

154
void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
X
xujiaqi01 已提交
155
  if (need_dump_field_ || need_dump_param_) {
156 157
    InitDumpEnv();
  }
Z
zhaocaibei123 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171

#ifdef PADDLE_WITH_PSCORE
  // pull dense param first
  auto communicator = paddle::distributed::Communicator::GetInstance();
  // for unittest which call train_from_dataset but does not call
  // fleet.init_worker() first
  if (communicator == nullptr) {
    VLOG(0) << "MultiTrainer::InitOtherEnv Communicator is null!";
  } else {
    auto& recv_ctx = communicator->GetRecvCtxMap();
    communicator->PullDense(recv_ctx);
    VLOG(3) << "init other env done.";
  }
#endif
172 173
}

174 175 176 177
Scope* MultiTrainer::GetWorkerScope(int thread_id) {
  return workers_[thread_id]->GetThreadScope();
}

178
void MultiTrainer::Run() {
179
  VLOG(3) << "Going to run";
180
  for (int thidx = 0; thidx < thread_num_; ++thidx) {
181 182 183 184 185 186 187
    if (!debug_) {
      threads_.push_back(
          std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
    } else {
      threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
                                     workers_[thidx].get()));
    }
188 189 190 191 192 193
  }
  for (auto& th : threads_) {
    th.join();
  }
}

T
Thunderbrook 已提交
194 195
#ifdef PADDLE_WITH_HETERPS
void MultiTrainer::MergeDenseParam() {
T
Thunderbrook 已提交
196
#ifdef PADDLE_WTIH_PSCORE
T
Thunderbrook 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209
  auto communicator = paddle::distributed::Communicator::GetInstance();
  auto& recv_ctx = communicator->GetRecvCtxMap();
  Scope* thread_scope = workers_[0]->GetThreadScope();
  for (auto& iter : recv_ctx) {
    auto& varnames = iter.second;
    for (auto& name : varnames) {
      Variable* root_var = root_scope_->FindVar(name);
      LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
      Variable* var = thread_scope->FindVar(name);
      LoDTensor* tensor = var->GetMutable<LoDTensor>();
      TensorCopy((*tensor), root_tensor->place(), root_tensor);
    }
  }
T
Thunderbrook 已提交
210
#endif
T
Thunderbrook 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
}
#endif

template <typename T>
void MultiTrainer::MergeToRootScope(LoDTensor* root_tensor, LoDTensor* tensor) {
  LoDTensor tmp_root;
  TensorCopy(*root_tensor, platform::CPUPlace(), &tmp_root);
  T* tmp_root_data = tmp_root.data<T>();
  LoDTensor tmp_tensor;
  TensorCopy(*tensor, platform::CPUPlace(), &tmp_tensor);
  T* data = tmp_tensor.data<T>();
  for (int i = 0; i < tmp_tensor.numel(); i++) {
    tmp_root_data[i] += data[i];
  }
  TensorCopy(tmp_root, platform::CPUPlace(), root_tensor);
}

228
void MultiTrainer::Finalize() {
X
xujiaqi01 已提交
229
  if (need_dump_field_ || need_dump_param_) {
230 231
    FinalizeDumpEnv();
  }
W
wangguanqun 已提交
232

T
Thunderbrook 已提交
233 234 235 236 237 238 239
  for (size_t i = 0; i < need_merge_var_names_.size(); i++) {
    Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]);
    if (root_var == nullptr) {
      continue;
    }
    LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();

W
wangguanqun 已提交
240
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
241
    for (size_t j = 0; j < places_.size(); j++) {
W
wangguanqun 已提交
242 243 244
#else
    for (int j = 1; j < thread_num_; j++) {
#endif
T
Thunderbrook 已提交
245 246 247 248 249 250 251 252 253
      Scope* cur_thread_scope = workers_[j]->GetThreadScope();
      Variable* thread_var =
          cur_thread_scope->FindVar(need_merge_var_names_[i]);
      if (thread_var == nullptr) {
        continue;
      }
      LoDTensor* thread_tensor = thread_var->GetMutable<LoDTensor>();
#define MergeCallback(cpp_type, proto_type)                                    \
  do {                                                                         \
254 255 256
    if (framework::TransToProtoVarType(root_tensor->dtype()) == proto_type) {  \
      if (framework::TransToProtoVarType(thread_tensor->dtype()) !=            \
          proto_type) {                                                        \
T
Thunderbrook 已提交
257 258
        VLOG(0) << "Error: thread id=" << j << ", need_merge_var_names_[" << i \
                << "] " << need_merge_var_names_[i]                            \
259 260
                << ", root tensor type=" << root_tensor->dtype()               \
                << ", thread tensor type=" << thread_tensor->dtype();          \
T
Thunderbrook 已提交
261 262 263 264 265 266 267 268
        exit(-1);                                                              \
      }                                                                        \
      MergeToRootScope<cpp_type>(root_tensor, thread_tensor);                  \
    }                                                                          \
  } while (0)
      _ForEachDataType_(MergeCallback);
    }
  }
W
wangguanqun 已提交
269
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
270 271
  MergeDenseParam();
#endif
Z
zhaocaibei123 已提交
272 273 274 275 276 277 278 279 280 281 282

#if defined PADDLE_WITH_PSCORE
  auto communicator = paddle::distributed::Communicator::GetInstance();
  // for unittest which does not call fleet.init_worker() first
  if (communicator == nullptr) {
    VLOG(0) << "MultiTrainer::Finalize communicator is null!";
  } else {
    communicator->_worker_ptr->flush();
    VLOG(1) << "MultiTrainer::Finalize ps client flush done";
  }
#endif
283 284
  root_scope_->DropKids();
}
D
Dong Daxiang 已提交
285

286 287
}  // end namespace framework
}  // end namespace paddle