executor.h 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19
#include <map>
#include <memory>
#include <string>
20
#include <utility>
21
#include <vector>
22
#include "common/types.h"
Refine  
陈后江 已提交
23
#include "common/util.h"
L
liuruilong 已提交
24
#include "framework/lod_tensor.h"
L
liuruilong 已提交
25
#include "framework/operator.h"
26
#include "framework/program/program.h"
L
liuruilong 已提交
27
#include "framework/tensor.h"
28 29

namespace paddle_mobile {
30
namespace framework {
31

32
template <typename Device, typename T = float>
33
class Executor {
W
wangliu 已提交
34
 public:
L
liuruilong 已提交
35 36
  Executor(const Program<Device> &program, paddle_mobile::PaddleMobileConfigInternal config, int batch_size = 1,
           const bool use_optimize = true, const bool lod_mode = false);
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
  Executor(const Program<Device> &program, int batch_size = 1,
           const bool use_optimize = true, const bool lod_mode = false);

  PMStatus Predict(const std::vector<std::pair<std::string, Tensor>> &inputs);
  PMStatus Predict(
      const std::vector<std::pair<std::string, LoDTensor>> &inputs);

  std::vector<T> Predict(const std::vector<T> &input,
                         const std::vector<int64_t> &dims);
  PMStatus Predict();

  void SetInput(const Tensor &input, const std::string &var_name);
  void SetInput(const LoDTensor &input, const std::string &var_name);

  std::shared_ptr<LoDTensor> GetOutput(const std::string &var_name);
52

H
hjchen2 已提交
53
#ifdef PADDLE_MOBILE_FPGA
54 55 56
  void InjectVariable(const Tensor &t, std::string var_name);
  void FeedData(const Tensor &t);
  std::shared_ptr<Tensor> FetchResult(int id = -1);
H
hjchen2 已提交
57 58 59 60 61
  void Predict_From_To(int start = 0, int end = -1);
  void Predict_From(int start);
  void Predict_To(int end);
#endif

W
wangliu 已提交
62
 protected:
63
  Executor() = default;
64

L
liuruilong 已提交
65 66


67 68
  bool varInputMemory(const std::shared_ptr<VarDesc> &var_desc, Variable *var,
                      LoDTensor *tensor) const;
69
  void InitMemory();
L
liuruilong 已提交
70
  void InitCombineMemory();
L
liuruilong 已提交
71
  void InitNoPersistableMemory(const LoDTensor &input_tensor);
72 73
  void LoadMemory(void **data, const std::shared_ptr<VarDesc> var_desc,
                  LoDTensor *tensor);
L
liuruilong 已提交
74
#ifdef PADDLE_MOBILE_CL
75
  void LoadMemory(const VarDesc var_desc, float *tensorInput, char **data);
L
liuruilong 已提交
76
#endif
77 78 79 80

  int batch_size_;
  bool use_optimize_;
  bool lod_mode_;
L
liuruilong 已提交
81
  PaddleMobileConfigInternal config_ = PaddleMobileConfigInternal();
82 83 84 85 86 87 88
  Program<Device> program_;
  std::shared_ptr<ProgramDesc> program_desc_;
  typedef std::shared_ptr<OperatorBase<Device>> OperatorBasePtr;
  std::vector<std::vector<OperatorBasePtr>> ops_of_block_;
  // operators list
  std::vector<OperatorBasePtr> ops_list_;

L
liuruilong 已提交
89 90 91 92
  // for super resoltion
  DDim input_dim_;


D
dolphin8 已提交
93 94 95 96 97 98 99
#ifdef PADDLE_MOBILE_PROFILE
  struct ProfInfo {
    int tid = 0;
    uint64_t runBegin = 0UL;
    uint64_t runEnd = 0UL;
  };
#endif
100 101
};

102
}  // namespace framework
W
wangliu 已提交
103
}  // namespace paddle_mobile