section_worker.cc 29.5 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL)
L
lilong12 已提交
13 14 15 16 17
#include <float.h>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"

H
hutuxian 已提交
18 19 20 21 22
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"

#include "paddle/fluid/framework/device_worker.h"
H
hutuxian 已提交
23
#include "paddle/fluid/framework/fleet/box_wrapper.h"
H
hutuxian 已提交
24 25 26 27 28 29 30 31 32 33
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"

namespace paddle {
namespace framework {

std::atomic<int> SectionWorker::cpu_id_(0);
L
lilong12 已提交
34
std::mutex SectionWorker::thread_mutex;
S
update  
sandyhouse 已提交
35
std::mutex SectionWorker::cout_mutex;
L
lilong12 已提交
36 37 38 39
std::condition_variable SectionWorker::thread_condition;
bool SectionWorker::threads_completed = false;
uint64_t SectionWorker::batch_id_(0);

H
hutuxian 已提交
40
void SectionWorker::Initialize(const TrainerDesc& desc) {
H
hutuxian 已提交
41
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
L
lilong12 已提交
42
  program_.reset(new ProgramDesc(
H
hutuxian 已提交
43
      desc.section_param().section_config(section_id_).program_desc()));
L
lilong12 已提交
44
  for (auto& op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
}

void SectionWorker::AutoSetCPUAffinity(bool reuse) {
  int thread_cpu_id = cpu_id_.fetch_add(1);

  unsigned concurrency_cap = std::thread::hardware_concurrency();
  unsigned proc = thread_cpu_id;

  if (proc >= concurrency_cap) {
    if (reuse) {
      proc %= concurrency_cap;
    } else {
      LOG(INFO) << "All " << concurrency_cap
                << " CPUs have been set affinities. Fail to set "
                << thread_cpu_id << "th thread";
      return;
    }
  }

  cpu_set_t mask;
  CPU_ZERO(&mask);
  CPU_SET(proc, &mask);

  if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
    return;
  }

  CPU_ZERO(&mask);
  if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) ||
      (0 == CPU_ISSET(proc, &mask))) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
  }
L
lilong12 已提交
80
  VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc;
H
hutuxian 已提交
81 82 83
}

void SectionWorker::TrainFiles() {
L
lilong12 已提交
84
  VLOG(3) << "begin section_worker TrainFiles";
H
hutuxian 已提交
85 86
  AutoSetCPUAffinity(true);

L
lilong12 已提交
87 88 89 90 91 92 93 94
  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
95
    } else {
L
lilong12 已提交
96 97
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
98
    }
L
lilong12 已提交
99 100 101 102 103 104 105
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
106

S
update  
sandyhouse 已提交
107 108
  platform::Timer batch_timer;

L
lilong12 已提交
109 110 111
  if (thread_id_ == 0) {
    while (true) {
      // Start a minibatch.
112 113
      // real number of microbatches run
      int real_microbatch_num = 0;
S
update  
sandyhouse 已提交
114
      batch_timer.Start();
L
lilong12 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
      for (int i = 0; i < num_microbatches_; ++i) {
        try {
          for (auto& op : ops_) {
            int op_role = op->Attr<int>(std::string("op_role"));
            // We run op with op_role = kLRSched only for the first microbatch
            // to avoid increasing the @LR_DECAY_STEP@ multiple times.
            bool run_first_mbatch =
                op_role == static_cast<int>(OpRole::kForward) ||
                op_role == (static_cast<int>(OpRole::kForward) |
                            static_cast<int>(OpRole::kLoss)) ||
                op_role == static_cast<int>(OpRole::kLRSched);
            bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                              op_role == (static_cast<int>(OpRole::kForward) |
                                          static_cast<int>(OpRole::kLoss));
            if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
              VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                      << " for scope " << i;
              op->Run(*microbatch_scopes_[i], place_);
              if (gc) {
                DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                    unused_vars_, gc.get());
              }
            }
          }
        } catch (platform::EOFException&) {
          std::unique_lock<std::mutex> lk(thread_mutex);
          threads_completed = true;
          VLOG(3) << "thread " << thread_id_ << " completed.";
          VLOG(3) << "called notify all";
          thread_condition.notify_all();
          VLOG(0) << "EOF encountered";
146
          break;
L
lilong12 已提交
147
        }
148 149 150
        {
          real_microbatch_num += 1;
          batch_id_ += 1;
L
lilong12 已提交
151 152 153 154
          VLOG(3) << "called notify all";
          std::unique_lock<std::mutex> lk(thread_mutex);
          thread_condition.notify_all();
        }
H
hutuxian 已提交
155
      }
S
update  
sandyhouse 已提交
156
      dev_ctx_->Wait();
157
      VLOG(0) << "real_microbatch_num for thread 0 " << real_microbatch_num;
L
lilong12 已提交
158
      // backward pass
159
      for (int i = 0; i < real_microbatch_num; ++i) {
L
lilong12 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
          }
H
hutuxian 已提交
173 174
        }
      }
S
update  
sandyhouse 已提交
175
      dev_ctx_->Wait();
176 177 178 179 180
      if (real_microbatch_num == 0) {
        batch_timer.Pause();
        VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
        return;
      }
L
lilong12 已提交
181 182 183 184 185 186 187 188 189 190 191 192
      // update pass
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
        }
H
hutuxian 已提交
193
      }
L
lilong12 已提交
194
      dev_ctx_->Wait();
S
update  
sandyhouse 已提交
195 196
      batch_timer.Pause();
      VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
H
hutuxian 已提交
197
    }
L
lilong12 已提交
198 199 200
  } else {
    while (true) {
      // forward pass:
201
      int real_microbatch_num = 0;
L
lilong12 已提交
202
      for (int i = 0; i < num_microbatches_; ++i) {
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
        {
          PADDLE_ENFORCE_LE(
              local_batch_id_, batch_id_,
              platform::errors::InvalidArgument(
                  "local_batch_id_ (%d) must be less than or equal to "
                  "batch_id_ (%d)",
                  local_batch_id_, batch_id_));
          std::unique_lock<std::mutex> lk(thread_mutex);
          if (local_batch_id_ == batch_id_ && !threads_completed) {
            thread_condition.wait(lk);
          }
          VLOG(3) << "thread " << thread_id_ << " local_batch_id_ "
                  << local_batch_id_ << " batch_id_ " << batch_id_;
          if (threads_completed) {
            VLOG(3) << "thread " << thread_id_ << " completed.";
            lk.unlock();
            threads_completed = false;
            break;
          }
          lk.unlock();
          local_batch_id_ += 1;
          real_microbatch_num += 1;
        }
L
lilong12 已提交
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          // We run op with op_role = kLRSched only for the first microbatch
          // to avoid increasing the @LR_DECAY_STEP@ multiple times.
          bool run_first_mbatch =
              op_role == static_cast<int>(OpRole::kForward) ||
              op_role == (static_cast<int>(OpRole::kForward) |
                          static_cast<int>(OpRole::kLoss)) ||
              op_role == static_cast<int>(OpRole::kLRSched);
          bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                            op_role == (static_cast<int>(OpRole::kForward) |
                                        static_cast<int>(OpRole::kLoss));
          if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
          }
        }
      }
S
update  
sandyhouse 已提交
249
      dev_ctx_->Wait();
L
lilong12 已提交
250
      // backward pass
251
      for (int i = 0; i < real_microbatch_num; ++i) {
L
lilong12 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
          }
        }
      }
S
update  
sandyhouse 已提交
267
      dev_ctx_->Wait();
L
lilong12 已提交
268
      // update pass
269 270 271
      if (real_microbatch_num == 0) {
        return;
      }
L
lilong12 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
        }
      }
      dev_ctx_->Wait();
H
hutuxian 已提交
285 286 287 288 289
    }
  }
}

void SectionWorker::TrainFilesWithProfiler() {
L
lilong12 已提交
290
  VLOG(3) << "begin section_worker TrainFiles with profiler";
H
hutuxian 已提交
291 292
  AutoSetCPUAffinity(true);

L
lilong12 已提交
293 294
  platform::Timer batch_timer;
  platform::Timer timeline;
H
hutuxian 已提交
295 296 297

  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
L
lilong12 已提交
298 299 300
  std::vector<double> op_max_time;
  std::vector<double> op_min_time;
  std::vector<uint64_t> op_count;
H
hutuxian 已提交
301 302 303 304
  for (auto& op : ops_) {
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
L
lilong12 已提交
305 306 307 308
  op_max_time.resize(ops_.size());
  op_min_time.resize(ops_.size());
  for (size_t i = 0; i < op_min_time.size(); ++i) {
    op_min_time[i] = DBL_MAX;
H
hutuxian 已提交
309
  }
L
lilong12 已提交
310 311 312 313 314 315 316 317 318 319 320
  op_count.resize(ops_.size());

  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  // const std::vector<std::string> keep_vars;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
321
    } else {
L
lilong12 已提交
322 323
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
324
    }
L
lilong12 已提交
325 326 327 328 329 330 331
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
332

L
lilong12 已提交
333
  if (thread_id_ == 0) {
S
update  
sandyhouse 已提交
334 335 336 337
    struct timeval start;
    struct timeval end;
    struct timeval micro_start;
    struct timeval micro_end;
L
lilong12 已提交
338 339 340
    while (true) {
      // Start a minibatch.
      batch_timer.Start();
341
      int real_microbatch_num = 0;
L
lilong12 已提交
342 343 344
      for (int i = 0; i < num_microbatches_; ++i) {
        try {
          int op_idx = 0;
S
update  
sandyhouse 已提交
345
          gettimeofday(&micro_start, NULL);
L
lilong12 已提交
346
          for (auto& op : ops_) {
S
update  
sandyhouse 已提交
347
            gettimeofday(&start, NULL);
L
lilong12 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
            int op_role = op->Attr<int>(std::string("op_role"));
            // We run op with op_role = kLRSched only for the first microbatch
            // to avoid increasing the @LR_DECAY_STEP@ multiple times.
            bool run_first_mbatch =
                op_role == static_cast<int>(OpRole::kForward) ||
                op_role == (static_cast<int>(OpRole::kForward) |
                            static_cast<int>(OpRole::kLoss)) ||
                op_role == static_cast<int>(OpRole::kLRSched);
            bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                              op_role == (static_cast<int>(OpRole::kForward) |
                                          static_cast<int>(OpRole::kLoss));
            if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
              VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                      << " for scope " << i;
              timeline.Start();
              op->Run(*microbatch_scopes_[i], place_);
              if (gc) {
                DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                    unused_vars_, gc.get());
              }
S
update  
sandyhouse 已提交
368
              cudaDeviceSynchronize();
L
lilong12 已提交
369
              timeline.Pause();
S
update  
sandyhouse 已提交
370
              gettimeofday(&end, NULL);
L
lilong12 已提交
371 372 373 374 375 376 377 378 379 380
              auto time = timeline.ElapsedUS();
              op_total_time[op_idx] += time;
              if (time > op_max_time[op_idx]) {
                op_max_time[op_idx] = time;
              }
              if (time < op_min_time[op_idx]) {
                op_min_time[op_idx] = time;
              }
              op_count[op_idx] += 1;
              op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
381 382 383 384 385 386 387 388
              {
                std::unique_lock<std::mutex> lk(cout_mutex);
	            std::cout << std::fixed;
                std::cout.precision(0);
                std::cout << "::FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i
                        << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                        << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
              }
L
lilong12 已提交
389 390 391
            }
            op_idx++;
          }
S
update  
sandyhouse 已提交
392 393 394 395 396 397 398 399 400
          gettimeofday(&micro_end, NULL);
          {
            std::unique_lock<std::mutex> lk(cout_mutex);
	        std::cout << std::fixed;
            std::cout.precision(0);
            std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                      << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                      << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
          }
L
lilong12 已提交
401 402 403 404 405 406 407 408 409 410 411 412 413 414
        } catch (platform::EOFException&) {
          std::unique_lock<std::mutex> lk(thread_mutex);
          threads_completed = true;
          VLOG(3) << "thread " << thread_id_ << " completed.";
          VLOG(3) << "called notify all";
          thread_condition.notify_all();
          VLOG(0) << "EOF encountered";
          VLOG(0) << "============timeline============";
          for (size_t i = 0; i < ops_.size(); ++i) {
            VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i]
                    << ", min_time: " << op_min_time[i]
                    << ", mean_time: " << op_total_time[i] / op_count[i];
          }
          VLOG(0) << "================================";
415
          break;
L
lilong12 已提交
416
        }
417
        {
L
lilong12 已提交
418 419
          VLOG(3) << "called notify all";
          std::unique_lock<std::mutex> lk(thread_mutex);
420
          real_microbatch_num += 1;
L
lilong12 已提交
421 422 423
          batch_id_ += 1;
          thread_condition.notify_all();
        }
H
hutuxian 已提交
424
      }
S
update  
sandyhouse 已提交
425
      dev_ctx_->Wait();
L
lilong12 已提交
426
      // backward pass
427
      for (int i = 0; i < real_microbatch_num; ++i) {
L
lilong12 已提交
428
        int op_idx = 0;
S
update  
sandyhouse 已提交
429
        gettimeofday(&micro_start, NULL);
L
lilong12 已提交
430
        for (auto& op : ops_) {
S
update  
sandyhouse 已提交
431
          gettimeofday(&start, NULL);
L
lilong12 已提交
432 433 434 435 436 437 438 439 440 441 442 443
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            timeline.Start();
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
S
update  
sandyhouse 已提交
444 445
            cudaDeviceSynchronize();
            gettimeofday(&end, NULL);
L
lilong12 已提交
446 447 448 449 450 451 452 453 454 455 456
            timeline.Pause();
            auto time = timeline.ElapsedUS();
            op_total_time[op_idx] += time;
            if (time > op_max_time[op_idx]) {
              op_max_time[op_idx] = time;
            }
            if (time < op_min_time[op_idx]) {
              op_min_time[op_idx] = time;
            }
            op_count[op_idx] += 1;
            op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
457 458 459 460 461 462 463 464
            {
              std::unique_lock<std::mutex> lk(cout_mutex);
	          std::cout << std::fixed;
              std::cout.precision(0);
              std::cout << "::BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i
                      << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                      << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
            }
L
lilong12 已提交
465 466
          }
          op_idx++;
H
hutuxian 已提交
467
        }
S
update  
sandyhouse 已提交
468 469 470 471 472 473 474 475 476
        gettimeofday(&micro_end, NULL);
        {
          std::unique_lock<std::mutex> lk(cout_mutex);
	      std::cout << std::fixed;
          std::cout.precision(0);
          std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                    << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                    << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
        }
H
hutuxian 已提交
477
      }
S
update  
sandyhouse 已提交
478
      dev_ctx_->Wait();
479 480 481 482
      if (real_microbatch_num == 0) {
        batch_timer.Pause();
        VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
      }
L
lilong12 已提交
483 484
      // update pass
      int op_idx = 0;
S
update  
sandyhouse 已提交
485
      gettimeofday(&micro_start, NULL);
L
lilong12 已提交
486
      for (auto& op : ops_) {
S
update  
sandyhouse 已提交
487
        gettimeofday(&start, NULL);
L
lilong12 已提交
488 489 490 491 492 493 494 495 496 497
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          timeline.Start();
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
S
update  
sandyhouse 已提交
498 499
          cudaDeviceSynchronize();
          gettimeofday(&end, NULL);
L
lilong12 已提交
500 501 502 503 504 505 506 507 508 509 510
          timeline.Pause();
          auto time = timeline.ElapsedUS();
          op_total_time[op_idx] += time;
          if (time > op_max_time[op_idx]) {
            op_max_time[op_idx] = time;
          }
          if (time < op_min_time[op_idx]) {
            op_min_time[op_idx] = time;
          }
          op_count[op_idx] += 1;
          op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
511 512 513 514 515 516 517 518
          {
            std::unique_lock<std::mutex> lk(cout_mutex);
	        std::cout << std::fixed;
            std::cout.precision(0);
            std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << num_microbatches_
                    << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                    << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
          }
L
lilong12 已提交
519 520 521
        }
        op_idx++;
      }
S
update  
sandyhouse 已提交
522 523 524 525 526 527 528 529 530 531 532 533
      gettimeofday(&micro_end, NULL);
      {
        std::unique_lock<std::mutex> lk(cout_mutex);
	    std::cout << std::fixed;
        std::cout.precision(0);
        std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_
                  << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                  << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
      }
      struct timeval wait_start;
      struct timeval wait_end;
      gettimeofday(&wait_start, NULL);
H
hutuxian 已提交
534
      dev_ctx_->Wait();
S
update  
sandyhouse 已提交
535 536
      gettimeofday(&wait_end, NULL);
      VLOG(0) << "device wait: " << wait_end.tv_sec * 1e6 + wait_end.tv_usec - wait_start.tv_sec * 1e6 - wait_start.tv_usec;
L
lilong12 已提交
537 538
      batch_timer.Pause();
      VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
H
hutuxian 已提交
539
    }
L
lilong12 已提交
540
  } else {
S
update  
sandyhouse 已提交
541 542 543 544 545 546 547
    struct timeval start;
    struct timeval end;
    struct timeval micro_start;
    struct timeval micro_end;
    cudaEvent_t cu_start, cu_stop;
    cudaEventCreate(&cu_start);
    cudaEventCreate(&cu_stop);
L
lilong12 已提交
548 549
    while (true) {
      // forward pass:
550
      int real_microbatch_num = 0;
L
lilong12 已提交
551
      for (int i = 0; i < num_microbatches_; ++i) {
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580
        {
          PADDLE_ENFORCE_LE(
              local_batch_id_, batch_id_,
              platform::errors::InvalidArgument(
                  "local_batch_id_ (%d) must be less than or equal to "
                  "batch_id_ (%d)",
                  local_batch_id_, batch_id_));
          std::unique_lock<std::mutex> lk(thread_mutex);
          if (local_batch_id_ == batch_id_ && !threads_completed) {
            thread_condition.wait(lk);
          }
          VLOG(3) << "thread " << thread_id_ << " local_batch_id_ "
                  << local_batch_id_ << " batch_id_ " << batch_id_;
          if (threads_completed) {
            VLOG(3) << "thread " << thread_id_ << " completed.";
            lk.unlock();
            VLOG(0) << "============timeline============";
            for (size_t i = 0; i < ops_.size(); ++i) {
              VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i]
                      << ", min_time: " << op_min_time[i]
                      << ", mean_time: " << op_total_time[i] / op_count[i];
            }
            VLOG(0) << "================================";
            break;
          }
          lk.unlock();
          real_microbatch_num += 1;
          local_batch_id_ += 1;
        }
L
lilong12 已提交
581
        int op_idx = 0;
S
update  
sandyhouse 已提交
582
        gettimeofday(&micro_start, NULL);
L
lilong12 已提交
583
        for (auto& op : ops_) {
S
update  
sandyhouse 已提交
584
          gettimeofday(&start, NULL);
L
lilong12 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
          int op_role = op->Attr<int>(std::string("op_role"));
          // We run op with op_role = kLRSched only for the first microbatch
          // to avoid increasing the @LR_DECAY_STEP@ multiple times.
          bool run_first_mbatch =
              op_role == static_cast<int>(OpRole::kForward) ||
              op_role == (static_cast<int>(OpRole::kForward) |
                          static_cast<int>(OpRole::kLoss)) ||
              op_role == static_cast<int>(OpRole::kLRSched);
          bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                            op_role == (static_cast<int>(OpRole::kForward) |
                                        static_cast<int>(OpRole::kLoss));
          if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            timeline.Start();
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
S
update  
sandyhouse 已提交
605 606
            cudaDeviceSynchronize();
            gettimeofday(&end, NULL);
L
lilong12 已提交
607 608 609 610 611 612 613 614 615 616 617
            timeline.Pause();
            auto time = timeline.ElapsedUS();
            op_total_time[op_idx] += time;
            if (time > op_max_time[op_idx]) {
              op_max_time[op_idx] = time;
            }
            if (time < op_min_time[op_idx]) {
              op_min_time[op_idx] = time;
            }
            op_count[op_idx] += 1;
            op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
618 619 620 621 622 623 624 625
            {
              std::unique_lock<std::mutex> lk(cout_mutex);
	          std::cout << std::fixed;
              std::cout.precision(0);
              std::cout << "::FWD:B[" << local_batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i
                      << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                      << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
            }
L
lilong12 已提交
626 627 628
          }
          op_idx++;
        }
S
update  
sandyhouse 已提交
629 630 631 632 633 634 635 636 637
        gettimeofday(&micro_end, NULL);
        {
          std::unique_lock<std::mutex> lk(cout_mutex);
	      std::cout << std::fixed;
          std::cout.precision(0);
          std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                    << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                    << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
        }
H
hutuxian 已提交
638
      }
S
update  
sandyhouse 已提交
639
      dev_ctx_->Wait();
L
lilong12 已提交
640
      // backward pass
641
      for (int i = 0; i < real_microbatch_num; ++i) {
L
lilong12 已提交
642
        int op_idx = 0;
S
update  
sandyhouse 已提交
643
        gettimeofday(&micro_start, NULL);
L
lilong12 已提交
644
        for (auto& op : ops_) {
S
update  
sandyhouse 已提交
645
          gettimeofday(&start, NULL);
L
lilong12 已提交
646 647 648 649 650 651 652 653 654 655 656 657
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            timeline.Start();
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
S
update  
sandyhouse 已提交
658 659
            cudaDeviceSynchronize();
            gettimeofday(&end, NULL);
L
lilong12 已提交
660 661 662 663 664 665 666 667 668 669 670
            timeline.Pause();
            auto time = timeline.ElapsedUS();
            op_total_time[op_idx] += time;
            if (time > op_max_time[op_idx]) {
              op_max_time[op_idx] = time;
            }
            if (time < op_min_time[op_idx]) {
              op_min_time[op_idx] = time;
            }
            op_count[op_idx] += 1;
            op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
671 672 673 674 675 676 677 678
            {
              std::unique_lock<std::mutex> lk(cout_mutex);
	          std::cout << std::fixed;
              std::cout.precision(0);
              std::cout << "::BWD:B[" << local_batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i
                      << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                      << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
            }
L
lilong12 已提交
679 680 681
          }
          op_idx++;
        }
S
update  
sandyhouse 已提交
682 683 684 685 686 687 688 689 690
        gettimeofday(&micro_end, NULL);
        {
          std::unique_lock<std::mutex> lk(cout_mutex);
	      std::cout << std::fixed;
          std::cout.precision(0);
          std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                    << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                    << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
        }
L
lilong12 已提交
691
      }
S
update  
sandyhouse 已提交
692
      dev_ctx_->Wait();
693 694 695
      if (real_microbatch_num == 0) {
        return;
      }
L
lilong12 已提交
696 697
      // update pass
      int op_idx = 0;
S
update  
sandyhouse 已提交
698
      gettimeofday(&micro_start, NULL);
L
lilong12 已提交
699
      for (auto& op : ops_) {
S
update  
sandyhouse 已提交
700
        gettimeofday(&start, NULL);
L
lilong12 已提交
701 702 703 704 705 706 707 708 709 710
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          timeline.Start();
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
S
update  
sandyhouse 已提交
711 712
          cudaDeviceSynchronize();
          gettimeofday(&end, NULL);
L
lilong12 已提交
713 714 715 716 717 718 719 720 721 722 723
          timeline.Pause();
          auto time = timeline.ElapsedUS();
          op_total_time[op_idx] += time;
          if (time > op_max_time[op_idx]) {
            op_max_time[op_idx] = time;
          }
          if (time < op_min_time[op_idx]) {
            op_min_time[op_idx] = time;
          }
          op_count[op_idx] += 1;
          op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
724 725 726 727 728 729 730 731
          {
            std::unique_lock<std::mutex> lk(cout_mutex);
	        std::cout << std::fixed;
            std::cout.precision(0);
            std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << num_microbatches_
                    << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                    << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
          }
L
lilong12 已提交
732 733 734
        }
        op_idx++;
      }
S
update  
sandyhouse 已提交
735 736 737 738 739 740 741 742 743
      gettimeofday(&micro_end, NULL);
      {
        std::unique_lock<std::mutex> lk(cout_mutex);
	    std::cout << std::fixed;
        std::cout.precision(0);
        std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_
                  << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                  << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
      }
L
lilong12 已提交
744
      dev_ctx_->Wait();
H
hutuxian 已提交
745 746 747 748 749 750
    }
  }
}
}  // namespace framework
}  // namespace paddle
#endif