api_anakin_engine.h 2.8 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Y
Yan Chunwei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

/*
 * This file contains the implementation of inference API with Anakin engine
 * embeded, this API can only support Anakin models.
 */

#pragma once

22
#include <memory>
L
Luo Tao 已提交
23
#include <vector>
Y
Yan Chunwei 已提交
24

C
cuichaowen 已提交
25
#include "framework/core/net/net.h"
C
cuichaowen 已提交
26
#include "framework/graph/graph.h"
27
#include "paddle/fluid/inference/api/paddle_anakin_config.h"
C
cuichaowen 已提交
28
#include "saber/core/shape.h"
C
cuichaowen 已提交
29 30
#include "saber/saber_types.h"

Y
Yan Chunwei 已提交
31 32
namespace paddle {

Y
Yan Chunwei 已提交
33
using contrib::AnakinConfig;
34 35
using anakin::Precision;
using anakin::OpRunType;
Y
Yan Chunwei 已提交
36

37
template <typename T, Precision P, OpRunType R>
Y
Yan Chunwei 已提交
38 39
class PaddleInferenceAnakinPredictor : public PaddlePredictor {
 public:
40
  PaddleInferenceAnakinPredictor() = default;
C
cuichaowen 已提交
41

42 43 44 45
  explicit PaddleInferenceAnakinPredictor(const AnakinConfig& config)
      : config_(config) {
    this->InitPredictor();
  }
Y
Yan Chunwei 已提交
46 47 48 49

  // NOTE Unlike the native engine, the buffers of anakin engine's output_data
  // should be allocated first.
  bool Run(const std::vector<PaddleTensor>& inputs,
50 51
           std::vector<PaddleTensor>* output_data,
           int batch_size = -1) override;
Y
Yan Chunwei 已提交
52 53

  std::unique_ptr<PaddlePredictor> Clone() override;
54 55 56 57
  virtual bool ResetConfig(const AnakinConfig& config);
  virtual anakin::Net<T, P, R>& ResetExecuter(
      std::shared_ptr<anakin::graph::Graph<T, P>> graph_p);
  void InitPredictor();
C
cuichaowen 已提交
58

T
Tao Luo 已提交
59
  ~PaddleInferenceAnakinPredictor() override;
C
cuichaowen 已提交
60

61
 protected:
62 63 64 65 66 67
  void InitEnv();
  void InitGraph();
  virtual void OptimizeGraph();
  virtual void InitNet();
  virtual void SetContext();
  virtual void Predict();
68 69 70 71 72
  static std::mutex mutex_;
  AnakinConfig config_;
  std::shared_ptr<anakin::Context<T>> ctx_p_;
  std::shared_ptr<anakin::graph::Graph<T, P>> graph_p_;
  anakin::Net<T, P, R>* executor_p_{nullptr};
73 74 75 76 77

 private:
  bool RunImpl(const std::vector<PaddleTensor>& inputs,
               std::vector<PaddleTensor>* output_data);
  static std::once_flag init_anakin_;
Y
Yan Chunwei 已提交
78 79
};

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
#ifdef ANAKIN_MLU_PLACE
template <Precision P, OpRunType R>
class PaddleInferenceAnakinMLUPredictor final
    : public PaddleInferenceAnakinPredictor<anakin::MLU, P, R> {
 public:
  explicit PaddleInferenceAnakinMLUPredictor(const AnakinConfig& config) {
    this->ResetConfig(config);
    this->InitPredictor();
  }
  void SetContext() override;
  void OptimizeGraph() override;
  void InitNet() override;
  void Predict() override;
};
#endif
Y
Yan Chunwei 已提交
95
}  // namespace paddle