未验证 提交 7ab15dcc 编写于 作者: T Thomas Young 提交者: GitHub

Revert "fix async while"

上级 5eb6946b
...@@ -89,12 +89,11 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) { ...@@ -89,12 +89,11 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
size_t databuf_size = fetchvar_batch * fetchvar_bytesize_index; size_t databuf_size = fetchvar_batch * fetchvar_bytesize_index;
void* databuf_data = void* databuf_data = MempoolWrapper::instance().malloc(databuf_size,memoryPtr);
MempoolWrapper::instance().malloc(databuf_size, memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, databuf_size); paddle::PaddleBuf paddleBuf(databuf_data, databuf_size);
tensor_out.data = paddleBuf; tensor_out.data = paddleBuf;
// tensor_out.data.Resize(databuf_size); //tensor_out.data.Resize(databuf_size);
} else { } else {
// 当taskmeta_num = 1时,由于同时只有一个taskMeta操作task // 当taskmeta_num = 1时,由于同时只有一个taskMeta操作task
// 不涉及线程安全问题,所以此时可以直接由taskMeta->task->resize->copy // 不涉及线程安全问题,所以此时可以直接由taskMeta->task->resize->copy
...@@ -212,38 +211,48 @@ void TaskExecutor<TaskT>::stop() { ...@@ -212,38 +211,48 @@ void TaskExecutor<TaskT>::stop() {
} }
template <typename TaskT> template <typename TaskT>
int TaskExecutor<TaskT>::schedule( TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
const void* inVectorT_ptr, const void* inVectorT_ptr,
void* outVectorT_ptr, void* outVectorT_ptr, MempoolRegion* memoryPtr) { // NOLINT
MempoolRegion* memoryPtr,
THREAD_MUTEX_T* thread_mutex_ptr,
THREAD_COND_T* thread_cond_ptr,
TaskManager<InType, OutType>* task_manager_ptr) { // NOLINT
TaskT* task = butil::get_object<TaskT>(); TaskT* task = butil::get_object<TaskT>();
if (!task) { if (!task) {
LOG(ERROR) << "Failed get TaskT from object pool"; LOG(ERROR) << "Failed get TaskT from object pool";
return -1; return TaskHandler<TaskT>::valid_handle();
} }
task->clear(); task->clear();
task->task_manager_ptr = task_manager_ptr; /*
task->thread_mutex_ptr = thread_mutex_ptr; if (!BatchTasks<TaskT>::check_valid(in, out, _overrun)) {
task->thread_cond_ptr = thread_cond_ptr; LOG(ERROR) << "Invalid input & output";
return TaskHandler<TaskT>::valid_handle();
}
*/
int fds[2];
int rc = pipe(fds);
if (rc != 0) {
LOG(ERROR) << "call pipe() failed, errno=" << errno << ":"
<< strerror(errno);
return TaskHandler<TaskT>::valid_handle();
}
task->read_fd = fds[0];
task->write_fd = fds[1];
task->owner_tid = ::syscall(SYS_gettid); task->owner_tid = ::syscall(SYS_gettid);
task->memoryPtr = memoryPtr; task->memoryPtr = memoryPtr;
// task->_bspec_key = _bspec_key; //task->_bspec_key = _bspec_key;
task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr; task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr;
task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr; task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr;
if (!task->task_init()) { if (!task->task_init()) {
LOG(ERROR) << "task->init() failed"; LOG(ERROR) << "task->init() failed";
return -1;
} }
task->rem = task->batch_size(); task->rem = task->batch_size();
task->index.store(0, butil::memory_order_relaxed); task->index.store(0, butil::memory_order_relaxed);
AutoMutex lock(_mut); AutoMutex lock(_mut);
_task_queue.push_back(task); _task_queue.push_back(task);
THREAD_COND_SIGNAL(&_cond); THREAD_COND_SIGNAL(&_cond);
return 0;
return TaskHandler<TaskT>(*task);
} }
// this function is accessed by multi thread. // this function is accessed by multi thread.
...@@ -398,19 +407,13 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) { ...@@ -398,19 +407,13 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
} }
template <typename InItemT, typename OutItemT> template <typename InItemT, typename OutItemT>
bool TaskManager<InItemT, OutItemT>::schedule( bool TaskManager<InItemT, OutItemT>::schedule(const void* in,
const void* in, void* out, MempoolRegion* memoryPtr) { // NOLINT
void* out, TaskHandler<TaskT> handler =
MempoolRegion* memoryPtr, TaskExecutorVector<TaskT>::instance()[_model_index].schedule(in, out, memoryPtr);
THREAD_MUTEX_T* thread_mutex_ptr,
THREAD_COND_T* thread_cond_ptr) { // NOLINT if (handler.valid()) {
int error_no = TaskExecutorVector<TaskT>::instance()[_model_index].schedule( _task_owned = handler;
in, out, memoryPtr, thread_mutex_ptr, thread_cond_ptr, this);
if (error_no >= 0) {
_task_ready = false;
this->thread_mutex_ptr = thread_mutex_ptr;
this->thread_cond_ptr = thread_cond_ptr;
return true; return true;
} else { } else {
LOG(ERROR) << "failed to schedule task"; LOG(ERROR) << "failed to schedule task";
...@@ -420,13 +423,17 @@ bool TaskManager<InItemT, OutItemT>::schedule( ...@@ -420,13 +423,17 @@ bool TaskManager<InItemT, OutItemT>::schedule(
template <typename InItemT, typename OutItemT> template <typename InItemT, typename OutItemT>
void TaskManager<InItemT, OutItemT>::wait() { void TaskManager<InItemT, OutItemT>::wait() {
THREAD_MUTEX_LOCK(thread_mutex_ptr); char buffer[128];
while (!_task_ready) { while (read(_task_owned.read_fd, buffer, sizeof(buffer)) < 0 &&
THREAD_COND_WAIT(thread_cond_ptr, thread_mutex_ptr); errno == EINTR) {
} }
THREAD_MUTEX_UNLOCK(thread_mutex_ptr);
close(_task_owned.read_fd);
close(_task_owned.write_fd);
_task_owned.read_fd = -1;
_task_owned.write_fd = -1;
return; return;
} }
} // namespace bsf } // namespace bsf
} // namespace im } // namespace im
...@@ -52,9 +52,6 @@ typedef baidu::paddle_serving::predictor::MempoolRegion MempoolRegion; ...@@ -52,9 +52,6 @@ typedef baidu::paddle_serving::predictor::MempoolRegion MempoolRegion;
// `put`. // `put`.
template <typename TaskT> template <typename TaskT>
class BatchTasks; class BatchTasks;
template <typename InItemT, typename OutItemT>
class TaskManager;
// size_t `index` records how many batch have been processing completed. // size_t `index` records how many batch have been processing completed.
// `index` need to be atomic, cause the operation 'notify' is asynchronous. // `index` need to be atomic, cause the operation 'notify' is asynchronous.
template <typename InItemT, typename OutItemT> template <typename InItemT, typename OutItemT>
...@@ -68,6 +65,8 @@ struct Task { ...@@ -68,6 +65,8 @@ struct Task {
typedef std::vector<ShapeVector> VectorOfShapeVector; typedef std::vector<ShapeVector> VectorOfShapeVector;
typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper; typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper;
int read_fd;
int write_fd;
pid_t owner_tid; pid_t owner_tid;
const InVectorT* inVectorT_ptr; const InVectorT* inVectorT_ptr;
OutVectorT* outVectorT_ptr; OutVectorT* outVectorT_ptr;
...@@ -85,17 +84,13 @@ struct Task { ...@@ -85,17 +84,13 @@ struct Task {
// taskmeta_num * set_feed_lod_index.size() // taskmeta_num * set_feed_lod_index.size()
std::vector<OutVectorT> outLodTensorVector; std::vector<OutVectorT> outLodTensorVector;
MempoolRegion* memoryPtr; MempoolRegion* memoryPtr;
TaskManager<InItemT, OutItemT>* task_manager_ptr;
THREAD_MUTEX_T* thread_mutex_ptr;
THREAD_COND_T* thread_cond_ptr;
Task() { Task() {
read_fd = -1;
write_fd = -1;
owner_tid = -1; owner_tid = -1;
inVectorT_ptr = NULL; inVectorT_ptr = NULL;
outVectorT_ptr = NULL; outVectorT_ptr = NULL;
thread_mutex_ptr = NULL;
thread_cond_ptr = NULL;
task_manager_ptr = NULL;
set_feed_lod_index.clear(); set_feed_lod_index.clear();
set_feed_nobatch_index.clear(); set_feed_nobatch_index.clear();
vector_fetch_lod_index.clear(); vector_fetch_lod_index.clear();
...@@ -110,9 +105,8 @@ struct Task { ...@@ -110,9 +105,8 @@ struct Task {
outLodTensorVector.clear(); outLodTensorVector.clear();
} }
~Task() { ~Task() {
thread_mutex_ptr = NULL; read_fd = -1;
thread_cond_ptr = NULL; write_fd = -1;
task_manager_ptr = NULL;
owner_tid = -1; owner_tid = -1;
inVectorT_ptr = NULL; inVectorT_ptr = NULL;
outVectorT_ptr = NULL; outVectorT_ptr = NULL;
...@@ -130,10 +124,9 @@ struct Task { ...@@ -130,10 +124,9 @@ struct Task {
outLodTensorVector.clear(); outLodTensorVector.clear();
} }
void clear() { void clear(){
thread_mutex_ptr = NULL; read_fd = -1;
thread_cond_ptr = NULL; write_fd = -1;
task_manager_ptr = NULL;
owner_tid = -1; owner_tid = -1;
inVectorT_ptr = NULL; inVectorT_ptr = NULL;
outVectorT_ptr = NULL; outVectorT_ptr = NULL;
...@@ -380,12 +373,11 @@ struct Task { ...@@ -380,12 +373,11 @@ struct Task {
// 一次性扩容PaddleTensor中的data和lod // 一次性扩容PaddleTensor中的data和lod
paddle::PaddleTensor& fetchVarTensor = (*outVectorT_ptr)[feedvar_index]; paddle::PaddleTensor& fetchVarTensor = (*outVectorT_ptr)[feedvar_index];
fetchVarTensor.shape[0] = total_shape0; fetchVarTensor.shape[0] = total_shape0;
void* databuf_data = void* databuf_data = MempoolWrapper::instance().malloc(data_length,memoryPtr);
MempoolWrapper::instance().malloc(data_length, memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, data_length); paddle::PaddleBuf paddleBuf(databuf_data, data_length);
fetchVarTensor.data = paddleBuf; fetchVarTensor.data = paddleBuf;
// fetchVarTensor.data.Resize(data_length); //fetchVarTensor.data.Resize(data_length);
// task中的lod补0 // task中的lod补0
if (fetchVarTensor.lod.size() <= 0) { if (fetchVarTensor.lod.size() <= 0) {
fetchVarTensor.lod.push_back({0}); fetchVarTensor.lod.push_back({0});
...@@ -401,7 +393,7 @@ struct Task { ...@@ -401,7 +393,7 @@ struct Task {
size_t once_lod_length = 0; size_t once_lod_length = 0;
for (size_t taskmeta_index = 0; taskmeta_index < total_taskmeta_num; for (size_t taskmeta_index = 0; taskmeta_index < total_taskmeta_num;
++taskmeta_index) { ++taskmeta_index) {
// process data //process data
void* dst_ptr = fetchVarTensor.data.data() + data_length_offset; void* dst_ptr = fetchVarTensor.data.data() + data_length_offset;
void* source_ptr = void* source_ptr =
outLodTensorVector[taskmeta_index][index].data.data(); outLodTensorVector[taskmeta_index][index].data.data();
...@@ -409,7 +401,7 @@ struct Task { ...@@ -409,7 +401,7 @@ struct Task {
outLodTensorVector[taskmeta_index][index].data.length(); outLodTensorVector[taskmeta_index][index].data.length();
memcpy(dst_ptr, source_ptr, once_data_length); memcpy(dst_ptr, source_ptr, once_data_length);
data_length_offset += once_data_length; data_length_offset += once_data_length;
// process lod //process lod
size_t last_lod_value = fetchVarTensor.lod[0][lod_length_offset]; size_t last_lod_value = fetchVarTensor.lod[0][lod_length_offset];
once_lod_length = once_lod_length =
outLodTensorVector[taskmeta_index][index].lod[0].size(); outLodTensorVector[taskmeta_index][index].lod[0].size();
...@@ -420,6 +412,7 @@ struct Task { ...@@ -420,6 +412,7 @@ struct Task {
outLodTensorVector[taskmeta_index][index].lod[0][once_index]; outLodTensorVector[taskmeta_index][index].lod[0][once_index];
lod_length_offset++; lod_length_offset++;
} }
} }
} }
} }
...@@ -552,9 +545,8 @@ class BatchTasks { ...@@ -552,9 +545,8 @@ class BatchTasks {
TaskMetaT tm(task, start_index, add, task->taskmeta_num); TaskMetaT tm(task, start_index, add, task->taskmeta_num);
task->rem -= add; task->rem -= add;
_rem_size -= add; _rem_size -= add;
if (task->taskmeta_num == 0) { if(task->taskmeta_num == 0){
task->total_taskmeta_num = task->total_taskmeta_num = 1 + (task->rem + _batch_size - 1)/_batch_size;
1 + (task->rem + _batch_size - 1) / _batch_size;
} }
task->taskmeta_num += 1; task->taskmeta_num += 1;
_taskmeta_vector.push_back(tm); _taskmeta_vector.push_back(tm);
...@@ -651,8 +643,7 @@ class BatchTasks { ...@@ -651,8 +643,7 @@ class BatchTasks {
paddleTensor.lod = _batch_in_lod[feedvar_index]; paddleTensor.lod = _batch_in_lod[feedvar_index];
paddleTensor.shape = feedVarTensor.shape; paddleTensor.shape = feedVarTensor.shape;
paddleTensor.shape[0] = _total_shape0_batch_in[feedvar_index]; paddleTensor.shape[0] = _total_shape0_batch_in[feedvar_index];
size_t databuf_size = size_t databuf_size = feedvar_bytesize * _total_shape0_batch_in[feedvar_index];
feedvar_bytesize * _total_shape0_batch_in[feedvar_index];
void* databuf_data = MempoolWrapper::instance().malloc(databuf_size); void* databuf_data = MempoolWrapper::instance().malloc(databuf_size);
paddle::PaddleBuf paddleBuf(databuf_data, databuf_size); paddle::PaddleBuf paddleBuf(databuf_data, databuf_size);
paddleTensor.data = paddleBuf; paddleTensor.data = paddleBuf;
...@@ -762,26 +753,25 @@ class BatchTasks { ...@@ -762,26 +753,25 @@ class BatchTasks {
// 此时,无法分辨是否是天然nobatch,此时set_fetch_nobatch_index会漏掉 // 此时,无法分辨是否是天然nobatch,此时set_fetch_nobatch_index会漏掉
// 后续希望在其他地方能够区分两者。 // 后续希望在其他地方能够区分两者。
if (fetchvar_batch_size(fetchvar_index) != _total_fetch_batch) { if (fetchvar_batch_size(fetchvar_index) != _total_fetch_batch) {
if (fetchvar_batch_size(fetchvar_index) <= 0) { if(fetchvar_batch_size(fetchvar_index) <= 0){
// which means error. // which means error.
return false; return false;
} else if (fetchvar_batch_size(fetchvar_index) == 1) { }else if(fetchvar_batch_size(fetchvar_index) == 1){
// which means fetchvar shape[0] = 1. // which means fetchvar shape[0] = 1.
// shape[0] does not change with batch // shape[0] does not change with batch
set_fetch_nobatch_index.insert(fetchvar_index); set_fetch_nobatch_index.insert(fetchvar_index);
_total_fetch_batch = _total_fetch_batch =
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch); std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
} else if (_total_fetch_batch == 1) { }else if(_total_fetch_batch == 1){
// 这时意味着,之前的fetchvar shape[0] 全部都= 1 //这时意味着,之前的fetchvar shape[0] 全部都= 1
// 当前的fetchvar shape[0] > 1 //当前的fetchvar shape[0] > 1
// 所以,之前的都是no_batch //所以,之前的都是no_batch
for (size_t temp_index = fetchvar_index - 1; temp_index >= 0; for(size_t temp_index = fetchvar_index-1; temp_index >= 0; --temp_index){
--temp_index) {
set_fetch_nobatch_index.insert(fetchvar_index); set_fetch_nobatch_index.insert(fetchvar_index);
} }
_total_fetch_batch = _total_fetch_batch =
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch); std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
} else { }else{
// which means error. // which means error.
return false; return false;
} }
...@@ -866,11 +856,10 @@ class BatchTasks { ...@@ -866,11 +856,10 @@ class BatchTasks {
fetchVarTensor.shape[0] = shape0_length; fetchVarTensor.shape[0] = shape0_length;
fetch_lod_index++; fetch_lod_index++;
void* databuf_data = void* databuf_data = MempoolWrapper::instance().malloc(length,task->memoryPtr);
MempoolWrapper::instance().malloc(length, task->memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, length); paddle::PaddleBuf paddleBuf(databuf_data, length);
fetchVarTensor.data = paddleBuf; fetchVarTensor.data = paddleBuf;
// fetchVarTensor.data.Resize(length); //fetchVarTensor.data.Resize(length);
void* dst_ptr = fetchVarTensor.data.data(); void* dst_ptr = fetchVarTensor.data.data();
void* source_ptr = _batch_out[fetchvar_index].data.data() + void* source_ptr = _batch_out[fetchvar_index].data.data() +
shape0_index_start * fetchvar_bytesize_index; shape0_index_start * fetchvar_bytesize_index;
...@@ -897,12 +886,11 @@ class BatchTasks { ...@@ -897,12 +886,11 @@ class BatchTasks {
size_t length = fetchvar_bytesize_index * shape0_length; size_t length = fetchvar_bytesize_index * shape0_length;
fetchVarTensor.shape[0] = shape0_length; fetchVarTensor.shape[0] = shape0_length;
void* databuf_data = void* databuf_data = MempoolWrapper::instance().malloc(length,task->memoryPtr);
MempoolWrapper::instance().malloc(length, task->memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, length); paddle::PaddleBuf paddleBuf(databuf_data, length);
fetchVarTensor.data = paddleBuf; fetchVarTensor.data = paddleBuf;
// fetchVarTensor.data.Resize(length); //fetchVarTensor.data.Resize(length);
void* dst_ptr = fetchVarTensor.data.data(); void* dst_ptr = fetchVarTensor.data.data();
void* source_ptr = _batch_out[fetchvar_index].data.data() + void* source_ptr = _batch_out[fetchvar_index].data.data() +
shape0_index_start * fetchvar_bytesize_index; shape0_index_start * fetchvar_bytesize_index;
...@@ -952,15 +940,12 @@ class BatchTasks { ...@@ -952,15 +940,12 @@ class BatchTasks {
// index是局部变量,fetch_add是原子操作,成功则返回原值。 // index是局部变量,fetch_add是原子操作,成功则返回原值。
// 只有最后一个taskmeta都完成后,该线程的index+add才能>task->batch_size() // 只有最后一个taskmeta都完成后,该线程的index+add才能>task->batch_size()
// 故只有一个线程能进入if{}内.不会造成多线程竞争的问题。 // 故只有一个线程能进入if{}内.不会造成多线程竞争的问题。
size_t index = task->index.fetch_add(add); size_t index = task->index.fetch_add(add);
if ((index + add) >= task->batch_size()) { if ((index + add) >= task->batch_size()) {
task->combine_taskmeta(); task->combine_taskmeta();
char c = 0; char c = 0;
THREAD_MUTEX_LOCK(task->thread_mutex_ptr); while (write(task->write_fd, &c, 1) != 1 && errno == EINTR) {
task->task_manager_ptr->_task_ready = true; }
THREAD_COND_SIGNAL(task->thread_cond_ptr);
THREAD_MUTEX_UNLOCK(task->thread_mutex_ptr);
butil::return_object(task); butil::return_object(task);
} }
} }
...@@ -1000,6 +985,36 @@ class BatchTasks { ...@@ -1000,6 +985,36 @@ class BatchTasks {
bool _allow_split_request; bool _allow_split_request;
}; };
// BSF task handle
// TaskHandler is the handle of Task.
// `read_fd` is used for receive signal in brpc Thread.
// 'write_fd' is used for write signal in bsf Thread.
// when TaskMeta is done, bsf Thread will write to 'write_fd'.
// brpc Thread is keeping reading 'read_fd' in a while loop.
// brpc Thread will receive signal when TaskMeta is done.
// so `read_fd` and 'write_fd' is used for communicate in different Thread.
template <typename TaskT>
struct TaskHandler {
int read_fd;
int write_fd;
TaskHandler() : read_fd(-1), write_fd(-1) {
// do nothing
}
explicit TaskHandler(TaskT const& task)
: read_fd(task.read_fd), write_fd(task.write_fd) {
// do nothing
}
inline bool valid() const { return read_fd >= 0 && write_fd >= 0; }
static TaskHandler<TaskT>& valid_handle() {
static TaskHandler<TaskT> vhandle;
return vhandle;
}
};
// TaskExecutor is a Thread pool. // TaskExecutor is a Thread pool.
template <typename TaskT> template <typename TaskT>
class TaskExecutor; class TaskExecutor;
...@@ -1100,12 +1115,7 @@ class TaskExecutor { ...@@ -1100,12 +1115,7 @@ class TaskExecutor {
int work(ThreadContext<TaskT>* context); int work(ThreadContext<TaskT>* context);
int schedule(const void*, TaskHandler<TaskT> schedule(const void*, void*, MempoolRegion* memoryPtr);
void*,
MempoolRegion* memoryPtr,
THREAD_MUTEX_T* thread_mutex_ptr,
THREAD_COND_T* thread_cond_ptr,
TaskManager<InType, OutType>* task_manager_ptr);
bool move_task_to_batch(BatchTasks<TaskT>& batchTask); // NOLINT bool move_task_to_batch(BatchTasks<TaskT>& batchTask); // NOLINT
...@@ -1184,25 +1194,18 @@ class TaskManager { ...@@ -1184,25 +1194,18 @@ class TaskManager {
typedef typename TaskT::OutVectorT OutVectorT; typedef typename TaskT::OutVectorT OutVectorT;
explicit TaskManager(uint32_t model_index) // NOLINT explicit TaskManager(uint32_t model_index) // NOLINT
: _model_index(model_index), : _model_index(model_index) {}
_task_ready(false) {}
~TaskManager() { wait(); } ~TaskManager() { wait(); }
bool schedule(const void* in, bool schedule(const void* in, void* out, MempoolRegion* memoryPtr); // NOLINT
void* out,
MempoolRegion* memoryPtr,
THREAD_MUTEX_T* thread_mutex_ptr,
THREAD_COND_T* thread_cond_ptr); // NOLINT
void wait(); void wait();
inline void clear() { wait(); } inline void clear() { wait(); }
bool _task_ready = false;
private: private:
TaskHandler<TaskT> _task_owned;
uint32_t _model_index; uint32_t _model_index;
THREAD_MUTEX_T* thread_mutex_ptr;
THREAD_COND_T* thread_cond_ptr;
}; // class TaskManager }; // class TaskManager
class AutoMutex { class AutoMutex {
......
...@@ -98,11 +98,7 @@ int ReloadableInferEngine::infer(const void* in, ...@@ -98,11 +98,7 @@ int ReloadableInferEngine::infer(const void* in,
im::bsf::TaskManager<paddle::PaddleTensor, paddle::PaddleTensor> task_manager( im::bsf::TaskManager<paddle::PaddleTensor, paddle::PaddleTensor> task_manager(
_model_index); _model_index);
task_manager.schedule(in, task_manager.schedule(in, out, MempoolWrapper::instance().get_thread_memory_ptr());
out,
MempoolWrapper::instance().get_thread_memory_ptr(),
ThreadMutex::instance().get_thread_mutex_ptr(),
ThreadMutex::instance().get_thread_cond_ptr());
task_manager.wait(); task_manager.wait();
return 0; return 0;
} }
......
...@@ -31,9 +31,8 @@ ...@@ -31,9 +31,8 @@
#include "core/predictor/framework/infer_data.h" #include "core/predictor/framework/infer_data.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/predictor_metric.h" #include "core/predictor/framework/predictor_metric.h"
#include "core/predictor/framework/thread_mutex.h"
#include "experimental/float16.h"
#include "paddle_inference_api.h" // NOLINT #include "paddle_inference_api.h" // NOLINT
#include "experimental/float16.h"
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
namespace predictor { namespace predictor {
......
...@@ -97,12 +97,6 @@ int Resource::initialize(const std::string& path, const std::string& file) { ...@@ -97,12 +97,6 @@ int Resource::initialize(const std::string& path, const std::string& file) {
} }
LOG(WARNING) << "Successfully proc initialized mempool wrapper"; LOG(WARNING) << "Successfully proc initialized mempool wrapper";
if (ThreadMutex::instance().initialize() != 0) {
LOG(ERROR) << "Failed proc initialized mempool wrapper";
return -1;
}
LOG(WARNING) << "Successfully proc initialized ThreadMutex";
#ifdef WITH_AUTH #ifdef WITH_AUTH
std::string product_name_str = resource_conf.auth_product_name(); std::string product_name_str = resource_conf.auth_product_name();
std::string container_id_str = resource_conf.auth_container_id(); std::string container_id_str = resource_conf.auth_container_id();
...@@ -307,12 +301,6 @@ int Resource::thread_initialize() { ...@@ -307,12 +301,6 @@ int Resource::thread_initialize() {
} }
LOG(WARNING) << "Successfully thread initialized mempool wrapper"; LOG(WARNING) << "Successfully thread initialized mempool wrapper";
if (ThreadMutex::instance().thread_initialize() != 0) {
LOG(ERROR) << "Failed thread initialized ThreadMutex";
return -1;
}
LOG(WARNING) << "Successfully thread initialized ThreadMutex";
// infer manager // infer manager
if (FLAGS_enable_model_toolkit && if (FLAGS_enable_model_toolkit &&
InferManager::instance().thrd_initialize() != 0) { InferManager::instance().thrd_initialize() != 0) {
...@@ -356,12 +344,6 @@ int Resource::finalize() { ...@@ -356,12 +344,6 @@ int Resource::finalize() {
LOG(ERROR) << "Failed proc finalize infer manager"; LOG(ERROR) << "Failed proc finalize infer manager";
return -1; return -1;
} }
if (ThreadMutex::instance().finalize() != 0) {
LOG(ERROR) << "Failed proc finalize ThreadMutex";
return -1;
}
if (CubeAPI::instance()->destroy() != 0) { if (CubeAPI::instance()->destroy() != 0) {
LOG(ERROR) << "Destory cube api failed "; LOG(ERROR) << "Destory cube api failed ";
return -1; return -1;
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "core/predictor/common/inner_common.h" #include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/thread_mutex.h"
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/predictor/framework/thread_mutex.h"
#include "core/predictor/common/inner_common.h"
namespace baidu {
namespace paddle_serving {
namespace predictor {
int ThreadMutex::initialize() {
if (THREAD_KEY_CREATE(&_bspec_key_mutex, NULL) != 0) {
LOG(ERROR) << "unable to create thread_key of thrd_data";
return -1;
}
if (THREAD_SETSPECIFIC(_bspec_key_mutex, NULL) != 0) {
LOG(ERROR) << "failed initialize bsepecific key to null";
return -1;
}
if (THREAD_KEY_CREATE(&_bspec_key_cond, NULL) != 0) {
LOG(ERROR) << "unable to create thread_key of thrd_data";
return -1;
}
if (THREAD_SETSPECIFIC(_bspec_key_cond, NULL) != 0) {
LOG(ERROR) << "failed initialize bsepecific key to null";
return -1;
}
return 0;
}
int ThreadMutex::thread_initialize() {
THREAD_MUTEX_T* mutex_ptr = new THREAD_MUTEX_T();
THREAD_MUTEX_INIT(mutex_ptr, NULL);
if (THREAD_SETSPECIFIC(_bspec_key_mutex, mutex_ptr) != 0) {
LOG(ERROR) << "unable to set the thrd_data";
delete mutex_ptr;
return -1;
}
THREAD_COND_T* cont_ptr = new THREAD_COND_T();
THREAD_COND_INIT(cont_ptr, NULL);
if (THREAD_SETSPECIFIC(_bspec_key_cond, cont_ptr) != 0) {
LOG(ERROR) << "unable to set the thrd_data";
delete cont_ptr;
return -1;
}
LOG(WARNING) << "Succ thread initialize ThreadMutex";
return 0;
}
int ThreadMutex::thread_finalize() {
THREAD_MUTEX_T* mutex_ptr =
(THREAD_MUTEX_T*)THREAD_GETSPECIFIC(_bspec_key_mutex);
if (mutex_ptr != NULL) {
THREAD_MUTEX_DESTROY(mutex_ptr);
delete mutex_ptr;
}
THREAD_COND_T* cont_ptr = (THREAD_COND_T*)THREAD_GETSPECIFIC(_bspec_key_cond);
if (cont_ptr != NULL) {
THREAD_COND_DESTROY(cont_ptr);
delete cont_ptr;
}
LOG(WARNING) << "Succ thread initialize ThreadMutex";
return 0;
}
int ThreadMutex::finalize() {
THREAD_KEY_DELETE(_bspec_key_mutex);
THREAD_KEY_DELETE(_bspec_key_cond);
return 0;
}
THREAD_MUTEX_T* ThreadMutex::get_thread_mutex_ptr() {
THREAD_MUTEX_T* mutex_ptr =
(THREAD_MUTEX_T*)THREAD_GETSPECIFIC(_bspec_key_mutex);
return mutex_ptr;
}
THREAD_COND_T* ThreadMutex::get_thread_cond_ptr() {
THREAD_COND_T* cont_ptr = (THREAD_COND_T*)THREAD_GETSPECIFIC(_bspec_key_cond);
return cont_ptr;
}
} // namespace predictor
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "core/predictor/common/inner_common.h"
namespace baidu {
namespace paddle_serving {
namespace predictor {
class ThreadMutex {
public:
ThreadMutex() {}
static ThreadMutex& instance() {
static ThreadMutex thread_mutex;
return thread_mutex;
}
int initialize();
int thread_initialize();
THREAD_MUTEX_T* get_thread_mutex_ptr();
THREAD_COND_T* get_thread_cond_ptr();
int thread_finalize();
int finalize();
private:
THREAD_KEY_T _bspec_key_mutex;
THREAD_KEY_T _bspec_key_cond;
};
} // namespace predictor
} // namespace paddle_serving
} // namespace baidu
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册