cxx_api.h 4.6 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
S
superjomn 已提交
16 17 18 19
#include <memory>
#include <string>
#include <utility>
#include <vector>
C
Chunwei 已提交
20
#include "paddle/fluid/lite/api/paddle_api.h"
S
superjomn 已提交
21
#include "paddle/fluid/lite/core/op_lite.h"
S
Superjomn 已提交
22 23
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
24
#include "paddle/fluid/lite/core/types.h"
S
superjomn 已提交
25 26
#include "paddle/fluid/lite/model_parser/model_parser.h"

27 28 29 30
#ifdef LITE_WITH_X86
#include "paddle/fluid/framework/program_desc.h"
#endif

S
superjomn 已提交
31
namespace paddle {
S
superjomn 已提交
32 33
namespace lite {

C
Chunwei 已提交
34 35 36 37
/*
 * Predictor for inference, input a model, it will optimize and execute it.
 */
class Predictor {
S
superjomn 已提交
38
 public:
C
Chunwei 已提交
39 40 41 42 43
  // Create an empty predictor.
  Predictor() { scope_ = std::make_shared<Scope>(); }
  // Create a predictor with the weight variable scope set.
  explicit Predictor(const std::shared_ptr<lite::Scope>& root_scope)
      : scope_(root_scope) {}
S
superjomn 已提交
44

C
Chunwei 已提交
45
  // Build from a model, with places set for hardware config.
S
superjomn 已提交
46
  void Build(const std::string& model_path, const Place& prefer_place,
47 48
             const std::vector<Place>& valid_places,
             const std::vector<std::string>& passes = {});
Y
Yan Chunwei 已提交
49 50

  void Build(const framework::proto::ProgramDesc& desc,
51 52
             const Place& prefer_place, const std::vector<Place>& valid_places,
             const std::vector<std::string>& passes = {});
S
superjomn 已提交
53

C
Chunwei 已提交
54 55
  // Run the predictor for a single batch of data.
  void Run() { program_->Run(); }
S
Superjomn 已提交
56

C
Chunwei 已提交
57 58
  // Get offset-th col of feed inputs.
  lite::Tensor* GetInput(size_t offset);
59

C
Chunwei 已提交
60
  // Get offset-th col of fetch results.
C
Chunwei 已提交
61
  const lite::Tensor* GetOutput(size_t offset) const;
62

C
Chunwei 已提交
63
  const framework::proto::ProgramDesc& program_desc() const;
S
init  
superjomn 已提交
64
  const lite::Tensor* GetTensor(const std::string& name) const;
65
  const RuntimeProgram& runtime_program() const;
C
Chunwei 已提交
66

C
Chunwei 已提交
67 68
  // This method is disabled in mobile, for unnecessary dependencies required.
  void SaveModel(const std::string& dir);
S
Superjomn 已提交
69

70 71 72 73 74 75 76 77 78
#ifdef LITE_WITH_X86
  void Run(const std::vector<framework::Tensor>& tensors) {
    FeedVars(tensors);
    program_->Run();
  }

  void FeedVars(const std::vector<framework::Tensor>& tensors);
#endif

S
superjomn 已提交
79
 private:
S
Superjomn 已提交
80 81
  Optimizer optimizer_;
  framework::proto::ProgramDesc program_desc_;
S
superjomn 已提交
82
  std::shared_ptr<Scope> scope_;
S
Superjomn 已提交
83
  std::unique_ptr<RuntimeProgram> program_;
S
superjomn 已提交
84 85
};

C
Chunwei 已提交
86
#ifdef LITE_WITH_X86
Y
Yan Chunwei 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
/*
 * An executor for training.
 *
 * Usage:
 *
 * CXXTrainer trainer(...);
 * trainer.RunStartupProgram(...);
 * auto exe = BuildMainProgramExecutor(...);
 *
 * for (auto& epoch : epoches) {
 *   auto* tensor0 = exe.GetInput(...);
 *   // fill data for tensor0
 *   exe.Run();
 * }
 */
class CXXTrainer {
 public:
  CXXTrainer(const std::shared_ptr<lite::Scope>& root_scope,
             const Place& preferred_place,
             const std::vector<Place>& valid_places)
      : scope_(root_scope),
        preferred_place_(preferred_place),
        valid_places_(valid_places),
C
Chunwei 已提交
110
        main_program_executor_(Predictor(scope_)) {}
Y
Yan Chunwei 已提交
111 112 113 114

  // Build the RuntimeProgram cache for the main program. The cache will run
  // multiple times for the epoches.
  // NOTE Just support to execute the 0-th block currently.
C
Chunwei 已提交
115 116
  Predictor& BuildMainProgramExecutor(const framework::proto::ProgramDesc& desc,
                                      int block_id = 0) {
Y
Yan Chunwei 已提交
117 118 119 120
    main_program_executor_.Build(desc, preferred_place_, valid_places_);
    return main_program_executor_;
  }

121 122 123 124 125 126 127 128 129 130
#ifdef LITE_WITH_X86
  Predictor& BuildMainProgramExecutor(framework::ProgramDesc& desc) {  // NOLINT
    return BuildMainProgramExecutor(*desc.Proto());
  }

  void RunStartupProgram(framework::ProgramDesc& desc) {  // NOLINT
    RunStartupProgram(*desc.Proto());
  }
#endif

Y
Yan Chunwei 已提交
131 132 133
  // Run the startup program. It just executes once, no cache needed.
  void RunStartupProgram(const framework::proto::ProgramDesc& desc,
                         int block_id = 0) {
C
Chunwei 已提交
134
    Predictor exe(scope_);
Y
Yan Chunwei 已提交
135 136 137 138 139 140 141 142 143 144 145
    exe.Build(desc, preferred_place_, valid_places_);
    exe.Run();
  }

 private:
  std::shared_ptr<lite::Scope> scope_;

  Place preferred_place_;
  std::vector<Place> valid_places_;

  // The training program.
C
Chunwei 已提交
146
  Predictor main_program_executor_;
Y
Yan Chunwei 已提交
147
};
C
Chunwei 已提交
148
#endif
Y
Yan Chunwei 已提交
149

S
superjomn 已提交
150
}  // namespace lite
S
superjomn 已提交
151
}  // namespace paddle