section_worker.cc 4.1 KB
Newer Older
H
hutuxian 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

12
#if defined(PADDLE_WITH_NCCL)
L
lilong12 已提交
13
#include <float.h>
S
sandyhouse 已提交
14
#include "paddle/fluid/framework/device_worker.h"
L
lilong12 已提交
15
#include "paddle/fluid/framework/executor_gc_helper.h"
H
hutuxian 已提交
16 17 18 19 20 21

#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {

S
sandyhouse 已提交
22 23
class TrainerDesc;

L
lilong12 已提交
24 25
uint64_t SectionWorker::batch_id_(0);

H
hutuxian 已提交
26
void SectionWorker::Initialize(const TrainerDesc& desc) {
H
hutuxian 已提交
27
  dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
28 29
  program_.reset(
      new ProgramDesc(desc.section_param().section_config().program_desc()));
L
lilong12 已提交
30
  for (auto& op_desc : program_->Block(0).AllOps()) {
H
hutuxian 已提交
31 32 33 34 35
    ops_.push_back(OpRegistry::CreateOp(*op_desc));
  }
}

void SectionWorker::TrainFiles() {
36
  VLOG(5) << "begin section_worker TrainFiles";
H
hutuxian 已提交
37

38
  int64_t max_memory_size = GetEagerDeletionThreshold();
L
lilong12 已提交
39 40
  std::unique_ptr<GarbageCollector> gc;
  auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
41
  if (max_memory_size >= 0) {
L
lilong12 已提交
42
#ifdef PADDLE_WITH_CUDA
43 44 45 46 47
    if (platform::is_gpu_place(place_)) {
      if (IsFastEagerDeletionModeEnabled()) {
        gc.reset(new UnsafeFastGPUGarbageCollector(
            BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size));
      }
H
hutuxian 已提交
48
    }
L
lilong12 已提交
49 50
#endif
  }
H
hutuxian 已提交
51

52 53 54 55 56 57
  for (int i = 0; i < num_microbatches_; ++i) {
    for (auto& op : ops_) {
      int op_role = op->Attr<int>(std::string("op_role"));
      // We run op with op_role = kLRSched only for the first microbatch
      // to avoid increasing the @LR_DECAY_STEP@ multiple times.
      bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
L
lilong12 已提交
58
                              op_role == (static_cast<int>(OpRole::kForward) |
59 60 61 62 63 64 65 66 67 68 69 70
                                          static_cast<int>(OpRole::kLoss)) ||
                              op_role == static_cast<int>(OpRole::kLRSched);
      bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
                        op_role == (static_cast<int>(OpRole::kForward) |
                                    static_cast<int>(OpRole::kLoss));
      if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
        VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
                << i;
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
L
lilong12 已提交
71
        }
H
hutuxian 已提交
72
      }
H
hutuxian 已提交
73
    }
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    cudaDeviceSynchronize();
  }

  // backward pass
  for (int i = 0; i < num_microbatches_; ++i) {
    for (auto& op : ops_) {
      int op_role = op->Attr<int>(std::string("op_role"));
      if (op_role == static_cast<int>(OpRole::kBackward) ||
          op_role == (static_cast<int>(OpRole::kBackward) |
                      static_cast<int>(OpRole::kLoss))) {
        VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
                << i;
        op->Run(*microbatch_scopes_[i], place_);
        if (gc) {
          DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
                              gc.get());
L
lilong12 已提交
90 91
        }
      }
H
hutuxian 已提交
92
    }
93
    cudaDeviceSynchronize();
H
hutuxian 已提交
94 95
  }

96
  // update pass
H
hutuxian 已提交
97
  for (auto& op : ops_) {
98 99 100
    int op_role = op->Attr<int>(std::string("op_role"));
    if (op_role == static_cast<int>(OpRole::kOptimize)) {
      VLOG(3) << "Update: running op " << op->Type();
S
sandyhouse 已提交
101
      op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
102 103 104
      if (gc) {
        DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_,
                            gc.get());
H
hutuxian 已提交
105 106 107
      }
    }
  }
108 109
  dev_ctx_->Wait();
  ++batch_id_;
H
hutuxian 已提交
110
}
111

H
hutuxian 已提交
112 113 114
}  // namespace framework
}  // namespace paddle
#endif