interpretercore_util.cc 27.2 KB
Newer Older
W
wanghuancoder 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
15 16
#include <algorithm>

W
wanghuancoder 已提交
17
#include "paddle/fluid/framework/executor_gc_helper.h"
18
#include "paddle/fluid/framework/new_executor/data_transfer.h"
X
xiongkun 已提交
19 20 21
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
22
#include "paddle/phi/core/kernel_factory.h"
W
wanghuancoder 已提交
23

24 25 26
PADDLE_DEFINE_EXPORTED_bool(
    new_executor_sequential_run, false,
    "Enable sequential execution for standalone executor, used for debug");
27

W
wanghuancoder 已提交
28 29
namespace paddle {
namespace framework {
30
namespace interpreter {
31

L
liutiexing 已提交
32 33
constexpr size_t kPrepareWorkQueueIdx = 2;

34 35 36 37 38 39 40 41 42 43 44 45 46
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
                             std::function<void()> fn) {
  // NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
  if (FLAGS_new_executor_sequential_run) {
    VLOG(4) << "FLAGS_new_executor_sequential_run:"
            << FLAGS_new_executor_sequential_run;
    queue_group_->AddTask(static_cast<size_t>(OpFuncType::kQueueAsync),
                          std::move(fn));
  } else {
    queue_group_->AddTask(static_cast<size_t>(op_func_type), std::move(fn));
  }
}

47
using VariableIdMap = std::map<std::string, std::vector<int>>;
W
wanghuancoder 已提交
48

49
void AsyncWorkQueue::PrepareAtomicDeps(
50
    const std::vector<size_t>& dependecy_count) {
51
  VLOG(4) << "PrepareAtomicDeps";
L
liutiexing 已提交
52 53 54 55 56 57 58 59 60 61
  atomic_deps_ =
      queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&dependecy_count] {
        auto op_deps = std::make_unique<std::vector<std::atomic<size_t>>>(
            dependecy_count.size());
        for (size_t i = 0; i < dependecy_count.size(); ++i) {
          (*op_deps)[i] = dependecy_count[i];
        }
        VLOG(4) << "AtomicDeps:" << op_deps.get() << " " << op_deps->size();
        return op_deps;
      });
62 63
}

64
void AsyncWorkQueue::PrepareAtomicVarRef(
65
    const std::vector<VariableMetaInfo>& vec_meta_info) {
66
  VLOG(4) << "PrepareAtomicVarRef";
L
liutiexing 已提交
67 68 69 70 71 72 73 74 75 76
  atomic_var_ref_ =
      queue_group_->AddAwaitableTask(kPrepareWorkQueueIdx, [&vec_meta_info] {
        auto var_ref = std::make_unique<std::vector<std::atomic<size_t>>>(
            vec_meta_info.size());
        for (size_t i = 0; i < vec_meta_info.size(); ++i) {
          (*var_ref)[i] = vec_meta_info[i].var_ref_count_;
        }
        VLOG(4) << "AtomicVarRef:" << var_ref.get() << " " << var_ref->size();
        return var_ref;
      });
77 78
}

W
wanghuancoder 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
  auto* var_desc = block.FindVar(name);
  if (var_desc == nullptr || var_desc->Persistable()) {
    return false;
  }

  auto type = var_desc->Proto()->type().type();

  return type == proto::VarType::LOD_TENSOR ||
         type == proto::VarType::SELECTED_ROWS ||
         type == proto::VarType::LOD_TENSOR_ARRAY;
}

std::unordered_map<const paddle::framework::OperatorBase*,
                   std::vector<std::string>>
L
Leo Chen 已提交
94 95
get_unused_vars(const BlockDesc& block,
                const std::vector<std::shared_ptr<OperatorBase>>& ops) {
W
wanghuancoder 已提交
96 97 98
  std::unordered_map<std::string, size_t> var_op_idx_map;

  for (size_t i = 0; i < ops.size(); ++i) {
L
Leo Chen 已提交
99
    const auto& op = ops[i];
W
wanghuancoder 已提交
100 101 102 103 104 105 106 107 108 109

    OpInOutInfo info;
    for (auto& name_pair : op->Inputs()) {
      for (auto& name : name_pair.second) {
        if (!var_can_be_deleted(name, block)) {
          continue;
        }

        // var can be gc-ed
        if (!info.IsBuilt()) {
L
Leo Chen 已提交
110
          info.Build(op.get());
W
wanghuancoder 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        }

        if (info.IsInArgBufferNeeded(name)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        } else {
          VLOG(10) << "Skip reference count computing of variable "
                   << name_pair.first << "(" << name << ") in Operator "
                   << op->Type();
        }
      }
    }

    for (auto& name_pair : op->Outputs()) {
      for (auto& name : name_pair.second) {
        if (var_can_be_deleted(name, block)) {
          // Update the last living op of variable to current op
          var_op_idx_map[name] = i;
        }
      }
    }
  }

  std::unordered_map<const OperatorBase*, std::vector<std::string>> result;
  for (auto& name_op_idx_pair : var_op_idx_map) {
    auto& name = name_op_idx_pair.first;
    size_t op_idx = name_op_idx_pair.second;
L
Leo Chen 已提交
138 139

    result[ops[op_idx].get()].emplace_back(name);
140
    VLOG(4) << ops[op_idx].get()->Type() << " " << name;
W
wanghuancoder 已提交
141
  }
142
  VLOG(4) << "gc map size:" << result.size();
W
wanghuancoder 已提交
143 144 145
  return result;
}

146
void build_variable_scope(const framework::BlockDesc& block,
147 148 149 150 151 152 153 154 155
                          VariableScope* var_scope, bool use_local_scope) {
  VLOG(3) << "Creating Variables";
  auto inner_scope = var_scope->GetMutableScope();

  // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
  // created in var_scope.scope_ , and other scope is created in local scope.
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();

156
  for (auto& var_desc : block.AllVars()) {
157
    auto var_name = var_desc->Name();
X
xiongkun 已提交
158 159 160
    // TODO(xiongkun): user may create a variable with name that exists before.
    // under such circumstances, we should raise a error. Currently we can't
    // get the var_desc of startup_program, so leave it later.
161
    if (var_name == framework::kEmptyVarName) {
W
wanghuancoder 已提交
162 163
      continue;
    }
164 165
    if (var_desc->Persistable()) {
      auto* ptr = inner_scope->Var(var_name);
W
wanghuancoder 已提交
166

167 168 169 170
      VLOG(3) << "Initialize Variable " << var_name;
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
              << ptr << " type is " << static_cast<int>(var_desc->GetType());
171
    } else {
172 173 174 175 176
      auto* ptr = local_scope->Var(var_name);
      InitializeVariable(ptr, var_desc->GetType());
      VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
              << ptr << "Variable Type "
              << static_cast<int>(var_desc->GetType());
W
wanghuancoder 已提交
177
    }
178
    var_scope->SetVarDesc(var_name, var_desc);
W
wanghuancoder 已提交
179 180 181
  }
}

L
Leo Chen 已提交
182
void create_all_ops(const framework::BlockDesc& block,
X
xiongkun 已提交
183
                    std::vector<std::unique_ptr<OperatorBase>>* ops) {
184 185
  for (auto& op : block.AllOps()) {
    VLOG(3) << "CreateOp from : " << op->Type();
W
wanghuancoder 已提交
186 187 188 189 190 191 192 193 194 195 196 197

    auto& info = OpInfoMap::Instance().Get(op->Type());

    const VariableNameMap& inputs_names = op->Inputs();
    const VariableNameMap& outputs_names = op->Outputs();
    AttributeMap op_attr_map = op->GetAttrMap();

    if (info.Checker() != nullptr) {
      info.Checker()->Check(&op_attr_map);
    }
    auto op_base =
        info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
X
xiongkun 已提交
198
    ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
W
wanghuancoder 已提交
199
  }
200 201 202
}

std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
203 204
    const VariableNameMap& var_name_map, VariableScope* var_scope,
    bool enforce_exist = true) {
205 206 207 208 209 210 211 212
  VariableValueMap name2var;
  VariableIdMap name2id;
  for (auto& item : var_name_map) {
    std::vector<Variable*> vars;
    std::vector<int> ids;
    vars.reserve(item.second.size());

    for (auto& var_name : item.second) {
213 214 215 216 217
      if (!enforce_exist && !var_scope->HasVar(var_name)) {
        // skip the non-exist variable: such as recurrent_grad
        VLOG(4) << var_name << " don't exist in variable scope, skip it!";
        continue;
      }
218 219 220 221 222 223 224 225 226 227
      auto var_id = var_scope->VarId(var_name);
      auto* in_var = var_scope->Var(var_id);
      vars.push_back(in_var);
      ids.push_back(var_id);
    }
    name2var[item.first] = std::move(vars);
    name2id[item.first] = std::move(ids);
  }
  return std::make_tuple(name2var, name2id);
}
W
wanghuancoder 已提交
228

229 230 231 232 233 234 235 236 237 238 239 240
void apply_device_guard(const OperatorBase* op_base,
                        const platform::Place& place,
                        OpKernelType* expected_kernel_key) {
  bool need_change_place =
      (op_base->HasAttr("op_device") &&
       (op_base->Attr<std::string>("op_device").length() > 0));
  if (need_change_place) {
    auto& op_device = op_base->Attr<std::string>("op_device");
    if (op_device == "cpu" || platform::is_cpu_place(place)) {
      VLOG(3) << "Switch into CPUPlace by device_guard.";
      expected_kernel_key->place_ = platform::CPUPlace();
    } else if (op_device.find("gpu") != std::string::npos &&
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
               (platform::is_gpu_place(place) ||
                platform::is_npu_place(place))) {
      // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
      // will be executed and a warning will be given at the same time.
      if (op_base->SupportGPU()) {
        expected_kernel_key->place_ = place;
      } else if (op_base->SupportNPU()) {
        expected_kernel_key->place_ = place;
      } else {
        expected_kernel_key->place_ = platform::CPUPlace();
        LOG_FIRST_N(WARNING, 1)
            << "Op(" << op_base->Type()
            << ") has no CUDA implementation. It will be assigned to CPUPlace.";
      }
      VLOG(3) << "Switch into " << expected_kernel_key->place_
              << " by device_guard.";
257 258 259 260 261 262 263
    } else {
      PADDLE_THROW(
          platform::errors::Fatal("Unsupported current place %s", op_device));
    }
  }
}

264
void deal_operator_base(const platform::Place& place,
L
Leo Chen 已提交
265 266
                        const VariableScope* var_scope,
                        std::shared_ptr<OperatorBase> op_base,
267
                        OpFuncNode* op_func_node, Scope* local_scope) {
268 269 270 271
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(place);
  // input, output is prepared. set the other attributes.
  op_func_node->operator_base_ = op_base;
272 273 274 275 276 277 278 279 280
  if (platform::is_gpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueAsync;
  } else if (platform::is_cpu_place(place)) {
    op_func_node->type_ = OpFuncType::kQueueSync;
  } else {
    PADDLE_THROW(
        platform::errors::Fatal("Unsupported current place %s", place));
  }

281
  op_func_node->kernel_func_ = nullptr;
282
  op_base->Run(*local_scope, place);  // Run without data transformer.
283 284 285 286 287 288 289 290 291 292 293 294

  std::unordered_set<int> no_data_transform_index;
  for (auto& it : op_func_node->input_index) {
    for (auto& id : it.second) {
      no_data_transform_index.emplace(id);
    }
  }
  op_func_node->no_data_transform_index =
      no_data_transform_index;  // all index is no-need-transform
  op_func_node->dev_ctx_ = dev_ctx;
}

295
void build_op_func_list(const platform::Place& place,
296
                        const framework::BlockDesc& block,
297
                        std::vector<OpFuncNode>* vec_func_list,
298 299 300
                        VariableScope* var_scope, bool use_local_scope) {
  Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
                                       : var_scope->GetMutableScope();
301
  auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
X
xiongkun 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314
  std::vector<std::unique_ptr<OperatorBase>>
      ops_unique;  // its elements will be moved to vec_func_list
  // Step 1: create all ops for current block.
  create_all_ops(block, &ops_unique);
  // If gc is enabled and block size > 1
  const ProgramDesc& main_program = *block.Program();
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
      main_program, block.ID(), ops_unique);
  operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
      main_program, block.ID(), ops_unique);

315 316
  // its elements will be moved to vec_func_list
  std::vector<std::shared_ptr<OperatorBase>> ops;
X
xiongkun 已提交
317 318 319
  for (auto& op_unique : ops_unique) {
    ops.emplace_back(std::move(op_unique));
  }
320
  auto unused_var_map = get_unused_vars(block, ops);
W
wanghuancoder 已提交
321

L
Leo Chen 已提交
322 323
  for (size_t i = 0; i < ops.size(); ++i) {
    auto op = ops[i].get();
324
    VLOG(6) << "Build OpFuncNode from : " << op->Type();
W
wanghuancoder 已提交
325 326 327 328 329

    auto inputs_names = op->Inputs();
    auto outputs_names = op->Outputs();

    VariableValueMap ins_map;
330
    VariableIdMap ins_name2id;
331
    bool enforce_exist = true;
W
wanghuancoder 已提交
332 333 334 335 336 337 338
    if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
        op->Type() == "rnn_memory_helper_grad" ||
        op->Type() == "conditional_block" ||
        op->Type() == "conditional_block_grad" || op->Type() == "while" ||
        op->Type() == "while_grad") {
      enforce_exist = false;
    }
339
    std::tie(ins_map, ins_name2id) =
340
        build_variable_map(inputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
341 342

    VariableValueMap outs_map;
343 344
    VariableIdMap outs_name2id;
    std::tie(outs_map, outs_name2id) =
345
        build_variable_map(outputs_names, var_scope, enforce_exist);
W
wanghuancoder 已提交
346

347
    // step 2: build OpFuncNode
W
wanghuancoder 已提交
348
    OpFuncNode op_func_node;
349
    op_func_node.operator_base_ = ops[i];
W
wanghuancoder 已提交
350 351
    op_func_node.input_index = ins_name2id;
    op_func_node.output_index = outs_name2id;
352
    VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
353

354
    if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
355
      // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
356
      deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
357 358
      VLOG(4) << "End run " << place << " "
              << op_func_node.operator_base_->DebugStringEx(local_scope);
W
wanghuancoder 已提交
359
    } else {
360 361
      auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(op));
362 363 364 365
      // construct RuntimeContext and analysis KernelType
      RuntimeContext runtime_context({}, {});
      runtime_context.inputs.swap(ins_map);
      runtime_context.outputs.swap(outs_map);
366

367 368 369 370
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      auto* dev_ctx = pool.Get(place);
      Scope scope;
371 372
      auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
          ExecutionContext(*op, scope, *dev_ctx, runtime_context));
373
      op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key));
374

375 376
      // change device by the device_guard()
      apply_device_guard(op, place, &expected_kernel_key);
377 378
      VLOG(3) << "expected_kernel_key : " << expected_kernel_key;

379
      // step 3. apply data transforms and insert data transfer ops
380
      VariableValueMap& ins_map_temp = runtime_context.inputs;
381
      VariableValueMap& outs_map_temp = runtime_context.outputs;
382 383 384

      // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
      // ApplyDataTransform
385 386 387 388 389 390
      ApplyDataTransform(expected_kernel_key, place, &ins_map_temp,
                         &outs_map_temp, var_scope, &op_func_node,
                         vec_func_list, use_local_scope);
      op_with_kernel = const_cast<framework::OperatorWithKernel*>(
          static_cast<const framework::OperatorWithKernel*>(
              op_func_node.operator_base_.get()));
391

392
      // step 4. Run op kernel
393
      VLOG(3) << op_with_kernel->Type()
394 395 396 397 398 399 400 401 402 403 404 405 406 407
              << " : expected_kernel_key : " << expected_kernel_key;

      if (platform::is_gpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueAsync;
      } else if (platform::is_cpu_place(expected_kernel_key.place_)) {
        op_func_node.type_ = OpFuncType::kQueueSync;
      } else {
        PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
                                             expected_kernel_key.place_));
      }
      if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
        dev_ctx = pool.Get(expected_kernel_key.place_);
      }
      op_func_node.dev_ctx_ = dev_ctx;
408 409
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
410 411
      auto exec_ctx =
          ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
W
wanghuancoder 已提交
412

413 414 415 416 417 418 419 420 421 422
      // see OperatorWithKernel::RunImpl in operator.cc for why
      if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
            op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
        InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
        // TODO(Aurelius84): In case of control flow ops, they are NOT
        // inheritted
        // from OperatorWithKernel.
        op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
      }

423 424
      auto run_phi_kernel = false;
      if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
425
              op_with_kernel->Type())) {
426 427
        auto pt_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
        auto pt_kernel_name = op_with_kernel->PhiKernelSignature()->name;
428

429 430
        if (op_with_kernel->PhiKernel()->IsValid()) {
          run_phi_kernel = true;
431 432 433 434 435 436 437
        } else {
          auto kernels_iter = all_op_kernels.find(op_with_kernel->Type());
          if (kernels_iter == all_op_kernels.end() ||
              kernels_iter->second.find(expected_kernel_key) ==
                  kernels_iter->second.end()) {
            auto pt_cpu_kernel_key = FallBackToCpu(
                expected_kernel_key, pt_kernel_key, *op_with_kernel);
438
            op_with_kernel->ResetPhiKernel(
439
                new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
440
                    pt_kernel_name, pt_cpu_kernel_key)));
441
            if (op_with_kernel->PhiKernel()->IsValid()) {
442 443 444
              VLOG(6) << "Static mode PrepareImpl - kernel name: "
                      << pt_kernel_name
                      << " | kernel key: " << pt_cpu_kernel_key
445 446
                      << " | kernel: " << *(op_with_kernel->PhiKernel());
              run_phi_kernel = true;
447 448 449 450 451 452
            }
          }
        }
      }
      VLOG(3) << op_with_kernel->Type()
              << " : expected_kernel_key : " << expected_kernel_key;
453
      if (run_phi_kernel) {
454
        phi::KernelContext pt_kernel_context;
455 456 457
        op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx,
                                              &pt_kernel_context);
        op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
458

459
        (*op_func_node.pt_kernel_)(&pt_kernel_context);
460
      } else {
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
        auto kernels_iter = all_op_kernels.find(op->Type());
        PADDLE_ENFORCE_NE(
            kernels_iter, all_op_kernels.end(),
            platform::errors::Unavailable(
                "There are no kernels which are registered in the %s operator.",
                op->Type()));
        OpKernelMap& kernels = kernels_iter->second;

        auto kernel_iter = kernels.find(expected_kernel_key);
        PADDLE_ENFORCE_NE(
            kernel_iter, kernels.end(),
            platform::errors::NotFound(
                "Operator (%s) does not have kernel for %s.", op->Type(),
                KernelTypeToString(expected_kernel_key)));
        // TODO(zhiqiu): add fallback logic
476 477 478
        op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
        op_func_node.kernel_func_(exec_ctx);
      }
479 480 481 482 483 484 485 486

      // post-process grad_op.outputs if need cast complex grad into real grad.
      // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
      if (framework::IsComplexType(expected_kernel_key.data_type_)) {
        interpreter::HandleComplexGradToRealGrad(
            op_func_node, place, outputs_names, &runtime_context.outputs,
            var_scope, vec_func_list, local_scope);
      }
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
      if (!op_func_node.inplace_back_map.empty()) {
        auto& m = op_func_node.inplace_back_map;
        // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
        for (auto& p : m) {
          auto* transformed_tensor =
              GetMutableLoDTensorOrSelectedRowsValueFromVar(
                  var_scope->Var(p.first));
          auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
              var_scope->Var(p.second));
          original_tensor->ShareDataWith(*transformed_tensor);
          VLOG(4) << "Transfer inplace variable back form "
                  << var_scope->GetNameById(p.first) << " to "
                  << var_scope->GetNameById(p.second);
        }
      }
502
    }
W
wanghuancoder 已提交
503

504 505 506
    VLOG(4) << "End run " << place << " "
            << op_func_node.operator_base_->DebugStringEx(local_scope);

L
Leo Chen 已提交
507
    vec_func_list->emplace_back(op_func_node);
508

W
wanghuancoder 已提交
509
    // gc---------------------------------------------------------------------------
L
Leo Chen 已提交
510
    auto iter = unused_var_map.find(op);
W
wanghuancoder 已提交
511 512 513 514 515 516 517 518 519
    if (iter == unused_var_map.end()) {
      continue;
    }

    auto& delete_vars = iter->second;
    std::deque<std::shared_ptr<memory::Allocation>>* garbages =
        new std::deque<std::shared_ptr<memory::Allocation>>();

    for (auto& var_name : delete_vars) {
520
      auto* var = var_scope->FindVar(var_name);
W
wanghuancoder 已提交
521 522 523 524
      if (var == nullptr) {
        continue;
      }

525
      VLOG(6) << "Erase variable " << var_name;
W
wanghuancoder 已提交
526 527 528
      if (var->IsType<LoDTensor>()) {
        garbages->emplace_back(
            var->GetMutable<LoDTensor>()->MoveMemoryHolder());
529 530
      } else if (var->IsType<phi::SelectedRows>()) {
        garbages->emplace_back(var->GetMutable<phi::SelectedRows>()
W
wanghuancoder 已提交
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
                                   ->mutable_value()
                                   ->MoveMemoryHolder());
      } else if (var->IsType<LoDTensorArray>()) {
        auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
        for (auto& t : *lod_tensor_arr) {
          garbages->emplace_back(t.MoveMemoryHolder());
        }
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "Type %s of variable %s is not supported eager deletion.",
            framework::ToTypeName(var->Type()), var_name));
      }
    }
    delete garbages;  // free mem
  }
}

548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
void add_fetch(const std::vector<std::string>& fetch_names,
               framework::BlockDesc* block) {
  auto* fetch_holder = block->Var(kFetchVarName);
  fetch_holder->SetType(proto::VarType::FETCH_LIST);
  fetch_holder->SetPersistable(true);

  int i = 0;
  for (auto& fetch_name : fetch_names) {
    // append fetch op
    auto* op = block->AppendOp();
    op->SetType("fetch_v2");
    op->SetInput("X", {fetch_name});
    op->SetOutput("Out", {kFetchVarName});
    op->SetAttr("col", {static_cast<int>(i)});
    op->CheckAttrs();
    i++;
  }
}

W
wanghuancoder 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
                                 const std::vector<size_t>& second) {
  std::vector<size_t> out(first.size() + second.size());
  std::merge(first.begin(), first.end(), second.begin(), second.end(),
             out.begin());

  std::vector<size_t>::iterator it;
  it = std::unique(out.begin(), out.end());

  out.resize(std::distance(out.begin(), it));

  return out;
}

X
xiongkun 已提交
581
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
582
                          std::map<int, std::list<int>>* var2min_rw_op,
X
xiongkun 已提交
583 584 585
                          int cur_op, int rw_var) {
  // rw_var is inputs or outputs of cur_op
  // this function update the var2min_rw_op set .
586
  if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
587
    (*var2min_rw_op)[rw_var] = std::list<int>();
588
  }
X
xiongkun 已提交
589
  for (auto dep_op : op2dependences.at(cur_op)) {
590
    var2min_rw_op->at(rw_var).remove(dep_op);
X
xiongkun 已提交
591
  }
592
  var2min_rw_op->at(rw_var).push_back(cur_op);
X
xiongkun 已提交
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
}

std::map<int, std::list<int>> get_downstream_map(
    const std::map<int, std::set<int>>& op2dependences) {
  // op2dependences is op -> it's dependences. we want to get op -> [ops] map,
  // where ops is the next instruction of op.
  std::map<int, std::list<int>> result;
  for (auto& item : op2dependences) {
    int op = item.first;
    for (auto dep_op : item.second) {
      if (result.find(dep_op) == result.end())
        result[dep_op] = std::list<int>();
      result[dep_op].push_back(op);
    }
  }
  return std::move(result);
}

std::map<int, std::list<int>> build_op_downstream_map(
    const std::vector<Instruction>& vec_instruction) {
  auto var2min_rw_op = std::map<
      int, std::list<int>>();  // # map from variable id to read / write op id.
  auto var2recent_write_op =
      std::map<int, int>();  // # map from variable to recent write op.
  auto op2dependences =
      std::map<int, std::set<int>>();  //# map from op to the dependence list,
                                       // op must run after the dependence.
  std::set<int>
      remove_duplicate;  // remove the duplicate between inputs and outputs

  // reserve
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    op2dependences[op_idx] = std::set<int>();
  }

  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    remove_duplicate.clear();
    // step1: update the op2dependences structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
        if (var2recent_write_op.count(var))
          op2dependences[op_idx].insert(var2recent_write_op[var]);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        if (var2min_rw_op.count(var)) {
          for (auto dep_op : var2min_rw_op[var]) {
            op2dependences[op_idx].insert(dep_op);
          }
        }
      }
    }

    // step2: update 2 var2xxxx data structure
    for (auto& item :
         vec_instruction[op_idx].Inputs()) {  // for all inputs(read only)
      for (auto var : item.second) {
654
        update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
655 656 657 658 659 660 661 662 663 664
        remove_duplicate.insert(var);
      }
    }

    for (auto& item :
         vec_instruction[op_idx].Outputs()) {  // for all write vars
      for (auto var : item.second) {
        var2recent_write_op[var] = op_idx;
        if (remove_duplicate.count(var) ==
            0) {  // var in input list and in output list, so remove it.
665
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
X
xiongkun 已提交
666 667 668
        }
      }
    }
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685

    // NOTE(zhiqiu): The inplace op with `transfer` also changes
    // original output after that so add original output as well
    // original: a->op->a
    // after: a->data_transfer->a'->op->a'->transfer_back->a
    // which means op writes a and a'
    if (!vec_instruction[op_idx].InplaceBackMap().empty()) {
      auto& m = vec_instruction[op_idx].InplaceBackMap();
      for (auto& p : m) {
        auto var = p.second;
        var2recent_write_op[var] = op_idx;
        // var in input list and in output list, so remove it.
        if (remove_duplicate.count(var) == 0) {
          update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
        }
      }
    }
X
xiongkun 已提交
686
  }
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702

  // add dependences for random op, make sure that the random op is scheduled
  // sequentially
  const std::set<std::string> random_op_set = {
      "bernoulli",      "poisson", "multinomial", "gaussian_random",
      "uniform_random", "randint", "randperm",    "exponential"};
  int dependence_op_idx = -1;
  for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
    if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) {
      if (dependence_op_idx != -1) {
        op2dependences[op_idx].insert(dependence_op_idx);
      }
      dependence_op_idx = op_idx;
    }
  }

X
xiongkun 已提交
703 704 705
  return std::move(get_downstream_map(op2dependences));
}

706
}  // namespace interpreter
W
wanghuancoder 已提交
707 708
}  // namespace framework
}  // namespace paddle