#ifndef BAIDU_PADDLE_SERVING_PREDICTOR_BSF_H #define BAIDU_PADDLE_SERVING_PREDICTOR_BSF_H #include #include #include #include #include "common/inner_common.h" #include namespace im { namespace bsf { static const size_t DEFAULT_BATCH_SIZE = 100; template struct Task { typedef std::vector InArrayT; typedef std::vector OutArrayT; typedef InItemT InType; typedef OutItemT OutType; typedef Task TaskT; int read_fd; int write_fd; pid_t owner_tid; const InArrayT* in; OutArrayT* out; size_t rem; size_t size; size_t batch_size() { return in->size(); } butil::atomic index; Task() { read_fd = -1; write_fd = -1; owner_tid = -1; in = NULL; out = NULL; rem = -1; size = -1; index.store(0, butil::memory_order_relaxed); } }; template struct TaskMeta { TaskMeta(TaskT* ptr, size_t start, size_t add) : task(ptr) , begin(start) , end(start + add) {} TaskT* task; size_t begin; size_t end; }; template class BatchTasks { public: typedef typename TaskT::InType InType; typedef typename TaskT::OutType OutType; typedef TaskMeta TaskMetaT; BatchTasks(size_t batch_size, bool batch_align = true) : _batch_size(batch_size) , _rem_size(batch_size) , _batch_align(batch_align) { _batch_in.clear(); _batch_out.clear(); _tasks.clear(); } ~BatchTasks() { _batch_in.clear(); _batch_out.clear(); _tasks.clear(); } // synchronized operation size_t append_task(TaskT* task) { size_t add = std::min(task->rem, _rem_size); if (!_batch_align) { add = task->rem; } TaskMetaT tm(task, task->in->size() - task->rem, add); _tasks.push_back(tm); task->rem -= add; _rem_size -= add; return _rem_size; } static bool check_valid( const typename TaskT::InArrayT& in, typename TaskT::OutArrayT& out, bool align) { (void)in; (void)out; (void)align; return true; } void merge_tasks() { for (size_t ti = 0; ti < _tasks.size(); ++ti) { TaskMetaT& tm = _tasks[ti]; for (size_t vi = tm.begin; vi < tm.end; ++vi) { _batch_in.push_back((*tm.task->in)[vi]); _batch_out.push_back((*tm.task->out)[vi]); } } } void notify_tasks() { if (_batch_out.size() != _batch_in.size()) { LOG(FATAL) << "batch size not consistency: " << _batch_out.size() << " != " << _batch_in.size(); return ; } for (size_t ti = 0, bi = 0; ti < _tasks.size(); ++ti) { TaskT* task = _tasks[ti].task; size_t begin = _tasks[ti].begin; size_t end = _tasks[ti].end; size_t add = end - begin; for (size_t oi = begin; oi < end; ++oi, ++bi) { if (bi >= _batch_in.size()) { LOG(FATAL) << "batch index overflow: " << bi << " > " <<_batch_in.size(); return ; } (*task->out)[oi] = _batch_out[bi]; } size_t index = task->index.fetch_add(add); if ((index + add) >= task->in->size()) { char c = 0; while (write(task->write_fd, &c, 1) != 1 && errno == EINTR) { ; } butil::return_object(task); } } } const typename TaskT::InArrayT& in() const { return _batch_in; } typename TaskT::OutArrayT& out() { return _batch_out; } size_t task_size() { return _tasks.size(); } private: std::vector _tasks; typename TaskT::InArrayT _batch_in; typename TaskT::OutArrayT _batch_out; size_t _rem_size; size_t _batch_size; bool _batch_align; }; // BSF 任务句柄, 用来等待时指定任务列表 template struct TaskHandler { int read_fd; int write_fd; TaskHandler() : read_fd(-1), write_fd(-1) { // do nothing } TaskHandler(TaskT const& task) : read_fd(task.read_fd) , write_fd(task.write_fd) { // do nothing } inline bool valid() const { return read_fd >= 0 && write_fd >= 0; } static TaskHandler& valid_handle() { static TaskHandler vhandle; return vhandle; } }; template class TaskExecutor; template class TaskManager; template struct ThreadContext { TaskExecutor* executor; void* user_thread_context; THREAD_T tid; int init_status; ThreadContext() : executor(NULL) , user_thread_context(NULL) , tid(-1), init_status(0) { // do nothing } ~ThreadContext() { tid = -1; executor = NULL; user_thread_context = NULL; init_status = 0; } }; template class TaskExecutor { public: typedef typename TaskT::InType InType; typedef typename TaskT::OutType OutType; typedef typename TaskT::InArrayT InArrayT; typedef typename TaskT::OutArrayT OutArrayT; typedef std::vector TaskArrayT; TaskExecutor() : _stop(false) , _thread_init_fn(NULL) , _thread_reset_fn(NULL) , _user_thread_contexts(NULL) , _batch_size(DEFAULT_BATCH_SIZE) , _batch_align(false) , _fn(NULL) { THREAD_MUTEX_INIT(&_mut, NULL); THREAD_COND_INIT(&_cond, NULL); _task_queue.clear(); } ~TaskExecutor() { THREAD_MUTEX_DESTROY(&_mut); THREAD_COND_DESTROY(&_cond); } static TaskExecutor* instance() { static TaskExecutor singleton; return &singleton; } void set_batch_size(size_t batch_size) { _batch_size = batch_size; } void set_batch_align(size_t batch_align) { _batch_align = batch_align; } void set_thread_init_fn(boost::function init_fn, void** contexts = NULL) { _thread_init_fn = init_fn; _user_thread_contexts = contexts; } void set_thread_reset_fn(boost::function reset_fn) { _thread_reset_fn = reset_fn; } void set_thread_callback_fn(boost::function cb) { _fn = cb; } int start(uint32_t thread_num, uint32_t init_timeout_sec = 0); void stop(); static void* thread_entry(void* args); private: TaskExecutor(TaskExecutor const& other); TaskExecutor* operator=(TaskExecutor const& other); int work(ThreadContext* context); TaskHandler schedule(const InArrayT&, OutArrayT&); bool fetch_batch(BatchTasks& batch); bool _stop; // can't use boost::mutex, because some stupid macro THREAD_MUTEX_T _mut; THREAD_COND_T _cond; std::deque _task_queue; boost::function _thread_init_fn; boost::function _thread_reset_fn; void** _user_thread_contexts; std::vector*> _thread_contexts; friend class TaskManager; size_t _batch_size; bool _batch_align; boost::function _fn; }; template class TaskManager { public: typedef Task TaskT; typedef typename TaskT::InArrayT InArrayT; typedef typename TaskT::OutArrayT OutArrayT; explicit TaskManager(TaskExecutor& exe, size_t batch_size) : _executor(exe) { } TaskManager() : _executor(*TaskExecutor::instance()) { } ~TaskManager() { wait(); } bool schedule(const InArrayT& in, OutArrayT& out); void wait(); inline void clear() { wait(); } private: TaskExecutor& _executor; TaskHandler _task_owned; }; // class TaskManager class AutoMutex { public: AutoMutex(THREAD_MUTEX_T& mut) : _mut(mut) { THREAD_MUTEX_LOCK(&_mut); } ~AutoMutex() { THREAD_MUTEX_UNLOCK(&_mut); } private: THREAD_MUTEX_T& _mut; }; } // namespace bsf } // namespace im #include "bsf-inl.h" #include "bsf-inl-tensor.h" #endif //BAIDU_PADDLE_SERVING_PREDICTOR_BSF_H