program.cc 12.5 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/program.h"
16
#include <algorithm>
17
#include <map>
18
#include "lite/model_parser/cpp_desc.h"
J
juncaipeng 已提交
19
#include "lite/operators/conditional_block_op.h"
20
#include "lite/operators/subgraph_op.h"
Y
Yan Chunwei 已提交
21
#include "lite/operators/while_op.h"
22
#ifdef LITE_WITH_PRECISION_PROFILE
Y
Yan Chunwei 已提交
23 24 25 26 27 28 29 30 31 32 33
#include "lite/core/profile/precision_profiler.h"
#endif

namespace paddle {
namespace lite {

void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) {
  CHECK(desc);
  // NOTE: RuntimeProgram do not has all meta info, so save model just update
  // upon origin model
  CHECK(desc->BlocksSize());
34 35
  auto main_block = desc->GetBlock<cpp::BlockDesc>(0);
  main_block->ClearOps();
Y
Yan Chunwei 已提交
36
  for (auto& node : instructions_) {
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    auto op_type = node.op()->op_info()->Type();
    if (op_type == "subgraph") {
      auto subgraph_op = const_cast<operators::SubgraphOp*>(
          static_cast<const operators::SubgraphOp*>(node.op()));
      int sub_block_idx = subgraph_op->op_info()->GetAttr<int32_t>("sub_block");
      if (sub_block_idx < 0) {
        // It's a new subgraph op when its sub_block_idx < 0, Now we add its
        // subblock desc to the program desc, Then update its sub_block_idx to
        // the index of block desc of the program desc.
        sub_block_idx = desc->BlocksSize();
        auto sub_block_desc = subgraph_op->GetSubBlock();
        CHECK(sub_block_desc);
        auto new_block_desc = desc->AddBlock<cpp::BlockDesc>();
        *new_block_desc = *sub_block_desc;
        delete sub_block_desc;
        subgraph_op->mutable_op_info()->SetAttr<int32_t>("sub_block",
                                                         sub_block_idx);
        subgraph_op->SetSubBlock(new_block_desc);
        // Update main block desc after a new subblock desc is added
        main_block = desc->GetBlock<cpp::BlockDesc>(0);
      }
    }
    auto op = main_block->AddOp<cpp::OpDesc>();
Y
Yan Chunwei 已提交
60 61 62 63 64
    *op = *node.op()->op_info();
    op->SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType());
  }
}

65 66 67 68 69 70
// `UpdateVarsOfProgram` will remove unused var_descs and add new created
// vars' descs in the block 0. Now, the type of a new created var can only
// be LOD_TENSOR.
void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
  CHECK(desc);
  CHECK(desc->BlocksSize());
71
  std::map<std::string, cpp::VarDesc> origin_var_maps;
72 73
  auto& main_block = *desc->GetBlock<cpp::BlockDesc>(0);
  auto var_size = main_block.VarsSize();
74
  for (int i = 0; i < var_size; i++) {
75 76 77 78 79 80 81 82 83 84 85 86
    auto v = main_block.GetVar<cpp::VarDesc>(i);
    auto name = v->Name();
    origin_var_maps.emplace(name, *v);
  }

  main_block.ClearVars();
  for (auto& node : instructions_) {
    auto* op = const_cast<lite::OpLite*>(node.op());
    auto* kernel = node.kernel();
    auto* scope = op->scope();
    auto in_names = op->op_info()->input_names();
    auto out_names = op->op_info()->output_names();
87
    in_names.insert(in_names.end(), out_names.begin(), out_names.end());
88
    std::stable_sort(in_names.begin(), in_names.end());
89 90 91 92
    in_names.erase(std::unique(in_names.begin(), in_names.end()),
                   in_names.end());
    for (auto& in_name : in_names) {
      auto it = origin_var_maps.find(in_name);
93 94 95 96 97
      if (it != origin_var_maps.end()) {
        auto* v = main_block.AddVar<cpp::VarDesc>();
        v->SetName((it->second).Name());
        v->SetType((it->second).GetType());
        v->SetPersistable((it->second).Persistable());
98 99 100 101
        if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") {
          v->SetShape((it->second).GetShape());
          v->SetDataType((it->second).GetDataType());
        }
102 103 104
      } else {
        // New created vars must be LOD_TENSOR
        auto* v = main_block.AddVar<cpp::VarDesc>();
105
        v->SetName(in_name);
106 107
        v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
        std::string in_arg_name;
108 109 110 111 112 113 114
        const Type* type;
        if (op->op_info()->GetInputArgname(in_name, &in_arg_name)) {
          type = kernel->GetInputDeclType(in_arg_name);
        } else {
          op->op_info()->GetOutputArgname(in_name, &in_arg_name);
          type = kernel->GetOutputDeclType(in_arg_name);
        }
115
        if (type->IsTensor()) {
116
          auto tensor = scope->FindVar(in_name)->GetMutable<Tensor>();
117
          v->SetPersistable(tensor->persistable());
118
          if (in_name != "feed" && in_name != "fetch") {
119 120
            v->SetShape(tensor->dims().data());
            switch (tensor->precision()) {
121 122 123 124
#define SET_DATATYPE(precision__, data_type)                    \
  case PrecisionType::precision__:                              \
    v->SetDataType(data_type);                                  \
    LOG(INFO) << "update var" << (it->second).Name() << "done"; \
125
    break
126
              SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL);
127
              SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32);
128
              SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16);
129 130 131 132 133 134
              SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8);
              SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16);
              SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32);
              SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64);
#undef SET_DATATYPE
              default:
135
                VLOG(4) << "warning! unknown precision type";
136 137
            }
          }
138 139 140 141 142 143 144
        } else {
          CHECK(false) << "unsupported var type";
        }
      }
    }
  }
}
Y
Yan Chunwei 已提交
145
void RuntimeProgram::Run() {
146 147 148 149 150 151
#ifdef LITE_WITH_PRECISION_PROFILE
  auto inst_precision_profiler = paddle::lite::profile::PrecisionProfiler();
  std::string precision_profiler_summary =
      inst_precision_profiler.GetSummaryHeader();
#endif

152 153 154 155 156 157 158 159 160
#ifdef LITE_WITH_NVTX
  const NVTXAnnotator& annotator = NVTXAnnotator::Global();
  NVTXRangeAnnotation annotation_one_loop = annotator.AnnotateBlock();
  if (annotator.IsEnabled()) {
    annotation_one_loop.generate(register_layer_names_.back(),
                                 lite::Color::Engine);
  }
#endif
  int idx = -1;
Y
Yan Chunwei 已提交
161
  for (auto& inst : instructions_) {
162
    ++idx;
163
#ifndef LITE_WITH_FPGA
164
    if (inst.is_feed_fetch_op()) continue;
165
#endif
166 167 168 169 170 171 172
#ifdef LITE_WITH_NVTX
    NVTXRangeAnnotation annotation = annotator.AnnotateBlock();
    nvtxStringHandle_t registered_name = register_layer_names_[idx];
    if (annotator.IsEnabled()) {
      annotation.generate(registered_name, lite::Color::Runner);
    }
#endif
173 174 175 176
#ifdef LITE_WITH_CUDA
    if (inst.need_sync()) {
      inst.Sync();
    }
177
#endif
Y
Yan Chunwei 已提交
178
    inst.Run();
179
#ifdef LITE_WITH_PRECISION_PROFILE
180
#ifndef LITE_WITH_FPGA
181 182
    precision_profiler_summary +=
        inst_precision_profiler.GetInstPrecision(&inst);
183
#endif
184
#endif  // LITE_WITH_PRECISION_PROFILE
Y
Yan Chunwei 已提交
185
  }
186
#ifdef LITE_WITH_PROFILE
187
  LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 1);
188
#endif
189 190
#ifdef LITE_WITH_PRECISION_PROFILE
  LOG(INFO) << "\n" << precision_profiler_summary;
191
#endif
Y
Yan Chunwei 已提交
192 193 194 195 196 197
}

void Program::Build(const cpp::ProgramDesc& prog) {
  CHECK(ops_.empty()) << "Executor duplicate Build found";

  // Create operators.
198
  auto& program = prog;
Y
Yan Chunwei 已提交
199 200 201 202 203 204 205 206 207
  CHECK(program.BlocksSize());
  auto& main_block = *program.GetBlock<cpp::BlockDesc>(0);
  for (size_t i = 0; i < main_block.OpsSize(); ++i) {
    auto& op_desc = *main_block.GetOp<cpp::OpDesc>(i);
    auto op_type = op_desc.Type();
    // if (op_type == "feed" || op_type == "fetch") continue;
    VLOG(4) << "create Op [" << op_type << "]";
    auto op = LiteOpRegistry::Global().Create(op_type);
    CHECK(op) << "no Op found for " << op_type;
208 209
    if (op_type == "while" || op_type == "conditional_block" ||
        op_type == "subgraph") {
210
      auto sub_block_idx = op_desc.GetAttr<int32_t>("sub_block");
211 212 213 214
      CHECK(sub_block_idx >= 0 && sub_block_idx < program.BlocksSize())
          << "Invalid attribute sub_block(" << sub_block_idx << ") for "
          << op_type;
      auto sub_block_desc =
Y
Yan Chunwei 已提交
215 216
          const_cast<cpp::ProgramDesc&>(prog).GetBlock<cpp::BlockDesc>(
              sub_block_idx);
217
      CHECK(sub_block_desc);
J
juncaipeng 已提交
218
      if (op_type == "while") {
219 220
        static_cast<operators::WhileOpLite*>(op.get())->SetSubBlock(
            sub_block_desc);
J
juncaipeng 已提交
221 222
      } else if (op_type == "conditional_block") {
        static_cast<operators::ConditionalBlockOpLite*>(op.get())->SetSubBlock(
223 224 225 226
            sub_block_desc);
      } else if (op_type == "subgraph") {
        static_cast<operators::SubgraphOp*>(op.get())->SetSubBlock(
            sub_block_desc);
J
juncaipeng 已提交
227
      }
Y
Yan Chunwei 已提交
228 229 230 231 232 233
    }
    ops_.emplace_back(std::move(op));
    ops_.back()->Attach(op_desc, exec_scope_);
  }
}

234 235
void Program::PrepareWorkspace(const cpp::ProgramDesc& prog,
                               const std::vector<std::string>& var_names) {
Y
Yan Chunwei 已提交
236 237 238 239 240 241 242 243
  CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
  exec_scope_ = &scope_->NewScope();
  // Create Feed and Fetch var.
  scope_->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
  scope_->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
  tmp_vars_.push_back("feed");
  tmp_vars_.push_back("fetch");

244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
  auto VarPrecision2KernlPrecision =
      [](const lite::VarDescAPI::Type& type) -> PrecisionType {
    switch (type) {
      case lite::VarDescAPI::Type::FP32:
        return PRECISION(kFloat);
      case lite::VarDescAPI::Type::FP16:
        return PRECISION(kFP16);
      case lite::VarDescAPI::Type::INT8:
        return PRECISION(kInt8);
      case lite::VarDescAPI::Type::INT16:
        return PRECISION(kInt16);
      case lite::VarDescAPI::Type::INT32:
        return PRECISION(kInt32);
      case lite::VarDescAPI::Type::INT64:
        return PRECISION(kInt64);
      default:
        // LOG(FATAL) << "not supported type: " << static_cast<int>(type);
        return PRECISION(kUnk);
    }
  };

265
  auto& program = prog;
Y
Yan Chunwei 已提交
266 267 268 269 270 271
  CHECK(program.BlocksSize());
  for (size_t b = 0; b < program.BlocksSize(); ++b) {
    auto& main_block = *program.GetBlock<cpp::BlockDesc>(b);
    for (size_t i = 0; i < main_block.VarsSize(); ++i) {
      auto& var_desc = *main_block.GetVar<cpp::VarDesc>(i);
      if (!var_desc.Persistable()) {
272 273 274 275 276 277
        if (var_desc.GetType() == lite::VarDescAPI::Type::LOD_TENSOR &&
            VarPrecision2KernlPrecision(var_desc.GetDataType()) !=
                PRECISION(kUnk)) {
          var_data_type_[var_desc.Name()] =
              VarPrecision2KernlPrecision(var_desc.GetDataType());
        }
Y
Yan Chunwei 已提交
278
        tmp_vars_.push_back(var_desc.Name());
279 280 281
        VLOG(4) << "var name: " << var_desc.Name() << " type is "
                << static_cast<int>(var_desc.GetType()) << " data type is "
                << static_cast<int>(var_desc.GetDataType());
Y
Yan Chunwei 已提交
282 283 284 285 286 287 288 289 290 291 292
        exec_scope_->Var(var_desc.Name());
        if (b > 0) {
          VLOG(4) << "var: " << var_desc.Name();
        }
      } else {
        if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
        weights_.push_back(var_desc.Name());
        if (var_desc.Persistable()) scope_->Var(var_desc.Name());
      }
    }
  }
293 294 295 296 297 298 299

  for (auto i : var_names) {
    exec_scope_->LocalVar(i);
    auto* tensor = scope_->Var(i)->GetMutable<lite::Tensor>();
    auto* sub_tensor = exec_scope_->Var(i)->GetMutable<lite::Tensor>();
    sub_tensor->CopyDataFrom(*tensor);
  }
Y
Yan Chunwei 已提交
300 301 302
}

void Instruction::Run() {
303 304 305 306 307 308 309
#ifdef LITE_WITH_PROFILE
  CHECK(profiler_) << "Profiler pointer of kernel can not be nullptr. "
                      "When LITE_WITH_PROFILE is defined, please set a "
                      "Profiler for Instruction.";
  profiler_->StartTiming(
      profile::Type::kCreate, profile_id_, kernel_->mutable_context());
#endif
Y
Yan Chunwei 已提交
310 311
  CHECK(op_) << "op null";
  CHECK(kernel_) << "kernel null";
312

Y
Yan Chunwei 已提交
313 314 315 316 317
  if (first_epoch_) {
    first_epoch_ = false;
    CHECK(op_->CheckShape());
  }

318 319 320
  if (op_->run_once() && has_run_) {
    return;
  }
321

322
  op_->InferShape();
Y
Yan Chunwei 已提交
323 324
  kernel_->Launch();
  has_run_ = true;
325 326 327

#ifdef LITE_WITH_PROFILE
  if (first_epoch_for_profiler_) {
328
    kernel_->SetIsKernelTest(false);
329 330 331 332
    SetProfileRuntimeOpInfo(profiler_->GetOpCharacter(profile_id_));
    first_epoch_for_profiler_ = false;
  }
#endif
Y
Yan Chunwei 已提交
333 334 335 336 337 338 339 340 341
}

STL::ostream& operator<<(STL::ostream& os, const Instruction& other) {
  os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")";
  return os;
}

}  // namespace lite
}  // namespace paddle