api_anakin_engine.h 3.8 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Y
Yan Chunwei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

/*
 * This file contains the implementation of inference API with Anakin engine
 * embeded, this API can only support Anakin models.
 */

#pragma once

22
#include <memory>
23
#include <string>
L
Luo Tao 已提交
24
#include <vector>
Y
Yan Chunwei 已提交
25

C
cuichaowen 已提交
26
#include "framework/core/net/net.h"
C
cuichaowen 已提交
27
#include "framework/graph/graph.h"
28
#include "paddle/fluid/inference/api/paddle_anakin_config.h"
C
cuichaowen 已提交
29
#include "saber/core/shape.h"
C
cuichaowen 已提交
30 31
#include "saber/saber_types.h"

Y
Yan Chunwei 已提交
32 33
namespace paddle {

Y
Yan Chunwei 已提交
34
using contrib::AnakinConfig;
35 36
using anakin::Precision;
using anakin::OpRunType;
Y
Yan Chunwei 已提交
37

38
template <typename T, Precision P, OpRunType R>
Y
Yan Chunwei 已提交
39 40
class PaddleInferenceAnakinPredictor : public PaddlePredictor {
 public:
41
  PaddleInferenceAnakinPredictor() = default;
C
cuichaowen 已提交
42

43 44 45 46
  explicit PaddleInferenceAnakinPredictor(const AnakinConfig& config)
      : config_(config) {
    this->InitPredictor();
  }
Y
Yan Chunwei 已提交
47 48 49 50

  // NOTE Unlike the native engine, the buffers of anakin engine's output_data
  // should be allocated first.
  bool Run(const std::vector<PaddleTensor>& inputs,
51 52
           std::vector<PaddleTensor>* output_data,
           int batch_size = -1) override;
Y
Yan Chunwei 已提交
53 54

  std::unique_ptr<PaddlePredictor> Clone() override;
55
  bool Reset(PaddleInferenceAnakinPredictor<T, P, R>* predictor);
56
  void InitPredictor();
57 58 59 60 61 62 63 64 65 66
  std::shared_ptr<anakin::graph::Graph<T, P>> GetGraph() {
    return this->graph_p_;
  }
  std::vector<std::string> GetInputNames() override {
    return this->input_names_;
  }
  std::vector<std::string> GetOutputNames() override {
    return this->output_names_;
  }
  const AnakinConfig& GetConfig() const { return this->config_; }
C
cuichaowen 已提交
67

T
Tao Luo 已提交
68
  ~PaddleInferenceAnakinPredictor() override;
C
cuichaowen 已提交
69

70
 protected:
71 72 73 74 75 76
  void InitEnv();
  void InitGraph();
  virtual void OptimizeGraph();
  virtual void InitNet();
  virtual void SetContext();
  virtual void Predict();
77
  virtual std::unique_ptr<PaddlePredictor> New();
78 79 80 81 82
  static std::mutex mutex_;
  AnakinConfig config_;
  std::shared_ptr<anakin::Context<T>> ctx_p_;
  std::shared_ptr<anakin::graph::Graph<T, P>> graph_p_;
  anakin::Net<T, P, R>* executor_p_{nullptr};
83 84
  std::vector<std::string> input_names_;
  std::vector<std::string> output_names_;
85 86 87 88 89

 private:
  bool RunImpl(const std::vector<PaddleTensor>& inputs,
               std::vector<PaddleTensor>* output_data);
  static std::once_flag init_anakin_;
Y
Yan Chunwei 已提交
90 91
};

92 93 94 95 96
#ifdef ANAKIN_MLU_PLACE
template <Precision P, OpRunType R>
class PaddleInferenceAnakinMLUPredictor final
    : public PaddleInferenceAnakinPredictor<anakin::MLU, P, R> {
 public:
97
  PaddleInferenceAnakinMLUPredictor() = default;
98
  explicit PaddleInferenceAnakinMLUPredictor(const AnakinConfig& config) {
99
    this->config_ = config;
100 101
    this->InitPredictor();
  }
102
  std::unique_ptr<PaddlePredictor> New() override;
103 104 105 106 107 108
  void SetContext() override;
  void OptimizeGraph() override;
  void InitNet() override;
  void Predict() override;
};
#endif
石晓伟 已提交
109 110 111 112 113 114

#ifdef ANAKIN_BM_PLACE
template <Precision P, OpRunType R>
class PaddleInferenceAnakinBMPredictor final
    : public PaddleInferenceAnakinPredictor<anakin::BM, P, R> {
 public:
115
  PaddleInferenceAnakinBMPredictor() = default;
石晓伟 已提交
116
  explicit PaddleInferenceAnakinBMPredictor(const AnakinConfig& config) {
117
    this->config_ = config;
石晓伟 已提交
118 119
    this->InitPredictor();
  }
120
  std::unique_ptr<PaddlePredictor> New() override;
石晓伟 已提交
121 122 123 124 125
  void OptimizeGraph() override;
  void InitNet() override;
  void Predict() override;
};
#endif
Y
Yan Chunwei 已提交
126
}  // namespace paddle