diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index 65fa6587ecb68f18b72a03c7f54433252ea1608a..92dca7eeba53c2fa23020526faa83a19a38633b6 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -119,7 +119,7 @@ int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) { LOG(ERROR) << "Predictor Creation Failed"; return -1; } - _api.thrd_initialize(); + // _api.thrd_initialize(); return 0; } @@ -130,7 +130,7 @@ int PredictorClient::create_predictor() { LOG(ERROR) << "Predictor Creation Failed"; return -1; } - _api.thrd_initialize(); + // _api.thrd_initialize(); return 0; } @@ -152,7 +152,7 @@ int PredictorClient::batch_predict( int fetch_name_num = fetch_name.size(); - _api.thrd_clear(); + _api.thrd_initialize(); std::string variant_tag; _predictor = _api.fetch_predictor("general_model", &variant_tag); predict_res_batch.set_variant_tag(variant_tag); @@ -247,8 +247,9 @@ int PredictorClient::batch_predict( } else { client_infer_end = timeline.TimeStampUS(); postprocess_start = client_infer_end; - + VLOG(2) << "get model output num"; uint32_t model_num = res.outputs_size(); + VLOG(2) << "model num: " << model_num; for (uint32_t m_idx = 0; m_idx < model_num; ++m_idx) { VLOG(2) << "process model output index: " << m_idx; auto output = res.outputs(m_idx); @@ -326,6 +327,8 @@ int PredictorClient::batch_predict( fprintf(stderr, "%s\n", oss.str().c_str()); } + + _api.thrd_clear(); return 0; } diff --git a/core/predictor/common/macros.h b/core/predictor/common/macros.h index fa4a068668cb1a37c37a2726634c24be26a3fb40..ba3ac0dae3b22e68198c9ca9995c56a3ba31a55c 100644 --- a/core/predictor/common/macros.h +++ b/core/predictor/common/macros.h @@ -27,9 +27,9 @@ namespace predictor { } #endif -#ifdef WITH_GPU -#define USE_PTHREAD -#endif +// #ifdef WITH_GPU +// #define USE_PTHREAD +// #endif #ifdef USE_PTHREAD diff --git a/core/sdk-cpp/include/stub_impl.h b/core/sdk-cpp/include/stub_impl.h index dc3c16ca6414b915ba3fd5d4feaac501bbe07cba..a112ddf25a2451e1bcffd62654bc0c6d043c9d80 100644 --- a/core/sdk-cpp/include/stub_impl.h +++ b/core/sdk-cpp/include/stub_impl.h @@ -19,6 +19,7 @@ #include #include "core/sdk-cpp/include/common.h" #include "core/sdk-cpp/include/endpoint_config.h" +#include "core/sdk-cpp/include/macros.h" #include "core/sdk-cpp/include/predictor.h" #include "core/sdk-cpp/include/stub.h" @@ -245,7 +246,7 @@ class StubImpl : public Stub { const brpc::ChannelOptions& options); StubTLS* get_tls() { - return static_cast(bthread_getspecific(_bthread_key)); + return static_cast(THREAD_GETSPECIFIC(_bthread_key)); } private: @@ -262,7 +263,8 @@ class StubImpl : public Stub { uint32_t _package_size; // tls handlers - bthread_key_t _bthread_key; + // bthread_key_t _bthread_key; + THREAD_KEY_T _bthread_key; // bvar variables std::map _ltc_bvars; diff --git a/core/sdk-cpp/include/stub_impl.hpp b/core/sdk-cpp/include/stub_impl.hpp index 6fad5b5e2c702652126bc159333046790fcefc69..756c12893393f10a1c2ebfa83bf3a94adac7a4bc 100644 --- a/core/sdk-cpp/include/stub_impl.hpp +++ b/core/sdk-cpp/include/stub_impl.hpp @@ -70,7 +70,7 @@ int StubImpl::initialize(const VariantInfo& var, _endpoint = ep; - if (bthread_key_create(&_bthread_key, NULL) != 0) { + if (THREAD_KEY_CREATE(&_bthread_key, NULL) != 0) { LOG(FATAL) << "Failed create key for stub tls"; return -1; } @@ -132,13 +132,13 @@ int StubImpl::initialize(const VariantInfo& var, template int StubImpl::thrd_initialize() { - if (bthread_getspecific(_bthread_key) != NULL) { + if (THREAD_GETSPECIFIC(_bthread_key) != NULL) { LOG(WARNING) << "Already thread initialized for stub"; return 0; } StubTLS* tls = new (std::nothrow) StubTLS(); - if (!tls || bthread_setspecific(_bthread_key, tls) != 0) { + if (!tls || THREAD_SETSPECIFIC(_bthread_key, tls) != 0) { LOG(FATAL) << "Failed binding tls data to bthread_key"; return -1; }