basic_engine.cc 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/basic_engine.h"

#include <algorithm>
#include <memory>
#include <queue>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/profiler.h"

33 34
DECLARE_bool(sort_sum_gradient);

35 36 37
namespace paddle {
namespace imperative {

38
void BasicEngine::Init(VarBase* var, bool retain_graph) {
39
  retain_graph_ = retain_graph;
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
  init_node_ = var->GradVarBase()->GradNode();
  var->GradVarBase()->ClearGradNode();

  if (init_node_ == nullptr || var->OverridedStopGradient()) {
    VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
               "stop_gradient=True: "
            << var->Name();
    return;
  }

  VLOG(3) << "start backward";

  PADDLE_ENFORCE_EQ(
      var->HasGradVar(), true,
      platform::errors::NotFound("Grad variable not exist for variable %s",
                                 var->Name()));

  auto& fwd_var = var->Var().Get<framework::LoDTensor>();
  auto* grad_var =
      var->GradVarBase()->MutableVar()->GetMutable<framework::LoDTensor>();
  VLOG(6) << "init loss grad:" << var->GradVarBase()->Name()
          << " as stop_gradient false";
  var->GradVarBase()->InnerSetOverridedStopGradient(false);
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place());
  grad_var->Resize(fwd_var.dims());
  grad_var->mutable_data(fwd_var.place(), fwd_var.type());
  operators::math::set_constant(*dev_ctx, grad_var, 1.0);
}

void BasicEngine::CheckBackwardInputs(const OpBase& op) {
  for (auto& pair : op.GetInsMap()) {
    if (!pair.second.IsGrad()) {
      continue;
    }

    for (auto& var : pair.second) {
      if (!var) {
        continue;
      }

      auto* inner_var = var->MutableVar();
      framework::Tensor* tensor = nullptr;
      if (!inner_var->IsInitialized() ||
          inner_var->IsType<framework::LoDTensor>()) {
        tensor = inner_var->GetMutable<framework::LoDTensor>();
      }

      if (tensor && !tensor->IsInitialized()) {
        VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(op.place());
        tensor->mutable_data(op.place(), var->DataType());
        operators::math::set_constant(*dev_ctx, tensor, 0.0);
      }
    }
  }
}

void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
  for (const auto& pair : op.GetOutsMap()) {
    if (!pair.second.IsGrad()) {
      continue;
    }

    for (const auto& var : pair.second) {
      if (!var) continue;

      auto& accumulator = accumulators_[var.get()];
      if (!accumulator) {
108
        if (FLAGS_sort_sum_gradient) {
109 110 111 112 113 114 115 116
          accumulator.reset(new SortedGradientAccumulator(var.get()));
        } else {
          accumulator.reset(new EagerGradientAccumulator(var.get()));
        }
      }

      accumulator->IncreaseRefCnt();

117 118 119 120 121 122 123 124 125 126
      if (var->HasLeafHooks()) {
        VLOG(3) << "Grad variable wrapper (" << var->Name()
                << ") has leaf grad hooks.";
        PADDLE_ENFORCE_NE(var->HasGradNode(), true,
                          platform::errors::PermissionDenied(
                              "Only leaf Tensor's gradient can append hook to "
                              "Gradientaccumulator."));
        accumulator->SetPostHooks(var->GetLeafHooks());
      }

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
      VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "("
              << var.get() << ")  with reference count "
              << accumulator->RefCnt();
    }
  }
}

void BasicEngine::PrepareDeps() {
  PADDLE_ENFORCE_EQ(
      node_deps_.empty(), true,
      platform::errors::AlreadyExists("Op deps must be initialized here"));
  PADDLE_ENFORCE_EQ(
      accumulators_.empty(), true,
      platform::errors::AlreadyExists("Accumulators must be initialized here"));

  std::queue<GradOpNode*> q;
  std::unordered_set<GradOpNode*> visited;

  q.push(init_node_.get());
  visited.insert(init_node_.get());

  while (!q.empty()) {
    auto* cur_node = q.front();
    q.pop();

    for (auto& cur_op : *cur_node) {
Z
Zeng Jinle 已提交
153
      cur_op.EnforceHasInOut();
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
      PrepareGradAccumulators(cur_op);
    }

    const auto& grad_pending_nodes = cur_node->GradPendingNodes();
    for (auto& grad_pending_node : grad_pending_nodes) {
      PADDLE_ENFORCE_NOT_NULL(
          grad_pending_node,
          platform::errors::NotFound("Grad pending node should not be null"));
      ++node_deps_[grad_pending_node.get()];
      if (visited.count(grad_pending_node.get()) == 0) {
        visited.insert(grad_pending_node.get());
        q.push(grad_pending_node.get());
      }
    }
  }
}

void BasicEngine::Execute() {
  if (init_node_ == nullptr) {
    return;
  }

  PrepareDeps();
  // Start execute Computation graph
  std::queue<std::shared_ptr<GradOpNode>> q;
  q.push(std::move(init_node_));

  size_t op_num = 0;

  while (!q.empty()) {
    auto shared_cur_node = std::move(q.front());
    q.pop();

    for (auto& cur_op : *shared_cur_node) {
      ++op_num;

      // CheckBackWardInput
      CheckBackwardInputs(cur_op);

      // Step 1: Run Backward
      auto& bwd_ins = cur_op.GetInsMap();
      auto& bwd_outs = cur_op.GetOutsMap();

      NameVarMap<VariableWrapper> tmp_outs(bwd_outs);
      // 1. construct the output map 2. replace the element in the map
      // A var may be coresponding to several grad var in one op
      for (auto& pair : tmp_outs) {
        if (!pair.second.IsGrad()) {
          continue;
        }

        for (auto& var : pair.second) {
          if (!var) {
            continue;
          }

          auto iter = accumulators_.find(var.get());
          PADDLE_ENFORCE_EQ(
              iter != accumulators_.end(), true,
              platform::errors::NotFound("Cannot find gradient of variable %s",
                                         var->Name()));
215

216
          if (!var->OverridedStopGradient() && iter->second->RefCnt() == 1) {
217
            no_need_run_accumulators_.emplace_back(iter->second.get());
218 219 220
            continue;
          }

221 222 223
          auto tmp_var = std::make_shared<VariableWrapper>(var->Name());
          tmp_var->SetType(var->Type());
          var = tmp_var;
224 225 226 227 228 229 230 231 232 233
          need_accu_var_list_.emplace_back(iter->second.get(), var);
        }
      }

      {
        VLOG(3) << "Start to execute grad op " << cur_op.Type();
        OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
                    cur_op.place());
      }

234 235 236 237 238 239 240
      // Step 2: Sum Gradient & Call Accumulator Hooks
      for (auto* accumulator : no_need_run_accumulators_) {
        if (accumulator->HasPostHooks()) {
          accumulator->CallBackwardPostHooks();
        }
      }

241 242 243 244 245
      for (auto& pair : need_accu_var_list_) {
        pair.first->Add(std::move(pair.second), cur_op.id());
      }

      need_accu_var_list_.clear();
246
      no_need_run_accumulators_.clear();
247 248

      VLOG(3) << "Remove op after op " << cur_op.Type() << " runs";
249 250 251
      if (!retain_graph_) {
        cur_op.ClearBackwardTrace();
      }
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
    }

    // Step 3: Collect ready ops
    for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) {
      PADDLE_ENFORCE_NOT_NULL(grad_pending_node,
                              platform::errors::NotFound(
                                  "Grad pending node should not be nullptr"));
      auto iter = node_deps_.find(grad_pending_node.get());
      if (iter == node_deps_.end()) {
        continue;
      }

      if (--(iter->second) == 0) {
        q.push(grad_pending_node);
      }
    }
  }
  Clear();

  VLOG(1) << "Backward op number: " << op_num;
}

void BasicEngine::Clear() {
  init_node_.reset();
  node_deps_.clear();
  accumulators_.clear();
  need_accu_var_list_.clear();
279
  no_need_run_accumulators_.clear();
280 281 282 283
}

}  // namespace imperative
}  // namespace paddle