interpretercore_util.cc 33.2 KB
Newer Older
W
wanghuancoder 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
15 16
#include <algorithm>

W
wanghuancoder 已提交
17
#include "paddle/fluid/framework/executor_gc_helper.h"
18
#include "paddle/fluid/framework/new_executor/data_transfer.h"
X
xiongkun 已提交
19 20 21
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
22
#include "paddle/phi/core/kernel_factory.h"
W
wanghuancoder 已提交
23

L
Leo Chen 已提交
24 25 26 27
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

28 29 30
PADDLE_DEFINE_EXPORTED_bool(
    new_executor_sequential_run, false,
    "Enable sequential execution for standalone executor, used for debug");
31

32 33
DECLARE_bool(use_mkldnn);

W
wanghuancoder 已提交
34 35
namespace paddle {
namespace framework {
36
namespace interpreter {
37

L
liutiexing 已提交
38 39
constexpr size_t kPrepareWorkQueueIdx = 2;

40 41
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
                             std::function<void()> fn) {
42
  VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " ";
43 44 45 46 47 48 49 50 51 52 53
  // NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
  if (FLAGS_new_executor_sequential_run) {
    VLOG(4) << "FLAGS_new_executor_sequential_run:"
            << FLAGS_new_executor_sequential_run;
    queue_group_->AddTask(static_cast<size_t>(OpFuncType::kQueueAsync),
                          std::move(fn));
  } else {
    queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
  }
}

54
using VariableIdMap = std::map<std::string, std::vector<int>>;
W
wanghuancoder 已提交
55

56
void AsyncWorkQueue::PrepareAtomicDeps(
57
    const std::vector<size_t>& dependecy_count) {
58
  VLOG(4) << "PrepareAtomicDeps";
L
liutiexing 已提交
59 60 61 62 63 64 65 66 67 68
  atomic_deps_ =
      queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&dependecy_count] {
        auto op_deps = std::make_unique<std::vector<std::atomic<size_t>>>(
            dependecy_count.size());
        for (size_t i = 0; i < dependecy_count.size(); ++i) {
          (*op_deps)[i] = dependecy_count[i];
        }
        VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
        return op_deps;
      });
69 70
}

71
void AsyncWorkQueue::PrepareAtomicVarRef(
72
    const std::vector<VariableMetaInfo>& vec_meta_info) {
73
  VLOG(4) << "PrepareAtomicVarRef";
L
liutiexing 已提交
74 75 76 77 78 79 80 81 82 83
  atomic_var_ref_ =
      queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&vec_meta_info] {
        auto var_ref = std::make_unique<std::vector<std::atomic<size_t>>>(
            vec_meta_info.size());
        for (size_t i = 0; i < vec_meta_info.size(); ++i) {
          (*var_ref)[i] = vec_meta_info[i].var_ref_count_;
        }
        VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
        return var_ref;
      });
84 85
}

W
wanghuancoder 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
  auto* var_desc = block.FindVar(name);
  if (var_desc == nullptr || var_desc->Persistable()) {
    return false;
  }

  auto type = var_desc->Proto()->type().type();

  return type == proto::VarType::LOD_TENSOR ||
         type == proto::VarType::SELECTED_ROWS ||
         type == proto::VarType::LOD_TENSOR_ARRAY;
}

std::unordered_map<const paddle::framework::OperatorBase*,
                   std::vector<std::string>>
L
Leo Chen 已提交
101 102
get_unused_vars(const BlockDesc& block,
                const std::vector<std::shared_ptr<OperatorBase>>& ops) {
W
wanghuancoder 已提交
103 104 105
  std::unordered_map<std::string, size_t> var_op_idx_map;

  for (size_t i = 0; i < ops.size(); ++i) {
L
Leo Chen 已提交
106
    const auto& op = ops[i];
W
wanghuancoder 已提交
107 108 109 110 111 112 113 114 115 116

    OpInOutInfo info;
    for (auto& name_pair : op->Inputs()) {
      for (auto& name : name_pair.second) {
        if (!var_can_be_deleted(name, block)) {
          continue;
        }

        // var can be gc-ed
        if (!info.IsBuilt()) {
L
Leo Chen 已提交
117
          info.Build(op.get());
W
wanghuancoder 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        }

        if (info.IsInArgBufferNeeded(name)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        } else {
          VLOG(10) << "Skip reference count computing of variable "
                   << name_pair.first << "(" << name << ") in Operator "
                   << op->Type();
        }
      }
    }

    for (auto& name_pair : op->Outputs()) {
      for (auto& name : name_pair.second) {
        if (var_can_be_deleted(name, block)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        }
      }
    }
  }

  std::unordered_map<const OperatorBase*, std::vector<std::string>> result;
  for (auto& name_op_idx_pair : var_op_idx_map) {
    auto& name = name_op_idx_pair.first;
    size_t op_idx = name_op_idx_pair.second;
L
Leo Chen 已提交
145 146

    result[ops[op_idx].get()].emplace_back(name);
147
    VLOG(4) << ops[op_idx].get()->Type() << " " << name;
W
wanghuancoder 已提交
148
  }
149
  VLOG(4) << "gc map size:" << result.size();
W
wanghuancoder 已提交
150 151 152
  return result;
}

153
void build_variable_scope(const framework::BlockDesc& block,
154 155 156 157 158 159 160 161 162
                          VariableScope* var_scope, bool use_local_scope) {
  VLOG(3) << "Creating Variables";
  auto inner_scope = var_scope->GetMutableScope();

  // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
  // created in var_scope.scope_ , and other scope is created in local scope.
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();

163
  for (auto& var_desc : block.AllVars()) {
164
    auto var_name = var_desc->Name();
X
xiongkun 已提交
165 166 167
    // TODO(xiongkun): user may create a variable with name that exists before.
    // under such circumstances, we should raise a error. Currently we can't
    // get the var_desc of startup_program, so leave it later.
168
    if (var_name == framework::kEmptyVarName) {
W
wanghuancoder 已提交
169 170
      continue;
    }
171 172
    if (var_desc->Persistable()) {
      auto* ptr = inner_scope->Var(var_name);
W
wanghuancoder 已提交
173

174 175 176 177
      VLOG(3) << "Initialize Variable " << var_name;
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
              << ptr << " type is " << static_cast<int>(var_desc->GetType());
178
    } else {
179 180 181 182 183
      auto* ptr = local_scope->Var(var_name);
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
              << ptr << "Variable Type "
              << static_cast<int>(var_desc->GetType());
W
wanghuancoder 已提交
184
    }
185
    var_scope->SetVarDesc(var_name, var_desc);
W
wanghuancoder 已提交
186 187 188
  }
}

L
Leo Chen 已提交
189
void create_all_ops(const framework::BlockDesc& block,
X
xiongkun 已提交
190
                    std::vector<std::unique_ptr<OperatorBase>>* ops) {
191 192
  for (auto& op : block.AllOps()) {
    VLOG(3) << "CreateOp from : " << op->Type();
W
wanghuancoder 已提交
193 194 195 196 197

    auto& info = OpInfoMap::Instance().Get(op->Type());

    const VariableNameMap& inputs_names = op->Inputs();
    const VariableNameMap& outputs_names = op->Outputs();
198

W
wanghuancoder 已提交
199 200 201 202 203 204 205
    AttributeMap op_attr_map = op->GetAttrMap();

    if (info.Checker() != nullptr) {
      info.Checker()->Check(&op_attr_map);
    }
    auto op_base =
        info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
206 207 208 209 210 211 212 213 214 215

#ifdef PADDLE_WITH_MKLDNN
    if (FLAGS_use_mkldnn) {
      if (op->HasAttr("use_mkldnn")) {
        VLOG(4) << "Set use_mkldnn=True for " << op_base->Type();
        op_base->SetAttr("use_mkldnn", true);
      }
    }
#endif

X
xiongkun 已提交
216
    ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
W
wanghuancoder 已提交
217
  }
218 219 220
}

std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
221 222
    const VariableNameMap& var_name_map, VariableScope* var_scope,
    bool enforce_exist = true) {
223 224 225 226 227 228 229 230
  VariableValueMap name2var;
  VariableIdMap name2id;
  for (auto& item : var_name_map) {
    std::vector<Variable*> vars;
    std::vector<int> ids;
    vars.reserve(item.second.size());

    for (auto& var_name : item.second) {
231 232 233 234 235
      if (!enforce_exist && !var_scope->HasVar(var_name)) {
        // skip the non-exist variable: such as recurrent_grad
        VLOG(4) << var_name << " don't exist in variable scope, skip it!";
        continue;
      }
236 237 238 239 240 241 242 243 244 245
      auto var_id = var_scope->VarId(var_name);
      auto* in_var = var_scope->Var(var_id);
      vars.push_back(in_var);
      ids.push_back(var_id);
    }
    name2var[item.first] = std::move(vars);
    name2id[item.first] = std::move(ids);
  }
  return std::make_tuple(name2var, name2id);
}
W
wanghuancoder 已提交
246

247 248 249 250 251 252 253 254 255 256 257 258
void apply_device_guard(const OperatorBase* op_base,
                        const platform::Place& place,
                        OpKernelType* expected_kernel_key) {
  bool need_change_place =
      (op_base->HasAttr("op_device") &&
       (op_base->Attr<std::string>("op_device").length() > 0));
  if (need_change_place) {
    auto& op_device = op_base->Attr<std::string>("op_device");
    if (op_device == "cpu" || platform::is_cpu_place(place)) {
      VLOG(3) << "Switch into CPUPlace by device_guard.";
      expected_kernel_key->place_ = platform::CPUPlace();
    } else if (op_device.find("gpu") != std::string::npos &&
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
               (platform::is_gpu_place(place) ||
                platform::is_npu_place(place))) {
      // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
      // will be executed and a warning will be given at the same time.
      if (op_base->SupportGPU()) {
        expected_kernel_key->place_ = place;
      } else if (op_base->SupportNPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no CUDA implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
275 276 277 278 279 280 281
    } else {
      PADDLE_THROW(
          platform::errors::Fatal("Unsupported current place %s", op_device));
    }
  }
}

282
void deal_operator_base(const platform::Place& place,
L
Leo Chen 已提交
283 284
                        const VariableScope* var_scope,
                        std::shared_ptr<OperatorBase> op_base,
285
                        OpFuncNode* op_func_node, Scope* local_scope) {
286 287 288 289
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  // input, output is prepared. set the other attributes.
  op_func_node->operator_base_ = op_base;
290 291 292 293 294 295 296 297 298
  if (platform::is_gpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueAsync;
  } else if (platform::is_cpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueSync;
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("Unsupported current place %s", place));
  }

299
  op_func_node->kernel_func_ = nullptr;
300
  op_base->Run(*local_scope, place);  // Run without data transformer.
301 302 303 304 305 306 307 308 309 310 311 312

  std::unordered_set<int> no_data_transform_index;
  for (auto& it : op_func_node->input_index) {
    for (auto& id : it.second) {
      no_data_transform_index.emplace(id);
    }
  }
  op_func_node->no_data_transform_index =
      no_data_transform_index;  // all index is no-need-transform
  op_func_node->dev_ctx_ = dev_ctx;
}

313
void build_op_func_list(const platform::Place& place,
314
                        const framework::BlockDesc& block,
315
                        std::vector<OpFuncNode>* vec_func_list,
316 317 318
                        VariableScope* var_scope, bool use_local_scope) {
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();
319
  auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
X
xiongkun 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332
  std::vector<std::unique_ptr<OperatorBase>>
      ops_unique;  // its elements will be moved to vec_func_list
  // Step 1: create all ops for current block.
  create_all_ops(block, &ops_unique);
  // If gc is enabled and block size > 1
  const ProgramDesc& main_program = *block.Program();
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
      main_program, block.ID(), ops_unique);

L
Leo Chen 已提交
333 334 335 336
#ifdef PADDLE_WITH_MKLDNN
  platform::RegisterModelLayout(ops_unique, place);
#endif

337 338
  // its elements will be moved to vec_func_list
  std::vector<std::shared_ptr<OperatorBase>> ops;
X
xiongkun 已提交
339 340 341
  for (auto& op_unique : ops_unique) {
    ops.emplace_back(std::move(op_unique));
  }
342
  auto unused_var_map = get_unused_vars(block, ops);
W
wanghuancoder 已提交
343

L
Leo Chen 已提交
344 345
  for (size_t i = 0; i < ops.size(); ++i) {
    auto op = ops[i].get();
346
    VLOG(6) << "Build OpFuncNode from : " << op->Type();
W
wanghuancoder 已提交
347 348 349 350 351

    auto inputs_names = op->Inputs();
    auto outputs_names = op->Outputs();

    VariableValueMap ins_map;
352
    VariableIdMap ins_name2id;
353
    bool enforce_exist = true;
W
wanghuancoder 已提交
354 355 356 357 358 359 360
    if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
        op->Type() == "rnn_memory_helper_grad" ||
        op->Type() == "conditional_block" ||
        op->Type() == "conditional_block_grad" || op->Type() == "while" ||
        op->Type() == "while_grad") {
      enforce_exist = false;
    }
361
    std::tie(ins_map, ins_name2id) =
362
        build_variable_map(inputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
363 364

    VariableValueMap outs_map;
365 366
    VariableIdMap outs_name2id;
    std::tie(outs_map, outs_name2id) =
367
        build_variable_map(outputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
368

369
    // step 2: build OpFuncNode
W
wanghuancoder 已提交
370
    OpFuncNode op_func_node;
371
    op_func_node.operator_base_ = ops[i];
W
wanghuancoder 已提交
372 373
    op_func_node.input_index = ins_name2id;
    op_func_node.output_index = outs_name2id;
374
    VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
375

376
    if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
377
      // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
378
      deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
379 380
      VLOG(4) << "End run " << place << " "
              << op_func_node.operator_base_->DebugStringEx(local_scope);
W
wanghuancoder 已提交
381
    } else {
382 383
      auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(op));
384 385 386 387
      // construct RuntimeContext and analysis KernelType
      RuntimeContext runtime_context({}, {});
      runtime_context.inputs.swap(ins_map);
      runtime_context.outputs.swap(outs_map);
388

389 390 391 392
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(place);
      Scope scope;
393 394
      auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
          ExecutionContext(*op, scope, *dev_ctx, runtime_context));
395
      op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key));
396

397 398
      // change device by the device_guard()
      apply_device_guard(op, place, &expected_kernel_key);
399 400
      VLOG(3) << "expected_kernel_key : " << expected_kernel_key;

401
      // step 3. apply data transforms and insert data transfer ops
402
      VariableValueMap& ins_map_temp = runtime_context.inputs;
403
      VariableValueMap& outs_map_temp = runtime_context.outputs;
404 405 406

      // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
      // ApplyDataTransform
407 408 409 410 411 412
      ApplyDataTransform(expected_kernel_key, place, &ins_map_temp,
                         &outs_map_temp, var_scope, &op_func_node,
                         vec_func_list, use_local_scope);
      op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(
              op_func_node.operator_base_.get()));
413

414
      // step 4. Run op kernel
415
      VLOG(3) << op_with_kernel->Type()
416 417 418 419 420 421 422 423 424 425 426 427 428 429
              << " : expected_kernel_key : " << expected_kernel_key;

      if (platform::is_gpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueAsync;
      } else if (platform::is_cpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueSync;
      } else {
        PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
                                             expected_kernel_key.place_));
      }
      if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
        dev_ctx = pool.Get(expected_kernel_key.place_);
      }
      op_func_node.dev_ctx_ = dev_ctx;
430 431
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
W
wanghuancoder 已提交
432

433 434 435 436 437
      // see OperatorWithKernel::RunImpl in operator.cc for why
      if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
            op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
        InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
        // TODO(Aurelius84): In case of control flow ops, they are NOT
438
        // inheritted from OperatorWithKernel.
439 440 441
        op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
      }

442 443 444
      auto exec_ctx =
          ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);

445 446
      auto run_phi_kernel = false;
      if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
447
              op_with_kernel->Type())) {
448 449
        auto pt_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
        auto pt_kernel_name = op_with_kernel->PhiKernelSignature()->name;
450

451 452
        if (op_with_kernel->PhiKernel()->IsValid()) {
          run_phi_kernel = true;
453 454 455 456 457 458 459
        } else {
          auto kernels_iter = all_op_kernels.find(op_with_kernel->Type());
          if (kernels_iter == all_op_kernels.end() ||
              kernels_iter->second.find(expected_kernel_key) ==
                  kernels_iter->second.end()) {
            auto pt_cpu_kernel_key = FallBackToCpu(
                expected_kernel_key, pt_kernel_key, *op_with_kernel);
460
            op_with_kernel->ResetPhiKernel(
461
                new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
462
                    pt_kernel_name, pt_cpu_kernel_key)));
463
            if (op_with_kernel->PhiKernel()->IsValid()) {
464 465 466
              VLOG(6) << "Static mode PrepareImpl - kernel name: "
                      << pt_kernel_name
                      << " | kernel key: " << pt_cpu_kernel_key
467 468
                      << " | kernel: " << *(op_with_kernel->PhiKernel());
              run_phi_kernel = true;
469 470 471 472 473 474
            }
          }
        }
      }
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
475
      if (run_phi_kernel) {
476
        phi::KernelContext pt_kernel_context;
477 478 479
        op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx,
                                              &pt_kernel_context);
        op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
480
        (*op_func_node.pt_kernel_)(&pt_kernel_context);
481
      } else {
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
        auto kernels_iter = all_op_kernels.find(op->Type());
        PADDLE_ENFORCE_NE(
            kernels_iter, all_op_kernels.end(),
            platform::errors::Unavailable(
                "There are no kernels which are registered in the %s operator.",
                op->Type()));
        OpKernelMap& kernels = kernels_iter->second;

        auto kernel_iter = kernels.find(expected_kernel_key);
        PADDLE_ENFORCE_NE(
            kernel_iter, kernels.end(),
            platform::errors::NotFound(
                "Operator (%s) does not have kernel for %s.", op->Type(),
                KernelTypeToString(expected_kernel_key)));
        // TODO(zhiqiu): add fallback logic
497 498 499
        op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
        op_func_node.kernel_func_(exec_ctx);
      }
500 501 502 503 504 505 506 507

      // post-process grad_op.outputs if need cast complex grad into real grad.
      // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
      if (framework::IsComplexType(expected_kernel_key.data_type_)) {
        interpreter::HandleComplexGradToRealGrad(
            op_func_node, place, outputs_names, &runtime_context.outputs,
            var_scope, vec_func_list, local_scope);
      }
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522
      if (!op_func_node.inplace_back_map.empty()) {
        auto& m = op_func_node.inplace_back_map;
        // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
        for (auto& p : m) {
          auto* transformed_tensor =
              GetMutableLoDTensorOrSelectedRowsValueFromVar(
                  var_scope->Var(p.first));
          auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
              var_scope->Var(p.second));
          original_tensor->ShareDataWith(*transformed_tensor);
          VLOG(4) << "Transfer inplace variable back form "
                  << var_scope->GetNameById(p.first) << " to "
                  << var_scope->GetNameById(p.second);
        }
      }
523
    }
W
wanghuancoder 已提交
524

525 526 527
    VLOG(4) << "End run " << place << " "
            << op_func_node.operator_base_->DebugStringEx(local_scope);

L
Leo Chen 已提交
528
    vec_func_list->emplace_back(op_func_node);
529

W
wanghuancoder 已提交
530
    // gc---------------------------------------------------------------------------
L
Leo Chen 已提交
531
    auto iter = unused_var_map.find(op);
W
wanghuancoder 已提交
532 533 534 535 536 537 538 539 540
    if (iter == unused_var_map.end()) {
      continue;
    }

    auto& delete_vars = iter->second;
    std::deque<std::shared_ptr<memory::Allocation>>* garbages =
        new std::deque<std::shared_ptr<memory::Allocation>>();

    for (auto& var_name : delete_vars) {
541
      auto* var = var_scope->FindVar(var_name);
W
wanghuancoder 已提交
542 543 544 545
      if (var == nullptr) {
        continue;
      }

546
      VLOG(6) << "Erase variable " << var_name;
W
wanghuancoder 已提交
547 548 549
      if (var->IsType<LoDTensor>()) {
        garbages->emplace_back(
            var->GetMutable<LoDTensor>()->MoveMemoryHolder());
550 551
      } else if (var->IsType<phi::SelectedRows>()) {
        garbages->emplace_back(var->GetMutable<phi::SelectedRows>()
W
wanghuancoder 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568
                                   ->mutable_value()
                                   ->MoveMemoryHolder());
      } else if (var->IsType<LoDTensorArray>()) {
        auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
        for (auto& t : *lod_tensor_arr) {
          garbages->emplace_back(t.MoveMemoryHolder());
        }
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Type %s of variable %s is not supported eager deletion.",
            framework::ToTypeName(var->Type()), var_name));
      }
    }
    delete garbages;  // free mem
  }
}

569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587
void add_fetch(const std::vector<std::string>& fetch_names,
               framework::BlockDesc* block) {
  auto* fetch_holder = block->Var(kFetchVarName);
  fetch_holder->SetType(proto::VarType::FETCH_LIST);
  fetch_holder->SetPersistable(true);

  int i = 0;
  for (auto& fetch_name : fetch_names) {
    // append fetch op
    auto* op = block->AppendOp();
    op->SetType("fetch_v2");
    op->SetInput("X", {fetch_name});
    op->SetOutput("Out", {kFetchVarName});
    op->SetAttr("col", {static_cast<int>(i)});
    op->CheckAttrs();
    i++;
  }
}

W
wanghuancoder 已提交
588 589 590 591 592 593 594 595 596 597 598 599 600 601
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
                                 const std::vector<size_t>& second) {
  std::vector<size_t> out(first.size() + second.size());
  std::merge(first.begin(), first.end(), second.begin(), second.end(),
             out.begin());

  std::vector<size_t>::iterator it;
  it = std::unique(out.begin(), out.end());

  out.resize(std::distance(out.begin(), it));

  return out;
}

X
xiongkun 已提交
602
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
603
                          std::map<int, std::list<int>>* var2min_rw_op,
X
xiongkun 已提交
604 605 606
                          int cur_op, int rw_var) {
  // rw_var is inputs or outputs of cur_op
  // this function update the var2min_rw_op set .
607
  if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
608
    (*var2min_rw_op)[rw_var] = std::list<int>();
609
  }
X
xiongkun 已提交
610
  for (auto dep_op : op2dependences.at(cur_op)) {
611
    var2min_rw_op->at(rw_var).remove(dep_op);
X
xiongkun 已提交
612
  }
613
  var2min_rw_op->at(rw_var).push_back(cur_op);
X
xiongkun 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
}

std::map<int, std::list<int>> get_downstream_map(
    const std::map<int, std::set<int>>& op2dependences) {
  // op2dependences is op -> it's dependences. we want to get op -> [ops] map,
  // where ops is the next instruction of op.
  std::map<int, std::list<int>> result;
  for (auto& item : op2dependences) {
    int op = item.first;
    for (auto dep_op : item.second) {
      if (result.find(dep_op) == result.end())
        result[dep_op] = std::list<int>();
      result[dep_op].push_back(op);
    }
  }
  return std::move(result);
}

std::map<int, std::list<int>> build_op_downstream_map(
    const std::vector<Instruction>& vec_instruction) {
  auto var2min_rw_op = std::map<
      int, std::list<int>>();  // # map from variable id to read / write op id.
  auto var2recent_write_op =
      std::map<int, int>();  // # map from variable to recent write op.
  auto op2dependences =
      std::map<int, std::set<int>>();  //# map from op to the dependence list,
                                       // op must run after the dependence.
  std::set<int>
      remove_duplicate;  // remove the duplicate between inputs and outputs

  // reserve
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    op2dependences[op_idx] = std::set<int>();
  }

  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    remove_duplicate.clear();
    // step1: update the op2dependences structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
        if (var2recent_write_op.count(var))
          op2dependences[op_idx].insert(var2recent_write_op[var]);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        if (var2min_rw_op.count(var)) {
          for (auto dep_op : var2min_rw_op[var]) {
            op2dependences[op_idx].insert(dep_op);
          }
        }
      }
    }

    // step2: update 2 var2xxxx data structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
675
        update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
676 677 678 679 680 681 682 683 684 685
        remove_duplicate.insert(var);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        var2recent_write_op[var] = op_idx;
        if (remove_duplicate.count(var) ==
            0) {  // var in input list and in output list, so remove it.
686
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
687 688 689
        }
      }
    }
690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706

    // NOTE(zhiqiu): The inplace op with `transfer` also changes
    // original output after that so add original output as well
    // original: a->op->a
    // after: a->data_transfer->a'->op->a'->transfer_back->a
    // which means op writes a and a'
    if (!vec_instruction[op_idx].InplaceBackMap().empty()) {
      auto& m = vec_instruction[op_idx].InplaceBackMap();
      for (auto& p : m) {
        auto var = p.second;
        var2recent_write_op[var] = op_idx;
        // var in input list and in output list, so remove it.
        if (remove_duplicate.count(var) == 0) {
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
        }
      }
    }
X
xiongkun 已提交
707
  }
708 709 710 711

  // add dependences for random op, make sure that the random op is scheduled
  // sequentially
  const std::set<std::string> random_op_set = {
712 713 714 715 716 717 718
      "bernoulli", "poisson", "multinomial", "gaussian_random",
      "truncated_gaussian_random", "uniform_random", "randint", "randperm",
      "exponential",
      "sampling_id"
      "dropout",
      "class_center_sample",
  };
719

720 721 722 723 724 725 726 727 728 729
  int dependence_op_idx = -1;
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) {
      if (dependence_op_idx != -1) {
        op2dependences[op_idx].insert(dependence_op_idx);
      }
      dependence_op_idx = op_idx;
    }
  }

730
  // add dependency for communication op
731 732 733 734 735 736 737 738 739 740 741 742
  auto is_comm_op = [](std::string op) -> bool {
    const std::set<std::string> special_comm_op_set = {
        "send", "recv", "send_v2", "recv_v2",
    };
    const std::string communication_op_prefix = "c_";
    if (op.find(communication_op_prefix) != std::string::npos ||
        special_comm_op_set.count(op)) {
      return true;
    }
    return false;
  };

743 744
  dependence_op_idx = -1;
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
745
    if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) {
746 747
      if (dependence_op_idx != -1) {
        op2dependences[op_idx].insert(dependence_op_idx);
748 749 750
        VLOG(4) << "Add depend from "
                << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
                << vec_instruction[op_idx].OpBase()->Type();
751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853
      }
      dependence_op_idx = op_idx;
    }
  }

  // TODO(zhiqiu): there still some cases not handled
  // add dependency for c_sync_comm_stream

  // in program, we can add only one c_sync_comm_stream to sync all
  // communication ops.
  // c_allreduce_sum(a)
  // c_allreduce_sum(b)
  // c_allreduce_sum(c)
  // c_sync_comm_stream(a)
  const std::string kSyncComm = "c_sync_comm_stream";
  dependence_op_idx = -1;
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    if (vec_instruction[op_idx].OpBase()->Type() == kSyncComm) {
      dependence_op_idx = op_idx;
    } else {
      if (dependence_op_idx != -1) {
        VLOG(4) << "Add depend from "
                << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
                << vec_instruction[op_idx].OpBase()->Type();
        op2dependences[op_idx].insert(dependence_op_idx);
      }
    }
  }

  // add dependency for coalesce_tensor
  const std::string kCoalesceTensor = "coalesce_tensor";
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    if (vec_instruction[op_idx].OpBase()->Type() == kCoalesceTensor) {
      VLOG(4) << "Add depend for " << kCoalesceTensor << " " << op_idx;
      auto fused_out = vec_instruction[op_idx].Outputs().at("FusedOutput")[0];
      auto outputs = vec_instruction[op_idx].Outputs().at("Output");

      auto is_read = [](const Instruction& inst, int var_id) -> bool {
        for (auto pair : inst.Inputs()) {
          for (auto item : pair.second) {
            if (item == var_id) {
              return true;
            }
          }
        }
        return false;
      };

      auto is_write = [](const Instruction& inst, int var_id) -> bool {
        for (auto pair : inst.Outputs()) {
          for (auto item : pair.second) {
            if (item == var_id) {
              return true;
            }
          }
        }
        return false;
      };

      // find first op that reads fused_out
      auto first_read_fused_out_op = -1;
      for (auto j = op_idx + 1; j < vec_instruction.size(); ++j) {
        if (is_read(vec_instruction[j], fused_out)) {
          first_read_fused_out_op = j;
          break;
        }
      }

      if (UNLIKELY(first_read_fused_out_op == -1)) {
        VLOG(4) << "No op read FusedOutput";
        continue;
      }

      // find ops that write 'outputs' between (op_index,
      // first_read_fused_out_op)
      // add depend: them->first_read_fused_out_op
      for (auto j = op_idx + 1;
           j < static_cast<size_t>(first_read_fused_out_op); ++j) {
        for (auto var_id : outputs) {
          if (is_write(vec_instruction[j], var_id)) {
            op2dependences[first_read_fused_out_op].insert(j);
            VLOG(4) << j << " -> " << first_read_fused_out_op;
            VLOG(4)
                << "Add depend from " << vec_instruction[j].OpBase()->Type()
                << " to "
                << vec_instruction[first_read_fused_out_op].OpBase()->Type();
          }
        }
      }

      // find first op read 'outputs' between (first_read_fused_out_op, end)
      // add depned:  first_read_fused_out_op -> first op that reads 'outputs'

      // special case for consecutive communication ops, for example,
      // FusedOutput = c_sync_calc_stream(FusedOutput)
      // FusedOutput= c_allreduce_sum(FusedOutput)
      // FusedOutput = c_sync_comm_stream(FusedOutput)
      // we should take the last one to add depned instead of
      // 'first_read_fused_out_op'
      size_t target = first_read_fused_out_op;
      for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size();
           ++j) {
        if (j == target + 1 &&
854 855
            is_comm_op(vec_instruction[target].OpBase()->Type()) &&
            is_comm_op(vec_instruction[j].OpBase()->Type())) {
856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881
          VLOG(4) << "Found consecutive communication ops, "
                  << vec_instruction[target].OpBase()->Type() << " -> "
                  << vec_instruction[j].OpBase()->Type();
          target = j;
          continue;
        }

        for (auto var_id : outputs) {
          if (is_read(vec_instruction[j], var_id)) {
            op2dependences[j].insert(target);
            VLOG(4) << target << " -> " << j;
            VLOG(4) << "Add depend from "
                    << vec_instruction[target].OpBase()->Type() << " to "
                    << vec_instruction[j].OpBase()->Type();
          }
        }
      }
    }
  }
  for (auto pair : op2dependences) {
    VLOG(10) << pair.first << " Depends on " << pair.second.size();
    std::ostringstream oss;
    std::copy(pair.second.begin(), pair.second.end(),
              std::ostream_iterator<int>(oss, " "));
    VLOG(10) << oss.str();
  }
X
xiongkun 已提交
882 883 884
  return std::move(get_downstream_map(op2dependences));
}

885
}  // namespace interpreter
W
wanghuancoder 已提交
886 887
}  // namespace framework
}  // namespace paddle