section_worker.cc 9.0 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
13
    defined(PADDLE_WITH_ASCEND_CL)
L
lilong12 已提交
14
#include <float.h>
H
hutuxian 已提交
15
#include "paddle/fluid/framework/device_worker.h"
16
#include "paddle/fluid/framework/executor_gc_helper.h"
H
hutuxian 已提交
17 18 19 20 21
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {

22 23
class TrainerDesc;

L
lilong12 已提交
24 25
uint64_t SectionWorker::batch_id_(0);

26
void SectionWorker::Initialize(const TrainerDesc &desc) {
H
hutuxian 已提交
27
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
28 29
  program_.reset(
      new ProgramDesc(desc.section_param().section_config().program_desc()));
30
  for (auto &op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
31 32
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
33

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
  for (auto &op : ops_) {
    // cache the op type during the init part
    // reduce unnecessary op visit during running
    int op_role = op->Attr<int>("op_role");
    if ((op_role == static_cast<int>(OpRole::kForward)) ||
        (op_role == (static_cast<int>(OpRole::kForward) |
                     static_cast<int>(OpRole::kLoss))) ||
        (op_role == static_cast<int>(OpRole::kLRSched))) {
      // forward ops and lr schedule ops, used for first micro step
      forward_and_lr_ops_.push_back(op.get());
      if ((op_role != static_cast<int>(OpRole::kLRSched))) {
        // only forward ops, used for second and later micro steps
        forward_ops_.push_back(op.get());
      }
    } else if ((op_role == static_cast<int>(OpRole::kBackward)) ||
               (op_role == (static_cast<int>(OpRole::kBackward) |
                            static_cast<int>(OpRole::kLoss)))) {
      backward_ops_.push_back(op.get());
    } else if (op_role == static_cast<int>(OpRole::kOptimize)) {
      optimizer_ops_.push_back(op.get());
    } else {
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "The op %s is None of LRSched, Forward, Backward or Optimize.",
          op->Type()));
    }
  }

61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
  // if not 1F1B scheduler
  if (schedule_mode_ != 1) return;

  bool is_first_stage = (pipeline_stage_ == 0);
  int BACKWARD = static_cast<int>(OpRole::kBackward);
  for (auto &op : ops_) {
    int op_role = op->Attr<int>("op_role");
    auto op_type = op->Type();

    // pipeline backward send op
    if (op_role != BACKWARD) continue;
    if (op_type != "send_v2" && op_type != "partial_send") continue;

    auto var_name = op->InputVars()[0];
    VLOG(3) << "Pipeline backward send var " << var_name;
    PADDLE_ENFORCE_NE(is_first_stage, true,
                      platform::errors::PreconditionNotMet(
                          "The first pipeline stage must do not have a "
                          "backward send var, please check var %s",
                          var_name));

    backward_send_vars_.push_back(var_name);
    skip_vars_.push_back(var_name);
  }
}

void SectionWorker::PrepareUnusedVar() {
  VLOG(5) << "begin prepare the unsed vars";
  unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
H
hutuxian 已提交
90 91
}

92 93 94 95
void SectionWorker::RunForward(
    int micro_id, std::unique_ptr<GarbageCollector> &gc,
    std::unordered_map<const OperatorBase *, std::vector<std::string>>
        &unused_vars_) {
96 97 98 99 100 101 102 103 104
  std::vector<OperatorBase *> &forward_tmp =
      micro_id == 0 ? forward_and_lr_ops_ : forward_ops_;
  for (auto &op : forward_tmp) {
    VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
            << micro_id;
    op->Run(*microbatch_scopes_[micro_id], place_);
    if (gc) {
      DeleteUnusedTensors(*microbatch_scopes_[micro_id], op, unused_vars_,
                          gc.get());
105 106 107 108 109 110 111 112
    }
  }
}

void SectionWorker::RunBackward(
    int micro_id, std::unique_ptr<GarbageCollector> &gc,
    std::unordered_map<const OperatorBase *, std::vector<std::string>>
        &unused_vars_) {
113 114 115 116 117 118 119
  for (auto &op : backward_ops_) {
    VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
            << micro_id;
    op->Run(*microbatch_scopes_[micro_id], place_);
    if (gc) {
      DeleteUnusedTensors(*microbatch_scopes_[micro_id], op, unused_vars_,
                          gc.get());
120 121 122 123 124 125 126 127
    }
  }
}

void SectionWorker::RunUpdate(
    std::unique_ptr<GarbageCollector> &gc,
    std::unordered_map<const OperatorBase *, std::vector<std::string>>
        &unused_vars_) {
128 129 130 131 132 133
  for (auto &op : optimizer_ops_) {
    VLOG(3) << "Update: running op " << op->Type();
    op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
    if (gc) {
      DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], op,
                          unused_vars_, gc.get());
134 135 136 137
    }
  }
}

138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
void SectionWorker::RunFThenB(std::unique_ptr<GarbageCollector> &gc) {
  // F-then-B scheduler which runs Forward phase for all microbatches,
  // then runs Backward phase for all microbatches.
  // step1: run forward
  for (int i = 0; i < num_microbatches_; ++i) {
    RunForward(i, gc, unused_vars_);
  }
  // step2: run backward
  for (int i = 0; i < num_microbatches_; ++i) {
    RunBackward(i, gc, unused_vars_);
  }
  // step3: run update
  RunUpdate(gc, unused_vars_);
}

void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {
  // 1F1B scheduler, which runs forward phase and backward phase altertively
  // after startup phase. For a stage, the number of microbatches for
  // startup is num_pipeline_stages_ - pipeline_stage_ - 1, where
  // num_pipeline_stages_ is the total number of pipeline stages and
  // pipeline_stage_ is the pipeline stage of the current device.
  auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
  VLOG(3) << "startup_steps:" << startup_steps
          << ", num_stages: " << num_pipeline_stages_
          << ", stage:" << pipeline_stage_;
  PADDLE_ENFORCE_GT(
      num_microbatches_, startup_steps,
      platform::errors::InvalidArgument(
          "To use pipeline with 1F1B scheduler, please make sure number of "
          "microbatches (%d) is than startup steps (%d).",
          num_microbatches_, startup_steps));
  int fw_step = 0;
  int bw_step = 0;

  // startup phase
  while (fw_step < startup_steps) {
    RunForward(fw_step, gc, unused_vars_);
    fw_step += 1;
176
    VLOG(2) << "micro steps fw_step:" << fw_step;
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
  }

  // 1f1b phase
  while (fw_step < num_microbatches_) {
    RunForward(fw_step, gc, unused_vars_);

    // delete backward send var at step=(bw_step - 2)
    if (gc && bw_step >= 2) {
      DeleteUnusedTensors(*microbatch_scopes_[bw_step - 2], backward_send_vars_,
                          gc.get());
    }

    RunBackward(bw_step, gc, unused_vars_);

    fw_step += 1;
    bw_step += 1;
193
    VLOG(2) << "micro steps fw_step:" << fw_step << ", bw_step:" << bw_step;
194 195 196 197 198 199 200
  }

  int reserve_bw_send_step = bw_step - 2;
  // backward phase
  while (bw_step < num_microbatches_) {
    RunBackward(bw_step, gc, unused_vars_);
    bw_step += 1;
201
    VLOG(2) << "micro steps  bw_step:" << bw_step;
202 203
  }

204
  VLOG(2) << "run update";
205 206 207 208 209 210 211 212 213 214
  RunUpdate(gc, unused_vars_);

  if (gc) {
    // NOTE(wangxi): program must add sync backward send comm at update
    // delete backward send var
    for (int i = reserve_bw_send_step; i < num_microbatches_; ++i) {
      DeleteUnusedTensors(*microbatch_scopes_[i], backward_send_vars_,
                          gc.get());
    }
  }
215 216
}

H
hutuxian 已提交
217
void SectionWorker::TrainFiles() {
218
  VLOG(5) << "begin section_worker TrainFiles";
219
  VLOG(2) << "mini batch steps:" << batch_id_;
H
hutuxian 已提交
220

221
  int64_t max_memory_size = GetEagerDeletionThreshold();
L
lilong12 已提交
222
  std::unique_ptr<GarbageCollector> gc;
223
  if (max_memory_size >= 0) {
224
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
225 226 227 228 229
    if (platform::is_gpu_place(place_)) {
      if (IsFastEagerDeletionModeEnabled()) {
        gc.reset(new UnsafeFastGPUGarbageCollector(
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
      }
H
hutuxian 已提交
230
    }
B
Baibaifan 已提交
231 232 233 234 235 236 237 238 239 240 241 242 243 244
#elif defined(PADDLE_WITH_ASCEND_CL)
    if (IsFastEagerDeletionModeEnabled()) {
      VLOG(4) << "Use unsafe fast gc for NPU.";
      gc.reset(new NPUUnsafeFastGarbageCollector(
          BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Please set FLAGS_fast_eager_deletion_mode=true to use "
          "GarbageCollector on NPU."));
      // TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
      VLOG(4) << "Use default stream gc for NPU.";
      gc.reset(new NPUDefaultStreamGarbageCollector(
          BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
    }
L
lilong12 已提交
245
#endif
B
Baibaifan 已提交
246
  }  // max_memory_size >= 0
H
hutuxian 已提交
247

248
  if (schedule_mode_ == 0) {
249
    RunFThenB(gc);
250
  } else {
251
    Run1F1B(gc);
H
hutuxian 已提交
252
  }
253

254 255
  dev_ctx_->Wait();
  ++batch_id_;
H
hutuxian 已提交
256
}
257

H
hutuxian 已提交
258 259 260
}  // namespace framework
}  // namespace paddle
#endif