interpretercore_util.cc 28.8 KB
Newer Older
W
wanghuancoder 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14

W
wanghuancoder 已提交
15
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
16

17 18
#include <algorithm>

19
#include "paddle/fluid/framework/details/nan_inf_utils.h"
W
wanghuancoder 已提交
20
#include "paddle/fluid/framework/executor_gc_helper.h"
21
#include "paddle/fluid/framework/new_executor/data_transfer.h"
22
#include "paddle/fluid/memory/stats.h"
X
xiongkun 已提交
23 24 25
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
26
#include "paddle/fluid/operators/ops_extra_info.h"
27
#include "paddle/phi/core/kernel_context.h"
28
#include "paddle/phi/core/kernel_factory.h"
W
wanghuancoder 已提交
29

L
Leo Chen 已提交
30 31 32 33
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

34
PADDLE_DEFINE_EXPORTED_bool(
35 36
    new_executor_serial_run,
    false,
37
    "Enable serial execution for standalone executor, used for debug.");
38

39 40 41 42 43
PADDLE_DEFINE_EXPORTED_bool(
    new_executor_log_memory_stats,
    false,
    "Log memory stats after each op runs, just used for debug.");

44
DECLARE_bool(use_mkldnn);
45
DECLARE_bool(check_nan_inf);
46

W
wanghuancoder 已提交
47 48
namespace paddle {
namespace framework {
49
namespace interpreter {
50

51
using VariableIdMap = std::map<std::string, std::vector<int>>;
L
liutiexing 已提交
52

53
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
54
    size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
55 56 57 58 59 60 61 62 63 64 65 66 67
  std::vector<WorkQueueOptions> group_options;
  // for execute host Kernel
  group_options.emplace_back(/*name*/ "HostTasks",
                             /*num_threads*/ host_num_threads,
                             /*allow_spinning*/ true,
                             /*always_spinning*/ false,
                             /*track_task*/ false,
                             /*detached*/ true,
                             /*events_waiter*/ waiter);
  // for launch device Kernel
  group_options.emplace_back(/*name*/ "DeviceKernelLaunch",
                             /*num_threads*/ device_num_threads,
                             /*allow_spinning*/ true,
68
                             /*always_spinning*/ false,
69 70 71 72 73 74 75 76 77 78
                             /*track_task*/ false,
                             /*detached*/ true,
                             /*events_waiter*/ waiter);
  return group_options;
}

AsyncWorkQueue::AsyncWorkQueue(size_t host_num_threads,
                               size_t device_num_threads,
                               EventsWaiter* waiter)
    : host_num_thread_(host_num_threads) {
79 80
  queue_group_ = CreateWorkQueueGroup(
      ConstructWorkQueueOptions(host_num_threads, device_num_threads, waiter));
81 82
}

83 84
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
                             std::function<void()> fn) {
85
  VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " ";
86 87
  // NOTE(zhiqiu): use the second queue of size of, so only one thread is used.
  if (FLAGS_new_executor_serial_run) {
88 89 90 91 92 93 94
    queue_group_->AddTask(static_cast<size_t>(OpFuncType::kQueueAsync),
                          std::move(fn));
  } else {
    queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
  }
}

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
void LogDeviceMemoryStats(const platform::Place& place) {
  if (FLAGS_new_executor_log_memory_stats && platform::is_gpu_place(place)) {
    VLOG(0) << "memory_allocated: "
            << static_cast<double>(memory::DeviceMemoryStatCurrentValue(
                   "Allocated", place.device)) /
                   1024 / 1024
            << " MB";
    VLOG(0) << "max_memory_allocated: "
            << static_cast<double>(memory::DeviceMemoryStatPeakValue(
                   "Allocated", place.device)) /
                   1024 / 1024
            << " MB";
  }
}

W
wanghuancoder 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
  auto* var_desc = block.FindVar(name);
  if (var_desc == nullptr || var_desc->Persistable()) {
    return false;
  }

  auto type = var_desc->Proto()->type().type();

  return type == proto::VarType::LOD_TENSOR ||
         type == proto::VarType::SELECTED_ROWS ||
         type == proto::VarType::LOD_TENSOR_ARRAY;
}

std::unordered_map<const paddle::framework::OperatorBase*,
                   std::vector<std::string>>
L
Leo Chen 已提交
125 126
get_unused_vars(const BlockDesc& block,
                const std::vector<std::shared_ptr<OperatorBase>>& ops) {
W
wanghuancoder 已提交
127 128 129
  std::unordered_map<std::string, size_t> var_op_idx_map;

  for (size_t i = 0; i < ops.size(); ++i) {
L
Leo Chen 已提交
130
    const auto& op = ops[i];
W
wanghuancoder 已提交
131 132 133 134 135 136 137 138 139 140

    OpInOutInfo info;
    for (auto& name_pair : op->Inputs()) {
      for (auto& name : name_pair.second) {
        if (!var_can_be_deleted(name, block)) {
          continue;
        }

        // var can be gc-ed
        if (!info.IsBuilt()) {
L
Leo Chen 已提交
141
          info.Build(op.get());
W
wanghuancoder 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
        }

        if (info.IsInArgBufferNeeded(name)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        } else {
          VLOG(10) << "Skip reference count computing of variable "
                   << name_pair.first << "(" << name << ") in Operator "
                   << op->Type();
        }
      }
    }

    for (auto& name_pair : op->Outputs()) {
      for (auto& name : name_pair.second) {
        if (var_can_be_deleted(name, block)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        }
      }
    }
  }

  std::unordered_map<const OperatorBase*, std::vector<std::string>> result;
  for (auto& name_op_idx_pair : var_op_idx_map) {
    auto& name = name_op_idx_pair.first;
    size_t op_idx = name_op_idx_pair.second;
L
Leo Chen 已提交
169 170

    result[ops[op_idx].get()].emplace_back(name);
171
    VLOG(4) << ops[op_idx].get()->Type() << " " << name;
W
wanghuancoder 已提交
172
  }
173
  VLOG(4) << "gc map size:" << result.size();
W
wanghuancoder 已提交
174 175 176
  return result;
}

177
void build_variable_scope(const framework::BlockDesc& block,
178 179
                          VariableScope* var_scope,
                          bool use_local_scope) {
180 181 182 183 184 185 186 187
  VLOG(3) << "Creating Variables";
  auto inner_scope = var_scope->GetMutableScope();

  // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
  // created in var_scope.scope_ , and other scope is created in local scope.
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();

188
  for (auto& var_desc : block.AllVars()) {
189
    auto var_name = var_desc->Name();
X
xiongkun 已提交
190 191 192
    // TODO(xiongkun): user may create a variable with name that exists before.
    // under such circumstances, we should raise a error. Currently we can't
    // get the var_desc of startup_program, so leave it later.
193
    if (var_name == framework::kEmptyVarName) {
W
wanghuancoder 已提交
194 195
      continue;
    }
196

197 198
    if (var_desc->Persistable()) {
      auto* ptr = inner_scope->Var(var_name);
W
wanghuancoder 已提交
199

200
      VLOG(3) << "Initialize Variable " << var_name;
201 202
      // NOTE(zhiqiu): if var exists in scope and the type is right,
      // InitializeVariable will not create a new variable.
203 204 205
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
              << ptr << " type is " << static_cast<int>(var_desc->GetType());
206
    } else {
207 208 209 210 211
      auto* ptr = local_scope->Var(var_name);
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
              << ptr << "Variable Type "
              << static_cast<int>(var_desc->GetType());
W
wanghuancoder 已提交
212
    }
213
    var_scope->AddVar(var_name, var_desc);
W
wanghuancoder 已提交
214 215 216
  }
}

L
Leo Chen 已提交
217
void create_all_ops(const framework::BlockDesc& block,
X
xiongkun 已提交
218
                    std::vector<std::unique_ptr<OperatorBase>>* ops) {
219
  for (auto& op : block.AllOps()) {
220
    auto op_type = op->Type();
221
    VLOG(8) << "CreateOp from : " << op_type;
W
wanghuancoder 已提交
222

223
    auto& info = OpInfoMap::Instance().Get(op_type);
W
wanghuancoder 已提交
224 225 226

    const VariableNameMap& inputs_names = op->Inputs();
    const VariableNameMap& outputs_names = op->Outputs();
227

W
wanghuancoder 已提交
228
    AttributeMap op_attr_map = op->GetAttrMap();
229
    AttributeMap op_runtime_attr_map = op->GetRuntimeAttrMap();
W
wanghuancoder 已提交
230 231 232 233

    if (info.Checker() != nullptr) {
      info.Checker()->Check(&op_attr_map);
    }
234 235 236 237 238 239 240

    const auto& extra_attr_checkers =
        operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(op_type);
    for (const auto& checker : extra_attr_checkers) {
      checker(&op_runtime_attr_map, false);
    }

W
wanghuancoder 已提交
241
    auto op_base =
242 243
        info.Creator()(op_type, inputs_names, outputs_names, op_attr_map);
    op_base->SetRuntimeAttributeMap(op_runtime_attr_map);
244 245 246 247 248 249 250 251 252 253

#ifdef PADDLE_WITH_MKLDNN
    if (FLAGS_use_mkldnn) {
      if (op->HasAttr("use_mkldnn")) {
        VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
        op_base->SetAttr("use_mkldnn", true);
      }
    }
#endif

X
xiongkun 已提交
254
    ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
W
wanghuancoder 已提交
255
  }
256 257
}

258
std::tuple<VariableValueMap, VariableIdMap> BuildVariableMap(
259 260
    const VariableNameMap& var_name_map,
    VariableScope* var_scope,
261
    Scope* local_scope,
262 263
    bool allow_var_not_in_program = false,
    bool allow_var_not_in_scope = false) {
264 265 266 267 268 269 270 271
  VariableValueMap name2var;
  VariableIdMap name2id;
  for (auto& item : var_name_map) {
    std::vector<Variable*> vars;
    std::vector<int> ids;
    vars.reserve(item.second.size());

    for (auto& var_name : item.second) {
272
      if (!var_scope->HasVar(var_name)) {
273 274
        if (allow_var_not_in_program && local_scope->FindVar(var_name)) {
          VLOG(3) << "Add " << var_name << " to var_scope";
275
          var_scope->AddVar(var_name, nullptr);
276
        } else if (allow_var_not_in_scope) {
277 278 279
          VLOG(4) << var_name << " don't exist in variable scope, skip it!";
          continue;
        }
280
      }
281
      auto* var = local_scope->FindVar(var_name);
282
      auto var_id = var_scope->VarId(var_name);
283
      vars.push_back(var);
284 285 286 287 288 289 290
      ids.push_back(var_id);
    }
    name2var[item.first] = std::move(vars);
    name2id[item.first] = std::move(ids);
  }
  return std::make_tuple(name2var, name2id);
}
W
wanghuancoder 已提交
291

292 293 294 295 296 297 298 299 300 301 302 303
void apply_device_guard(const OperatorBase* op_base,
                        const platform::Place& place,
                        OpKernelType* expected_kernel_key) {
  bool need_change_place =
      (op_base->HasAttr("op_device") &&
       (op_base->Attr<std::string>("op_device").length() > 0));
  if (need_change_place) {
    auto& op_device = op_base->Attr<std::string>("op_device");
    if (op_device == "cpu" || platform::is_cpu_place(place)) {
      VLOG(3) << "Switch into CPUPlace by device_guard.";
      expected_kernel_key->place_ = platform::CPUPlace();
    } else if (op_device.find("gpu") != std::string::npos &&
304 305 306 307
               platform::is_gpu_place(place)) {
      // when the Op that does not have GPUKernel is assigned to GPU, the
      // CPUKernel will be executed and a warning will be given at the same
      // time.
308 309 310 311 312 313 314 315 316 317
      if (op_base->SupportGPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no CUDA implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    } else if (op_device.find("npu") != std::string::npos &&
               platform::is_npu_place(place)) {
      // when the Op that does not have NPUKernel is assigned to NPU, the
      // CPUKernel will be executed and a warning will be given at the same
      // time.
      if (op_base->SupportNPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no NPU implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
    } else if (op_device.find("xpu") != std::string::npos &&
               platform::is_xpu_place(place)) {
      // when the Op that does not have XPUKernel is assigned to XPU, the
      // CPUKernel will be executed and a warning will be given at the same
      // time.
      if (op_base->SupportXPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no XPU implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
348 349 350 351 352 353 354
    } else {
      PADDLE_THROW(
          platform::errors::Fatal("Unsupported current place %s", op_device));
    }
  }
}

355
void deal_operator_base(const platform::Place& place,
L
Leo Chen 已提交
356 357
                        const VariableScope* var_scope,
                        std::shared_ptr<OperatorBase> op_base,
358 359
                        OpFuncNode* op_func_node,
                        Scope* local_scope) {
360 361 362 363
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  // input, output is prepared. set the other attributes.
  op_func_node->operator_base_ = op_base;
364
  if (IsSupportedHetePlace(place)) {
365 366 367 368 369 370 371
    op_func_node->type_ = OpFuncType::kQueueAsync;
  } else if (platform::is_cpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueSync;
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("Unsupported current place %s", place));
  }
372
  op_func_node->kernel_func_ = nullptr;
373
  op_base->Run(*local_scope, place);  // Run without data transformer.
374 375 376 377 378 379 380 381 382 383 384
  std::unordered_set<int> no_data_transform_index;
  for (auto& it : op_func_node->input_index) {
    for (auto& id : it.second) {
      no_data_transform_index.emplace(id);
    }
  }
  op_func_node->no_data_transform_index =
      no_data_transform_index;  // all index is no-need-transform
  op_func_node->dev_ctx_ = dev_ctx;
}

385
void build_op_func_list(const platform::Place& place,
386
                        const framework::BlockDesc& block,
387
                        const std::set<std::string>& skip_gc_vars,
388
                        std::vector<OpFuncNode>* vec_func_list,
389
                        VariableScope* var_scope,
390 391
                        bool use_local_scope,
                        bool used_for_jit) {
392 393
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();
X
xiongkun 已提交
394 395 396 397
  std::vector<std::unique_ptr<OperatorBase>>
      ops_unique;  // its elements will be moved to vec_func_list
  // Step 1: create all ops for current block.
  create_all_ops(block, &ops_unique);
398 399 400 401 402 403 404 405 406 407 408

  if (!used_for_jit) {
    // If gc is enabled and block size > 1
    const ProgramDesc& main_program = *block.Program();
    operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
        main_program, block.ID(), ops_unique);
    operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
        main_program, block.ID(), ops_unique);
    operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
        main_program, block.ID(), ops_unique);
  }
X
xiongkun 已提交
409

L
Leo Chen 已提交
410 411 412
#ifdef PADDLE_WITH_MKLDNN
  platform::RegisterModelLayout(ops_unique, place);
#endif
413 414
  // its elements will be moved to vec_func_list
  std::vector<std::shared_ptr<OperatorBase>> ops;
X
xiongkun 已提交
415 416 417
  for (auto& op_unique : ops_unique) {
    ops.emplace_back(std::move(op_unique));
  }
418
  auto unused_var_map = get_unused_vars(block, ops);
W
wanghuancoder 已提交
419

420
  bool flag_log_is_printed = false;
L
Leo Chen 已提交
421 422
  for (size_t i = 0; i < ops.size(); ++i) {
    auto op = ops[i].get();
423 424 425
    const std::string& op_type = op->Type();

    VLOG(6) << "Build OpFuncNode from : " << op_type;
W
wanghuancoder 已提交
426

P
pangyoki 已提交
427 428
    // Print new executor log if grad op is used.
    // It's only for test and will be removed later.
429
    if (!flag_log_is_printed && op_type.find("_grad") != std::string::npos) {
430
      LOG_FIRST_N(INFO, 1) << "Standalone Executor is Used.";
P
pangyoki 已提交
431 432 433
      flag_log_is_printed = true;
    }

434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
    // Hot fix for variables used in dataloader, like
    // 'lod_tensor_blocking_queue_0'. These variables may be created in scope,
    // and it is not existed as variable in program.
    const std::set<std::string> ops_with_var_not_in_program = {
        "create_py_reader"};
    const std::set<std::string> ops_with_var_not_in_scope = {
        "conditional_block",
        "conditional_block_grad",
        "recurrent_grad",
        "rnn_memory_helper",
        "rnn_memory_helper_grad",
        "while",
        "while_grad"};
    bool allow_var_not_in_program = ops_with_var_not_in_program.count(op_type);
    bool allow_var_not_in_scope = ops_with_var_not_in_scope.count(op_type);

    framework::VariableNameMap& input_name_map = op->Inputs();
W
wanghuancoder 已提交
451
    VariableValueMap ins_map;
452
    VariableIdMap ins_name2id;
453 454 455 456 457
    std::tie(ins_map, ins_name2id) = BuildVariableMap(input_name_map,
                                                      var_scope,
                                                      local_scope,
                                                      allow_var_not_in_program,
                                                      allow_var_not_in_scope);
W
wanghuancoder 已提交
458

459
    framework::VariableNameMap& output_name_map = op->Outputs();
W
wanghuancoder 已提交
460
    VariableValueMap outs_map;
461
    VariableIdMap outs_name2id;
462 463 464 465 466 467
    std::tie(outs_map, outs_name2id) =
        BuildVariableMap(output_name_map,
                         var_scope,
                         local_scope,
                         /*allow_var_not_in_program=*/false,
                         allow_var_not_in_scope);
W
wanghuancoder 已提交
468

469
    // step 1: build OpFuncNode
W
wanghuancoder 已提交
470
    OpFuncNode op_func_node;
471
    op_func_node.operator_base_ = ops[i];
W
wanghuancoder 已提交
472 473
    op_func_node.input_index = ins_name2id;
    op_func_node.output_index = outs_name2id;
474
    VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
475

476 477 478 479 480 481 482 483 484 485
#ifdef PADDLE_WITH_ASCEND_CL
    // NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable
    // values, but only through special `float_status` to checks whether
    // the operation is overflow. More about `float_status`, see:
    // https://gitee.com/ascend/modelzoo/issues/I3NF8V?from=project-issue
    if (FLAGS_check_nan_inf) {
      framework::details::NPUAllocAndClearFloatStatus(*op, *local_scope, place);
    }
#endif

486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
    try {
      if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
        // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
        deal_operator_base(
            place, var_scope, ops[i], &op_func_node, local_scope);
        VLOG(4) << "deal_operator_base";
      } else {
        VLOG(4) << "OP is not null";
        auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
            static_cast<const framework::OperatorWithKernel*>(op));
        VLOG(4) << "get op_with_kernel";
        // construct RuntimeContext and analysis KernelType
        RuntimeContext runtime_context({}, {});
        runtime_context.inputs.swap(ins_map);
        runtime_context.outputs.swap(outs_map);
        VLOG(4) << "get RuntimeContext";

        Scope scope, *runtime_scope = &scope;
        // NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
        // But some OPs do have such behavior (e.g., cinn_launch OP). Here
        // special treatment for them.
        if (op_with_kernel->Type() == "cinn_launch") {
          VLOG(6) << "OP(" << op_with_kernel->Type()
                  << ") use scope in kernel, "
                     "so pass a real scope to "
                     "ExecutionContext";
          runtime_scope = local_scope;
        }
514

515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
        auto& pool = platform::DeviceContextPool::Instance();
        auto* dev_ctx = pool.Get(place);
        VLOG(4) << "get dev_ctx";
        auto exec_ctx = ExecutionContext(
            *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
        VLOG(4) << "get exec_ctx";
        auto expected_kernel_key =
            op_with_kernel->GetExpectedKernelType(exec_ctx);
        VLOG(4) << "get expected_kernel_key";
        // change device by the device_guard()
        apply_device_guard(op, place, &expected_kernel_key);
        VLOG(4) << "expected_kernel_key : " << expected_kernel_key;

        // step 2. select op kernel
        auto run_phi_kernel = false;
        if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
                op_with_kernel->Type())) {
          auto phi_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
          auto phi_kernel_name = op_with_kernel->PhiKernelSignature()->name;

          if (op_with_kernel->PhiKernel()->IsValid()) {
            run_phi_kernel = true;
          } else {
            if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) {
              auto phi_cpu_kernel_key = FallBackToCpu(
                  expected_kernel_key, phi_kernel_key, *op_with_kernel);
              op_with_kernel->ResetPhiKernel(
                  new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
                      phi_kernel_name, phi_cpu_kernel_key)));
              if (op_with_kernel->PhiKernel()->IsValid()) {
                VLOG(6) << "Static mode PrepareImpl - kernel name: "
                        << phi_kernel_name
                        << " | kernel key: " << phi_cpu_kernel_key
                        << " | kernel: " << *(op_with_kernel->PhiKernel());
                op_with_kernel->ResetKernelType(new OpKernelType(
                    TransPhiKernelKeyToOpKernelType(phi_cpu_kernel_key)));
                run_phi_kernel = true;
              }
553 554 555
            }
          }
        }
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
        VLOG(4) << "if run phi kernel? : " << run_phi_kernel;
        if (!run_phi_kernel) {
          op_with_kernel->ChooseKernel(exec_ctx);
          op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
        } else {
          op_func_node.phi_kernel_ = op_with_kernel->PhiKernel();
        }
        auto kernel_type = *(op_with_kernel->kernel_type());
        if (kernel_type.place_ != dev_ctx->GetPlace()) {
          dev_ctx = pool.Get(kernel_type.place_);
        }
        op_func_node.dev_ctx_ = dev_ctx;
        if (IsSupportedHetePlace(kernel_type.place_)) {
          op_func_node.type_ = OpFuncType::kQueueAsync;
        } else if (platform::is_cpu_place(kernel_type.place_)) {
          op_func_node.type_ = OpFuncType::kQueueSync;
        } else {
          PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
                                               kernel_type.place_));
        }
        VLOG(3) << op_with_kernel->Type()
                << " : finally selected kernel_key: " << kernel_type;

        // step 3. data transform
        VariableValueMap& ins_map_temp = runtime_context.inputs;
        VariableValueMap& outs_map_temp = runtime_context.outputs;
        ApplyDataTransform(kernel_type,
                           place,
                           &ins_map_temp,
                           &outs_map_temp,
                           var_scope,
                           &op_func_node,
                           vec_func_list,
                           use_local_scope);
        VLOG(4) << "apply data transform done. ";
        // step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc
        // for why.
        if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
              op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
          InterpretercoreInferShapeContext infer_shape_ctx(*op,
                                                           runtime_context);
          // TODO(Aurelius84): In case of control flow ops, they are NOT
          // inheritted from OperatorWithKernel.
          op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
        }
601

602 603 604 605 606 607 608 609 610 611 612
        // step 5. run kernel
        if (run_phi_kernel) {
          phi::KernelContext phi_kernel_context;
          op_with_kernel->BuildPhiKernelContext(
              runtime_context, dev_ctx, &phi_kernel_context);
          (*op_func_node.phi_kernel_)(&phi_kernel_context);
        } else {
          // the place of exec_ctx maybe has changed.
          op_func_node.kernel_func_(ExecutionContext(
              *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context));
        }
613

614 615 616 617 618 619
        // post-process grad_op.outputs if need cast complex grad into real
        // grad.
        // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
        if (framework::IsComplexType(kernel_type.data_type_)) {
          interpreter::HandleComplexGradToRealGrad(op_func_node,
                                                   place,
620
                                                   output_name_map,
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
                                                   &runtime_context.outputs,
                                                   var_scope,
                                                   vec_func_list,
                                                   local_scope);
        }
        if (!op_func_node.inplace_back_map.empty()) {
          auto& m = op_func_node.inplace_back_map;
          // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in
          // operator.cc
          for (auto& p : m) {
            auto* transformed_tensor =
                GetMutableLoDTensorOrSelectedRowsValueFromVar(
                    local_scope->FindVar(var_scope->GetNameById(p.first)));
            auto* original_tensor =
                GetMutableLoDTensorOrSelectedRowsValueFromVar(
                    local_scope->FindVar(var_scope->GetNameById(p.second)));
            original_tensor->ShareDataWith(*transformed_tensor);
            VLOG(4) << "Transfer inplace variable back form "
                    << var_scope->GetNameById(p.first) << " to "
                    << var_scope->GetNameById(p.second);
          }
642
        }
643

644 645 646 647 648
        // for debug nan/inf
        if (FLAGS_check_nan_inf) {
          VLOG(4) << "Check nan/inf";
          framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place);
        }
649
      }
650
    } catch (platform::EnforceNotMet& ex) {
651
      framework::InsertCallStackInfo(op_type, op->Attrs(), &ex);
652 653 654 655
      throw std::move(ex);
    } catch (platform::EOFException&) {
      std::rethrow_exception(std::current_exception());
    } catch (std::exception& ex) {
656
      LOG(WARNING) << op_type << " raises an exception "
657 658 659 660
                   << platform::demangle(typeid(ex).name()) << ", "
                   << ex.what();
      std::rethrow_exception(std::current_exception());
    } catch (...) {
661
      LOG(WARNING) << op_type << " raises an unknown exception";
662
      std::rethrow_exception(std::current_exception());
663
    }
W
wanghuancoder 已提交
664

665 666 667
    VLOG(4) << "End run " << place << " "
            << op_func_node.operator_base_->DebugStringEx(local_scope);

L
Leo Chen 已提交
668
    vec_func_list->emplace_back(op_func_node);
669

W
wanghuancoder 已提交
670
    // gc---------------------------------------------------------------------------
L
Leo Chen 已提交
671
    auto iter = unused_var_map.find(op);
W
wanghuancoder 已提交
672
    if (iter == unused_var_map.end()) {
673
      interpreter::LogDeviceMemoryStats(place);
W
wanghuancoder 已提交
674 675 676 677 678 679 680 681
      continue;
    }

    auto& delete_vars = iter->second;
    std::deque<std::shared_ptr<memory::Allocation>>* garbages =
        new std::deque<std::shared_ptr<memory::Allocation>>();

    for (auto& var_name : delete_vars) {
682
      auto* var = local_scope->FindVar(var_name);
683
      if (var == nullptr || skip_gc_vars.find(var_name) != skip_gc_vars.end()) {
W
wanghuancoder 已提交
684 685 686
        continue;
      }

687
      VLOG(6) << "Erase variable " << var_name;
W
wanghuancoder 已提交
688 689 690 691 692 693
      if (var->IsType<LoDTensor>()) {
        garbages->emplace_back(
            var->GetMutable<LoDTensor>()->MoveMemoryHolder());
      }
    }
    delete garbages;  // free mem
694 695

    interpreter::LogDeviceMemoryStats(place);
W
wanghuancoder 已提交
696
  }
697 698 699 700 701 702

  // NOTE(Ruibiao): Release memory cache to avoid memory fragments in Allocator.
  // It reduce about 10% memory usage for V100 8-GPU training of
  // transformer_base_bs4096_amp_fp16 and transformer_base_bs4096_pure_fp16
  // model.
  memory::Release(place);
W
wanghuancoder 已提交
703 704
}

705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723
void add_fetch(const std::vector<std::string>& fetch_names,
               framework::BlockDesc* block) {
  auto* fetch_holder = block->Var(kFetchVarName);
  fetch_holder->SetType(proto::VarType::FETCH_LIST);
  fetch_holder->SetPersistable(true);

  int i = 0;
  for (auto& fetch_name : fetch_names) {
    // append fetch op
    auto* op = block->AppendOp();
    op->SetType("fetch_v2");
    op->SetInput("X", {fetch_name});
    op->SetOutput("Out", {kFetchVarName});
    op->SetAttr("col", {static_cast<int>(i)});
    op->CheckAttrs();
    i++;
  }
}

W
wanghuancoder 已提交
724 725 726
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
                                 const std::vector<size_t>& second) {
  std::vector<size_t> out(first.size() + second.size());
727 728
  std::merge(
      first.begin(), first.end(), second.begin(), second.end(), out.begin());
W
wanghuancoder 已提交
729 730 731 732 733 734 735 736 737

  std::vector<size_t>::iterator it;
  it = std::unique(out.begin(), out.end());

  out.resize(std::distance(out.begin(), it));

  return out;
}

738
}  // namespace interpreter
W
wanghuancoder 已提交
739 740
}  // namespace framework
}  // namespace paddle