variant.cpp 4.1 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "sdk-cpp/include/variant.h"
#include "sdk-cpp/include/factory.h"
W
sdk-cpp  
wangguibao 已提交
17 18 19 20 21 22

namespace baidu {
namespace paddle_serving {
namespace sdk_cpp {

int Variant::initialize(const EndpointInfo& ep_info,
W
wangguibao 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
                        const VariantInfo& var_info) {
  _endpoint_name = ep_info.endpoint_name;
  _stub_service = ep_info.stub_service;

  _variant_tag = var_info.parameters.route_tag.value;
  _stub_map.clear();

  const SplitParameters& split_info = var_info.splitinfo;
  uint32_t tag_size = split_info.tag_values.size();
  for (uint32_t ti = 0; ti < tag_size; ++ti) {  // split
    Stub* stub = StubFactory::instance().generate_object(_stub_service);
    const std::string& tag_value = split_info.tag_values[ti];
    if (!stub ||
        stub->initialize(var_info,
                         ep_info.endpoint_name,
                         &split_info.split_tag.value,
                         &tag_value) != 0) {
      LOG(ERROR) << "Failed init stub from factory"
                 << ", stub name: " << ep_info.stub_service
                 << ", filter tag: " << tag_value;
      return -1;
W
sdk-cpp  
wangguibao 已提交
44 45
    }

W
wangguibao 已提交
46 47 48 49 50
    // 判重
    std::map<std::string, Stub*>::iterator iter = _stub_map.find(tag_value);
    if (iter != _stub_map.end()) {
      LOG(ERROR) << "duplicated tag value: " << tag_value;
      return -1;
W
sdk-cpp  
wangguibao 已提交
51
    }
W
wangguibao 已提交
52 53
    _stub_map[tag_value] = stub;
  }
W
sdk-cpp  
wangguibao 已提交
54

W
wangguibao 已提交
55 56 57
  if (_stub_map.size() > 0) {
    LOG(INFO) << "Initialize variants from VariantInfo"
              << ", stubs count: " << _stub_map.size();
W
sdk-cpp  
wangguibao 已提交
58
    return 0;
W
wangguibao 已提交
59 60 61 62 63 64 65 66 67 68 69 70
  }

  Stub* stub = StubFactory::instance().generate_object(ep_info.stub_service);
  if (!stub || stub->initialize(var_info, _endpoint_name, NULL, NULL) != 0) {
    LOG(ERROR) << "Failed init stub from factory"
               << ", stub name: " << ep_info.stub_service;
    return -1;
  }

  _default_stub = stub;
  LOG(INFO) << "Succ create default debug";
  return 0;
W
sdk-cpp  
wangguibao 已提交
71 72 73
}

int Variant::thrd_initialize() {
W
wangguibao 已提交
74 75 76 77 78 79 80 81 82 83
  if (_stub_map.size() <= 0) {
    return _default_stub->thrd_initialize();
  }

  std::map<std::string, Stub*>::iterator iter;
  for (iter = _stub_map.begin(); iter != _stub_map.end(); ++iter) {
    Stub* stub = iter->second;
    if (!stub || stub->thrd_initialize() != 0) {
      LOG(ERROR) << "Failed thrd initialize stub: " << iter->first;
      return -1;
W
sdk-cpp  
wangguibao 已提交
84
    }
W
wangguibao 已提交
85 86
    LOG(INFO) << "Succ thrd initialize stub:" << iter->first;
  }
W
sdk-cpp  
wangguibao 已提交
87

W
wangguibao 已提交
88 89
  LOG(WARNING) << "Succ thrd initialize all stubs";
  return 0;
W
sdk-cpp  
wangguibao 已提交
90 91 92
}

int Variant::thrd_clear() {
W
wangguibao 已提交
93 94 95 96 97 98 99 100 101 102
  if (_stub_map.size() <= 0) {
    return _default_stub->thrd_clear();
  }

  std::map<std::string, Stub*>::iterator iter;
  for (iter = _stub_map.begin(); iter != _stub_map.end(); ++iter) {
    Stub* stub = iter->second;
    if (!stub || stub->thrd_clear() != 0) {
      LOG(ERROR) << "Failed thrd clear stub: " << iter->first;
      return -1;
W
sdk-cpp  
wangguibao 已提交
103
    }
W
wangguibao 已提交
104 105
  }
  return 0;
W
sdk-cpp  
wangguibao 已提交
106 107 108
}

int Variant::thrd_finalize() {
W
wangguibao 已提交
109 110 111 112 113 114 115 116 117 118
  if (_stub_map.size() <= 0) {
    return _default_stub->thrd_finalize();
  }

  std::map<std::string, Stub*>::iterator iter;
  for (iter = _stub_map.begin(); iter != _stub_map.end(); ++iter) {
    Stub* stub = iter->second;
    if (!stub || stub->thrd_finalize() != 0) {
      LOG(ERROR) << "Failed thrd finalize stub: " << iter->first;
      return -1;
W
sdk-cpp  
wangguibao 已提交
119
    }
W
wangguibao 已提交
120 121
  }
  return 0;
W
sdk-cpp  
wangguibao 已提交
122 123 124
}

Predictor* Variant::get_predictor() {
W
wangguibao 已提交
125 126 127
  if (_default_stub) {
    return _default_stub->fetch_predictor();
  }
W
sdk-cpp  
wangguibao 已提交
128

W
wangguibao 已提交
129
  return NULL;
W
sdk-cpp  
wangguibao 已提交
130 131
}

W
wangguibao 已提交
132 133 134 135
Predictor* Variant::get_predictor(const void* params) {
  if (_default_stub) {
    return _default_stub->fetch_predictor();
  }
W
sdk-cpp  
wangguibao 已提交
136

W
wangguibao 已提交
137
  return NULL;
W
sdk-cpp  
wangguibao 已提交
138 139
}

W
wangguibao 已提交
140 141 142
}  // namespace sdk_cpp
}  // namespace paddle_serving
}  // namespace baidu