提交 ec38cd8f 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #483 from guru4elephant/grpc_compatible

make serving compatible with grpc
...@@ -119,7 +119,7 @@ int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) { ...@@ -119,7 +119,7 @@ int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) {
LOG(ERROR) << "Predictor Creation Failed"; LOG(ERROR) << "Predictor Creation Failed";
return -1; return -1;
} }
_api.thrd_initialize(); // _api.thrd_initialize();
return 0; return 0;
} }
...@@ -130,7 +130,7 @@ int PredictorClient::create_predictor() { ...@@ -130,7 +130,7 @@ int PredictorClient::create_predictor() {
LOG(ERROR) << "Predictor Creation Failed"; LOG(ERROR) << "Predictor Creation Failed";
return -1; return -1;
} }
_api.thrd_initialize(); // _api.thrd_initialize();
return 0; return 0;
} }
...@@ -152,7 +152,7 @@ int PredictorClient::batch_predict( ...@@ -152,7 +152,7 @@ int PredictorClient::batch_predict(
int fetch_name_num = fetch_name.size(); int fetch_name_num = fetch_name.size();
_api.thrd_clear(); _api.thrd_initialize();
std::string variant_tag; std::string variant_tag;
_predictor = _api.fetch_predictor("general_model", &variant_tag); _predictor = _api.fetch_predictor("general_model", &variant_tag);
predict_res_batch.set_variant_tag(variant_tag); predict_res_batch.set_variant_tag(variant_tag);
...@@ -247,8 +247,9 @@ int PredictorClient::batch_predict( ...@@ -247,8 +247,9 @@ int PredictorClient::batch_predict(
} else { } else {
client_infer_end = timeline.TimeStampUS(); client_infer_end = timeline.TimeStampUS();
postprocess_start = client_infer_end; postprocess_start = client_infer_end;
VLOG(2) << "get model output num";
uint32_t model_num = res.outputs_size(); uint32_t model_num = res.outputs_size();
VLOG(2) << "model num: " << model_num;
for (uint32_t m_idx = 0; m_idx < model_num; ++m_idx) { for (uint32_t m_idx = 0; m_idx < model_num; ++m_idx) {
VLOG(2) << "process model output index: " << m_idx; VLOG(2) << "process model output index: " << m_idx;
auto output = res.outputs(m_idx); auto output = res.outputs(m_idx);
...@@ -326,6 +327,8 @@ int PredictorClient::batch_predict( ...@@ -326,6 +327,8 @@ int PredictorClient::batch_predict(
fprintf(stderr, "%s\n", oss.str().c_str()); fprintf(stderr, "%s\n", oss.str().c_str());
} }
_api.thrd_clear();
return 0; return 0;
} }
......
...@@ -27,9 +27,9 @@ namespace predictor { ...@@ -27,9 +27,9 @@ namespace predictor {
} }
#endif #endif
#ifdef WITH_GPU // #ifdef WITH_GPU
#define USE_PTHREAD // #define USE_PTHREAD
#endif // #endif
#ifdef USE_PTHREAD #ifdef USE_PTHREAD
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "core/sdk-cpp/include/common.h"
namespace baidu {
namespace paddle_serving {
namespace sdk_cpp {
#ifndef CATCH_ANY_AND_RET
#define CATCH_ANY_AND_RET(errno) \
catch (...) { \
LOG(ERROR) << "exception catched"; \
return errno; \
}
#endif
#define USE_PTHREAD
#ifdef USE_PTHREAD
#define THREAD_T pthread_t
#define THREAD_KEY_T pthread_key_t
#define THREAD_MUTEX_T pthread_mutex_t
#define THREAD_KEY_CREATE pthread_key_create
#define THREAD_SETSPECIFIC pthread_setspecific
#define THREAD_GETSPECIFIC pthread_getspecific
#define THREAD_CREATE pthread_create
#define THREAD_CANCEL pthread_cancel
#define THREAD_JOIN pthread_join
#define THREAD_KEY_DELETE pthread_key_delete
#define THREAD_MUTEX_INIT pthread_mutex_init
#define THREAD_MUTEX_LOCK pthread_mutex_lock
#define THREAD_MUTEX_UNLOCK pthread_mutex_unlock
#define THREAD_MUTEX_DESTROY pthread_mutex_destroy
#define THREAD_COND_T pthread_cond_t
#define THREAD_COND_INIT pthread_cond_init
#define THREAD_COND_SIGNAL pthread_cond_signal
#define THREAD_COND_WAIT pthread_cond_wait
#define THREAD_COND_DESTROY pthread_cond_destroy
#else
#define THREAD_T bthread_t
#define THREAD_KEY_T bthread_key_t
#define THREAD_MUTEX_T bthread_mutex_t
#define THREAD_KEY_CREATE bthread_key_create
#define THREAD_SETSPECIFIC bthread_setspecific
#define THREAD_GETSPECIFIC bthread_getspecific
#define THREAD_CREATE bthread_start_background
#define THREAD_CANCEL bthread_stop
#define THREAD_JOIN bthread_join
#define THREAD_KEY_DELETE bthread_key_delete
#define THREAD_MUTEX_INIT bthread_mutex_init
#define THREAD_MUTEX_LOCK bthread_mutex_lock
#define THREAD_MUTEX_UNLOCK bthread_mutex_unlock
#define THREAD_MUTEX_DESTROY bthread_mutex_destroy
#define THREAD_COND_T bthread_cond_t
#define THREAD_COND_INIT bthread_cond_init
#define THREAD_COND_SIGNAL bthread_cond_signal
#define THREAD_COND_WAIT bthread_cond_wait
#define THREAD_COND_DESTROY bthread_cond_destroy
#endif
} // namespace sdk_cpp
} // namespace paddle_serving
} // namespace baidu
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "core/sdk-cpp/include/common.h" #include "core/sdk-cpp/include/common.h"
#include "core/sdk-cpp/include/endpoint_config.h" #include "core/sdk-cpp/include/endpoint_config.h"
#include "core/sdk-cpp/include/macros.h"
#include "core/sdk-cpp/include/predictor.h" #include "core/sdk-cpp/include/predictor.h"
#include "core/sdk-cpp/include/stub.h" #include "core/sdk-cpp/include/stub.h"
...@@ -245,7 +246,7 @@ class StubImpl : public Stub { ...@@ -245,7 +246,7 @@ class StubImpl : public Stub {
const brpc::ChannelOptions& options); const brpc::ChannelOptions& options);
StubTLS* get_tls() { StubTLS* get_tls() {
return static_cast<StubTLS*>(bthread_getspecific(_bthread_key)); return static_cast<StubTLS*>(THREAD_GETSPECIFIC(_bthread_key));
} }
private: private:
...@@ -262,7 +263,8 @@ class StubImpl : public Stub { ...@@ -262,7 +263,8 @@ class StubImpl : public Stub {
uint32_t _package_size; uint32_t _package_size;
// tls handlers // tls handlers
bthread_key_t _bthread_key; // bthread_key_t _bthread_key;
THREAD_KEY_T _bthread_key;
// bvar variables // bvar variables
std::map<std::string, BvarWrapper*> _ltc_bvars; std::map<std::string, BvarWrapper*> _ltc_bvars;
......
...@@ -70,7 +70,7 @@ int StubImpl<T, C, R, I, O>::initialize(const VariantInfo& var, ...@@ -70,7 +70,7 @@ int StubImpl<T, C, R, I, O>::initialize(const VariantInfo& var,
_endpoint = ep; _endpoint = ep;
if (bthread_key_create(&_bthread_key, NULL) != 0) { if (THREAD_KEY_CREATE(&_bthread_key, NULL) != 0) {
LOG(FATAL) << "Failed create key for stub tls"; LOG(FATAL) << "Failed create key for stub tls";
return -1; return -1;
} }
...@@ -132,13 +132,13 @@ int StubImpl<T, C, R, I, O>::initialize(const VariantInfo& var, ...@@ -132,13 +132,13 @@ int StubImpl<T, C, R, I, O>::initialize(const VariantInfo& var,
template <typename T, typename C, typename R, typename I, typename O> template <typename T, typename C, typename R, typename I, typename O>
int StubImpl<T, C, R, I, O>::thrd_initialize() { int StubImpl<T, C, R, I, O>::thrd_initialize() {
if (bthread_getspecific(_bthread_key) != NULL) { if (THREAD_GETSPECIFIC(_bthread_key) != NULL) {
LOG(WARNING) << "Already thread initialized for stub"; LOG(WARNING) << "Already thread initialized for stub";
return 0; return 0;
} }
StubTLS* tls = new (std::nothrow) StubTLS(); StubTLS* tls = new (std::nothrow) StubTLS();
if (!tls || bthread_setspecific(_bthread_key, tls) != 0) { if (!tls || THREAD_SETSPECIFIC(_bthread_key, tls) != 0) {
LOG(FATAL) << "Failed binding tls data to bthread_key"; LOG(FATAL) << "Failed binding tls data to bthread_key";
return -1; return -1;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册