提交 62d2d295 编写于 作者: D dongdaxiang

make serving compatible with grpc

上级 372800b8
...@@ -119,7 +119,7 @@ int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) { ...@@ -119,7 +119,7 @@ int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) {
LOG(ERROR) << "Predictor Creation Failed"; LOG(ERROR) << "Predictor Creation Failed";
return -1; return -1;
} }
_api.thrd_initialize(); // _api.thrd_initialize();
return 0; return 0;
} }
...@@ -130,7 +130,7 @@ int PredictorClient::create_predictor() { ...@@ -130,7 +130,7 @@ int PredictorClient::create_predictor() {
LOG(ERROR) << "Predictor Creation Failed"; LOG(ERROR) << "Predictor Creation Failed";
return -1; return -1;
} }
_api.thrd_initialize(); // _api.thrd_initialize();
return 0; return 0;
} }
...@@ -152,7 +152,7 @@ int PredictorClient::batch_predict( ...@@ -152,7 +152,7 @@ int PredictorClient::batch_predict(
int fetch_name_num = fetch_name.size(); int fetch_name_num = fetch_name.size();
_api.thrd_clear(); _api.thrd_initialize();
std::string variant_tag; std::string variant_tag;
_predictor = _api.fetch_predictor("general_model", &variant_tag); _predictor = _api.fetch_predictor("general_model", &variant_tag);
predict_res_batch.set_variant_tag(variant_tag); predict_res_batch.set_variant_tag(variant_tag);
...@@ -247,8 +247,9 @@ int PredictorClient::batch_predict( ...@@ -247,8 +247,9 @@ int PredictorClient::batch_predict(
} else { } else {
client_infer_end = timeline.TimeStampUS(); client_infer_end = timeline.TimeStampUS();
postprocess_start = client_infer_end; postprocess_start = client_infer_end;
VLOG(2) << "get model output num";
uint32_t model_num = res.outputs_size(); uint32_t model_num = res.outputs_size();
VLOG(2) << "model num: " << model_num;
for (uint32_t m_idx = 0; m_idx < model_num; ++m_idx) { for (uint32_t m_idx = 0; m_idx < model_num; ++m_idx) {
VLOG(2) << "process model output index: " << m_idx; VLOG(2) << "process model output index: " << m_idx;
auto output = res.outputs(m_idx); auto output = res.outputs(m_idx);
...@@ -326,6 +327,8 @@ int PredictorClient::batch_predict( ...@@ -326,6 +327,8 @@ int PredictorClient::batch_predict(
fprintf(stderr, "%s\n", oss.str().c_str()); fprintf(stderr, "%s\n", oss.str().c_str());
} }
_api.thrd_clear();
return 0; return 0;
} }
......
...@@ -27,9 +27,9 @@ namespace predictor { ...@@ -27,9 +27,9 @@ namespace predictor {
} }
#endif #endif
#ifdef WITH_GPU // #ifdef WITH_GPU
#define USE_PTHREAD // #define USE_PTHREAD
#endif // #endif
#ifdef USE_PTHREAD #ifdef USE_PTHREAD
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "core/sdk-cpp/include/common.h" #include "core/sdk-cpp/include/common.h"
#include "core/sdk-cpp/include/endpoint_config.h" #include "core/sdk-cpp/include/endpoint_config.h"
#include "core/sdk-cpp/include/macros.h"
#include "core/sdk-cpp/include/predictor.h" #include "core/sdk-cpp/include/predictor.h"
#include "core/sdk-cpp/include/stub.h" #include "core/sdk-cpp/include/stub.h"
...@@ -245,7 +246,7 @@ class StubImpl : public Stub { ...@@ -245,7 +246,7 @@ class StubImpl : public Stub {
const brpc::ChannelOptions& options); const brpc::ChannelOptions& options);
StubTLS* get_tls() { StubTLS* get_tls() {
return static_cast<StubTLS*>(bthread_getspecific(_bthread_key)); return static_cast<StubTLS*>(THREAD_GETSPECIFIC(_bthread_key));
} }
private: private:
...@@ -262,7 +263,8 @@ class StubImpl : public Stub { ...@@ -262,7 +263,8 @@ class StubImpl : public Stub {
uint32_t _package_size; uint32_t _package_size;
// tls handlers // tls handlers
bthread_key_t _bthread_key; // bthread_key_t _bthread_key;
THREAD_KEY_T _bthread_key;
// bvar variables // bvar variables
std::map<std::string, BvarWrapper*> _ltc_bvars; std::map<std::string, BvarWrapper*> _ltc_bvars;
......
...@@ -70,7 +70,7 @@ int StubImpl<T, C, R, I, O>::initialize(const VariantInfo& var, ...@@ -70,7 +70,7 @@ int StubImpl<T, C, R, I, O>::initialize(const VariantInfo& var,
_endpoint = ep; _endpoint = ep;
if (bthread_key_create(&_bthread_key, NULL) != 0) { if (THREAD_KEY_CREATE(&_bthread_key, NULL) != 0) {
LOG(FATAL) << "Failed create key for stub tls"; LOG(FATAL) << "Failed create key for stub tls";
return -1; return -1;
} }
...@@ -132,13 +132,13 @@ int StubImpl<T, C, R, I, O>::initialize(const VariantInfo& var, ...@@ -132,13 +132,13 @@ int StubImpl<T, C, R, I, O>::initialize(const VariantInfo& var,
template <typename T, typename C, typename R, typename I, typename O> template <typename T, typename C, typename R, typename I, typename O>
int StubImpl<T, C, R, I, O>::thrd_initialize() { int StubImpl<T, C, R, I, O>::thrd_initialize() {
if (bthread_getspecific(_bthread_key) != NULL) { if (THREAD_GETSPECIFIC(_bthread_key) != NULL) {
LOG(WARNING) << "Already thread initialized for stub"; LOG(WARNING) << "Already thread initialized for stub";
return 0; return 0;
} }
StubTLS* tls = new (std::nothrow) StubTLS(); StubTLS* tls = new (std::nothrow) StubTLS();
if (!tls || bthread_setspecific(_bthread_key, tls) != 0) { if (!tls || THREAD_SETSPECIFIC(_bthread_key, tls) != 0) {
LOG(FATAL) << "Failed binding tls data to bthread_key"; LOG(FATAL) << "Failed binding tls data to bthread_key";
return -1; return -1;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册