interpretercore_util.cc 37.1 KB
Newer Older
W
wanghuancoder 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
15 16
#include <algorithm>

W
wanghuancoder 已提交
17
#include "paddle/fluid/framework/executor_gc_helper.h"
18
#include "paddle/fluid/framework/new_executor/data_transfer.h"
X
xiongkun 已提交
19 20 21
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
22
#include "paddle/phi/core/kernel_context.h"
23
#include "paddle/phi/core/kernel_factory.h"
W
wanghuancoder 已提交
24

L
Leo Chen 已提交
25 26 27 28
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

29 30 31
PADDLE_DEFINE_EXPORTED_bool(
    new_executor_sequential_run, false,
    "Enable sequential execution for standalone executor, used for debug");
32

33 34
DECLARE_bool(use_mkldnn);

W
wanghuancoder 已提交
35 36
namespace paddle {
namespace framework {
37
namespace interpreter {
38

L
liutiexing 已提交
39 40
constexpr size_t kPrepareWorkQueueIdx = 2;

41 42
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
                             std::function<void()> fn) {
43
  VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " ";
44 45 46 47 48 49 50 51 52 53 54
  // NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
  if (FLAGS_new_executor_sequential_run) {
    VLOG(4) << "FLAGS_new_executor_sequential_run:"
            << FLAGS_new_executor_sequential_run;
    queue_group_->AddTask(static_cast<size_t>(OpFuncType::kQueueAsync),
                          std::move(fn));
  } else {
    queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
  }
}

55
using VariableIdMap = std::map<std::string, std::vector<int>>;
W
wanghuancoder 已提交
56

57
void AsyncWorkQueue::PrepareAtomicDeps(
58
    const std::vector<size_t>& dependecy_count) {
59
  VLOG(4) << "PrepareAtomicDeps";
L
liutiexing 已提交
60 61 62 63 64 65 66 67 68 69
  atomic_deps_ =
      queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&dependecy_count] {
        auto op_deps = std::make_unique<std::vector<std::atomic<size_t>>>(
            dependecy_count.size());
        for (size_t i = 0; i < dependecy_count.size(); ++i) {
          (*op_deps)[i] = dependecy_count[i];
        }
        VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
        return op_deps;
      });
70 71
}

72
void AsyncWorkQueue::PrepareAtomicVarRef(
73
    const std::vector<VariableMetaInfo>& vec_meta_info) {
74
  VLOG(4) << "PrepareAtomicVarRef";
L
liutiexing 已提交
75 76 77 78 79 80 81 82 83 84
  atomic_var_ref_ =
      queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&vec_meta_info] {
        auto var_ref = std::make_unique<std::vector<std::atomic<size_t>>>(
            vec_meta_info.size());
        for (size_t i = 0; i < vec_meta_info.size(); ++i) {
          (*var_ref)[i] = vec_meta_info[i].var_ref_count_;
        }
        VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
        return var_ref;
      });
85 86
}

W
wanghuancoder 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
  auto* var_desc = block.FindVar(name);
  if (var_desc == nullptr || var_desc->Persistable()) {
    return false;
  }

  auto type = var_desc->Proto()->type().type();

  return type == proto::VarType::LOD_TENSOR ||
         type == proto::VarType::SELECTED_ROWS ||
         type == proto::VarType::LOD_TENSOR_ARRAY;
}

std::unordered_map<const paddle::framework::OperatorBase*,
                   std::vector<std::string>>
L
Leo Chen 已提交
102 103
get_unused_vars(const BlockDesc& block,
                const std::vector<std::shared_ptr<OperatorBase>>& ops) {
W
wanghuancoder 已提交
104 105 106
  std::unordered_map<std::string, size_t> var_op_idx_map;

  for (size_t i = 0; i < ops.size(); ++i) {
L
Leo Chen 已提交
107
    const auto& op = ops[i];
W
wanghuancoder 已提交
108 109 110 111 112 113 114 115 116 117

    OpInOutInfo info;
    for (auto& name_pair : op->Inputs()) {
      for (auto& name : name_pair.second) {
        if (!var_can_be_deleted(name, block)) {
          continue;
        }

        // var can be gc-ed
        if (!info.IsBuilt()) {
L
Leo Chen 已提交
118
          info.Build(op.get());
W
wanghuancoder 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        }

        if (info.IsInArgBufferNeeded(name)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        } else {
          VLOG(10) << "Skip reference count computing of variable "
                   << name_pair.first << "(" << name << ") in Operator "
                   << op->Type();
        }
      }
    }

    for (auto& name_pair : op->Outputs()) {
      for (auto& name : name_pair.second) {
        if (var_can_be_deleted(name, block)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        }
      }
    }
  }

  std::unordered_map<const OperatorBase*, std::vector<std::string>> result;
  for (auto& name_op_idx_pair : var_op_idx_map) {
    auto& name = name_op_idx_pair.first;
    size_t op_idx = name_op_idx_pair.second;
L
Leo Chen 已提交
146 147

    result[ops[op_idx].get()].emplace_back(name);
148
    VLOG(4) << ops[op_idx].get()->Type() << " " << name;
W
wanghuancoder 已提交
149
  }
150
  VLOG(4) << "gc map size:" << result.size();
W
wanghuancoder 已提交
151 152 153
  return result;
}

154
void build_variable_scope(const framework::BlockDesc& block,
155 156 157 158 159 160 161 162 163
                          VariableScope* var_scope, bool use_local_scope) {
  VLOG(3) << "Creating Variables";
  auto inner_scope = var_scope->GetMutableScope();

  // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
  // created in var_scope.scope_ , and other scope is created in local scope.
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();

164
  for (auto& var_desc : block.AllVars()) {
165
    auto var_name = var_desc->Name();
X
xiongkun 已提交
166 167 168
    // TODO(xiongkun): user may create a variable with name that exists before.
    // under such circumstances, we should raise a error. Currently we can't
    // get the var_desc of startup_program, so leave it later.
169
    if (var_name == framework::kEmptyVarName) {
W
wanghuancoder 已提交
170 171
      continue;
    }
172 173
    if (var_desc->Persistable()) {
      auto* ptr = inner_scope->Var(var_name);
W
wanghuancoder 已提交
174

175
      VLOG(3) << "Initialize Variable " << var_name;
176 177
      // NOTE(zhiqiu): if var exists in scope and the type is right,
      // InitializeVariable will not create a new variable.
178 179 180
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
              << ptr << " type is " << static_cast<int>(var_desc->GetType());
181
    } else {
182 183 184 185 186
      auto* ptr = local_scope->Var(var_name);
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
              << ptr << "Variable Type "
              << static_cast<int>(var_desc->GetType());
W
wanghuancoder 已提交
187
    }
188
    var_scope->SetVarDesc(var_name, var_desc);
W
wanghuancoder 已提交
189 190 191
  }
}

L
Leo Chen 已提交
192
void create_all_ops(const framework::BlockDesc& block,
X
xiongkun 已提交
193
                    std::vector<std::unique_ptr<OperatorBase>>* ops) {
194 195
  for (auto& op : block.AllOps()) {
    VLOG(3) << "CreateOp from : " << op->Type();
W
wanghuancoder 已提交
196 197 198 199 200

    auto& info = OpInfoMap::Instance().Get(op->Type());

    const VariableNameMap& inputs_names = op->Inputs();
    const VariableNameMap& outputs_names = op->Outputs();
201

W
wanghuancoder 已提交
202 203 204 205 206 207 208
    AttributeMap op_attr_map = op->GetAttrMap();

    if (info.Checker() != nullptr) {
      info.Checker()->Check(&op_attr_map);
    }
    auto op_base =
        info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
209 210 211 212 213 214 215 216 217 218

#ifdef PADDLE_WITH_MKLDNN
    if (FLAGS_use_mkldnn) {
      if (op->HasAttr("use_mkldnn")) {
        VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
        op_base->SetAttr("use_mkldnn", true);
      }
    }
#endif

X
xiongkun 已提交
219
    ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
W
wanghuancoder 已提交
220
  }
221 222 223
}

std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
224 225
    const VariableNameMap& var_name_map, VariableScope* var_scope,
    bool enforce_exist = true) {
226 227 228 229 230 231 232 233
  VariableValueMap name2var;
  VariableIdMap name2id;
  for (auto& item : var_name_map) {
    std::vector<Variable*> vars;
    std::vector<int> ids;
    vars.reserve(item.second.size());

    for (auto& var_name : item.second) {
234 235 236 237 238
      if (!enforce_exist && !var_scope->HasVar(var_name)) {
        // skip the non-exist variable: such as recurrent_grad
        VLOG(4) << var_name << " don't exist in variable scope, skip it!";
        continue;
      }
239 240 241 242 243 244 245 246 247 248
      auto var_id = var_scope->VarId(var_name);
      auto* in_var = var_scope->Var(var_id);
      vars.push_back(in_var);
      ids.push_back(var_id);
    }
    name2var[item.first] = std::move(vars);
    name2id[item.first] = std::move(ids);
  }
  return std::make_tuple(name2var, name2id);
}
W
wanghuancoder 已提交
249

250 251 252 253 254 255 256 257 258 259 260 261
void apply_device_guard(const OperatorBase* op_base,
                        const platform::Place& place,
                        OpKernelType* expected_kernel_key) {
  bool need_change_place =
      (op_base->HasAttr("op_device") &&
       (op_base->Attr<std::string>("op_device").length() > 0));
  if (need_change_place) {
    auto& op_device = op_base->Attr<std::string>("op_device");
    if (op_device == "cpu" || platform::is_cpu_place(place)) {
      VLOG(3) << "Switch into CPUPlace by device_guard.";
      expected_kernel_key->place_ = platform::CPUPlace();
    } else if (op_device.find("gpu") != std::string::npos &&
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
               (platform::is_gpu_place(place) ||
                platform::is_npu_place(place))) {
      // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
      // will be executed and a warning will be given at the same time.
      if (op_base->SupportGPU()) {
        expected_kernel_key->place_ = place;
      } else if (op_base->SupportNPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no CUDA implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
278 279 280 281 282 283 284
    } else {
      PADDLE_THROW(
          platform::errors::Fatal("Unsupported current place %s", op_device));
    }
  }
}

285
void deal_operator_base(const platform::Place& place,
L
Leo Chen 已提交
286 287
                        const VariableScope* var_scope,
                        std::shared_ptr<OperatorBase> op_base,
288
                        OpFuncNode* op_func_node, Scope* local_scope) {
289 290 291 292
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  // input, output is prepared. set the other attributes.
  op_func_node->operator_base_ = op_base;
293 294 295 296 297 298 299 300 301
  if (platform::is_gpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueAsync;
  } else if (platform::is_cpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueSync;
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("Unsupported current place %s", place));
  }

302
  op_func_node->kernel_func_ = nullptr;
303
  op_base->Run(*local_scope, place);  // Run without data transformer.
304 305 306 307 308 309 310 311 312 313 314 315

  std::unordered_set<int> no_data_transform_index;
  for (auto& it : op_func_node->input_index) {
    for (auto& id : it.second) {
      no_data_transform_index.emplace(id);
    }
  }
  op_func_node->no_data_transform_index =
      no_data_transform_index;  // all index is no-need-transform
  op_func_node->dev_ctx_ = dev_ctx;
}

316
void build_op_func_list(const platform::Place& place,
317
                        const framework::BlockDesc& block,
318
                        std::vector<OpFuncNode>* vec_func_list,
319 320 321
                        VariableScope* var_scope, bool use_local_scope) {
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();
322
  auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
X
xiongkun 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335
  std::vector<std::unique_ptr<OperatorBase>>
      ops_unique;  // its elements will be moved to vec_func_list
  // Step 1: create all ops for current block.
  create_all_ops(block, &ops_unique);
  // If gc is enabled and block size > 1
  const ProgramDesc& main_program = *block.Program();
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
      main_program, block.ID(), ops_unique);

L
Leo Chen 已提交
336 337 338 339
#ifdef PADDLE_WITH_MKLDNN
  platform::RegisterModelLayout(ops_unique, place);
#endif

340 341
  // its elements will be moved to vec_func_list
  std::vector<std::shared_ptr<OperatorBase>> ops;
X
xiongkun 已提交
342 343 344
  for (auto& op_unique : ops_unique) {
    ops.emplace_back(std::move(op_unique));
  }
345
  auto unused_var_map = get_unused_vars(block, ops);
W
wanghuancoder 已提交
346

L
Leo Chen 已提交
347 348
  for (size_t i = 0; i < ops.size(); ++i) {
    auto op = ops[i].get();
349
    VLOG(6) << "Build OpFuncNode from : " << op->Type();
W
wanghuancoder 已提交
350 351 352 353 354

    auto inputs_names = op->Inputs();
    auto outputs_names = op->Outputs();

    VariableValueMap ins_map;
355
    VariableIdMap ins_name2id;
356
    bool enforce_exist = true;
W
wanghuancoder 已提交
357 358 359 360 361 362 363
    if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
        op->Type() == "rnn_memory_helper_grad" ||
        op->Type() == "conditional_block" ||
        op->Type() == "conditional_block_grad" || op->Type() == "while" ||
        op->Type() == "while_grad") {
      enforce_exist = false;
    }
364
    std::tie(ins_map, ins_name2id) =
365
        build_variable_map(inputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
366 367

    VariableValueMap outs_map;
368 369
    VariableIdMap outs_name2id;
    std::tie(outs_map, outs_name2id) =
370
        build_variable_map(outputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
371

372
    // step 2: build OpFuncNode
W
wanghuancoder 已提交
373
    OpFuncNode op_func_node;
374
    op_func_node.operator_base_ = ops[i];
W
wanghuancoder 已提交
375 376
    op_func_node.input_index = ins_name2id;
    op_func_node.output_index = outs_name2id;
377
    VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
378

379
    if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
380
      // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
381
      deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
382 383
      VLOG(4) << "End run " << place << " "
              << op_func_node.operator_base_->DebugStringEx(local_scope);
W
wanghuancoder 已提交
384
    } else {
385 386
      auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(op));
387 388 389 390
      // construct RuntimeContext and analysis KernelType
      RuntimeContext runtime_context({}, {});
      runtime_context.inputs.swap(ins_map);
      runtime_context.outputs.swap(outs_map);
391

392 393 394 395
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(place);
      Scope scope;
396 397 398 399 400 401 402 403 404 405 406
      Scope* runtime_scope = &scope;
      // NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
      // But some OPs do have such behavior (e.g., cinn_launch OP). Here special
      // treatment for them.
      if (op_with_kernel->Type() == "cinn_launch") {
        VLOG(6) << "OP(" << op_with_kernel->Type() << ") use scope in kernel, "
                                                      "so pass a real scope to "
                                                      "ExecutionContext";
        runtime_scope = local_scope;
      }

407
      auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
408
          ExecutionContext(*op, *runtime_scope, *dev_ctx, runtime_context));
409
      op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key));
410

411 412
      // change device by the device_guard()
      apply_device_guard(op, place, &expected_kernel_key);
413 414
      VLOG(3) << "expected_kernel_key : " << expected_kernel_key;

415
      // step 3. apply data transforms and insert data transfer ops
416
      VariableValueMap& ins_map_temp = runtime_context.inputs;
417
      VariableValueMap& outs_map_temp = runtime_context.outputs;
418 419 420

      // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
      // ApplyDataTransform
421 422 423 424 425 426
      ApplyDataTransform(expected_kernel_key, place, &ins_map_temp,
                         &outs_map_temp, var_scope, &op_func_node,
                         vec_func_list, use_local_scope);
      op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(
              op_func_node.operator_base_.get()));
427

428
      // step 4. Run op kernel
429
      VLOG(3) << op_with_kernel->Type()
430 431 432 433 434 435 436 437 438 439 440 441 442 443
              << " : expected_kernel_key : " << expected_kernel_key;

      if (platform::is_gpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueAsync;
      } else if (platform::is_cpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueSync;
      } else {
        PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
                                             expected_kernel_key.place_));
      }
      if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
        dev_ctx = pool.Get(expected_kernel_key.place_);
      }
      op_func_node.dev_ctx_ = dev_ctx;
444 445
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
W
wanghuancoder 已提交
446

447 448 449 450 451
      // see OperatorWithKernel::RunImpl in operator.cc for why
      if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
            op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
        InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
        // TODO(Aurelius84): In case of control flow ops, they are NOT
452
        // inheritted from OperatorWithKernel.
453 454 455
        op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
      }

456 457
      auto exec_ctx = ExecutionContext(*op_with_kernel, *runtime_scope,
                                       *dev_ctx, runtime_context);
458

459 460
      auto run_phi_kernel = false;
      if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
461
              op_with_kernel->Type())) {
462 463
        auto pt_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
        auto pt_kernel_name = op_with_kernel->PhiKernelSignature()->name;
464

465 466
        if (op_with_kernel->PhiKernel()->IsValid()) {
          run_phi_kernel = true;
467 468 469 470 471 472 473
        } else {
          auto kernels_iter = all_op_kernels.find(op_with_kernel->Type());
          if (kernels_iter == all_op_kernels.end() ||
              kernels_iter->second.find(expected_kernel_key) ==
                  kernels_iter->second.end()) {
            auto pt_cpu_kernel_key = FallBackToCpu(
                expected_kernel_key, pt_kernel_key, *op_with_kernel);
474
            op_with_kernel->ResetPhiKernel(
475
                new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
476
                    pt_kernel_name, pt_cpu_kernel_key)));
477
            if (op_with_kernel->PhiKernel()->IsValid()) {
478 479 480
              VLOG(6) << "Static mode PrepareImpl - kernel name: "
                      << pt_kernel_name
                      << " | kernel key: " << pt_cpu_kernel_key
481 482
                      << " | kernel: " << *(op_with_kernel->PhiKernel());
              run_phi_kernel = true;
483 484 485 486 487 488
            }
          }
        }
      }
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
489
      if (run_phi_kernel) {
490
        phi::KernelContext pt_kernel_context;
491 492 493
        op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx,
                                              &pt_kernel_context);
        op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
494
        (*op_func_node.pt_kernel_)(&pt_kernel_context);
495
      } else {
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
        auto kernels_iter = all_op_kernels.find(op->Type());
        PADDLE_ENFORCE_NE(
            kernels_iter, all_op_kernels.end(),
            platform::errors::Unavailable(
                "There are no kernels which are registered in the %s operator.",
                op->Type()));
        OpKernelMap& kernels = kernels_iter->second;

        auto kernel_iter = kernels.find(expected_kernel_key);
        PADDLE_ENFORCE_NE(
            kernel_iter, kernels.end(),
            platform::errors::NotFound(
                "Operator (%s) does not have kernel for %s.", op->Type(),
                KernelTypeToString(expected_kernel_key)));
        // TODO(zhiqiu): add fallback logic
511 512 513
        op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
        op_func_node.kernel_func_(exec_ctx);
      }
514 515 516 517 518 519 520 521

      // post-process grad_op.outputs if need cast complex grad into real grad.
      // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
      if (framework::IsComplexType(expected_kernel_key.data_type_)) {
        interpreter::HandleComplexGradToRealGrad(
            op_func_node, place, outputs_names, &runtime_context.outputs,
            var_scope, vec_func_list, local_scope);
      }
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
      if (!op_func_node.inplace_back_map.empty()) {
        auto& m = op_func_node.inplace_back_map;
        // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
        for (auto& p : m) {
          auto* transformed_tensor =
              GetMutableLoDTensorOrSelectedRowsValueFromVar(
                  var_scope->Var(p.first));
          auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
              var_scope->Var(p.second));
          original_tensor->ShareDataWith(*transformed_tensor);
          VLOG(4) << "Transfer inplace variable back form "
                  << var_scope->GetNameById(p.first) << " to "
                  << var_scope->GetNameById(p.second);
        }
      }
537
    }
W
wanghuancoder 已提交
538

539 540 541
    VLOG(4) << "End run " << place << " "
            << op_func_node.operator_base_->DebugStringEx(local_scope);

L
Leo Chen 已提交
542
    vec_func_list->emplace_back(op_func_node);
543

W
wanghuancoder 已提交
544
    // gc---------------------------------------------------------------------------
L
Leo Chen 已提交
545
    auto iter = unused_var_map.find(op);
W
wanghuancoder 已提交
546 547 548 549 550 551 552 553 554
    if (iter == unused_var_map.end()) {
      continue;
    }

    auto& delete_vars = iter->second;
    std::deque<std::shared_ptr<memory::Allocation>>* garbages =
        new std::deque<std::shared_ptr<memory::Allocation>>();

    for (auto& var_name : delete_vars) {
555
      auto* var = var_scope->FindVar(var_name);
W
wanghuancoder 已提交
556 557 558 559
      if (var == nullptr) {
        continue;
      }

560
      VLOG(6) << "Erase variable " << var_name;
W
wanghuancoder 已提交
561 562 563
      if (var->IsType<LoDTensor>()) {
        garbages->emplace_back(
            var->GetMutable<LoDTensor>()->MoveMemoryHolder());
564 565
      } else if (var->IsType<phi::SelectedRows>()) {
        garbages->emplace_back(var->GetMutable<phi::SelectedRows>()
W
wanghuancoder 已提交
566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582
                                   ->mutable_value()
                                   ->MoveMemoryHolder());
      } else if (var->IsType<LoDTensorArray>()) {
        auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
        for (auto& t : *lod_tensor_arr) {
          garbages->emplace_back(t.MoveMemoryHolder());
        }
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Type %s of variable %s is not supported eager deletion.",
            framework::ToTypeName(var->Type()), var_name));
      }
    }
    delete garbages;  // free mem
  }
}

583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
void add_fetch(const std::vector<std::string>& fetch_names,
               framework::BlockDesc* block) {
  auto* fetch_holder = block->Var(kFetchVarName);
  fetch_holder->SetType(proto::VarType::FETCH_LIST);
  fetch_holder->SetPersistable(true);

  int i = 0;
  for (auto& fetch_name : fetch_names) {
    // append fetch op
    auto* op = block->AppendOp();
    op->SetType("fetch_v2");
    op->SetInput("X", {fetch_name});
    op->SetOutput("Out", {kFetchVarName});
    op->SetAttr("col", {static_cast<int>(i)});
    op->CheckAttrs();
    i++;
  }
}

W
wanghuancoder 已提交
602 603 604 605 606 607 608 609 610 611 612 613 614 615
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
                                 const std::vector<size_t>& second) {
  std::vector<size_t> out(first.size() + second.size());
  std::merge(first.begin(), first.end(), second.begin(), second.end(),
             out.begin());

  std::vector<size_t>::iterator it;
  it = std::unique(out.begin(), out.end());

  out.resize(std::distance(out.begin(), it));

  return out;
}

X
xiongkun 已提交
616
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
617
                          std::map<int, std::list<int>>* var2min_rw_op,
X
xiongkun 已提交
618 619 620
                          int cur_op, int rw_var) {
  // rw_var is inputs or outputs of cur_op
  // this function update the var2min_rw_op set .
621
  if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
622
    (*var2min_rw_op)[rw_var] = std::list<int>();
623
  }
X
xiongkun 已提交
624
  for (auto dep_op : op2dependences.at(cur_op)) {
625
    var2min_rw_op->at(rw_var).remove(dep_op);
X
xiongkun 已提交
626
  }
627
  var2min_rw_op->at(rw_var).push_back(cur_op);
X
xiongkun 已提交
628 629 630
}

std::map<int, std::list<int>> get_downstream_map(
631 632 633 634 635
    const std::map<int, std::set<int>>& op2dependences,
    std::vector<std::vector<bool>>* op_happens_before) {
  // step1: convert op2dependences to downstream_map directly
  // op2dependences is op -> it's dependences.
  // we want to get op -> [next ops] map,
X
xiongkun 已提交
636
  // where ops is the next instruction of op.
637
  std::map<int, std::list<int>> downstream;
X
xiongkun 已提交
638 639 640
  for (auto& item : op2dependences) {
    int op = item.first;
    for (auto dep_op : item.second) {
641 642 643
      if (downstream.find(dep_op) == downstream.end())
        downstream[dep_op] = std::list<int>();
      downstream[dep_op].push_back(op);
X
xiongkun 已提交
644 645
    }
  }
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668

  auto downstream_map_to_str = [&]() -> std::string {
    std::ostringstream oss;
    for (auto pair : downstream) {
      oss << pair.first << " -> ";
      std::copy(pair.second.begin(), pair.second.end(),
                std::ostream_iterator<int>(oss, " "));
      oss << std::endl;
    }
    return oss.str();
  };

  auto downstream_map_count = [&]() -> size_t {
    size_t count = 0;
    for (auto pair : downstream) {
      count += pair.second.size();
    }
    return count;
  };

  VLOG(6) << "downstream count: " << downstream_map_count();
  VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str();

669
  // step2: remove unnecessary downstream ops
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743
  // for example, a->b->c
  // a: b, c
  // b: c
  // =>
  // a: b
  // b: c

  // NOTE(zhiqiu): the size of downstream != size of op2dependences
  // since there are some ops that have no downstream-op.
  auto op_num = op2dependences.size();
  // happens_before[i][j] means i should be executed before j
  op_happens_before->resize(op_num);
  for (size_t i = 0; i < op_num; ++i) {
    (*op_happens_before)[i].resize(op_num);
    std::fill((*op_happens_before)[i].begin(), (*op_happens_before)[i].end(),
              false);
  }

  // bfs to get all next ops
  auto bfs = [&](size_t op_idx) {
    std::queue<size_t> q;
    std::vector<bool> visited(op_num, false);
    q.push(op_idx);
    while (!q.empty()) {
      size_t op = q.front();
      q.pop();
      visited[op] = true;
      if (!downstream.count(op)) {
        continue;
      }
      for (auto next : downstream[op]) {
        if (!visited[next]) {
          PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx], false,
                            paddle::platform::errors::AlreadyExists(
                                "There exists circle in graph, expected "
                                "%d->%d, but already got %d->%d",
                                op_idx, next, next, op_idx));
          (*op_happens_before)[op_idx][next] = true;
          VLOG(8) << "happens before: " << op_idx << " " << next;
          q.push(next);
        }
      }
    }
  };

  for (size_t i = 0; i < op_num; ++i) {
    bfs(i);
  }

  // shrink, find the downstream op that has no other op in the
  // downstream list happens before it
  for (size_t i = 0; i < op_num; ++i) {
    std::list<int> minumum_nexts;
    for (size_t item : downstream[i]) {
      bool not_after_any = true;
      // find the op that is not executed after any
      for (size_t other_item : downstream[i]) {
        if ((*op_happens_before)[other_item][item]) {
          VLOG(8) << "happens_before: " << other_item << "->" << item
                  << ", so skip " << item;
          not_after_any = false;
          break;
        }
      }
      if (not_after_any) {
        VLOG(8) << "downstream op of " << i << ": " << item;
        minumum_nexts.push_back(item);
      }
    }
    downstream[i] = minumum_nexts;
  }
  VLOG(6) << "downstream count: " << downstream_map_count();
  VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str();

744
  return downstream;
X
xiongkun 已提交
745 746 747
}

std::map<int, std::list<int>> build_op_downstream_map(
748 749
    const std::vector<Instruction>& vec_instruction,
    std::vector<std::vector<bool>>* op_happens_before) {
X
xiongkun 已提交
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790
  auto var2min_rw_op = std::map<
      int, std::list<int>>();  // # map from variable id to read / write op id.
  auto var2recent_write_op =
      std::map<int, int>();  // # map from variable to recent write op.
  auto op2dependences =
      std::map<int, std::set<int>>();  //# map from op to the dependence list,
                                       // op must run after the dependence.
  std::set<int>
      remove_duplicate;  // remove the duplicate between inputs and outputs

  // reserve
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    op2dependences[op_idx] = std::set<int>();
  }

  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    remove_duplicate.clear();
    // step1: update the op2dependences structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
        if (var2recent_write_op.count(var))
          op2dependences[op_idx].insert(var2recent_write_op[var]);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        if (var2min_rw_op.count(var)) {
          for (auto dep_op : var2min_rw_op[var]) {
            op2dependences[op_idx].insert(dep_op);
          }
        }
      }
    }

    // step2: update 2 var2xxxx data structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
791
        update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
792 793 794 795 796 797 798 799 800 801
        remove_duplicate.insert(var);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        var2recent_write_op[var] = op_idx;
        if (remove_duplicate.count(var) ==
            0) {  // var in input list and in output list, so remove it.
802
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
803 804 805
        }
      }
    }
806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822

    // NOTE(zhiqiu): The inplace op with `transfer` also changes
    // original output after that so add original output as well
    // original: a->op->a
    // after: a->data_transfer->a'->op->a'->transfer_back->a
    // which means op writes a and a'
    if (!vec_instruction[op_idx].InplaceBackMap().empty()) {
      auto& m = vec_instruction[op_idx].InplaceBackMap();
      for (auto& p : m) {
        auto var = p.second;
        var2recent_write_op[var] = op_idx;
        // var in input list and in output list, so remove it.
        if (remove_duplicate.count(var) == 0) {
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
        }
      }
    }
X
xiongkun 已提交
823
  }
824 825 826 827

  // add dependences for random op, make sure that the random op is scheduled
  // sequentially
  const std::set<std::string> random_op_set = {
828 829 830 831 832 833 834
      "bernoulli", "poisson", "multinomial", "gaussian_random",
      "truncated_gaussian_random", "uniform_random", "randint", "randperm",
      "exponential",
      "sampling_id"
      "dropout",
      "class_center_sample",
  };
835

836 837 838 839 840 841 842 843 844 845
  int dependence_op_idx = -1;
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) {
      if (dependence_op_idx != -1) {
        op2dependences[op_idx].insert(dependence_op_idx);
      }
      dependence_op_idx = op_idx;
    }
  }

846
  // add dependency for communication op
847 848 849 850 851 852 853 854 855 856 857 858
  auto is_comm_op = [](std::string op) -> bool {
    const std::set<std::string> special_comm_op_set = {
        "send", "recv", "send_v2", "recv_v2",
    };
    const std::string communication_op_prefix = "c_";
    if (op.find(communication_op_prefix) != std::string::npos ||
        special_comm_op_set.count(op)) {
      return true;
    }
    return false;
  };

859 860
  dependence_op_idx = -1;
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
861
    if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) {
862 863
      if (dependence_op_idx != -1) {
        op2dependences[op_idx].insert(dependence_op_idx);
864 865 866
        VLOG(4) << "Add depend from "
                << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
                << vec_instruction[op_idx].OpBase()->Type();
867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969
      }
      dependence_op_idx = op_idx;
    }
  }

  // TODO(zhiqiu): there still some cases not handled
  // add dependency for c_sync_comm_stream

  // in program, we can add only one c_sync_comm_stream to sync all
  // communication ops.
  // c_allreduce_sum(a)
  // c_allreduce_sum(b)
  // c_allreduce_sum(c)
  // c_sync_comm_stream(a)
  const std::string kSyncComm = "c_sync_comm_stream";
  dependence_op_idx = -1;
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    if (vec_instruction[op_idx].OpBase()->Type() == kSyncComm) {
      dependence_op_idx = op_idx;
    } else {
      if (dependence_op_idx != -1) {
        VLOG(4) << "Add depend from "
                << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
                << vec_instruction[op_idx].OpBase()->Type();
        op2dependences[op_idx].insert(dependence_op_idx);
      }
    }
  }

  // add dependency for coalesce_tensor
  const std::string kCoalesceTensor = "coalesce_tensor";
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    if (vec_instruction[op_idx].OpBase()->Type() == kCoalesceTensor) {
      VLOG(4) << "Add depend for " << kCoalesceTensor << " " << op_idx;
      auto fused_out = vec_instruction[op_idx].Outputs().at("FusedOutput")[0];
      auto outputs = vec_instruction[op_idx].Outputs().at("Output");

      auto is_read = [](const Instruction& inst, int var_id) -> bool {
        for (auto pair : inst.Inputs()) {
          for (auto item : pair.second) {
            if (item == var_id) {
              return true;
            }
          }
        }
        return false;
      };

      auto is_write = [](const Instruction& inst, int var_id) -> bool {
        for (auto pair : inst.Outputs()) {
          for (auto item : pair.second) {
            if (item == var_id) {
              return true;
            }
          }
        }
        return false;
      };

      // find first op that reads fused_out
      auto first_read_fused_out_op = -1;
      for (auto j = op_idx + 1; j < vec_instruction.size(); ++j) {
        if (is_read(vec_instruction[j], fused_out)) {
          first_read_fused_out_op = j;
          break;
        }
      }

      if (UNLIKELY(first_read_fused_out_op == -1)) {
        VLOG(4) << "No op read FusedOutput";
        continue;
      }

      // find ops that write 'outputs' between (op_index,
      // first_read_fused_out_op)
      // add depend: them->first_read_fused_out_op
      for (auto j = op_idx + 1;
           j < static_cast<size_t>(first_read_fused_out_op); ++j) {
        for (auto var_id : outputs) {
          if (is_write(vec_instruction[j], var_id)) {
            op2dependences[first_read_fused_out_op].insert(j);
            VLOG(4) << j << " -> " << first_read_fused_out_op;
            VLOG(4)
                << "Add depend from " << vec_instruction[j].OpBase()->Type()
                << " to "
                << vec_instruction[first_read_fused_out_op].OpBase()->Type();
          }
        }
      }

      // find first op read 'outputs' between (first_read_fused_out_op, end)
      // add depned:  first_read_fused_out_op -> first op that reads 'outputs'

      // special case for consecutive communication ops, for example,
      // FusedOutput = c_sync_calc_stream(FusedOutput)
      // FusedOutput= c_allreduce_sum(FusedOutput)
      // FusedOutput = c_sync_comm_stream(FusedOutput)
      // we should take the last one to add depned instead of
      // 'first_read_fused_out_op'
      size_t target = first_read_fused_out_op;
      for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size();
           ++j) {
        if (j == target + 1 &&
970 971
            is_comm_op(vec_instruction[target].OpBase()->Type()) &&
            is_comm_op(vec_instruction[j].OpBase()->Type())) {
972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992
          VLOG(4) << "Found consecutive communication ops, "
                  << vec_instruction[target].OpBase()->Type() << " -> "
                  << vec_instruction[j].OpBase()->Type();
          target = j;
          continue;
        }

        for (auto var_id : outputs) {
          if (is_read(vec_instruction[j], var_id)) {
            op2dependences[j].insert(target);
            VLOG(4) << target << " -> " << j;
            VLOG(4) << "Add depend from "
                    << vec_instruction[target].OpBase()->Type() << " to "
                    << vec_instruction[j].OpBase()->Type();
          }
        }
      }
    }
  }
  for (auto pair : op2dependences) {
    std::ostringstream oss;
993
    oss << pair.first << " Depends on " << pair.second.size() << " ops: ";
994 995 996 997
    std::copy(pair.second.begin(), pair.second.end(),
              std::ostream_iterator<int>(oss, " "));
    VLOG(10) << oss.str();
  }
998
  return get_downstream_map(op2dependences, op_happens_before);
X
xiongkun 已提交
999 1000
}

1001
}  // namespace interpreter
W
wanghuancoder 已提交
1002 1003
}  // namespace framework
}  // namespace paddle