diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index 65fa6587ecb68f18b72a03c7f54433252ea1608a..92dca7eeba53c2fa23020526faa83a19a38633b6 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -119,7 +119,7 @@ int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) { LOG(ERROR) << "Predictor Creation Failed"; return -1; } - _api.thrd_initialize(); + // _api.thrd_initialize(); return 0; } @@ -130,7 +130,7 @@ int PredictorClient::create_predictor() { LOG(ERROR) << "Predictor Creation Failed"; return -1; } - _api.thrd_initialize(); + // _api.thrd_initialize(); return 0; } @@ -152,7 +152,7 @@ int PredictorClient::batch_predict( int fetch_name_num = fetch_name.size(); - _api.thrd_clear(); + _api.thrd_initialize(); std::string variant_tag; _predictor = _api.fetch_predictor("general_model", &variant_tag); predict_res_batch.set_variant_tag(variant_tag); @@ -247,8 +247,9 @@ int PredictorClient::batch_predict( } else { client_infer_end = timeline.TimeStampUS(); postprocess_start = client_infer_end; - + VLOG(2) << "get model output num"; uint32_t model_num = res.outputs_size(); + VLOG(2) << "model num: " << model_num; for (uint32_t m_idx = 0; m_idx < model_num; ++m_idx) { VLOG(2) << "process model output index: " << m_idx; auto output = res.outputs(m_idx); @@ -326,6 +327,8 @@ int PredictorClient::batch_predict( fprintf(stderr, "%s\n", oss.str().c_str()); } + + _api.thrd_clear(); return 0; } diff --git a/core/predictor/common/macros.h b/core/predictor/common/macros.h index fa4a068668cb1a37c37a2726634c24be26a3fb40..ba3ac0dae3b22e68198c9ca9995c56a3ba31a55c 100644 --- a/core/predictor/common/macros.h +++ b/core/predictor/common/macros.h @@ -27,9 +27,9 @@ namespace predictor { } #endif -#ifdef WITH_GPU -#define USE_PTHREAD -#endif +// #ifdef WITH_GPU +// #define USE_PTHREAD +// #endif #ifdef USE_PTHREAD diff --git a/core/sdk-cpp/include/macros.h b/core/sdk-cpp/include/macros.h new file mode 100644 index 0000000000000000000000000000000000000000..66eaef445f3b54f7d0209c11667aafaed5522569 --- /dev/null +++ b/core/sdk-cpp/include/macros.h @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "core/sdk-cpp/include/common.h" + +namespace baidu { +namespace paddle_serving { +namespace sdk_cpp { + +#ifndef CATCH_ANY_AND_RET +#define CATCH_ANY_AND_RET(errno) \ + catch (...) { \ + LOG(ERROR) << "exception catched"; \ + return errno; \ + } +#endif + +#define USE_PTHREAD + +#ifdef USE_PTHREAD + +#define THREAD_T pthread_t +#define THREAD_KEY_T pthread_key_t +#define THREAD_MUTEX_T pthread_mutex_t +#define THREAD_KEY_CREATE pthread_key_create +#define THREAD_SETSPECIFIC pthread_setspecific +#define THREAD_GETSPECIFIC pthread_getspecific +#define THREAD_CREATE pthread_create +#define THREAD_CANCEL pthread_cancel +#define THREAD_JOIN pthread_join +#define THREAD_KEY_DELETE pthread_key_delete +#define THREAD_MUTEX_INIT pthread_mutex_init +#define THREAD_MUTEX_LOCK pthread_mutex_lock +#define THREAD_MUTEX_UNLOCK pthread_mutex_unlock +#define THREAD_MUTEX_DESTROY pthread_mutex_destroy +#define THREAD_COND_T pthread_cond_t +#define THREAD_COND_INIT pthread_cond_init +#define THREAD_COND_SIGNAL pthread_cond_signal +#define THREAD_COND_WAIT pthread_cond_wait +#define THREAD_COND_DESTROY pthread_cond_destroy + +#else + +#define THREAD_T bthread_t +#define THREAD_KEY_T bthread_key_t +#define THREAD_MUTEX_T bthread_mutex_t +#define THREAD_KEY_CREATE bthread_key_create +#define THREAD_SETSPECIFIC bthread_setspecific +#define THREAD_GETSPECIFIC bthread_getspecific +#define THREAD_CREATE bthread_start_background +#define THREAD_CANCEL bthread_stop +#define THREAD_JOIN bthread_join +#define THREAD_KEY_DELETE bthread_key_delete +#define THREAD_MUTEX_INIT bthread_mutex_init +#define THREAD_MUTEX_LOCK bthread_mutex_lock +#define THREAD_MUTEX_UNLOCK bthread_mutex_unlock +#define THREAD_MUTEX_DESTROY bthread_mutex_destroy +#define THREAD_COND_T bthread_cond_t +#define THREAD_COND_INIT bthread_cond_init +#define THREAD_COND_SIGNAL bthread_cond_signal +#define THREAD_COND_WAIT bthread_cond_wait +#define THREAD_COND_DESTROY bthread_cond_destroy + +#endif + +} // namespace sdk_cpp +} // namespace paddle_serving +} // namespace baidu diff --git a/core/sdk-cpp/include/stub_impl.h b/core/sdk-cpp/include/stub_impl.h index dc3c16ca6414b915ba3fd5d4feaac501bbe07cba..a112ddf25a2451e1bcffd62654bc0c6d043c9d80 100644 --- a/core/sdk-cpp/include/stub_impl.h +++ b/core/sdk-cpp/include/stub_impl.h @@ -19,6 +19,7 @@ #include #include "core/sdk-cpp/include/common.h" #include "core/sdk-cpp/include/endpoint_config.h" +#include "core/sdk-cpp/include/macros.h" #include "core/sdk-cpp/include/predictor.h" #include "core/sdk-cpp/include/stub.h" @@ -245,7 +246,7 @@ class StubImpl : public Stub { const brpc::ChannelOptions& options); StubTLS* get_tls() { - return static_cast(bthread_getspecific(_bthread_key)); + return static_cast(THREAD_GETSPECIFIC(_bthread_key)); } private: @@ -262,7 +263,8 @@ class StubImpl : public Stub { uint32_t _package_size; // tls handlers - bthread_key_t _bthread_key; + // bthread_key_t _bthread_key; + THREAD_KEY_T _bthread_key; // bvar variables std::map _ltc_bvars; diff --git a/core/sdk-cpp/include/stub_impl.hpp b/core/sdk-cpp/include/stub_impl.hpp index 6fad5b5e2c702652126bc159333046790fcefc69..756c12893393f10a1c2ebfa83bf3a94adac7a4bc 100644 --- a/core/sdk-cpp/include/stub_impl.hpp +++ b/core/sdk-cpp/include/stub_impl.hpp @@ -70,7 +70,7 @@ int StubImpl::initialize(const VariantInfo& var, _endpoint = ep; - if (bthread_key_create(&_bthread_key, NULL) != 0) { + if (THREAD_KEY_CREATE(&_bthread_key, NULL) != 0) { LOG(FATAL) << "Failed create key for stub tls"; return -1; } @@ -132,13 +132,13 @@ int StubImpl::initialize(const VariantInfo& var, template int StubImpl::thrd_initialize() { - if (bthread_getspecific(_bthread_key) != NULL) { + if (THREAD_GETSPECIFIC(_bthread_key) != NULL) { LOG(WARNING) << "Already thread initialized for stub"; return 0; } StubTLS* tls = new (std::nothrow) StubTLS(); - if (!tls || bthread_setspecific(_bthread_key, tls) != 0) { + if (!tls || THREAD_SETSPECIFIC(_bthread_key, tls) != 0) { LOG(FATAL) << "Failed binding tls data to bthread_key"; return -1; }