infer.h 30.4 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
W
wangguibao 已提交
16
#include <sys/stat.h>
W
wangguibao 已提交
17
#include <sys/types.h>
W
wangguibao 已提交
18
#include <unistd.h>
Z
zhangjun 已提交
19
#include <pthread.h>
W
wangguibao 已提交
20
#include <string>
M
MRXLT 已提交
21
#include <utility>
W
wangguibao 已提交
22
#include <vector>
H
HexToString 已提交
23
#include <numeric>
G
guru4elephant 已提交
24
#include "core/predictor/common/inner_common.h"
H
HexToString 已提交
25
#include "core/predictor/framework/bsf.h"
G
guru4elephant 已提交
26 27
#include "core/predictor/framework/factory.h"
#include "core/predictor/framework/infer_data.h"
W
wangjiawei04 已提交
28
#include "paddle_inference_api.h"  // NOLINT
W
wangguibao 已提交
29 30 31 32
namespace baidu {
namespace paddle_serving {
namespace predictor {

W
wangguibao 已提交
33 34
using configure::ModelToolkitConf;

Z
zhangjun 已提交
35 36 37 38 39 40 41 42 43 44 45
class AutoLock {
 public:
  explicit AutoLock(pthread_mutex_t& mutex) : _mut(mutex) {
    pthread_mutex_lock(&mutex);
  }
  ~AutoLock() { pthread_mutex_unlock(&_mut); }

 private:
  pthread_mutex_t& _mut;
};

Z
update  
zhangjun 已提交
46
class GlobalCreateMutex {
Z
zhangjun 已提交
47 48 49 50
 public:
  pthread_mutex_t& mutex() { return _mut; }

  static pthread_mutex_t& instance() {
Z
update  
zhangjun 已提交
51
    static GlobalCreateMutex gmutex;
Z
zhangjun 已提交
52 53 54 55
    return gmutex.mutex();
  }

 private:
Z
update  
zhangjun 已提交
56
  GlobalCreateMutex() { pthread_mutex_init(&_mut, NULL); }
Z
zhangjun 已提交
57 58 59
  pthread_mutex_t _mut;
};

W
wangguibao 已提交
60
class InferEngine {
W
wangguibao 已提交
61 62 63 64 65 66 67 68 69 70
 public:
  virtual ~InferEngine() {}

  virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
    return proc_initialize_impl(conf, version);
  }
  virtual int proc_finalize() { return proc_finalize_impl(); }
  virtual int thrd_initialize() { return thrd_initialize_impl(); }
  virtual int thrd_clear() { return thrd_clear_impl(); }
  virtual int thrd_finalize() { return thrd_finalize_impl(); }
H
HexToString 已提交
71
  virtual int infer(const void* in, void* out, uint32_t batch_size = -1) { return infer_impl(in, out, batch_size); }
W
wangguibao 已提交
72 73 74 75 76 77 78 79 80 81 82 83

  virtual int reload() = 0;

  virtual uint64_t version() const = 0;

  // begin: framework inner call
  virtual int proc_initialize_impl(const configure::EngineDesc& conf,
                                   bool version) = 0;
  virtual int thrd_initialize_impl() = 0;
  virtual int thrd_finalize_impl() = 0;
  virtual int thrd_clear_impl() = 0;
  virtual int proc_finalize_impl() = 0;
H
HexToString 已提交
84
  virtual int infer_impl(const void* in,
H
HexToString 已提交
85 86
                          void* out,
                          uint32_t batch_size = -1) = 0;
H
HexToString 已提交
87
  virtual int task_infer_impl(const BatchTensor& in,
H
HexToString 已提交
88 89
                          BatchTensor& out) = 0;  // NOLINT

W
wangguibao 已提交
90 91 92 93 94 95
  // end: framework inner call
};

class ReloadableInferEngine : public InferEngine {
 public:
  virtual ~ReloadableInferEngine() {}
W
wangguibao 已提交
96

W
wangguibao 已提交
97 98 99 100 101
  union last_check_status {
    time_t last_timestamp;
    uint64_t last_md5sum;
    uint64_t last_revision;
  };
W
wangguibao 已提交
102

Z
update  
zhangjun 已提交
103
  virtual int load(const configure::EngineDesc& conf) = 0;
H
HexToString 已提交
104
  typedef im::bsf::Task<Tensor, Tensor> TaskT;
W
wangguibao 已提交
105 106 107 108

  int proc_initialize_impl(const configure::EngineDesc& conf, bool version) {
    _reload_tag_file = conf.reloadable_meta();
    _reload_mode_tag = conf.reloadable_type();
Z
zhangjun 已提交
109
    _model_data_path = conf.model_dir();
W
wangguibao 已提交
110 111 112
    _infer_thread_num = conf.runtime_thread_num();
    _infer_batch_size = conf.batch_infer_size();
    _infer_batch_align = conf.enable_batch_align();
113

Z
update  
zhangjun 已提交
114
    _conf = conf;
Z
zhangjun 已提交
115

Z
update  
zhangjun 已提交
116
    if (!check_need_reload() || load(conf) != 0) {
W
wangguibao 已提交
117 118
      LOG(ERROR) << "Failed load model_data_path" << _model_data_path;
      return -1;
W
wangguibao 已提交
119
    }
W
wangguibao 已提交
120 121 122 123

    if (parse_version_info(conf, version) != 0) {
      LOG(ERROR) << "Failed parse version info";
      return -1;
W
wangguibao 已提交
124
    }
W
wangguibao 已提交
125 126 127 128 129 130 131 132 133

    LOG(WARNING) << "Succ load model_data_path" << _model_data_path;
    return 0;
  }

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    if (proc_initialize_impl(conf, version) != 0) {
      LOG(ERROR) << "Failed proc initialize impl";
      return -1;
W
wangguibao 已提交
134
    }
H
HexToString 已提交
135 136 137 138 139 140 141 142 143 144 145 146

    // init bsf framework
    if (_infer_thread_num <= 0) {
      return 0;
    }

    // init bsf framework
    im::bsf::TaskExecutor<TaskT>::instance()->set_thread_init_fn(
        boost::bind(&InferEngine::thrd_initialize_impl, this));
    im::bsf::TaskExecutor<TaskT>::instance()->set_thread_reset_fn(
        boost::bind(&InferEngine::thrd_clear_impl, this));
    im::bsf::TaskExecutor<TaskT>::instance()->set_thread_callback_fn(
H
HexToString 已提交
147
        boost::bind(&InferEngine::task_infer_impl, this, _1, _2));
H
HexToString 已提交
148 149 150 151 152 153 154 155 156 157 158 159
    im::bsf::TaskExecutor<TaskT>::instance()->set_batch_size(_infer_batch_size);
    im::bsf::TaskExecutor<TaskT>::instance()->set_batch_align(
        _infer_batch_align);
    if (im::bsf::TaskExecutor<TaskT>::instance()->start(_infer_thread_num) !=
        0) {
      LOG(ERROR) << "Failed start bsf executor, threads:" << _infer_thread_num;
      return -1;
    }

    LOG(WARNING) << "Enable batch schedule framework, thread_num:"
                 << _infer_thread_num << ", batch_size:" << _infer_batch_size
                 << ", enable_batch_align:" << _infer_batch_align;
W
wangguibao 已提交
160 161
    return 0;
  }
W
wangguibao 已提交
162

H
HexToString 已提交
163 164
  int infer(const void* in, void* out, uint32_t batch_size = -1) {
    if (_infer_thread_num <= 0) {
H
HexToString 已提交
165
      return infer_impl(in, out, batch_size);
H
HexToString 已提交
166 167 168 169 170 171 172 173
    }

    im::bsf::TaskManager<Tensor, Tensor> task_manager;
    task_manager.schedule(*(reinterpret_cast<const BatchTensor*>(in)),
                          *(reinterpret_cast<BatchTensor*>(out)));
    task_manager.wait();
    return 0;
  }
W
wangguibao 已提交
174

W
wangguibao 已提交
175 176 177 178 179 180
  int thrd_initialize() {
    if (_infer_thread_num > 0) {
      return 0;
    }
    return thrd_initialize_impl();
  }
W
wangguibao 已提交
181

W
wangguibao 已提交
182 183 184 185
  int thrd_clear() {
    if (_infer_thread_num > 0) {
      return 0;
    }
W
wangguibao 已提交
186

W
wangguibao 已提交
187 188
    return thrd_clear_impl();
  }
W
wangguibao 已提交
189

W
wangguibao 已提交
190 191 192 193 194
  int proc_finalize() {
    if (proc_finalize_impl() != 0) {
      LOG(ERROR) << "Failed proc finalize impl";
      return -1;
    }
W
wangguibao 已提交
195

H
HexToString 已提交
196 197 198
    if (_infer_thread_num > 0) {
      im::bsf::TaskExecutor<TaskT>::instance()->stop();
    }
W
wangguibao 已提交
199 200
    return 0;
  }
W
wangguibao 已提交
201

W
wangguibao 已提交
202 203 204
  int reload() {
    if (check_need_reload()) {
      LOG(WARNING) << "begin reload model[" << _model_data_path << "].";
Z
zhangjun 已提交
205
      return load(_conf);
W
wangguibao 已提交
206 207 208 209 210
    }
    return 0;
  }

  uint64_t version() const { return _version; }
Z
update  
zhangjun 已提交
211
  
W
wangguibao 已提交
212
  uint32_t thread_num() const { return _infer_thread_num; }
W
wangguibao 已提交
213

W
wangguibao 已提交
214 215 216 217 218
 private:
  int parse_version_info(const configure::EngineDesc& config, bool version) {
    _version = uint64_t(-1);
    return 0;
  }
W
wangguibao 已提交
219

W
wangguibao 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
  bool check_need_reload() {
    if (_reload_mode_tag == "timestamp_ne") {
      return check_timestamp_ne();
    } else if (_reload_mode_tag == "timestamp_gt") {
      return check_timestamp_gt();
    } else if (_reload_mode_tag == "md5sum") {
      return check_md5sum();
    } else if (_reload_mode_tag == "revision") {
      return check_revision();
    } else if (_reload_mode_tag == "none") {
      return false;
    } else {
      LOG(ERROR) << "Not support check type: " << _reload_mode_tag;
      return false;
    }
  }

  bool check_timestamp_ne() {
    struct stat st;
    if (stat(_reload_tag_file.c_str(), &st) != 0) {
      LOG(ERROR) << "Failed stat config file:" << _reload_tag_file;
      return false;
    }
W
wangguibao 已提交
243

W
wangguibao 已提交
244 245 246
    if ((st.st_mode & S_IFREG) && st.st_mtime != _last_status.last_timestamp) {
      _last_status.last_timestamp = st.st_mtime;
      return true;
W
wangguibao 已提交
247 248
    }

W
wangguibao 已提交
249 250
    return false;
  }
W
wangguibao 已提交
251

W
wangguibao 已提交
252 253 254 255 256 257
  bool check_timestamp_gt() {
    struct stat st;
    if (stat(_reload_tag_file.c_str(), &st) != 0) {
      LOG(ERROR) << "Failed stat config file:" << _reload_tag_file;
      return false;
    }
W
wangguibao 已提交
258

W
wangguibao 已提交
259 260 261
    if ((st.st_mode & S_IFREG) && st.st_mtime > _last_status.last_timestamp) {
      _last_status.last_timestamp = st.st_mtime;
      return true;
W
wangguibao 已提交
262 263
    }

W
wangguibao 已提交
264 265 266 267 268 269 270 271 272
    return false;
  }

  bool check_md5sum() { return false; }

  bool check_revision() { return false; }

 protected:
  std::string _model_data_path;
Z
update  
zhangjun 已提交
273
  configure::EngineDesc _conf;
W
wangguibao 已提交
274 275 276 277 278 279 280 281 282 283

 private:
  std::string _reload_tag_file;
  std::string _reload_mode_tag;
  last_check_status _last_status;
  uint32_t _infer_thread_num;
  uint32_t _infer_batch_size;
  bool _infer_batch_align;
  uint64_t _version;
};
W
wangguibao 已提交
284

W
wangguibao 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
template <typename EngineCore>
struct ModelData {
  ModelData() : current_idx(1) {
    cores[0] = NULL;
    cores[1] = NULL;
  }

  ~ModelData() {
    delete cores[0];
    delete cores[1];
  }

  EngineCore* cores[2];
  uint32_t current_idx;
};

template <typename EngineCore>
class DBReloadableInferEngine : public ReloadableInferEngine {
 public:
  virtual ~DBReloadableInferEngine() {}

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    THREAD_KEY_CREATE(&_skey, NULL);
    THREAD_MUTEX_INIT(&_mutex, NULL);
    return ReloadableInferEngine::proc_initialize(conf, version);
  }

Z
update  
zhangjun 已提交
312
  virtual int load(const configure::EngineDesc& conf) {
W
wangguibao 已提交
313 314
    if (_reload_vec.empty()) {
      return 0;
W
wangguibao 已提交
315 316
    }

W
wangguibao 已提交
317
    for (uint32_t ti = 0; ti < _reload_vec.size(); ++ti) {
Z
update  
zhangjun 已提交
318
      if (load_data(_reload_vec[ti], conf) != 0) {
W
wangguibao 已提交
319 320 321 322 323
        LOG(ERROR) << "Failed reload engine model: " << ti;
        return -1;
      }
    }

Z
update  
zhangjun 已提交
324
    LOG(WARNING) << "Succ load engine, path: " << conf.model_dir();
W
wangguibao 已提交
325

W
wangguibao 已提交
326 327
    return 0;
  }
W
wangguibao 已提交
328

329
  int load_data(ModelData<EngineCore>* md,
Z
update  
zhangjun 已提交
330
                const configure::EngineDesc& conf) {
W
wangguibao 已提交
331 332 333
    uint32_t next_idx = (md->current_idx + 1) % 2;
    if (md->cores[next_idx]) {
      delete md->cores[next_idx];
W
wangguibao 已提交
334 335
    }

W
wangguibao 已提交
336
    md->cores[next_idx] = new (std::nothrow) EngineCore;
337

Z
update  
zhangjun 已提交
338 339 340
    //params.dump();
    if (!md->cores[next_idx] || md->cores[next_idx]->create(conf) != 0) {
      LOG(ERROR) << "Failed create model, path: " << conf.model_dir();
W
wangguibao 已提交
341
      return -1;
W
wangguibao 已提交
342
    }
W
wangguibao 已提交
343 344 345
    md->current_idx = next_idx;
    return 0;
  }
W
wangguibao 已提交
346

W
wangguibao 已提交
347 348
  virtual int thrd_initialize_impl() {
    // memory pool to be inited in non-serving-threads
H
HexToString 已提交
349 350 351 352
    if (MempoolWrapper::instance().thread_initialize() != 0) {
      LOG(ERROR) << "Failed thread initialize mempool";
      return -1;
    }
W
wangguibao 已提交
353

W
wangguibao 已提交
354
    ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
Z
update  
zhangjun 已提交
355
    if (!md || load_data(md, _conf) != 0) {
356
      LOG(ERROR) << "Failed create thread data from "
Z
zhangjun 已提交
357
                 << _conf.model_dir();
W
wangguibao 已提交
358
      return -1;
W
wangguibao 已提交
359 360
    }

W
wangguibao 已提交
361
    THREAD_SETSPECIFIC(_skey, md);
H
HexToString 已提交
362
    im::bsf::AutoMutex lock(_mutex);
W
wangguibao 已提交
363 364 365 366 367 368
    _reload_vec.push_back(md);
    return 0;
  }

  int thrd_clear_impl() {
    // for non-serving-threads
H
HexToString 已提交
369 370 371 372
    if (MempoolWrapper::instance().thread_clear() != 0) {
      LOG(ERROR) << "Failed thread clear mempool";
      return -1;
    }
W
wangguibao 已提交
373 374 375 376
    return 0;
  }

  int thrd_finalize_impl() { return 0; }
W
wangguibao 已提交
377

W
wangguibao 已提交
378 379 380 381 382
  int proc_finalize_impl() {
    THREAD_KEY_DELETE(_skey);
    THREAD_MUTEX_DESTROY(&_mutex);
    return 0;
  }
W
wangguibao 已提交
383

W
wangguibao 已提交
384 385 386 387 388 389
  EngineCore* get_core() {
    ModelData<EngineCore>* md =
        (ModelData<EngineCore>*)THREAD_GETSPECIFIC(_skey);
    if (!md) {
      LOG(ERROR) << "Failed get thread specific data";
      return NULL;
W
wangguibao 已提交
390
    }
W
wangguibao 已提交
391 392
    return md->cores[md->current_idx];
  }
W
wangguibao 已提交
393

W
wangguibao 已提交
394 395 396 397
 protected:
  THREAD_KEY_T _skey;
  THREAD_MUTEX_T _mutex;
  std::vector<ModelData<EngineCore>*> _reload_vec;
W
wangguibao 已提交
398

W
wangguibao 已提交
399 400
 private:
};
W
wangguibao 已提交
401

W
wangguibao 已提交
402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
// 多个EngineCore共用同一份模型数据
template <typename EngineCore>
class CloneDBReloadableInferEngine
    : public DBReloadableInferEngine<EngineCore> {
 public:
  virtual ~CloneDBReloadableInferEngine() {}

  virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
    _pd = new (std::nothrow) ModelData<EngineCore>;
    if (!_pd) {
      LOG(ERROR) << "Failed to allocate for ProcData";
      return -1;
    }
    return DBReloadableInferEngine<EngineCore>::proc_initialize(conf, version);
  }

Z
update  
zhangjun 已提交
418
  virtual int load(const configure::EngineDesc& conf) {
W
wangguibao 已提交
419 420
    // 加载进程级模型数据
    if (!_pd ||
Z
update  
zhangjun 已提交
421
        DBReloadableInferEngine<EngineCore>::load_data(_pd, conf) != 0) {
Z
zhangjun 已提交
422
      LOG(ERROR) << "Failed to create common model from [" << conf.model_dir()
W
wangguibao 已提交
423 424 425 426
                 << "].";
      return -1;
    }
    LOG(WARNING) << "Succ load common model[" << _pd->cores[_pd->current_idx]
Z
update  
zhangjun 已提交
427
                 << "], path[" << conf.model_dir() << "].";
W
wangguibao 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440

    if (DBReloadableInferEngine<EngineCore>::_reload_vec.empty()) {
      return 0;
    }

    for (uint32_t ti = 0;
         ti < DBReloadableInferEngine<EngineCore>::_reload_vec.size();
         ++ti) {
      if (load_data(DBReloadableInferEngine<EngineCore>::_reload_vec[ti],
                    _pd->cores[_pd->current_idx]) != 0) {
        LOG(ERROR) << "Failed reload engine model: " << ti;
        return -1;
      }
W
wangguibao 已提交
441 442
    }

Z
update  
zhangjun 已提交
443
    LOG(WARNING) << "Succ load clone model, path[" << conf.model_dir() << "]";
W
wangguibao 已提交
444

W
wangguibao 已提交
445 446
    return 0;
  }
W
wangguibao 已提交
447

W
wangguibao 已提交
448 449 450 451 452
  // 加载线程级对象,多个线程级对象共用pd_core的模型数据
  int load_data(ModelData<EngineCore>* td, EngineCore* pd_core) {
    uint32_t next_idx = (td->current_idx + 1) % 2;
    if (td->cores[next_idx]) {
      delete td->cores[next_idx];
W
wangguibao 已提交
453 454
    }

W
wangguibao 已提交
455 456 457 458 459 460
    td->cores[next_idx] = new (std::nothrow) EngineCore;
    if (!td->cores[next_idx] ||
        td->cores[next_idx]->clone(pd_core->get()) != 0) {
      LOG(ERROR) << "Failed clone model from pd_core[ " << pd_core << "], idx["
                 << next_idx << "]";
      return -1;
W
wangguibao 已提交
461
    }
W
wangguibao 已提交
462 463 464 465 466 467
    td->current_idx = next_idx;
    LOG(WARNING) << "td_core[" << td->cores[td->current_idx]
                 << "] clone model from pd_core[" << pd_core
                 << "] succ, cur_idx[" << td->current_idx << "].";
    return 0;
  }
W
wangguibao 已提交
468

W
wangguibao 已提交
469
  virtual int thrd_initialize_impl() {
H
HexToString 已提交
470 471 472 473 474 475
    // memory pool to be inited in non-serving-threads
    if (MempoolWrapper::instance().thread_initialize() != 0) {
      LOG(ERROR) << "Failed thread initialize mempool";
      return -1;
    }

W
wangguibao 已提交
476 477 478 479 480 481 482 483
    ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
    if (!md || load_data(md, _pd->cores[_pd->current_idx]) != 0) {
      LOG(ERROR) << "Failed clone thread data, origin_core["
                 << _pd->cores[_pd->current_idx] << "].";
      return -1;
    }

    THREAD_SETSPECIFIC(DBReloadableInferEngine<EngineCore>::_skey, md);
H
HexToString 已提交
484
    im::bsf::AutoMutex lock(DBReloadableInferEngine<EngineCore>::_mutex);
W
wangguibao 已提交
485 486 487
    DBReloadableInferEngine<EngineCore>::_reload_vec.push_back(md);
    return 0;
  }
W
wangguibao 已提交
488

W
wangguibao 已提交
489 490 491
 protected:
  ModelData<EngineCore>*
      _pd;  // 进程级EngineCore,多个线程级EngineCore共用该对象的模型数据
W
wangguibao 已提交
492 493
};

Z
update  
zhangjun 已提交
494
template <typename PaddleInferenceCore>
M
bug fix  
MRXLT 已提交
495
#ifdef WITH_TRT
Z
update  
zhangjun 已提交
496
class FluidInferEngine : public DBReloadableInferEngine<PaddleInferenceCore> {
M
bug fix  
MRXLT 已提交
497
#else
Z
update  
zhangjun 已提交
498
class FluidInferEngine : public CloneDBReloadableInferEngine<PaddleInferenceCore> {
M
bug fix  
MRXLT 已提交
499 500
#endif
 public:  // NOLINT
W
wangguibao 已提交
501 502
  FluidInferEngine() {}
  ~FluidInferEngine() {}
H
HexToString 已提交
503
  typedef std::vector<paddle::PaddleTensor> TensorVector;
H
HexToString 已提交
504
  int infer_impl(const void* in, void* out, uint32_t batch_size = -1) {
505 506
    //First of all, get the real core acording to the template parameter 'PaddleInferenceCore'.
    PaddleInferenceCore* core =DBReloadableInferEngine<PaddleInferenceCore>::get_core();
W
wangguibao 已提交
507 508 509
    if (!core || !core->get()) {
      LOG(ERROR) << "Failed get fluid core in infer_impl()";
      return -1;
W
wangguibao 已提交
510
    }
H
HexToString 已提交
511 512 513
    //We use the for loop to process the input data.
    //Inside each for loop, use the in[i]->name as inputName and call 'core->GetInputHandle(inputName)' to get the pointer of InputData.
    //Set the lod and shape information of InputData first. then copy data from cpu to the core.
H
HexToString 已提交
514
    const TensorVector* tensorVector_in_pointer = reinterpret_cast<const TensorVector*>(in);
H
HexToString 已提交
515
    for (int i=0; i < tensorVector_in_pointer->size(); ++i) {
H
HexToString 已提交
516 517 518 519
      auto lod_tensor_in = core->GetInputHandle((*tensorVector_in_pointer)[i].name);
      lod_tensor_in->SetLoD((*tensorVector_in_pointer)[i].lod);
      lod_tensor_in->Reshape((*tensorVector_in_pointer)[i].shape);
      void* origin_data = (*tensorVector_in_pointer)[i].data.data();
H
HexToString 已提交
520 521
      //Because the core needs to determine the size of memory space according to the data type passed in.
      //The pointer type of data must be one of float *,int64_t*,int32_t* instead void*.
H
HexToString 已提交
522
      if ((*tensorVector_in_pointer)[i].dtype == paddle::PaddleDType::FLOAT32) {
H
HexToString 已提交
523
        float* data = static_cast<float*>(origin_data);
H
HexToString 已提交
524
        lod_tensor_in->CopyFromCpu(data);
H
HexToString 已提交
525
      }else if ((*tensorVector_in_pointer)[i].dtype == paddle::PaddleDType::INT64) {
H
HexToString 已提交
526
        int64_t* data = static_cast<int64_t*>(origin_data);
H
HexToString 已提交
527
        lod_tensor_in->CopyFromCpu(data);
H
HexToString 已提交
528
      }else if ((*tensorVector_in_pointer)[i].dtype == paddle::PaddleDType::INT32) {
H
HexToString 已提交
529
        int32_t* data = static_cast<int32_t*>(origin_data);
H
HexToString 已提交
530
        lod_tensor_in->CopyFromCpu(data);
H
HexToString 已提交
531
      }
W
wangjiawei04 已提交
532
    }
H
HexToString 已提交
533
    //After the input data is passed in, call 'core->Run()' perform the prediction process.
W
wangjiawei04 已提交
534
    if (!core->Run()) {
H
HexToString 已提交
535 536
        LOG(ERROR) << "Failed run fluid family core";
        return -1;
W
wangjiawei04 已提交
537
    }
H
HexToString 已提交
538 539 540
    
    //In order to get the results, first, call the 'core->GetOutputNames()' to get the name of output(which is a dict like {OutputName:pointer of OutputValue}).
    //Then, use for-loop to get OutputValue by calling 'core->GetOutputHandle'.
H
HexToString 已提交
541
    std::vector<std::string> outnames = core->GetOutputNames();
H
HexToString 已提交
542 543 544 545 546 547
    std::vector<int> output_shape;
    int out_num =0;
    int dataType =0;
    void* databuf_data = NULL;
    char* databuf_char = NULL;
    size_t databuf_size = 0;
H
HexToString 已提交
548
    TensorVector* tensorVector_out_pointer = reinterpret_cast<TensorVector*>(out);
H
HexToString 已提交
549
    if (!tensorVector_out_pointer) {
H
HexToString 已提交
550
      LOG(ERROR) << "tensorVector_out_pointer is nullptr,error";
W
wangguibao 已提交
551 552
      return -1;
    }
H
HexToString 已提交
553 554
    //Get the type and shape information of OutputData first. then copy data to cpu from the core.
    //The pointer type of data_out must be one of float *,int64_t*,int32_t* instead void*.
H
HexToString 已提交
555
    for (int i=0; i < outnames.size(); ++i) {
H
HexToString 已提交
556
      auto lod_tensor_out = core->GetOutputHandle(outnames[i]);
H
HexToString 已提交
557 558 559
      output_shape = lod_tensor_out->shape();
      out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
      dataType = lod_tensor_out->type();
H
HexToString 已提交
560
      if (dataType == paddle::PaddleDType::FLOAT32) {
H
HexToString 已提交
561
        databuf_size = out_num*sizeof(float);
H
HexToString 已提交
562
        databuf_data = MempoolWrapper::instance().malloc(databuf_size);
H
HexToString 已提交
563 564 565 566 567
        if (!databuf_data) {
            LOG(ERROR) << "Malloc failed, size: " << databuf_size;
            return -1;
        }
        float* data_out = reinterpret_cast<float*>(databuf_data);
H
HexToString 已提交
568
        lod_tensor_out->CopyToCpu(data_out);
H
HexToString 已提交
569
        databuf_char = reinterpret_cast<char*>(data_out);
H
HexToString 已提交
570
      }else if (dataType == paddle::PaddleDType::INT64) {
H
HexToString 已提交
571
        databuf_size = out_num*sizeof(int64_t);
H
HexToString 已提交
572
        databuf_data = MempoolWrapper::instance().malloc(databuf_size);
H
HexToString 已提交
573 574 575 576
        if (!databuf_data) {
            LOG(ERROR) << "Malloc failed, size: " << databuf_size;
            return -1;
        }
H
HexToString 已提交
577
        int64_t* data_out = reinterpret_cast<int64_t*>(databuf_data);
H
HexToString 已提交
578
        lod_tensor_out->CopyToCpu(data_out);
H
HexToString 已提交
579
        databuf_char = reinterpret_cast<char*>(data_out);
H
HexToString 已提交
580
      }else if (dataType == paddle::PaddleDType::INT32) {
H
HexToString 已提交
581
        databuf_size = out_num*sizeof(int32_t);
H
HexToString 已提交
582
        databuf_data = MempoolWrapper::instance().malloc(databuf_size);
H
HexToString 已提交
583 584 585 586 587 588 589
        if (!databuf_data) {
            LOG(ERROR) << "Malloc failed, size: " << databuf_size;
            return -1;
        }
        int32_t* data_out = reinterpret_cast<int32_t*>(databuf_data);
        lod_tensor_out->CopyToCpu(data_out);
        databuf_char = reinterpret_cast<char*>(data_out);
H
HexToString 已提交
590
      }
H
HexToString 已提交
591 592 593
      //Because task scheduling requires OPs to use 'Channel'(which is a data structure) to transfer data between OPs.
      //We need to copy the processed data to the 'Channel' for the next OP.
      //In this function, it means we should copy the 'databuf_char' to the pointer 'void* out'.(which is also called ‘tensorVector_out_pointer’)
H
HexToString 已提交
594 595 596 597 598
      paddle::PaddleTensor tensor_out;
      tensor_out.name = outnames[i];
      tensor_out.dtype = paddle::PaddleDType(dataType);
      tensor_out.shape.assign(output_shape.begin(), output_shape.end());
      std::vector<std::vector<size_t>> out_lod = lod_tensor_out->lod();
H
HexToString 已提交
599
      for (int li=0; li < out_lod.size(); ++li) {
H
HexToString 已提交
600 601 602 603
        std::vector<size_t> lod_element;
        lod_element.assign(out_lod[li].begin(), out_lod[li].end());
        tensor_out.lod.push_back(lod_element);
      }
H
HexToString 已提交
604
      paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
H
HexToString 已提交
605 606
      tensor_out.data = paddleBuf;
      tensorVector_out_pointer->push_back(tensor_out);
H
HexToString 已提交
607
    }
W
wangguibao 已提交
608 609
    return 0;
  }
H
HexToString 已提交
610

H
HexToString 已提交
611 612
  int task_infer_impl(const BatchTensor& in, BatchTensor& out) {  // NOLINT
    return infer_impl(&in, &out);
H
HexToString 已提交
613 614 615
  }


W
wangguibao 已提交
616 617
};

W
wangguibao 已提交
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
typedef FactoryPool<InferEngine> StaticInferFactory;

class VersionedInferEngine : public InferEngine {
 public:
  VersionedInferEngine() { _versions.clear(); }
  ~VersionedInferEngine() {}

  int proc_initialize(const configure::EngineDesc& conf) {
    if (proc_initialize(conf, false) != 0) {
      LOG(ERROR) << "Failed proc intialize engine: " << conf.name().c_str();
      return -1;
    }

    LOG(WARNING) << "Succ proc initialize engine: " << conf.name().c_str();
    return 0;
  }

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    std::string engine_type = conf.type();
    InferEngine* engine =
        StaticInferFactory::instance().generate_object(engine_type);
    if (!engine) {
      LOG(ERROR) << "Failed generate engine with type:" << engine_type;
      return -1;
    }
643
#ifndef BCLOUD
M
MRXLT 已提交
644
    VLOG(2) << "FLAGS_logtostderr " << FLAGS_logtostderr;
M
MRXLT 已提交
645
    int tmp = FLAGS_logtostderr;
W
wangguibao 已提交
646 647 648 649
    if (engine->proc_initialize(conf, version) != 0) {
      LOG(ERROR) << "Failed initialize engine, type:" << engine_type;
      return -1;
    }
M
bug fix  
MRXLT 已提交
650
    VLOG(2) << "FLAGS_logtostderr " << FLAGS_logtostderr;
M
MRXLT 已提交
651
    FLAGS_logtostderr = tmp;
652 653 654 655 656 657
#else
    if (engine->proc_initialize(conf, version) != 0) {
      LOG(ERROR) << "Failed initialize engine, type:" << engine_type;
      return -1;
    }
#endif
W
wangguibao 已提交
658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682
    auto r = _versions.insert(std::make_pair(engine->version(), engine));
    if (!r.second) {
      LOG(ERROR) << "Failed insert item: " << engine->version()
                 << ", type: " << engine_type;
      return -1;
    }
    LOG(WARNING) << "Succ proc initialize version engine: "
                 << engine->version();
    return 0;
  }

  int proc_finalize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->proc_finalize() != 0) {
        LOG(ERROR) << "Failed proc finalize version engine: " << iter->first;
      }
      LOG(WARNING) << "Succ proc finalize version engine: " << iter->first;
    }
    return 0;
  }

  int thrd_initialize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_initialize() != 0) {
        LOG(ERROR) << "Failed thrd initialize version engine: " << iter->first;
W
wangguibao 已提交
683
        return -1;
W
wangguibao 已提交
684 685
      }
      LOG(WARNING) << "Succ thrd initialize version engine: " << iter->first;
W
wangguibao 已提交
686
    }
W
wangguibao 已提交
687 688
    return 0;
  }
W
wangguibao 已提交
689

W
wangguibao 已提交
690 691 692 693
  int thrd_clear() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_clear() != 0) {
        LOG(ERROR) << "Failed thrd clear version engine: " << iter->first;
W
wangguibao 已提交
694
        return -1;
W
wangguibao 已提交
695
      }
W
wangguibao 已提交
696
    }
W
wangguibao 已提交
697 698
    return 0;
  }
W
wangguibao 已提交
699

W
wangguibao 已提交
700 701 702 703 704 705 706
  int thrd_finalize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_finalize() != 0) {
        LOG(ERROR) << "Failed thrd finalize version engine: " << iter->first;
        return -1;
      }
      LOG(WARNING) << "Succ thrd finalize version engine: " << iter->first;
W
wangguibao 已提交
707
    }
W
wangguibao 已提交
708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740
    return 0;
  }

  int reload() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->reload() != 0) {
        LOG(ERROR) << "Failed reload version engine: " << iter->first;
        return -1;
      }
      LOG(WARNING) << "Succ reload version engine: " << iter->first;
    }
    return 0;
  }

  uint64_t version() const {
    InferEngine* engine = default_engine();
    if (engine) {
      return engine->version();
    } else {
      return uint64_t(-1);
    }
  }

  // inference interface
  InferEngine* default_engine() const {
    if (_versions.size() != 1) {
      LOG(ERROR) << "Ambiguous default engine version:" << _versions.size();
      return NULL;
    }

    return _versions.begin()->second;
  }

H
HexToString 已提交
741
  int infer(const void* in, void* out, uint32_t batch_size) {
W
wangguibao 已提交
742 743 744 745 746
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get default engine";
      return -1;
    }
H
HexToString 已提交
747
    return engine->infer(in, out, batch_size);
W
wangguibao 已提交
748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765
  }

  template <typename T>
  T* get_core() {
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get core";
      return NULL;
    }
    auto db_engine = dynamic_cast<DBReloadableInferEngine<T>*>(engine);
    if (db_engine) {
      return db_engine->get_core();
    }
    LOG(WARNING) << "fail to get core";
    return NULL;
  }

  // versioned inference interface
H
HexToString 已提交
766
  int infer(const void* in, void* out, uint32_t batch_size, uint64_t version) {
W
wangguibao 已提交
767 768 769 770 771 772
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
      return -1;
    }

H
HexToString 已提交
773
    return iter->second->infer(in, out, batch_size);
W
wangguibao 已提交
774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799
  }

  template <typename T>
  T* get_core(uint64_t version) {
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
      return NULL;
    }

    auto db_engine = dynamic_cast<DBReloadableInferEngine<T>*>(iter->second);
    if (db_engine) {
      return db_engine->get_core();
    }
    LOG(WARNING) << "fail to get core for " << version;
    return NULL;
  }

  // --
  int proc_initialize_impl(const configure::EngineDesc& conf, bool) {
    return -1;
  }
  int thrd_initialize_impl() { return -1; }
  int thrd_finalize_impl() { return -1; }
  int thrd_clear_impl() { return -1; }
  int proc_finalize_impl() { return -1; }
H
HexToString 已提交
800 801
  int infer_impl(const void* in, void* out, uint32_t batch_size = -1) { return -1; }
  int task_infer_impl(const BatchTensor& in, BatchTensor& out) {  // NOLINT
H
HexToString 已提交
802 803
    return -1;
  }  // NOLINT
W
wangguibao 已提交
804 805 806

 private:
  boost::unordered_map<uint64_t, InferEngine*> _versions;
W
wangguibao 已提交
807 808
};

W
wangguibao 已提交
809 810 811 812 813 814 815 816 817 818 819 820 821 822 823
class InferManager {
 public:
  static InferManager& instance() {
    static InferManager ins;
    return ins;
  }

  int proc_initialize(const char* path, const char* file) {
    ModelToolkitConf model_toolkit_conf;
    if (configure::read_proto_conf(path, file, &model_toolkit_conf) != 0) {
      LOG(ERROR) << "failed load infer config, path: " << path << "/" << file;
      return -1;
    }
    size_t engine_num = model_toolkit_conf.engines_size();
    for (size_t ei = 0; ei < engine_num; ++ei) {
B
barrierye 已提交
824 825
      LOG(INFO) << "model_toolkit_conf.engines(" << ei
                << ").name: " << model_toolkit_conf.engines(ei).name();
W
wangguibao 已提交
826 827 828 829 830 831 832 833
      std::string engine_name = model_toolkit_conf.engines(ei).name();
      VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine();
      if (!engine) {
        LOG(ERROR) << "Failed generate versioned engine: " << engine_name;
        return -1;
      }
      if (engine->proc_initialize(model_toolkit_conf.engines(ei)) != 0) {
        LOG(ERROR) << "Failed initialize version engine, name:" << engine_name;
W
wangguibao 已提交
834
        return -1;
W
wangguibao 已提交
835 836 837 838 839 840 841
      }
      auto r = _map.insert(std::make_pair(engine_name, engine));
      if (!r.second) {
        LOG(ERROR) << "Failed insert item: " << engine_name;
        return -1;
      }
      LOG(WARNING) << "Succ proc initialize engine: " << engine_name;
W
wangguibao 已提交
842
    }
W
wangguibao 已提交
843 844 845 846 847 848 849
    return 0;
  }

  int thrd_initialize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_initialize() != 0) {
        LOG(ERROR) << "Failed thrd initialize engine, name: " << it->first;
W
wangguibao 已提交
850
        return -1;
W
wangguibao 已提交
851 852
      }
      LOG(WARNING) << "Succ thrd initialize engine, name: " << it->first;
W
wangguibao 已提交
853
    }
W
wangguibao 已提交
854 855
    return 0;
  }
W
wangguibao 已提交
856

W
wangguibao 已提交
857 858 859 860 861 862 863 864 865
  int thrd_clear() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_clear() != 0) {
        LOG(ERROR) << "Failed thrd clear engine, name: " << it->first;
        return -1;
      }
    }
    return 0;
  }
W
wangguibao 已提交
866

W
wangguibao 已提交
867 868 869 870 871 872 873 874 875
  int reload() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->reload() != 0) {
        LOG(ERROR) << "Failed reload engine, name: " << it->first;
        return -1;
      }
    }
    return 0;
  }
W
wangguibao 已提交
876

W
wangguibao 已提交
877 878 879 880 881 882 883 884 885 886
  int thrd_finalize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_finalize() != 0) {
        LOG(ERROR) << "Failed thrd finalize engine, name: " << it->first;
        return -1;
      }
      LOG(WARNING) << "Succ thrd finalize engine, name: " << it->first;
    }
    return 0;
  }
W
wangguibao 已提交
887

W
wangguibao 已提交
888 889 890 891 892 893 894 895
  int proc_finalize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->proc_finalize() != 0) {
        LOG(ERROR) << "Failed proc finalize engine, name: " << it->first;
        return -1;
      }
      LOG(WARNING) << "Succ proc finalize engine, name: " << it->first;
    }
W
wangguibao 已提交
896
    _map.clear();
W
wangguibao 已提交
897 898 899 900
    return 0;
  }

  // Inference interface
H
HexToString 已提交
901 902 903 904
  int infer(const char* model_name,
            const void* in,
            void* out,
            uint32_t batch_size = -1) {
W
wangguibao 已提交
905 906 907 908 909
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return -1;
    }
H
HexToString 已提交
910
    return it->second->infer(in, out, batch_size);
W
wangguibao 已提交
911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929
  }

  template <typename T>
  T* get_core(const char* model_name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return NULL;
    }
    auto infer_engine =
        dynamic_cast<DBReloadableInferEngine<T>*>(it->second->default_engine());
    if (infer_engine) {
      return infer_engine->get_core();
    }
    LOG(WARNING) << "fail to get core for " << model_name;
    return NULL;
  }

  // Versioned inference interface
H
HexToString 已提交
930 931 932 933 934
  int infer(const char* model_name, 
            const void* in,
            void* out,
            uint32_t batch_size,
            uint64_t version) {
W
wangguibao 已提交
935 936 937 938 939
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return -1;
    }
H
HexToString 已提交
940
    return it->second->infer(in, out, batch_size, version);
W
wangguibao 已提交
941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971
  }

  template <typename T>
  T* get_core(const char* model_name, uint64_t version) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return NULL;
    }
    return it->second->get_core<T>(version);
  }

  int query_version(const std::string& model, uint64_t& version) {  // NOLINT
    auto it = _map.find(model);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model;
      return -1;
    }
    auto infer_engine = it->second->default_engine();
    if (!infer_engine) {
      LOG(WARNING) << "Cannot get default engine for model:" << model;
      return -1;
    }
    version = infer_engine->version();
    LOG(INFO) << "Succ get version: " << version << " for model: " << model;
    return 0;
  }

 private:
  boost::unordered_map<std::string, VersionedInferEngine*> _map;
};
W
wangguibao 已提交
972

W
wangguibao 已提交
973 974 975
}  // namespace predictor
}  // namespace paddle_serving
}  // namespace baidu