predictor_sdk.h 2.5 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <map>
#include <string>
G
guru4elephant 已提交
18 19 20 21 22
#include "core/sdk-cpp/include/config_manager.h"
#include "core/sdk-cpp/include/endpoint.h"
#include "core/sdk-cpp/include/endpoint_config.h"
#include "core/sdk-cpp/include/predictor.h"
#include "core/sdk-cpp/include/stub.h"
W
sdk-cpp  
wangguibao 已提交
23 24 25 26 27 28

namespace baidu {
namespace paddle_serving {
namespace sdk_cpp {

class PredictorApi {
W
wangguibao 已提交
29 30
 public:
  PredictorApi() {}
W
sdk-cpp  
wangguibao 已提交
31

W
wangguibao 已提交
32
  int register_all();
W
sdk-cpp  
wangguibao 已提交
33

B
barrierye 已提交
34
  int create(const std::string& sdk_desc_str);
G
guru4elephant 已提交
35

W
wangguibao 已提交
36
  int create(const char* path, const char* file);
W
sdk-cpp  
wangguibao 已提交
37

W
wangguibao 已提交
38
  int thrd_initialize();
W
sdk-cpp  
wangguibao 已提交
39

W
wangguibao 已提交
40
  int thrd_clear();
W
sdk-cpp  
wangguibao 已提交
41

W
wangguibao 已提交
42
  int thrd_finalize();
W
sdk-cpp  
wangguibao 已提交
43

W
wangguibao 已提交
44
  void destroy();
W
sdk-cpp  
wangguibao 已提交
45

W
wangguibao 已提交
46 47 48 49
  static PredictorApi& instance() {
    static PredictorApi api;
    return api;
  }
W
sdk-cpp  
wangguibao 已提交
50

51
  Predictor* fetch_predictor(std::string ep_name, std::string* variant_tag) {
W
wangguibao 已提交
52 53 54 55 56
    std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
    if (it == _endpoints.end() || !it->second) {
      LOG(ERROR) << "Failed fetch predictor:"
                 << ", ep_name: " << ep_name;
      return NULL;
W
sdk-cpp  
wangguibao 已提交
57
    }
58
    return it->second->get_predictor(variant_tag);
W
wangguibao 已提交
59 60
  }

61 62 63
  Predictor* fetch_predictor(std::string ep_name,
                             const void* params,
                             std::string* variant_tag) {
W
wangguibao 已提交
64 65 66 67 68
    std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
    if (it == _endpoints.end() || !it->second) {
      LOG(ERROR) << "Failed fetch predictor:"
                 << ", ep_name: " << ep_name;
      return NULL;
W
sdk-cpp  
wangguibao 已提交
69
    }
70
    return it->second->get_predictor(params, variant_tag);
W
wangguibao 已提交
71 72 73 74 75 76 77
  }

  int free_predictor(Predictor* predictor) {
    const Stub* stub = predictor->stub();
    if (!stub || stub->return_predictor(predictor) != 0) {
      LOG(ERROR) << "Failed return predictor via stub";
      return -1;
W
sdk-cpp  
wangguibao 已提交
78 79
    }

W
wangguibao 已提交
80 81
    return 0;
  }
W
sdk-cpp  
wangguibao 已提交
82

W
wangguibao 已提交
83 84 85 86
 private:
  EndpointConfigManager _config_manager;
  std::map<std::string, Endpoint*> _endpoints;
};
W
sdk-cpp  
wangguibao 已提交
87

W
wangguibao 已提交
88 89 90
}  // namespace sdk_cpp
}  // namespace paddle_serving
}  // namespace baidu