cxx_api.h 4.0 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
S
superjomn 已提交
16 17 18 19
#include <memory>
#include <string>
#include <utility>
#include <vector>
S
superjomn 已提交
20
#include "paddle/fluid/lite/core/op_lite.h"
S
Superjomn 已提交
21 22
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
23
#include "paddle/fluid/lite/core/types.h"
S
superjomn 已提交
24 25 26
#include "paddle/fluid/lite/model_parser/model_parser.h"

namespace paddle {
S
superjomn 已提交
27 28
namespace lite {

C
Chunwei 已提交
29 30 31 32
/*
 * Predictor for inference, input a model, it will optimize and execute it.
 */
class Predictor {
S
superjomn 已提交
33
 public:
C
Chunwei 已提交
34 35 36 37 38
  // Create an empty predictor.
  Predictor() { scope_ = std::make_shared<Scope>(); }
  // Create a predictor with the weight variable scope set.
  explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
      : scope_(root_scope) {}
S
superjomn 已提交
39

C
Chunwei 已提交
40
  // Build from a model, with places set for hardware config.
S
superjomn 已提交
41
  void Build(const std::string& model_path, const Place& prefer_place,
42 43
             const std::vector<Place>& valid_places,
             const std::vector<std::string>& passes = {});
Y
Yan Chunwei 已提交
44 45

  void Build(const framework::proto::ProgramDesc& desc,
46 47
             const Place& prefer_place, const std::vector<Place>& valid_places,
             const std::vector<std::string>& passes = {});
S
superjomn 已提交
48

C
Chunwei 已提交
49 50
  // Run the predictor for a single batch of data.
  void Run() { program_->Run(); }
S
Superjomn 已提交
51

C
Chunwei 已提交
52 53
  // Get offset-th col of feed inputs.
  lite::Tensor* GetInput(size_t offset);
54

C
Chunwei 已提交
55 56
  // Get offset-th col of fetch results.
  const lite::Tensor* GetOutput(size_t offset);
57

C
Chunwei 已提交
58
  const framework::proto::ProgramDesc& program_desc() const;
S
init  
superjomn 已提交
59
  const lite::Tensor* GetTensor(const std::string& name) const;
60
  const RuntimeProgram& runtime_program() const;
C
Chunwei 已提交
61

C
Chunwei 已提交
62 63
  // This method is disabled in mobile, for unnecessary dependencies required.
  void SaveModel(const std::string& dir);
S
Superjomn 已提交
64

S
superjomn 已提交
65
 private:
S
Superjomn 已提交
66 67
  Optimizer optimizer_;
  framework::proto::ProgramDesc program_desc_;
S
superjomn 已提交
68
  std::shared_ptr<Scope> scope_;
S
Superjomn 已提交
69
  std::unique_ptr<RuntimeProgram> program_;
S
superjomn 已提交
70 71
};

C
Chunwei 已提交
72
#ifdef LITE_WITH_X86
Y
Yan Chunwei 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
/*
 * An executor for training.
 *
 * Usage:
 *
 * CXXTrainer trainer(...);
 * trainer.RunStartupProgram(...);
 * auto exe = BuildMainProgramExecutor(...);
 *
 * for (auto& epoch : epoches) {
 *   auto* tensor0 = exe.GetInput(...);
 *   // fill data for tensor0
 *   exe.Run();
 * }
 */
class CXXTrainer {
 public:
  CXXTrainer(const std::shared_ptr<lite::Scope>& root_scope,
             const Place& preferred_place,
             const std::vector<Place>& valid_places)
      : scope_(root_scope),
        preferred_place_(preferred_place),
        valid_places_(valid_places),
C
Chunwei 已提交
96
        main_program_executor_(Predictor(scope_)) {}
Y
Yan Chunwei 已提交
97 98 99 100

  // Build the RuntimeProgram cache for the main program. The cache will run
  // multiple times for the epoches.
  // NOTE Just support to execute the 0-th block currently.
C
Chunwei 已提交
101 102
  Predictor& BuildMainProgramExecutor(const framework::proto::ProgramDesc& desc,
                                      int block_id = 0) {
Y
Yan Chunwei 已提交
103 104 105 106 107 108 109
    main_program_executor_.Build(desc, preferred_place_, valid_places_);
    return main_program_executor_;
  }

  // Run the startup program. It just executes once, no cache needed.
  void RunStartupProgram(const framework::proto::ProgramDesc& desc,
                         int block_id = 0) {
C
Chunwei 已提交
110
    Predictor exe(scope_);
Y
Yan Chunwei 已提交
111 112 113 114 115 116 117 118 119 120 121
    exe.Build(desc, preferred_place_, valid_places_);
    exe.Run();
  }

 private:
  std::shared_ptr<lite::Scope> scope_;

  Place preferred_place_;
  std::vector<Place> valid_places_;

  // The training program.
C
Chunwei 已提交
122
  Predictor main_program_executor_;
Y
Yan Chunwei 已提交
123
};
C
Chunwei 已提交
124
#endif
Y
Yan Chunwei 已提交
125

S
superjomn 已提交
126
}  // namespace lite
S
superjomn 已提交
127
}  // namespace paddle