interpretercore_util.cc 42.0 KB
Newer Older
W
wanghuancoder 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
15

16 17
#include <algorithm>

W
wanghuancoder 已提交
18
#include "paddle/fluid/framework/executor_gc_helper.h"
19
#include "paddle/fluid/framework/new_executor/data_transfer.h"
X
xiongkun 已提交
20 21 22
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
23
#include "paddle/phi/core/kernel_context.h"
24
#include "paddle/phi/core/kernel_factory.h"
W
wanghuancoder 已提交
25

L
Leo Chen 已提交
26 27 28 29
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

30 31 32 33 34
// The difference between "sequential_run" and "serial_run":
// "sequential_run" dispatches OPs one by one according to the sequence in the
// Program, while "serial_run" ensures that all Ops are scheduled in a singal
// thread. In standalone executor, "sequential_run" is also "serial_run", while
// "serial_run" is not necessarily "sequential_run".
35 36
PADDLE_DEFINE_EXPORTED_bool(new_executor_sequential_run,
                            false,
37 38 39
                            "Enable sequential execution for standalone "
                            "executor, only applied to GPU OPs.");

40
PADDLE_DEFINE_EXPORTED_bool(
41 42
    new_executor_serial_run,
    false,
43
    "Enable serial execution for standalone executor, used for debug.");
44

45 46
DECLARE_bool(use_mkldnn);

W
wanghuancoder 已提交
47 48
namespace paddle {
namespace framework {
49
namespace interpreter {
50

51
using VariableIdMap = std::map<std::string, std::vector<int>>;
L
liutiexing 已提交
52
constexpr size_t kPrepareWorkQueueIdx = 2;
53
const char blocking_queue_prefix[] = "lod_tensor_blocking_queue";
L
liutiexing 已提交
54

55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
    size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
  std::vector<WorkQueueOptions> group_options;
  // for execute host Kernel
  group_options.emplace_back(/*name*/ "HostTasks",
                             /*num_threads*/ host_num_threads,
                             /*allow_spinning*/ true,
                             /*always_spinning*/ false,
                             /*track_task*/ false,
                             /*detached*/ true,
                             /*events_waiter*/ waiter);
  // for launch device Kernel
  group_options.emplace_back(/*name*/ "DeviceKernelLaunch",
                             /*num_threads*/ device_num_threads,
                             /*allow_spinning*/ true,
                             /*always_spinning*/ true,
                             /*track_task*/ false,
                             /*detached*/ true,
                             /*events_waiter*/ waiter);
  // for prepare deps and others
  group_options.emplace_back(/*name*/ "Prepare",
                             /*num_threads*/ 1,
                             /*allow_spinning*/ true,
                             /*always_spinning*/ false,
                             /*track_task*/ false,
                             /*detached*/ true,
                             /*events_waiter*/ waiter);
  return group_options;
}

AsyncWorkQueue::AsyncWorkQueue(size_t host_num_threads,
                               size_t device_num_threads,
                               EventsWaiter* waiter)
    : host_num_thread_(host_num_threads) {
  queue_group_ = CreateWorkQueueGroup(
      ConstructWorkQueueOptions(host_num_threads, device_num_threads, waiter));
}

93 94
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
                             std::function<void()> fn) {
95
  VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " ";
96 97
  // NOTE(zhiqiu): use the second queue of size of, so only one thread is used.
  if (FLAGS_new_executor_serial_run) {
98 99 100 101 102 103 104
    queue_group_->AddTask(static_cast<size_t>(OpFuncType::kQueueAsync),
                          std::move(fn));
  } else {
    queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
  }
}

105 106 107 108 109 110 111 112 113 114 115 116 117 118
std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicDeps(const std::vector<size_t>& dependecy_count) {
  VLOG(4) << "PrepareAtomicDeps";
  return queue_group_->AddAwaitableTask(
      kPrepareWorkQueueIdx, interpreter::PrepareAtomicDeps, dependecy_count);
}

std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicVarRef(
    const std::vector<VariableMetaInfo>& vec_meta_info) {
  VLOG(4) << "PrepareAtomicVarRef";
  return queue_group_->AddAwaitableTask(
      kPrepareWorkQueueIdx, interpreter::PrepareAtomicVarRef, vec_meta_info);
}
W
wanghuancoder 已提交
119

120
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicDeps(
121
    const std::vector<size_t>& dependecy_count) {
122
  VLOG(4) << "PrepareAtomicDeps";
123 124 125 126 127 128 129

  auto op_deps = std::make_unique<AtomicVectorSizeT>(dependecy_count.size());
  for (size_t i = 0; i < dependecy_count.size(); ++i) {
    (*op_deps)[i] = dependecy_count[i];
  }
  VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
  return op_deps;
130 131
}

132
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicVarRef(
133
    const std::vector<VariableMetaInfo>& vec_meta_info) {
134
  VLOG(4) << "PrepareAtomicVarRef";
135 136 137 138 139 140
  auto var_ref = std::make_unique<AtomicVectorSizeT>(vec_meta_info.size());
  for (size_t i = 0; i < vec_meta_info.size(); ++i) {
    (*var_ref)[i] = vec_meta_info[i].var_ref_count_;
  }
  VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
  return var_ref;
141 142
}

W
wanghuancoder 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
  auto* var_desc = block.FindVar(name);
  if (var_desc == nullptr || var_desc->Persistable()) {
    return false;
  }

  auto type = var_desc->Proto()->type().type();

  return type == proto::VarType::LOD_TENSOR ||
         type == proto::VarType::SELECTED_ROWS ||
         type == proto::VarType::LOD_TENSOR_ARRAY;
}

std::unordered_map<const paddle::framework::OperatorBase*,
                   std::vector<std::string>>
L
Leo Chen 已提交
158 159
get_unused_vars(const BlockDesc& block,
                const std::vector<std::shared_ptr<OperatorBase>>& ops) {
W
wanghuancoder 已提交
160 161 162
  std::unordered_map<std::string, size_t> var_op_idx_map;

  for (size_t i = 0; i < ops.size(); ++i) {
L
Leo Chen 已提交
163
    const auto& op = ops[i];
W
wanghuancoder 已提交
164 165 166 167 168 169 170 171 172 173

    OpInOutInfo info;
    for (auto& name_pair : op->Inputs()) {
      for (auto& name : name_pair.second) {
        if (!var_can_be_deleted(name, block)) {
          continue;
        }

        // var can be gc-ed
        if (!info.IsBuilt()) {
L
Leo Chen 已提交
174
          info.Build(op.get());
W
wanghuancoder 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        }

        if (info.IsInArgBufferNeeded(name)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        } else {
          VLOG(10) << "Skip reference count computing of variable "
                   << name_pair.first << "(" << name << ") in Operator "
                   << op->Type();
        }
      }
    }

    for (auto& name_pair : op->Outputs()) {
      for (auto& name : name_pair.second) {
        if (var_can_be_deleted(name, block)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        }
      }
    }
  }

  std::unordered_map<const OperatorBase*, std::vector<std::string>> result;
  for (auto& name_op_idx_pair : var_op_idx_map) {
    auto& name = name_op_idx_pair.first;
    size_t op_idx = name_op_idx_pair.second;
L
Leo Chen 已提交
202 203

    result[ops[op_idx].get()].emplace_back(name);
204
    VLOG(4) << ops[op_idx].get()->Type() << " " << name;
W
wanghuancoder 已提交
205
  }
206
  VLOG(4) << "gc map size:" << result.size();
W
wanghuancoder 已提交
207 208 209
  return result;
}

210
void build_variable_scope(const framework::BlockDesc& block,
211 212
                          VariableScope* var_scope,
                          bool use_local_scope) {
213 214 215 216 217 218 219 220
  VLOG(3) << "Creating Variables";
  auto inner_scope = var_scope->GetMutableScope();

  // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
  // created in var_scope.scope_ , and other scope is created in local scope.
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();

221
  for (auto& var_desc : block.AllVars()) {
222
    auto var_name = var_desc->Name();
X
xiongkun 已提交
223 224 225
    // TODO(xiongkun): user may create a variable with name that exists before.
    // under such circumstances, we should raise a error. Currently we can't
    // get the var_desc of startup_program, so leave it later.
226
    if (var_name == framework::kEmptyVarName) {
W
wanghuancoder 已提交
227 228
      continue;
    }
229

230 231
    if (var_desc->Persistable()) {
      auto* ptr = inner_scope->Var(var_name);
W
wanghuancoder 已提交
232

233
      VLOG(3) << "Initialize Variable " << var_name;
234 235
      // NOTE(zhiqiu): if var exists in scope and the type is right,
      // InitializeVariable will not create a new variable.
236 237 238
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
              << ptr << " type is " << static_cast<int>(var_desc->GetType());
239
    } else {
240 241 242 243 244
      auto* ptr = local_scope->Var(var_name);
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
              << ptr << "Variable Type "
              << static_cast<int>(var_desc->GetType());
W
wanghuancoder 已提交
245
    }
246
    var_scope->AddVar(var_name, var_desc);
W
wanghuancoder 已提交
247 248 249
  }
}

L
Leo Chen 已提交
250
void create_all_ops(const framework::BlockDesc& block,
X
xiongkun 已提交
251
                    std::vector<std::unique_ptr<OperatorBase>>* ops) {
252 253
  for (auto& op : block.AllOps()) {
    VLOG(3) << "CreateOp from : " << op->Type();
W
wanghuancoder 已提交
254 255 256 257 258

    auto& info = OpInfoMap::Instance().Get(op->Type());

    const VariableNameMap& inputs_names = op->Inputs();
    const VariableNameMap& outputs_names = op->Outputs();
259

W
wanghuancoder 已提交
260 261 262 263 264 265 266
    AttributeMap op_attr_map = op->GetAttrMap();

    if (info.Checker() != nullptr) {
      info.Checker()->Check(&op_attr_map);
    }
    auto op_base =
        info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
267 268 269 270 271 272 273 274 275 276

#ifdef PADDLE_WITH_MKLDNN
    if (FLAGS_use_mkldnn) {
      if (op->HasAttr("use_mkldnn")) {
        VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
        op_base->SetAttr("use_mkldnn", true);
      }
    }
#endif

X
xiongkun 已提交
277
    ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
W
wanghuancoder 已提交
278
  }
279 280 281
}

std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
282 283
    const VariableNameMap& var_name_map,
    VariableScope* var_scope,
284
    Scope* local_scope,
285
    bool enforce_exist = true) {
286 287 288 289 290 291 292 293
  VariableValueMap name2var;
  VariableIdMap name2id;
  for (auto& item : var_name_map) {
    std::vector<Variable*> vars;
    std::vector<int> ids;
    vars.reserve(item.second.size());

    for (auto& var_name : item.second) {
294 295 296 297 298 299 300 301 302 303 304 305
      if (!var_scope->HasVar(var_name)) {
        // Hot fix for variables used in dataloader, like
        // 'lod_tensor_blocking_queue_0' These variables may be created in
        // scope, and it is not existed as variable in program.
        if (var_name.find(blocking_queue_prefix) != std::string::npos &&
            local_scope->FindVar(var_name)) {
          var_scope->AddVar(var_name, nullptr);
        } else if (!enforce_exist) {
          // skip the non-exist variable: such as recurrent_grad
          VLOG(4) << var_name << " don't exist in variable scope, skip it!";
          continue;
        }
306
      }
307
      auto* var = local_scope->FindVar(var_name);
308
      auto var_id = var_scope->VarId(var_name);
309
      vars.push_back(var);
310 311 312 313 314 315 316
      ids.push_back(var_id);
    }
    name2var[item.first] = std::move(vars);
    name2id[item.first] = std::move(ids);
  }
  return std::make_tuple(name2var, name2id);
}
W
wanghuancoder 已提交
317

318 319 320 321 322 323 324 325 326 327 328 329
void apply_device_guard(const OperatorBase* op_base,
                        const platform::Place& place,
                        OpKernelType* expected_kernel_key) {
  bool need_change_place =
      (op_base->HasAttr("op_device") &&
       (op_base->Attr<std::string>("op_device").length() > 0));
  if (need_change_place) {
    auto& op_device = op_base->Attr<std::string>("op_device");
    if (op_device == "cpu" || platform::is_cpu_place(place)) {
      VLOG(3) << "Switch into CPUPlace by device_guard.";
      expected_kernel_key->place_ = platform::CPUPlace();
    } else if (op_device.find("gpu") != std::string::npos &&
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
               (platform::is_gpu_place(place) ||
                platform::is_npu_place(place))) {
      // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
      // will be executed and a warning will be given at the same time.
      if (op_base->SupportGPU()) {
        expected_kernel_key->place_ = place;
      } else if (op_base->SupportNPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no CUDA implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
346 347 348 349 350 351 352
    } else {
      PADDLE_THROW(
          platform::errors::Fatal("Unsupported current place %s", op_device));
    }
  }
}

353
void deal_operator_base(const platform::Place& place,
L
Leo Chen 已提交
354 355
                        const VariableScope* var_scope,
                        std::shared_ptr<OperatorBase> op_base,
356 357
                        OpFuncNode* op_func_node,
                        Scope* local_scope) {
358 359 360 361
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  // input, output is prepared. set the other attributes.
  op_func_node->operator_base_ = op_base;
362
  if (IsSupportedHetePlace(place)) {
363 364 365 366 367 368 369 370
    op_func_node->type_ = OpFuncType::kQueueAsync;
  } else if (platform::is_cpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueSync;
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("Unsupported current place %s", place));
  }

371
  op_func_node->kernel_func_ = nullptr;
372
  op_base->Run(*local_scope, place);  // Run without data transformer.
373 374 375 376 377 378 379 380 381 382 383 384

  std::unordered_set<int> no_data_transform_index;
  for (auto& it : op_func_node->input_index) {
    for (auto& id : it.second) {
      no_data_transform_index.emplace(id);
    }
  }
  op_func_node->no_data_transform_index =
      no_data_transform_index;  // all index is no-need-transform
  op_func_node->dev_ctx_ = dev_ctx;
}

385
void build_op_func_list(const platform::Place& place,
386
                        const framework::BlockDesc& block,
387
                        const std::set<std::string>& skip_gc_vars,
388
                        std::vector<OpFuncNode>* vec_func_list,
389 390
                        VariableScope* var_scope,
                        bool use_local_scope) {
391 392
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();
X
xiongkun 已提交
393 394 395 396 397 398 399 400 401 402 403 404 405
  std::vector<std::unique_ptr<OperatorBase>>
      ops_unique;  // its elements will be moved to vec_func_list
  // Step 1: create all ops for current block.
  create_all_ops(block, &ops_unique);
  // If gc is enabled and block size > 1
  const ProgramDesc& main_program = *block.Program();
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
      main_program, block.ID(), ops_unique);

L
Leo Chen 已提交
406 407 408 409
#ifdef PADDLE_WITH_MKLDNN
  platform::RegisterModelLayout(ops_unique, place);
#endif

410 411
  // its elements will be moved to vec_func_list
  std::vector<std::shared_ptr<OperatorBase>> ops;
X
xiongkun 已提交
412 413 414
  for (auto& op_unique : ops_unique) {
    ops.emplace_back(std::move(op_unique));
  }
415
  auto unused_var_map = get_unused_vars(block, ops);
W
wanghuancoder 已提交
416

L
Leo Chen 已提交
417 418
  for (size_t i = 0; i < ops.size(); ++i) {
    auto op = ops[i].get();
419
    VLOG(6) << "Build OpFuncNode from : " << op->Type();
W
wanghuancoder 已提交
420 421 422 423 424

    auto inputs_names = op->Inputs();
    auto outputs_names = op->Outputs();

    VariableValueMap ins_map;
425
    VariableIdMap ins_name2id;
426
    bool enforce_exist = true;
W
wanghuancoder 已提交
427 428 429 430 431 432 433
    if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
        op->Type() == "rnn_memory_helper_grad" ||
        op->Type() == "conditional_block" ||
        op->Type() == "conditional_block_grad" || op->Type() == "while" ||
        op->Type() == "while_grad") {
      enforce_exist = false;
    }
434
    std::tie(ins_map, ins_name2id) =
435
        build_variable_map(inputs_names, var_scope, local_scope, enforce_exist);
W
wanghuancoder 已提交
436 437

    VariableValueMap outs_map;
438
    VariableIdMap outs_name2id;
439 440
    std::tie(outs_map, outs_name2id) = build_variable_map(
        outputs_names, var_scope, local_scope, enforce_exist);
W
wanghuancoder 已提交
441

442
    // step 1: build OpFuncNode
W
wanghuancoder 已提交
443
    OpFuncNode op_func_node;
444
    op_func_node.operator_base_ = ops[i];
W
wanghuancoder 已提交
445 446
    op_func_node.input_index = ins_name2id;
    op_func_node.output_index = outs_name2id;
447
    VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
448

449
    if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
450
      // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
451
      deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
452 453
      VLOG(4) << "End run " << place << " "
              << op_func_node.operator_base_->DebugStringEx(local_scope);
W
wanghuancoder 已提交
454
    } else {
455 456
      auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(op));
457 458 459 460
      // construct RuntimeContext and analysis KernelType
      RuntimeContext runtime_context({}, {});
      runtime_context.inputs.swap(ins_map);
      runtime_context.outputs.swap(outs_map);
461

462
      Scope scope, *runtime_scope = &scope;
463 464 465 466
      // NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
      // But some OPs do have such behavior (e.g., cinn_launch OP). Here special
      // treatment for them.
      if (op_with_kernel->Type() == "cinn_launch") {
467 468 469 470
        VLOG(6) << "OP(" << op_with_kernel->Type()
                << ") use scope in kernel, "
                   "so pass a real scope to "
                   "ExecutionContext";
471 472 473
        runtime_scope = local_scope;
      }

474 475
      auto& pool = platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(place);
476 477
      auto exec_ctx = ExecutionContext(
          *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
478 479
      auto expected_kernel_key =
          op_with_kernel->GetExpectedKernelType(exec_ctx);
480 481
      // change device by the device_guard()
      apply_device_guard(op, place, &expected_kernel_key);
482
      VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
483

484
      // step 2. select op kernel
485 486
      auto run_phi_kernel = false;
      if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
487
              op_with_kernel->Type())) {
488 489
        auto pt_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
        auto pt_kernel_name = op_with_kernel->PhiKernelSignature()->name;
490

491 492
        if (op_with_kernel->PhiKernel()->IsValid()) {
          run_phi_kernel = true;
493
        } else {
494
          if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) {
495 496
            auto pt_cpu_kernel_key = FallBackToCpu(
                expected_kernel_key, pt_kernel_key, *op_with_kernel);
497
            op_with_kernel->ResetPhiKernel(
498
                new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
499
                    pt_kernel_name, pt_cpu_kernel_key)));
500
            if (op_with_kernel->PhiKernel()->IsValid()) {
501 502 503
              VLOG(6) << "Static mode PrepareImpl - kernel name: "
                      << pt_kernel_name
                      << " | kernel key: " << pt_cpu_kernel_key
504
                      << " | kernel: " << *(op_with_kernel->PhiKernel());
505 506
              op_with_kernel->ResetKernelType(new OpKernelType(
                  TransPhiKernelKeyToOpKernelType(pt_cpu_kernel_key)));
507
              run_phi_kernel = true;
508 509 510 511
            }
          }
        }
      }
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
      if (!run_phi_kernel) {
        op_with_kernel->ChooseKernel(exec_ctx);
        op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
      } else {
        op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
      }
      auto kernel_type = *(op_with_kernel->kernel_type());
      if (kernel_type.place_ != dev_ctx->GetPlace()) {
        dev_ctx = pool.Get(kernel_type.place_);
      }
      op_func_node.dev_ctx_ = dev_ctx;
      if (IsSupportedHetePlace(kernel_type.place_)) {
        op_func_node.type_ = OpFuncType::kQueueAsync;
      } else if (platform::is_cpu_place(kernel_type.place_)) {
        op_func_node.type_ = OpFuncType::kQueueSync;
      } else {
        PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
                                             kernel_type.place_));
      }
531
      VLOG(3) << op_with_kernel->Type()
532 533 534 535 536
              << " : finally selected kernel_key: " << kernel_type;

      // step 3. data transform
      VariableValueMap& ins_map_temp = runtime_context.inputs;
      VariableValueMap& outs_map_temp = runtime_context.outputs;
537 538 539 540 541 542 543
      ApplyDataTransform(kernel_type,
                         place,
                         &ins_map_temp,
                         &outs_map_temp,
                         var_scope,
                         &op_func_node,
                         vec_func_list,
544 545 546 547 548 549 550 551 552 553 554 555 556
                         use_local_scope);

      // step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc for
      // why.
      if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
            op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
        InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
        // TODO(Aurelius84): In case of control flow ops, they are NOT
        // inheritted from OperatorWithKernel.
        op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
      }

      // step 5. run kernel
557
      if (run_phi_kernel) {
558
        phi::KernelContext pt_kernel_context;
559 560
        op_with_kernel->BuildPhiKernelContext(
            runtime_context, dev_ctx, &pt_kernel_context);
561
        (*op_func_node.pt_kernel_)(&pt_kernel_context);
562
      } else {
563 564 565
        // the place of exec_ctx maybe has changed.
        op_func_node.kernel_func_(ExecutionContext(
            *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context));
566
      }
567

568 569
      // post-process grad_op.outputs if need cast complex grad into real
      // grad.
570
      // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
571
      if (framework::IsComplexType(kernel_type.data_type_)) {
572 573 574 575 576 577 578
        interpreter::HandleComplexGradToRealGrad(op_func_node,
                                                 place,
                                                 outputs_names,
                                                 &runtime_context.outputs,
                                                 var_scope,
                                                 vec_func_list,
                                                 local_scope);
579
      }
580 581
      if (!op_func_node.inplace_back_map.empty()) {
        auto& m = op_func_node.inplace_back_map;
582 583
        // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in
        // operator.cc
584 585 586
        for (auto& p : m) {
          auto* transformed_tensor =
              GetMutableLoDTensorOrSelectedRowsValueFromVar(
587
                  local_scope->FindVar(var_scope->GetNameById(p.first)));
588
          auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
589
              local_scope->FindVar(var_scope->GetNameById(p.second)));
590 591 592 593 594 595
          original_tensor->ShareDataWith(*transformed_tensor);
          VLOG(4) << "Transfer inplace variable back form "
                  << var_scope->GetNameById(p.first) << " to "
                  << var_scope->GetNameById(p.second);
        }
      }
596
    }
W
wanghuancoder 已提交
597

598 599 600
    VLOG(4) << "End run " << place << " "
            << op_func_node.operator_base_->DebugStringEx(local_scope);

L
Leo Chen 已提交
601
    vec_func_list->emplace_back(op_func_node);
602

W
wanghuancoder 已提交
603
    // gc---------------------------------------------------------------------------
L
Leo Chen 已提交
604
    auto iter = unused_var_map.find(op);
W
wanghuancoder 已提交
605 606 607 608 609 610 611 612 613
    if (iter == unused_var_map.end()) {
      continue;
    }

    auto& delete_vars = iter->second;
    std::deque<std::shared_ptr<memory::Allocation>>* garbages =
        new std::deque<std::shared_ptr<memory::Allocation>>();

    for (auto& var_name : delete_vars) {
614
      auto* var = local_scope->FindVar(var_name);
615
      if (var == nullptr || skip_gc_vars.find(var_name) != skip_gc_vars.end()) {
W
wanghuancoder 已提交
616 617 618
        continue;
      }

619
      VLOG(6) << "Erase variable " << var_name;
W
wanghuancoder 已提交
620 621 622
      if (var->IsType<LoDTensor>()) {
        garbages->emplace_back(
            var->GetMutable<LoDTensor>()->MoveMemoryHolder());
623 624
      } else if (var->IsType<phi::SelectedRows>()) {
        garbages->emplace_back(var->GetMutable<phi::SelectedRows>()
W
wanghuancoder 已提交
625 626 627 628 629 630 631 632 633 634
                                   ->mutable_value()
                                   ->MoveMemoryHolder());
      } else if (var->IsType<LoDTensorArray>()) {
        auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
        for (auto& t : *lod_tensor_arr) {
          garbages->emplace_back(t.MoveMemoryHolder());
        }
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Type %s of variable %s is not supported eager deletion.",
635 636
            framework::ToTypeName(var->Type()),
            var_name));
W
wanghuancoder 已提交
637 638 639 640 641 642
      }
    }
    delete garbages;  // free mem
  }
}

643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
void add_fetch(const std::vector<std::string>& fetch_names,
               framework::BlockDesc* block) {
  auto* fetch_holder = block->Var(kFetchVarName);
  fetch_holder->SetType(proto::VarType::FETCH_LIST);
  fetch_holder->SetPersistable(true);

  int i = 0;
  for (auto& fetch_name : fetch_names) {
    // append fetch op
    auto* op = block->AppendOp();
    op->SetType("fetch_v2");
    op->SetInput("X", {fetch_name});
    op->SetOutput("Out", {kFetchVarName});
    op->SetAttr("col", {static_cast<int>(i)});
    op->CheckAttrs();
    i++;
  }
}

W
wanghuancoder 已提交
662 663 664
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
                                 const std::vector<size_t>& second) {
  std::vector<size_t> out(first.size() + second.size());
665 666
  std::merge(
      first.begin(), first.end(), second.begin(), second.end(), out.begin());
W
wanghuancoder 已提交
667 668 669 670 671 672 673 674 675

  std::vector<size_t>::iterator it;
  it = std::unique(out.begin(), out.end());

  out.resize(std::distance(out.begin(), it));

  return out;
}

X
xiongkun 已提交
676
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
677
                          std::map<int, std::list<int>>* var2min_rw_op,
678 679
                          int cur_op,
                          int rw_var) {
X
xiongkun 已提交
680 681
  // rw_var is inputs or outputs of cur_op
  // this function update the var2min_rw_op set .
682
  if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
683
    (*var2min_rw_op)[rw_var] = std::list<int>();
684
  }
X
xiongkun 已提交
685
  for (auto dep_op : op2dependences.at(cur_op)) {
686
    var2min_rw_op->at(rw_var).remove(dep_op);
X
xiongkun 已提交
687
  }
688
  var2min_rw_op->at(rw_var).push_back(cur_op);
X
xiongkun 已提交
689 690
}

691 692
void AddDownstreamOp(int prior_op_idx,
                     int posterior_op_idx,
693 694 695 696 697 698 699
                     std::map<int, std::list<int>>* op_downstream_map) {
  if (op_downstream_map->find(prior_op_idx) == op_downstream_map->end()) {
    op_downstream_map->emplace(std::make_pair(prior_op_idx, std::list<int>()));
  }
  op_downstream_map->at(prior_op_idx).push_back(posterior_op_idx);
}

700 701
void AddDownstreamOp(int prior_op_idx,
                     int posterior_op_idx,
702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
                     std::map<int, std::list<int>>* op_downstream_map,
                     const std::vector<std::vector<bool>>& op_happens_before) {
  if (op_downstream_map->find(prior_op_idx) != op_downstream_map->end()) {
    for (int op_idx : op_downstream_map->at(prior_op_idx)) {
      if (op_happens_before[op_idx][posterior_op_idx]) {
        VLOG(7) << "Find dependencies " << prior_op_idx << "->" << op_idx
                << "->" << posterior_op_idx << ", skip adding " << prior_op_idx
                << "->" << posterior_op_idx;
        return;
      }
    }
  }

  AddDownstreamOp(prior_op_idx, posterior_op_idx, op_downstream_map);
}

size_t CountDownstreamMap(const std::map<int, std::list<int>>& downstream_map) {
  size_t count = 0;
  for (auto pair : downstream_map) {
    count += pair.second.size();
  }
  return count;
}

const std::string StringizeDownstreamMap(
    const std::map<int, std::list<int>>& downstream_map) {
  std::ostringstream oss;
  for (auto pair : downstream_map) {
    oss << pair.first << " -> ";
731 732
    std::copy(pair.second.begin(),
              pair.second.end(),
733 734 735 736 737 738 739 740 741 742 743 744
              std::ostream_iterator<int>(oss, " "));
    oss << std::endl;
  }
  return oss.str();
}

// convert op2dependences to downstream_map directly. op2dependences is op ->
// it's dependences, we want to get op -> [next ops] map, where ops is the next
// instruction of op.
std::map<int, std::list<int>> GetDownstreamMap(
    const std::map<int, std::set<int>>& op2dependences) {
  std::map<int, std::list<int>> downstream_map;
X
xiongkun 已提交
745 746 747
  for (auto& item : op2dependences) {
    int op = item.first;
    for (auto dep_op : item.second) {
748
      AddDownstreamOp(dep_op, op, &downstream_map);
X
xiongkun 已提交
749 750
    }
  }
751

752 753 754
  VLOG(6) << "downstream count: " << CountDownstreamMap(downstream_map);
  VLOG(6) << "downstream_map: " << std::endl
          << StringizeDownstreamMap(downstream_map);
755

756 757
  return downstream_map;
}
758

759 760 761 762
void ShrinkDownstreamMap(std::map<int, std::list<int>>* downstream_map,
                         std::vector<std::vector<bool>>* op_happens_before,
                         size_t op_num) {
  // remove unnecessary downstream ops
763 764 765 766 767 768 769 770 771 772 773
  // for example, a->b->c
  // a: b, c
  // b: c
  // =>
  // a: b
  // b: c

  // happens_before[i][j] means i should be executed before j
  op_happens_before->resize(op_num);
  for (size_t i = 0; i < op_num; ++i) {
    (*op_happens_before)[i].resize(op_num);
774 775
    std::fill(
        (*op_happens_before)[i].begin(), (*op_happens_before)[i].end(), false);
776 777 778 779 780 781 782 783 784 785 786
  }

  // bfs to get all next ops
  auto bfs = [&](size_t op_idx) {
    std::queue<size_t> q;
    std::vector<bool> visited(op_num, false);
    q.push(op_idx);
    while (!q.empty()) {
      size_t op = q.front();
      q.pop();
      visited[op] = true;
787
      if (!downstream_map->count(op)) {
788 789
        continue;
      }
790
      for (auto next : downstream_map->at(op)) {
791
        if (!visited[next]) {
792 793
          PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx],
                            false,
794 795 796
                            paddle::platform::errors::AlreadyExists(
                                "There exists circle in graph, expected "
                                "%d->%d, but already got %d->%d",
797 798 799 800
                                op_idx,
                                next,
                                next,
                                op_idx));
801 802 803 804 805 806 807 808 809 810 811 812 813 814 815
          (*op_happens_before)[op_idx][next] = true;
          VLOG(8) << "happens before: " << op_idx << " " << next;
          q.push(next);
        }
      }
    }
  };

  for (size_t i = 0; i < op_num; ++i) {
    bfs(i);
  }

  // shrink, find the downstream op that has no other op in the
  // downstream list happens before it
  for (size_t i = 0; i < op_num; ++i) {
816 817 818 819
    if (downstream_map->find(i) == downstream_map->end()) {
      continue;
    }

820
    std::list<int> minumum_nexts;
821
    for (size_t item : downstream_map->at(i)) {
822 823
      bool not_after_any = true;
      // find the op that is not executed after any
824
      for (size_t other_item : downstream_map->at(i)) {
825 826 827 828 829 830 831 832 833 834 835 836
        if ((*op_happens_before)[other_item][item]) {
          VLOG(8) << "happens_before: " << other_item << "->" << item
                  << ", so skip " << item;
          not_after_any = false;
          break;
        }
      }
      if (not_after_any) {
        VLOG(8) << "downstream op of " << i << ": " << item;
        minumum_nexts.push_back(item);
      }
    }
837
    downstream_map->at(i) = minumum_nexts;
838
  }
839 840 841
  VLOG(6) << "downstream count: " << CountDownstreamMap(*downstream_map);
  VLOG(6) << "downstream_map: " << std::endl
          << StringizeDownstreamMap(*downstream_map);
X
xiongkun 已提交
842 843 844
}

std::map<int, std::list<int>> build_op_downstream_map(
845 846
    const std::vector<Instruction>& vec_instruction,
    std::vector<std::vector<bool>>* op_happens_before) {
847 848 849
  auto var2min_rw_op =
      std::map<int, std::list<int>>();  // # map from variable id to read /
                                        // write op id.
X
xiongkun 已提交
850 851 852 853 854 855 856 857
  auto var2recent_write_op =
      std::map<int, int>();  // # map from variable to recent write op.
  auto op2dependences =
      std::map<int, std::set<int>>();  //# map from op to the dependence list,
                                       // op must run after the dependence.
  std::set<int>
      remove_duplicate;  // remove the duplicate between inputs and outputs

858 859
  size_t op_num = vec_instruction.size();

X
xiongkun 已提交
860
  // reserve
861
  for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
X
xiongkun 已提交
862 863 864
    op2dependences[op_idx] = std::set<int>();
  }

865
  for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
X
xiongkun 已提交
866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888
    remove_duplicate.clear();
    // step1: update the op2dependences structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
        if (var2recent_write_op.count(var))
          op2dependences[op_idx].insert(var2recent_write_op[var]);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        if (var2min_rw_op.count(var)) {
          for (auto dep_op : var2min_rw_op[var]) {
            op2dependences[op_idx].insert(dep_op);
          }
        }
      }
    }

    // step2: update 2 var2xxxx data structure
    for (auto& item :
889
         vec_instruction[op_idx].Outputs()) {  // for all write vars
X
xiongkun 已提交
890
      for (auto var : item.second) {
891 892
        var2recent_write_op[var] = op_idx;
        var2min_rw_op[var] = {static_cast<int>(op_idx)};
X
xiongkun 已提交
893 894 895 896 897
        remove_duplicate.insert(var);
      }
    }

    for (auto& item :
898
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
X
xiongkun 已提交
899 900 901
      for (auto var : item.second) {
        if (remove_duplicate.count(var) ==
            0) {  // var in input list and in output list, so remove it.
902
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
903 904 905
        }
      }
    }
906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922

    // NOTE(zhiqiu): The inplace op with `transfer` also changes
    // original output after that so add original output as well
    // original: a->op->a
    // after: a->data_transfer->a'->op->a'->transfer_back->a
    // which means op writes a and a'
    if (!vec_instruction[op_idx].InplaceBackMap().empty()) {
      auto& m = vec_instruction[op_idx].InplaceBackMap();
      for (auto& p : m) {
        auto var = p.second;
        var2recent_write_op[var] = op_idx;
        // var in input list and in output list, so remove it.
        if (remove_duplicate.count(var) == 0) {
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
        }
      }
    }
X
xiongkun 已提交
923
  }
924

925 926 927 928 929
  // NOTE(zhiqiu): the size of downstream != size of op2dependences since there
  // are some ops that have no downstream-op.
  std::map<int, std::list<int>> op_downstream_map =
      GetDownstreamMap(op2dependences);

930
  ShrinkDownstreamMap(&op_downstream_map, op_happens_before, op_num);
931

932 933 934
  // add dependences for random op, make sure that the random op is scheduled
  // sequentially
  const std::set<std::string> random_op_set = {
935 936 937 938 939 940 941 942
      "bernoulli",
      "poisson",
      "multinomial",
      "gaussian_random",
      "truncated_gaussian_random",
      "uniform_random",
      "randint",
      "randperm",
943 944 945 946 947
      "exponential",
      "sampling_id"
      "dropout",
      "class_center_sample",
  };
948

949
  int dependence_op_idx = -1;
950
  for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
951 952
    if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) {
      if (dependence_op_idx != -1) {
953 954
        AddDownstreamOp(
            dependence_op_idx, op_idx, &op_downstream_map, *op_happens_before);
955 956 957 958 959
      }
      dependence_op_idx = op_idx;
    }
  }

960
  // add dependency for communication op
961 962
  auto is_comm_op = [](std::string op) -> bool {
    const std::set<std::string> special_comm_op_set = {
963 964 965 966
        "send",
        "recv",
        "send_v2",
        "recv_v2",
967 968 969 970 971 972 973 974 975
    };
    const std::string communication_op_prefix = "c_";
    if (op.find(communication_op_prefix) != std::string::npos ||
        special_comm_op_set.count(op)) {
      return true;
    }
    return false;
  };

976
  dependence_op_idx = -1;
977
  for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
978
    if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) {
979
      if (dependence_op_idx != -1) {
980 981
        AddDownstreamOp(
            dependence_op_idx, op_idx, &op_downstream_map, *op_happens_before);
982 983 984
        VLOG(4) << "Add depend from "
                << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
                << vec_instruction[op_idx].OpBase()->Type();
985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000
      }
      dependence_op_idx = op_idx;
    }
  }

  // TODO(zhiqiu): there still some cases not handled
  // add dependency for c_sync_comm_stream

  // in program, we can add only one c_sync_comm_stream to sync all
  // communication ops.
  // c_allreduce_sum(a)
  // c_allreduce_sum(b)
  // c_allreduce_sum(c)
  // c_sync_comm_stream(a)
  const std::string kSyncComm = "c_sync_comm_stream";
  dependence_op_idx = -1;
1001
  for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
1002 1003 1004 1005 1006 1007 1008
    if (vec_instruction[op_idx].OpBase()->Type() == kSyncComm) {
      dependence_op_idx = op_idx;
    } else {
      if (dependence_op_idx != -1) {
        VLOG(4) << "Add depend from "
                << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
                << vec_instruction[op_idx].OpBase()->Type();
1009 1010
        AddDownstreamOp(
            dependence_op_idx, op_idx, &op_downstream_map, *op_happens_before);
1011 1012 1013 1014 1015 1016
      }
    }
  }

  // add dependency for coalesce_tensor
  const std::string kCoalesceTensor = "coalesce_tensor";
1017
  for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
    if (vec_instruction[op_idx].OpBase()->Type() == kCoalesceTensor) {
      VLOG(4) << "Add depend for " << kCoalesceTensor << " " << op_idx;
      auto fused_out = vec_instruction[op_idx].Outputs().at("FusedOutput")[0];
      auto outputs = vec_instruction[op_idx].Outputs().at("Output");

      auto is_read = [](const Instruction& inst, int var_id) -> bool {
        for (auto pair : inst.Inputs()) {
          for (auto item : pair.second) {
            if (item == var_id) {
              return true;
            }
          }
        }
        return false;
      };

      auto is_write = [](const Instruction& inst, int var_id) -> bool {
        for (auto pair : inst.Outputs()) {
          for (auto item : pair.second) {
            if (item == var_id) {
              return true;
            }
          }
        }
        return false;
      };

      // find first op that reads fused_out
      auto first_read_fused_out_op = -1;
1047
      for (auto j = op_idx + 1; j < op_num; ++j) {
1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062
        if (is_read(vec_instruction[j], fused_out)) {
          first_read_fused_out_op = j;
          break;
        }
      }

      if (UNLIKELY(first_read_fused_out_op == -1)) {
        VLOG(4) << "No op read FusedOutput";
        continue;
      }

      // find ops that write 'outputs' between (op_index,
      // first_read_fused_out_op)
      // add depend: them->first_read_fused_out_op
      for (auto j = op_idx + 1;
1063 1064
           j < static_cast<size_t>(first_read_fused_out_op);
           ++j) {
1065 1066
        for (auto var_id : outputs) {
          if (is_write(vec_instruction[j], var_id)) {
1067 1068 1069
            AddDownstreamOp(j,
                            first_read_fused_out_op,
                            &op_downstream_map,
1070
                            *op_happens_before);
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089
            VLOG(4) << j << " -> " << first_read_fused_out_op;
            VLOG(4)
                << "Add depend from " << vec_instruction[j].OpBase()->Type()
                << " to "
                << vec_instruction[first_read_fused_out_op].OpBase()->Type();
          }
        }
      }

      // find first op read 'outputs' between (first_read_fused_out_op, end)
      // add depned:  first_read_fused_out_op -> first op that reads 'outputs'

      // special case for consecutive communication ops, for example,
      // FusedOutput = c_sync_calc_stream(FusedOutput)
      // FusedOutput= c_allreduce_sum(FusedOutput)
      // FusedOutput = c_sync_comm_stream(FusedOutput)
      // we should take the last one to add depned instead of
      // 'first_read_fused_out_op'
      size_t target = first_read_fused_out_op;
1090
      for (size_t j = first_read_fused_out_op + 1; j < op_num; ++j) {
1091
        if (j == target + 1 &&
1092 1093
            is_comm_op(vec_instruction[target].OpBase()->Type()) &&
            is_comm_op(vec_instruction[j].OpBase()->Type())) {
1094 1095 1096 1097 1098 1099 1100 1101 1102
          VLOG(4) << "Found consecutive communication ops, "
                  << vec_instruction[target].OpBase()->Type() << " -> "
                  << vec_instruction[j].OpBase()->Type();
          target = j;
          continue;
        }

        for (auto var_id : outputs) {
          if (is_read(vec_instruction[j], var_id)) {
1103
            AddDownstreamOp(target, j, &op_downstream_map, *op_happens_before);
1104 1105 1106 1107 1108 1109 1110 1111 1112
            VLOG(4) << target << " -> " << j;
            VLOG(4) << "Add depend from "
                    << vec_instruction[target].OpBase()->Type() << " to "
                    << vec_instruction[j].OpBase()->Type();
          }
        }
      }
    }
  }
1113

1114 1115 1116 1117 1118
  if (FLAGS_new_executor_sequential_run) {
    dependence_op_idx = -1;
    for (size_t op_idx = 0; op_idx < op_num; ++op_idx) {
      if (!IsCpuOp(vec_instruction[op_idx])) {
        if (dependence_op_idx != -1) {
1119 1120 1121
          AddDownstreamOp(dependence_op_idx,
                          op_idx,
                          &op_downstream_map,
1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133
                          *op_happens_before);
          VLOG(4) << "Add depend from "
                  << vec_instruction[dependence_op_idx].OpBase()->Type() << "("
                  << dependence_op_idx << ") to "
                  << vec_instruction[op_idx].OpBase()->Type() << "(" << op_idx
                  << ")";
        }
        dependence_op_idx = op_idx;
      }
    }
  }

1134 1135 1136 1137 1138
  VLOG(8) << "downstream count: " << CountDownstreamMap(op_downstream_map);
  VLOG(8) << "downstream_map: " << std::endl
          << StringizeDownstreamMap(op_downstream_map);

  return op_downstream_map;
X
xiongkun 已提交
1139 1140
}

1141
}  // namespace interpreter
W
wanghuancoder 已提交
1142 1143
}  // namespace framework
}  // namespace paddle