cxx_api.h 4.3 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
S
superjomn 已提交
16 17 18 19
#include <memory>
#include <string>
#include <utility>
#include <vector>
S
superjomn 已提交
20
#include "paddle/fluid/lite/core/op_lite.h"
S
Superjomn 已提交
21 22
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
23
#include "paddle/fluid/lite/core/types.h"
S
superjomn 已提交
24 25 26
#include "paddle/fluid/lite/model_parser/model_parser.h"

namespace paddle {
S
superjomn 已提交
27 28
namespace lite {

C
Chunwei 已提交
29 30 31 32
/*
 * Predictor for inference, input a model, it will optimize and execute it.
 */
class Predictor {
S
superjomn 已提交
33
 public:
C
Chunwei 已提交
34 35 36 37 38
  // Create an empty predictor.
  Predictor() { scope_ = std::make_shared<Scope>(); }
  // Create a predictor with the weight variable scope set.
  explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
      : scope_(root_scope) {}
S
superjomn 已提交
39

C
Chunwei 已提交
40
  // Build from a model, with places set for hardware config.
S
superjomn 已提交
41
  void Build(const std::string& model_path, const Place& prefer_place,
C
Chunwei 已提交
42
             const std::vector<Place>& valid_places);
Y
Yan Chunwei 已提交
43 44 45 46 47 48

  void Build(const framework::proto::ProgramDesc& desc,
             const Place& prefer_place,
             const std::vector<Place>& valid_places) {
    program_desc_ = desc;
    Program program(desc, scope_, valid_places);
S
superjomn 已提交
49

S
Superjomn 已提交
50
    optimizer_.KernelPickPreferPlace(prefer_place);
51 52
    core::KernelPickFactor factor;
    factor.ConsiderTarget();
N
nhzlx 已提交
53
    factor.ConsiderPrecision();
S
Superjomn 已提交
54 55
    optimizer_.Run(std::move(program), valid_places, factor);
    program_ = optimizer_.GenRuntimeProgram();
S
superjomn 已提交
56 57
  }

C
Chunwei 已提交
58 59
  // Run the predictor for a single batch of data.
  void Run() { program_->Run(); }
S
Superjomn 已提交
60

C
Chunwei 已提交
61 62
  // Get offset-th col of feed inputs.
  lite::Tensor* GetInput(size_t offset);
63

C
Chunwei 已提交
64 65
  // Get offset-th col of fetch results.
  const lite::Tensor* GetOutput(size_t offset);
66

C
Chunwei 已提交
67 68
  // Return the program desc for debug.
  const framework::proto::ProgramDesc& program_desc() const;
C
Chunwei 已提交
69 70 71 72 73
  const lite::Tensor* GetTensor(const std::string& name) const {
    auto* var = program_->exec_scope()->FindVar(name);
    return &var->Get<lite::Tensor>();
  }

C
Chunwei 已提交
74 75
  // This method is disabled in mobile, for unnecessary dependencies required.
  void SaveModel(const std::string& dir);
S
Superjomn 已提交
76

S
superjomn 已提交
77
 private:
S
Superjomn 已提交
78 79
  Optimizer optimizer_;
  framework::proto::ProgramDesc program_desc_;
S
superjomn 已提交
80
  std::shared_ptr<Scope> scope_;
S
Superjomn 已提交
81
  std::unique_ptr<RuntimeProgram> program_;
S
superjomn 已提交
82 83
};

C
Chunwei 已提交
84
#ifdef LITE_WITH_X86
Y
Yan Chunwei 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
/*
 * An executor for training.
 *
 * Usage:
 *
 * CXXTrainer trainer(...);
 * trainer.RunStartupProgram(...);
 * auto exe = BuildMainProgramExecutor(...);
 *
 * for (auto& epoch : epoches) {
 *   auto* tensor0 = exe.GetInput(...);
 *   // fill data for tensor0
 *   exe.Run();
 * }
 */
class CXXTrainer {
 public:
  CXXTrainer(const std::shared_ptr<lite::Scope>& root_scope,
             const Place& preferred_place,
             const std::vector<Place>& valid_places)
      : scope_(root_scope),
        preferred_place_(preferred_place),
        valid_places_(valid_places),
C
Chunwei 已提交
108
        main_program_executor_(Predictor(scope_)) {}
Y
Yan Chunwei 已提交
109 110 111 112

  // Build the RuntimeProgram cache for the main program. The cache will run
  // multiple times for the epoches.
  // NOTE Just support to execute the 0-th block currently.
C
Chunwei 已提交
113 114
  Predictor& BuildMainProgramExecutor(const framework::proto::ProgramDesc& desc,
                                      int block_id = 0) {
Y
Yan Chunwei 已提交
115 116 117 118 119 120 121
    main_program_executor_.Build(desc, preferred_place_, valid_places_);
    return main_program_executor_;
  }

  // Run the startup program. It just executes once, no cache needed.
  void RunStartupProgram(const framework::proto::ProgramDesc& desc,
                         int block_id = 0) {
C
Chunwei 已提交
122
    Predictor exe(scope_);
Y
Yan Chunwei 已提交
123 124 125 126 127 128 129 130 131 132 133
    exe.Build(desc, preferred_place_, valid_places_);
    exe.Run();
  }

 private:
  std::shared_ptr<lite::Scope> scope_;

  Place preferred_place_;
  std::vector<Place> valid_places_;

  // The training program.
C
Chunwei 已提交
134
  Predictor main_program_executor_;
Y
Yan Chunwei 已提交
135
};
C
Chunwei 已提交
136
#endif
Y
Yan Chunwei 已提交
137

S
superjomn 已提交
138
}  // namespace lite
S
superjomn 已提交
139
}  // namespace paddle