engine.h 2.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <map>
18 19 20 21 22 23 24 25 26 27 28 29 30
#include <memory>
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/program.h"
#include "lite/core/tensor.h"

namespace paddle {
namespace lite {
namespace subgraph {

class Engine {
 public:
31 32
  Engine(KernelContext *ctx,
         int block_idx,
33 34 35
         cpp::BlockDesc *block_desc,
         const std::vector<std::string> &input_names,
         const std::vector<std::string> &output_names,
36 37
         lite::Scope *scope,
         std::string model_cache_dir = "")
38 39
      : ctx_(ctx),
        block_idx_(block_idx),
40 41 42
        block_desc_(block_desc),
        input_names_(input_names),
        output_names_(output_names),
43 44
        scope_(scope),
        model_cache_dir_(model_cache_dir) {}
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
  virtual ~Engine() = default;

  virtual int Build();
  virtual int Launch();

 private:
  Engine(const Engine &) = delete;

 protected:
  virtual int BuildDeviceProgram();
  virtual int LaunchDeviceProgram();

  virtual int BuildOriginProgram();
  virtual int LaunchOriginProgram();

60
  virtual void InitDeviceTensor();
61 62
  virtual bool InputShapeChanged();

63
  KernelContext *ctx_{nullptr};
64 65 66 67 68 69 70 71 72 73 74 75 76 77
  int block_idx_;
  cpp::BlockDesc *block_desc_;
  std::vector<std::string> input_names_;
  std::vector<std::string> output_names_;
  Scope *scope_{nullptr};
  // SUCCESS: device program build successed. FAILED: device program build
  // failed. REBUILD_WHEN_SHAPE_CHANGED: device program build successed but need
  // to rebuild when input shape changed.
  int build_device_program_status_{0};
  std::vector<DDim> origin_idims_;
  std::vector<DDim> origin_odims_;
  std::vector<Tensor *> origin_itensors_;
  std::vector<Tensor *> origin_otensors_;
  std::vector<Instruction> origin_program_;
78
  std::string model_cache_dir_{""};
79 80 81 82 83
};

}  // namespace subgraph
}  // namespace lite
}  // namespace paddle