bsf.h 8.4 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
W
wangguibao 已提交
16 17

#include <errno.h>
W
wangguibao 已提交
18
#include <algorithm>
W
wangguibao 已提交
19
#include <deque>
W
wangguibao 已提交
20 21
#include <vector>
#include "butil/atomicops.h"
W
wangguibao 已提交
22 23
#include "common/inner_common.h"

W
wangguibao 已提交
24
#include "boost/function.hpp"
W
wangguibao 已提交
25 26 27 28 29 30

namespace im {
namespace bsf {

static const size_t DEFAULT_BATCH_SIZE = 100;

W
wangguibao 已提交
31
template <typename InItemT, typename OutItemT>
W
wangguibao 已提交
32
struct Task {
W
wangguibao 已提交
33 34 35 36 37
  typedef std::vector<InItemT> InArrayT;
  typedef std::vector<OutItemT> OutArrayT;
  typedef InItemT InType;
  typedef OutItemT OutType;
  typedef Task<InItemT, OutItemT> TaskT;
W
wangguibao 已提交
38

W
wangguibao 已提交
39 40
  int read_fd;
  int write_fd;
W
wangguibao 已提交
41

W
wangguibao 已提交
42
  pid_t owner_tid;
W
wangguibao 已提交
43

W
wangguibao 已提交
44 45
  const InArrayT* in;
  OutArrayT* out;
W
wangguibao 已提交
46

W
wangguibao 已提交
47 48
  size_t rem;
  size_t size;
W
wangguibao 已提交
49

W
wangguibao 已提交
50
  size_t batch_size() { return in->size(); }
W
wangguibao 已提交
51

W
wangguibao 已提交
52 53 54 55 56 57 58 59 60 61 62 63
  butil::atomic<size_t> index;

  Task() {
    read_fd = -1;
    write_fd = -1;
    owner_tid = -1;
    in = NULL;
    out = NULL;
    rem = -1;
    size = -1;
    index.store(0, butil::memory_order_relaxed);
  }
W
wangguibao 已提交
64 65
};

W
wangguibao 已提交
66
template <typename TaskT>
W
wangguibao 已提交
67
struct TaskMeta {
W
wangguibao 已提交
68 69 70 71 72 73
  TaskMeta(TaskT* ptr, size_t start, size_t add)
      : task(ptr), begin(start), end(start + add) {}

  TaskT* task;
  size_t begin;
  size_t end;
W
wangguibao 已提交
74 75
};

W
wangguibao 已提交
76
template <typename TaskT>
W
wangguibao 已提交
77
class BatchTasks {
W
wangguibao 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
 public:
  typedef typename TaskT::InType InType;
  typedef typename TaskT::OutType OutType;
  typedef TaskMeta<TaskT> TaskMetaT;

  explicit BatchTasks(size_t batch_size, bool batch_align = true)
      : _batch_size(batch_size),
        _rem_size(batch_size),
        _batch_align(batch_align) {
    _batch_in.clear();
    _batch_out.clear();
    _tasks.clear();
  }

  ~BatchTasks() {
    _batch_in.clear();
    _batch_out.clear();
    _tasks.clear();
  }

  // synchronized operation
  size_t append_task(TaskT* task) {
    size_t add = std::min(task->rem, _rem_size);
    if (!_batch_align) {
      add = task->rem;
W
wangguibao 已提交
103 104
    }

W
wangguibao 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
    TaskMetaT tm(task, task->in->size() - task->rem, add);
    _tasks.push_back(tm);

    task->rem -= add;
    _rem_size -= add;
    return _rem_size;
  }

  static bool check_valid(const typename TaskT::InArrayT& in,
                          const typename TaskT::OutArrayT& out,
                          bool align) {
    (void)in;
    (void)out;
    (void)align;
    return true;
  }

  void merge_tasks() {
    for (size_t ti = 0; ti < _tasks.size(); ++ti) {
      TaskMetaT& tm = _tasks[ti];
      for (size_t vi = tm.begin; vi < tm.end; ++vi) {
        _batch_in.push_back((*tm.task->in)[vi]);
        _batch_out.push_back((*tm.task->out)[vi]);
      }
W
wangguibao 已提交
129
    }
W
wangguibao 已提交
130
  }
W
wangguibao 已提交
131

W
wangguibao 已提交
132 133 134 135 136
  void notify_tasks() {
    if (_batch_out.size() != _batch_in.size()) {
      LOG(ERROR) << "batch size not consistency: " << _batch_out.size()
                 << " != " << _batch_in.size();
      return;
W
wangguibao 已提交
137 138
    }

W
wangguibao 已提交
139 140 141 142 143 144 145 146 147 148 149
    for (size_t ti = 0, bi = 0; ti < _tasks.size(); ++ti) {
      TaskT* task = _tasks[ti].task;
      size_t begin = _tasks[ti].begin;
      size_t end = _tasks[ti].end;
      size_t add = end - begin;

      for (size_t oi = begin; oi < end; ++oi, ++bi) {
        if (bi >= _batch_in.size()) {
          LOG(ERROR) << "batch index overflow: " << bi << " > "
                     << _batch_in.size();
          return;
W
wangguibao 已提交
150
        }
W
wangguibao 已提交
151 152
        (*task->out)[oi] = _batch_out[bi];
      }
W
wangguibao 已提交
153

W
wangguibao 已提交
154 155 156 157
      size_t index = task->index.fetch_add(add);
      if ((index + add) >= task->in->size()) {
        char c = 0;
        while (write(task->write_fd, &c, 1) != 1 && errno == EINTR) {
W
wangguibao 已提交
158
        }
W
wangguibao 已提交
159 160
        butil::return_object(task);
      }
W
wangguibao 已提交
161
    }
W
wangguibao 已提交
162
  }
W
wangguibao 已提交
163

W
wangguibao 已提交
164
  const typename TaskT::InArrayT& in() const { return _batch_in; }
W
wangguibao 已提交
165

W
wangguibao 已提交
166
  typename TaskT::OutArrayT& out() { return _batch_out; }
W
wangguibao 已提交
167

W
wangguibao 已提交
168 169 170 171 172 173 174 175 176
  size_t task_size() { return _tasks.size(); }

 private:
  std::vector<TaskMetaT> _tasks;
  typename TaskT::InArrayT _batch_in;
  typename TaskT::OutArrayT _batch_out;
  size_t _rem_size;
  size_t _batch_size;
  bool _batch_align;
W
wangguibao 已提交
177 178
};

W
wangguibao 已提交
179 180
// BSF task handle
template <typename TaskT>
W
wangguibao 已提交
181
struct TaskHandler {
W
wangguibao 已提交
182 183
  int read_fd;
  int write_fd;
W
wangguibao 已提交
184

W
wangguibao 已提交
185 186 187
  TaskHandler() : read_fd(-1), write_fd(-1) {
    // do nothing
  }
W
wangguibao 已提交
188

W
wangguibao 已提交
189 190 191 192
  explicit TaskHandler(TaskT const& task)
      : read_fd(task.read_fd), write_fd(task.write_fd) {
    // do nothing
  }
W
wangguibao 已提交
193

W
wangguibao 已提交
194
  inline bool valid() const { return read_fd >= 0 && write_fd >= 0; }
W
wangguibao 已提交
195

W
wangguibao 已提交
196 197 198 199
  static TaskHandler<TaskT>& valid_handle() {
    static TaskHandler<TaskT> vhandle;
    return vhandle;
  }
W
wangguibao 已提交
200 201
};

W
wangguibao 已提交
202
template <typename TaskT>
W
wangguibao 已提交
203 204
class TaskExecutor;

W
wangguibao 已提交
205
template <typename InItemT, typename OutItemT>
W
wangguibao 已提交
206 207
class TaskManager;

W
wangguibao 已提交
208
template <typename TaskT>
W
wangguibao 已提交
209
struct ThreadContext {
W
wangguibao 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
  TaskExecutor<TaskT>* executor;
  void* user_thread_context;
  THREAD_T tid;
  int init_status;

  ThreadContext()
      : executor(NULL), user_thread_context(NULL), tid(-1), init_status(0) {
    // do nothing
  }

  ~ThreadContext() {
    tid = -1;
    executor = NULL;
    user_thread_context = NULL;
    init_status = 0;
  }
W
wangguibao 已提交
226 227
};

W
wangguibao 已提交
228
template <typename TaskT>
W
wangguibao 已提交
229
class TaskExecutor {
W
wangguibao 已提交
230 231 232 233 234 235
 public:
  typedef typename TaskT::InType InType;
  typedef typename TaskT::OutType OutType;
  typedef typename TaskT::InArrayT InArrayT;
  typedef typename TaskT::OutArrayT OutArrayT;
  typedef std::vector<TaskT> TaskArrayT;
W
wangguibao 已提交
236

W
wangguibao 已提交
237 238 239 240 241 242 243 244 245 246 247 248
  TaskExecutor()
      : _stop(false),
        _thread_init_fn(NULL),
        _thread_reset_fn(NULL),
        _user_thread_contexts(NULL),
        _batch_size(DEFAULT_BATCH_SIZE),
        _batch_align(false),
        _fn(NULL) {
    THREAD_MUTEX_INIT(&_mut, NULL);
    THREAD_COND_INIT(&_cond, NULL);
    _task_queue.clear();
  }
W
wangguibao 已提交
249

W
wangguibao 已提交
250 251 252 253
  ~TaskExecutor() {
    THREAD_MUTEX_DESTROY(&_mut);
    THREAD_COND_DESTROY(&_cond);
  }
W
wangguibao 已提交
254

W
wangguibao 已提交
255 256 257 258
  static TaskExecutor<TaskT>* instance() {
    static TaskExecutor<TaskT> singleton;
    return &singleton;
  }
W
wangguibao 已提交
259

W
wangguibao 已提交
260
  void set_batch_size(size_t batch_size) { _batch_size = batch_size; }
W
wangguibao 已提交
261

W
wangguibao 已提交
262
  void set_batch_align(size_t batch_align) { _batch_align = batch_align; }
W
wangguibao 已提交
263

W
wangguibao 已提交
264 265 266 267 268
  void set_thread_init_fn(boost::function<int(void*)> init_fn,
                          void** contexts = NULL) {
    _thread_init_fn = init_fn;
    _user_thread_contexts = contexts;
  }
W
wangguibao 已提交
269

W
wangguibao 已提交
270 271 272 273 274 275 276 277
  void set_thread_reset_fn(boost::function<int(void*)> reset_fn) {
    _thread_reset_fn = reset_fn;
  }

  void set_thread_callback_fn(
      boost::function<void(const InArrayT&, OutArrayT&)> cb) {
    _fn = cb;
  }
W
wangguibao 已提交
278

W
wangguibao 已提交
279 280
  int start(uint32_t thread_num, uint32_t init_timeout_sec = 0);
  void stop();
W
wangguibao 已提交
281

W
wangguibao 已提交
282
  static void* thread_entry(void* args);
W
wangguibao 已提交
283

W
wangguibao 已提交
284 285 286
 private:
  TaskExecutor(TaskExecutor<TaskT> const& other);
  TaskExecutor* operator=(TaskExecutor<TaskT> const& other);
W
wangguibao 已提交
287

W
wangguibao 已提交
288
  int work(ThreadContext<TaskT>* context);
W
wangguibao 已提交
289

W
wangguibao 已提交
290
  TaskHandler<TaskT> schedule(const InArrayT&, OutArrayT&);
W
wangguibao 已提交
291

W
wangguibao 已提交
292
  bool fetch_batch(BatchTasks<TaskT>& batch);  // NOLINT
W
wangguibao 已提交
293

W
wangguibao 已提交
294
  bool _stop;
W
wangguibao 已提交
295

W
wangguibao 已提交
296 297 298
  // can't use boost::mutex, because some stupid macro
  THREAD_MUTEX_T _mut;
  THREAD_COND_T _cond;
W
wangguibao 已提交
299

W
wangguibao 已提交
300
  std::deque<TaskT*> _task_queue;
W
wangguibao 已提交
301

W
wangguibao 已提交
302 303 304
  boost::function<int(void*)> _thread_init_fn;
  boost::function<int(void*)> _thread_reset_fn;
  void** _user_thread_contexts;
W
wangguibao 已提交
305

W
wangguibao 已提交
306 307
  std::vector<ThreadContext<TaskT>*> _thread_contexts;
  friend class TaskManager<InType, OutType>;
W
wangguibao 已提交
308

W
wangguibao 已提交
309 310
  size_t _batch_size;
  bool _batch_align;
W
wangguibao 已提交
311

W
wangguibao 已提交
312
  boost::function<void(const InArrayT&, OutArrayT&)> _fn;
W
wangguibao 已提交
313 314
};

W
wangguibao 已提交
315
template <typename InItemT, typename OutItemT>
W
wangguibao 已提交
316
class TaskManager {
W
wangguibao 已提交
317 318 319 320
 public:
  typedef Task<InItemT, OutItemT> TaskT;
  typedef typename TaskT::InArrayT InArrayT;
  typedef typename TaskT::OutArrayT OutArrayT;
W
wangguibao 已提交
321

W
wangguibao 已提交
322 323
  explicit TaskManager(TaskExecutor<TaskT>& exe, size_t batch_size)  // NOLINT
      : _executor(exe) {}
W
wangguibao 已提交
324

W
wangguibao 已提交
325
  TaskManager() : _executor(*TaskExecutor<TaskT>::instance()) {}
W
wangguibao 已提交
326

W
wangguibao 已提交
327
  ~TaskManager() { wait(); }
W
wangguibao 已提交
328

W
wangguibao 已提交
329 330
  bool schedule(const InArrayT& in, OutArrayT& out);  // NOLINT
  void wait();
W
wangguibao 已提交
331

W
wangguibao 已提交
332
  inline void clear() { wait(); }
W
wangguibao 已提交
333

W
wangguibao 已提交
334 335 336 337
 private:
  TaskExecutor<TaskT>& _executor;
  TaskHandler<TaskT> _task_owned;
};  // class TaskManager
W
wangguibao 已提交
338 339

class AutoMutex {
W
wangguibao 已提交
340 341 342 343
 public:
  explicit AutoMutex(THREAD_MUTEX_T& mut) : _mut(mut) {
    THREAD_MUTEX_LOCK(&_mut);
  }
W
wangguibao 已提交
344

W
wangguibao 已提交
345
  ~AutoMutex() { THREAD_MUTEX_UNLOCK(&_mut); }
W
wangguibao 已提交
346

W
wangguibao 已提交
347 348
 private:
  THREAD_MUTEX_T& _mut;
W
wangguibao 已提交
349 350
};

W
wangguibao 已提交
351 352
}  // namespace bsf
}  // namespace im
W
wangguibao 已提交
353

W
wangguibao 已提交
354 355
#include "predictor/framework/bsf-inl-tensor.h"
#include "predictor/framework/bsf-inl.h"