predictor_sdk.h 2.4 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <map>
#include <string>
G
guru4elephant 已提交
18 19 20 21 22
#include "core/sdk-cpp/include/config_manager.h"
#include "core/sdk-cpp/include/endpoint.h"
#include "core/sdk-cpp/include/endpoint_config.h"
#include "core/sdk-cpp/include/predictor.h"
#include "core/sdk-cpp/include/stub.h"
W
sdk-cpp  
wangguibao 已提交
23 24 25 26 27 28

namespace baidu {
namespace paddle_serving {
namespace sdk_cpp {

class PredictorApi {
W
wangguibao 已提交
29 30
 public:
  PredictorApi() {}
W
sdk-cpp  
wangguibao 已提交
31

W
wangguibao 已提交
32
  int register_all();
W
sdk-cpp  
wangguibao 已提交
33

B
barrierye 已提交
34
  int create(const std::string& sdk_desc_str);
G
guru4elephant 已提交
35

W
wangguibao 已提交
36
  int create(const char* path, const char* file);
W
sdk-cpp  
wangguibao 已提交
37

W
wangguibao 已提交
38
  int thrd_initialize();
W
sdk-cpp  
wangguibao 已提交
39

W
wangguibao 已提交
40
  int thrd_clear();
W
sdk-cpp  
wangguibao 已提交
41

W
wangguibao 已提交
42
  int thrd_finalize();
W
sdk-cpp  
wangguibao 已提交
43

W
wangguibao 已提交
44
  void destroy();
W
sdk-cpp  
wangguibao 已提交
45

W
wangguibao 已提交
46 47 48 49
  static PredictorApi& instance() {
    static PredictorApi api;
    return api;
  }
W
sdk-cpp  
wangguibao 已提交
50

W
wangguibao 已提交
51 52 53 54 55 56
  Predictor* fetch_predictor(std::string ep_name) {
    std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
    if (it == _endpoints.end() || !it->second) {
      LOG(ERROR) << "Failed fetch predictor:"
                 << ", ep_name: " << ep_name;
      return NULL;
W
sdk-cpp  
wangguibao 已提交
57
    }
W
wangguibao 已提交
58 59 60 61 62 63 64 65 66
    return it->second->get_predictor();
  }

  Predictor* fetch_predictor(std::string ep_name, const void* params) {
    std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
    if (it == _endpoints.end() || !it->second) {
      LOG(ERROR) << "Failed fetch predictor:"
                 << ", ep_name: " << ep_name;
      return NULL;
W
sdk-cpp  
wangguibao 已提交
67
    }
W
wangguibao 已提交
68 69 70 71 72 73 74 75
    return it->second->get_predictor(params);
  }

  int free_predictor(Predictor* predictor) {
    const Stub* stub = predictor->stub();
    if (!stub || stub->return_predictor(predictor) != 0) {
      LOG(ERROR) << "Failed return predictor via stub";
      return -1;
W
sdk-cpp  
wangguibao 已提交
76 77
    }

W
wangguibao 已提交
78 79
    return 0;
  }
W
sdk-cpp  
wangguibao 已提交
80

W
wangguibao 已提交
81 82 83 84
 private:
  EndpointConfigManager _config_manager;
  std::map<std::string, Endpoint*> _endpoints;
};
W
sdk-cpp  
wangguibao 已提交
85

W
wangguibao 已提交
86 87 88
}  // namespace sdk_cpp
}  // namespace paddle_serving
}  // namespace baidu