backward.cc 17.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/eager/backward.h"
16

17
#include "paddle/fluid/eager/general_grad.h"
J
Jiabin Yang 已提交
18
#include "paddle/phi/kernels/autotune/switch_autotune.h"
19 20 21

namespace egr {

22
std::unordered_map<GradNodeBase*, int> getInDegreeMap(
23
    const std::deque<GradNodeBase*>& init_queue) {
24
  // Calculate in_degree for each node
25 26
  // We can completely remove this pass, if in_degree were set during forward
  // pass
27 28 29
  std::unordered_map<GradNodeBase*, int> node_in_degree_map;

  // Copy nodes
30
  std::deque<GradNodeBase*> queue = init_queue;
31 32 33 34 35
  std::unordered_set<GradNodeBase*> visited;

  // Visit each node exactly once in any order
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
36
    queue.pop_front();
37 38 39 40 41 42

    if (visited.count(node)) {
      continue;
    }
    visited.insert(node);

43 44 45 46 47
    PADDLE_ENFORCE_NOT_NULL(
        node,
        paddle::platform::errors::Fatal(
            "We got null node when we traverse the backward graph, and this "
            "should not happened please check your code and contact us."));
48
    // Find and append next nodes
49 50 51 52 53
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    for (const auto& meta_list : metas) {
      for (const GradSlotMeta& meta : meta_list) {
        const auto& edge = meta.GetEdge();
54 55 56 57 58 59 60 61 62 63
        GradNodeBase* next_node = edge.GetMutableGradNode().get();
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
        if (!next_node) continue;

        // Update in_degree
        if (!node_in_degree_map.count(next_node))
          node_in_degree_map[next_node] = 0;
        node_in_degree_map[next_node]++;
64
        queue.push_back(next_node);
65 66 67
      }
    }
  }
68

69
  return node_in_degree_map;
70 71 72 73 74
}

// Enforce GradNode has TensorWrappers as Input
void EnforceGradNodeHasInput(GradNodeBase* node) {
  PADDLE_ENFORCE_NE(
75 76
      node->IsTensorWrappersCleared(),
      true,
77 78 79 80 81 82 83 84
      paddle::platform::errors::Fatal(
          "The TensorWrappers of %s do not exist. This may be because:\n"
          "You calculate backward twice for the same subgraph without "
          "setting retain_graph=True. Please set retain_graph=True in the "
          "first backward/grad call.\n",
          node->name()));
}

85
void DuplicateCheck(const std::vector<paddle::Tensor>& inputs, bool is_input) {
86 87 88 89 90
  std::unordered_set<AutogradMeta*> visisted_ins;
  std::string msg = is_input ? "inputs" : "outputs";
  for (auto in : inputs) {
    AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in);
    PADDLE_ENFORCE_EQ(
91 92
        visisted_ins.count(auto_grad_meta),
        0,
93
        paddle::platform::errors::AlreadyExists(
94 95 96 97
            "%s contain duplicate tensor %s, please check %s carefully.",
            msg,
            in.name(),
            msg));
98
    visisted_ins.insert(auto_grad_meta);
99 100 101
  }
}

102 103
GeneralGrad* GeneralGrad::general_grad_ = new GeneralGrad();

104 105 106
std::vector<paddle::Tensor> RunBackward(
    const std::vector<paddle::Tensor>& tensors,  // output
    const std::vector<paddle::Tensor>& grad_tensors,
107 108
    bool retain_graph,
    bool create_graph = false,
109
    const std::vector<paddle::Tensor>& inputs = {},
110
    bool allow_unused = false,
111
    const std::vector<paddle::Tensor>& no_grad_vars = {}) {
112
  VLOG(3) << "Start Backward";
113

114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  std::queue<GradNodeBase*> force_sequential_nodes_forward_queue =
      egr::Controller::Instance().GetForceSequentialNodes();
  std::deque<GradNodeBase*> force_sequential_nodes_queue;
  std::set<GradNodeBase*> force_sequential_nodes_set;
  std::set<GradNodeBase*> ready_force_sequential_nodes;
  auto force_sequential_nodes_size =
      force_sequential_nodes_forward_queue.size();
  for (size_t i = 0; i < force_sequential_nodes_size; ++i) {
    force_sequential_nodes_set.insert(
        force_sequential_nodes_forward_queue.front());
    force_sequential_nodes_queue.push_front(
        force_sequential_nodes_forward_queue.front());
    force_sequential_nodes_forward_queue.pop();
  }

129 130 131 132
  // *Gradient Hook should happen at node-level
  // *Inplace version check should perform at node-level
  // *Cross-batch accumulation happens at forward pass

133 134
  // GeneralGrad
  bool is_general_grad = !inputs.empty();
135
  if (is_general_grad) GeneralGrad::Instance().Clear();
136

137 138 139
  /* --- Initialization --- */
  // 1. Init queue with starting nodes
  // 2. Prepare initial input buffers
140 141
  std::deque<GradNodeBase*> queue;
  std::deque<GradNodeBase*> orig_queue;
142 143
  std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
      node_input_buffers_dict;
W
wanghuancoder 已提交
144
  std::unordered_set<GradNodeBase*> visited;
145
  for (size_t i = 0; i < tensors.size(); i++) {
146
    const paddle::Tensor& tensor = tensors[i];
147

148 149
    AutogradMeta* auto_grad_meta = EagerUtils::nullable_autograd_meta(tensor);
    if (auto_grad_meta == nullptr) {
J
Jiabin Yang 已提交
150
      VLOG(5) << "Skip auto grad since there is no grad op for var or loss is "
151 152 153 154
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }
155 156 157
    // Get grad input info from target tensors
    auto input_info = auto_grad_meta->OutRankInfo();

J
Jiabin Yang 已提交
158
    VLOG(5) << "Out Rank of Tensor is slot: " << input_info.first
159 160
            << ", rank: " << input_info.second;
    // Get target GradNodeBase from target tensors
161 162 163 164
    auto shared_grad_node = auto_grad_meta->GetMutableGradNode();

    if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
        auto_grad_meta->StopGradient()) {
J
Jiabin Yang 已提交
165
      VLOG(5) << "Skip auto grad since there is no grad op for var or loss is "
166 167 168 169 170
                 "stop_gradient=True: "
              << tensor.name();
      continue;
    }

171
    // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
172
    GradNodeBase* grad_node = shared_grad_node.get();
173 174
    if (is_general_grad) {
      // Save orig grad node
175
      orig_queue.push_back(grad_node);
176 177 178 179 180 181 182

      // Replace grad_node with copied grad_node
      grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

      // Record potential startup grad node
      GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
    }
183 184 185

    // Prepare GradTensorHolder
    if (!node_input_buffers_dict.count(grad_node)) {
J
Jiabin Yang 已提交
186
      VLOG(5) << "Create Value for grad input tensor " << i
187
              << " of grad node: " << grad_node->name();
188 189 190
      node_input_buffers_dict[grad_node] =
          std::make_unique<GradTensorHolder>(grad_node->InputMeta());
    }
191 192 193 194

    // copy grad tensor since we should totally run grad without affect forward
    // value
    if (grad_tensors.size() > 0 && grad_tensors[i].initialized()) {
195 196 197 198 199
      PADDLE_ENFORCE(
          grad_tensors.size() == tensors.size(),
          paddle::platform::errors::Fatal(
              "Detected size mismatch between tensors and grad_tensors"
              "grad_tensors should either have "
200
              "size = 0 or same size as tensors."));
201
      // Feed given tensor if it's provided
J
Jiabin Yang 已提交
202
      VLOG(3) << "Fill grad input tensor " << i << "with give grad tensor";
203

204 205 206
      // Deep copy
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
          input_info.first, input_info.second, grad_tensors[i]);
207
    } else {
J
Jiabin Yang 已提交
208
      VLOG(3) << "Fill grad input tensor " << i << " with 1.0";
209 210 211 212 213
      // Initialize tensor with 1.0
      // Forward Tensor "tensor" is passed to indicate tensortype, datatype and
      // dims
      // GradTensorHolder will initialize another tensor with same tensortype,
      // datatype and dims but filled with 1.0
214
      node_input_buffers_dict[grad_node]->CopyValueFromTensor(
215
          input_info.first, input_info.second, tensor, /*fill_one=*/true);
216 217
    }

218
    // Prepare queue, potential startup_nodes
W
wanghuancoder 已提交
219 220 221 222
    if (visited.count(grad_node)) {
      continue;
    }
    visited.insert(grad_node);
223
    queue.push_back(grad_node);
224 225 226
  }

  if (is_general_grad) {
227 228 229
    // Prepare several vital preprocess for GeneralGrad
    GeneralGrad::Instance().PreparedForGeneralGrad(
        inputs, no_grad_vars, orig_queue, &queue, node_input_buffers_dict);
230 231
  }

J
Jiabin Yang 已提交
232
  VLOG(5) << "Update In degree Map for backward";
233 234 235 236
  // 3. Compute in_degree for each node
  std::unordered_map<GradNodeBase*, int> node_in_degree_map =
      getInDegreeMap(queue);

J
Jiabin Yang 已提交
237
  VLOG(5) << "Startup_ops's size is " << queue.size();
238

239 240 241
  /* --- Topological Visit --- */
  // 1. Pop queue
  // 2. Run node
242
  //    |- Check and capture target result
243 244 245
  //    |- node(grads)
  //    |- Prepare for next node
  // 3. Update queue
246 247
  while (!queue.empty()) {
    GradNodeBase* node = queue.front();
J
Jiabin Yang 已提交
248
    VLOG(3) << "Preparing GradNode:" << node->name() << " addr:" << node;
249
    paddle::platform::RecordEvent node_record_event(
250
        std::string((*node).name()),
251 252
        paddle::platform::TracerEventType::Operator,
        1);
253

254
    if (queue.size() > 1 && node_in_degree_map[node] != 0) {
255
      queue.pop_front();
256 257
      continue;
    }
258
    queue.pop_front();
259

260
    // Run node: This is where Hook happens
261 262
    auto node_input_buffer_iter = node_input_buffers_dict.find(node);
    PADDLE_ENFORCE_NE(
263 264
        node_input_buffer_iter,
        node_input_buffers_dict.end(),
265
        paddle::platform::errors::Fatal(
266
            "Unable to find next node in the GradTensorHolder \n"
267
            "Trying to run Node without configuring its GradTensorHolder."));
268 269

    std::unique_ptr<GradTensorHolder> node_input_buffer =
270
        std::move(node_input_buffer_iter->second);
271

272
    // Check input
273 274
    EnforceGradNodeHasInput(node);

J
Jiabin Yang 已提交
275
    VLOG(7) << "Run Backward Kernel with GradTensorHolder.";
276
    // Run Pre Backward Node and get outputs
277
    paddle::small_vector<std::vector<paddle::Tensor>, kSlotSmallVectorSize>
278 279
        grad_output_tensors = (*node)(
            node_input_buffer->Buffers(), create_graph, is_general_grad);
280

281 282 283 284 285
    if (!inputs.empty() && is_general_grad) {
      GeneralGrad::Instance().SetResultForEnddingNodes(grad_output_tensors,
                                                       node);
    }

286 287
    // retain_grad or not
    if (!retain_graph) {
J
Jiabin Yang 已提交
288
      VLOG(3)
289 290 291 292
          << "retain_graph is false, need to clear the TensorWrapper of nodes.";
      node->ClearTensorWrappers();
    }

293
    // TODO(jiabin): Should we erase it or find a more efficient way.
294
    node_input_buffers_dict.erase(node_input_buffer_iter);
295 296

    // Prepare GradTensorHolder for next node
297 298 299
    const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
        metas = node->OutputMeta();
    PADDLE_ENFORCE(metas.size() == grad_output_tensors.size() || metas.empty(),
300 301
                   paddle::platform::errors::Fatal(
                       "Number of edges should be either empty ( for leaf node "
302 303
                       ") or the same as number of output grad tensors, but we "
                       "got edges size is: %d, grad_output size is: %d",
304 305
                       metas.size(),
                       grad_output_tensors.size()));
306

307 308 309
    for (size_t i = 0; i < metas.size(); i++) {
      for (size_t j = 0; j < metas[i].size(); j++) {
        const Edge& edge = metas[i][j].GetEdge();
J
Jiabin Yang 已提交
310 311 312
        if (!edge.IsInitialized()) {
          continue;
        }
313 314
        auto edge_rank = edge.GetEdgeRankInfo();
        // Since we make edge has as same rank as bwd outputs, we indexing them
315
        // with the same rank(i, j)
316
        auto next_node_shared = edge.GetMutableGradNode();
317 318 319
        VLOG(3) << "Node: " << node->name() << " addr:" << node
                << ", Found pending node: " << next_node_shared->name()
                << " addr: " << next_node_shared.get();
320 321 322
        // Next node could be nullptr if it is leaf tensor with no
        // AccumulationNode attached
        // Or it could also originated from dispensable inputs
323 324 325 326
        if (!next_node_shared || !next_node_shared.get() ||
            grad_output_tensors[i].empty()) {
          continue;
        }
327

328
        PADDLE_ENFORCE_LT(
329 330
            j,
            grad_output_tensors[i].size(),
331 332 333 334 335
            paddle::platform::errors::Fatal(
                "Rank of grad_output_tensors should be less than "
                "grad_output_tensors[i].size(), which is: %d. This error may "
                "indicate autoprune or autograd api error. ",
                grad_output_tensors.size()));
336
        paddle::Tensor& grad_output_tensor = grad_output_tensors[i][j];
337 338 339

        if ((!grad_output_tensor.defined() ||
             !grad_output_tensor.initialized())) {
J
Jiabin Yang 已提交
340
          VLOG(7) << "We get grad_output_tensor with slot: " << i
341
                  << ", rank: " << j << " as uninitialized or undefined tensor";
342
        }
343

J
Jiabin Yang 已提交
344
        VLOG(7) << "Get Edge and grad_output_tensor with slot: " << i
345 346 347
                << ", rank: " << j
                << " 's name is: " << grad_output_tensor.name();

348 349 350 351 352
        auto* next_node = next_node_shared.get();
        if (!node_input_buffers_dict.count(next_node)) {
          const auto& input_meta = next_node->InputMeta();
          auto grad_tensor_holder =
              std::make_unique<GradTensorHolder>(input_meta);
J
Jiabin Yang 已提交
353
          VLOG(7) << "Construct GradTensorHolder for grad node: "
354 355 356 357
                  << next_node->name();
          node_input_buffers_dict[next_node] = std::move(grad_tensor_holder);
        }

J
Jiabin Yang 已提交
358
        VLOG(3) << "Sum or Move grad inputs for edge slot: " << edge_rank.first
359
                << ", rank: " << edge_rank.second;
360

361 362 363 364
        node_input_buffers_dict[next_node]->add(edge_rank.first,
                                                edge_rank.second,
                                                grad_output_tensor,
                                                create_graph);
365 366 367

        // Update queue
        node_in_degree_map[next_node]--;
J
Jiabin Yang 已提交
368
        VLOG(7) << next_node->name()
369
                << " ref_cnt is: " << node_in_degree_map[next_node];
370

371 372 373 374
        PADDLE_ENFORCE(
            node_in_degree_map[next_node] >= 0,
            paddle::platform::errors::Fatal(
                "Detected in-degree value smaller than zero. For Node: %s"
375
                "Node's in-degree cannot be negative.",
376
                next_node->name()));
377

378 379
        auto add_next_node_func = [&node_in_degree_map,
                                   &queue](GradNodeBase* next_node) {
W
wanghuancoder 已提交
380 381 382 383
          if (dynamic_cast<egr::GradNodeAccumulation*>(next_node)) {
            queue.push_front(std::move(next_node));
          } else {
            queue.push_back(std::move(next_node));
384 385
          }
        };
W
wanghuancoder 已提交
386 387 388
        if (node_in_degree_map[next_node] == 0) {
          if (force_sequential_nodes_set.count(next_node)) {
            if (force_sequential_nodes_queue.front() == next_node) {
389
              force_sequential_nodes_queue.pop_front();
W
wanghuancoder 已提交
390 391 392 393 394 395 396 397 398 399 400
              add_next_node_func(next_node);
              while (ready_force_sequential_nodes.count(
                  force_sequential_nodes_queue.front())) {
                ready_force_sequential_nodes.erase(
                    force_sequential_nodes_queue.front());
                add_next_node_func(force_sequential_nodes_queue.front());
                force_sequential_nodes_queue.pop_front();
              }
            } else {
              ready_force_sequential_nodes.insert(next_node);
              continue;
401
            }
402
          } else {
W
wanghuancoder 已提交
403
            add_next_node_func(next_node);
404
          }
405 406 407 408
        }
      }
    }
  }
409

J
Jiabin Yang 已提交
410
  VLOG(7) << "Run Backward Final hook size: "
411 412 413 414 415
          << egr::Controller::Instance().FinalBackwardHooks().size();
  for (auto& hook : egr::Controller::Instance().FinalBackwardHooks()) {
    (*hook)();
  }
  egr::Controller::Instance().ClearFinalBackwardHooks();
416
  if (!is_general_grad) return {};
J
Jiabin Yang 已提交
417
  VLOG(3) << "Finish Backward";
418
  return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
419 420
}

421 422 423
void Backward(const std::vector<paddle::Tensor>& tensors,  // outputs
              const std::vector<paddle::Tensor>& grad_tensors,
              bool retain_graph) {
424
  VLOG(3) << "Run in Backward";
425
  paddle::platform::RecordEvent backward_record_event(
426
      "backward", paddle::platform::TracerEventType::UserDefined, 1);
427
  egr::Controller::Instance().ClearForceSequentialNodes();
428
  RunBackward(tensors, grad_tensors, retain_graph);
J
Jiabin Yang 已提交
429
  phi::autotune::AutoTuneStatus::Instance().Update();
430 431
}

432 433 434 435
std::vector<paddle::Tensor> Grad(
    const std::vector<paddle::Tensor>& tensors,  // outputs
    const std::vector<paddle::Tensor>& inputs,
    const std::vector<paddle::Tensor>& grad_tensors,
436 437 438 439
    bool retain_graph,
    bool create_graph,
    bool only_inputs,
    bool allow_unused,
440
    const std::vector<paddle::Tensor>& no_grad_vars) {
441
  VLOG(3) << "Run in Grad";
442 443 444 445

  DuplicateCheck(inputs, true /* is_input */);
  DuplicateCheck(tensors, false /* is_input */);

446 447 448 449 450 451 452
  return RunBackward(tensors,
                     grad_tensors,
                     retain_graph,
                     create_graph,
                     inputs,
                     allow_unused,
                     no_grad_vars);
453
}
454
}  // namespace egr