section_worker.cc 29.8 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL)
L
lilong12 已提交
13 14 15 16 17
#include <float.h>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"

H
hutuxian 已提交
18 19 20 21 22
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"

#include "paddle/fluid/framework/device_worker.h"
H
hutuxian 已提交
23
#include "paddle/fluid/framework/fleet/box_wrapper.h"
H
hutuxian 已提交
24 25 26 27 28 29 30 31 32 33
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"

namespace paddle {
namespace framework {

std::atomic<int> SectionWorker::cpu_id_(0);
L
lilong12 已提交
34
std::mutex SectionWorker::thread_mutex;
S
update  
sandyhouse 已提交
35
std::mutex SectionWorker::cout_mutex;
L
lilong12 已提交
36 37 38 39
std::condition_variable SectionWorker::thread_condition;
bool SectionWorker::threads_completed = false;
uint64_t SectionWorker::batch_id_(0);

H
hutuxian 已提交
40
void SectionWorker::Initialize(const TrainerDesc& desc) {
H
hutuxian 已提交
41
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
L
lilong12 已提交
42
  program_.reset(new ProgramDesc(
H
hutuxian 已提交
43
      desc.section_param().section_config(section_id_).program_desc()));
L
lilong12 已提交
44
  for (auto& op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
}

void SectionWorker::AutoSetCPUAffinity(bool reuse) {
  int thread_cpu_id = cpu_id_.fetch_add(1);

  unsigned concurrency_cap = std::thread::hardware_concurrency();
  unsigned proc = thread_cpu_id;

  if (proc >= concurrency_cap) {
    if (reuse) {
      proc %= concurrency_cap;
    } else {
      LOG(INFO) << "All " << concurrency_cap
                << " CPUs have been set affinities. Fail to set "
                << thread_cpu_id << "th thread";
      return;
    }
  }

  cpu_set_t mask;
  CPU_ZERO(&mask);
  CPU_SET(proc, &mask);

  if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
    return;
  }

  CPU_ZERO(&mask);
  if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) ||
      (0 == CPU_ISSET(proc, &mask))) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
  }
L
lilong12 已提交
80
  VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc;
H
hutuxian 已提交
81 82 83
}

void SectionWorker::TrainFiles() {
L
lilong12 已提交
84
  VLOG(3) << "begin section_worker TrainFiles";
H
hutuxian 已提交
85 86
  AutoSetCPUAffinity(true);

L
lilong12 已提交
87 88 89 90 91 92 93 94
  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
95
    } else {
L
lilong12 已提交
96 97
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
98
    }
L
lilong12 已提交
99 100 101 102 103 104 105
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
106

S
update  
sandyhouse 已提交
107 108
  platform::Timer batch_timer;

L
lilong12 已提交
109 110 111
  if (thread_id_ == 0) {
    while (true) {
      // Start a minibatch.
112 113
      // real number of microbatches run
      int real_microbatch_num = 0;
S
update  
sandyhouse 已提交
114
      batch_timer.Start();
L
lilong12 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
      for (int i = 0; i < num_microbatches_; ++i) {
        try {
          for (auto& op : ops_) {
            int op_role = op->Attr<int>(std::string("op_role"));
            // We run op with op_role = kLRSched only for the first microbatch
            // to avoid increasing the @LR_DECAY_STEP@ multiple times.
            bool run_first_mbatch =
                op_role == static_cast<int>(OpRole::kForward) ||
                op_role == (static_cast<int>(OpRole::kForward) |
                            static_cast<int>(OpRole::kLoss)) ||
                op_role == static_cast<int>(OpRole::kLRSched);
            bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                              op_role == (static_cast<int>(OpRole::kForward) |
                                          static_cast<int>(OpRole::kLoss));
            if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
              VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                      << " for scope " << i;
              op->Run(*microbatch_scopes_[i], place_);
              if (gc) {
                DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                    unused_vars_, gc.get());
              }
            }
          }
        } catch (platform::EOFException&) {
          std::unique_lock<std::mutex> lk(thread_mutex);
          threads_completed = true;
          VLOG(3) << "thread " << thread_id_ << " completed.";
          VLOG(3) << "called notify all";
          thread_condition.notify_all();
          VLOG(0) << "EOF encountered";
146
          break;
L
lilong12 已提交
147
        }
148 149 150
        {
          real_microbatch_num += 1;
          batch_id_ += 1;
L
lilong12 已提交
151 152 153 154
          VLOG(3) << "called notify all";
          std::unique_lock<std::mutex> lk(thread_mutex);
          thread_condition.notify_all();
        }
H
hutuxian 已提交
155
      }
S
update  
sandyhouse 已提交
156
      dev_ctx_->Wait();
157
      VLOG(0) << "real_microbatch_num for thread 0 " << real_microbatch_num;
L
lilong12 已提交
158
      // backward pass
159
      for (int i = 0; i < real_microbatch_num; ++i) {
L
lilong12 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
          }
H
hutuxian 已提交
173 174
        }
      }
S
update  
sandyhouse 已提交
175
      dev_ctx_->Wait();
176 177 178 179 180
      if (real_microbatch_num == 0) {
        batch_timer.Pause();
        VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
        return;
      }
L
lilong12 已提交
181 182 183 184 185 186 187 188 189 190 191 192
      // update pass
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
        }
H
hutuxian 已提交
193
      }
L
lilong12 已提交
194
      dev_ctx_->Wait();
S
update  
sandyhouse 已提交
195 196
      batch_timer.Pause();
      VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
S
sandyhouse 已提交
197 198 199 200 201 202
      {
        std::unique_lock<std::mutex> lk(thread_mutex);
        if (threads_completed) {
          return;
        }
      }
H
hutuxian 已提交
203
    }
L
lilong12 已提交
204 205 206
  } else {
    while (true) {
      // forward pass:
S
sandyhouse 已提交
207
      bool local_completed = false;
208
      int real_microbatch_num = 0;
L
lilong12 已提交
209
      for (int i = 0; i < num_microbatches_; ++i) {
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
        {
          PADDLE_ENFORCE_LE(
              local_batch_id_, batch_id_,
              platform::errors::InvalidArgument(
                  "local_batch_id_ (%d) must be less than or equal to "
                  "batch_id_ (%d)",
                  local_batch_id_, batch_id_));
          std::unique_lock<std::mutex> lk(thread_mutex);
          if (local_batch_id_ == batch_id_ && !threads_completed) {
            thread_condition.wait(lk);
          }
          VLOG(3) << "thread " << thread_id_ << " local_batch_id_ "
                  << local_batch_id_ << " batch_id_ " << batch_id_;
          if (threads_completed) {
            VLOG(3) << "thread " << thread_id_ << " completed.";
            lk.unlock();
            threads_completed = false;
S
sandyhouse 已提交
227
            local_completed = true;
228 229 230 231 232 233
            break;
          }
          lk.unlock();
          local_batch_id_ += 1;
          real_microbatch_num += 1;
        }
L
lilong12 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          // We run op with op_role = kLRSched only for the first microbatch
          // to avoid increasing the @LR_DECAY_STEP@ multiple times.
          bool run_first_mbatch =
              op_role == static_cast<int>(OpRole::kForward) ||
              op_role == (static_cast<int>(OpRole::kForward) |
                          static_cast<int>(OpRole::kLoss)) ||
              op_role == static_cast<int>(OpRole::kLRSched);
          bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                            op_role == (static_cast<int>(OpRole::kForward) |
                                        static_cast<int>(OpRole::kLoss));
          if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
          }
        }
      }
S
update  
sandyhouse 已提交
257
      dev_ctx_->Wait();
L
lilong12 已提交
258
      // backward pass
259
      for (int i = 0; i < real_microbatch_num; ++i) {
L
lilong12 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
          }
        }
      }
S
update  
sandyhouse 已提交
275
      dev_ctx_->Wait();
L
lilong12 已提交
276
      // update pass
277 278 279
      if (real_microbatch_num == 0) {
        return;
      }
L
lilong12 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
        }
      }
      dev_ctx_->Wait();
S
sandyhouse 已提交
293 294 295
      if (local_completed) {
        return;
      }
H
hutuxian 已提交
296 297 298 299 300
    }
  }
}

void SectionWorker::TrainFilesWithProfiler() {
L
lilong12 已提交
301
  VLOG(3) << "begin section_worker TrainFiles with profiler";
H
hutuxian 已提交
302 303
  AutoSetCPUAffinity(true);

L
lilong12 已提交
304 305
  platform::Timer batch_timer;
  platform::Timer timeline;
H
hutuxian 已提交
306 307 308

  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
L
lilong12 已提交
309 310 311
  std::vector<double> op_max_time;
  std::vector<double> op_min_time;
  std::vector<uint64_t> op_count;
H
hutuxian 已提交
312 313 314 315
  for (auto& op : ops_) {
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
L
lilong12 已提交
316 317 318 319
  op_max_time.resize(ops_.size());
  op_min_time.resize(ops_.size());
  for (size_t i = 0; i < op_min_time.size(); ++i) {
    op_min_time[i] = DBL_MAX;
H
hutuxian 已提交
320
  }
L
lilong12 已提交
321 322 323 324 325 326 327 328 329 330 331
  op_count.resize(ops_.size());

  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  // const std::vector<std::string> keep_vars;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
332
    } else {
L
lilong12 已提交
333 334
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
335
    }
L
lilong12 已提交
336 337 338 339 340 341 342
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
343

L
lilong12 已提交
344
  if (thread_id_ == 0) {
S
update  
sandyhouse 已提交
345 346 347 348
    struct timeval start;
    struct timeval end;
    struct timeval micro_start;
    struct timeval micro_end;
L
lilong12 已提交
349 350 351
    while (true) {
      // Start a minibatch.
      batch_timer.Start();
352
      int real_microbatch_num = 0;
L
lilong12 已提交
353 354 355
      for (int i = 0; i < num_microbatches_; ++i) {
        try {
          int op_idx = 0;
S
update  
sandyhouse 已提交
356
          gettimeofday(&micro_start, NULL);
L
lilong12 已提交
357
          for (auto& op : ops_) {
S
update  
sandyhouse 已提交
358
            gettimeofday(&start, NULL);
L
lilong12 已提交
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
            int op_role = op->Attr<int>(std::string("op_role"));
            // We run op with op_role = kLRSched only for the first microbatch
            // to avoid increasing the @LR_DECAY_STEP@ multiple times.
            bool run_first_mbatch =
                op_role == static_cast<int>(OpRole::kForward) ||
                op_role == (static_cast<int>(OpRole::kForward) |
                            static_cast<int>(OpRole::kLoss)) ||
                op_role == static_cast<int>(OpRole::kLRSched);
            bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                              op_role == (static_cast<int>(OpRole::kForward) |
                                          static_cast<int>(OpRole::kLoss));
            if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
              VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                      << " for scope " << i;
              timeline.Start();
              op->Run(*microbatch_scopes_[i], place_);
              if (gc) {
                DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                    unused_vars_, gc.get());
              }
S
update  
sandyhouse 已提交
379
              cudaDeviceSynchronize();
L
lilong12 已提交
380
              timeline.Pause();
S
update  
sandyhouse 已提交
381
              gettimeofday(&end, NULL);
L
lilong12 已提交
382 383 384 385 386 387 388 389 390 391
              auto time = timeline.ElapsedUS();
              op_total_time[op_idx] += time;
              if (time > op_max_time[op_idx]) {
                op_max_time[op_idx] = time;
              }
              if (time < op_min_time[op_idx]) {
                op_min_time[op_idx] = time;
              }
              op_count[op_idx] += 1;
              op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
392 393 394 395 396 397 398 399
              {
                std::unique_lock<std::mutex> lk(cout_mutex);
	            std::cout << std::fixed;
                std::cout.precision(0);
                std::cout << "::FWD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i
                        << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                        << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
              }
L
lilong12 已提交
400 401 402
            }
            op_idx++;
          }
S
update  
sandyhouse 已提交
403 404 405 406 407 408 409 410 411
          gettimeofday(&micro_end, NULL);
          {
            std::unique_lock<std::mutex> lk(cout_mutex);
	        std::cout << std::fixed;
            std::cout.precision(0);
            std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                      << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                      << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
          }
L
lilong12 已提交
412 413 414 415 416 417 418 419 420 421 422 423 424 425
        } catch (platform::EOFException&) {
          std::unique_lock<std::mutex> lk(thread_mutex);
          threads_completed = true;
          VLOG(3) << "thread " << thread_id_ << " completed.";
          VLOG(3) << "called notify all";
          thread_condition.notify_all();
          VLOG(0) << "EOF encountered";
          VLOG(0) << "============timeline============";
          for (size_t i = 0; i < ops_.size(); ++i) {
            VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i]
                    << ", min_time: " << op_min_time[i]
                    << ", mean_time: " << op_total_time[i] / op_count[i];
          }
          VLOG(0) << "================================";
426
          break;
L
lilong12 已提交
427
        }
428
        {
L
lilong12 已提交
429 430
          VLOG(3) << "called notify all";
          std::unique_lock<std::mutex> lk(thread_mutex);
431
          real_microbatch_num += 1;
L
lilong12 已提交
432 433 434
          batch_id_ += 1;
          thread_condition.notify_all();
        }
H
hutuxian 已提交
435
      }
S
update  
sandyhouse 已提交
436
      dev_ctx_->Wait();
L
lilong12 已提交
437
      // backward pass
438
      for (int i = 0; i < real_microbatch_num; ++i) {
L
lilong12 已提交
439
        int op_idx = 0;
S
update  
sandyhouse 已提交
440
        gettimeofday(&micro_start, NULL);
L
lilong12 已提交
441
        for (auto& op : ops_) {
S
update  
sandyhouse 已提交
442
          gettimeofday(&start, NULL);
L
lilong12 已提交
443 444 445 446 447 448 449 450 451 452 453 454
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            timeline.Start();
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
S
update  
sandyhouse 已提交
455 456
            cudaDeviceSynchronize();
            gettimeofday(&end, NULL);
L
lilong12 已提交
457 458 459 460 461 462 463 464 465 466 467
            timeline.Pause();
            auto time = timeline.ElapsedUS();
            op_total_time[op_idx] += time;
            if (time > op_max_time[op_idx]) {
              op_max_time[op_idx] = time;
            }
            if (time < op_min_time[op_idx]) {
              op_min_time[op_idx] = time;
            }
            op_count[op_idx] += 1;
            op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
468 469 470 471 472 473 474 475
            {
              std::unique_lock<std::mutex> lk(cout_mutex);
	          std::cout << std::fixed;
              std::cout.precision(0);
              std::cout << "::BWD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i
                      << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                      << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
            }
L
lilong12 已提交
476 477
          }
          op_idx++;
H
hutuxian 已提交
478
        }
S
update  
sandyhouse 已提交
479 480 481 482 483 484 485 486 487
        gettimeofday(&micro_end, NULL);
        {
          std::unique_lock<std::mutex> lk(cout_mutex);
	      std::cout << std::fixed;
          std::cout.precision(0);
          std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                    << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                    << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
        }
H
hutuxian 已提交
488
      }
S
update  
sandyhouse 已提交
489
      dev_ctx_->Wait();
490 491 492
      if (real_microbatch_num == 0) {
        batch_timer.Pause();
        VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
S
sandyhouse 已提交
493
        return;
494
      }
L
lilong12 已提交
495 496
      // update pass
      int op_idx = 0;
S
update  
sandyhouse 已提交
497
      gettimeofday(&micro_start, NULL);
L
lilong12 已提交
498
      for (auto& op : ops_) {
S
update  
sandyhouse 已提交
499
        gettimeofday(&start, NULL);
L
lilong12 已提交
500 501 502 503 504 505 506 507 508 509
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          timeline.Start();
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
S
update  
sandyhouse 已提交
510 511
          cudaDeviceSynchronize();
          gettimeofday(&end, NULL);
L
lilong12 已提交
512 513 514 515 516 517 518 519 520 521 522
          timeline.Pause();
          auto time = timeline.ElapsedUS();
          op_total_time[op_idx] += time;
          if (time > op_max_time[op_idx]) {
            op_max_time[op_idx] = time;
          }
          if (time < op_min_time[op_idx]) {
            op_min_time[op_idx] = time;
          }
          op_count[op_idx] += 1;
          op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
523 524 525 526 527 528 529 530
          {
            std::unique_lock<std::mutex> lk(cout_mutex);
	        std::cout << std::fixed;
            std::cout.precision(0);
            std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << num_microbatches_
                    << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                    << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
          }
L
lilong12 已提交
531 532 533
        }
        op_idx++;
      }
S
update  
sandyhouse 已提交
534 535 536 537 538 539 540 541 542
      gettimeofday(&micro_end, NULL);
      {
        std::unique_lock<std::mutex> lk(cout_mutex);
	    std::cout << std::fixed;
        std::cout.precision(0);
        std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_
                  << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                  << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
      }
H
hutuxian 已提交
543
      dev_ctx_->Wait();
L
lilong12 已提交
544 545
      batch_timer.Pause();
      VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
S
sandyhouse 已提交
546 547 548 549 550 551
      {
        std::unique_lock<std::mutex> lk(thread_mutex);
        if (threads_completed) {
          return;
        }
      }
H
hutuxian 已提交
552
    }
L
lilong12 已提交
553
  } else {
S
update  
sandyhouse 已提交
554 555 556 557 558 559 560
    struct timeval start;
    struct timeval end;
    struct timeval micro_start;
    struct timeval micro_end;
    cudaEvent_t cu_start, cu_stop;
    cudaEventCreate(&cu_start);
    cudaEventCreate(&cu_stop);
S
sandyhouse 已提交
561
    bool local_completed = false;
L
lilong12 已提交
562 563
    while (true) {
      // forward pass:
564
      int real_microbatch_num = 0;
L
lilong12 已提交
565
      for (int i = 0; i < num_microbatches_; ++i) {
566 567 568 569 570 571 572 573 574 575 576 577 578 579
        {
          PADDLE_ENFORCE_LE(
              local_batch_id_, batch_id_,
              platform::errors::InvalidArgument(
                  "local_batch_id_ (%d) must be less than or equal to "
                  "batch_id_ (%d)",
                  local_batch_id_, batch_id_));
          std::unique_lock<std::mutex> lk(thread_mutex);
          if (local_batch_id_ == batch_id_ && !threads_completed) {
            thread_condition.wait(lk);
          }
          VLOG(3) << "thread " << thread_id_ << " local_batch_id_ "
                  << local_batch_id_ << " batch_id_ " << batch_id_;
          if (threads_completed) {
S
sandyhouse 已提交
580
            local_completed = true;
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
            VLOG(3) << "thread " << thread_id_ << " completed.";
            lk.unlock();
            VLOG(0) << "============timeline============";
            for (size_t i = 0; i < ops_.size(); ++i) {
              VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i]
                      << ", min_time: " << op_min_time[i]
                      << ", mean_time: " << op_total_time[i] / op_count[i];
            }
            VLOG(0) << "================================";
            break;
          }
          lk.unlock();
          real_microbatch_num += 1;
          local_batch_id_ += 1;
        }
L
lilong12 已提交
596
        int op_idx = 0;
S
update  
sandyhouse 已提交
597
        gettimeofday(&micro_start, NULL);
L
lilong12 已提交
598
        for (auto& op : ops_) {
S
update  
sandyhouse 已提交
599
          gettimeofday(&start, NULL);
L
lilong12 已提交
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619
          int op_role = op->Attr<int>(std::string("op_role"));
          // We run op with op_role = kLRSched only for the first microbatch
          // to avoid increasing the @LR_DECAY_STEP@ multiple times.
          bool run_first_mbatch =
              op_role == static_cast<int>(OpRole::kForward) ||
              op_role == (static_cast<int>(OpRole::kForward) |
                          static_cast<int>(OpRole::kLoss)) ||
              op_role == static_cast<int>(OpRole::kLRSched);
          bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                            op_role == (static_cast<int>(OpRole::kForward) |
                                        static_cast<int>(OpRole::kLoss));
          if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            timeline.Start();
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
S
update  
sandyhouse 已提交
620 621
            cudaDeviceSynchronize();
            gettimeofday(&end, NULL);
L
lilong12 已提交
622 623 624 625 626 627 628 629 630 631 632
            timeline.Pause();
            auto time = timeline.ElapsedUS();
            op_total_time[op_idx] += time;
            if (time > op_max_time[op_idx]) {
              op_max_time[op_idx] = time;
            }
            if (time < op_min_time[op_idx]) {
              op_min_time[op_idx] = time;
            }
            op_count[op_idx] += 1;
            op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
633 634 635 636 637 638 639 640
            {
              std::unique_lock<std::mutex> lk(cout_mutex);
	          std::cout << std::fixed;
              std::cout.precision(0);
              std::cout << "::FWD:B[" << local_batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i
                      << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                      << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
            }
L
lilong12 已提交
641 642 643
          }
          op_idx++;
        }
S
update  
sandyhouse 已提交
644 645 646 647 648 649 650 651 652
        gettimeofday(&micro_end, NULL);
        {
          std::unique_lock<std::mutex> lk(cout_mutex);
	      std::cout << std::fixed;
          std::cout.precision(0);
          std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                    << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                    << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
        }
H
hutuxian 已提交
653
      }
S
update  
sandyhouse 已提交
654
      dev_ctx_->Wait();
L
lilong12 已提交
655
      // backward pass
656
      for (int i = 0; i < real_microbatch_num; ++i) {
L
lilong12 已提交
657
        int op_idx = 0;
S
update  
sandyhouse 已提交
658
        gettimeofday(&micro_start, NULL);
L
lilong12 已提交
659
        for (auto& op : ops_) {
S
update  
sandyhouse 已提交
660
          gettimeofday(&start, NULL);
L
lilong12 已提交
661 662 663 664 665 666 667 668 669 670 671 672
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            timeline.Start();
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
S
update  
sandyhouse 已提交
673 674
            cudaDeviceSynchronize();
            gettimeofday(&end, NULL);
L
lilong12 已提交
675 676 677 678 679 680 681 682 683 684 685
            timeline.Pause();
            auto time = timeline.ElapsedUS();
            op_total_time[op_idx] += time;
            if (time > op_max_time[op_idx]) {
              op_max_time[op_idx] = time;
            }
            if (time < op_min_time[op_idx]) {
              op_min_time[op_idx] = time;
            }
            op_count[op_idx] += 1;
            op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
686 687 688 689 690 691 692 693
            {
              std::unique_lock<std::mutex> lk(cout_mutex);
	          std::cout << std::fixed;
              std::cout.precision(0);
              std::cout << "::BWD:B[" << local_batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << i
                      << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                      << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
            }
L
lilong12 已提交
694 695 696
          }
          op_idx++;
        }
S
update  
sandyhouse 已提交
697 698 699 700 701 702 703 704 705
        gettimeofday(&micro_end, NULL);
        {
          std::unique_lock<std::mutex> lk(cout_mutex);
	      std::cout << std::fixed;
          std::cout.precision(0);
          std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                    << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                    << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
        }
L
lilong12 已提交
706
      }
S
update  
sandyhouse 已提交
707
      dev_ctx_->Wait();
708 709 710
      if (real_microbatch_num == 0) {
        return;
      }
L
lilong12 已提交
711 712
      // update pass
      int op_idx = 0;
S
update  
sandyhouse 已提交
713
      gettimeofday(&micro_start, NULL);
L
lilong12 已提交
714
      for (auto& op : ops_) {
S
update  
sandyhouse 已提交
715
        gettimeofday(&start, NULL);
L
lilong12 已提交
716 717 718 719 720 721 722 723 724 725
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          timeline.Start();
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
S
update  
sandyhouse 已提交
726 727
          cudaDeviceSynchronize();
          gettimeofday(&end, NULL);
L
lilong12 已提交
728 729 730 731 732 733 734 735 736 737 738
          timeline.Pause();
          auto time = timeline.ElapsedUS();
          op_total_time[op_idx] += time;
          if (time > op_max_time[op_idx]) {
            op_max_time[op_idx] = time;
          }
          if (time < op_min_time[op_idx]) {
            op_min_time[op_idx] = time;
          }
          op_count[op_idx] += 1;
          op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
739 740 741 742 743 744 745 746
          {
            std::unique_lock<std::mutex> lk(cout_mutex);
	        std::cout << std::fixed;
            std::cout.precision(0);
            std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:SCOPE[" << num_microbatches_
                    << "]:OP[" << op->Type() << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                    << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
          }
L
lilong12 已提交
747 748 749
        }
        op_idx++;
      }
S
update  
sandyhouse 已提交
750 751 752 753 754 755 756 757 758
      gettimeofday(&micro_end, NULL);
      {
        std::unique_lock<std::mutex> lk(cout_mutex);
	    std::cout << std::fixed;
        std::cout.precision(0);
        std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_
                  << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                  << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
      }
L
lilong12 已提交
759
      dev_ctx_->Wait();
S
sandyhouse 已提交
760 761 762
      if (local_completed) {
        return;
      }
H
hutuxian 已提交
763 764 765 766 767 768
    }
  }
}
}  // namespace framework
}  // namespace paddle
#endif