interpretercore_util.cc 26.7 KB
Newer Older
W
wanghuancoder 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
15 16
#include <algorithm>

W
wanghuancoder 已提交
17
#include "paddle/fluid/framework/executor_gc_helper.h"
18
#include "paddle/fluid/framework/new_executor/data_transfer.h"
X
xiongkun 已提交
19 20 21
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
22
#include "paddle/phi/core/kernel_factory.h"
W
wanghuancoder 已提交
23

24 25 26
PADDLE_DEFINE_EXPORTED_bool(
    new_executor_sequential_run, false,
    "Enable sequential execution for standalone executor, used for debug");
27

W
wanghuancoder 已提交
28 29
namespace paddle {
namespace framework {
30
namespace interpreter {
31 32 33 34 35 36 37 38 39 40 41 42 43 44

void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
                             std::function<void()> fn) {
  // NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
  if (FLAGS_new_executor_sequential_run) {
    VLOG(4) << "FLAGS_new_executor_sequential_run:"
            << FLAGS_new_executor_sequential_run;
    queue_group_->AddTask(static_cast<size_t>(OpFuncType::kQueueAsync),
                          std::move(fn));
  } else {
    queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
  }
}

45
using VariableIdMap = std::map<std::string, std::vector<int>>;
W
wanghuancoder 已提交
46

47
void AsyncWorkQueue::PrepareAtomicDeps(
48
    const std::vector<size_t>& dependecy_count) {
49 50 51 52 53 54 55 56 57 58 59 60 61
  VLOG(4) << "PrepareAtomicDeps";
  auto p = std::make_shared<
      std::promise<std::unique_ptr<std::vector<std::atomic<size_t>>>>>();
  atomic_deps_ = p->get_future();
  queue_group_->AddTask(2, [&dependecy_count, p] {
    auto* op_deps =
        new std::vector<std::atomic<size_t>>(dependecy_count.size());
    for (size_t i = 0; i < dependecy_count.size(); ++i) {
      (*op_deps)[i] = dependecy_count[i];
    }
    VLOG(4) << "AtomicDeps:" << op_deps << " " << (*op_deps).size();
    p->set_value(std::unique_ptr<std::vector<std::atomic<size_t>>>(op_deps));
  });
62 63
}

64
void AsyncWorkQueue::PrepareAtomicVarRef(
65
    const std::vector<VariableMetaInfo>& vec_meta_info) {
66 67 68 69 70 71 72 73 74 75 76 77
  VLOG(4) << "PrepareAtomicVarRef";
  auto p = std::make_shared<
      std::promise<std::unique_ptr<std::vector<std::atomic<size_t>>>>>();
  atomic_var_ref_ = p->get_future();
  queue_group_->AddTask(2, [&vec_meta_info, p] {
    auto* var_ref = new std::vector<std::atomic<size_t>>(vec_meta_info.size());
    for (size_t i = 0; i < vec_meta_info.size(); ++i) {
      (*var_ref)[i] = vec_meta_info[i].var_ref_count_;
    }
    VLOG(4) << "AtomicVarRef:" << var_ref << " " << (*var_ref).size();
    p->set_value(std::unique_ptr<std::vector<std::atomic<size_t>>>(var_ref));
  });
78 79
}

W
wanghuancoder 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
  auto* var_desc = block.FindVar(name);
  if (var_desc == nullptr || var_desc->Persistable()) {
    return false;
  }

  auto type = var_desc->Proto()->type().type();

  return type == proto::VarType::LOD_TENSOR ||
         type == proto::VarType::SELECTED_ROWS ||
         type == proto::VarType::LOD_TENSOR_ARRAY;
}

std::unordered_map<const paddle::framework::OperatorBase*,
                   std::vector<std::string>>
L
Leo Chen 已提交
95 96
get_unused_vars(const BlockDesc& block,
                const std::vector<std::shared_ptr<OperatorBase>>& ops) {
W
wanghuancoder 已提交
97 98 99
  std::unordered_map<std::string, size_t> var_op_idx_map;

  for (size_t i = 0; i < ops.size(); ++i) {
L
Leo Chen 已提交
100
    const auto& op = ops[i];
W
wanghuancoder 已提交
101 102 103 104 105 106 107 108 109 110

    OpInOutInfo info;
    for (auto& name_pair : op->Inputs()) {
      for (auto& name : name_pair.second) {
        if (!var_can_be_deleted(name, block)) {
          continue;
        }

        // var can be gc-ed
        if (!info.IsBuilt()) {
L
Leo Chen 已提交
111
          info.Build(op.get());
W
wanghuancoder 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        }

        if (info.IsInArgBufferNeeded(name)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        } else {
          VLOG(10) << "Skip reference count computing of variable "
                   << name_pair.first << "(" << name << ") in Operator "
                   << op->Type();
        }
      }
    }

    for (auto& name_pair : op->Outputs()) {
      for (auto& name : name_pair.second) {
        if (var_can_be_deleted(name, block)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        }
      }
    }
  }

  std::unordered_map<const OperatorBase*, std::vector<std::string>> result;
  for (auto& name_op_idx_pair : var_op_idx_map) {
    auto& name = name_op_idx_pair.first;
    size_t op_idx = name_op_idx_pair.second;
L
Leo Chen 已提交
139 140

    result[ops[op_idx].get()].emplace_back(name);
141
    VLOG(4) << ops[op_idx].get()->Type() << " " << name;
W
wanghuancoder 已提交
142
  }
143
  VLOG(4) << "gc map size:" << result.size();
W
wanghuancoder 已提交
144 145 146
  return result;
}

147
void build_variable_scope(const framework::BlockDesc& block,
148 149 150 151 152 153 154 155 156
                          VariableScope* var_scope, bool use_local_scope) {
  VLOG(3) << "Creating Variables";
  auto inner_scope = var_scope->GetMutableScope();

  // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
  // created in var_scope.scope_ , and other scope is created in local scope.
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();

157
  for (auto& var_desc : block.AllVars()) {
158
    auto var_name = var_desc->Name();
X
xiongkun 已提交
159 160 161
    // TODO(xiongkun): user may create a variable with name that exists before.
    // under such circumstances, we should raise a error. Currently we can't
    // get the var_desc of startup_program, so leave it later.
162
    if (var_name == framework::kEmptyVarName) {
W
wanghuancoder 已提交
163 164
      continue;
    }
165 166
    if (var_desc->Persistable()) {
      auto* ptr = inner_scope->Var(var_name);
W
wanghuancoder 已提交
167

168 169 170 171
      VLOG(3) << "Initialize Variable " << var_name;
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
              << ptr << " type is " << static_cast<int>(var_desc->GetType());
172
    } else {
173 174 175 176 177
      auto* ptr = local_scope->Var(var_name);
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
              << ptr << "Variable Type "
              << static_cast<int>(var_desc->GetType());
W
wanghuancoder 已提交
178
    }
179
    var_scope->SetVarDesc(var_name, var_desc);
W
wanghuancoder 已提交
180 181 182
  }
}

L
Leo Chen 已提交
183
void create_all_ops(const framework::BlockDesc& block,
X
xiongkun 已提交
184
                    std::vector<std::unique_ptr<OperatorBase>>* ops) {
185 186
  for (auto& op : block.AllOps()) {
    VLOG(3) << "CreateOp from : " << op->Type();
W
wanghuancoder 已提交
187 188 189 190 191 192 193 194 195 196 197 198

    auto& info = OpInfoMap::Instance().Get(op->Type());

    const VariableNameMap& inputs_names = op->Inputs();
    const VariableNameMap& outputs_names = op->Outputs();
    AttributeMap op_attr_map = op->GetAttrMap();

    if (info.Checker() != nullptr) {
      info.Checker()->Check(&op_attr_map);
    }
    auto op_base =
        info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
X
xiongkun 已提交
199
    ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
W
wanghuancoder 已提交
200
  }
201 202 203
}

std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
204 205
    const VariableNameMap& var_name_map, VariableScope* var_scope,
    bool enforce_exist = true) {
206 207 208 209 210 211 212 213
  VariableValueMap name2var;
  VariableIdMap name2id;
  for (auto& item : var_name_map) {
    std::vector<Variable*> vars;
    std::vector<int> ids;
    vars.reserve(item.second.size());

    for (auto& var_name : item.second) {
214 215 216 217 218
      if (!enforce_exist && !var_scope->HasVar(var_name)) {
        // skip the non-exist variable: such as recurrent_grad
        VLOG(4) << var_name << " don't exist in variable scope, skip it!";
        continue;
      }
219 220 221 222 223 224 225 226 227 228
      auto var_id = var_scope->VarId(var_name);
      auto* in_var = var_scope->Var(var_id);
      vars.push_back(in_var);
      ids.push_back(var_id);
    }
    name2var[item.first] = std::move(vars);
    name2id[item.first] = std::move(ids);
  }
  return std::make_tuple(name2var, name2id);
}
W
wanghuancoder 已提交
229

230 231 232 233 234 235 236 237 238 239 240 241
void apply_device_guard(const OperatorBase* op_base,
                        const platform::Place& place,
                        OpKernelType* expected_kernel_key) {
  bool need_change_place =
      (op_base->HasAttr("op_device") &&
       (op_base->Attr<std::string>("op_device").length() > 0));
  if (need_change_place) {
    auto& op_device = op_base->Attr<std::string>("op_device");
    if (op_device == "cpu" || platform::is_cpu_place(place)) {
      VLOG(3) << "Switch into CPUPlace by device_guard.";
      expected_kernel_key->place_ = platform::CPUPlace();
    } else if (op_device.find("gpu") != std::string::npos &&
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
               (platform::is_gpu_place(place) ||
                platform::is_npu_place(place))) {
      // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
      // will be executed and a warning will be given at the same time.
      if (op_base->SupportGPU()) {
        expected_kernel_key->place_ = place;
      } else if (op_base->SupportNPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no CUDA implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
258 259 260 261 262 263 264
    } else {
      PADDLE_THROW(
          platform::errors::Fatal("Unsupported current place %s", op_device));
    }
  }
}

265
void deal_operator_base(const platform::Place& place,
L
Leo Chen 已提交
266 267
                        const VariableScope* var_scope,
                        std::shared_ptr<OperatorBase> op_base,
268
                        OpFuncNode* op_func_node, Scope* local_scope) {
269 270 271 272
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  // input, output is prepared. set the other attributes.
  op_func_node->operator_base_ = op_base;
273 274 275 276 277 278 279 280 281
  if (platform::is_gpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueAsync;
  } else if (platform::is_cpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueSync;
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("Unsupported current place %s", place));
  }

282
  op_func_node->kernel_func_ = nullptr;
283
  op_base->Run(*local_scope, place);  // Run without data transformer.
284 285 286 287 288 289 290 291 292 293 294 295

  std::unordered_set<int> no_data_transform_index;
  for (auto& it : op_func_node->input_index) {
    for (auto& id : it.second) {
      no_data_transform_index.emplace(id);
    }
  }
  op_func_node->no_data_transform_index =
      no_data_transform_index;  // all index is no-need-transform
  op_func_node->dev_ctx_ = dev_ctx;
}

296
void build_op_func_list(const platform::Place& place,
297
                        const framework::BlockDesc& block,
298
                        std::vector<OpFuncNode>* vec_func_list,
299 300 301
                        VariableScope* var_scope, bool use_local_scope) {
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();
302
  auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
X
xiongkun 已提交
303 304 305 306 307 308 309 310 311 312 313 314 315
  std::vector<std::unique_ptr<OperatorBase>>
      ops_unique;  // its elements will be moved to vec_func_list
  // Step 1: create all ops for current block.
  create_all_ops(block, &ops_unique);
  // If gc is enabled and block size > 1
  const ProgramDesc& main_program = *block.Program();
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
      main_program, block.ID(), ops_unique);

316 317
  // its elements will be moved to vec_func_list
  std::vector<std::shared_ptr<OperatorBase>> ops;
X
xiongkun 已提交
318 319 320
  for (auto& op_unique : ops_unique) {
    ops.emplace_back(std::move(op_unique));
  }
321
  auto unused_var_map = get_unused_vars(block, ops);
W
wanghuancoder 已提交
322

L
Leo Chen 已提交
323 324
  for (size_t i = 0; i < ops.size(); ++i) {
    auto op = ops[i].get();
325
    VLOG(6) << "Build OpFuncNode from : " << op->Type();
W
wanghuancoder 已提交
326 327 328 329 330

    auto inputs_names = op->Inputs();
    auto outputs_names = op->Outputs();

    VariableValueMap ins_map;
331
    VariableIdMap ins_name2id;
332
    bool enforce_exist = true;
W
wanghuancoder 已提交
333 334 335 336 337 338 339
    if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
        op->Type() == "rnn_memory_helper_grad" ||
        op->Type() == "conditional_block" ||
        op->Type() == "conditional_block_grad" || op->Type() == "while" ||
        op->Type() == "while_grad") {
      enforce_exist = false;
    }
340
    std::tie(ins_map, ins_name2id) =
341
        build_variable_map(inputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
342 343

    VariableValueMap outs_map;
344 345
    VariableIdMap outs_name2id;
    std::tie(outs_map, outs_name2id) =
346
        build_variable_map(outputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
347

348
    // step 2: build OpFuncNode
W
wanghuancoder 已提交
349
    OpFuncNode op_func_node;
350
    op_func_node.operator_base_ = ops[i];
W
wanghuancoder 已提交
351 352
    op_func_node.input_index = ins_name2id;
    op_func_node.output_index = outs_name2id;
353
    VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
354

355
    if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
356
      // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
357
      deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
358 359
      VLOG(4) << "End run " << place << " "
              << op_func_node.operator_base_->DebugStringEx(local_scope);
W
wanghuancoder 已提交
360
    } else {
361 362
      auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(op));
363 364 365 366
      // construct RuntimeContext and analysis KernelType
      RuntimeContext runtime_context({}, {});
      runtime_context.inputs.swap(ins_map);
      runtime_context.outputs.swap(outs_map);
367

368 369 370 371
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(place);
      Scope scope;
372 373
      auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
          ExecutionContext(*op, scope, *dev_ctx, runtime_context));
374
      op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key));
375

376 377
      // change device by the device_guard()
      apply_device_guard(op, place, &expected_kernel_key);
378 379
      VLOG(3) << "expected_kernel_key : " << expected_kernel_key;

380
      // step 3. apply data transforms and insert data transfer ops
381
      VariableValueMap& ins_map_temp = runtime_context.inputs;
382
      VariableValueMap& outs_map_temp = runtime_context.outputs;
383 384 385

      // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
      // ApplyDataTransform
386 387 388 389 390 391
      ApplyDataTransform(expected_kernel_key, place, &ins_map_temp,
                         &outs_map_temp, var_scope, &op_func_node,
                         vec_func_list, use_local_scope);
      op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(
              op_func_node.operator_base_.get()));
392

393
      // step 4. Run op kernel
394
      VLOG(3) << op_with_kernel->Type()
395 396 397 398 399 400 401 402 403 404 405 406 407 408
              << " : expected_kernel_key : " << expected_kernel_key;

      if (platform::is_gpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueAsync;
      } else if (platform::is_cpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueSync;
      } else {
        PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
                                             expected_kernel_key.place_));
      }
      if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
        dev_ctx = pool.Get(expected_kernel_key.place_);
      }
      op_func_node.dev_ctx_ = dev_ctx;
409 410
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
411 412
      auto exec_ctx =
          ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
W
wanghuancoder 已提交
413

414 415 416 417 418 419 420 421 422 423
      // see OperatorWithKernel::RunImpl in operator.cc for why
      if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
            op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
        InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
        // TODO(Aurelius84): In case of control flow ops, they are NOT
        // inheritted
        // from OperatorWithKernel.
        op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
      }

424 425
      auto run_phi_kernel = false;
      if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
426
              op_with_kernel->Type())) {
427 428
        auto pt_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
        auto pt_kernel_name = op_with_kernel->PhiKernelSignature()->name;
429

430 431
        if (op_with_kernel->PhiKernel()->IsValid()) {
          run_phi_kernel = true;
432 433 434 435 436 437 438
        } else {
          auto kernels_iter = all_op_kernels.find(op_with_kernel->Type());
          if (kernels_iter == all_op_kernels.end() ||
              kernels_iter->second.find(expected_kernel_key) ==
                  kernels_iter->second.end()) {
            auto pt_cpu_kernel_key = FallBackToCpu(
                expected_kernel_key, pt_kernel_key, *op_with_kernel);
439
            op_with_kernel->ResetPhiKernel(
440
                new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
441
                    pt_kernel_name, pt_cpu_kernel_key)));
442
            if (op_with_kernel->PhiKernel()->IsValid()) {
443 444 445
              VLOG(6) << "Static mode PrepareImpl - kernel name: "
                      << pt_kernel_name
                      << " | kernel key: " << pt_cpu_kernel_key
446 447
                      << " | kernel: " << *(op_with_kernel->PhiKernel());
              run_phi_kernel = true;
448 449 450 451 452 453
            }
          }
        }
      }
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
454
      if (run_phi_kernel) {
455
        phi::KernelContext pt_kernel_context;
456 457 458
        op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx,
                                              &pt_kernel_context);
        op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
459

460
        (*op_func_node.pt_kernel_)(&pt_kernel_context);
461
      } else {
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
        auto kernels_iter = all_op_kernels.find(op->Type());
        PADDLE_ENFORCE_NE(
            kernels_iter, all_op_kernels.end(),
            platform::errors::Unavailable(
                "There are no kernels which are registered in the %s operator.",
                op->Type()));
        OpKernelMap& kernels = kernels_iter->second;

        auto kernel_iter = kernels.find(expected_kernel_key);
        PADDLE_ENFORCE_NE(
            kernel_iter, kernels.end(),
            platform::errors::NotFound(
                "Operator (%s) does not have kernel for %s.", op->Type(),
                KernelTypeToString(expected_kernel_key)));
        // TODO(zhiqiu): add fallback logic
477 478 479
        op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
        op_func_node.kernel_func_(exec_ctx);
      }
480 481 482 483 484 485 486 487

      // post-process grad_op.outputs if need cast complex grad into real grad.
      // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
      if (framework::IsComplexType(expected_kernel_key.data_type_)) {
        interpreter::HandleComplexGradToRealGrad(
            op_func_node, place, outputs_names, &runtime_context.outputs,
            var_scope, vec_func_list, local_scope);
      }
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
      if (!op_func_node.inplace_back_map.empty()) {
        auto& m = op_func_node.inplace_back_map;
        // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
        for (auto& p : m) {
          auto* transformed_tensor =
              GetMutableLoDTensorOrSelectedRowsValueFromVar(
                  var_scope->Var(p.first));
          auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
              var_scope->Var(p.second));
          original_tensor->ShareDataWith(*transformed_tensor);
          VLOG(4) << "Transfer inplace variable back form "
                  << var_scope->GetNameById(p.first) << " to "
                  << var_scope->GetNameById(p.second);
        }
      }
503
    }
W
wanghuancoder 已提交
504

505 506 507
    VLOG(4) << "End run " << place << " "
            << op_func_node.operator_base_->DebugStringEx(local_scope);

L
Leo Chen 已提交
508
    vec_func_list->emplace_back(op_func_node);
509

W
wanghuancoder 已提交
510
    // gc---------------------------------------------------------------------------
L
Leo Chen 已提交
511
    auto iter = unused_var_map.find(op);
W
wanghuancoder 已提交
512 513 514 515 516 517 518 519 520
    if (iter == unused_var_map.end()) {
      continue;
    }

    auto& delete_vars = iter->second;
    std::deque<std::shared_ptr<memory::Allocation>>* garbages =
        new std::deque<std::shared_ptr<memory::Allocation>>();

    for (auto& var_name : delete_vars) {
521
      auto* var = var_scope->FindVar(var_name);
W
wanghuancoder 已提交
522 523 524 525
      if (var == nullptr) {
        continue;
      }

526
      VLOG(6) << "Erase variable " << var_name;
W
wanghuancoder 已提交
527 528 529
      if (var->IsType<LoDTensor>()) {
        garbages->emplace_back(
            var->GetMutable<LoDTensor>()->MoveMemoryHolder());
530 531
      } else if (var->IsType<phi::SelectedRows>()) {
        garbages->emplace_back(var->GetMutable<phi::SelectedRows>()
W
wanghuancoder 已提交
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
                                   ->mutable_value()
                                   ->MoveMemoryHolder());
      } else if (var->IsType<LoDTensorArray>()) {
        auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
        for (auto& t : *lod_tensor_arr) {
          garbages->emplace_back(t.MoveMemoryHolder());
        }
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Type %s of variable %s is not supported eager deletion.",
            framework::ToTypeName(var->Type()), var_name));
      }
    }
    delete garbages;  // free mem
  }
}

549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
void add_fetch(const std::vector<std::string>& fetch_names,
               framework::BlockDesc* block) {
  auto* fetch_holder = block->Var(kFetchVarName);
  fetch_holder->SetType(proto::VarType::FETCH_LIST);
  fetch_holder->SetPersistable(true);

  int i = 0;
  for (auto& fetch_name : fetch_names) {
    // append fetch op
    auto* op = block->AppendOp();
    op->SetType("fetch_v2");
    op->SetInput("X", {fetch_name});
    op->SetOutput("Out", {kFetchVarName});
    op->SetAttr("col", {static_cast<int>(i)});
    op->CheckAttrs();
    i++;
  }
}

W
wanghuancoder 已提交
568 569 570 571 572 573 574 575 576 577 578 579 580 581
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
                                 const std::vector<size_t>& second) {
  std::vector<size_t> out(first.size() + second.size());
  std::merge(first.begin(), first.end(), second.begin(), second.end(),
             out.begin());

  std::vector<size_t>::iterator it;
  it = std::unique(out.begin(), out.end());

  out.resize(std::distance(out.begin(), it));

  return out;
}

X
xiongkun 已提交
582
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
583
                          std::map<int, std::list<int>>* var2min_rw_op,
X
xiongkun 已提交
584 585 586
                          int cur_op, int rw_var) {
  // rw_var is inputs or outputs of cur_op
  // this function update the var2min_rw_op set .
587
  if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
588
    (*var2min_rw_op)[rw_var] = std::list<int>();
589
  }
X
xiongkun 已提交
590
  for (auto dep_op : op2dependences.at(cur_op)) {
591
    var2min_rw_op->at(rw_var).remove(dep_op);
X
xiongkun 已提交
592
  }
593
  var2min_rw_op->at(rw_var).push_back(cur_op);
X
xiongkun 已提交
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654
}

std::map<int, std::list<int>> get_downstream_map(
    const std::map<int, std::set<int>>& op2dependences) {
  // op2dependences is op -> it's dependences. we want to get op -> [ops] map,
  // where ops is the next instruction of op.
  std::map<int, std::list<int>> result;
  for (auto& item : op2dependences) {
    int op = item.first;
    for (auto dep_op : item.second) {
      if (result.find(dep_op) == result.end())
        result[dep_op] = std::list<int>();
      result[dep_op].push_back(op);
    }
  }
  return std::move(result);
}

std::map<int, std::list<int>> build_op_downstream_map(
    const std::vector<Instruction>& vec_instruction) {
  auto var2min_rw_op = std::map<
      int, std::list<int>>();  // # map from variable id to read / write op id.
  auto var2recent_write_op =
      std::map<int, int>();  // # map from variable to recent write op.
  auto op2dependences =
      std::map<int, std::set<int>>();  //# map from op to the dependence list,
                                       // op must run after the dependence.
  std::set<int>
      remove_duplicate;  // remove the duplicate between inputs and outputs

  // reserve
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    op2dependences[op_idx] = std::set<int>();
  }

  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    remove_duplicate.clear();
    // step1: update the op2dependences structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
        if (var2recent_write_op.count(var))
          op2dependences[op_idx].insert(var2recent_write_op[var]);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        if (var2min_rw_op.count(var)) {
          for (auto dep_op : var2min_rw_op[var]) {
            op2dependences[op_idx].insert(dep_op);
          }
        }
      }
    }

    // step2: update 2 var2xxxx data structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
655
        update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
656 657 658 659 660 661 662 663 664 665
        remove_duplicate.insert(var);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        var2recent_write_op[var] = op_idx;
        if (remove_duplicate.count(var) ==
            0) {  // var in input list and in output list, so remove it.
666
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
667 668 669
        }
      }
    }
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686

    // NOTE(zhiqiu): The inplace op with `transfer` also changes
    // original output after that so add original output as well
    // original: a->op->a
    // after: a->data_transfer->a'->op->a'->transfer_back->a
    // which means op writes a and a'
    if (!vec_instruction[op_idx].InplaceBackMap().empty()) {
      auto& m = vec_instruction[op_idx].InplaceBackMap();
      for (auto& p : m) {
        auto var = p.second;
        var2recent_write_op[var] = op_idx;
        // var in input list and in output list, so remove it.
        if (remove_duplicate.count(var) == 0) {
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
        }
      }
    }
X
xiongkun 已提交
687 688 689 690
  }
  return std::move(get_downstream_map(op2dependences));
}

691
}  // namespace interpreter
W
wanghuancoder 已提交
692 693
}  // namespace framework
}  // namespace paddle