paddle_engine.h 4.9 KB
Newer Older
Z
update  
zhangjun 已提交
1
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Z
zhangjun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <pthread.h>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include "core/configure/include/configure_parser.h"
#include "core/configure/inferencer_configure.pb.h"
#include "core/predictor/framework/infer.h"
#include "paddle_inference_api.h"  // NOLINT

namespace baidu {
namespace paddle_serving {
namespace inference {

using paddle_infer::Config;
using paddle_infer::Predictor;
using paddle_infer::Tensor;
using paddle_infer::CreatePredictor;

Z
update  
zhangjun 已提交
36 37 38 39
const static int max_batch = 32;
const static int min_subgraph_size = 3;
// Engine Base
class PaddleEngineBase {
Z
zhangjun 已提交
40
 public:
Z
update  
zhangjun 已提交
41
  virtual ~PaddleEngineBase() {}
Z
zhangjun 已提交
42
  virtual std::vector<std::string> GetInputNames() {
Z
update  
zhangjun 已提交
43
    return _predictor -> GetInputNames();
Z
zhangjun 已提交
44 45 46
  }

  virtual std::unique_ptr<Tensor> GetInputHandle(const std::string& name) {
Z
update  
zhangjun 已提交
47
    return _predictor -> GetInputHandle(name);
Z
zhangjun 已提交
48 49 50
  }

  virtual std::vector<std::string> GetOutputNames() {
Z
update  
zhangjun 已提交
51
    return _predictor -> GetOutputNames();
Z
zhangjun 已提交
52 53 54
  }

  virtual std::unique_ptr<Tensor> GetOutputHandle(const std::string& name) {
Z
update  
zhangjun 已提交
55
    return _predictor -> GetOutputHandle(name);
Z
zhangjun 已提交
56 57 58
  }

  virtual bool Run() {
Z
update  
zhangjun 已提交
59
    if (!_predictor -> Run()) {
Z
zhangjun 已提交
60 61 62 63 64 65
      LOG(ERROR) << "Failed call Run with paddle predictor";
      return false;
    }
    return true;
  }

Z
update  
zhangjun 已提交
66
  virtual int create(const configure::EngineDesc& conf) = 0;
Z
zhangjun 已提交
67

Z
update  
zhangjun 已提交
68 69
  virtual int clone(void* predictor) {
    if (predictor == NULL) {
Z
zhangjun 已提交
70 71 72
      LOG(ERROR) << "origin paddle Predictor is null.";
      return -1;
    }
Z
update  
zhangjun 已提交
73 74 75 76
    Predictor*  prep = static_cast<Predictor*>(predictor);
    _predictor = prep -> Clone();
    if (_predictor.get() == NULL) {
      LOG(ERROR) << "fail to clone paddle predictor: " << predictor;
Z
zhangjun 已提交
77 78 79 80 81
      return -1;
    }
    return 0;
  }

Z
update  
zhangjun 已提交
82
  virtual void* get() { return _predictor.get(); }
Z
zhangjun 已提交
83 84

 protected:
Z
update  
zhangjun 已提交
85
  std::shared_ptr<Predictor> _predictor;
Z
zhangjun 已提交
86 87
};

Z
update  
zhangjun 已提交
88 89
// Paddle Inference Engine
class PaddleInferenceEngine : public PaddleEngineBase {
Z
zhangjun 已提交
90
 public:
Z
update  
zhangjun 已提交
91 92 93
  int create(const configure::EngineDesc& engine_conf) {
    std::string model_path = engine_conf.model_dir();
    if (access(model_path.c_str(), F_OK) == -1) {
Z
zhangjun 已提交
94
      LOG(ERROR) << "create paddle predictor failed, path not exits: "
Z
update  
zhangjun 已提交
95
                 << model_path;
Z
zhangjun 已提交
96 97 98 99
      return -1;
    }

    Config config;
Z
update  
zhangjun 已提交
100 101 102 103 104 105 106 107 108 109 110
    // todo, auto config(zhangjun)
    if(engine_conf.has_combined_model()) {
      if(!engine_conf.combined_model()) {
        config.SetModel(model_path)
      } else {
        config.SetParamsFile(model_path + "/__params__");
        config.SetProgFile(model_path + "/__model__");
      }
    } else {
      config.SetParamsFile(model_path + "/__params__");
      config.SetProgFile(model_path + "/__model__");
Z
zhangjun 已提交
111
    }
Z
update  
zhangjun 已提交
112
    
Z
zhangjun 已提交
113
    config.SwitchSpecifyInputNames(true);
Z
update  
zhangjun 已提交
114 115 116 117
    config.SetCpuMathLibraryNumThreads(1);
    if (engine_conf.has_use_gpu() && engine_conf.use_gpu()) {
      // 2000MB GPU memory
      config.EnableUseGpu(2000, FLAGS_gpuid);
Z
zhangjun 已提交
118
    }
Z
update  
zhangjun 已提交
119 120 121 122 123 124 125 126 127
  
    if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
      config.EnableTensorRtEngine(1 << 20,
                                  max_batch,
                                  min_subgraph_size,
                                  Config::Precision::kFloat32,
                                  false,
                                  false);
      LOG(INFO) << "create TensorRT predictor";
Z
zhangjun 已提交
128 129
    }

Z
update  
zhangjun 已提交
130 131
    if (engine_conf.has_lite() && engine_conf.use_lite()) {
      config.EnableLiteEngine(PrecisionType::kFloat32, true);
Z
zhangjun 已提交
132 133
    }

Z
update  
zhangjun 已提交
134 135 136 137 138
    if (engine_conf.has_xpu() && engine_conf.use_xpu()) {
      // 2 MB l3 cache
      config.EnableXpu(2 * 1024 * 1024);
    }
    if (engine_conf.has_enable_ir_optimization() && !engine_conf.enable_ir_optimization()) {
Z
zhangjun 已提交
139
      config.SwitchIrOptim(false);
Z
update  
zhangjun 已提交
140 141
    } else {
      config.SwitchIrOptim(true);
Z
zhangjun 已提交
142 143
    }

Z
update  
zhangjun 已提交
144 145
    if (engine_conf.has_enable_memory_optimization() && engine_conf.enable_memory_optimization()) {
      config.EnableMemoryOptim();
Z
zhangjun 已提交
146
    }
Z
update  
zhangjun 已提交
147 148 149 150
    
    if (false) {
      // todo, encrypt model
      //analysis_config.SetModelBuffer();
Z
zhangjun 已提交
151 152 153
    }

    AutoLock lock(GlobalPaddleCreateMutex::instance());
Z
update  
zhangjun 已提交
154 155
    _predictor = CreatePredictor(config);
    if (NULL == _predictor.get()) {
Z
zhangjun 已提交
156 157 158
      LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
      return -1;
    }
Z
update  
zhangjun 已提交
159

Z
zhangjun 已提交
160 161 162 163 164
    VLOG(2) << "create paddle predictor sucess, path: " << data_path;
    return 0;
  }
};

Z
update  
zhangjun 已提交
165
}  // namespace inference
Z
zhangjun 已提交
166 167
}  // namespace paddle_serving
}  // namespace baidu