program.cc 11.7 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/program.h"
16
#include <algorithm>
17
#include <unordered_map>
Y
Yan Chunwei 已提交
18 19 20
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/model_parser/cpp/var_desc.h"
J
juncaipeng 已提交
21
#include "lite/operators/conditional_block_op.h"
22
#include "lite/operators/subgraph_op.h"
Y
Yan Chunwei 已提交
23
#include "lite/operators/while_op.h"
24
#ifdef LITE_WITH_PRECISION_PROFILE
Y
Yan Chunwei 已提交
25 26 27 28 29 30 31 32 33 34 35
#include "lite/core/profile/precision_profiler.h"
#endif

namespace paddle {
namespace lite {

void RuntimeProgram::SaveOpInfosToProgram(cpp::ProgramDesc* desc) {
  CHECK(desc);
  // NOTE: RuntimeProgram do not has all meta info, so save model just update
  // upon origin model
  CHECK(desc->BlocksSize());
36 37
  auto main_block = desc->GetBlock<cpp::BlockDesc>(0);
  main_block->ClearOps();
Y
Yan Chunwei 已提交
38
  for (auto& node : instructions_) {
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    auto op_type = node.op()->op_info()->Type();
    if (op_type == "subgraph") {
      auto subgraph_op = const_cast<operators::SubgraphOp*>(
          static_cast<const operators::SubgraphOp*>(node.op()));
      int sub_block_idx = subgraph_op->op_info()->GetAttr<int32_t>("sub_block");
      if (sub_block_idx < 0) {
        // It's a new subgraph op when its sub_block_idx < 0, Now we add its
        // subblock desc to the program desc, Then update its sub_block_idx to
        // the index of block desc of the program desc.
        sub_block_idx = desc->BlocksSize();
        auto sub_block_desc = subgraph_op->GetSubBlock();
        CHECK(sub_block_desc);
        auto new_block_desc = desc->AddBlock<cpp::BlockDesc>();
        *new_block_desc = *sub_block_desc;
        delete sub_block_desc;
        subgraph_op->mutable_op_info()->SetAttr<int32_t>("sub_block",
                                                         sub_block_idx);
        subgraph_op->SetSubBlock(new_block_desc);
        // Update main block desc after a new subblock desc is added
        main_block = desc->GetBlock<cpp::BlockDesc>(0);
      }
    }
    auto op = main_block->AddOp<cpp::OpDesc>();
Y
Yan Chunwei 已提交
62 63 64 65 66
    *op = *node.op()->op_info();
    op->SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType());
  }
}

67 68 69 70 71 72 73 74 75
// `UpdateVarsOfProgram` will remove unused var_descs and add new created
// vars' descs in the block 0. Now, the type of a new created var can only
// be LOD_TENSOR.
void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
  CHECK(desc);
  CHECK(desc->BlocksSize());
  std::unordered_map<std::string, cpp::VarDesc> origin_var_maps;
  auto& main_block = *desc->GetBlock<cpp::BlockDesc>(0);
  auto var_size = main_block.VarsSize();
76
  for (int i = 0; i < var_size; i++) {
77 78 79 80 81 82 83 84 85 86 87 88
    auto v = main_block.GetVar<cpp::VarDesc>(i);
    auto name = v->Name();
    origin_var_maps.emplace(name, *v);
  }

  main_block.ClearVars();
  for (auto& node : instructions_) {
    auto* op = const_cast<lite::OpLite*>(node.op());
    auto* kernel = node.kernel();
    auto* scope = op->scope();
    auto in_names = op->op_info()->input_names();
    auto out_names = op->op_info()->output_names();
89 90 91 92 93 94
    in_names.insert(in_names.end(), out_names.begin(), out_names.end());
    std::sort(in_names.begin(), in_names.end());
    in_names.erase(std::unique(in_names.begin(), in_names.end()),
                   in_names.end());
    for (auto& in_name : in_names) {
      auto it = origin_var_maps.find(in_name);
95 96 97 98 99
      if (it != origin_var_maps.end()) {
        auto* v = main_block.AddVar<cpp::VarDesc>();
        v->SetName((it->second).Name());
        v->SetType((it->second).GetType());
        v->SetPersistable((it->second).Persistable());
100 101 102 103
        if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") {
          v->SetShape((it->second).GetShape());
          v->SetDataType((it->second).GetDataType());
        }
104 105 106
      } else {
        // New created vars must be LOD_TENSOR
        auto* v = main_block.AddVar<cpp::VarDesc>();
107
        v->SetName(in_name);
108 109
        v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
        std::string in_arg_name;
110 111 112 113 114 115 116
        const Type* type;
        if (op->op_info()->GetInputArgname(in_name, &in_arg_name)) {
          type = kernel->GetInputDeclType(in_arg_name);
        } else {
          op->op_info()->GetOutputArgname(in_name, &in_arg_name);
          type = kernel->GetOutputDeclType(in_arg_name);
        }
117
        if (type->IsTensor()) {
118
          auto tensor = scope->FindVar(in_name)->GetMutable<Tensor>();
119
          v->SetPersistable(tensor->persistable());
120
          if (in_name != "feed" && in_name != "fetch") {
121 122
            v->SetShape(tensor->dims().data());
            switch (tensor->precision()) {
123 124 125 126
#define SET_DATATYPE(precision__, data_type)                    \
  case PrecisionType::precision__:                              \
    v->SetDataType(data_type);                                  \
    LOG(INFO) << "update var" << (it->second).Name() << "done"; \
127
    break
128
              SET_DATATYPE(kBool, VarDescAPI::VarDataType::BOOL);
129
              SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32);
130
              SET_DATATYPE(kFP16, VarDescAPI::VarDataType::FP16);
131 132 133 134 135 136
              SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8);
              SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16);
              SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32);
              SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64);
#undef SET_DATATYPE
              default:
137
                VLOG(4) << "warning! unknown precision type";
138 139
            }
          }
140 141 142 143 144 145 146
        } else {
          CHECK(false) << "unsupported var type";
        }
      }
    }
  }
}
Y
Yan Chunwei 已提交
147
void RuntimeProgram::Run() {
148 149 150 151 152 153
#ifdef LITE_WITH_PRECISION_PROFILE
  auto inst_precision_profiler = paddle::lite::profile::PrecisionProfiler();
  std::string precision_profiler_summary =
      inst_precision_profiler.GetSummaryHeader();
#endif

Y
Yan Chunwei 已提交
154
  for (auto& inst : instructions_) {
155
#ifndef LITE_WITH_FPGA
156
    if (inst.is_feed_fetch_op()) continue;
157 158 159 160 161
#endif
#ifdef LITE_WITH_CUDA
    if (inst.need_sync()) {
      inst.Sync();
    }
162
#endif
Y
Yan Chunwei 已提交
163
    inst.Run();
164
#ifdef LITE_WITH_PRECISION_PROFILE
165
#ifndef LITE_WITH_FPGA
166 167
    precision_profiler_summary +=
        inst_precision_profiler.GetInstPrecision(&inst);
168
#endif
169
#endif  // LITE_WITH_PRECISION_PROFILE
Y
Yan Chunwei 已提交
170
  }
171
#ifdef LITE_WITH_PROFILE
172
  LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 1);
173
#endif
174 175
#ifdef LITE_WITH_PRECISION_PROFILE
  LOG(INFO) << "\n" << precision_profiler_summary;
176
#endif
Y
Yan Chunwei 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
}

void Program::Build(const cpp::ProgramDesc& prog) {
  CHECK(ops_.empty()) << "Executor duplicate Build found";

  // Create operators.
  auto program = prog;
  CHECK(program.BlocksSize());
  auto& main_block = *program.GetBlock<cpp::BlockDesc>(0);
  for (size_t i = 0; i < main_block.OpsSize(); ++i) {
    auto& op_desc = *main_block.GetOp<cpp::OpDesc>(i);
    auto op_type = op_desc.Type();
    // if (op_type == "feed" || op_type == "fetch") continue;
    VLOG(4) << "create Op [" << op_type << "]";
    auto op = LiteOpRegistry::Global().Create(op_type);
    CHECK(op) << "no Op found for " << op_type;
193 194
    if (op_type == "while" || op_type == "conditional_block" ||
        op_type == "subgraph") {
195
      auto sub_block_idx = op_desc.GetAttr<int32_t>("sub_block");
196 197 198 199
      CHECK(sub_block_idx >= 0 && sub_block_idx < program.BlocksSize())
          << "Invalid attribute sub_block(" << sub_block_idx << ") for "
          << op_type;
      auto sub_block_desc =
Y
Yan Chunwei 已提交
200 201
          const_cast<cpp::ProgramDesc&>(prog).GetBlock<cpp::BlockDesc>(
              sub_block_idx);
202
      CHECK(sub_block_desc);
J
juncaipeng 已提交
203
      if (op_type == "while") {
204 205
        static_cast<operators::WhileOpLite*>(op.get())->SetSubBlock(
            sub_block_desc);
J
juncaipeng 已提交
206 207
      } else if (op_type == "conditional_block") {
        static_cast<operators::ConditionalBlockOpLite*>(op.get())->SetSubBlock(
208 209 210 211
            sub_block_desc);
      } else if (op_type == "subgraph") {
        static_cast<operators::SubgraphOp*>(op.get())->SetSubBlock(
            sub_block_desc);
J
juncaipeng 已提交
212
      }
Y
Yan Chunwei 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
    }
    ops_.emplace_back(std::move(op));
    ops_.back()->Attach(op_desc, exec_scope_);
  }
}

void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) {
  CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
  exec_scope_ = &scope_->NewScope();
  // Create Feed and Fetch var.
  scope_->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
  scope_->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
  tmp_vars_.push_back("feed");
  tmp_vars_.push_back("fetch");

228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
  auto VarPrecision2KernlPrecision =
      [](const lite::VarDescAPI::Type& type) -> PrecisionType {
    switch (type) {
      case lite::VarDescAPI::Type::FP32:
        return PRECISION(kFloat);
      case lite::VarDescAPI::Type::FP16:
        return PRECISION(kFP16);
      case lite::VarDescAPI::Type::INT8:
        return PRECISION(kInt8);
      case lite::VarDescAPI::Type::INT16:
        return PRECISION(kInt16);
      case lite::VarDescAPI::Type::INT32:
        return PRECISION(kInt32);
      case lite::VarDescAPI::Type::INT64:
        return PRECISION(kInt64);
      default:
        // LOG(FATAL) << "not supported type: " << static_cast<int>(type);
        return PRECISION(kUnk);
    }
  };

Y
Yan Chunwei 已提交
249 250 251 252 253 254 255
  auto program = prog;
  CHECK(program.BlocksSize());
  for (size_t b = 0; b < program.BlocksSize(); ++b) {
    auto& main_block = *program.GetBlock<cpp::BlockDesc>(b);
    for (size_t i = 0; i < main_block.VarsSize(); ++i) {
      auto& var_desc = *main_block.GetVar<cpp::VarDesc>(i);
      if (!var_desc.Persistable()) {
256 257 258 259 260 261
        if (var_desc.GetType() == lite::VarDescAPI::Type::LOD_TENSOR &&
            VarPrecision2KernlPrecision(var_desc.GetDataType()) !=
                PRECISION(kUnk)) {
          var_data_type_[var_desc.Name()] =
              VarPrecision2KernlPrecision(var_desc.GetDataType());
        }
Y
Yan Chunwei 已提交
262
        tmp_vars_.push_back(var_desc.Name());
263 264 265
        VLOG(4) << "var name: " << var_desc.Name() << " type is "
                << static_cast<int>(var_desc.GetType()) << " data type is "
                << static_cast<int>(var_desc.GetDataType());
Y
Yan Chunwei 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279
        exec_scope_->Var(var_desc.Name());
        if (b > 0) {
          VLOG(4) << "var: " << var_desc.Name();
        }
      } else {
        if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
        weights_.push_back(var_desc.Name());
        if (var_desc.Persistable()) scope_->Var(var_desc.Name());
      }
    }
  }
}

void Instruction::Run() {
280 281 282 283 284 285 286
#ifdef LITE_WITH_PROFILE
  CHECK(profiler_) << "Profiler pointer of kernel can not be nullptr. "
                      "When LITE_WITH_PROFILE is defined, please set a "
                      "Profiler for Instruction.";
  profiler_->StartTiming(
      profile::Type::kCreate, profile_id_, kernel_->mutable_context());
#endif
Y
Yan Chunwei 已提交
287 288
  CHECK(op_) << "op null";
  CHECK(kernel_) << "kernel null";
289

Y
Yan Chunwei 已提交
290 291 292 293 294
  if (first_epoch_) {
    first_epoch_ = false;
    CHECK(op_->CheckShape());
  }

295 296 297
  if (op_->run_once() && has_run_) {
    return;
  }
298

299
  op_->InferShape();
Y
Yan Chunwei 已提交
300 301
  kernel_->Launch();
  has_run_ = true;
302 303 304 305 306 307 308

#ifdef LITE_WITH_PROFILE
  if (first_epoch_for_profiler_) {
    SetProfileRuntimeOpInfo(profiler_->GetOpCharacter(profile_id_));
    first_epoch_for_profiler_ = false;
  }
#endif
Y
Yan Chunwei 已提交
309 310 311 312 313 314 315 316 317
}

STL::ostream& operator<<(STL::ostream& os, const Instruction& other) {
  os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")";
  return os;
}

}  // namespace lite
}  // namespace paddle