section_worker.cc 21.7 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL)
L
lilong12 已提交
13 14 15 16 17
#include <float.h>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"

H
hutuxian 已提交
18 19 20 21 22
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"

#include "paddle/fluid/framework/device_worker.h"
H
hutuxian 已提交
23
#include "paddle/fluid/framework/fleet/box_wrapper.h"
H
hutuxian 已提交
24 25 26 27 28 29 30 31 32 33
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"

namespace paddle {
namespace framework {

std::atomic<int> SectionWorker::cpu_id_(0);
L
lilong12 已提交
34 35 36 37 38
std::mutex SectionWorker::thread_mutex;
std::condition_variable SectionWorker::thread_condition;
bool SectionWorker::threads_completed = false;
uint64_t SectionWorker::batch_id_(0);

H
hutuxian 已提交
39
void SectionWorker::Initialize(const TrainerDesc& desc) {
H
hutuxian 已提交
40
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
L
lilong12 已提交
41
  program_.reset(new ProgramDesc(
H
hutuxian 已提交
42
      desc.section_param().section_config(section_id_).program_desc()));
L
lilong12 已提交
43
  for (auto& op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
}

void SectionWorker::AutoSetCPUAffinity(bool reuse) {
  int thread_cpu_id = cpu_id_.fetch_add(1);

  unsigned concurrency_cap = std::thread::hardware_concurrency();
  unsigned proc = thread_cpu_id;

  if (proc >= concurrency_cap) {
    if (reuse) {
      proc %= concurrency_cap;
    } else {
      LOG(INFO) << "All " << concurrency_cap
                << " CPUs have been set affinities. Fail to set "
                << thread_cpu_id << "th thread";
      return;
    }
  }

  cpu_set_t mask;
  CPU_ZERO(&mask);
  CPU_SET(proc, &mask);

  if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
    return;
  }

  CPU_ZERO(&mask);
  if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) ||
      (0 == CPU_ISSET(proc, &mask))) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
  }
L
lilong12 已提交
79
  VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc;
H
hutuxian 已提交
80 81 82
}

void SectionWorker::TrainFiles() {
L
lilong12 已提交
83
  VLOG(3) << "begin section_worker TrainFiles";
H
hutuxian 已提交
84 85
  AutoSetCPUAffinity(true);

L
lilong12 已提交
86 87 88 89 90 91 92 93
  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
94
    } else {
L
lilong12 已提交
95 96
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
97
    }
L
lilong12 已提交
98 99 100 101 102 103 104
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
105

L
lilong12 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  if (thread_id_ == 0) {
    while (true) {
      // Start a minibatch.
      for (int i = 0; i < num_microbatches_; ++i) {
        try {
          for (auto& op : ops_) {
            int op_role = op->Attr<int>(std::string("op_role"));
            // We run op with op_role = kLRSched only for the first microbatch
            // to avoid increasing the @LR_DECAY_STEP@ multiple times.
            bool run_first_mbatch =
                op_role == static_cast<int>(OpRole::kForward) ||
                op_role == (static_cast<int>(OpRole::kForward) |
                            static_cast<int>(OpRole::kLoss)) ||
                op_role == static_cast<int>(OpRole::kLRSched);
            bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                              op_role == (static_cast<int>(OpRole::kForward) |
                                          static_cast<int>(OpRole::kLoss));
            if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
              VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                      << " for scope " << i;
              op->Run(*microbatch_scopes_[i], place_);
              if (gc) {
                DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                    unused_vars_, gc.get());
              }
            }
          }
        } catch (platform::EOFException&) {
          std::unique_lock<std::mutex> lk(thread_mutex);
          threads_completed = true;
          VLOG(3) << "thread " << thread_id_ << " completed.";
          VLOG(3) << "called notify all";
          thread_condition.notify_all();
          VLOG(0) << "EOF encountered";
          return;
        }
        if (i == 0) {
          VLOG(3) << "called notify all";
          std::unique_lock<std::mutex> lk(thread_mutex);
          batch_id_ += 1;
          thread_condition.notify_all();
        }
H
hutuxian 已提交
148
      }
L
lilong12 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
      // backward pass
      for (int i = 0; i < num_microbatches_; ++i) {
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
          }
H
hutuxian 已提交
164 165
        }
      }
L
lilong12 已提交
166 167 168 169 170 171 172 173 174 175 176 177
      // update pass
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
        }
H
hutuxian 已提交
178
      }
L
lilong12 已提交
179
      dev_ctx_->Wait();
H
hutuxian 已提交
180
    }
L
lilong12 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
  } else {
    while (true) {
      {
        PADDLE_ENFORCE_LE(
            local_batch_id_, batch_id_,
            platform::errors::InvalidArgument(
                "local_batch_id_ (%d) must be less than or equal to "
                "batch_id_ (%d)",
                local_batch_id_, batch_id_));
        std::unique_lock<std::mutex> lk(thread_mutex);
        if (local_batch_id_ == batch_id_ && !threads_completed) {
          thread_condition.wait(lk);
        }
        VLOG(3) << "thread " << thread_id_ << " local_batch_id_ "
                << local_batch_id_ << " batch_id_ " << batch_id_;
        if (threads_completed) {
          VLOG(3) << "thread " << thread_id_ << " completed.";
          lk.unlock();
          threads_completed = false;
          return;
        }
        lk.unlock();
        local_batch_id_ += 1;
H
hutuxian 已提交
204
      }
L
lilong12 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
      // forward pass:
      for (int i = 0; i < num_microbatches_; ++i) {
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          // We run op with op_role = kLRSched only for the first microbatch
          // to avoid increasing the @LR_DECAY_STEP@ multiple times.
          bool run_first_mbatch =
              op_role == static_cast<int>(OpRole::kForward) ||
              op_role == (static_cast<int>(OpRole::kForward) |
                          static_cast<int>(OpRole::kLoss)) ||
              op_role == static_cast<int>(OpRole::kLRSched);
          bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                            op_role == (static_cast<int>(OpRole::kForward) |
                                        static_cast<int>(OpRole::kLoss));
          if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
          }
        }
      }
      // backward pass
      for (int i = 0; i < num_microbatches_; ++i) {
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
          }
        }
      }
      // update pass
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
        }
      }
      dev_ctx_->Wait();
H
hutuxian 已提交
261 262 263 264 265
    }
  }
}

void SectionWorker::TrainFilesWithProfiler() {
L
lilong12 已提交
266
  VLOG(3) << "begin section_worker TrainFiles with profiler";
H
hutuxian 已提交
267 268
  AutoSetCPUAffinity(true);

L
lilong12 已提交
269 270
  platform::Timer batch_timer;
  platform::Timer timeline;
H
hutuxian 已提交
271 272 273

  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
L
lilong12 已提交
274 275 276
  std::vector<double> op_max_time;
  std::vector<double> op_min_time;
  std::vector<uint64_t> op_count;
H
hutuxian 已提交
277 278 279 280
  for (auto& op : ops_) {
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
L
lilong12 已提交
281 282 283 284
  op_max_time.resize(ops_.size());
  op_min_time.resize(ops_.size());
  for (size_t i = 0; i < op_min_time.size(); ++i) {
    op_min_time[i] = DBL_MAX;
H
hutuxian 已提交
285
  }
L
lilong12 已提交
286 287 288 289 290 291 292 293 294 295 296
  op_count.resize(ops_.size());

  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  // const std::vector<std::string> keep_vars;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
297
    } else {
L
lilong12 已提交
298 299
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
300
    }
L
lilong12 已提交
301 302 303 304 305 306 307
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
308

L
lilong12 已提交
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
  if (thread_id_ == 0) {
    while (true) {
      // Start a minibatch.
      // int batch_size = 0;
      batch_timer.Start();
      for (int i = 0; i < num_microbatches_; ++i) {
        try {
          int op_idx = 0;
          for (auto& op : ops_) {
            int op_role = op->Attr<int>(std::string("op_role"));
            // We run op with op_role = kLRSched only for the first microbatch
            // to avoid increasing the @LR_DECAY_STEP@ multiple times.
            bool run_first_mbatch =
                op_role == static_cast<int>(OpRole::kForward) ||
                op_role == (static_cast<int>(OpRole::kForward) |
                            static_cast<int>(OpRole::kLoss)) ||
                op_role == static_cast<int>(OpRole::kLRSched);
            bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                              op_role == (static_cast<int>(OpRole::kForward) |
                                          static_cast<int>(OpRole::kLoss));
            if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
              VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                      << " for scope " << i;
              timeline.Start();
              op->Run(*microbatch_scopes_[i], place_);
              if (gc) {
                DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                    unused_vars_, gc.get());
              }
              timeline.Pause();
              auto time = timeline.ElapsedUS();
              op_total_time[op_idx] += time;
              if (time > op_max_time[op_idx]) {
                op_max_time[op_idx] = time;
              }
              if (time < op_min_time[op_idx]) {
                op_min_time[op_idx] = time;
              }
              op_count[op_idx] += 1;
              op_total_time[op_idx] += time;
            }
            op_idx++;
          }
        } catch (platform::EOFException&) {
          std::unique_lock<std::mutex> lk(thread_mutex);
          threads_completed = true;
          VLOG(3) << "thread " << thread_id_ << " completed.";
          VLOG(3) << "called notify all";
          thread_condition.notify_all();
          VLOG(0) << "EOF encountered";
          VLOG(0) << "============timeline============";
          for (size_t i = 0; i < ops_.size(); ++i) {
            VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i]
                    << ", min_time: " << op_min_time[i]
                    << ", mean_time: " << op_total_time[i] / op_count[i];
          }
          VLOG(0) << "================================";
          return;
        }
        if (i == 0) {
          VLOG(3) << "called notify all";
          std::unique_lock<std::mutex> lk(thread_mutex);
          batch_id_ += 1;
          thread_condition.notify_all();
        }
H
hutuxian 已提交
374
      }
L
lilong12 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
      // backward pass
      for (int i = 0; i < num_microbatches_; ++i) {
        int op_idx = 0;
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            timeline.Start();
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
            timeline.Pause();
            auto time = timeline.ElapsedUS();
            op_total_time[op_idx] += time;
            if (time > op_max_time[op_idx]) {
              op_max_time[op_idx] = time;
            }
            if (time < op_min_time[op_idx]) {
              op_min_time[op_idx] = time;
            }
            op_count[op_idx] += 1;
            op_total_time[op_idx] += time;
          }
          op_idx++;
H
hutuxian 已提交
404 405
        }
      }
L
lilong12 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
      // update pass
      int op_idx = 0;
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          timeline.Start();
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
          timeline.Pause();
          auto time = timeline.ElapsedUS();
          op_total_time[op_idx] += time;
          if (time > op_max_time[op_idx]) {
            op_max_time[op_idx] = time;
          }
          if (time < op_min_time[op_idx]) {
            op_min_time[op_idx] = time;
          }
          op_count[op_idx] += 1;
          op_total_time[op_idx] += time;
        }
        op_idx++;
      }
H
hutuxian 已提交
433
      dev_ctx_->Wait();
L
lilong12 已提交
434 435
      batch_timer.Pause();
      VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
H
hutuxian 已提交
436
    }
L
lilong12 已提交
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
  } else {
    while (true) {
      {
        PADDLE_ENFORCE_LE(
            local_batch_id_, batch_id_,
            platform::errors::InvalidArgument(
                "local_batch_id_ (%d) must be less than or equal to "
                "batch_id_ (%d)",
                local_batch_id_, batch_id_));
        std::unique_lock<std::mutex> lk(thread_mutex);
        if (local_batch_id_ == batch_id_ && !threads_completed) {
          thread_condition.wait(lk);
        }
        VLOG(3) << "thread " << thread_id_ << " local_batch_id_ "
                << local_batch_id_ << " batch_id_ " << batch_id_;
        if (threads_completed) {
          VLOG(3) << "thread " << thread_id_ << " completed.";
          lk.unlock();
          VLOG(0) << "============timeline============";
          for (size_t i = 0; i < ops_.size(); ++i) {
            VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i]
                    << ", min_time: " << op_min_time[i]
                    << ", mean_time: " << op_total_time[i] / op_count[i];
          }
          VLOG(0) << "================================";
          threads_completed = false;
          return;
        }
        lk.unlock();
        local_batch_id_ += 1;
H
hutuxian 已提交
467
      }
L
lilong12 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
      // forward pass:
      for (int i = 0; i < num_microbatches_; ++i) {
        int op_idx = 0;
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          // We run op with op_role = kLRSched only for the first microbatch
          // to avoid increasing the @LR_DECAY_STEP@ multiple times.
          bool run_first_mbatch =
              op_role == static_cast<int>(OpRole::kForward) ||
              op_role == (static_cast<int>(OpRole::kForward) |
                          static_cast<int>(OpRole::kLoss)) ||
              op_role == static_cast<int>(OpRole::kLRSched);
          bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                            op_role == (static_cast<int>(OpRole::kForward) |
                                        static_cast<int>(OpRole::kLoss));
          if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            timeline.Start();
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
            timeline.Pause();
            auto time = timeline.ElapsedUS();
            op_total_time[op_idx] += time;
            if (time > op_max_time[op_idx]) {
              op_max_time[op_idx] = time;
            }
            if (time < op_min_time[op_idx]) {
              op_min_time[op_idx] = time;
            }
            op_count[op_idx] += 1;
            op_total_time[op_idx] += time;
          }
          op_idx++;
        }
H
hutuxian 已提交
506
      }
L
lilong12 已提交
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
      // backward pass
      for (int i = 0; i < num_microbatches_; ++i) {
        int op_idx = 0;
        for (auto& op : ops_) {
          int op_role = op->Attr<int>(std::string("op_role"));
          if (op_role == static_cast<int>(OpRole::kBackward) ||
              op_role == (static_cast<int>(OpRole::kBackward) |
                          static_cast<int>(OpRole::kLoss))) {
            VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                    << " for scope " << i;
            timeline.Start();
            op->Run(*microbatch_scopes_[i], place_);
            if (gc) {
              DeleteUnusedTensors(*microbatch_scopes_[i], op.get(),
                                  unused_vars_, gc.get());
            }
            timeline.Pause();
            auto time = timeline.ElapsedUS();
            op_total_time[op_idx] += time;
            if (time > op_max_time[op_idx]) {
              op_max_time[op_idx] = time;
            }
            if (time < op_min_time[op_idx]) {
              op_min_time[op_idx] = time;
            }
            op_count[op_idx] += 1;
            op_total_time[op_idx] += time;
          }
          op_idx++;
        }
      }
      // update pass
      int op_idx = 0;
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
        if (op_role == static_cast<int>(OpRole::kOptimize)) {
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
                  << " for minibatch scope";
          timeline.Start();
          op->Run(*microbatch_scopes_[0], place_);
          if (gc) {
            DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                                op.get(), unused_vars_, gc.get());
          }
          timeline.Pause();
          auto time = timeline.ElapsedUS();
          op_total_time[op_idx] += time;
          if (time > op_max_time[op_idx]) {
            op_max_time[op_idx] = time;
          }
          if (time < op_min_time[op_idx]) {
            op_min_time[op_idx] = time;
          }
          op_count[op_idx] += 1;
          op_total_time[op_idx] += time;
        }
        op_idx++;
      }
      dev_ctx_->Wait();
H
hutuxian 已提交
566 567 568 569 570 571
    }
  }
}
}  // namespace framework
}  // namespace paddle
#endif