interpretercore_util.cc 24.4 KB
Newer Older
W
wanghuancoder 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
15 16
#include <algorithm>

W
wanghuancoder 已提交
17
#include "paddle/fluid/framework/executor_gc_helper.h"
18
#include "paddle/fluid/framework/new_executor/data_transfer.h"
X
xiongkun 已提交
19 20 21
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
22
#include "paddle/pten/core/kernel_factory.h"
W
wanghuancoder 已提交
23

24 25 26
PADDLE_DEFINE_EXPORTED_bool(
    new_executor_sequential_run, false,
    "Enable sequential execution for standalone executor, used for debug");
27

W
wanghuancoder 已提交
28 29
namespace paddle {
namespace framework {
30
namespace interpreter {
31 32 33 34 35 36 37 38 39 40 41 42 43 44

void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
                             std::function<void()> fn) {
  // NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
  if (FLAGS_new_executor_sequential_run) {
    VLOG(4) << "FLAGS_new_executor_sequential_run:"
            << FLAGS_new_executor_sequential_run;
    queue_group_->AddTask(static_cast<size_t>(OpFuncType::kQueueAsync),
                          std::move(fn));
  } else {
    queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
  }
}

45
using VariableIdMap = std::map<std::string, std::vector<int>>;
W
wanghuancoder 已提交
46

47
AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps(
48
    const std::vector<size_t>& dependecy_count) {
49 50 51 52 53 54
  if (atomic_deps_.size() != dependecy_count.size()) {
    atomic_deps_.clear();
    std::generate_n(std::back_inserter(atomic_deps_), dependecy_count.size(),
                    [] { return std::make_unique<std::atomic<size_t>>(0); });
  }

55
  for (size_t i = 0; i < dependecy_count.size(); ++i) {
56
    atomic_deps_[i]->store(dependecy_count[i]);
57
  }
58
  return atomic_deps_;
59 60
}

61
AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicVarRef(
62
    const std::vector<VariableMetaInfo>& vec_meta_info) {
63 64 65 66 67
  if (atomic_var_ref_.size() != vec_meta_info.size()) {
    atomic_var_ref_.clear();
    std::generate_n(std::back_inserter(atomic_var_ref_), vec_meta_info.size(),
                    [] { return std::make_unique<std::atomic<size_t>>(0); });
  }
68 69

  for (size_t i = 0; i < vec_meta_info.size(); ++i) {
70
    atomic_var_ref_[i]->store(vec_meta_info[i].var_ref_count_);
71
  }
72
  return atomic_var_ref_;
73 74
}

W
wanghuancoder 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
  auto* var_desc = block.FindVar(name);
  if (var_desc == nullptr || var_desc->Persistable()) {
    return false;
  }

  auto type = var_desc->Proto()->type().type();

  return type == proto::VarType::LOD_TENSOR ||
         type == proto::VarType::SELECTED_ROWS ||
         type == proto::VarType::LOD_TENSOR_ARRAY;
}

std::unordered_map<const paddle::framework::OperatorBase*,
                   std::vector<std::string>>
L
Leo Chen 已提交
90 91
get_unused_vars(const BlockDesc& block,
                const std::vector<std::shared_ptr<OperatorBase>>& ops) {
W
wanghuancoder 已提交
92 93 94
  std::unordered_map<std::string, size_t> var_op_idx_map;

  for (size_t i = 0; i < ops.size(); ++i) {
L
Leo Chen 已提交
95
    const auto& op = ops[i];
W
wanghuancoder 已提交
96 97 98 99 100 101 102 103 104 105

    OpInOutInfo info;
    for (auto& name_pair : op->Inputs()) {
      for (auto& name : name_pair.second) {
        if (!var_can_be_deleted(name, block)) {
          continue;
        }

        // var can be gc-ed
        if (!info.IsBuilt()) {
L
Leo Chen 已提交
106
          info.Build(op.get());
W
wanghuancoder 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
        }

        if (info.IsInArgBufferNeeded(name)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        } else {
          VLOG(10) << "Skip reference count computing of variable "
                   << name_pair.first << "(" << name << ") in Operator "
                   << op->Type();
        }
      }
    }

    for (auto& name_pair : op->Outputs()) {
      for (auto& name : name_pair.second) {
        if (var_can_be_deleted(name, block)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        }
      }
    }
  }

  std::unordered_map<const OperatorBase*, std::vector<std::string>> result;
  for (auto& name_op_idx_pair : var_op_idx_map) {
    auto& name = name_op_idx_pair.first;
    size_t op_idx = name_op_idx_pair.second;
L
Leo Chen 已提交
134 135

    result[ops[op_idx].get()].emplace_back(name);
W
wanghuancoder 已提交
136 137 138 139
  }
  return result;
}

140
void build_variable_scope(const framework::BlockDesc& block,
141 142 143 144 145 146 147 148 149
                          VariableScope* var_scope, bool use_local_scope) {
  VLOG(3) << "Creating Variables";
  auto inner_scope = var_scope->GetMutableScope();

  // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
  // created in var_scope.scope_ , and other scope is created in local scope.
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();

150
  for (auto& var_desc : block.AllVars()) {
151
    auto var_name = var_desc->Name();
X
xiongkun 已提交
152 153 154
    // TODO(xiongkun): user may create a variable with name that exists before.
    // under such circumstances, we should raise a error. Currently we can't
    // get the var_desc of startup_program, so leave it later.
155
    if (var_name == framework::kEmptyVarName) {
W
wanghuancoder 已提交
156 157
      continue;
    }
158 159
    if (var_desc->Persistable()) {
      auto* ptr = inner_scope->Var(var_name);
W
wanghuancoder 已提交
160

161 162 163 164
      VLOG(3) << "Initialize Variable " << var_name;
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
              << ptr << " type is " << static_cast<int>(var_desc->GetType());
165
    } else {
166 167 168 169 170
      auto* ptr = local_scope->Var(var_name);
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
              << ptr << "Variable Type "
              << static_cast<int>(var_desc->GetType());
W
wanghuancoder 已提交
171
    }
172
    var_scope->SetVarDesc(var_name, var_desc);
W
wanghuancoder 已提交
173 174 175
  }
}

L
Leo Chen 已提交
176
void create_all_ops(const framework::BlockDesc& block,
X
xiongkun 已提交
177
                    std::vector<std::unique_ptr<OperatorBase>>* ops) {
178 179
  for (auto& op : block.AllOps()) {
    VLOG(3) << "CreateOp from : " << op->Type();
W
wanghuancoder 已提交
180 181 182 183 184 185 186 187 188 189 190 191

    auto& info = OpInfoMap::Instance().Get(op->Type());

    const VariableNameMap& inputs_names = op->Inputs();
    const VariableNameMap& outputs_names = op->Outputs();
    AttributeMap op_attr_map = op->GetAttrMap();

    if (info.Checker() != nullptr) {
      info.Checker()->Check(&op_attr_map);
    }
    auto op_base =
        info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
X
xiongkun 已提交
192
    ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
W
wanghuancoder 已提交
193
  }
194 195 196
}

std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
197 198
    const VariableNameMap& var_name_map, VariableScope* var_scope,
    bool enforce_exist = true) {
199 200 201 202 203 204 205 206
  VariableValueMap name2var;
  VariableIdMap name2id;
  for (auto& item : var_name_map) {
    std::vector<Variable*> vars;
    std::vector<int> ids;
    vars.reserve(item.second.size());

    for (auto& var_name : item.second) {
207 208 209 210 211
      if (!enforce_exist && !var_scope->HasVar(var_name)) {
        // skip the non-exist variable: such as recurrent_grad
        VLOG(4) << var_name << " don't exist in variable scope, skip it!";
        continue;
      }
212 213 214 215 216 217 218 219 220 221
      auto var_id = var_scope->VarId(var_name);
      auto* in_var = var_scope->Var(var_id);
      vars.push_back(in_var);
      ids.push_back(var_id);
    }
    name2var[item.first] = std::move(vars);
    name2id[item.first] = std::move(ids);
  }
  return std::make_tuple(name2var, name2id);
}
W
wanghuancoder 已提交
222

223 224 225 226 227 228 229 230 231 232 233 234
void apply_device_guard(const OperatorBase* op_base,
                        const platform::Place& place,
                        OpKernelType* expected_kernel_key) {
  bool need_change_place =
      (op_base->HasAttr("op_device") &&
       (op_base->Attr<std::string>("op_device").length() > 0));
  if (need_change_place) {
    auto& op_device = op_base->Attr<std::string>("op_device");
    if (op_device == "cpu" || platform::is_cpu_place(place)) {
      VLOG(3) << "Switch into CPUPlace by device_guard.";
      expected_kernel_key->place_ = platform::CPUPlace();
    } else if (op_device.find("gpu") != std::string::npos &&
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
               (platform::is_gpu_place(place) ||
                platform::is_npu_place(place))) {
      // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
      // will be executed and a warning will be given at the same time.
      if (op_base->SupportGPU()) {
        expected_kernel_key->place_ = place;
      } else if (op_base->SupportNPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no CUDA implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
251 252 253 254 255 256 257
    } else {
      PADDLE_THROW(
          platform::errors::Fatal("Unsupported current place %s", op_device));
    }
  }
}

258
void deal_operator_base(const platform::Place& place,
L
Leo Chen 已提交
259 260
                        const VariableScope* var_scope,
                        std::shared_ptr<OperatorBase> op_base,
261
                        OpFuncNode* op_func_node, Scope* local_scope) {
262 263 264 265
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  // input, output is prepared. set the other attributes.
  op_func_node->operator_base_ = op_base;
266 267 268 269 270 271 272 273 274
  if (platform::is_gpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueAsync;
  } else if (platform::is_cpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueSync;
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("Unsupported current place %s", place));
  }

275
  op_func_node->kernel_func_ = nullptr;
276
  op_base->Run(*local_scope, place);  // Run without data transformer.
277 278 279 280 281 282 283 284 285 286 287 288

  std::unordered_set<int> no_data_transform_index;
  for (auto& it : op_func_node->input_index) {
    for (auto& id : it.second) {
      no_data_transform_index.emplace(id);
    }
  }
  op_func_node->no_data_transform_index =
      no_data_transform_index;  // all index is no-need-transform
  op_func_node->dev_ctx_ = dev_ctx;
}

289
void build_op_func_list(const platform::Place& place,
290
                        const framework::BlockDesc& block,
291
                        std::vector<OpFuncNode>* vec_func_list,
292 293 294
                        VariableScope* var_scope, bool use_local_scope) {
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();
295
  auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
X
xiongkun 已提交
296 297 298 299 300 301 302 303 304 305 306 307 308
  std::vector<std::unique_ptr<OperatorBase>>
      ops_unique;  // its elements will be moved to vec_func_list
  // Step 1: create all ops for current block.
  create_all_ops(block, &ops_unique);
  // If gc is enabled and block size > 1
  const ProgramDesc& main_program = *block.Program();
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
      main_program, block.ID(), ops_unique);

L
Leo Chen 已提交
309 310
  std::vector<std::shared_ptr<OperatorBase>>
      ops;  // its elements will be moved to vec_func_list
X
xiongkun 已提交
311 312 313
  for (auto& op_unique : ops_unique) {
    ops.emplace_back(std::move(op_unique));
  }
314
  auto unused_var_map = get_unused_vars(block, ops);
W
wanghuancoder 已提交
315

L
Leo Chen 已提交
316 317
  for (size_t i = 0; i < ops.size(); ++i) {
    auto op = ops[i].get();
318
    VLOG(6) << "Build OpFuncNode from : " << op->Type();
W
wanghuancoder 已提交
319 320 321 322 323

    auto inputs_names = op->Inputs();
    auto outputs_names = op->Outputs();

    VariableValueMap ins_map;
324
    VariableIdMap ins_name2id;
325
    bool enforce_exist = true;
W
wanghuancoder 已提交
326 327 328 329 330 331 332
    if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
        op->Type() == "rnn_memory_helper_grad" ||
        op->Type() == "conditional_block" ||
        op->Type() == "conditional_block_grad" || op->Type() == "while" ||
        op->Type() == "while_grad") {
      enforce_exist = false;
    }
333
    std::tie(ins_map, ins_name2id) =
334
        build_variable_map(inputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
335 336

    VariableValueMap outs_map;
337 338
    VariableIdMap outs_name2id;
    std::tie(outs_map, outs_name2id) =
339
        build_variable_map(outputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
340

341
    // step 2: build OpFuncNode
W
wanghuancoder 已提交
342
    OpFuncNode op_func_node;
343
    op_func_node.operator_base_ = ops[i];
W
wanghuancoder 已提交
344 345
    op_func_node.input_index = ins_name2id;
    op_func_node.output_index = outs_name2id;
346

L
Leo Chen 已提交
347
    if (dynamic_cast<const framework::OperatorWithKernel*>(op) == nullptr) {
348
      // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
349
      deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
W
wanghuancoder 已提交
350
    } else {
351 352
      auto op_with_kernel =
          static_cast<const framework::OperatorWithKernel*>(op);
353 354 355 356
      // construct RuntimeContext and analysis KernelType
      RuntimeContext runtime_context({}, {});
      runtime_context.inputs.swap(ins_map);
      runtime_context.outputs.swap(outs_map);
357 358 359 360 361 362 363 364

      // see OperatorWithKernel::RunImpl in operator.cc for why
      if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
            op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
        InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
        // TODO(Aurelius84): In case of control flow ops, they are NOT
        // inheritted
        // from OperatorWithKernel.
365
        op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
366 367
      }

368 369 370 371
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(place);
      Scope scope;
372 373
      auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
          ExecutionContext(*op, scope, *dev_ctx, runtime_context));
374

375 376
      // change device by the device_guard()
      apply_device_guard(op, place, &expected_kernel_key);
377 378
      VLOG(3) << "expected_kernel_key : " << expected_kernel_key;

379
      // step 3. apply data transforms and insert data transfer ops
380
      VariableValueMap& ins_map_temp = runtime_context.inputs;
381 382 383

      // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
      // ApplyDataTransform
384
      ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope,
385
                         &op_func_node, vec_func_list, use_local_scope);
386 387 388
      op_with_kernel = static_cast<const framework::OperatorWithKernel*>(
          op_func_node.operator_base_.get());

389
      // step 4. Run op kernel
390
      VLOG(3) << op_with_kernel->Type()
391 392 393 394 395 396 397 398 399 400 401 402 403 404
              << " : expected_kernel_key : " << expected_kernel_key;

      if (platform::is_gpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueAsync;
      } else if (platform::is_cpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueSync;
      } else {
        PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
                                             expected_kernel_key.place_));
      }
      if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
        dev_ctx = pool.Get(expected_kernel_key.place_);
      }
      op_func_node.dev_ctx_ = dev_ctx;
405 406
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
407 408
      auto exec_ctx =
          ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
W
wanghuancoder 已提交
409

410
      auto run_pten_kernel = false;
411
      if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(
412
              op_with_kernel->Type())) {
413 414
        auto pt_kernel_key = op_with_kernel->ChoosePtenKernel(exec_ctx);
        auto pt_kernel_name = op_with_kernel->PtenKernelSignature()->name;
415

416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
        if (op_with_kernel->PtenKernel()->IsValid()) {
          run_pten_kernel = true;
        } else {
          auto kernels_iter = all_op_kernels.find(op_with_kernel->Type());
          if (kernels_iter == all_op_kernels.end() ||
              kernels_iter->second.find(expected_kernel_key) ==
                  kernels_iter->second.end()) {
            auto pt_cpu_kernel_key = FallBackToCpu(
                expected_kernel_key, pt_kernel_key, *op_with_kernel);
            op_with_kernel->ResetPtenKernel(
                new pten::Kernel(pten::KernelFactory::Instance().SelectKernel(
                    pt_kernel_name, pt_cpu_kernel_key)));
            if (op_with_kernel->PtenKernel()->IsValid()) {
              VLOG(6) << "Static mode PrepareImpl - kernel name: "
                      << pt_kernel_name
                      << " | kernel key: " << pt_cpu_kernel_key
                      << " | kernel: " << *(op_with_kernel->PtenKernel());
              run_pten_kernel = true;
            }
          }
        }
      }
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
440
      if (run_pten_kernel) {
441 442 443
        pten::KernelContext pt_kernel_context;
        op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx,
                                               &pt_kernel_context);
444 445
        op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();

446
        (*op_func_node.pt_kernel_)(&pt_kernel_context);
447
      } else {
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
        auto kernels_iter = all_op_kernels.find(op->Type());
        PADDLE_ENFORCE_NE(
            kernels_iter, all_op_kernels.end(),
            platform::errors::Unavailable(
                "There are no kernels which are registered in the %s operator.",
                op->Type()));
        OpKernelMap& kernels = kernels_iter->second;

        auto kernel_iter = kernels.find(expected_kernel_key);
        PADDLE_ENFORCE_NE(
            kernel_iter, kernels.end(),
            platform::errors::NotFound(
                "Operator (%s) does not have kernel for %s.", op->Type(),
                KernelTypeToString(expected_kernel_key)));
        // TODO(zhiqiu): add fallback logic
463 464 465
        op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
        op_func_node.kernel_func_(exec_ctx);
      }
466 467 468 469 470 471 472 473

      // post-process grad_op.outputs if need cast complex grad into real grad.
      // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
      if (framework::IsComplexType(expected_kernel_key.data_type_)) {
        interpreter::HandleComplexGradToRealGrad(
            op_func_node, place, outputs_names, &runtime_context.outputs,
            var_scope, vec_func_list, local_scope);
      }
474
    }
W
wanghuancoder 已提交
475

L
Leo Chen 已提交
476
    vec_func_list->emplace_back(op_func_node);
W
wanghuancoder 已提交
477
    // gc---------------------------------------------------------------------------
L
Leo Chen 已提交
478
    auto iter = unused_var_map.find(op);
W
wanghuancoder 已提交
479 480 481 482 483 484 485 486 487
    if (iter == unused_var_map.end()) {
      continue;
    }

    auto& delete_vars = iter->second;
    std::deque<std::shared_ptr<memory::Allocation>>* garbages =
        new std::deque<std::shared_ptr<memory::Allocation>>();

    for (auto& var_name : delete_vars) {
488
      auto* var = var_scope->FindVar(var_name);
W
wanghuancoder 已提交
489 490 491 492
      if (var == nullptr) {
        continue;
      }

493
      VLOG(6) << "Erase variable " << var_name;
W
wanghuancoder 已提交
494 495 496
      if (var->IsType<LoDTensor>()) {
        garbages->emplace_back(
            var->GetMutable<LoDTensor>()->MoveMemoryHolder());
497 498
      } else if (var->IsType<pten::SelectedRows>()) {
        garbages->emplace_back(var->GetMutable<pten::SelectedRows>()
W
wanghuancoder 已提交
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
                                   ->mutable_value()
                                   ->MoveMemoryHolder());
      } else if (var->IsType<LoDTensorArray>()) {
        auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
        for (auto& t : *lod_tensor_arr) {
          garbages->emplace_back(t.MoveMemoryHolder());
        }
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Type %s of variable %s is not supported eager deletion.",
            framework::ToTypeName(var->Type()), var_name));
      }
    }

    delete garbages;  // free mem

L
Leo Chen 已提交
515
    VLOG(3) << "run " << op->Type() << " done.";
W
wanghuancoder 已提交
516 517 518
  }
}

519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
void add_fetch(const std::vector<std::string>& fetch_names,
               framework::BlockDesc* block) {
  auto* fetch_holder = block->Var(kFetchVarName);
  fetch_holder->SetType(proto::VarType::FETCH_LIST);
  fetch_holder->SetPersistable(true);

  int i = 0;
  for (auto& fetch_name : fetch_names) {
    // append fetch op
    auto* op = block->AppendOp();
    op->SetType("fetch_v2");
    op->SetInput("X", {fetch_name});
    op->SetOutput("Out", {kFetchVarName});
    op->SetAttr("col", {static_cast<int>(i)});
    op->CheckAttrs();
    i++;
  }
}

W
wanghuancoder 已提交
538 539 540 541 542 543 544 545 546 547 548 549 550 551
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
                                 const std::vector<size_t>& second) {
  std::vector<size_t> out(first.size() + second.size());
  std::merge(first.begin(), first.end(), second.begin(), second.end(),
             out.begin());

  std::vector<size_t>::iterator it;
  it = std::unique(out.begin(), out.end());

  out.resize(std::distance(out.begin(), it));

  return out;
}

X
xiongkun 已提交
552
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
553
                          std::map<int, std::list<int>>* var2min_rw_op,
X
xiongkun 已提交
554 555 556
                          int cur_op, int rw_var) {
  // rw_var is inputs or outputs of cur_op
  // this function update the var2min_rw_op set .
557
  if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
558
    (*var2min_rw_op)[rw_var] = std::list<int>();
559
  }
X
xiongkun 已提交
560
  for (auto dep_op : op2dependences.at(cur_op)) {
561
    var2min_rw_op->at(rw_var).remove(dep_op);
X
xiongkun 已提交
562
  }
563
  var2min_rw_op->at(rw_var).push_back(cur_op);
X
xiongkun 已提交
564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624
}

std::map<int, std::list<int>> get_downstream_map(
    const std::map<int, std::set<int>>& op2dependences) {
  // op2dependences is op -> it's dependences. we want to get op -> [ops] map,
  // where ops is the next instruction of op.
  std::map<int, std::list<int>> result;
  for (auto& item : op2dependences) {
    int op = item.first;
    for (auto dep_op : item.second) {
      if (result.find(dep_op) == result.end())
        result[dep_op] = std::list<int>();
      result[dep_op].push_back(op);
    }
  }
  return std::move(result);
}

std::map<int, std::list<int>> build_op_downstream_map(
    const std::vector<Instruction>& vec_instruction) {
  auto var2min_rw_op = std::map<
      int, std::list<int>>();  // # map from variable id to read / write op id.
  auto var2recent_write_op =
      std::map<int, int>();  // # map from variable to recent write op.
  auto op2dependences =
      std::map<int, std::set<int>>();  //# map from op to the dependence list,
                                       // op must run after the dependence.
  std::set<int>
      remove_duplicate;  // remove the duplicate between inputs and outputs

  // reserve
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    op2dependences[op_idx] = std::set<int>();
  }

  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    remove_duplicate.clear();
    // step1: update the op2dependences structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
        if (var2recent_write_op.count(var))
          op2dependences[op_idx].insert(var2recent_write_op[var]);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        if (var2min_rw_op.count(var)) {
          for (auto dep_op : var2min_rw_op[var]) {
            op2dependences[op_idx].insert(dep_op);
          }
        }
      }
    }

    // step2: update 2 var2xxxx data structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
625
        update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
626 627 628 629 630 631 632 633 634 635
        remove_duplicate.insert(var);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        var2recent_write_op[var] = op_idx;
        if (remove_duplicate.count(var) ==
            0) {  // var in input list and in output list, so remove it.
636
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
637 638 639 640 641 642 643
        }
      }
    }
  }
  return std::move(get_downstream_map(op2dependences));
}

644
}  // namespace interpreter
W
wanghuancoder 已提交
645 646
}  // namespace framework
}  // namespace paddle