naive_executor.cc 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

X
Xin Pan 已提交
15
#include "paddle/fluid/framework/naive_executor.h"
16
#include <string>
17
#include "paddle/fluid/framework/op_registry.h"
W
Wang Guibao 已提交
18
#include "paddle/fluid/framework/variable_helper.h"
19
#include "paddle/fluid/platform/denormal.h"
20 21 22
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
W
wenbin 已提交
23 24 25
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h"
#endif
26 27 28

namespace paddle {
namespace framework {
29 30 31
void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
                            int block_id, bool with_feed_fetch_ops) {
  if (!scope) {
32 33
    scope_ = new framework::Scope;
  } else {
34
    scope_ = scope;
35
  }
36 37

  VLOG(3) << "NaiveExecutor init with scope " << scope;
38 39 40 41
  CreateOps(program_desc, block_id, with_feed_fetch_ops);
}

void NaiveExecutor::Run() {
42 43 44
#ifdef PADDLE_WITH_MKLDNN
  platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
45
  platform::ScopedFlushDenormal flush;
46
  for (auto &op : ops_) {
Y
Yan Chunwei 已提交
47 48
    VLOG(4) << std::this_thread::get_id() << " run "
            << op->DebugStringEx(scope_) << " on scope " << scope_;
49
    op->SetIsCalledByExecutor(false);
50 51 52 53
    op->Run(*scope_, place_);
  }
}

54 55
void NaiveExecutor::CreateVariables(const ProgramDesc &desc, int block_id,
                                    bool persistable, Scope *scope) {
56 57 58
  PADDLE_ENFORCE_NOT_NULL(scope,
                          platform::errors::InvalidArgument(
                              "The Scope to hold variables is nullptr."));
59

60 61
  auto &global_block = desc.Block(block_id);

62
  const auto *anc = scope;
63 64 65
  PADDLE_ENFORCE_NE(
      anc->parent(), anc,
      platform::errors::InvalidArgument("Input scope should be child scope."));
66 67
  while (anc->parent()) {
    anc = anc->parent();
68 69
  }

Y
Yan Chunwei 已提交
70
  int num_vars = 0;
71 72 73 74
  for (auto &var : global_block.AllVars()) {
    if (var->Name() == framework::kEmptyVarName) {
      continue;
    }
Y
Yan Chunwei 已提交
75
    num_vars++;
76 77 78 79 80 81 82 83 84 85 86 87 88

    if (persistable == var->Persistable()) {
      if (persistable) {
        if (!anc->FindVar(var->Name())) {
          auto *ptr = const_cast<Scope *>(anc)->Var(var->Name());
          VLOG(3) << scope << " Create persistable variable " << var->Name()
                  << ", which pointer is " << ptr;
          InitializeVariable(ptr, var->GetType());
        }
      } else {
        auto *ptr = const_cast<Scope *>(scope)->Var(var->Name());
        VLOG(3) << scope << " Create variable " << var->Name()
                << ", which pointer is " << ptr;
89 90 91 92
        InitializeVariable(ptr, var->GetType());
      }
    }
  }
Y
Yan Chunwei 已提交
93
  VLOG(4) << "naive executor create " << num_vars << " vars";
94 95 96 97 98 99 100
}

void NaiveExecutor::CreateOps(const ProgramDesc &desc, int block_id,
                              bool with_feed_fetch_ops) {
  for (const auto &op_desc : desc.Block(block_id).AllOps()) {
    if (!with_feed_fetch_ops &&
        (op_desc->Type() == "feed" || op_desc->Type() == "fetch")) {
101 102
      LOG(INFO) << "---  skip [" << op_desc->Input("X")[0] << "], "
                << op_desc->Type() << " -> " << op_desc->Output("Out")[0];
103 104 105 106 107 108 109
      continue;
    }
    ops_.emplace_back(OpRegistry::CreateOp(*op_desc));
  }
}

LoDTensor *NaiveExecutor::FindTensor(const std::string &name) {
110 111 112
  PADDLE_ENFORCE_NOT_NULL(scope_,
                          platform::errors::PreconditionNotMet(
                              "Need to init scope in NaiveExecutor firstly."));
113
  auto *var = scope_->FindVar(name);
114 115
  PADDLE_ENFORCE_NOT_NULL(var, platform::errors::NotFound(
                                   "No variable [%s] in current scope.", name));
116 117 118 119 120 121 122 123 124 125 126 127 128 129
  auto *tensor = const_cast<LoDTensor *>(&var->Get<LoDTensor>());
  return tensor;
}

void NaiveExecutor::CleanFeedFetchOps() {
  std::vector<std::unique_ptr<OperatorBase>> ops;
  for (auto &op : ops_) {
    if (op->Type() != "feed" && op->Type() != "fetch") {
      ops.emplace_back(std::move(op));
    }
  }
  ops_.swap(ops);
}

130 131 132 133
NaiveExecutor::~NaiveExecutor() {
#ifdef PADDLE_WITH_MKLDNN
  // Clear mkl-dnn cache,
  // this is needed to have mkl-dnn unit tests working
134
  platform::ClearMKLDNNCache(place_, this);
135 136 137
#endif
}

W
wenbin 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
void NaiveExecutor::ResetTrtOps(int num) {
#if PADDLE_WITH_TENSORRT
  for (auto &op : ops_) {
    if (op->Type() == "tensorrt_engine") {
      operators::TensorRTEngineOp *trtop =
          dynamic_cast<operators::TensorRTEngineOp *>(op.get());
      if (!trtop) return;
      std::string engine_key = trtop->Attr<std::string>("engine_key");
      int engine_predictor_id = trtop->Attr<int>("predictor_id");
      std::string engine_name =
          engine_key + std::to_string(engine_predictor_id);
      operators::TensorRTEngine *trt_engine =
          paddle::inference::Singleton<
              inference::tensorrt::TRTEngineManager>::Global()
              .Get(engine_name);
      if (trt_engine->with_dynamic_shape()) {
        LOG(INFO) << "rebuild trt engine, this may cost a lot of time!";
        trt_engine->ResetContext();
        trt_engine->ClearTensorMap();
        trt_engine->SetProfileNum(num);
        auto *anc = scope_->parent();
        while (anc && anc->parent()) {
          anc = anc->parent();
        }
        if (anc == nullptr) {
          anc = scope_;
        }
        trtop->PrepareTRTEngine(*anc, trt_engine);
      }
    }
  }
#endif
}
171 172
}  // namespace framework
}  // namespace paddle