section_worker.cc 15.5 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL)
L
lilong12 已提交
13 14 15 16 17
#include <float.h>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"

H
hutuxian 已提交
18 19 20 21 22
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"

#include "paddle/fluid/framework/device_worker.h"
H
hutuxian 已提交
23
#include "paddle/fluid/framework/fleet/box_wrapper.h"
H
hutuxian 已提交
24 25 26 27 28 29 30 31 32 33
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"

namespace paddle {
namespace framework {

std::atomic<int> SectionWorker::cpu_id_(0);
S
sandyhouse 已提交
34 35 36 37
// std::mutex SectionWorker::thread_mutex;
// std::mutex SectionWorker::cout_mutex;
// std::condition_variable SectionWorker::thread_condition;
// bool SectionWorker::threads_completed = false;
L
lilong12 已提交
38 39
uint64_t SectionWorker::batch_id_(0);

H
hutuxian 已提交
40
void SectionWorker::Initialize(const TrainerDesc& desc) {
H
hutuxian 已提交
41
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
S
sandyhouse 已提交
42 43 44
  program_.reset(
      new ProgramDesc(desc.section_param().section_config().program_desc()));
  // desc.section_param().section_config(section_id_).program_desc()));
L
lilong12 已提交
45
  for (auto& op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
}

void SectionWorker::AutoSetCPUAffinity(bool reuse) {
  int thread_cpu_id = cpu_id_.fetch_add(1);

  unsigned concurrency_cap = std::thread::hardware_concurrency();
  unsigned proc = thread_cpu_id;

  if (proc >= concurrency_cap) {
    if (reuse) {
      proc %= concurrency_cap;
    } else {
      LOG(INFO) << "All " << concurrency_cap
                << " CPUs have been set affinities. Fail to set "
                << thread_cpu_id << "th thread";
      return;
    }
  }

  cpu_set_t mask;
  CPU_ZERO(&mask);
  CPU_SET(proc, &mask);

  if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
    return;
  }

  CPU_ZERO(&mask);
  if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) ||
      (0 == CPU_ISSET(proc, &mask))) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
  }
L
lilong12 已提交
81
  VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc;
H
hutuxian 已提交
82 83 84
}

void SectionWorker::TrainFiles() {
L
lilong12 已提交
85
  VLOG(3) << "begin section_worker TrainFiles";
S
sandyhouse 已提交
86
  // AutoSetCPUAffinity(true);
H
hutuxian 已提交
87

L
lilong12 已提交
88 89 90 91 92 93 94 95
  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
96
    } else {
L
lilong12 已提交
97 98
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
99
    }
L
lilong12 已提交
100 101 102 103 104 105 106
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
107

S
update  
sandyhouse 已提交
108 109
  platform::Timer batch_timer;

S
sandyhouse 已提交
110 111 112 113 114 115 116 117
  // if (thread_id_ == 0) {
  // while (true) {
  // Start a minibatch.
  // real number of microbatches run
  // int real_microbatch_num = 0;
  batch_timer.Start();
  for (int i = 0; i < num_microbatches_; ++i) {
    try {
L
lilong12 已提交
118 119
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
S
sandyhouse 已提交
120 121 122 123 124 125 126 127 128 129 130 131
        // We run op with op_role = kLRSched only for the first microbatch
        // to avoid increasing the @LR_DECAY_STEP@ multiple times.
        bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
                                op_role == (static_cast<int>(OpRole::kForward) |
                                            static_cast<int>(OpRole::kLoss)) ||
                                op_role == static_cast<int>(OpRole::kLRSched);
        bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                          op_role == (static_cast<int>(OpRole::kForward) |
                                      static_cast<int>(OpRole::kLoss));
        if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
          VLOG(3) << "running an op " << op->Type() << " for scope " << i;
          op->Run(*microbatch_scopes_[i], place_);
L
lilong12 已提交
132
          if (gc) {
S
sandyhouse 已提交
133 134
            DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                                gc.get());
L
lilong12 已提交
135 136
          }
        }
H
hutuxian 已提交
137
      }
S
sandyhouse 已提交
138 139 140 141 142 143 144 145
    } catch (platform::EOFException&) {
      // std::unique_lock<std::mutex> lk(thread_mutex);
      // threads_completed = true;
      VLOG(3) << "thread  completed.";
      // VLOG(3) << "called notify all";
      // thread_condition.notify_all();
      VLOG(0) << "EOF encountered";
      break;
H
hutuxian 已提交
146
    }
S
sandyhouse 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
  }
  dev_ctx_->Wait();

  // backward pass
  for (int i = 0; i < num_microbatches_; ++i) {
    for (auto& op : ops_) {
      int op_role = op->Attr<int>(std::string("op_role"));
      if (op_role == static_cast<int>(OpRole::kBackward) ||
          op_role == (static_cast<int>(OpRole::kBackward) |
                      static_cast<int>(OpRole::kLoss))) {
        VLOG(3) << "running an op " << op->Type() << " for scope " << i;
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
L
lilong12 已提交
162 163
        }
      }
S
sandyhouse 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176
    }
  }
  dev_ctx_->Wait();
  // update pass
  for (auto& op : ops_) {
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kOptimize)) {
      VLOG(3) << "running an op " << op->Type() << " for minibatch scope";
      op->Run(*microbatch_scopes_[0], place_);
      if (gc) {
        for (int i = 0; i < num_microbatches_; ++i) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
L
lilong12 已提交
177 178
        }
      }
H
hutuxian 已提交
179 180
    }
  }
S
sandyhouse 已提交
181 182 183 184
  dev_ctx_->Wait();
  batch_timer.Pause();
  VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
  ++batch_id_;
H
hutuxian 已提交
185 186 187
}

void SectionWorker::TrainFilesWithProfiler() {
L
lilong12 已提交
188
  VLOG(3) << "begin section_worker TrainFiles with profiler";
S
sandyhouse 已提交
189
  // AutoSetCPUAffinity(true);
H
hutuxian 已提交
190

L
lilong12 已提交
191 192
  platform::Timer batch_timer;
  platform::Timer timeline;
H
hutuxian 已提交
193 194 195

  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
L
lilong12 已提交
196 197 198
  std::vector<double> op_max_time;
  std::vector<double> op_min_time;
  std::vector<uint64_t> op_count;
H
hutuxian 已提交
199 200 201 202
  for (auto& op : ops_) {
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
L
lilong12 已提交
203 204 205 206
  op_max_time.resize(ops_.size());
  op_min_time.resize(ops_.size());
  for (size_t i = 0; i < op_min_time.size(); ++i) {
    op_min_time[i] = DBL_MAX;
H
hutuxian 已提交
207
  }
L
lilong12 已提交
208 209 210 211 212 213 214 215 216 217 218
  op_count.resize(ops_.size());

  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  // const std::vector<std::string> keep_vars;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
219
    } else {
L
lilong12 已提交
220 221
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
222
    }
L
lilong12 已提交
223 224 225 226 227 228 229
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
230

S
sandyhouse 已提交
231 232 233 234 235 236 237 238 239 240
  // if (thread_id_ == 0) {
  struct timeval start;
  struct timeval end;
  struct timeval micro_start;
  struct timeval micro_end;
  // Start a minibatch.
  batch_timer.Start();
  int real_microbatch_num = 0;
  for (int i = 0; i < num_microbatches_; ++i) {
    try {
L
lilong12 已提交
241
      int op_idx = 0;
S
update  
sandyhouse 已提交
242
      gettimeofday(&micro_start, NULL);
L
lilong12 已提交
243
      for (auto& op : ops_) {
S
update  
sandyhouse 已提交
244
        gettimeofday(&start, NULL);
L
lilong12 已提交
245
        int op_role = op->Attr<int>(std::string("op_role"));
S
sandyhouse 已提交
246 247 248 249 250 251 252 253 254 255
        // We run op with op_role = kLRSched only for the first microbatch
        // to avoid increasing the @LR_DECAY_STEP@ multiple times.
        bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
                                op_role == (static_cast<int>(OpRole::kForward) |
                                            static_cast<int>(OpRole::kLoss)) ||
                                op_role == static_cast<int>(OpRole::kLRSched);
        bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                          op_role == (static_cast<int>(OpRole::kForward) |
                                      static_cast<int>(OpRole::kLoss));
        if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
L
lilong12 已提交
256
          VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
S
sandyhouse 已提交
257
                  << " for scope " << i;
L
lilong12 已提交
258
          timeline.Start();
S
sandyhouse 已提交
259
          op->Run(*microbatch_scopes_[i], place_);
L
lilong12 已提交
260
          if (gc) {
S
sandyhouse 已提交
261 262
            DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                                gc.get());
L
lilong12 已提交
263
          }
S
update  
sandyhouse 已提交
264
          cudaDeviceSynchronize();
L
lilong12 已提交
265
          timeline.Pause();
S
sandyhouse 已提交
266
          gettimeofday(&end, NULL);
L
lilong12 已提交
267 268 269 270 271 272 273 274 275 276
          auto time = timeline.ElapsedUS();
          op_total_time[op_idx] += time;
          if (time > op_max_time[op_idx]) {
            op_max_time[op_idx] = time;
          }
          if (time < op_min_time[op_idx]) {
            op_min_time[op_idx] = time;
          }
          op_count[op_idx] += 1;
          op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
277
          {
S
sandyhouse 已提交
278
            // std::unique_lock<std::mutex> lk(cout_mutex);
S
sandyhouse 已提交
279
            std::cout << std::fixed;
S
update  
sandyhouse 已提交
280
            std::cout.precision(0);
S
sandyhouse 已提交
281 282
            std::cout << "::FWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                      << "]:SCOPE[" << i << "]:OP[" << op->Type() << "]:START["
S
sandyhouse 已提交
283 284
                      << start.tv_sec * 1e6 + start.tv_usec << "]:END["
                      << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
S
update  
sandyhouse 已提交
285
          }
L
lilong12 已提交
286 287 288
        }
        op_idx++;
      }
S
update  
sandyhouse 已提交
289 290
      gettimeofday(&micro_end, NULL);
      {
S
sandyhouse 已提交
291
        // std::unique_lock<std::mutex> lk(cout_mutex);
S
sandyhouse 已提交
292
        std::cout << std::fixed;
S
update  
sandyhouse 已提交
293
        std::cout.precision(0);
S
sandyhouse 已提交
294
        std::cout << "!!FWD:B[" << batch_id_ << "]:SEC[" << thread_id_
S
sandyhouse 已提交
295 296 297 298
                  << "]:START["
                  << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END["
                  << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]"
                  << std::endl;
S
update  
sandyhouse 已提交
299
      }
S
sandyhouse 已提交
300 301 302 303 304 305 306 307
    } catch (platform::EOFException&) {
      VLOG(3) << "thread  completed.";
      VLOG(0) << "EOF encountered";
      VLOG(0) << "============timeline============";
      for (size_t i = 0; i < ops_.size(); ++i) {
        VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i]
                << ", min_time: " << op_min_time[i]
                << ", mean_time: " << op_total_time[i] / op_count[i];
S
sandyhouse 已提交
308
      }
S
sandyhouse 已提交
309 310
      VLOG(0) << "================================";
      break;
H
hutuxian 已提交
311
    }
S
sandyhouse 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
  }
  dev_ctx_->Wait();
  // backward pass
  for (int i = 0; i < num_microbatches_; ++i) {
    int op_idx = 0;
    gettimeofday(&micro_start, NULL);
    for (auto& op : ops_) {
      gettimeofday(&start, NULL);
      int op_role = op->Attr<int>(std::string("op_role"));
      if (op_role == static_cast<int>(OpRole::kBackward) ||
          op_role == (static_cast<int>(OpRole::kBackward) |
                      static_cast<int>(OpRole::kLoss))) {
        VLOG(3) << "running an op " << op->Type() << " for scope " << i;
        timeline.Start();
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
330
        }
S
sandyhouse 已提交
331 332 333 334 335 336 337 338 339 340
        cudaDeviceSynchronize();
        gettimeofday(&end, NULL);
        timeline.Pause();
        auto time = timeline.ElapsedUS();
        op_total_time[op_idx] += time;
        if (time > op_max_time[op_idx]) {
          op_max_time[op_idx] = time;
        }
        if (time < op_min_time[op_idx]) {
          op_min_time[op_idx] = time;
L
lilong12 已提交
341
        }
S
sandyhouse 已提交
342 343
        op_count[op_idx] += 1;
        op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
344
        {
S
sandyhouse 已提交
345
          // std::unique_lock<std::mutex> lk(cout_mutex);
S
sandyhouse 已提交
346
          std::cout << std::fixed;
S
update  
sandyhouse 已提交
347
          std::cout.precision(0);
S
sandyhouse 已提交
348 349 350 351
          std::cout << "::BWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                    << "]:SCOPE[" << i << "]:OP[" << op->Type() << "]:START["
                    << start.tv_sec * 1e6 + start.tv_usec << "]:END["
                    << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
S
update  
sandyhouse 已提交
352
        }
H
hutuxian 已提交
353
      }
S
sandyhouse 已提交
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
      op_idx++;
    }
    gettimeofday(&micro_end, NULL);
    {
      // std::unique_lock<std::mutex> lk(cout_mutex);
      std::cout << std::fixed;
      std::cout.precision(0);
      std::cout << "!!BWD:B[" << batch_id_ << "]:SEC[" << thread_id_
                << "]:START[" << micro_start.tv_sec * 1e6 + micro_start.tv_usec
                << "]:END[" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]"
                << std::endl;
    }
  }
  dev_ctx_->Wait();
  if (real_microbatch_num == 0) {
    batch_timer.Pause();
    VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
    return;
  }
  // update pass
  int op_idx = 0;
  gettimeofday(&micro_start, NULL);
  for (auto& op : ops_) {
    gettimeofday(&start, NULL);
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kOptimize)) {
      VLOG(3) << "running an op " << op->Type() << " for " << thread_id_
              << " for minibatch scope";
      timeline.Start();
      op->Run(*microbatch_scopes_[0], place_);
      if (gc) {
        for (int i = 0; i < num_microbatches_; ++i) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
S
update  
sandyhouse 已提交
388
        }
L
lilong12 已提交
389
      }
S
sandyhouse 已提交
390 391 392 393 394 395 396
      cudaDeviceSynchronize();
      gettimeofday(&end, NULL);
      timeline.Pause();
      auto time = timeline.ElapsedUS();
      op_total_time[op_idx] += time;
      if (time > op_max_time[op_idx]) {
        op_max_time[op_idx] = time;
397
      }
S
sandyhouse 已提交
398 399
      if (time < op_min_time[op_idx]) {
        op_min_time[op_idx] = time;
L
lilong12 已提交
400
      }
S
sandyhouse 已提交
401 402
      op_count[op_idx] += 1;
      op_total_time[op_idx] += time;
S
update  
sandyhouse 已提交
403
      {
S
sandyhouse 已提交
404
        std::cout << std::fixed;
S
update  
sandyhouse 已提交
405
        std::cout.precision(0);
S
sandyhouse 已提交
406 407 408 409
        std::cout << "::UPD:B[" << batch_id_ << "]:SEC[" << thread_id_
                  << "]:SCOPE[" << num_microbatches_ << "]:OP[" << op->Type()
                  << "]:START[" << start.tv_sec * 1e6 + start.tv_usec
                  << "]:END[" << end.tv_sec * 1e6 + end.tv_usec << "]"
S
sandyhouse 已提交
410
                  << std::endl;
S
update  
sandyhouse 已提交
411
      }
H
hutuxian 已提交
412
    }
S
sandyhouse 已提交
413 414 415 416 417 418 419 420 421
    op_idx++;
  }
  gettimeofday(&micro_end, NULL);
  {
    std::cout << std::fixed;
    std::cout.precision(0);
    std::cout << "!!UPD:B[" << batch_id_ << "]:SEC[" << thread_id_ << "]:START["
              << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END["
              << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
H
hutuxian 已提交
422
  }
S
sandyhouse 已提交
423 424 425 426
  dev_ctx_->Wait();
  batch_timer.Pause();
  VLOG(0) << "batch time: " << batch_timer.ElapsedUS();
  ++batch_id_;
H
hutuxian 已提交
427 428 429 430
}
}  // namespace framework
}  // namespace paddle
#endif