cxx_api.h 4.7 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
S
superjomn 已提交
16 17 18 19
#include <memory>
#include <string>
#include <utility>
#include <vector>
S
superjomn 已提交
20
#include "paddle/fluid/lite/core/op_lite.h"
S
Superjomn 已提交
21 22
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
23
#include "paddle/fluid/lite/core/types.h"
S
superjomn 已提交
24 25 26
#include "paddle/fluid/lite/model_parser/model_parser.h"

namespace paddle {
S
superjomn 已提交
27 28 29 30
namespace lite {

struct Config {};

Y
Yan Chunwei 已提交
31
class ExecutorLite {
S
superjomn 已提交
32
 public:
Y
Yan Chunwei 已提交
33 34 35 36
  ExecutorLite() { scope_ = std::make_shared<Scope>(); }
  explicit ExecutorLite(const std::shared_ptr<lite::Scope>& root_scope) {
    scope_ = root_scope;
  }
S
superjomn 已提交
37

S
superjomn 已提交
38
  void Build(const std::string& model_path, const Place& prefer_place,
39
             const std::vector<Place>& valid_places) {
S
Superjomn 已提交
40
    LoadModel(model_path, scope_.get(), &program_desc_);
Y
Yan Chunwei 已提交
41 42 43 44 45 46 47 48
    Build(program_desc_, prefer_place, valid_places);
  }

  void Build(const framework::proto::ProgramDesc& desc,
             const Place& prefer_place,
             const std::vector<Place>& valid_places) {
    program_desc_ = desc;
    Program program(desc, scope_, valid_places);
S
superjomn 已提交
49

S
Superjomn 已提交
50
    optimizer_.KernelPickPreferPlace(prefer_place);
51 52
    core::KernelPickFactor factor;
    factor.ConsiderTarget();
S
Superjomn 已提交
53 54
    optimizer_.Run(std::move(program), valid_places, factor);
    program_ = optimizer_.GenRuntimeProgram();
S
superjomn 已提交
55 56
  }

57 58
// This method is disabled in mobile, or unnecessary dependencies required.
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
S
Superjomn 已提交
59
  void SaveModel(const std::string& dir);
60
#endif
S
Superjomn 已提交
61

62
  // Get offset-th col of feed.
63
  lite::Tensor* GetInput(size_t offset) {
64 65
    auto* _feed_list = program_->exec_scope()->FindVar("feed");
    CHECK(_feed_list) << "no feed variable in exec_scope";
66
    auto* feed_list = _feed_list->GetMutable<std::vector<lite::Tensor>>();
67 68 69 70 71 72
    if (offset >= feed_list->size()) {
      feed_list->resize(offset + 1);
    }
    return &feed_list->at(offset);
  }

73
  const lite::Tensor* GetOutput(size_t offset) {
74 75
    auto* _fetch_list = program_->exec_scope()->FindVar("fetch");
    CHECK(_fetch_list) << "no fatch variable in exec_scope";
76
    auto& fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
77 78 79 80
    CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow";
    return &fetch_list.at(offset);
  }

C
Chunwei 已提交
81 82 83 84 85
  const lite::Tensor* GetTensor(const std::string& name) const {
    auto* var = program_->exec_scope()->FindVar(name);
    return &var->Get<lite::Tensor>();
  }

S
Superjomn 已提交
86
  void Run() { program_->Run(); }
S
superjomn 已提交
87

S
Superjomn 已提交
88 89 90 91
  const framework::proto::ProgramDesc& program_desc() const {
    return program_desc_;
  }

S
superjomn 已提交
92
 private:
S
Superjomn 已提交
93 94
  Optimizer optimizer_;
  framework::proto::ProgramDesc program_desc_;
S
superjomn 已提交
95
  std::shared_ptr<Scope> scope_;
S
Superjomn 已提交
96
  std::unique_ptr<RuntimeProgram> program_;
S
superjomn 已提交
97 98
};

Y
Yan Chunwei 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
/*
 * An executor for training.
 *
 * Usage:
 *
 * CXXTrainer trainer(...);
 * trainer.RunStartupProgram(...);
 * auto exe = BuildMainProgramExecutor(...);
 *
 * for (auto& epoch : epoches) {
 *   auto* tensor0 = exe.GetInput(...);
 *   // fill data for tensor0
 *   exe.Run();
 * }
 */
class CXXTrainer {
 public:
  CXXTrainer(const std::shared_ptr<lite::Scope>& root_scope,
             const Place& preferred_place,
             const std::vector<Place>& valid_places)
      : scope_(root_scope),
        preferred_place_(preferred_place),
        valid_places_(valid_places),
        main_program_executor_(ExecutorLite(scope_)) {}

  // Build the RuntimeProgram cache for the main program. The cache will run
  // multiple times for the epoches.
  // NOTE Just support to execute the 0-th block currently.
  ExecutorLite& BuildMainProgramExecutor(
      const framework::proto::ProgramDesc& desc, int block_id = 0) {
    main_program_executor_.Build(desc, preferred_place_, valid_places_);
    return main_program_executor_;
  }

  // Run the startup program. It just executes once, no cache needed.
  void RunStartupProgram(const framework::proto::ProgramDesc& desc,
                         int block_id = 0) {
    ExecutorLite exe(scope_);
    exe.Build(desc, preferred_place_, valid_places_);
    exe.Run();
  }

 private:
  std::shared_ptr<lite::Scope> scope_;

  Place preferred_place_;
  std::vector<Place> valid_places_;

  // The training program.
  ExecutorLite main_program_executor_;
};

S
superjomn 已提交
151
}  // namespace lite
S
superjomn 已提交
152
}  // namespace paddle