提交 a1e33e46 编写于 作者: H HexToString

fix async while

上级 550852a7
...@@ -86,14 +86,15 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) { ...@@ -86,14 +86,15 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
// 此时 lod 为空。 // 此时 lod 为空。
tensor_out.lod = batchTask._batch_out[fetchvar_index].lod; tensor_out.lod = batchTask._batch_out[fetchvar_index].lod;
// resize all batch memory at one time // resize all batch memory at one time
size_t databuf_size = fetchvar_batch * fetchvar_bytesize_index; size_t databuf_size = fetchvar_batch * fetchvar_bytesize_index;
void* databuf_data = MempoolWrapper::instance().malloc(databuf_size,memoryPtr); void* databuf_data =
MempoolWrapper::instance().malloc(databuf_size, memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, databuf_size); paddle::PaddleBuf paddleBuf(databuf_data, databuf_size);
tensor_out.data = paddleBuf; tensor_out.data = paddleBuf;
//tensor_out.data.Resize(databuf_size); // tensor_out.data.Resize(databuf_size);
} else { } else {
// 当taskmeta_num = 1时,由于同时只有一个taskMeta操作task // 当taskmeta_num = 1时,由于同时只有一个taskMeta操作task
// 不涉及线程安全问题,所以此时可以直接由taskMeta->task->resize->copy // 不涉及线程安全问题,所以此时可以直接由taskMeta->task->resize->copy
...@@ -211,48 +212,38 @@ void TaskExecutor<TaskT>::stop() { ...@@ -211,48 +212,38 @@ void TaskExecutor<TaskT>::stop() {
} }
template <typename TaskT> template <typename TaskT>
TaskHandler<TaskT> TaskExecutor<TaskT>::schedule( int TaskExecutor<TaskT>::schedule(
const void* inVectorT_ptr, const void* inVectorT_ptr,
void* outVectorT_ptr, MempoolRegion* memoryPtr) { // NOLINT void* outVectorT_ptr,
MempoolRegion* memoryPtr,
THREAD_MUTEX_T* thread_mutex_ptr,
THREAD_COND_T* thread_cond_ptr,
TaskManager<InType, OutType>* task_manager_ptr) { // NOLINT
TaskT* task = butil::get_object<TaskT>(); TaskT* task = butil::get_object<TaskT>();
if (!task) { if (!task) {
LOG(ERROR) << "Failed get TaskT from object pool"; LOG(ERROR) << "Failed get TaskT from object pool";
return TaskHandler<TaskT>::valid_handle(); return -1;
} }
task->clear(); task->clear();
/* task->task_manager_ptr = task_manager_ptr;
if (!BatchTasks<TaskT>::check_valid(in, out, _overrun)) { task->thread_mutex_ptr = thread_mutex_ptr;
LOG(ERROR) << "Invalid input & output"; task->thread_cond_ptr = thread_cond_ptr;
return TaskHandler<TaskT>::valid_handle();
}
*/
int fds[2];
int rc = pipe(fds);
if (rc != 0) {
LOG(ERROR) << "call pipe() failed, errno=" << errno << ":"
<< strerror(errno);
return TaskHandler<TaskT>::valid_handle();
}
task->read_fd = fds[0];
task->write_fd = fds[1];
task->owner_tid = ::syscall(SYS_gettid); task->owner_tid = ::syscall(SYS_gettid);
task->memoryPtr = memoryPtr; task->memoryPtr = memoryPtr;
//task->_bspec_key = _bspec_key; // task->_bspec_key = _bspec_key;
task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr; task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr;
task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr; task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr;
if (!task->task_init()) { if (!task->task_init()) {
LOG(ERROR) << "task->init() failed"; LOG(ERROR) << "task->init() failed";
return -1;
} }
task->rem = task->batch_size(); task->rem = task->batch_size();
task->index.store(0, butil::memory_order_relaxed); task->index.store(0, butil::memory_order_relaxed);
AutoMutex lock(_mut); AutoMutex lock(_mut);
_task_queue.push_back(task); _task_queue.push_back(task);
THREAD_COND_SIGNAL(&_cond); THREAD_COND_SIGNAL(&_cond);
return 0;
return TaskHandler<TaskT>(*task);
} }
// this function is accessed by multi thread. // this function is accessed by multi thread.
...@@ -407,13 +398,19 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) { ...@@ -407,13 +398,19 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
} }
template <typename InItemT, typename OutItemT> template <typename InItemT, typename OutItemT>
bool TaskManager<InItemT, OutItemT>::schedule(const void* in, bool TaskManager<InItemT, OutItemT>::schedule(
void* out, MempoolRegion* memoryPtr) { // NOLINT const void* in,
TaskHandler<TaskT> handler = void* out,
TaskExecutorVector<TaskT>::instance()[_model_index].schedule(in, out, memoryPtr); MempoolRegion* memoryPtr,
THREAD_MUTEX_T* thread_mutex_ptr,
if (handler.valid()) { THREAD_COND_T* thread_cond_ptr) { // NOLINT
_task_owned = handler; int error_no = TaskExecutorVector<TaskT>::instance()[_model_index].schedule(
in, out, memoryPtr, thread_mutex_ptr, thread_cond_ptr, this);
if (error_no >= 0) {
_task_ready = false;
this->thread_mutex_ptr = thread_mutex_ptr;
this->thread_cond_ptr = thread_cond_ptr;
return true; return true;
} else { } else {
LOG(ERROR) << "failed to schedule task"; LOG(ERROR) << "failed to schedule task";
...@@ -423,17 +420,13 @@ bool TaskManager<InItemT, OutItemT>::schedule(const void* in, ...@@ -423,17 +420,13 @@ bool TaskManager<InItemT, OutItemT>::schedule(const void* in,
template <typename InItemT, typename OutItemT> template <typename InItemT, typename OutItemT>
void TaskManager<InItemT, OutItemT>::wait() { void TaskManager<InItemT, OutItemT>::wait() {
char buffer[128]; THREAD_MUTEX_LOCK(thread_mutex_ptr);
while (read(_task_owned.read_fd, buffer, sizeof(buffer)) < 0 && while (!_task_ready) {
errno == EINTR) { THREAD_COND_WAIT(thread_cond_ptr, thread_mutex_ptr);
} }
THREAD_MUTEX_UNLOCK(thread_mutex_ptr);
close(_task_owned.read_fd);
close(_task_owned.write_fd);
_task_owned.read_fd = -1;
_task_owned.write_fd = -1;
return; return;
} }
} // namespace bsf } // namespace bsf
} // namespace im } // namespace im
...@@ -52,6 +52,9 @@ typedef baidu::paddle_serving::predictor::MempoolRegion MempoolRegion; ...@@ -52,6 +52,9 @@ typedef baidu::paddle_serving::predictor::MempoolRegion MempoolRegion;
// `put`. // `put`.
template <typename TaskT> template <typename TaskT>
class BatchTasks; class BatchTasks;
template <typename InItemT, typename OutItemT>
class TaskManager;
// size_t `index` records how many batch have been processing completed. // size_t `index` records how many batch have been processing completed.
// `index` need to be atomic, cause the operation 'notify' is asynchronous. // `index` need to be atomic, cause the operation 'notify' is asynchronous.
template <typename InItemT, typename OutItemT> template <typename InItemT, typename OutItemT>
...@@ -65,8 +68,6 @@ struct Task { ...@@ -65,8 +68,6 @@ struct Task {
typedef std::vector<ShapeVector> VectorOfShapeVector; typedef std::vector<ShapeVector> VectorOfShapeVector;
typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper; typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper;
int read_fd;
int write_fd;
pid_t owner_tid; pid_t owner_tid;
const InVectorT* inVectorT_ptr; const InVectorT* inVectorT_ptr;
OutVectorT* outVectorT_ptr; OutVectorT* outVectorT_ptr;
...@@ -84,13 +85,17 @@ struct Task { ...@@ -84,13 +85,17 @@ struct Task {
// taskmeta_num * set_feed_lod_index.size() // taskmeta_num * set_feed_lod_index.size()
std::vector<OutVectorT> outLodTensorVector; std::vector<OutVectorT> outLodTensorVector;
MempoolRegion* memoryPtr; MempoolRegion* memoryPtr;
TaskManager<InItemT, OutItemT>* task_manager_ptr;
THREAD_MUTEX_T* thread_mutex_ptr;
THREAD_COND_T* thread_cond_ptr;
Task() { Task() {
read_fd = -1;
write_fd = -1;
owner_tid = -1; owner_tid = -1;
inVectorT_ptr = NULL; inVectorT_ptr = NULL;
outVectorT_ptr = NULL; outVectorT_ptr = NULL;
thread_mutex_ptr = NULL;
thread_cond_ptr = NULL;
task_manager_ptr = NULL;
set_feed_lod_index.clear(); set_feed_lod_index.clear();
set_feed_nobatch_index.clear(); set_feed_nobatch_index.clear();
vector_fetch_lod_index.clear(); vector_fetch_lod_index.clear();
...@@ -105,8 +110,9 @@ struct Task { ...@@ -105,8 +110,9 @@ struct Task {
outLodTensorVector.clear(); outLodTensorVector.clear();
} }
~Task() { ~Task() {
read_fd = -1; thread_mutex_ptr = NULL;
write_fd = -1; thread_cond_ptr = NULL;
task_manager_ptr = NULL;
owner_tid = -1; owner_tid = -1;
inVectorT_ptr = NULL; inVectorT_ptr = NULL;
outVectorT_ptr = NULL; outVectorT_ptr = NULL;
...@@ -124,9 +130,10 @@ struct Task { ...@@ -124,9 +130,10 @@ struct Task {
outLodTensorVector.clear(); outLodTensorVector.clear();
} }
void clear(){ void clear() {
read_fd = -1; thread_mutex_ptr = NULL;
write_fd = -1; thread_cond_ptr = NULL;
task_manager_ptr = NULL;
owner_tid = -1; owner_tid = -1;
inVectorT_ptr = NULL; inVectorT_ptr = NULL;
outVectorT_ptr = NULL; outVectorT_ptr = NULL;
...@@ -373,11 +380,12 @@ struct Task { ...@@ -373,11 +380,12 @@ struct Task {
// 一次性扩容PaddleTensor中的data和lod // 一次性扩容PaddleTensor中的data和lod
paddle::PaddleTensor& fetchVarTensor = (*outVectorT_ptr)[feedvar_index]; paddle::PaddleTensor& fetchVarTensor = (*outVectorT_ptr)[feedvar_index];
fetchVarTensor.shape[0] = total_shape0; fetchVarTensor.shape[0] = total_shape0;
void* databuf_data = MempoolWrapper::instance().malloc(data_length,memoryPtr); void* databuf_data =
MempoolWrapper::instance().malloc(data_length, memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, data_length); paddle::PaddleBuf paddleBuf(databuf_data, data_length);
fetchVarTensor.data = paddleBuf; fetchVarTensor.data = paddleBuf;
//fetchVarTensor.data.Resize(data_length); // fetchVarTensor.data.Resize(data_length);
// task中的lod补0 // task中的lod补0
if (fetchVarTensor.lod.size() <= 0) { if (fetchVarTensor.lod.size() <= 0) {
fetchVarTensor.lod.push_back({0}); fetchVarTensor.lod.push_back({0});
...@@ -393,7 +401,7 @@ struct Task { ...@@ -393,7 +401,7 @@ struct Task {
size_t once_lod_length = 0; size_t once_lod_length = 0;
for (size_t taskmeta_index = 0; taskmeta_index < total_taskmeta_num; for (size_t taskmeta_index = 0; taskmeta_index < total_taskmeta_num;
++taskmeta_index) { ++taskmeta_index) {
//process data // process data
void* dst_ptr = fetchVarTensor.data.data() + data_length_offset; void* dst_ptr = fetchVarTensor.data.data() + data_length_offset;
void* source_ptr = void* source_ptr =
outLodTensorVector[taskmeta_index][index].data.data(); outLodTensorVector[taskmeta_index][index].data.data();
...@@ -401,7 +409,7 @@ struct Task { ...@@ -401,7 +409,7 @@ struct Task {
outLodTensorVector[taskmeta_index][index].data.length(); outLodTensorVector[taskmeta_index][index].data.length();
memcpy(dst_ptr, source_ptr, once_data_length); memcpy(dst_ptr, source_ptr, once_data_length);
data_length_offset += once_data_length; data_length_offset += once_data_length;
//process lod // process lod
size_t last_lod_value = fetchVarTensor.lod[0][lod_length_offset]; size_t last_lod_value = fetchVarTensor.lod[0][lod_length_offset];
once_lod_length = once_lod_length =
outLodTensorVector[taskmeta_index][index].lod[0].size(); outLodTensorVector[taskmeta_index][index].lod[0].size();
...@@ -412,7 +420,6 @@ struct Task { ...@@ -412,7 +420,6 @@ struct Task {
outLodTensorVector[taskmeta_index][index].lod[0][once_index]; outLodTensorVector[taskmeta_index][index].lod[0][once_index];
lod_length_offset++; lod_length_offset++;
} }
} }
} }
} }
...@@ -545,8 +552,9 @@ class BatchTasks { ...@@ -545,8 +552,9 @@ class BatchTasks {
TaskMetaT tm(task, start_index, add, task->taskmeta_num); TaskMetaT tm(task, start_index, add, task->taskmeta_num);
task->rem -= add; task->rem -= add;
_rem_size -= add; _rem_size -= add;
if(task->taskmeta_num == 0){ if (task->taskmeta_num == 0) {
task->total_taskmeta_num = 1 + (task->rem + _batch_size - 1)/_batch_size; task->total_taskmeta_num =
1 + (task->rem + _batch_size - 1) / _batch_size;
} }
task->taskmeta_num += 1; task->taskmeta_num += 1;
_taskmeta_vector.push_back(tm); _taskmeta_vector.push_back(tm);
...@@ -643,7 +651,8 @@ class BatchTasks { ...@@ -643,7 +651,8 @@ class BatchTasks {
paddleTensor.lod = _batch_in_lod[feedvar_index]; paddleTensor.lod = _batch_in_lod[feedvar_index];
paddleTensor.shape = feedVarTensor.shape; paddleTensor.shape = feedVarTensor.shape;
paddleTensor.shape[0] = _total_shape0_batch_in[feedvar_index]; paddleTensor.shape[0] = _total_shape0_batch_in[feedvar_index];
size_t databuf_size = feedvar_bytesize * _total_shape0_batch_in[feedvar_index]; size_t databuf_size =
feedvar_bytesize * _total_shape0_batch_in[feedvar_index];
void* databuf_data = MempoolWrapper::instance().malloc(databuf_size); void* databuf_data = MempoolWrapper::instance().malloc(databuf_size);
paddle::PaddleBuf paddleBuf(databuf_data, databuf_size); paddle::PaddleBuf paddleBuf(databuf_data, databuf_size);
paddleTensor.data = paddleBuf; paddleTensor.data = paddleBuf;
...@@ -753,25 +762,26 @@ class BatchTasks { ...@@ -753,25 +762,26 @@ class BatchTasks {
// 此时,无法分辨是否是天然nobatch,此时set_fetch_nobatch_index会漏掉 // 此时,无法分辨是否是天然nobatch,此时set_fetch_nobatch_index会漏掉
// 后续希望在其他地方能够区分两者。 // 后续希望在其他地方能够区分两者。
if (fetchvar_batch_size(fetchvar_index) != _total_fetch_batch) { if (fetchvar_batch_size(fetchvar_index) != _total_fetch_batch) {
if(fetchvar_batch_size(fetchvar_index) <= 0){ if (fetchvar_batch_size(fetchvar_index) <= 0) {
// which means error. // which means error.
return false; return false;
}else if(fetchvar_batch_size(fetchvar_index) == 1){ } else if (fetchvar_batch_size(fetchvar_index) == 1) {
// which means fetchvar shape[0] = 1. // which means fetchvar shape[0] = 1.
// shape[0] does not change with batch // shape[0] does not change with batch
set_fetch_nobatch_index.insert(fetchvar_index); set_fetch_nobatch_index.insert(fetchvar_index);
_total_fetch_batch = _total_fetch_batch =
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch); std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
}else if(_total_fetch_batch == 1){ } else if (_total_fetch_batch == 1) {
//这时意味着,之前的fetchvar shape[0] 全部都= 1 // 这时意味着,之前的fetchvar shape[0] 全部都= 1
//当前的fetchvar shape[0] > 1 // 当前的fetchvar shape[0] > 1
//所以,之前的都是no_batch // 所以,之前的都是no_batch
for(size_t temp_index = fetchvar_index-1; temp_index >= 0; --temp_index){ for (size_t temp_index = fetchvar_index - 1; temp_index >= 0;
--temp_index) {
set_fetch_nobatch_index.insert(fetchvar_index); set_fetch_nobatch_index.insert(fetchvar_index);
} }
_total_fetch_batch = _total_fetch_batch =
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch); std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
}else{ } else {
// which means error. // which means error.
return false; return false;
} }
...@@ -856,10 +866,11 @@ class BatchTasks { ...@@ -856,10 +866,11 @@ class BatchTasks {
fetchVarTensor.shape[0] = shape0_length; fetchVarTensor.shape[0] = shape0_length;
fetch_lod_index++; fetch_lod_index++;
void* databuf_data = MempoolWrapper::instance().malloc(length,task->memoryPtr); void* databuf_data =
MempoolWrapper::instance().malloc(length, task->memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, length); paddle::PaddleBuf paddleBuf(databuf_data, length);
fetchVarTensor.data = paddleBuf; fetchVarTensor.data = paddleBuf;
//fetchVarTensor.data.Resize(length); // fetchVarTensor.data.Resize(length);
void* dst_ptr = fetchVarTensor.data.data(); void* dst_ptr = fetchVarTensor.data.data();
void* source_ptr = _batch_out[fetchvar_index].data.data() + void* source_ptr = _batch_out[fetchvar_index].data.data() +
shape0_index_start * fetchvar_bytesize_index; shape0_index_start * fetchvar_bytesize_index;
...@@ -885,12 +896,13 @@ class BatchTasks { ...@@ -885,12 +896,13 @@ class BatchTasks {
(*task->outVectorT_ptr)[fetchvar_index]; (*task->outVectorT_ptr)[fetchvar_index];
size_t length = fetchvar_bytesize_index * shape0_length; size_t length = fetchvar_bytesize_index * shape0_length;
fetchVarTensor.shape[0] = shape0_length; fetchVarTensor.shape[0] = shape0_length;
void* databuf_data = MempoolWrapper::instance().malloc(length,task->memoryPtr); void* databuf_data =
MempoolWrapper::instance().malloc(length, task->memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, length); paddle::PaddleBuf paddleBuf(databuf_data, length);
fetchVarTensor.data = paddleBuf; fetchVarTensor.data = paddleBuf;
//fetchVarTensor.data.Resize(length); // fetchVarTensor.data.Resize(length);
void* dst_ptr = fetchVarTensor.data.data(); void* dst_ptr = fetchVarTensor.data.data();
void* source_ptr = _batch_out[fetchvar_index].data.data() + void* source_ptr = _batch_out[fetchvar_index].data.data() +
shape0_index_start * fetchvar_bytesize_index; shape0_index_start * fetchvar_bytesize_index;
...@@ -940,12 +952,15 @@ class BatchTasks { ...@@ -940,12 +952,15 @@ class BatchTasks {
// index是局部变量,fetch_add是原子操作,成功则返回原值。 // index是局部变量,fetch_add是原子操作,成功则返回原值。
// 只有最后一个taskmeta都完成后,该线程的index+add才能>task->batch_size() // 只有最后一个taskmeta都完成后,该线程的index+add才能>task->batch_size()
// 故只有一个线程能进入if{}内.不会造成多线程竞争的问题。 // 故只有一个线程能进入if{}内.不会造成多线程竞争的问题。
size_t index = task->index.fetch_add(add); size_t index = task->index.fetch_add(add);
if ((index + add) >= task->batch_size()) { if ((index + add) >= task->batch_size()) {
task->combine_taskmeta(); task->combine_taskmeta();
char c = 0; char c = 0;
while (write(task->write_fd, &c, 1) != 1 && errno == EINTR) { THREAD_MUTEX_LOCK(task->thread_mutex_ptr);
} task->task_manager_ptr->_task_ready = true;
THREAD_COND_SIGNAL(task->thread_cond_ptr);
THREAD_MUTEX_UNLOCK(task->thread_mutex_ptr);
butil::return_object(task); butil::return_object(task);
} }
} }
...@@ -985,36 +1000,6 @@ class BatchTasks { ...@@ -985,36 +1000,6 @@ class BatchTasks {
bool _allow_split_request; bool _allow_split_request;
}; };
// BSF task handle
// TaskHandler is the handle of Task.
// `read_fd` is used for receive signal in brpc Thread.
// 'write_fd' is used for write signal in bsf Thread.
// when TaskMeta is done, bsf Thread will write to 'write_fd'.
// brpc Thread is keeping reading 'read_fd' in a while loop.
// brpc Thread will receive signal when TaskMeta is done.
// so `read_fd` and 'write_fd' is used for communicate in different Thread.
template <typename TaskT>
struct TaskHandler {
int read_fd;
int write_fd;
TaskHandler() : read_fd(-1), write_fd(-1) {
// do nothing
}
explicit TaskHandler(TaskT const& task)
: read_fd(task.read_fd), write_fd(task.write_fd) {
// do nothing
}
inline bool valid() const { return read_fd >= 0 && write_fd >= 0; }
static TaskHandler<TaskT>& valid_handle() {
static TaskHandler<TaskT> vhandle;
return vhandle;
}
};
// TaskExecutor is a Thread pool. // TaskExecutor is a Thread pool.
template <typename TaskT> template <typename TaskT>
class TaskExecutor; class TaskExecutor;
...@@ -1115,7 +1100,12 @@ class TaskExecutor { ...@@ -1115,7 +1100,12 @@ class TaskExecutor {
int work(ThreadContext<TaskT>* context); int work(ThreadContext<TaskT>* context);
TaskHandler<TaskT> schedule(const void*, void*, MempoolRegion* memoryPtr); int schedule(const void*,
void*,
MempoolRegion* memoryPtr,
THREAD_MUTEX_T* thread_mutex_ptr,
THREAD_COND_T* thread_cond_ptr,
TaskManager<InType, OutType>* task_manager_ptr);
bool move_task_to_batch(BatchTasks<TaskT>& batchTask); // NOLINT bool move_task_to_batch(BatchTasks<TaskT>& batchTask); // NOLINT
...@@ -1194,18 +1184,25 @@ class TaskManager { ...@@ -1194,18 +1184,25 @@ class TaskManager {
typedef typename TaskT::OutVectorT OutVectorT; typedef typename TaskT::OutVectorT OutVectorT;
explicit TaskManager(uint32_t model_index) // NOLINT explicit TaskManager(uint32_t model_index) // NOLINT
: _model_index(model_index) {} : _model_index(model_index),
_task_ready(false) {}
~TaskManager() { wait(); } ~TaskManager() { wait(); }
bool schedule(const void* in, void* out, MempoolRegion* memoryPtr); // NOLINT bool schedule(const void* in,
void* out,
MempoolRegion* memoryPtr,
THREAD_MUTEX_T* thread_mutex_ptr,
THREAD_COND_T* thread_cond_ptr); // NOLINT
void wait(); void wait();
inline void clear() { wait(); } inline void clear() { wait(); }
bool _task_ready = false;
private: private:
TaskHandler<TaskT> _task_owned;
uint32_t _model_index; uint32_t _model_index;
THREAD_MUTEX_T* thread_mutex_ptr;
THREAD_COND_T* thread_cond_ptr;
}; // class TaskManager }; // class TaskManager
class AutoMutex { class AutoMutex {
......
...@@ -98,7 +98,11 @@ int ReloadableInferEngine::infer(const void* in, ...@@ -98,7 +98,11 @@ int ReloadableInferEngine::infer(const void* in,
im::bsf::TaskManager<paddle::PaddleTensor, paddle::PaddleTensor> task_manager( im::bsf::TaskManager<paddle::PaddleTensor, paddle::PaddleTensor> task_manager(
_model_index); _model_index);
task_manager.schedule(in, out, MempoolWrapper::instance().get_thread_memory_ptr()); task_manager.schedule(in,
out,
MempoolWrapper::instance().get_thread_memory_ptr(),
ThreadMutex::instance().get_thread_mutex_ptr(),
ThreadMutex::instance().get_thread_cond_ptr());
task_manager.wait(); task_manager.wait();
return 0; return 0;
} }
......
...@@ -31,8 +31,9 @@ ...@@ -31,8 +31,9 @@
#include "core/predictor/framework/infer_data.h" #include "core/predictor/framework/infer_data.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/predictor_metric.h" #include "core/predictor/framework/predictor_metric.h"
#include "paddle_inference_api.h" // NOLINT #include "core/predictor/framework/thread_mutex.h"
#include "experimental/float16.h" #include "experimental/float16.h"
#include "paddle_inference_api.h" // NOLINT
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
namespace predictor { namespace predictor {
...@@ -548,7 +549,7 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> { ...@@ -548,7 +549,7 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
int8_t* data = static_cast<int8_t*>(origin_data); int8_t* data = static_cast<int8_t*>(origin_data);
lod_tensor_in->CopyFromCpu(data); lod_tensor_in->CopyFromCpu(data);
} else if ((*tensorVector_in_pointer)[i].dtype == } else if ((*tensorVector_in_pointer)[i].dtype ==
paddle::PaddleDType::FLOAT16) { paddle::PaddleDType::FLOAT16) {
paddle::platform::float16* data = paddle::platform::float16* data =
static_cast<paddle::platform::float16*>(origin_data); static_cast<paddle::platform::float16*>(origin_data);
lod_tensor_in->CopyFromCpu(data); lod_tensor_in->CopyFromCpu(data);
......
...@@ -97,6 +97,12 @@ int Resource::initialize(const std::string& path, const std::string& file) { ...@@ -97,6 +97,12 @@ int Resource::initialize(const std::string& path, const std::string& file) {
} }
LOG(WARNING) << "Successfully proc initialized mempool wrapper"; LOG(WARNING) << "Successfully proc initialized mempool wrapper";
if (ThreadMutex::instance().initialize() != 0) {
LOG(ERROR) << "Failed proc initialized mempool wrapper";
return -1;
}
LOG(WARNING) << "Successfully proc initialized ThreadMutex";
#ifdef WITH_AUTH #ifdef WITH_AUTH
std::string product_name_str = resource_conf.auth_product_name(); std::string product_name_str = resource_conf.auth_product_name();
std::string container_id_str = resource_conf.auth_container_id(); std::string container_id_str = resource_conf.auth_container_id();
...@@ -301,6 +307,12 @@ int Resource::thread_initialize() { ...@@ -301,6 +307,12 @@ int Resource::thread_initialize() {
} }
LOG(WARNING) << "Successfully thread initialized mempool wrapper"; LOG(WARNING) << "Successfully thread initialized mempool wrapper";
if (ThreadMutex::instance().thread_initialize() != 0) {
LOG(ERROR) << "Failed thread initialized ThreadMutex";
return -1;
}
LOG(WARNING) << "Successfully thread initialized ThreadMutex";
// infer manager // infer manager
if (FLAGS_enable_model_toolkit && if (FLAGS_enable_model_toolkit &&
InferManager::instance().thrd_initialize() != 0) { InferManager::instance().thrd_initialize() != 0) {
...@@ -344,6 +356,12 @@ int Resource::finalize() { ...@@ -344,6 +356,12 @@ int Resource::finalize() {
LOG(ERROR) << "Failed proc finalize infer manager"; LOG(ERROR) << "Failed proc finalize infer manager";
return -1; return -1;
} }
if (ThreadMutex::instance().finalize() != 0) {
LOG(ERROR) << "Failed proc finalize ThreadMutex";
return -1;
}
if (CubeAPI::instance()->destroy() != 0) { if (CubeAPI::instance()->destroy() != 0) {
LOG(ERROR) << "Destory cube api failed "; LOG(ERROR) << "Destory cube api failed ";
return -1; return -1;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "core/predictor/common/inner_common.h" #include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/thread_mutex.h"
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册