heter_pipeline_trainer.cc 11.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(PADDLE_WITH_PSCORE)
16
#include "paddle/fluid/distributed/ps/service/heter_server.h"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"

namespace paddle {
namespace framework {

class Variable;

using MiniScope = std::unordered_map<int, Scope*>;
using MicroScope =
    std::unordered_map<int, std::shared_ptr<std::vector<Scope*>>>;
using TaskQueue =
    std::unordered_map<int, std::shared_ptr<::paddle::framework::BlockingQueue<
                                std::pair<std::string, int>>>>;

void HeterPipelineTrainer::ResetDataset(Dataset* dataset) {
  if (pipeline_stage_ == 0) {
    SetDataset(dataset);
    const std::vector<paddle::framework::DataFeed*> readers =
        dataset->GetReaders();
    VLOG(3) << "readers num: " << readers.size();
    // change thread num is not supported
    PADDLE_ENFORCE_EQ(thread_num_, readers.size(),
                      platform::errors::InvalidArgument(
                          "change Dataset thread_num is not supported"));
    int cnt = -1;
    for (auto& worker_pair : workers_) {
      cnt++;
      auto device_worker = worker_pair.second;
      auto this_worker =
          std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
              device_worker);
      this_worker->SetDataFeed(readers[cnt]);
      this_worker->SetReaderPlace(place_);
    }
  }
}

void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
                                      Dataset* dataset) {
  thread_num_ = trainer_desc.thread_num();
  ParseDumpConfig(trainer_desc);
  SetDebug(trainer_desc.debug());
  const std::vector<paddle::framework::DataFeed*> readers =
      dataset->GetReaders();
  VLOG(3) << "readers num: " << readers.size();
  // change thread num to readers num
  thread_num_ = readers.size();
  VLOG(3) << "worker thread num: " << thread_num_;
  const auto& heter_section_params = trainer_desc.heter_section_param();
  num_pipeline_stages_ = heter_section_params.num_pipeline_stages();
  pipeline_stage_ = heter_section_params.pipeline_stage();
  num_microbatches_ = heter_section_params.num_microbatches();
  VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
  trainer_desc_ = trainer_desc;
  trainer_id_ = trainer_desc.trainer_id();
  for (int i = 0; i < num_pipeline_stages_; ++i) {
    auto trainer_num = trainer_desc.trainers(i);
    trainers_.push_back(trainer_num);
  }
  int cpu_trainer_num = trainers_[0];
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
  // int cur_stage_trainer_num = trainers_[pipeline_stage_];
  // int global_thread_num = cpu_trainer_num * thread_num_;
  // int previous_trainers = 0;
  // for (int i = 0; i < pipeline_stage_; i++) previous_trainers +=
  // trainers_[i];
  // int stage_trainer_id =
  //    trainer_id_ - previous_trainers;  // trainer id in current stage

  if (pipeline_stage_ == 0) {  // for cpu trainer
    int cnt = -1;
    int real_thread_id = trainer_id_;
    for (int i = 0; i < thread_num_; i++) {
      cnt++;
      workers_[real_thread_id] = DeviceWorkerFactory::CreateDeviceWorker(
          trainer_desc.device_worker_name());
      auto this_worker =
          std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
              workers_[real_thread_id]);
      this_worker->SetDebug(debug_);
      this_worker->SetNeedDumpField(need_dump_field_);
      this_worker->SetNeedDumpParam(need_dump_param_);
      this_worker->SetDumpFieldVector(dump_fields_);
      this_worker->SetDumpParamVector(dump_param_);
      this_worker->InitRandomDumpConfig(trainer_desc);
      this_worker->SetDeviceIndex(real_thread_id);
      real_thread_id += cpu_trainer_num;
      // if (pipeline_stage_ == 0) {
      this_worker->SetDataFeed(readers[cnt]);
      //}
      this_worker->SetMicrobatchNum(num_microbatches_);
      this_worker->SetPipelineStageNum(num_pipeline_stages_);
      this_worker->SetPipelineStage(pipeline_stage_);
    }
  } else {  // for heter_trainer
    // heter trainer with thread_id == -1 is not for
    // real training
    workers_[-1] = DeviceWorkerFactory::CreateDeviceWorker(
117 118 119
        trainer_desc.device_worker_name());
    auto this_worker =
        std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
120
            workers_[-1]);
121 122 123
    this_worker->SetMicrobatchNum(num_microbatches_);
    this_worker->SetPipelineStageNum(num_pipeline_stages_);
    this_worker->SetPipelineStage(pipeline_stage_);
124
    this_worker->SetDeviceIndex(-1);
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
  }
}

void HeterPipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
  if (need_dump_field_) {
    InitDumpEnv();
  }
}

std::string HeterPipelineTrainer::GetDumpPath(int tid) {
  return string::format_string("%s/part-%05d", dump_fields_path_.c_str(), tid);
}

void HeterPipelineTrainer::InitDumpEnv() {
  queue_ = paddle::framework::MakeChannel<std::string>();
  for (int i = 0; i < thread_num_; ++i) {
    workers_[i]->SetChannelWriter(queue_.get());
  }
  dump_thread_num_ = 1;
  for (int i = 0; i < dump_thread_num_; i++) {
    dump_thread_.push_back(
        std::thread(std::bind(&TrainerBase::DumpWork, this, i)));
  }
}

void HeterPipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
                                          const platform::Place& place) {
  place_ = place;
  PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
                                           "root_scope_ can not be nullptr"));
  // initialize mini_scopes & micro_scopes
  mini_scopes_.reset(new MiniScope{});
  micro_scopes_.reset(new MicroScope{});
  task_queue_.reset(new TaskQueue{});
  for (auto& worker_pair : workers_) {
    auto worker_index = worker_pair.first;
    auto device_worker = worker_pair.second;
    auto this_worker =
        std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
            device_worker);
    this_worker->SetPlace(place);
    this_worker->Initialize(trainer_desc_);
    if (pipeline_stage_ == 0) {
      this_worker->SetReaderPlace(place);
    }
    this_worker->SetRootScope(root_scope_);
    // generate mini_batch scope for every worker
    auto* minibatch_scope = &root_scope_->NewScope();
    (*mini_scopes_)[worker_index] = minibatch_scope;
    this_worker->SetMinibatchScope(minibatch_scope);
    // after set micro num & mini batch scope
    this_worker->CreateMicrobatchScopes();
    (*micro_scopes_)[worker_index] = this_worker->GetMicrobatchScopes();
    (*task_queue_)[worker_index] = this_worker->GetThreadQueue();
  }
}

void HeterPipelineTrainer::Run() {
  VLOG(3) << "Going to run HeterPipelineTrainer::Run()";
  if (listen_ptr_ == nullptr) {
    for (auto& worker_pair : workers_) {
      auto& device_worker = worker_pair.second;
      auto worker_0 =
          std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
              device_worker);
      listen_ptr_.reset(new std::thread(
          std::bind(&HeterSectionWorker::RunListen, worker_0.get())));
      break;
    }
  }
  auto heter_server = paddle::distributed::HeterServer::GetInstance();
  heter_server->WaitServerReady();
197
  heter_server->SetMiniBatchScopes(mini_scopes_);
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
  heter_server->SetMicroBatchScopes(micro_scopes_);
  heter_server->SetTaskQueue(task_queue_);
  // main training logic
  if (pipeline_stage_ == 0) {  // for cpu trainer
    for (auto& worker_pair : workers_) {
      auto device_worker = worker_pair.second;
      if (!debug_) {
        threads_.push_back(
            std::thread(&DeviceWorker::TrainFiles, device_worker.get()));
      } else {
        threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
                                       device_worker.get()));
      }
    }
  } else {  // for heter worker
213
    // start thread_worker with thread_id = -1
214 215 216 217 218 219 220 221 222 223
    for (auto& worker_pair : workers_) {
      auto device_worker = worker_pair.second;
      if (!debug_) {
        threads_.push_back(
            std::thread(&DeviceWorker::TrainFiles, device_worker.get()));
      } else {
        threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
                                       device_worker.get()));
      }
    }
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    bool epoch_finish = false;
    auto heter_server = paddle::distributed::HeterServer::GetInstance();
    while (!epoch_finish) {
      if (heter_server->IsStop()) {
        epoch_finish = true;
        continue;
      }
      // create new thread_worker
      // size_t thread_num = (*micro_scopes_).size();
      // size_t thread_num = (*task_queue_).size();
      size_t thread_num = heter_server->GetThreadNum();
      while (thread_num > threads_.size()) {
        for (auto& worker_pair : (*micro_scopes_)) {
          auto worker_index = worker_pair.first;
          if (workers_.find(worker_index) != workers_.end()) continue;
          workers_[worker_index] = DeviceWorkerFactory::CreateDeviceWorker(
              trainer_desc_.device_worker_name());
          auto this_worker =
              std::dynamic_pointer_cast<paddle::framework::HeterSectionWorker>(
                  workers_[worker_index]);
          this_worker->SetDebug(debug_);
          this_worker->SetNeedDumpField(need_dump_field_);
          this_worker->SetNeedDumpParam(need_dump_param_);
          this_worker->SetDumpFieldVector(dump_fields_);
          this_worker->SetDumpParamVector(dump_param_);
          this_worker->InitRandomDumpConfig(trainer_desc_);
          this_worker->SetDeviceIndex(worker_index);
          this_worker->SetMicrobatchNum(num_microbatches_);
          this_worker->SetPipelineStageNum(num_pipeline_stages_);
          this_worker->SetPipelineStage(pipeline_stage_);
          this_worker->SetPlace(place_);
          this_worker->Initialize(trainer_desc_);
          this_worker->SetRootScope(root_scope_);

          // generate mini_batch scope for every worker
          // auto* minibatch_scope = &root_scope_->NewScope();
          auto* minibatch_scope = (*mini_scopes_)[worker_index];
          // (*mini_scopes_)[worker_index] = minibatch_scope;
          this_worker->SetMinibatchScope(minibatch_scope);
          // after set micro num & mini batch scope
          this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]);
          this_worker->CreateMicrobatchScopes();
          // this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]);
          this_worker->SetThreadQueue((*task_queue_)[worker_index]);
          if (!debug_) {
            threads_.push_back(
                std::thread(&DeviceWorker::TrainFiles, this_worker.get()));
          } else {
            threads_.push_back(std::thread(
                &DeviceWorker::TrainFilesWithProfiler, this_worker.get()));
          }
        }
      }
    }
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
  }
  for (auto& th : threads_) {
    th.join();
  }
  if (threads_.size() > 0) {
    threads_.clear();
  }
  VLOG(3) << "Epoch Trainging done";
}

void HeterPipelineTrainer::Finalize() {
  VLOG(3) << "HeterPipelineTrainer Finalize";
  auto heter_server = paddle::distributed::HeterServer::GetInstance();
  heter_server->Stop();
  if (listen_ptr_) {
    (listen_ptr_.get())->join();
    listen_ptr_.reset(nullptr);
  }
  if (need_dump_field_) {
    FinalizeDumpEnv();
  }
  root_scope_->DropKids();
}

Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) {
303 304 305 306 307
  if (workers_.find(thread_id) != workers_.end()) {
    return workers_[thread_id]->GetThreadScope();
  } else {
    return nullptr;
  }
308 309 310 311 312
}

}  // end namespace framework
}  // end namespace paddle
#endif