section_worker.cc 15.9 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL)
L
lilong12 已提交
13 14 15 16 17
#include <float.h>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"

H
hutuxian 已提交
18 19 20 21 22
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"

#include "paddle/fluid/framework/device_worker.h"
H
hutuxian 已提交
23
#include "paddle/fluid/framework/fleet/box_wrapper.h"
H
hutuxian 已提交
24 25 26 27 28 29 30 31 32
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"

namespace paddle {
namespace framework {

S
sandyhouse 已提交
33
// std::atomic<int> SectionWorker::cpu_id_(0);
S
sandyhouse 已提交
34 35 36 37
// std::mutex SectionWorker::thread_mutex;
// std::mutex SectionWorker::cout_mutex;
// std::condition_variable SectionWorker::thread_condition;
// bool SectionWorker::threads_completed = false;
L
lilong12 已提交
38 39
uint64_t SectionWorker::batch_id_(0);

H
hutuxian 已提交
40
void SectionWorker::Initialize(const TrainerDesc& desc) {
H
hutuxian 已提交
41
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
S
sandyhouse 已提交
42 43 44
  program_.reset(
      new ProgramDesc(desc.section_param().section_config().program_desc()));
  // desc.section_param().section_config(section_id_).program_desc()));
L
lilong12 已提交
45
  for (auto& op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
46 47 48 49 50
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
}

void SectionWorker::AutoSetCPUAffinity(bool reuse) {
S
sandyhouse 已提交
51
  // int thread_cpu_id = cpu_id_.fetch_add(1);
H
hutuxian 已提交
52 53

  unsigned concurrency_cap = std::thread::hardware_concurrency();
S
sandyhouse 已提交
54 55
  // unsigned proc = thread_cpu_id;
  unsigned proc = cpu_id_;
H
hutuxian 已提交
56 57 58 59 60 61

  if (proc >= concurrency_cap) {
    if (reuse) {
      proc %= concurrency_cap;
    } else {
      LOG(INFO) << "All " << concurrency_cap
S
sandyhouse 已提交
62 63 64
                << " CPUs have been set affinities. Fail to set " << cpu_id_
                << "th thread.";
      // << thread_cpu_id << "th thread";
H
hutuxian 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
      return;
    }
  }

  cpu_set_t mask;
  CPU_ZERO(&mask);
  CPU_SET(proc, &mask);

  if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
    return;
  }

  CPU_ZERO(&mask);
  if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) ||
      (0 == CPU_ISSET(proc, &mask))) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
  }
S
sandyhouse 已提交
83 84
  // VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc;
  VLOG(3) << "Set " << cpu_id_ << "th thread affinity to CPU " << proc;
H
hutuxian 已提交
85 86 87
}

void SectionWorker::TrainFiles() {
L
lilong12 已提交
88
  VLOG(3) << "begin section_worker TrainFiles";
S
sandyhouse 已提交
89
  // AutoSetCPUAffinity(true);
H
hutuxian 已提交
90

L
lilong12 已提交
91 92 93 94 95 96 97 98
  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
99
    } else {
L
lilong12 已提交
100 101
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
102
    }
L
lilong12 已提交
103 104 105 106 107 108 109
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
110

S
update  
sandyhouse 已提交
111 112
  platform::Timer batch_timer;

S
sandyhouse 已提交
113 114 115 116 117 118 119 120
  // if (thread_id_ == 0) {
  // while (true) {
  // Start a minibatch.
  // real number of microbatches run
  // int real_microbatch_num = 0;
  batch_timer.Start();
  for (int i = 0; i < num_microbatches_; ++i) {
    try {
L
lilong12 已提交
121 122
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
S
sandyhouse 已提交
123 124 125 126 127 128 129 130 131 132 133 134
        // We run op with op_role = kLRSched only for the first microbatch
        // to avoid increasing the @LR_DECAY_STEP@ multiple times.
        bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
                                op_role == (static_cast<int>(OpRole::kForward) |
                                            static_cast<int>(OpRole::kLoss)) ||
                                op_role == static_cast<int>(OpRole::kLRSched);
        bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                          op_role == (static_cast<int>(OpRole::kForward) |
                                      static_cast<int>(OpRole::kLoss));
        if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
          VLOG(3) << "running an op " << op->Type() << " for scope " << i;
          op->Run(*microbatch_scopes_[i], place_);
L
lilong12 已提交
135
          if (gc) {
S
sandyhouse 已提交
136 137
            DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                                gc.get());
L
lilong12 已提交
138 139
          }
        }
H
hutuxian 已提交
140
      }
S
sandyhouse 已提交
141
    } catch (platform::EOFException& e) {
S
sandyhouse 已提交
142 143 144 145 146
      // std::unique_lock<std::mutex> lk(thread_mutex);
      // threads_completed = true;
      VLOG(3) << "thread  completed.";
      // VLOG(3) << "called notify all";
      // thread_condition.notify_all();
S
sandyhouse 已提交
147 148
      VLOG(3) << "EOF encountered";
      // throw platform::EOFException();
S
sandyhouse 已提交
149 150
      // throw e;
      PADDLE_THROW_EOF();
S
sandyhouse 已提交
151
      break;
H
hutuxian 已提交
152
    }
S
sandyhouse 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
  }
  dev_ctx_->Wait();

  // backward pass
  for (int i = 0; i < num_microbatches_; ++i) {
    for (auto& op : ops_) {
      int op_role = op->Attr<int>(std::string("op_role"));
      if (op_role == static_cast<int>(OpRole::kBackward) ||
          op_role == (static_cast<int>(OpRole::kBackward) |
                      static_cast<int>(OpRole::kLoss))) {
        VLOG(3) << "running an op " << op->Type() << " for scope " << i;
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
L
lilong12 已提交
168 169
        }
      }
S
sandyhouse 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182
    }
  }
  dev_ctx_->Wait();
  // update pass
  for (auto& op : ops_) {
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kOptimize)) {
      VLOG(3) << "running an op " << op->Type() << " for minibatch scope";
      op->Run(*microbatch_scopes_[0], place_);
      if (gc) {
        for (int i = 0; i < num_microbatches_; ++i) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
L
lilong12 已提交
183 184
        }
      }
H
hutuxian 已提交
185 186
    }
  }
S
sandyhouse 已提交
187 188 189 190
  dev_ctx_->Wait();
  batch_timer.Pause();
  VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
  ++batch_id_;
H
hutuxian 已提交
191 192 193
}

void SectionWorker::TrainFilesWithProfiler() {
L
lilong12 已提交
194
  VLOG(3) << "begin section_worker TrainFiles with profiler";
S
sandyhouse 已提交
195
  // AutoSetCPUAffinity(true);
H
hutuxian 已提交
196

L
lilong12 已提交
197 198
  platform::Timer batch_timer;
  platform::Timer timeline;
H
hutuxian 已提交
199 200

  std::vector<std::string> op_name;
S
sandyhouse 已提交
201
  std::vector<double> op_total_time;
L
lilong12 已提交
202 203 204
  std::vector<double> op_max_time;
  std::vector<double> op_min_time;
  std::vector<uint64_t> op_count;
H
hutuxian 已提交
205 206 207 208
  for (auto& op : ops_) {
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
L
lilong12 已提交
209 210 211 212
  op_max_time.resize(ops_.size());
  op_min_time.resize(ops_.size());
  for (size_t i = 0; i < op_min_time.size(); ++i) {
    op_min_time[i] = DBL_MAX;
S
sandyhouse 已提交
213
    op_max_time[i] = 0.0;
H
hutuxian 已提交
214
  }
L
lilong12 已提交
215 216 217 218 219 220 221 222 223 224 225
  op_count.resize(ops_.size());

  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  // const std::vector<std::string> keep_vars;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
226
    } else {
L
lilong12 已提交
227 228
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
229
    }
L
lilong12 已提交
230 231 232 233 234 235 236
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
237

S
sandyhouse 已提交
238 239 240 241 242 243 244
  // if (thread_id_ == 0) {
  struct timeval start;
  struct timeval end;
  struct timeval micro_start;
  struct timeval micro_end;
  // Start a minibatch.
  batch_timer.Start();
S
sandyhouse 已提交
245
  // int real_microbatch_num = 0;
S
sandyhouse 已提交
246 247
  for (int i = 0; i < num_microbatches_; ++i) {
    try {
L
lilong12 已提交
248
      int op_idx = 0;
S
update  
sandyhouse 已提交
249
      gettimeofday(&micro_start, NULL);
L
lilong12 已提交
250
      for (auto& op : ops_) {
S
update  
sandyhouse 已提交
251
        gettimeofday(&start, NULL);
L
lilong12 已提交
252
        int op_role = op->Attr<int>(std::string("op_role"));
S
sandyhouse 已提交
253 254 255 256 257 258 259 260 261 262
        // We run op with op_role = kLRSched only for the first microbatch
        // to avoid increasing the @LR_DECAY_STEP@ multiple times.
        bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
                                op_role == (static_cast<int>(OpRole::kForward) |
                                            static_cast<int>(OpRole::kLoss)) ||
                                op_role == static_cast<int>(OpRole::kLRSched);
        bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                          op_role == (static_cast<int>(OpRole::kForward) |
                                      static_cast<int>(OpRole::kLoss));
        if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
S
sandyhouse 已提交
263 264 265
          // VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
          //        << " for scope " << i;
          VLOG(3) << "running an op " << op->Type() << " for scope " << i;
L
lilong12 已提交
266
          timeline.Start();
S
sandyhouse 已提交
267
          op->Run(*microbatch_scopes_[i], place_);
L
lilong12 已提交
268
          if (gc) {
S
sandyhouse 已提交
269 270
            DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                                gc.get());
L
lilong12 已提交
271
          }
S
update  
sandyhouse 已提交
272
          cudaDeviceSynchronize();
L
lilong12 已提交
273
          timeline.Pause();
S
sandyhouse 已提交
274
          gettimeofday(&end, NULL);
L
lilong12 已提交
275 276 277 278 279 280 281 282 283 284
          auto time = timeline.ElapsedUS();
          op_total_time[op_idx] += time;
          if (time > op_max_time[op_idx]) {
            op_max_time[op_idx] = time;
          }
          if (time < op_min_time[op_idx]) {
            op_min_time[op_idx] = time;
          }
          op_count[op_idx] += 1;
          op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
285
          {
S
sandyhouse 已提交
286
            // std::unique_lock<std::mutex> lk(cout_mutex);
S
sandyhouse 已提交
287
            std::cout << std::fixed;
S
update  
sandyhouse 已提交
288
            std::cout.precision(0);
S
sandyhouse 已提交
289 290
            std::cout << "::FWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                      << "]:SCOPE[" << i << "]:OP[" << op->Type() << "]:START["
S
sandyhouse 已提交
291 292
                      << start.tv_sec * 1e6 + start.tv_usec << "]:END["
                      << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
S
update  
sandyhouse 已提交
293
          }
L
lilong12 已提交
294 295 296
        }
        op_idx++;
      }
S
update  
sandyhouse 已提交
297 298
      gettimeofday(&micro_end, NULL);
      {
S
sandyhouse 已提交
299
        // std::unique_lock<std::mutex> lk(cout_mutex);
S
sandyhouse 已提交
300
        std::cout << std::fixed;
S
update  
sandyhouse 已提交
301
        std::cout.precision(0);
S
sandyhouse 已提交
302
        std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_
S
sandyhouse 已提交
303 304 305 306
                  << "]:START["
                  << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END["
                  << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]"
                  << std::endl;
S
update  
sandyhouse 已提交
307
      }
S
sandyhouse 已提交
308
    } catch (platform::EOFException& e) {
S
sandyhouse 已提交
309 310 311 312 313 314 315
      VLOG(3) << "thread  completed.";
      VLOG(0) << "EOF encountered";
      VLOG(0) << "============timeline============";
      for (size_t i = 0; i < ops_.size(); ++i) {
        VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i]
                << ", min_time: " << op_min_time[i]
                << ", mean_time: " << op_total_time[i] / op_count[i];
S
sandyhouse 已提交
316
      }
S
sandyhouse 已提交
317
      VLOG(0) << "================================";
S
sandyhouse 已提交
318
      throw e;
S
sandyhouse 已提交
319
      break;
H
hutuxian 已提交
320
    }
S
sandyhouse 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
  }
  dev_ctx_->Wait();
  // backward pass
  for (int i = 0; i < num_microbatches_; ++i) {
    int op_idx = 0;
    gettimeofday(&micro_start, NULL);
    for (auto& op : ops_) {
      gettimeofday(&start, NULL);
      int op_role = op->Attr<int>(std::string("op_role"));
      if (op_role == static_cast<int>(OpRole::kBackward) ||
          op_role == (static_cast<int>(OpRole::kBackward) |
                      static_cast<int>(OpRole::kLoss))) {
        VLOG(3) << "running an op " << op->Type() << " for scope " << i;
        timeline.Start();
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
339
        }
S
sandyhouse 已提交
340 341 342 343 344 345 346 347 348 349
        cudaDeviceSynchronize();
        gettimeofday(&end, NULL);
        timeline.Pause();
        auto time = timeline.ElapsedUS();
        op_total_time[op_idx] += time;
        if (time > op_max_time[op_idx]) {
          op_max_time[op_idx] = time;
        }
        if (time < op_min_time[op_idx]) {
          op_min_time[op_idx] = time;
L
lilong12 已提交
350
        }
S
sandyhouse 已提交
351 352
        op_count[op_idx] += 1;
        op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
353
        {
S
sandyhouse 已提交
354
          // std::unique_lock<std::mutex> lk(cout_mutex);
S
sandyhouse 已提交
355
          std::cout << std::fixed;
S
update  
sandyhouse 已提交
356
          std::cout.precision(0);
S
sandyhouse 已提交
357 358 359 360
          std::cout << "::BWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                    << "]:SCOPE[" << i << "]:OP[" << op->Type() << "]:START["
                    << start.tv_sec * 1e6 + start.tv_usec << "]:END["
                    << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
S
update  
sandyhouse 已提交
361
        }
H
hutuxian 已提交
362
      }
S
sandyhouse 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375 376
      op_idx++;
    }
    gettimeofday(&micro_end, NULL);
    {
      // std::unique_lock<std::mutex> lk(cout_mutex);
      std::cout << std::fixed;
      std::cout.precision(0);
      std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]"
                << std::endl;
    }
  }
  dev_ctx_->Wait();
S
sandyhouse 已提交
377 378 379 380 381
  // if (real_microbatch_num == 0) {
  //   batch_timer.Pause();
  //   VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
  //   return;
  // }
S
sandyhouse 已提交
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
  // update pass
  int op_idx = 0;
  gettimeofday(&micro_start, NULL);
  for (auto& op : ops_) {
    gettimeofday(&start, NULL);
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kOptimize)) {
      VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
              << " for minibatch scope";
      timeline.Start();
      op->Run(*microbatch_scopes_[0], place_);
      if (gc) {
        for (int i = 0; i < num_microbatches_; ++i) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
S
update  
sandyhouse 已提交
397
        }
L
lilong12 已提交
398
      }
S
sandyhouse 已提交
399 400 401 402 403 404 405
      cudaDeviceSynchronize();
      gettimeofday(&end, NULL);
      timeline.Pause();
      auto time = timeline.ElapsedUS();
      op_total_time[op_idx] += time;
      if (time > op_max_time[op_idx]) {
        op_max_time[op_idx] = time;
406
      }
S
sandyhouse 已提交
407 408
      if (time < op_min_time[op_idx]) {
        op_min_time[op_idx] = time;
L
lilong12 已提交
409
      }
S
sandyhouse 已提交
410 411
      op_count[op_idx] += 1;
      op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
412
      {
S
sandyhouse 已提交
413
        std::cout << std::fixed;
S
update  
sandyhouse 已提交
414
        std::cout.precision(0);
S
sandyhouse 已提交
415 416 417 418
        std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_
                  << "]:SCOPE[" << num_microbatches_ << "]:OP[" << op->Type()
                  << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                  << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]"
S
sandyhouse 已提交
419
                  << std::endl;
S
update  
sandyhouse 已提交
420
      }
H
hutuxian 已提交
421
    }
S
sandyhouse 已提交
422 423 424 425 426 427 428 429 430
    op_idx++;
  }
  gettimeofday(&micro_end, NULL);
  {
    std::cout << std::fixed;
    std::cout.precision(0);
    std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:START["
              << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END["
              << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
H
hutuxian 已提交
431
  }
S
sandyhouse 已提交
432 433 434 435
  dev_ctx_->Wait();
  batch_timer.Pause();
  VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
  ++batch_id_;
H
hutuxian 已提交
436 437 438 439
}
}  // namespace framework
}  // namespace paddle
#endif