section_worker.cc 4.3 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
L
lilong12 已提交
13
#include <float.h>
H
hutuxian 已提交
14
#include "paddle/fluid/framework/device_worker.h"
15
#include "paddle/fluid/framework/executor_gc_helper.h"
H
hutuxian 已提交
16 17 18 19 20
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {

21 22
class TrainerDesc;

L
lilong12 已提交
23 24
uint64_t SectionWorker::batch_id_(0);

H
hutuxian 已提交
25
void SectionWorker::Initialize(const TrainerDesc& desc) {
H
hutuxian 已提交
26
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
27 28
  program_.reset(
      new ProgramDesc(desc.section_param().section_config().program_desc()));
L
lilong12 已提交
29
  for (auto& op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
30 31 32 33 34
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
}

void SectionWorker::TrainFiles() {
35
  VLOG(5) << "begin section_worker TrainFiles";
H
hutuxian 已提交
36

37
  int64_t max_memory_size = GetEagerDeletionThreshold();
L
lilong12 已提交
38 39
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
40
  if (max_memory_size >= 0) {
41
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
42 43 44 45 46
    if (platform::is_gpu_place(place_)) {
      if (IsFastEagerDeletionModeEnabled()) {
        gc.reset(new UnsafeFastGPUGarbageCollector(
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
      }
H
hutuxian 已提交
47
    }
L
lilong12 已提交
48 49
#endif
  }
H
hutuxian 已提交
50

51 52 53 54 55 56
  for (int i = 0; i < num_microbatches_; ++i) {
    for (auto& op : ops_) {
      int op_role = op->Attr<int>(std::string("op_role"));
      // We run op with op_role = kLRSched only for the first microbatch
      // to avoid increasing the @LR_DECAY_STEP@ multiple times.
      bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
L
lilong12 已提交
57
                              op_role == (static_cast<int>(OpRole::kForward) |
58 59 60 61 62 63 64 65 66 67 68 69
                                          static_cast<int>(OpRole::kLoss)) ||
                              op_role == static_cast<int>(OpRole::kLRSched);
      bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                        op_role == (static_cast<int>(OpRole::kForward) |
                                    static_cast<int>(OpRole::kLoss));
      if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
        VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
                << i;
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
L
lilong12 已提交
70
        }
H
hutuxian 已提交
71
      }
H
hutuxian 已提交
72
    }
73 74 75
#ifdef PADDLE_WITH_RCCL
    hipDeviceSynchronize();
#else
76
    cudaDeviceSynchronize();
77
#endif
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  }

  // backward pass
  for (int i = 0; i < num_microbatches_; ++i) {
    for (auto& op : ops_) {
      int op_role = op->Attr<int>(std::string("op_role"));
      if (op_role == static_cast<int>(OpRole::kBackward) ||
          op_role == (static_cast<int>(OpRole::kBackward) |
                      static_cast<int>(OpRole::kLoss))) {
        VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
                << i;
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
L
lilong12 已提交
93 94
        }
      }
H
hutuxian 已提交
95
    }
96 97 98
#ifdef PADDLE_WITH_RCCL
    hipDeviceSynchronize();
#else
99
    cudaDeviceSynchronize();
100
#endif
H
hutuxian 已提交
101 102
  }

103
  // update pass
H
hutuxian 已提交
104
  for (auto& op : ops_) {
105 106 107 108 109 110 111
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kOptimize)) {
      VLOG(3) << "Update: running op " << op->Type();
      op->Run(*microbatch_scopes_[0], place_);
      if (gc) {
        DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_,
                            gc.get());
H
hutuxian 已提交
112 113 114
      }
    }
  }
115 116
  dev_ctx_->Wait();
  ++batch_id_;
H
hutuxian 已提交
117
}
118

H
hutuxian 已提交
119 120 121
}  // namespace framework
}  // namespace paddle
#endif