program_interpreter.cc 54.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/program_interpreter.h"

#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/static_build.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_context.h"
28
#ifdef PADDLE_WITH_DNNL
29 30 31 32
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/phi/backends/device_manager.h"
33 34 35
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

namespace paddle {
namespace framework {

ProgramInterpreter::ProgramInterpreter(const platform::Place& place,
                                       const BlockDesc& block,
                                       framework::Scope* scope,
                                       const ExecutionConfig& execution_config)
    : place_(place),
      block_(block),
      stream_analyzer_(place),
      execution_config_(execution_config),
      var_scope_(scope) {
  VLOG(4) << "ProgramInterpreter(): " << this << " on " << place_;

  static_build_ = FLAGS_new_executor_static_build &&
                  !FLAGS_new_executor_use_cuda_graph &&
                  interpreter::BlockCanBeStaticBuilt(block);

  exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
  completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);

58 59
  dependecy_count_ = std::make_shared<std::vector<size_t>>();

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
  if (!FLAGS_new_executor_use_local_scope) {
    execution_config_.create_local_scope = false;
  }
  execution_config_.AnalyzeThreadPoolConfig(place, block.OpSize());
  execution_config_.Log(/*log_level=*/8);

  if (execution_config_.create_local_scope) {
    auto local_scope = &var_scope_.GetMutableScope()->NewScope();
    local_scope_ = local_scope;
  }
  var_scope_.SetLocalScope(local_scope_);

  instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) {
    SchedulingPriority lhs_scheduling_priority =
        vec_instruction_[lhs].GetSchedulingPriority();
    SchedulingPriority rhs_scheduling_priority =
        vec_instruction_[rhs].GetSchedulingPriority();
    if (lhs_scheduling_priority == rhs_scheduling_priority) {
      return lhs < rhs;
    }
    return lhs_scheduling_priority > rhs_scheduling_priority;
  };

  PrepareForCUDAGraphCapture();
}

ProgramInterpreter::~ProgramInterpreter() {
  // cancle gc's thread
  gc_.reset(nullptr);
  async_work_queue_.reset();
  VLOG(4) << "~ProgramInterpreter(): " << this << " on " << place_;

92
#ifdef PADDLE_WITH_DNNL
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
  // Clear mkl-dnn cache,
  // this is needed to have mkl-dnn unit tests working
  platform::ClearMKLDNNCache(place_, this);
#endif
}

void ProgramInterpreter::RunImpl() {
  // lazy initialization of gc, do not create gc is the program only run once
  if (!gc_) {
    gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
  }

  interpreter::ResetAtomicGuard guard(&deps_, &refs_);

  if ((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
      (sync_op_num_ == 0)) {
    VLOG(4) << "Tracing Instruction List";
    TraceInstructionList(vec_instruction_);
  } else {
    VLOG(4) << "Non-tracing";
    // For the program that only run once, it is no need to
    // create work_queue, so the async_work_queue_ is created
    // until the second step run.
    async_work_queue_ = GetWorkQueue();
    ExecuteInstructionList(vec_instruction_);
  }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  if (platform::is_custom_place(place_)) {
    platform::DeviceContextPool::Instance().Get(place_)->Wait();
  }
#endif
}

FetchList ProgramInterpreter::Run(
    const std::vector<std::string>& feed_names,
    const std::vector<phi::DenseTensor>& feed_tensors) {
  SetDeviceId(place_);
  CheckCUDAGraphBeforeRun(feed_names);

132
#ifdef PADDLE_WITH_DNNL
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
  platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif

  bool is_build = is_build_;
  Prepare(feed_names, feed_tensors, is_build);

  if (is_build) {
    RunImpl();
  }

  if (HasLocalScope()) {
    ClearLoDTensorArrayInLocalScope();
  }

  // return Fetch Tensors
  auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
  if (fetch_var) {
    auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
    if (platform::IsCUDAGraphCapturing()) {
      PADDLE_ENFORCE_EQ(fetch_list.empty(),
                        true,
                        platform::errors::InvalidArgument(
                            "Cannot fetch data when using CUDA Graph."));
    }
#endif
    return fetch_list;
  } else {
    return {};
  }
}

FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
                                  bool need_fetch) {
  SetDeviceId(place_);
  CheckCUDAGraphBeforeRun(feed_names);

170
#ifdef PADDLE_WITH_DNNL
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
  platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif

  if (!is_build_) {
    LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
    paddle::framework::interpreter::BuildVariableScope(
        block_, execution_config_, &var_scope_);

    std::vector<paddle::framework::OpFuncNode> op_func_nodes;
    paddle::framework::interpreter::BuildOpFuncList(
        place_,
        block_,
        execution_config_.skip_gc_vars,
        &op_func_nodes,
        &var_scope_,
        execution_config_,
        HasLocalScope(),
        static_build_);
    SetFeedVarsInplaceSkip(feed_names);
    // convert vec func_list to graph
    Convert(&op_func_nodes);
    UpdateSyncOpNum();
    if (static_build_) {
      VLOG(4) << "RUN impl";
      RunImpl();
    }
    is_build_ = true;
198
    is_shared_results_build_ = true;
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
  } else {
    RunImpl();
  }

  if (HasLocalScope()) {
    ClearLoDTensorArrayInLocalScope();
  }

  // return Fetch Tensors
  Scope* inner_scope =
      HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
  auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
  if (fetch_var && need_fetch) {
    auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
    if (platform::IsCUDAGraphCapturing()) {
      PADDLE_ENFORCE_EQ(fetch_list.empty(),
                        true,
                        platform::errors::InvalidArgument(
                            "Cannot fetch data when using CUDA Graph."));
    }
#endif
    return fetch_list;
  } else {
    return {};
  }
}

void ProgramInterpreter::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
  copy_program_ = prog;
}

void ProgramInterpreter::SetSkipGcVars(
    const std::set<std::string>& skip_gc_vars) {
  PADDLE_ENFORCE_EQ(
      execution_config_.skip_gc_vars.empty(),
      true,
      platform::errors::PreconditionNotMet(
          "execution_config_.skip_gc_vars can only be initialized once, now "
          "execution_config_.skip_gc_vars is "
          "not empty, do not call SetSkipGcVars method repeatedly."));
  execution_config_.skip_gc_vars = skip_gc_vars;
}

void ProgramInterpreter::SetJitInputVars(
    const std::set<std::string>& jit_input_vars) {
  PADDLE_ENFORCE_EQ(
      execution_config_.jit_input_vars.empty(),
      true,
      platform::errors::PreconditionNotMet(
          "execution_config_.jit_input_vars can only be initialized once, now "
          "execution_config_.jit_input_vars is "
          "not empty, do not call SetJitInputVars method repeatedly."));
  execution_config_.jit_input_vars = jit_input_vars;
}

const std::set<std::string>& ProgramInterpreter::JitInputVars() const {
  return execution_config_.jit_input_vars;
}

const VariableScope* ProgramInterpreter::GetVariableScope() const {
  return &var_scope_;
}

void ProgramInterpreter::reset_scope(Scope* new_scope) {
  var_scope_.SetScope(new_scope);
  auto& var_list = var_scope_.MutableVarList();
  for (size_t i = 0; i < var_list.size(); i++) {
    const auto& var_name = var_scope_.GetNameById(i);
    var_list[i] = new_scope->FindVar(var_name);
  }
  // The index should be assured valid, cause the InterpreterCore may not be
  // fully built, but was still cached and used. For example, see unit test
  // `test_assert.py`, it may exit before `ProgramInterpreter::Convert`,
  // but still was cached and used by later tests.
  for (size_t i = 0; i < std::min(refs_.size(), var_list.size()); i++) {
    refs_[i]->ResetVariable(var_list[i]);
  }

278 279
  for (auto& ins : vec_instruction_) {
    BuildAndCacheInstructionCtx(&ins);
280 281 282
  }
}

283
const Scope* ProgramInterpreter::local_scope() const { return local_scope_; }
284 285 286 287 288 289 290
void ProgramInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
  async_work_queue_ =
      reinterpret_cast<ProgramInterpreter*>(src)->GetWorkQueue();
  VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src
          << ") to InterpreterCore(" << this << ")";
}

291
void ProgramInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) {
Z
zhangbo9674 已提交
292 293
  const ProgramInterpreter& impl = dynamic_cast<const ProgramInterpreter&>(src);
  if (is_shared_results_build_ || !impl.IsSharedResultsBuild()) {
294 295
    return;
  }
296
  // share op dependency
Z
zhangbo9674 已提交
297 298
  dependency_builder_.ShareDependencyFrom(impl.GetDependencyBuilder());
  dependecy_count_ = impl.GetDependencyCount();
299
  // share event analysis
Z
zhangbo9674 已提交
300
  stream_analyzer_.ShareEventInfoFrom(impl.GetStreamAnalyzer());
301
  is_shared_results_build_ = true;
Z
zhangbo9674 已提交
302
  VLOG(8) << "Share Build Results from InterpreterCore(" << &impl
303 304 305
          << ") to InterpreterCore(" << this << ")";
}

306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
bool ProgramInterpreter::BuildInplaceCheckVarIsOnlyInput(
    const std::vector<std::vector<size_t>>& input_var2op, size_t var_index) {
  if (!var_scope_.VarDesc(var_index)) {
    return input_var2op.at(var_index).size() == 1;
  } else {
    int is_input_cnt = 0;
    for (auto inst_id : input_var2op.at(var_index)) {
      OpInOutInfo info;
      info.Build(vec_instruction_.at(inst_id).OpBase());
      if (info.IsInArgBufferNeeded(var_scope_.VarDesc(var_index)->Name())) {
        is_input_cnt++;
      }
    }
    return is_input_cnt == 1;
  }
}

std::shared_ptr<interpreter::AsyncWorkQueue>
ProgramInterpreter::GetWorkQueue() {
  if (async_work_queue_ == nullptr) {
    async_work_queue_ = std::make_shared<interpreter::AsyncWorkQueue>(
        execution_config_.host_num_threads,
        execution_config_.device_num_threads,
        nullptr);
  }
  return async_work_queue_;
}

334 335 336 337 338 339 340 341 342 343
const interpreter::DependencyBuilder& ProgramInterpreter::GetDependencyBuilder()
    const {
  return dependency_builder_;
}

std::shared_ptr<std::vector<size_t>> ProgramInterpreter::GetDependencyCount()
    const {
  return dependecy_count_;
}

344 345 346 347 348 349 350 351 352
const interpreter::StreamAnalyzer& ProgramInterpreter::GetStreamAnalyzer()
    const {
  return stream_analyzer_;
}

bool ProgramInterpreter::IsSharedResultsBuild() const {
  return is_shared_results_build_;
}

353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
void ProgramInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
  Scope* inner_scope =
      HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
  VariableValueMap ins_map;
  for (auto& var_name_item : instr_node->Inputs()) {
    std::vector<Variable*> input_vars;

    input_vars.reserve(var_name_item.second.size());
    for (auto& id : var_name_item.second) {
      input_vars.emplace_back(inner_scope->FindVar(var_scope_.GetNameById(id)));
    }
    ins_map.emplace(var_name_item.first, std::move(input_vars));
  }

  VariableValueMap outs_map;
  for (auto& var_name_item : instr_node->Outputs()) {
    std::vector<Variable*> out_vars;

    out_vars.reserve(var_name_item.second.size());
    for (auto& id : var_name_item.second) {
      out_vars.emplace_back(inner_scope->FindVar(var_scope_.GetNameById(id)));
    }
    outs_map.emplace(var_name_item.first, std::move(out_vars));
  }

  // set runtime_ctx and infershape_ctx_
  if (instr_node->OpBase()->Type() == "cinn_launch" ||
      instr_node->OpBase()->Type() == "cinn_instruction_run") {  // OP use scope
                                                                 // in kernel
    Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
                                         : var_scope_.GetMutableScope();
    instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
  } else {
    instr_node->ResetContext(ins_map, outs_map);
  }
}

void ProgramInterpreter::BuildInplace() {
  // NOTE(Ruibiao): coalesce_tensor_op outputs a FusedOutput phi::DenseTensor
  // and a list of Output Tensors which are sliced from the FusedOutput. These
  // outputs sholud not be the outvar of the in-place var-pair since memory
  // reuse between FusedOutput and Output Tensors is assumed. For the following
  // example:
  // fused_var, var1, var2, var3 = coalesce_tensor(var1, var2, var3)
  // var1 = sum(var4, var5)
  // ...
  //
  // After running coalesce_tensor_op, var1 is assumed to share the buffer
  // slices from fused_var. However, if sum_op is in-place, then var1 would
  // re-share the buffer with var4 instead of fused_var.
  std::set<std::string> skip_inplace_outvars;
  for (Instruction& instr : vec_instruction_) {
    OperatorBase* op = instr.OpBase();
    if (op->Type() == kCoalesceTensor) {
      const std::vector<std::string>& outputs =
          op->OutputVars(/*has_intermediate=*/false);
      skip_inplace_outvars.insert(outputs.begin(), outputs.end());
    }
  }

  Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
                                       : var_scope_.GetMutableScope();
  std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize());
  for (Instruction& instr : vec_instruction_) {
    for (auto& item : instr.Inputs()) {
      for (int var_id : item.second) {
        if (var_id != kEmptyVarIndex) {
          input_var2op.at(var_id).push_back(instr.Id());
        }
      }
    }
  }

426
  for (auto& instr : vec_instruction_) {
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
    auto* op_base = instr.OpBase();
    if (!op_base->Info().infer_inplace_) {
      continue;
    }

    auto in_to_outs = op_base->Info().infer_inplace_(
        platform::is_gpu_place(instr.DeviceContext().GetPlace()));

    auto& inputs = instr.Inputs();
    auto& outputs = instr.Outputs();
    for (auto& pair : in_to_outs) {
      auto iter = inputs.find(pair.first);
      if (iter != inputs.end() && !iter->second.empty()) {
        auto in_var_desc = var_scope_.VarDesc(iter->second[0]);
        if (in_var_desc && in_var_desc->Persistable()) {
          continue;
        }
        if (var_scope_.GetVarSikpInplace(iter->second[0])) {
          continue;
        }
        if (BuildInplaceCheckVarIsOnlyInput(input_var2op, iter->second[0])) {
          auto iterout = outputs.find(pair.second);
          if (iterout != outputs.end() && !iterout->second.empty()) {
            const std::string& invar_name =
                var_scope_.GetNameById(iter->second[0]);
            const std::string& outvar_name =
                var_scope_.GetNameById(iterout->second[0]);
            auto invar = local_scope->FindVar(invar_name);
            auto outvar = local_scope->FindVar(outvar_name);

            if (invar && outvar && invar->IsType<phi::DenseTensor>() &&
                outvar->IsType<phi::DenseTensor>() &&
                skip_inplace_outvars.find(outvar_name) ==
                    skip_inplace_outvars.end()) {
              instr.AddInplace(invar, outvar);
              VLOG(3) << "inplace " << op_base->Type() << " " << invar_name
                      << " -> " << outvar_name;
            }
          }
        }
      }
    }
  }
}

void ProgramInterpreter::PrepareForCUDAGraphCapture() {
  if (!FLAGS_new_executor_use_cuda_graph) return;
#ifdef PADDLE_WITH_CUDA
  PADDLE_ENFORCE_EQ(
      platform::IsCUDAGraphCapturing(),
      false,
      platform::errors::PermissionDenied("CUDA Graph is not allowed to capture "
                                         "before prepare."));
  PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_),
                    true,
                    platform::errors::InvalidArgument(
                        "CUDA Graph is only supported on NVIDIA GPU device."));
  // If set true, will call `cudaStreamSynchronize(nccl_stream)`after allreduce.
  // which may cause error in cuda graph. This behavior is consistent with PE.
  PADDLE_ENFORCE_EQ(FLAGS_sync_nccl_allreduce,
                    false,
                    platform::errors::InvalidArgument(
                        "FLAGS_sync_nccl_allreduce must be False to support "
                        "CUDA Graph capturing."));

  // All output vars of coalesce_tensor op should be persistable.
  // If fused output var of coalesce_tensor is gc, it will cause accuracy
  // problem. The specific reasons need to be analyzed.
  for (auto& op_desc : block_.AllOps()) {
    if (op_desc->Type() == kCoalesceTensor) {
      for (auto& out_var_name : op_desc->OutputArgumentNames()) {
        // The fused var needs to be set to persistable, not just added to
        // skip_gc_vars.
        // In the case where the feed fetch var is changed, StandaloneExecutor
        // will be newly constructed. If the fused var is not persistable,
        // these vars will be recreated and initialized, resulting in
        // precision problems.
        auto* out_var = op_desc->Block()->FindVarRecursive(out_var_name);
        if (out_var) {
          out_var->SetPersistable(true);
          VLOG(4) << "Mark Var(" << out_var_name << ") as Persistable.";
        }
      }
    }
  }
#else
  PADDLE_THROW(platform::errors::Unimplemented(
      "CUDA Graph is only supported on NVIDIA GPU device."));
#endif
}

void ProgramInterpreter::CheckCUDAGraphBeforeRun(
    const std::vector<std::string>& feed_names) {
#ifdef PADDLE_WITH_CUDA
  if (platform::IsCUDAGraphCapturing()) {
    PADDLE_ENFORCE_EQ(
        feed_names.empty(),
        true,
        platform::errors::InvalidArgument(
            "Feeding data is not permitted when capturing CUDA Graph."));
    PADDLE_ENFORCE_EQ(
        FLAGS_new_executor_use_cuda_graph,
        true,
        platform::errors::InvalidArgument(
            "You must turn on FLAGS_new_executor_use_cuda_graph to True "
            "to enable CUDA Graph capturing."));
    PADDLE_ENFORCE_EQ(
        place_,
        platform::CUDAGraphCapturingPlace(),
        platform::errors::InvalidArgument("The place to capture CUDAGraph is "
                                          "not the same as the place to run."));
  }
#endif
}

void ProgramInterpreter::BuildOperatorDependences() {
  // analysis the dependences between ops, add next_instr_list to each instr,
  // and set the dependecy_count_
  size_t instr_num = vec_instruction_.size();
546
  dependecy_count_ = GetDependencyCount();
547
  if (!is_shared_results_build_) {
548 549 550
    dependecy_count_->assign(instr_num, 0);
  }

551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
  auto downstream_map = dependency_builder_.Build(vec_instruction_);

  for (size_t instr_id = 0; instr_id < instr_num; ++instr_id) {
    Instruction& cur_instr = vec_instruction_[instr_id];
    const std::set<size_t>& next_instr_ids = downstream_map[instr_id];

    if (FLAGS_new_executor_serial_run) {
      for (size_t next_instr_id : next_instr_ids) {
        cur_instr.AddNextInstrInSameThread(next_instr_id);
      }
    } else {
      if (cur_instr.KernelType() == OpFuncType::kGpuAsync) {
        for (size_t next_instr_id : next_instr_ids) {
          if (vec_instruction_[next_instr_id].KernelType() ==
              OpFuncType::kGpuAsync) {
            cur_instr.AddNextInstrInSameThread(next_instr_id);
          } else {
            cur_instr.AddNextInstrInDifferentThread(next_instr_id);
          }
        }
      } else {
        bool has_instr_in_same_thread = false;
        for (size_t next_instr_id : next_instr_ids) {
          if (!has_instr_in_same_thread &&
              vec_instruction_[next_instr_id].KernelType() !=
                  OpFuncType::kGpuAsync) {
            cur_instr.AddNextInstrInSameThread(next_instr_id);
            has_instr_in_same_thread = true;
          } else {
            cur_instr.AddNextInstrInDifferentThread(next_instr_id);
          }
        }
      }
    }

586
    if (!is_shared_results_build_) {
587 588 589
      for (size_t next_instr_id : next_instr_ids) {
        ++(*dependecy_count_)[next_instr_id];
      }
590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
    }
  }
}

// At the end of each step, the holder of phi::DenseTensor in LoDTensorArray is
// null. Clear these Tensors and leave LoDTensorArray empty, otherwise an
// exception will occur in the next step
void ProgramInterpreter::ClearLoDTensorArrayInLocalScope() {
  auto vars = local_scope_->LocalVars();
  for (auto var : vars) {
    if (var->IsType<LoDTensorArray>()) {
      auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
      lod_tensor_arr->clear();
    }
  }
}

void ProgramInterpreter::Convert(
    std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
  auto& vec_meta_info = var_scope_.MutableVecMetaInfo();
  auto nodes = *op_func_nodes;
  auto op_nums = nodes.size();
  vec_instruction_.clear();
  vec_instruction_.reserve(op_nums);
  for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
    auto& op_func_node = nodes[op_idx];
    auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
    vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
#ifdef PADDLE_WITH_CUDA
    if (FLAGS_new_executor_use_cuda_graph) {
      auto& op = op_func_node.operator_base_;
      auto& op_type = op->Type();
      if (op_type == interpreter::kMemcpyD2H ||
          op_type == interpreter::kMemcpyH2D) {
        PADDLE_THROW(paddle::platform::errors::Fatal(
            "Cuda memory copy d2h/h2d is not allowed while using cuda graph."));
      }
      PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext),
                        true,
                        platform::errors::InvalidArgument(
                            "Device context of op %s must be [%s] while using "
                            "cuda graph, but got [%s].",
                            op_type,
                            typeid(phi::GPUContext).name(),
                            typeid(*dev_ctx_).name()));
      // cuda graph needs to record all stream
      phi::backends::gpu::CUDAGraphContextManager::Instance()
          .RecordCapturingDeviceContext(dev_ctx_);
    }
#endif
  }

  BuildOperatorDependences();

  // NOTE(Ruibiao): For cross-step stream synchronization, an event may be
  // recorded in the first step and waited in the second step. So, in the first
  // step, the WaitEvent may be called without RecordEvent. Considering that
  // before the first call to RecordEvent, an Event represents an empty set of
  // work and WaitEvent always return succeed immediately, we omit the
  // prelude-record for the first step here.
  stream_analyzer_.ConstructEvents(&vec_instruction_);

  // add event for the input var of jit program, since there are async copied
  // from gpu_pinned place to gpu place on compute stream.
654 655
  for (size_t i = 0; i < dependecy_count_->size(); ++i) {
    if ((*dependecy_count_)[i] == 0) {
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776
      auto& inst = vec_instruction_[i];
      if (inst.OpBase()->Type() == interpreter::kMemcpyD2H &&
          platform::is_gpu_place(place_)) {
        for (auto& item : inst.Inputs()) {
          for (auto var_id : item.second) {
            auto name = var_scope_.GetNameById(var_id);
            if (JitInputVars().count(name)) {
              auto device_event = std::make_shared<platform::DeviceEvent>(
                  place_, platform::GenerateDeviceEventFlag());
              VLOG(4) << "Add input event for input: " << name << " of "
                      << inst.OpBase()->Type();
              inst.AddEventToWait(
                  i, device_event, stream_analyzer_.GetWaiterType(inst));
            }
          }
        }
      }
    }
  }

  // calculate last_live_ops_
  for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
    Instruction& instr = vec_instruction_[op_idx];
    OpInOutInfo info;
    info.Build(instr.OpBase());

    std::set<size_t> gc_check_vars;

    const std::map<std::string, std::vector<int>>& ins = instr.Inputs();
    const std::map<std::string, std::vector<int>>& outs = instr.Outputs();
    std::multimap<std::string, std::vector<int>> ins_and_outs{ins.begin(),
                                                              ins.end()};
    ins_and_outs.insert(outs.begin(), outs.end());

    for (auto& item : ins_and_outs) {
      for (auto id : item.second) {
        if (id == kEmptyVarIndex) {
          continue;
        }
        auto* var_desc = var_scope_.VarDesc(id);
        // skip no_need_buffer input vars
        if (var_desc && ins.count(item.first) &&
            !info.IsInArgBufferNeeded(var_desc->Name())) {
          continue;
        }
        // skip when this var is not in block and not a data_transferred var,
        // which means this var is managed by other block
        const auto& var_name = var_scope_.GetNameById(id);
        bool not_owned = !block_.HasVar(var_name);
        const auto& transferred_vars = var_scope_.DataTransferAddedVars();
        bool not_transferred =
            std::all_of(transferred_vars.begin(),
                        transferred_vars.end(),
                        [&](const std::pair<std::string, int>& elem) {
                          return elem.first != var_name;
                        });
        if (not_owned && not_transferred) {
          VLOG(10) << "[gc_check_inputs] skip gc: " << var_name;
          continue;
        }
        gc_check_vars.insert(id);
      }
    }

    for (auto var_id : gc_check_vars) {
      Scope* inner_scope =
          HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
      paddle::framework::Variable* var =
          inner_scope->FindVar(var_scope_.GetNameById(var_id));
      if (var->IsType<phi::DenseTensor>() || var->IsType<phi::SelectedRows>() ||
          var->IsType<LoDTensorArray>()) {
        last_live_ops_[var_id].insert(op_idx);
      } else {
        VLOG(4) << "not clear " << var_scope_.GetNameById(var_id) << " after "
                << instr.OpBase()->Type() << " because its type is "
                << framework::ToTypeName(var->Type());
      }
    }
  }

  // clear the last_live_ops list for all vars in skip_gc_vars
  for (const std::string& skip_gc_var : execution_config_.skip_gc_vars) {
    int var_id = var_scope_.GetIdByName(skip_gc_var);
    if (var_id != -1) {
      last_live_ops_[var_id].clear();
      VLOG(8) << "Skip gc for var: " << skip_gc_var;
    }
  }

  // shrink, find the downstream op that has no other op in the
  // downstream list happens before it
  // For example,
  // b = op1(a)
  // c = op2(a, b)
  // in this case, a is the input of op1 and op2, we only need to check
  // a after op2, because op2 always uses a after op1.
  for (size_t i = 0; i < last_live_ops_.size(); ++i) {
    std::set<size_t> minumum_last_live_ops;
    for (size_t item : last_live_ops_[i]) {
      bool not_before_any = true;
      // find the op that is not executed before any
      for (size_t other_item : last_live_ops_[i]) {
        if (dependency_builder_.OpHappensBefore(item, other_item)) {
          VLOG(8) << "happens_before: " << item << "->" << other_item
                  << ", so skip " << item;
          not_before_any = false;
          break;
        }
      }
      if (not_before_any) {
        VLOG(8) << "last live op of var " << i << " "
                << var_scope_.GetNameById(i) << " : " << item << " "
                << vec_instruction_[item].OpBase()->Type();
        minumum_last_live_ops.insert(item);
        vec_instruction_[item].AddGCCheckVar(i);
      }
    }
    last_live_ops_[i] = minumum_last_live_ops;
    vec_meta_info[i].var_ref_count_ = last_live_ops_[i].size();
  }

777 778
  for (auto& ins : vec_instruction_) {
    BuildAndCacheInstructionCtx(&ins);
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793
  }

  bool inplaced = false;
  for (const Instruction& inst : vec_instruction_) {
    if (inst.OpBase()->Type() == "share_buffer" ||
        inst.OpBase()->Type() == "share_data") {
      VLOG(4) << "Already inplaced, skip inplace now.";
      inplaced = true;
    }
  }

  if (FLAGS_new_executor_use_inplace && !inplaced) {
    BuildInplace();
  }

794
  for (auto& dep : *dependecy_count_) {
795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
    deps_.emplace_back(std::make_shared<interpreter::OpDepInfo>(dep));
  }
  for (size_t i = 0; i < vec_meta_info.size(); ++i) {
    refs_.emplace_back(std::make_shared<interpreter::VarRefInfo>(
        vec_meta_info[i].var_ref_count_, var_scope_.VarRef(i)));
  }

  AnalyseExecuteOrderForTrace();
}

void ProgramInterpreter::BuildSkipShareLoDInfo() {
  for (size_t i = 0; i < vec_instruction_.size(); ++i) {
    bool can_skip_lod = true;
    for (auto& input : vec_instruction_[i].InnerRuntimeContext()->inputs) {
      for (auto& var : input.second) {
        if (var->IsType<phi::DenseTensor>()) {
811
          if (!var->Get<phi::DenseTensor>().lod().empty()) {
812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041
            can_skip_lod = false;
            break;
          }
        } else {
          can_skip_lod = false;
          break;
        }
      }
    }
    if (can_skip_lod) {
      VLOG(8) << "skip share lod for: " << vec_instruction_[i].OpBase()->Type()
              << " (" << i << ")";
    }
    vec_instruction_[i].InnerInferShapeContext()->SetSkipLoD(can_skip_lod);
  }
}

void ProgramInterpreter::RunOperator(const Instruction& instr_node) {
  auto* op = instr_node.OpBase();
  auto place = instr_node.DeviceContext().GetPlace();
  Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
                                       : var_scope_.GetMutableScope();
  VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);

  auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
  {
    // If it is OperatorBase, InferShape do nothing.
    if (op_with_kernel != nullptr) {
      platform::RecordEvent infershape_event(
          "infer_shape",
          platform::TracerEventType::OperatorInner,
          1,
          platform::EventRole::kInnerOp);

      // see OperatorWithKernel::RunImpl in operator.cc for why
      if (!(op_with_kernel->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
            op_with_kernel->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
        op_with_kernel->Info().infer_shape_(
            instr_node.InnerInferShapeContext().get());
      }
      infershape_event.End();
      platform::RecordOpInfoSupplement(op->Type(),
                                       op->Attrs(),
                                       *(instr_node.InnerInferShapeContext()),
                                       *(instr_node.InnerRuntimeContext()),
                                       op->Id());
    }
  }
  if (op_with_kernel != nullptr && FLAGS_new_executor_use_inplace) {
    // TODO(xiongkun03) Does operator base support inplace ?
    for (auto& pair : instr_node.InplaceInfo()) {
      const auto& in = paddle::framework::details::GetTensorFromVar(pair.first);
      auto* out =
          paddle::framework::details::GetMutableTensorFromVar(pair.second);
      if (in.dims() == out->dims()) {
        out->ShareBufferWith(in);
      }
    }
  }

  {
    platform::RecordEvent compute_event(
        "compute",
        platform::TracerEventType::OperatorInner,
        1,
        platform::EventRole::kInnerOp);
    if (op_with_kernel == nullptr) {  // operator base
      instr_node.OpBase()->Run(*local_scope, place_);
    } else {
      phi::Kernel* kernel = instr_node.PhiKernel();
      if (kernel && kernel->IsValid()) {  // phi kernel
        if (kernel->GetKernelRegisteredType() ==
            phi::KernelRegisteredType::FUNCTION) {
          VLOG(4) << "Run function kernel: " << op->Type();
          VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
                  << &instr_node.DeviceContext();
          phi::KernelContext phi_kernel_context;
          op_with_kernel->BuildPhiKernelContext(
              *instr_node.InnerRuntimeContext().get(),
              const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()),
              &phi_kernel_context);

          (*kernel)(&phi_kernel_context);
        } else {
          VLOG(4) << "Run structure kernel: " << op->Type();
          (*kernel)(instr_node.InnerExecutionContext().get());
        }
      } else {  // fluid kernel
        instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
      }
    }
  }

  VLOG(4) << "End run " << place << " " << op->DebugStringEx(local_scope);

  if (!instr_node.InplaceBackMap().empty()) {
    platform::RecordEvent inplaceback_event(
        "InplaceVarsBack", platform::TracerEventType::UserDefined, 10);
    auto& m = instr_node.InplaceBackMap();
    // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
    for (auto& p : m) {
      auto* transformed_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
          var_scope_.VarRef(p.first));
      auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
          var_scope_.VarRef(p.second));
      original_tensor->ShareDataWith(*transformed_tensor);
      VLOG(4) << "Transfer inplace variable back form "
              << var_scope_.GetNameById(p.first) << " to "
              << var_scope_.GetNameById(p.second);
    }
  }

  /*For profiling/benchmark only*/
  if (FLAGS_benchmark) {
    instr_node.DeviceContext().Wait();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
    VLOG(4) << "Operator(" << op->Type()
            << "): context wait and get last error";
#endif
  }

  for (auto& hook : hookfuncs_) {
    hook(op, local_scope);
  }

  // for debug nan/inf
  if (op_with_kernel != nullptr && FLAGS_check_nan_inf) {
    VLOG(4) << "Check nan/inf";
    try {
      framework::details::CheckOpHasNanOrInf(
          *op,
          *local_scope,
          place);  // TODO(xiongkun03) change it to inner scope.
    } catch (...) {
      const std::vector<std::string>* callstack = nullptr;
      auto attrs = op->Attrs();
      auto iter =
          attrs.find(OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
      if (iter != attrs.end()) {
        callstack = &PADDLE_GET_CONST(std::vector<std::string>, iter->second);
        if (callstack->empty()) callstack = nullptr;
      }
      std::ostringstream sout;
      if (callstack) {
        if (FLAGS_call_stack_level > 1) {
          sout << "\n\n  Compile Traceback (most recent call last):";
        } else {
          sout << "In user code:\n";
        }
        for (auto& line : *callstack) {
          sout << "\n  " << line;
        }
      }
      std::cout << sout.str() << std::endl;
      std::rethrow_exception(std::current_exception());
    }
  }
}

void ProgramInterpreter::RunInstruction(const Instruction& instr_node) {
  VLOG(5) << __func__ << " OP id:" << instr_node.Id()
          << " name:" << instr_node.OpBase()->Type() << " type:"
          << (instr_node.KernelType() == OpFuncType::kCpuSync
                  ? "kCpuSync"
                  : (instr_node.KernelType() == OpFuncType::kGpuSync
                         ? "kGpuSync"
                         : "kGpuAsync"))
          << " runs on " << platform::GetCurrentThreadName();

  auto* op = instr_node.OpBase();
  platform::RecordEvent instruction_event(
      op->Type(), platform::TracerEventType::Operator, 1);

  SetDeviceId(instr_node.DeviceContext().GetPlace());

  try {
    instr_node.WaitEvent(place_);

    if (!instr_node.IsArtificial()) {
      RunOperator(instr_node);
      CheckGC(instr_node);
      interpreter::LogDeviceMemoryStats(place_);
    }

    instr_node.RecordEvent(place_);
  } catch (platform::EnforceNotMet& ex) {
    framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex);
    exception_holder_.Catch(std::make_exception_ptr(std::move(ex)));
  } catch (platform::EOFException&) {
    exception_holder_.Catch(std::current_exception());
  } catch (std::exception& ex) {
    LOG(WARNING) << op->Type() << " raises an exception "
                 << platform::demangle(typeid(ex).name()) << ", " << ex.what();
    exception_holder_.Catch(std::current_exception());
  } catch (...) {
    LOG(WARNING) << op->Type() << " raises an unknown exception";
    exception_holder_.Catch(std::current_exception());
  }
}

std::string ProgramInterpreter::GetDepsString() const {
  std::stringstream ss;
  auto downstream_map = dependency_builder_.OpDownstreamMap();
  ss << "Note: when static_dep is 1, it is ok that the dynamic_dep will not "
        "be decreased to 0."
     << std::endl;
  ss << "unfinished_op_number_:" << unfinished_op_number_ << std::endl;
  for (size_t i = 0; i < deps_.size(); ++i) {
    ss << "op:" << i << ", type: " << vec_instruction_[i].OpBase()->Type()
       << ", static_dep:" << deps_[i]->StaticDep()
       << ", dynamic_dep:" << deps_[i]->DynamicDep() << ", downstream op: ";
    for (auto id : downstream_map[i]) {
      ss << id << ", ";
    }
    ss << std::endl;
  }
  return ss.str();
}

void ProgramInterpreter::ExecuteInstructionList(
    const std::vector<Instruction>& vec_instr) {
  unfinished_op_number_ = vec_instr.size();
  if (unfinished_op_number_ == 0) {
    VLOG(4) << "No op to run, return";
    return;
  }

  exception_holder_.Clear();

1042 1043
  for (size_t i = 0; i < dependecy_count_->size(); ++i) {
    if ((*dependecy_count_)[i] == 0) {
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182
      // NOTE(zhiqiu): hot fix for jit input var
      RecordMemcpyD2H(vec_instr.at(i));
      if (FLAGS_new_executor_serial_run) {
        RunInstructionAsync(i);
      } else {
        async_work_queue_->AddTask(vec_instr.at(i).KernelType(),
                                   [this, i] { RunInstructionAsync(i); });
      }
    }
  }

  // For debug hang in main_thread_blocker_.WaitEvent(),
  // launch async task to log deps every
  // FLAGS_executor_log_deps_every_microseconds, then cancel the std::async when
  // main_thread_blocker_.WaitEvent() executed. Why not use std::async instead
  // of workqueue? To make sure that the logging thread itself will not affect
  // the workqueue
  //  used in interpretercore.

  std::future<int> logged_times;
  std::atomic_bool cancel_log = ATOMIC_VAR_INIT(false);
  if (FLAGS_executor_log_deps_every_microseconds) {
    logged_times = std::async(
        std::launch::async,
        [this](const std::atomic_bool& cancel) {
          int times = 0;
          while (!cancel) {
            std::this_thread::sleep_for(std::chrono::microseconds(
                FLAGS_executor_log_deps_every_microseconds));
            // check again, since cancel may be changed during sleep
            if (cancel) {
              break;
            }
            VLOG(0) << "deps:\n" << GetDepsString();
            times++;
          }
          return times;
        },
        std::ref(cancel_log));
  }

  auto event_name = main_thread_blocker_.WaitEvent();
  VLOG(1) << "main_thread_blocker_(" << &main_thread_blocker_
          << ") got event_name: " << event_name;

  cancel_log = true;
  if (logged_times.valid()) {
    VLOG(1) << "Logged deps for " << logged_times.get() << " times";
  }

  if (UNLIKELY(exception_holder_.IsCaught())) {
    VLOG(1) << "Exception caught " << exception_holder_.Type();
    // Graceful exit when the executor encountered a fatal error.
    // EOF is not a fatal error.
    if (exception_holder_.Type() != "EOF") {
      async_work_queue_->Cancel();
      async_work_queue_.reset();
    }
    VLOG(4) << "Cancel ok";
    PADDLE_ENFORCE_EQ(
        main_thread_blocker_.Clear(),
        0,
        platform::errors::PreconditionNotMet(
            "main_thread_blocker_.Clear() return -1, clear failed"));
    VLOG(4) << "clear ok";
    exception_holder_.ReThrow();
  }
}

void ProgramInterpreter::RunNextInstructions(
    const Instruction& instr, SchedulingQueue* reserved_next_ops) {
  platform::RecordEvent record(
      "RunNextInstructions", platform::TracerEventType::UserDefined, 10);

  auto IsReady = [this](size_t next_id) {
    VLOG(4) << "op_id: " << next_id
            << ", remain deps: " << deps_[next_id]->DynamicDep();
    return deps_[next_id]->CheckAndDecrease();
  };

  for (size_t next_instr_id : instr.NextInstrsInDifferenceThread()) {
    if (IsReady(next_instr_id)) {
      async_work_queue_->AddTask(
          vec_instruction_[next_instr_id].KernelType(),
          [this, next_instr_id]() { RunInstructionAsync(next_instr_id); });
    }
  }

  for (size_t next_instr_id : instr.NextInstrsInSameThread()) {
    if (IsReady(next_instr_id)) {
      reserved_next_ops->push(next_instr_id);
    }
  }
}

void ProgramInterpreter::RunInstructionAsync(size_t instr_id) {
  // NOTE(Ruibiao): Due to the uncertain order in multi-threading asynchronous
  // scheduling, the priority order involved cross-thread scheduling is not
  // guaranteed. Only Ops scheduled by the same AddTask call have the guarantee
  // of priority order.
  SchedulingQueue ready_ops(instruction_scheduling_priority_less);
  ready_ops.push(instr_id);
  while (!ready_ops.empty()) {
    instr_id = ready_ops.top();
    ready_ops.pop();
    auto& instr_node = vec_instruction_.at(instr_id);

    RunInstruction(instr_node);

    if (UNLIKELY(exception_holder_.IsCaught())) {
      VLOG(4) << "Exception caught";
      if (exception_notifier_ != nullptr) {
        exception_notifier_->NotifyEvent();
      }
      return;
    }

    VLOG(4) << "unfinished_op_number_: " << unfinished_op_number_;
    if (UNLIKELY(unfinished_op_number_.fetch_sub(
                     1, std::memory_order_relaxed) == 1)) {
      if (completion_notifier_ != nullptr) {
        completion_notifier_->NotifyEvent();
      }
    }

    RunNextInstructions(instr_node, &ready_ops);
  }
}

void ProgramInterpreter::RecordStreamForGC(const Instruction& instr) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
  PADDLE_THROW(platform::errors::Unimplemented(
      "RecordStreamForGC is only implemented when compiled with GPU."));
#else
  if (!IsInterpretercoreFastGCEnabled() ||
      instr.KernelType() != OpFuncType::kGpuAsync) {
    return;
  }

1183 1184
  if (instr.DeviceContext().GetPlace().GetType() ==
      phi::AllocationType::CUSTOM) {
1185 1186 1187 1188 1189 1190 1191 1192
    return;
  }

  platform::RecordEvent record(
      "RecordStreamForGC", platform::TracerEventType::UserDefined, 10);

  gpuStream_t stream =
      reinterpret_cast<const phi::GPUContext&>(instr.DeviceContext()).stream();
1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204
// TODO(lizhiyu): Only analyse the 'send_v2' for GPT pp strategy right now.
// To support all the operators for communicating in the future.
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
  auto operator_base_ptr = instr.OpBase();
  if ((operator_base_ptr->Type() == "send_v2") &&
      (operator_base_ptr->Attr<bool>("use_calc_stream") == false)) {
    stream = platform::NCCLCommContext::Instance()
                 .Get(operator_base_ptr->Attr<int>("ring_id"),
                      instr.DeviceContext().GetPlace())
                 ->stream();
  }
#endif
1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362
  auto TensorRecordStream = [&stream](phi::DenseTensor& tensor) {
    auto allocation = tensor.Holder();
    if (allocation == nullptr) {
      return;
    }

    const platform::Place& place = allocation->place();
    if (platform::is_gpu_place(place)) {
      memory::RecordStream(allocation, stream);
    } else if (platform::is_cuda_pinned_place(place)) {
      // TODO(Ruibiao): Here should do something to make sure that the tensor
      // is not freed until the H2D copies done. However, simplely launch a
      // CUDA runtime callback to the H2D stream may lead a high performance
      // overhead. As all the cases we meet in H2D are copies from CPUPlace at
      // present, we just log a WARNING here. A better design is required.
      LOG(WARNING) << "Copy data from a CUDAPinned tensor in an asynchronous "
                      "manner may lead a data inconsistent";
    } else {
      // memory copies involve CPUPlace are always synchronous, so just do
      // nothing here
    }
  };

  /* NOTE(Ruibiao):Cross-stream tensor synchronization is required only when
   * all the following conditions are satisfied:
   * 1. The tensor will be GC after running the instruction, i.e., in
   * instr.GCCheckVars.
   * 2. The stream which initializes this tensor is different from the stream
   * which the instruction run in.
   * 3. The tensor is the instruction's input, cause we assume that
   * instruction will initialize all output tensors with its running stream.
   * 4. In the OP function of this instruction, the tensor is an input of a
   * async CUDA kernel.
   *
   * Here we only process the first condition, because:
   * 1. Since the RecordStream function will directly return when the recored
   * stream is equal to the owning stream, recording a stream same as which
   * initialized this tensor has less time overhead. Conversely, it may take
   * more time if we try to extract those cross-stream input vars from
   * instr.GCCheckVars.
   * 2. Now the instruction has no idea of which vars involving async running
   * in OP function, and thus we can not recognize condition 4. It should be
   * supported later.
   */
  for (int var_id : instr.GCCheckVars()) {
    VLOG(4) << "GC sync " << var_scope_.GetNameById(var_id) << " "
            << var_scope_.VarDesc(var_id);

    // persistable var will be ignore while GC
    if (var_scope_.VarDesc(var_id) &&
        var_scope_.VarDesc(var_id)->Persistable()) {
      continue;
    }

    paddle::framework::Variable* var = var_scope_.VarRef(var_id);
    if (var == nullptr) {
      continue;
    }

    if (var->IsType<phi::DenseTensor>()) {
      TensorRecordStream(*(var->GetMutable<phi::DenseTensor>()));
    } else if (var->IsType<
                   operators::reader::
                       OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) {
      // do nothing
    } else if (var->IsType<phi::SelectedRows>()) {
      TensorRecordStream(
          *(var->GetMutable<phi::SelectedRows>()->mutable_value()));
    } else if (var->IsType<LoDTensorArray>()) {
      auto* tensor_arr = var->GetMutable<LoDTensorArray>();
      for (auto& tensor : *tensor_arr) {
        TensorRecordStream(tensor);
      }
    } else if (var->IsType<std::vector<Scope*>>()) {
      // do nothing
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "The variable(%s) is not supported in eager deletion.",
          framework::ToTypeName(var->Type())));
    }
  }
#endif
}

void ProgramInterpreter::CheckGC(const Instruction& instr) {
  platform::RecordEvent record(
      "CheckGC", platform::TracerEventType::UserDefined, 10);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  RecordStreamForGC(instr);
#endif
  auto& var_scope = var_scope_;

  for (auto var_id : instr.GCCheckVars()) {
    VLOG(4) << "GC:" << var_scope_.GetNameById(var_id) << ", id:" << var_id
            << ", ref:" << refs_[var_id]->DynamicRef();
    bool is_ready = refs_[var_id]->CheckAndDecrease();
    // ignore all persistable var while GC
    if (var_scope.VarDesc(var_id) && var_scope.VarDesc(var_id)->Persistable()) {
      continue;
    }
    if (is_ready) {
      VLOG(6) << "Async delete variable with name : "
              << var_scope.GetNameById(var_id);
      gc_->Add(refs_[var_id]->Var(), instr);
    }
  }
}

void ProgramInterpreter::Prepare(
    const std::vector<std::string>& feed_names,
    const std::vector<phi::DenseTensor>& feed_tensors,
    bool prepare_feed) {
  PADDLE_ENFORCE_EQ(feed_names.size(),
                    feed_tensors.size(),
                    platform::errors::PreconditionNotMet(
                        "Required feed_names.size() == feed_tensors.size(), "
                        "but received %d != %d",
                        feed_names.size(),
                        feed_tensors.size()));
  auto FeedInput = [&] {
    VLOG(4) << "Feed inputs";
    for (size_t i = 0; i < feed_names.size(); ++i) {
      auto* feed_var = local_scope_->FindVar(feed_names[i]);
      PADDLE_ENFORCE_NOT_NULL(
          feed_var,
          platform::errors::NotFound("Variable %s should not be nullptr.",
                                     feed_names[i]));

      auto feed_tensor = feed_var->GetMutable<phi::DenseTensor>();
      feed_tensor->ShareDataWith(feed_tensors[i]);
      feed_tensor->set_lod(feed_tensors[i].lod());
    }
  };

  if (!is_build_) {
    paddle::framework::interpreter::BuildVariableScope(
        block_, execution_config_, &var_scope_);
    FeedInput();
    std::vector<paddle::framework::OpFuncNode> op_func_nodes;
    paddle::framework::interpreter::BuildOpFuncList(
        place_,
        block_,
        execution_config_.skip_gc_vars,
        &op_func_nodes,
        &var_scope_,
        execution_config_,
        HasLocalScope(),
        static_build_);
    SetFeedVarsInplaceSkip(feed_names);
    // convert vec func_list to graph
    Convert(&op_func_nodes);
    UpdateSyncOpNum();
    if (static_build_) {
      VLOG(4) << "RUN impl";
      RunImpl();
    }
    BuildSkipShareLoDInfo();
    is_build_ = true;
1363
    is_shared_results_build_ = true;
1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410
  }
  // NOTE: Because feed_tensor will be GC after
  // paddle::framework::BuildOpFuncList, so we should
  // call FeedInput again.
  if (prepare_feed) {
    FeedInput();
  }
}

void ProgramInterpreter::SetFeedVarsInplaceSkip(
    const std::vector<std::string>& feed_names) {
  for (auto& feed_name : feed_names) {
    var_scope_.SetVarSikpInplace(feed_name, true);
  }
}

bool ProgramInterpreter::HasLocalScope() const {
  return local_scope_ != nullptr;
}

// Note(zhangbo):
// (1) What is "Trace"?
// The OP execute scheduling rule adopted by Interpretercore by default is a
// multi-threaded scheduling mode(see ExecuteInstructionList). By maintaining a
// high-performance thread pool, the OP's execute scheduling is distributed to
// the sub threads maintained by the thread pool, but the main thread does not
// have any tasks. In Trace mode, the executor will execute directly in the main
// thread according to the pre provided OP sequence(trace_execute_order_),
// instead of being distributed to the thread pool.
// (2) When we use "Trace"?
// In dygraph to static, This scheduling causes that the execution of the
// forward and backward OPs and the execution of the dygraph optimizer cannot be
// executed in the same thread. Executing thread switch may cause cpu cache
// miss. When a model is all KQueueAsync type OPs, all OPs will be distributed
// to the DeviceThread for execution, and the multithreading scheduling will not
// have any benefits. Therefore, in the dynamic to static, when the number of
// KQueueAsync Ops is 0, we choose Trace mode.
void ProgramInterpreter::TraceInstructionList(
    const std::vector<Instruction>& vec_instr) {
  unfinished_op_number_ = vec_instr.size();
  if (unfinished_op_number_ == 0) {
    VLOG(4) << "No op to run, return";
    return;
  }

  exception_holder_.Clear();

1411 1412
  for (size_t i = 0; i < dependecy_count_->size(); ++i) {
    if ((*dependecy_count_)[i] == 0) {
1413 1414 1415 1416 1417
      // NOTE(zhiqiu): hot fix for jit input var
      RecordMemcpyD2H(vec_instr.at(i));
    }
  }

1418
  for (auto instr_id : trace_execute_order_) {
1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457
    auto& instr_node = vec_instruction_.at(instr_id);

    RunInstruction(instr_node);

    if (UNLIKELY(exception_holder_.IsCaught())) {
      VLOG(4) << "Exception caught";
      break;
    }
  }

  if (UNLIKELY(exception_holder_.IsCaught())) {
    VLOG(1) << "Exception caught " << exception_holder_.Type();
    PADDLE_ENFORCE_EQ(
        main_thread_blocker_.Clear(),
        0,
        platform::errors::PreconditionNotMet(
            "main_thread_blocker_.Clear() return -1, clear failed"));
    VLOG(4) << "clear ok";
    exception_holder_.ReThrow();
  }
}

void ProgramInterpreter::RecordMemcpyD2H(const Instruction& instr_node) {
  // NOTE(zhiqiu): hot fix for jit input var
  if (instr_node.OpBase()->Type() == interpreter::kMemcpyD2H) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    auto* default_dev_ctx = pool.Get(place_);
    for (auto& event : instr_node.EventsToWait()) {
      platform::RecordEvent record(
          "RecordStreamEvent", platform::TracerEventType::UserDefined, 10);
      VLOG(3) << "Record event on default stream in jit_input_var at op: "
              << instr_node.OpBase()->Type();
      event.event_->Record(default_dev_ctx);
    }
  }
}

void ProgramInterpreter::UpdateSyncOpNum() {
  int64_t sync_op_num = 0;
1458 1459 1460
  for (auto& ins : vec_instruction_) {
    if (ins.KernelType() == OpFuncType::kCpuSync ||
        ins.KernelType() == OpFuncType::kGpuSync) {
1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488
      sync_op_num = sync_op_num + 1;
    }
  }
  sync_op_num_ = sync_op_num;
  VLOG(4) << "Update sync op num, sync op num is: " << sync_op_num_;
}

// Note(zhangbo):
// When there is a KQueueSync type OP in the model, breadth traversal is better
// than depth traversal. For example: OP(O) ->(direct_run)-> OP(A)
// ->(sync_run)-> OP(B) OP(O) ->(direct_run)-> OP(C) ->(direct_run)-> OP(D) If B
// is run before C, B may always block to wait for A to finish executing, but in
// fact, C can be executed first during this time.
void ProgramInterpreter::AnalyseExecuteOrderForTrace() {
  VLOG(4) << "Analyze the execution order of Trace scheduling mode.";
  interpreter::ResetAtomicGuard guard(&deps_, &refs_);

  auto op_downstream_map = dependency_builder_.OpDownstreamMap();

  auto IsReady = [this](size_t next_id) {
    VLOG(4) << "op_id: " << next_id
            << ", remain deps: " << deps_[next_id]->DynamicDep();
    return deps_[next_id]->CheckAndDecrease();
  };

  std::vector<size_t> trace_order;
  SchedulingQueue ready_ops(instruction_scheduling_priority_less);

1489 1490
  for (size_t instr_id = 0; instr_id < dependecy_count_->size(); ++instr_id) {
    if ((*dependecy_count_)[instr_id] == 0) {
1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510
      ready_ops.push(instr_id);
    }
  }

  while (!ready_ops.empty()) {
    size_t now_id = ready_ops.top();
    ready_ops.pop();
    trace_order.push_back(now_id);

    auto next_op_set = op_downstream_map[now_id];

    for (size_t next_op_id : next_op_set) {
      if (IsReady(next_op_id)) {
        ready_ops.push(next_op_id);
      }
    }
  }

  PADDLE_ENFORCE_EQ(
      trace_order.size(),
1511
      dependecy_count_->size(),
1512 1513 1514 1515
      platform::errors::PreconditionNotMet(
          "trace_order size should be equal to dependecy_count_."));

  trace_execute_order_ = trace_order;
Z
zhangbo9674 已提交
1516

Z
zhangbo9674 已提交
1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529
  if (VLOG_IS_ON(6)) {
    std::stringstream ss;
    ss << "trace order: ";
    for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
      ss << vec_instruction_[trace_execute_order_[idx]]
                .OpFunc()
                ->operator_base_->Type()
         << "[" << trace_execute_order_[idx] << "]"
         << " -> ";
    }
    ss << "end\n";
    VLOG(6) << ss.str();
  }
1530 1531 1532 1533
}

}  // namespace framework
}  // namespace paddle