interpretercore_util.cc 27.5 KB
Newer Older
W
wanghuancoder 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
15

16 17
#include <algorithm>

18
#include "paddle/fluid/framework/details/nan_inf_utils.h"
W
wanghuancoder 已提交
19
#include "paddle/fluid/framework/executor_gc_helper.h"
20
#include "paddle/fluid/framework/new_executor/data_transfer.h"
X
xiongkun 已提交
21 22 23
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
24
#include "paddle/phi/core/kernel_context.h"
25
#include "paddle/phi/core/kernel_factory.h"
W
wanghuancoder 已提交
26

L
Leo Chen 已提交
27 28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

31
PADDLE_DEFINE_EXPORTED_bool(
32 33
    new_executor_serial_run,
    false,
34
    "Enable serial execution for standalone executor, used for debug.");
35

36
DECLARE_bool(use_mkldnn);
37
DECLARE_bool(check_nan_inf);
38

W
wanghuancoder 已提交
39 40
namespace paddle {
namespace framework {
41
namespace interpreter {
42

43
using VariableIdMap = std::map<std::string, std::vector<int>>;
L
liutiexing 已提交
44
constexpr size_t kPrepareWorkQueueIdx = 2;
45
const char blocking_queue_prefix[] = "lod_tensor_blocking_queue";
L
liutiexing 已提交
46

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
    size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
  std::vector<WorkQueueOptions> group_options;
  // for execute host Kernel
  group_options.emplace_back(/*name*/ "HostTasks",
                             /*num_threads*/ host_num_threads,
                             /*allow_spinning*/ true,
                             /*always_spinning*/ false,
                             /*track_task*/ false,
                             /*detached*/ true,
                             /*events_waiter*/ waiter);
  // for launch device Kernel
  group_options.emplace_back(/*name*/ "DeviceKernelLaunch",
                             /*num_threads*/ device_num_threads,
                             /*allow_spinning*/ true,
                             /*always_spinning*/ true,
                             /*track_task*/ false,
                             /*detached*/ true,
                             /*events_waiter*/ waiter);
  // for prepare deps and others
  group_options.emplace_back(/*name*/ "Prepare",
                             /*num_threads*/ 1,
                             /*allow_spinning*/ true,
                             /*always_spinning*/ false,
                             /*track_task*/ false,
                             /*detached*/ true,
                             /*events_waiter*/ waiter);
  return group_options;
}

AsyncWorkQueue::AsyncWorkQueue(size_t host_num_threads,
                               size_t device_num_threads,
                               EventsWaiter* waiter)
    : host_num_thread_(host_num_threads) {
  queue_group_ = CreateWorkQueueGroup(
      ConstructWorkQueueOptions(host_num_threads, device_num_threads, waiter));
}

85 86
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
                             std::function<void()> fn) {
87
  VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " ";
88 89
  // NOTE(zhiqiu): use the second queue of size of, so only one thread is used.
  if (FLAGS_new_executor_serial_run) {
90 91 92 93 94 95 96
    queue_group_->AddTask(static_cast<size_t>(OpFuncType::kQueueAsync),
                          std::move(fn));
  } else {
    queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
  }
}

97 98 99 100 101 102 103 104 105 106 107 108 109 110
std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicDeps(const std::vector<size_t>& dependecy_count) {
  VLOG(4) << "PrepareAtomicDeps";
  return queue_group_->AddAwaitableTask(
      kPrepareWorkQueueIdx, interpreter::PrepareAtomicDeps, dependecy_count);
}

std::future<std::unique_ptr<AtomicVectorSizeT>>
AsyncWorkQueue::PrepareAtomicVarRef(
    const std::vector<VariableMetaInfo>& vec_meta_info) {
  VLOG(4) << "PrepareAtomicVarRef";
  return queue_group_->AddAwaitableTask(
      kPrepareWorkQueueIdx, interpreter::PrepareAtomicVarRef, vec_meta_info);
}
W
wanghuancoder 已提交
111

112
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicDeps(
113
    const std::vector<size_t>& dependecy_count) {
114
  VLOG(4) << "PrepareAtomicDeps";
115 116 117 118 119 120 121

  auto op_deps = std::make_unique<AtomicVectorSizeT>(dependecy_count.size());
  for (size_t i = 0; i < dependecy_count.size(); ++i) {
    (*op_deps)[i] = dependecy_count[i];
  }
  VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
  return op_deps;
122 123
}

124
std::unique_ptr<AtomicVectorSizeT> PrepareAtomicVarRef(
125
    const std::vector<VariableMetaInfo>& vec_meta_info) {
126
  VLOG(4) << "PrepareAtomicVarRef";
127 128 129 130 131 132
  auto var_ref = std::make_unique<AtomicVectorSizeT>(vec_meta_info.size());
  for (size_t i = 0; i < vec_meta_info.size(); ++i) {
    (*var_ref)[i] = vec_meta_info[i].var_ref_count_;
  }
  VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
  return var_ref;
133 134
}

W
wanghuancoder 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
  auto* var_desc = block.FindVar(name);
  if (var_desc == nullptr || var_desc->Persistable()) {
    return false;
  }

  auto type = var_desc->Proto()->type().type();

  return type == proto::VarType::LOD_TENSOR ||
         type == proto::VarType::SELECTED_ROWS ||
         type == proto::VarType::LOD_TENSOR_ARRAY;
}

std::unordered_map<const paddle::framework::OperatorBase*,
                   std::vector<std::string>>
L
Leo Chen 已提交
150 151
get_unused_vars(const BlockDesc& block,
                const std::vector<std::shared_ptr<OperatorBase>>& ops) {
W
wanghuancoder 已提交
152 153 154
  std::unordered_map<std::string, size_t> var_op_idx_map;

  for (size_t i = 0; i < ops.size(); ++i) {
L
Leo Chen 已提交
155
    const auto& op = ops[i];
W
wanghuancoder 已提交
156 157 158 159 160 161 162 163 164 165

    OpInOutInfo info;
    for (auto& name_pair : op->Inputs()) {
      for (auto& name : name_pair.second) {
        if (!var_can_be_deleted(name, block)) {
          continue;
        }

        // var can be gc-ed
        if (!info.IsBuilt()) {
L
Leo Chen 已提交
166
          info.Build(op.get());
W
wanghuancoder 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
        }

        if (info.IsInArgBufferNeeded(name)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        } else {
          VLOG(10) << "Skip reference count computing of variable "
                   << name_pair.first << "(" << name << ") in Operator "
                   << op->Type();
        }
      }
    }

    for (auto& name_pair : op->Outputs()) {
      for (auto& name : name_pair.second) {
        if (var_can_be_deleted(name, block)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        }
      }
    }
  }

  std::unordered_map<const OperatorBase*, std::vector<std::string>> result;
  for (auto& name_op_idx_pair : var_op_idx_map) {
    auto& name = name_op_idx_pair.first;
    size_t op_idx = name_op_idx_pair.second;
L
Leo Chen 已提交
194 195

    result[ops[op_idx].get()].emplace_back(name);
196
    VLOG(4) << ops[op_idx].get()->Type() << " " << name;
W
wanghuancoder 已提交
197
  }
198
  VLOG(4) << "gc map size:" << result.size();
W
wanghuancoder 已提交
199 200 201
  return result;
}

202
void build_variable_scope(const framework::BlockDesc& block,
203 204
                          VariableScope* var_scope,
                          bool use_local_scope) {
205 206 207 208 209 210 211 212
  VLOG(3) << "Creating Variables";
  auto inner_scope = var_scope->GetMutableScope();

  // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
  // created in var_scope.scope_ , and other scope is created in local scope.
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();

213
  for (auto& var_desc : block.AllVars()) {
214
    auto var_name = var_desc->Name();
X
xiongkun 已提交
215 216 217
    // TODO(xiongkun): user may create a variable with name that exists before.
    // under such circumstances, we should raise a error. Currently we can't
    // get the var_desc of startup_program, so leave it later.
218
    if (var_name == framework::kEmptyVarName) {
W
wanghuancoder 已提交
219 220
      continue;
    }
221

222 223
    if (var_desc->Persistable()) {
      auto* ptr = inner_scope->Var(var_name);
W
wanghuancoder 已提交
224

225
      VLOG(3) << "Initialize Variable " << var_name;
226 227
      // NOTE(zhiqiu): if var exists in scope and the type is right,
      // InitializeVariable will not create a new variable.
228 229 230
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
              << ptr << " type is " << static_cast<int>(var_desc->GetType());
231
    } else {
232 233 234 235 236
      auto* ptr = local_scope->Var(var_name);
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
              << ptr << "Variable Type "
              << static_cast<int>(var_desc->GetType());
W
wanghuancoder 已提交
237
    }
238
    var_scope->AddVar(var_name, var_desc);
W
wanghuancoder 已提交
239 240 241
  }
}

L
Leo Chen 已提交
242
void create_all_ops(const framework::BlockDesc& block,
X
xiongkun 已提交
243
                    std::vector<std::unique_ptr<OperatorBase>>* ops) {
244 245
  for (auto& op : block.AllOps()) {
    VLOG(3) << "CreateOp from : " << op->Type();
W
wanghuancoder 已提交
246 247 248 249 250

    auto& info = OpInfoMap::Instance().Get(op->Type());

    const VariableNameMap& inputs_names = op->Inputs();
    const VariableNameMap& outputs_names = op->Outputs();
251

W
wanghuancoder 已提交
252 253 254 255 256 257 258
    AttributeMap op_attr_map = op->GetAttrMap();

    if (info.Checker() != nullptr) {
      info.Checker()->Check(&op_attr_map);
    }
    auto op_base =
        info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
259 260 261 262 263 264 265 266 267 268

#ifdef PADDLE_WITH_MKLDNN
    if (FLAGS_use_mkldnn) {
      if (op->HasAttr("use_mkldnn")) {
        VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
        op_base->SetAttr("use_mkldnn", true);
      }
    }
#endif

X
xiongkun 已提交
269
    ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
W
wanghuancoder 已提交
270
  }
271 272 273
}

std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
274 275
    const VariableNameMap& var_name_map,
    VariableScope* var_scope,
276
    Scope* local_scope,
277
    bool enforce_exist = true) {
278 279 280 281 282 283 284 285
  VariableValueMap name2var;
  VariableIdMap name2id;
  for (auto& item : var_name_map) {
    std::vector<Variable*> vars;
    std::vector<int> ids;
    vars.reserve(item.second.size());

    for (auto& var_name : item.second) {
286 287 288 289 290 291 292 293 294 295 296 297
      if (!var_scope->HasVar(var_name)) {
        // Hot fix for variables used in dataloader, like
        // 'lod_tensor_blocking_queue_0' These variables may be created in
        // scope, and it is not existed as variable in program.
        if (var_name.find(blocking_queue_prefix) != std::string::npos &&
            local_scope->FindVar(var_name)) {
          var_scope->AddVar(var_name, nullptr);
        } else if (!enforce_exist) {
          // skip the non-exist variable: such as recurrent_grad
          VLOG(4) << var_name << " don't exist in variable scope, skip it!";
          continue;
        }
298
      }
299
      auto* var = local_scope->FindVar(var_name);
300
      auto var_id = var_scope->VarId(var_name);
301
      vars.push_back(var);
302 303 304 305 306 307 308
      ids.push_back(var_id);
    }
    name2var[item.first] = std::move(vars);
    name2id[item.first] = std::move(ids);
  }
  return std::make_tuple(name2var, name2id);
}
W
wanghuancoder 已提交
309

310 311 312 313 314 315 316 317 318 319 320 321
void apply_device_guard(const OperatorBase* op_base,
                        const platform::Place& place,
                        OpKernelType* expected_kernel_key) {
  bool need_change_place =
      (op_base->HasAttr("op_device") &&
       (op_base->Attr<std::string>("op_device").length() > 0));
  if (need_change_place) {
    auto& op_device = op_base->Attr<std::string>("op_device");
    if (op_device == "cpu" || platform::is_cpu_place(place)) {
      VLOG(3) << "Switch into CPUPlace by device_guard.";
      expected_kernel_key->place_ = platform::CPUPlace();
    } else if (op_device.find("gpu") != std::string::npos &&
322 323 324 325
               platform::is_gpu_place(place)) {
      // when the Op that does not have GPUKernel is assigned to GPU, the
      // CPUKernel will be executed and a warning will be given at the same
      // time.
326 327 328 329 330 331 332 333 334 335
      if (op_base->SupportGPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no CUDA implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
    } else if (op_device.find("npu") != std::string::npos &&
               platform::is_npu_place(place)) {
      // when the Op that does not have NPUKernel is assigned to NPU, the
      // CPUKernel will be executed and a warning will be given at the same
      // time.
      if (op_base->SupportNPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no NPU implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
    } else if (op_device.find("xpu") != std::string::npos &&
               platform::is_xpu_place(place)) {
      // when the Op that does not have XPUKernel is assigned to XPU, the
      // CPUKernel will be executed and a warning will be given at the same
      // time.
      if (op_base->SupportXPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no XPU implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
366 367 368 369 370 371 372
    } else {
      PADDLE_THROW(
          platform::errors::Fatal("Unsupported current place %s", op_device));
    }
  }
}

373
void deal_operator_base(const platform::Place& place,
L
Leo Chen 已提交
374 375
                        const VariableScope* var_scope,
                        std::shared_ptr<OperatorBase> op_base,
376 377
                        OpFuncNode* op_func_node,
                        Scope* local_scope) {
378 379 380 381
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  // input, output is prepared. set the other attributes.
  op_func_node->operator_base_ = op_base;
382
  if (IsSupportedHetePlace(place)) {
383 384 385 386 387 388 389 390
    op_func_node->type_ = OpFuncType::kQueueAsync;
  } else if (platform::is_cpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueSync;
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("Unsupported current place %s", place));
  }

391
  op_func_node->kernel_func_ = nullptr;
392
  op_base->Run(*local_scope, place);  // Run without data transformer.
393 394 395 396 397 398 399 400 401 402 403 404

  std::unordered_set<int> no_data_transform_index;
  for (auto& it : op_func_node->input_index) {
    for (auto& id : it.second) {
      no_data_transform_index.emplace(id);
    }
  }
  op_func_node->no_data_transform_index =
      no_data_transform_index;  // all index is no-need-transform
  op_func_node->dev_ctx_ = dev_ctx;
}

405
void build_op_func_list(const platform::Place& place,
406
                        const framework::BlockDesc& block,
407
                        const std::set<std::string>& skip_gc_vars,
408
                        std::vector<OpFuncNode>* vec_func_list,
409 410
                        VariableScope* var_scope,
                        bool use_local_scope) {
411 412
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();
X
xiongkun 已提交
413 414 415 416 417 418 419 420 421 422 423 424 425
  std::vector<std::unique_ptr<OperatorBase>>
      ops_unique;  // its elements will be moved to vec_func_list
  // Step 1: create all ops for current block.
  create_all_ops(block, &ops_unique);
  // If gc is enabled and block size > 1
  const ProgramDesc& main_program = *block.Program();
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
      main_program, block.ID(), ops_unique);

L
Leo Chen 已提交
426 427 428 429
#ifdef PADDLE_WITH_MKLDNN
  platform::RegisterModelLayout(ops_unique, place);
#endif

430 431
  // its elements will be moved to vec_func_list
  std::vector<std::shared_ptr<OperatorBase>> ops;
X
xiongkun 已提交
432 433 434
  for (auto& op_unique : ops_unique) {
    ops.emplace_back(std::move(op_unique));
  }
435
  auto unused_var_map = get_unused_vars(block, ops);
W
wanghuancoder 已提交
436

L
Leo Chen 已提交
437 438
  for (size_t i = 0; i < ops.size(); ++i) {
    auto op = ops[i].get();
439
    VLOG(6) << "Build OpFuncNode from : " << op->Type();
W
wanghuancoder 已提交
440 441 442 443 444

    auto inputs_names = op->Inputs();
    auto outputs_names = op->Outputs();

    VariableValueMap ins_map;
445
    VariableIdMap ins_name2id;
446
    bool enforce_exist = true;
W
wanghuancoder 已提交
447 448 449 450 451 452 453
    if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
        op->Type() == "rnn_memory_helper_grad" ||
        op->Type() == "conditional_block" ||
        op->Type() == "conditional_block_grad" || op->Type() == "while" ||
        op->Type() == "while_grad") {
      enforce_exist = false;
    }
454
    std::tie(ins_map, ins_name2id) =
455
        build_variable_map(inputs_names, var_scope, local_scope, enforce_exist);
W
wanghuancoder 已提交
456 457

    VariableValueMap outs_map;
458
    VariableIdMap outs_name2id;
459 460
    std::tie(outs_map, outs_name2id) = build_variable_map(
        outputs_names, var_scope, local_scope, enforce_exist);
W
wanghuancoder 已提交
461

462
    // step 1: build OpFuncNode
W
wanghuancoder 已提交
463
    OpFuncNode op_func_node;
464
    op_func_node.operator_base_ = ops[i];
W
wanghuancoder 已提交
465 466
    op_func_node.input_index = ins_name2id;
    op_func_node.output_index = outs_name2id;
467
    VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
468

469 470 471 472 473 474 475 476 477 478
#ifdef PADDLE_WITH_ASCEND_CL
    // NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable
    // values, but only through special `float_status` to checks whether
    // the operation is overflow. More about `float_status`, see:
    // https://gitee.com/ascend/modelzoo/issues/I3NF8V?from=project-issue
    if (FLAGS_check_nan_inf) {
      framework::details::NPUAllocAndClearFloatStatus(*op, *local_scope, place);
    }
#endif

479
    if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
480
      // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
481
      deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
W
wanghuancoder 已提交
482
    } else {
483 484
      auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(op));
485 486 487 488
      // construct RuntimeContext and analysis KernelType
      RuntimeContext runtime_context({}, {});
      runtime_context.inputs.swap(ins_map);
      runtime_context.outputs.swap(outs_map);
489

490
      Scope scope, *runtime_scope = &scope;
491 492 493 494
      // NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
      // But some OPs do have such behavior (e.g., cinn_launch OP). Here special
      // treatment for them.
      if (op_with_kernel->Type() == "cinn_launch") {
495 496 497 498
        VLOG(6) << "OP(" << op_with_kernel->Type()
                << ") use scope in kernel, "
                   "so pass a real scope to "
                   "ExecutionContext";
499 500 501
        runtime_scope = local_scope;
      }

502 503
      auto& pool = platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(place);
504 505
      auto exec_ctx = ExecutionContext(
          *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
506 507
      auto expected_kernel_key =
          op_with_kernel->GetExpectedKernelType(exec_ctx);
508 509
      // change device by the device_guard()
      apply_device_guard(op, place, &expected_kernel_key);
510
      VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
511

512
      // step 2. select op kernel
513 514
      auto run_phi_kernel = false;
      if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
515
              op_with_kernel->Type())) {
516 517
        auto pt_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
        auto pt_kernel_name = op_with_kernel->PhiKernelSignature()->name;
518

519 520
        if (op_with_kernel->PhiKernel()->IsValid()) {
          run_phi_kernel = true;
521
        } else {
522
          if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) {
523 524
            auto pt_cpu_kernel_key = FallBackToCpu(
                expected_kernel_key, pt_kernel_key, *op_with_kernel);
525
            op_with_kernel->ResetPhiKernel(
526
                new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
527
                    pt_kernel_name, pt_cpu_kernel_key)));
528
            if (op_with_kernel->PhiKernel()->IsValid()) {
529 530 531
              VLOG(6) << "Static mode PrepareImpl - kernel name: "
                      << pt_kernel_name
                      << " | kernel key: " << pt_cpu_kernel_key
532
                      << " | kernel: " << *(op_with_kernel->PhiKernel());
533 534
              op_with_kernel->ResetKernelType(new OpKernelType(
                  TransPhiKernelKeyToOpKernelType(pt_cpu_kernel_key)));
535
              run_phi_kernel = true;
536 537 538 539
            }
          }
        }
      }
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
      if (!run_phi_kernel) {
        op_with_kernel->ChooseKernel(exec_ctx);
        op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
      } else {
        op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
      }
      auto kernel_type = *(op_with_kernel->kernel_type());
      if (kernel_type.place_ != dev_ctx->GetPlace()) {
        dev_ctx = pool.Get(kernel_type.place_);
      }
      op_func_node.dev_ctx_ = dev_ctx;
      if (IsSupportedHetePlace(kernel_type.place_)) {
        op_func_node.type_ = OpFuncType::kQueueAsync;
      } else if (platform::is_cpu_place(kernel_type.place_)) {
        op_func_node.type_ = OpFuncType::kQueueSync;
      } else {
        PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
                                             kernel_type.place_));
      }
559
      VLOG(3) << op_with_kernel->Type()
560 561 562 563 564
              << " : finally selected kernel_key: " << kernel_type;

      // step 3. data transform
      VariableValueMap& ins_map_temp = runtime_context.inputs;
      VariableValueMap& outs_map_temp = runtime_context.outputs;
565 566 567 568 569 570 571
      ApplyDataTransform(kernel_type,
                         place,
                         &ins_map_temp,
                         &outs_map_temp,
                         var_scope,
                         &op_func_node,
                         vec_func_list,
572 573 574 575 576 577 578 579 580 581 582 583 584
                         use_local_scope);

      // step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc for
      // why.
      if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
            op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
        InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
        // TODO(Aurelius84): In case of control flow ops, they are NOT
        // inheritted from OperatorWithKernel.
        op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
      }

      // step 5. run kernel
585
      if (run_phi_kernel) {
586
        phi::KernelContext pt_kernel_context;
587 588
        op_with_kernel->BuildPhiKernelContext(
            runtime_context, dev_ctx, &pt_kernel_context);
589
        (*op_func_node.pt_kernel_)(&pt_kernel_context);
590
      } else {
591 592 593
        // the place of exec_ctx maybe has changed.
        op_func_node.kernel_func_(ExecutionContext(
            *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context));
594
      }
595

596 597
      // post-process grad_op.outputs if need cast complex grad into real
      // grad.
598
      // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
599
      if (framework::IsComplexType(kernel_type.data_type_)) {
600 601 602 603 604 605 606
        interpreter::HandleComplexGradToRealGrad(op_func_node,
                                                 place,
                                                 outputs_names,
                                                 &runtime_context.outputs,
                                                 var_scope,
                                                 vec_func_list,
                                                 local_scope);
607
      }
608 609
      if (!op_func_node.inplace_back_map.empty()) {
        auto& m = op_func_node.inplace_back_map;
610 611
        // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in
        // operator.cc
612 613 614
        for (auto& p : m) {
          auto* transformed_tensor =
              GetMutableLoDTensorOrSelectedRowsValueFromVar(
615
                  local_scope->FindVar(var_scope->GetNameById(p.first)));
616
          auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
617
              local_scope->FindVar(var_scope->GetNameById(p.second)));
618 619 620 621 622 623
          original_tensor->ShareDataWith(*transformed_tensor);
          VLOG(4) << "Transfer inplace variable back form "
                  << var_scope->GetNameById(p.first) << " to "
                  << var_scope->GetNameById(p.second);
        }
      }
624 625 626 627 628 629

      // for debug nan/inf
      if (FLAGS_check_nan_inf) {
        VLOG(4) << "Check nan/inf";
        framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place);
      }
630
    }
W
wanghuancoder 已提交
631

632 633 634
    VLOG(4) << "End run " << place << " "
            << op_func_node.operator_base_->DebugStringEx(local_scope);

L
Leo Chen 已提交
635
    vec_func_list->emplace_back(op_func_node);
636

W
wanghuancoder 已提交
637
    // gc---------------------------------------------------------------------------
L
Leo Chen 已提交
638
    auto iter = unused_var_map.find(op);
W
wanghuancoder 已提交
639 640 641 642 643 644 645 646 647
    if (iter == unused_var_map.end()) {
      continue;
    }

    auto& delete_vars = iter->second;
    std::deque<std::shared_ptr<memory::Allocation>>* garbages =
        new std::deque<std::shared_ptr<memory::Allocation>>();

    for (auto& var_name : delete_vars) {
648
      auto* var = local_scope->FindVar(var_name);
649
      if (var == nullptr || skip_gc_vars.find(var_name) != skip_gc_vars.end()) {
W
wanghuancoder 已提交
650 651 652
        continue;
      }

653
      VLOG(6) << "Erase variable " << var_name;
W
wanghuancoder 已提交
654 655 656
      if (var->IsType<LoDTensor>()) {
        garbages->emplace_back(
            var->GetMutable<LoDTensor>()->MoveMemoryHolder());
657 658
      } else if (var->IsType<phi::SelectedRows>()) {
        garbages->emplace_back(var->GetMutable<phi::SelectedRows>()
W
wanghuancoder 已提交
659 660 661 662 663 664 665 666 667 668
                                   ->mutable_value()
                                   ->MoveMemoryHolder());
      } else if (var->IsType<LoDTensorArray>()) {
        auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
        for (auto& t : *lod_tensor_arr) {
          garbages->emplace_back(t.MoveMemoryHolder());
        }
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Type %s of variable %s is not supported eager deletion.",
669 670
            framework::ToTypeName(var->Type()),
            var_name));
W
wanghuancoder 已提交
671 672 673 674 675 676
      }
    }
    delete garbages;  // free mem
  }
}

677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
void add_fetch(const std::vector<std::string>& fetch_names,
               framework::BlockDesc* block) {
  auto* fetch_holder = block->Var(kFetchVarName);
  fetch_holder->SetType(proto::VarType::FETCH_LIST);
  fetch_holder->SetPersistable(true);

  int i = 0;
  for (auto& fetch_name : fetch_names) {
    // append fetch op
    auto* op = block->AppendOp();
    op->SetType("fetch_v2");
    op->SetInput("X", {fetch_name});
    op->SetOutput("Out", {kFetchVarName});
    op->SetAttr("col", {static_cast<int>(i)});
    op->CheckAttrs();
    i++;
  }
}

W
wanghuancoder 已提交
696 697 698
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
                                 const std::vector<size_t>& second) {
  std::vector<size_t> out(first.size() + second.size());
699 700
  std::merge(
      first.begin(), first.end(), second.begin(), second.end(), out.begin());
W
wanghuancoder 已提交
701 702 703 704 705 706 707 708 709

  std::vector<size_t>::iterator it;
  it = std::unique(out.begin(), out.end());

  out.resize(std::distance(out.begin(), it));

  return out;
}

710
}  // namespace interpreter
W
wanghuancoder 已提交
711 712
}  // namespace framework
}  // namespace paddle