section_worker.cc 5.7 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL)
L
lilong12 已提交
13
#include <float.h>
H
hutuxian 已提交
14
#include "paddle/fluid/framework/device_worker.h"
L
lilong12 已提交
15
#include "paddle/fluid/framework/executor_gc_helper.h"
H
hutuxian 已提交
16 17 18 19 20
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {

L
lilong12 已提交
21 22
class TrainerDesc;

L
lilong12 已提交
23 24
uint64_t SectionWorker::batch_id_(0);

L
lilong12 已提交
25
void SectionWorker::Initialize(const TrainerDesc &desc) {
H
hutuxian 已提交
26
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
27 28
  program_.reset(
      new ProgramDesc(desc.section_param().section_config().program_desc()));
L
lilong12 已提交
29
  for (auto &op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
30 31 32 33
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
}

L
lilong12 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
void SectionWorker::RunForward(
    int micro_id, std::unique_ptr<GarbageCollector> &gc,
    std::unordered_map<const OperatorBase *, std::vector<std::string>>
        &unused_vars_) {
  for (auto &op : ops_) {
    int op_role = op->Attr<int>(std::string("op_role"));
    // We run op with op_role = kLRSched only for the first microbatch
    // to avoid increasing the @LR_DECAY_STEP@ multiple times.
    bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
                            op_role == (static_cast<int>(OpRole::kForward) |
                                        static_cast<int>(OpRole::kLoss)) ||
                            op_role == static_cast<int>(OpRole::kLRSched);
    bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                      op_role == (static_cast<int>(OpRole::kForward) |
                                  static_cast<int>(OpRole::kLoss));
    if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
      VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
              << micro_id;
      op->Run(*microbatch_scopes_[micro_id], place_);
      if (gc) {
        DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
                            unused_vars_, gc.get());
      }
    }
  }
}

void SectionWorker::RunBackward(
    int micro_id, std::unique_ptr<GarbageCollector> &gc,
    std::unordered_map<const OperatorBase *, std::vector<std::string>>
        &unused_vars_) {
  for (auto &op : ops_) {
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kBackward) ||
        op_role == (static_cast<int>(OpRole::kBackward) |
                    static_cast<int>(OpRole::kLoss))) {
      VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
              << micro_id;
      op->Run(*microbatch_scopes_[micro_id], place_);
      if (gc) {
        DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
                            unused_vars_, gc.get());
      }
    }
  }
}

void SectionWorker::RunUpdate(
    std::unique_ptr<GarbageCollector> &gc,
    std::unordered_map<const OperatorBase *, std::vector<std::string>>
        &unused_vars_) {
  for (auto &op : ops_) {
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kOptimize)) {
      VLOG(3) << "Update: running op " << op->Type();
      op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
      if (gc) {
        DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
                            op.get(), unused_vars_, gc.get());
      }
    }
  }
}

H
hutuxian 已提交
98
void SectionWorker::TrainFiles() {
99
  VLOG(5) << "begin section_worker TrainFiles";
H
hutuxian 已提交
100

101
  int64_t max_memory_size = GetEagerDeletionThreshold();
L
lilong12 已提交
102 103
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
104
  if (max_memory_size >= 0) {
L
lilong12 已提交
105
#ifdef PADDLE_WITH_CUDA
106 107 108 109 110
    if (platform::is_gpu_place(place_)) {
      if (IsFastEagerDeletionModeEnabled()) {
        gc.reset(new UnsafeFastGPUGarbageCollector(
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
      }
H
hutuxian 已提交
111
    }
L
lilong12 已提交
112 113
#endif
  }
H
hutuxian 已提交
114

L
lilong12 已提交
115 116 117 118 119 120
  if (schedule_mode_ == 0) {
    // Gpipe scheduler which runs all forwards first, then backwards, then
    // update
    // step1: run forward
    for (int i = 0; i < num_microbatches_; ++i) {
      RunForward(i, gc, unused_vars_);
H
hutuxian 已提交
121
    }
L
lilong12 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    // step2: run backward
    for (int i = 0; i < num_microbatches_; ++i) {
      RunBackward(i, gc, unused_vars_);
    }
    // step2: run update
    RunUpdate(gc, unused_vars_);
  } else {
    // 1F1B scheduler
    auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
    VLOG(3) << "startup_steps:" << startup_steps
            << ", num_stages: " << num_pipeline_stages_
            << ", stage:" << pipeline_stage_;
    if (startup_steps > num_microbatches_) {
      startup_steps = num_microbatches_;
    }
    int fw_step = 0;
    int bw_step = 0;
    // startup phase
    while (fw_step < startup_steps) {
      RunForward(fw_step, gc, unused_vars_);
      fw_step += 1;
H
hutuxian 已提交
143 144
    }

L
lilong12 已提交
145 146 147 148 149 150 151 152 153 154 155
    // 1f1b phase
    while (fw_step < num_microbatches_) {
      RunForward(fw_step, gc, unused_vars_);
      fw_step += 1;
      RunBackward(bw_step, gc, unused_vars_);
      bw_step += 1;
    }
    // backward phase
    while (bw_step < num_microbatches_) {
      RunBackward(bw_step, gc, unused_vars_);
      bw_step += 1;
H
hutuxian 已提交
156
    }
L
lilong12 已提交
157
    RunUpdate(gc, unused_vars_);
H
hutuxian 已提交
158
  }
159 160
  dev_ctx_->Wait();
  ++batch_id_;
H
hutuxian 已提交
161
}
162

H
hutuxian 已提交
163 164 165
}  // namespace framework
}  // namespace paddle
#endif