executor.h 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19
#include <map>
#include <memory>
#include <string>
20
#include <unordered_map>
21
#include <utility>
22
#include <vector>
23
#include "common/types.h"
Refine  
陈后江 已提交
24
#include "common/util.h"
L
liuruilong 已提交
25
#include "framework/lod_tensor.h"
L
liuruilong 已提交
26
#include "framework/operator.h"
27
#include "framework/program/program.h"
L
liuruilong 已提交
28
#include "framework/tensor.h"
29
#include "framework/type_trait.h"
30
#include "pass/memory_optimize.h"
31 32

namespace paddle_mobile {
33
namespace framework {
34

35
template <typename Device, typename T = float>
36
class Executor {
W
wangliu 已提交
37
 public:
xiebaiyuan's avatar
xiebaiyuan 已提交
38 39
  Executor(const Program<Device> &program,
           paddle_mobile::PaddleMobileConfigInternal config, int batch_size = 1,
40 41
           const bool use_optimize = true, const bool lod_mode = false);

42 43
  void SetThreadNum(int thread_num,
                    PowerMode power_mode = PERFORMANCE_PRIORITY);
44

45 46 47 48 49 50 51 52 53 54 55 56
  PMStatus Predict(const std::vector<std::pair<std::string, Tensor>> &inputs);
  PMStatus Predict(
      const std::vector<std::pair<std::string, LoDTensor>> &inputs);

  std::vector<T> Predict(const std::vector<T> &input,
                         const std::vector<int64_t> &dims);
  PMStatus Predict();

  void SetInput(const Tensor &input, const std::string &var_name);
  void SetInput(const LoDTensor &input, const std::string &var_name);

  std::shared_ptr<LoDTensor> GetOutput(const std::string &var_name);
57 58 59
#ifdef PADDLE_MOBILE_CL
  const CLImage *GetOutputImage(const std::string &var_name);
#endif
60

61 62
  void FeedTensorData(const std::vector<framework::Tensor> &v);
  void GetTensorResults(std::vector<framework::Tensor *> *v);
63
  std::string GetExceptionMsg();
64

H
hjchen2 已提交
65
#ifdef PADDLE_MOBILE_FPGA
66 67
  void InjectVariable(const Tensor &t, std::string var_name);
  void FeedData(const Tensor &t);
68 69
  void FeedData(const std::vector<void *> &v);
  void GetResults(std::vector<void *> *v);
70
  framework::Tensor *GetTensorByName(const std::string &name);
71
  std::shared_ptr<Tensor> FetchResult(int id = -1);
H
hjchen2 已提交
72 73 74
  void Predict_From_To(int start = 0, int end = -1);
  void Predict_From(int start);
  void Predict_To(int end);
75 76 77
#ifdef PADDLE_MOBILE_FPGA_V2
  void InitQuantMemory();
#endif
H
hjchen2 已提交
78 79
#endif

W
wangliu 已提交
80
 protected:
81
  Executor() = default;
82

H
update  
hjchen2 已提交
83 84
  bool varInputMemory(const std::shared_ptr<VarDesc> &var_desc,
                      Variable *var) const;
85
  void InitFeedFetchList();
86
  void InitMemory();
L
liuruilong 已提交
87
  void InitCombineMemory();
Z
zhaojiaying01 已提交
88
  void InitNoPersistableMemory(const Tensor &input_tensor);
89 90
  void LoadMemory(void **data, const std::shared_ptr<VarDesc> var_desc,
                  LoDTensor *tensor);
L
liuruilong 已提交
91
#ifdef PADDLE_MOBILE_CL
92
  void LoadMemory(const VarDesc var_desc, float *tensorInput, char **data);
L
liuruilong 已提交
93
#endif
94 95 96 97

  int batch_size_;
  bool use_optimize_;
  bool lod_mode_;
L
liuruilong 已提交
98
  PaddleMobileConfigInternal config_;
99 100
  Program<Device> program_;
  std::shared_ptr<ProgramDesc> program_desc_;
101
  std::vector<std::shared_ptr<OperatorBase<Device>>> ops_of_block0_;
102 103
  std::unordered_map<std::string, int> feed_indices_;
  std::unordered_map<std::string, int> fetch_indices_;
104
  std::string exception_msg_;
105

106
  // for super resoltion
xiebaiyuan's avatar
xiebaiyuan 已提交
107
  DDim input_dim_last_;
108
  bool input_dim_has_changed_ = true;
L
liuruilong 已提交
109

D
dolphin8 已提交
110
#ifdef PADDLE_MOBILE_PROFILE
111 112
  typedef typename DtypeTensorTrait<Device>::gtype ProfileTensorType;

D
dolphin8 已提交
113 114 115 116 117
  struct ProfInfo {
    int tid = 0;
    uint64_t runBegin = 0UL;
    uint64_t runEnd = 0UL;
  };
118 119

  void PrintProfile(const vector<Executor<Device, T>::ProfInfo> &profile) const;
D
dolphin8 已提交
120
#endif
121 122
};

123
}  // namespace framework
W
wangliu 已提交
124
}  // namespace paddle_mobile