section_worker.cc 13.8 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL)
L
lilong12 已提交
13 14 15 16 17
#include <float.h>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"

H
hutuxian 已提交
18 19 20 21 22
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"

#include "paddle/fluid/framework/device_worker.h"
H
hutuxian 已提交
23
#include "paddle/fluid/framework/fleet/box_wrapper.h"
H
hutuxian 已提交
24 25 26 27 28 29 30 31 32
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/lodtensor_printer.h"

namespace paddle {
namespace framework {

L
lilong12 已提交
33 34
uint64_t SectionWorker::batch_id_(0);

H
hutuxian 已提交
35
void SectionWorker::Initialize(const TrainerDesc& desc) {
H
hutuxian 已提交
36
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
S
sandyhouse 已提交
37 38
  program_.reset(
      new ProgramDesc(desc.section_param().section_config().program_desc()));
L
lilong12 已提交
39
  for (auto& op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
40 41 42 43 44 45
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
}

void SectionWorker::AutoSetCPUAffinity(bool reuse) {
  unsigned concurrency_cap = std::thread::hardware_concurrency();
S
sandyhouse 已提交
46
  unsigned proc = cpu_id_;
H
hutuxian 已提交
47 48 49 50 51 52

  if (proc >= concurrency_cap) {
    if (reuse) {
      proc %= concurrency_cap;
    } else {
      LOG(INFO) << "All " << concurrency_cap
S
sandyhouse 已提交
53 54
                << " CPUs have been set affinities. Fail to set " << cpu_id_
                << "th thread.";
H
hutuxian 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
      return;
    }
  }

  cpu_set_t mask;
  CPU_ZERO(&mask);
  CPU_SET(proc, &mask);

  if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
    return;
  }

  CPU_ZERO(&mask);
  if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) ||
      (0 == CPU_ISSET(proc, &mask))) {
    LOG(WARNING) << "Fail to set thread affinity to CPU " << proc;
  }
S
sandyhouse 已提交
73
  VLOG(3) << "Set " << cpu_id_ << "th thread affinity to CPU " << proc;
H
hutuxian 已提交
74 75 76
}

void SectionWorker::TrainFiles() {
S
sandyhouse 已提交
77 78
  VLOG(5) << "begin section_worker TrainFiles";
  AutoSetCPUAffinity(true);
H
hutuxian 已提交
79

L
lilong12 已提交
80 81 82 83 84 85 86 87
  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
88
    } else {
L
lilong12 已提交
89 90
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
91
    }
L
lilong12 已提交
92 93 94 95 96 97 98
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
99

S
update  
sandyhouse 已提交
100
  platform::Timer batch_timer;
S
sandyhouse 已提交
101 102 103
  batch_timer.Start();
  for (int i = 0; i < num_microbatches_; ++i) {
    try {
L
lilong12 已提交
104 105
      for (auto& op : ops_) {
        int op_role = op->Attr<int>(std::string("op_role"));
S
sandyhouse 已提交
106 107 108 109 110 111 112 113 114 115
        // We run op with op_role = kLRSched only for the first microbatch
        // to avoid increasing the @LR_DECAY_STEP@ multiple times.
        bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
                                op_role == (static_cast<int>(OpRole::kForward) |
                                            static_cast<int>(OpRole::kLoss)) ||
                                op_role == static_cast<int>(OpRole::kLRSched);
        bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                          op_role == (static_cast<int>(OpRole::kForward) |
                                      static_cast<int>(OpRole::kLoss));
        if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
S
sandyhouse 已提交
116 117
          VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
                  << i;
S
sandyhouse 已提交
118
          op->Run(*microbatch_scopes_[i], place_);
L
lilong12 已提交
119
          if (gc) {
S
sandyhouse 已提交
120 121
            DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                                gc.get());
L
lilong12 已提交
122 123
          }
        }
H
hutuxian 已提交
124
      }
S
sandyhouse 已提交
125
    } catch (platform::EOFException& e) {
S
sandyhouse 已提交
126 127
      VLOG(3) << "EOF encountered and completed.";
      throw;
H
hutuxian 已提交
128
    }
S
sandyhouse 已提交
129 130 131 132 133 134 135 136 137
  }

  // backward pass
  for (int i = 0; i < num_microbatches_; ++i) {
    for (auto& op : ops_) {
      int op_role = op->Attr<int>(std::string("op_role"));
      if (op_role == static_cast<int>(OpRole::kBackward) ||
          op_role == (static_cast<int>(OpRole::kBackward) |
                      static_cast<int>(OpRole::kLoss))) {
S
sandyhouse 已提交
138 139
        VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
                << i;
S
sandyhouse 已提交
140 141 142 143
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
L
lilong12 已提交
144 145
        }
      }
S
sandyhouse 已提交
146 147
    }
  }
S
sandyhouse 已提交
148

S
sandyhouse 已提交
149 150 151 152
  // update pass
  for (auto& op : ops_) {
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kOptimize)) {
S
sandyhouse 已提交
153
      VLOG(3) << "Update: running op " << op->Type();
S
sandyhouse 已提交
154 155
      op->Run(*microbatch_scopes_[0], place_);
      if (gc) {
S
sandyhouse 已提交
156 157
        DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_,
                            gc.get());
L
lilong12 已提交
158
      }
H
hutuxian 已提交
159 160
    }
  }
S
sandyhouse 已提交
161 162
  dev_ctx_->Wait();
  batch_timer.Pause();
S
sandyhouse 已提交
163
  VLOG(0) << "batch: " << batch_id_ << ", time: " << batch_timer.ElapsedUS();
S
sandyhouse 已提交
164
  ++batch_id_;
H
hutuxian 已提交
165 166 167
}

void SectionWorker::TrainFilesWithProfiler() {
S
sandyhouse 已提交
168 169
  VLOG(5) << "begin section_worker TrainFiles with profiler";
  AutoSetCPUAffinity(true);
H
hutuxian 已提交
170

L
lilong12 已提交
171 172
  platform::Timer batch_timer;
  platform::Timer timeline;
H
hutuxian 已提交
173 174

  std::vector<std::string> op_name;
S
sandyhouse 已提交
175
  std::vector<double> op_total_time;
L
lilong12 已提交
176 177 178
  std::vector<double> op_max_time;
  std::vector<double> op_min_time;
  std::vector<uint64_t> op_count;
H
hutuxian 已提交
179 180 181 182
  for (auto& op : ops_) {
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
L
lilong12 已提交
183 184 185 186
  op_max_time.resize(ops_.size());
  op_min_time.resize(ops_.size());
  for (size_t i = 0; i < op_min_time.size(); ++i) {
    op_min_time[i] = DBL_MAX;
S
sandyhouse 已提交
187
    op_max_time[i] = 0.0;
H
hutuxian 已提交
188
  }
L
lilong12 已提交
189 190 191 192 193 194 195 196 197 198
  op_count.resize(ops_.size());

  int64_t max_memory_size = 0;
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
#ifdef PADDLE_WITH_CUDA
  if (platform::is_gpu_place(place_)) {
    if (IsFastEagerDeletionModeEnabled()) {
      gc.reset(new UnsafeFastGPUGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
199
    } else {
L
lilong12 已提交
200 201
      gc.reset(new DefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
H
hutuxian 已提交
202
    }
L
lilong12 已提交
203 204 205 206 207 208 209
  } else if (platform::is_cpu_place(place_)) {
#endif
    gc.reset(new CPUGarbageCollector(
        BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size));
#ifdef PADDLE_WITH_CUDA
  }
#endif
H
hutuxian 已提交
210

S
sandyhouse 已提交
211 212 213 214
  struct timeval start;
  struct timeval end;
  struct timeval micro_start;
  struct timeval micro_end;
S
sandyhouse 已提交
215

S
sandyhouse 已提交
216 217 218 219
  // Start a minibatch.
  batch_timer.Start();
  for (int i = 0; i < num_microbatches_; ++i) {
    try {
L
lilong12 已提交
220
      int op_idx = 0;
S
update  
sandyhouse 已提交
221
      gettimeofday(&micro_start, NULL);
L
lilong12 已提交
222
      for (auto& op : ops_) {
S
update  
sandyhouse 已提交
223
        gettimeofday(&start, NULL);
L
lilong12 已提交
224
        int op_role = op->Attr<int>(std::string("op_role"));
S
sandyhouse 已提交
225 226 227 228 229 230 231 232 233 234
        // We run op with op_role = kLRSched only for the first microbatch
        // to avoid increasing the @LR_DECAY_STEP@ multiple times.
        bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
                                op_role == (static_cast<int>(OpRole::kForward) |
                                            static_cast<int>(OpRole::kLoss)) ||
                                op_role == static_cast<int>(OpRole::kLRSched);
        bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                          op_role == (static_cast<int>(OpRole::kForward) |
                                      static_cast<int>(OpRole::kLoss));
        if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
S
sandyhouse 已提交
235 236
          VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
                  << i;
L
lilong12 已提交
237
          timeline.Start();
S
sandyhouse 已提交
238
          op->Run(*microbatch_scopes_[i], place_);
L
lilong12 已提交
239
          if (gc) {
S
sandyhouse 已提交
240 241
            DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                                gc.get());
L
lilong12 已提交
242
          }
S
update  
sandyhouse 已提交
243
          cudaDeviceSynchronize();
L
lilong12 已提交
244
          timeline.Pause();
S
sandyhouse 已提交
245
          gettimeofday(&end, NULL);
L
lilong12 已提交
246 247 248 249 250 251 252 253 254 255
          auto time = timeline.ElapsedUS();
          op_total_time[op_idx] += time;
          if (time > op_max_time[op_idx]) {
            op_max_time[op_idx] = time;
          }
          if (time < op_min_time[op_idx]) {
            op_min_time[op_idx] = time;
          }
          op_count[op_idx] += 1;
          op_total_time[op_idx] += time;
S
sandyhouse 已提交
256 257 258 259 260 261 262

          std::cout << std::fixed;
          std::cout.precision(0);
          std::cout << "::FWD:B[" << batch_id_ << "]:SCOPE[" << i << "]:OP["
                    << op->Type() << "]:START["
                    << start.tv_sec * 1e6 + start.tv_usec << "]:END["
                    << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
L
lilong12 已提交
263 264 265
        }
        op_idx++;
      }
S
sandyhouse 已提交
266

S
update  
sandyhouse 已提交
267
      gettimeofday(&micro_end, NULL);
S
sandyhouse 已提交
268 269 270 271 272 273
      std::cout << std::fixed;
      std::cout.precision(0);
      std::cout << "!!FWD:B[" << batch_id_ << "]:START["
                << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END["
                << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]"
                << std::endl;
S
sandyhouse 已提交
274
    } catch (platform::EOFException& e) {
S
sandyhouse 已提交
275
      VLOG(0) << "EOF encountered, and completed";
S
sandyhouse 已提交
276 277 278 279 280
      VLOG(0) << "============timeline============";
      for (size_t i = 0; i < ops_.size(); ++i) {
        VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i]
                << ", min_time: " << op_min_time[i]
                << ", mean_time: " << op_total_time[i] / op_count[i];
S
sandyhouse 已提交
281
      }
S
sandyhouse 已提交
282
      VLOG(0) << "================================";
S
sandyhouse 已提交
283
      throw;
H
hutuxian 已提交
284
    }
S
sandyhouse 已提交
285
  }
S
sandyhouse 已提交
286

S
sandyhouse 已提交
287 288 289 290 291 292 293 294 295 296
  // backward pass
  for (int i = 0; i < num_microbatches_; ++i) {
    int op_idx = 0;
    gettimeofday(&micro_start, NULL);
    for (auto& op : ops_) {
      gettimeofday(&start, NULL);
      int op_role = op->Attr<int>(std::string("op_role"));
      if (op_role == static_cast<int>(OpRole::kBackward) ||
          op_role == (static_cast<int>(OpRole::kBackward) |
                      static_cast<int>(OpRole::kLoss))) {
S
sandyhouse 已提交
297 298
        VLOG(3) << "Backward: running an op " << op->Type()
                << " for micro-batch " << i;
S
sandyhouse 已提交
299 300 301 302 303
        timeline.Start();
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
304
        }
S
sandyhouse 已提交
305 306 307 308 309 310 311 312 313 314
        cudaDeviceSynchronize();
        gettimeofday(&end, NULL);
        timeline.Pause();
        auto time = timeline.ElapsedUS();
        op_total_time[op_idx] += time;
        if (time > op_max_time[op_idx]) {
          op_max_time[op_idx] = time;
        }
        if (time < op_min_time[op_idx]) {
          op_min_time[op_idx] = time;
L
lilong12 已提交
315
        }
S
sandyhouse 已提交
316 317
        op_count[op_idx] += 1;
        op_total_time[op_idx] += time;
S
sandyhouse 已提交
318 319 320 321 322 323 324

        std::cout << std::fixed;
        std::cout.precision(0);
        std::cout << "::BWD:B[" << batch_id_ << "]:SCOPE[" << i << "]:OP["
                  << op->Type() << "]:START["
                  << start.tv_sec * 1e6 + start.tv_usec << "]:END["
                  << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
H
hutuxian 已提交
325
      }
S
sandyhouse 已提交
326 327
      op_idx++;
    }
S
sandyhouse 已提交
328

S
sandyhouse 已提交
329
    gettimeofday(&micro_end, NULL);
S
sandyhouse 已提交
330 331 332 333 334
    std::cout << std::fixed;
    std::cout.precision(0);
    std::cout << "!!BWD:B[" << batch_id_ << "]:START["
              << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END["
              << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
S
sandyhouse 已提交
335
  }
S
sandyhouse 已提交
336

S
sandyhouse 已提交
337 338 339 340 341 342 343
  // update pass
  int op_idx = 0;
  gettimeofday(&micro_start, NULL);
  for (auto& op : ops_) {
    gettimeofday(&start, NULL);
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kOptimize)) {
S
sandyhouse 已提交
344
      VLOG(3) << "Update: running op " << op->Type();
S
sandyhouse 已提交
345 346 347
      timeline.Start();
      op->Run(*microbatch_scopes_[0], place_);
      if (gc) {
S
sandyhouse 已提交
348 349
        DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_,
                            gc.get());
L
lilong12 已提交
350
      }
S
sandyhouse 已提交
351 352 353 354 355 356 357
      cudaDeviceSynchronize();
      gettimeofday(&end, NULL);
      timeline.Pause();
      auto time = timeline.ElapsedUS();
      op_total_time[op_idx] += time;
      if (time > op_max_time[op_idx]) {
        op_max_time[op_idx] = time;
358
      }
S
sandyhouse 已提交
359 360
      if (time < op_min_time[op_idx]) {
        op_min_time[op_idx] = time;
L
lilong12 已提交
361
      }
S
sandyhouse 已提交
362 363
      op_count[op_idx] += 1;
      op_total_time[op_idx] += time;
S
sandyhouse 已提交
364 365 366 367 368 369

      std::cout << std::fixed;
      std::cout.precision(0);
      std::cout << "::UPD:B[" << batch_id_ << "]:OP[" << op->Type()
                << "]:START[" << start.tv_sec * 1e6 + start.tv_usec << "]:END["
                << end.tv_sec * 1e6 + end.tv_usec << "]" << std::endl;
H
hutuxian 已提交
370
    }
S
sandyhouse 已提交
371 372 373
    op_idx++;
  }
  gettimeofday(&micro_end, NULL);
S
sandyhouse 已提交
374 375 376 377 378
  std::cout << std::fixed;
  std::cout.precision(0);
  std::cout << "!!UPD:B[" << batch_id_ << "]:START["
            << micro_start.tv_sec * 1e6 + micro_start.tv_usec << "]:END["
            << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << std::endl;
S
sandyhouse 已提交
379 380
  dev_ctx_->Wait();
  batch_timer.Pause();
S
sandyhouse 已提交
381
  VLOG(0) << "batch: " << batch_id_ << ", time: " << batch_timer.ElapsedUS();
S
sandyhouse 已提交
382
  ++batch_id_;
H
hutuxian 已提交
383
}
S
sandyhouse 已提交
384

H
hutuxian 已提交
385 386 387
}  // namespace framework
}  // namespace paddle
#endif