infer.h 30.6 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
W
wangguibao 已提交
16
#include <sys/stat.h>
W
wangguibao 已提交
17
#include <sys/types.h>
W
wangguibao 已提交
18
#include <unistd.h>
Z
zhangjun 已提交
19
#include <pthread.h>
W
wangguibao 已提交
20
#include <string>
M
MRXLT 已提交
21
#include <utility>
W
wangguibao 已提交
22
#include <vector>
H
HexToString 已提交
23
#include <numeric>
G
guru4elephant 已提交
24
#include "core/predictor/common/inner_common.h"
H
HexToString 已提交
25
#include "core/predictor/framework/bsf.h"
G
guru4elephant 已提交
26 27
#include "core/predictor/framework/factory.h"
#include "core/predictor/framework/infer_data.h"
W
wangjiawei04 已提交
28
#include "paddle_inference_api.h"  // NOLINT
W
wangguibao 已提交
29 30 31 32
namespace baidu {
namespace paddle_serving {
namespace predictor {

W
wangguibao 已提交
33 34
using configure::ModelToolkitConf;

Z
zhangjun 已提交
35 36 37 38 39 40 41 42 43 44 45
class AutoLock {
 public:
  explicit AutoLock(pthread_mutex_t& mutex) : _mut(mutex) {
    pthread_mutex_lock(&mutex);
  }
  ~AutoLock() { pthread_mutex_unlock(&_mut); }

 private:
  pthread_mutex_t& _mut;
};

Z
update  
zhangjun 已提交
46
class GlobalCreateMutex {
Z
zhangjun 已提交
47 48 49 50
 public:
  pthread_mutex_t& mutex() { return _mut; }

  static pthread_mutex_t& instance() {
Z
update  
zhangjun 已提交
51
    static GlobalCreateMutex gmutex;
Z
zhangjun 已提交
52 53 54 55
    return gmutex.mutex();
  }

 private:
Z
update  
zhangjun 已提交
56
  GlobalCreateMutex() { pthread_mutex_init(&_mut, NULL); }
Z
zhangjun 已提交
57 58 59
  pthread_mutex_t _mut;
};

W
wangguibao 已提交
60
class InferEngine {
W
wangguibao 已提交
61 62 63 64 65 66 67 68 69 70
 public:
  virtual ~InferEngine() {}

  virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
    return proc_initialize_impl(conf, version);
  }
  virtual int proc_finalize() { return proc_finalize_impl(); }
  virtual int thrd_initialize() { return thrd_initialize_impl(); }
  virtual int thrd_clear() { return thrd_clear_impl(); }
  virtual int thrd_finalize() { return thrd_finalize_impl(); }
H
HexToString 已提交
71 72 73
  virtual int infer(const void* in, void* out, uint32_t batch_size = -1) {
    return infer_impl(in, out, batch_size);
  }
W
wangguibao 已提交
74 75 76 77 78 79 80 81 82 83 84 85

  virtual int reload() = 0;

  virtual uint64_t version() const = 0;

  // begin: framework inner call
  virtual int proc_initialize_impl(const configure::EngineDesc& conf,
                                   bool version) = 0;
  virtual int thrd_initialize_impl() = 0;
  virtual int thrd_finalize_impl() = 0;
  virtual int thrd_clear_impl() = 0;
  virtual int proc_finalize_impl() = 0;
H
HexToString 已提交
86
  virtual int infer_impl(const void* in,
H
HexToString 已提交
87 88
                          void* out,
                          uint32_t batch_size = -1) = 0;
H
HexToString 已提交
89
  virtual int task_infer_impl(const BatchTensor& in,
H
HexToString 已提交
90 91
                          BatchTensor& out) = 0;  // NOLINT

W
wangguibao 已提交
92 93 94 95 96 97
  // end: framework inner call
};

class ReloadableInferEngine : public InferEngine {
 public:
  virtual ~ReloadableInferEngine() {}
W
wangguibao 已提交
98

W
wangguibao 已提交
99 100 101 102 103
  union last_check_status {
    time_t last_timestamp;
    uint64_t last_md5sum;
    uint64_t last_revision;
  };
W
wangguibao 已提交
104

Z
update  
zhangjun 已提交
105
  virtual int load(const configure::EngineDesc& conf) = 0;
H
HexToString 已提交
106
  typedef im::bsf::Task<Tensor, Tensor> TaskT;
W
wangguibao 已提交
107 108 109 110

  int proc_initialize_impl(const configure::EngineDesc& conf, bool version) {
    _reload_tag_file = conf.reloadable_meta();
    _reload_mode_tag = conf.reloadable_type();
Z
zhangjun 已提交
111
    _model_data_path = conf.model_dir();
W
wangguibao 已提交
112 113 114
    _infer_thread_num = conf.runtime_thread_num();
    _infer_batch_size = conf.batch_infer_size();
    _infer_batch_align = conf.enable_batch_align();
115

Z
update  
zhangjun 已提交
116
    _conf = conf;
Z
zhangjun 已提交
117

Z
update  
zhangjun 已提交
118
    if (!check_need_reload() || load(conf) != 0) {
W
wangguibao 已提交
119 120
      LOG(ERROR) << "Failed load model_data_path" << _model_data_path;
      return -1;
W
wangguibao 已提交
121
    }
W
wangguibao 已提交
122 123 124 125

    if (parse_version_info(conf, version) != 0) {
      LOG(ERROR) << "Failed parse version info";
      return -1;
W
wangguibao 已提交
126
    }
W
wangguibao 已提交
127 128 129 130 131 132 133 134 135

    LOG(WARNING) << "Succ load model_data_path" << _model_data_path;
    return 0;
  }

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    if (proc_initialize_impl(conf, version) != 0) {
      LOG(ERROR) << "Failed proc initialize impl";
      return -1;
W
wangguibao 已提交
136
    }
H
HexToString 已提交
137 138 139 140 141 142 143 144 145 146 147 148

    // init bsf framework
    if (_infer_thread_num <= 0) {
      return 0;
    }

    // init bsf framework
    im::bsf::TaskExecutor<TaskT>::instance()->set_thread_init_fn(
        boost::bind(&InferEngine::thrd_initialize_impl, this));
    im::bsf::TaskExecutor<TaskT>::instance()->set_thread_reset_fn(
        boost::bind(&InferEngine::thrd_clear_impl, this));
    im::bsf::TaskExecutor<TaskT>::instance()->set_thread_callback_fn(
H
HexToString 已提交
149
        boost::bind(&InferEngine::task_infer_impl, this, _1, _2));
H
HexToString 已提交
150 151 152 153 154 155 156 157 158 159 160 161
    im::bsf::TaskExecutor<TaskT>::instance()->set_batch_size(_infer_batch_size);
    im::bsf::TaskExecutor<TaskT>::instance()->set_batch_align(
        _infer_batch_align);
    if (im::bsf::TaskExecutor<TaskT>::instance()->start(_infer_thread_num) !=
        0) {
      LOG(ERROR) << "Failed start bsf executor, threads:" << _infer_thread_num;
      return -1;
    }

    LOG(WARNING) << "Enable batch schedule framework, thread_num:"
                 << _infer_thread_num << ", batch_size:" << _infer_batch_size
                 << ", enable_batch_align:" << _infer_batch_align;
W
wangguibao 已提交
162 163
    return 0;
  }
W
wangguibao 已提交
164

H
HexToString 已提交
165 166
  int infer(const void* in, void* out, uint32_t batch_size = -1) {
    if (_infer_thread_num <= 0) {
H
HexToString 已提交
167
      return infer_impl(in, out, batch_size);
H
HexToString 已提交
168 169 170 171 172 173 174 175
    }

    im::bsf::TaskManager<Tensor, Tensor> task_manager;
    task_manager.schedule(*(reinterpret_cast<const BatchTensor*>(in)),
                          *(reinterpret_cast<BatchTensor*>(out)));
    task_manager.wait();
    return 0;
  }
W
wangguibao 已提交
176

W
wangguibao 已提交
177 178 179 180 181 182
  int thrd_initialize() {
    if (_infer_thread_num > 0) {
      return 0;
    }
    return thrd_initialize_impl();
  }
W
wangguibao 已提交
183

W
wangguibao 已提交
184 185 186 187
  int thrd_clear() {
    if (_infer_thread_num > 0) {
      return 0;
    }
W
wangguibao 已提交
188

W
wangguibao 已提交
189 190
    return thrd_clear_impl();
  }
W
wangguibao 已提交
191

W
wangguibao 已提交
192 193 194 195 196
  int proc_finalize() {
    if (proc_finalize_impl() != 0) {
      LOG(ERROR) << "Failed proc finalize impl";
      return -1;
    }
W
wangguibao 已提交
197

H
HexToString 已提交
198 199 200
    if (_infer_thread_num > 0) {
      im::bsf::TaskExecutor<TaskT>::instance()->stop();
    }
W
wangguibao 已提交
201 202
    return 0;
  }
W
wangguibao 已提交
203

W
wangguibao 已提交
204 205 206
  int reload() {
    if (check_need_reload()) {
      LOG(WARNING) << "begin reload model[" << _model_data_path << "].";
Z
zhangjun 已提交
207
      return load(_conf);
W
wangguibao 已提交
208 209 210 211 212 213
    }
    return 0;
  }

  uint64_t version() const { return _version; }
  uint32_t thread_num() const { return _infer_thread_num; }
W
wangguibao 已提交
214

W
wangguibao 已提交
215 216 217 218 219
 private:
  int parse_version_info(const configure::EngineDesc& config, bool version) {
    _version = uint64_t(-1);
    return 0;
  }
W
wangguibao 已提交
220

W
wangguibao 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
  bool check_need_reload() {
    if (_reload_mode_tag == "timestamp_ne") {
      return check_timestamp_ne();
    } else if (_reload_mode_tag == "timestamp_gt") {
      return check_timestamp_gt();
    } else if (_reload_mode_tag == "md5sum") {
      return check_md5sum();
    } else if (_reload_mode_tag == "revision") {
      return check_revision();
    } else if (_reload_mode_tag == "none") {
      return false;
    } else {
      LOG(ERROR) << "Not support check type: " << _reload_mode_tag;
      return false;
    }
  }

  bool check_timestamp_ne() {
    struct stat st;
    if (stat(_reload_tag_file.c_str(), &st) != 0) {
      LOG(ERROR) << "Failed stat config file:" << _reload_tag_file;
      return false;
    }
W
wangguibao 已提交
244

W
wangguibao 已提交
245 246 247
    if ((st.st_mode & S_IFREG) && st.st_mtime != _last_status.last_timestamp) {
      _last_status.last_timestamp = st.st_mtime;
      return true;
W
wangguibao 已提交
248 249
    }

W
wangguibao 已提交
250 251
    return false;
  }
W
wangguibao 已提交
252

W
wangguibao 已提交
253 254 255 256 257 258
  bool check_timestamp_gt() {
    struct stat st;
    if (stat(_reload_tag_file.c_str(), &st) != 0) {
      LOG(ERROR) << "Failed stat config file:" << _reload_tag_file;
      return false;
    }
W
wangguibao 已提交
259

W
wangguibao 已提交
260 261 262
    if ((st.st_mode & S_IFREG) && st.st_mtime > _last_status.last_timestamp) {
      _last_status.last_timestamp = st.st_mtime;
      return true;
W
wangguibao 已提交
263 264
    }

W
wangguibao 已提交
265 266 267 268 269 270 271 272 273
    return false;
  }

  bool check_md5sum() { return false; }

  bool check_revision() { return false; }

 protected:
  std::string _model_data_path;
Z
update  
zhangjun 已提交
274
  configure::EngineDesc _conf;
W
wangguibao 已提交
275 276 277 278 279 280 281 282 283 284

 private:
  std::string _reload_tag_file;
  std::string _reload_mode_tag;
  last_check_status _last_status;
  uint32_t _infer_thread_num;
  uint32_t _infer_batch_size;
  bool _infer_batch_align;
  uint64_t _version;
};
W
wangguibao 已提交
285

W
wangguibao 已提交
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
template <typename EngineCore>
struct ModelData {
  ModelData() : current_idx(1) {
    cores[0] = NULL;
    cores[1] = NULL;
  }

  ~ModelData() {
    delete cores[0];
    delete cores[1];
  }

  EngineCore* cores[2];
  uint32_t current_idx;
};

template <typename EngineCore>
class DBReloadableInferEngine : public ReloadableInferEngine {
 public:
  virtual ~DBReloadableInferEngine() {}

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    THREAD_KEY_CREATE(&_skey, NULL);
    THREAD_MUTEX_INIT(&_mutex, NULL);
    return ReloadableInferEngine::proc_initialize(conf, version);
  }

Z
update  
zhangjun 已提交
313
  virtual int load(const configure::EngineDesc& conf) {
W
wangguibao 已提交
314 315
    if (_reload_vec.empty()) {
      return 0;
W
wangguibao 已提交
316 317
    }

W
wangguibao 已提交
318
    for (uint32_t ti = 0; ti < _reload_vec.size(); ++ti) {
Z
update  
zhangjun 已提交
319
      if (load_data(_reload_vec[ti], conf) != 0) {
W
wangguibao 已提交
320 321 322 323 324
        LOG(ERROR) << "Failed reload engine model: " << ti;
        return -1;
      }
    }

Z
update  
zhangjun 已提交
325
    LOG(WARNING) << "Succ load engine, path: " << conf.model_dir();
W
wangguibao 已提交
326

W
wangguibao 已提交
327 328
    return 0;
  }
W
wangguibao 已提交
329

330
  int load_data(ModelData<EngineCore>* md,
Z
update  
zhangjun 已提交
331
                const configure::EngineDesc& conf) {
W
wangguibao 已提交
332 333 334
    uint32_t next_idx = (md->current_idx + 1) % 2;
    if (md->cores[next_idx]) {
      delete md->cores[next_idx];
W
wangguibao 已提交
335 336
    }

W
wangguibao 已提交
337
    md->cores[next_idx] = new (std::nothrow) EngineCore;
338

H
HexToString 已提交
339
    // params.dump();
Z
update  
zhangjun 已提交
340 341
    if (!md->cores[next_idx] || md->cores[next_idx]->create(conf) != 0) {
      LOG(ERROR) << "Failed create model, path: " << conf.model_dir();
W
wangguibao 已提交
342
      return -1;
W
wangguibao 已提交
343
    }
W
wangguibao 已提交
344 345 346
    md->current_idx = next_idx;
    return 0;
  }
W
wangguibao 已提交
347

W
wangguibao 已提交
348 349
  virtual int thrd_initialize_impl() {
    // memory pool to be inited in non-serving-threads
H
HexToString 已提交
350 351 352 353
    if (MempoolWrapper::instance().thread_initialize() != 0) {
      LOG(ERROR) << "Failed thread initialize mempool";
      return -1;
    }
W
wangguibao 已提交
354

W
wangguibao 已提交
355
    ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
Z
update  
zhangjun 已提交
356
    if (!md || load_data(md, _conf) != 0) {
357
      LOG(ERROR) << "Failed create thread data from "
Z
zhangjun 已提交
358
                 << _conf.model_dir();
W
wangguibao 已提交
359
      return -1;
W
wangguibao 已提交
360 361
    }

W
wangguibao 已提交
362
    THREAD_SETSPECIFIC(_skey, md);
H
HexToString 已提交
363
    im::bsf::AutoMutex lock(_mutex);
W
wangguibao 已提交
364 365 366 367 368 369
    _reload_vec.push_back(md);
    return 0;
  }

  int thrd_clear_impl() {
    // for non-serving-threads
H
HexToString 已提交
370 371 372 373
    if (MempoolWrapper::instance().thread_clear() != 0) {
      LOG(ERROR) << "Failed thread clear mempool";
      return -1;
    }
W
wangguibao 已提交
374 375 376 377
    return 0;
  }

  int thrd_finalize_impl() { return 0; }
W
wangguibao 已提交
378

W
wangguibao 已提交
379 380 381 382 383
  int proc_finalize_impl() {
    THREAD_KEY_DELETE(_skey);
    THREAD_MUTEX_DESTROY(&_mutex);
    return 0;
  }
W
wangguibao 已提交
384

W
wangguibao 已提交
385 386 387 388 389 390
  EngineCore* get_core() {
    ModelData<EngineCore>* md =
        (ModelData<EngineCore>*)THREAD_GETSPECIFIC(_skey);
    if (!md) {
      LOG(ERROR) << "Failed get thread specific data";
      return NULL;
W
wangguibao 已提交
391
    }
W
wangguibao 已提交
392 393
    return md->cores[md->current_idx];
  }
W
wangguibao 已提交
394

W
wangguibao 已提交
395 396 397 398
 protected:
  THREAD_KEY_T _skey;
  THREAD_MUTEX_T _mutex;
  std::vector<ModelData<EngineCore>*> _reload_vec;
W
wangguibao 已提交
399

W
wangguibao 已提交
400 401
 private:
};
W
wangguibao 已提交
402

W
wangguibao 已提交
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
// 多个EngineCore共用同一份模型数据
template <typename EngineCore>
class CloneDBReloadableInferEngine
    : public DBReloadableInferEngine<EngineCore> {
 public:
  virtual ~CloneDBReloadableInferEngine() {}

  virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
    _pd = new (std::nothrow) ModelData<EngineCore>;
    if (!_pd) {
      LOG(ERROR) << "Failed to allocate for ProcData";
      return -1;
    }
    return DBReloadableInferEngine<EngineCore>::proc_initialize(conf, version);
  }

Z
update  
zhangjun 已提交
419
  virtual int load(const configure::EngineDesc& conf) {
W
wangguibao 已提交
420 421
    // 加载进程级模型数据
    if (!_pd ||
Z
update  
zhangjun 已提交
422
        DBReloadableInferEngine<EngineCore>::load_data(_pd, conf) != 0) {
Z
zhangjun 已提交
423
      LOG(ERROR) << "Failed to create common model from [" << conf.model_dir()
W
wangguibao 已提交
424 425 426 427
                 << "].";
      return -1;
    }
    LOG(WARNING) << "Succ load common model[" << _pd->cores[_pd->current_idx]
Z
update  
zhangjun 已提交
428
                 << "], path[" << conf.model_dir() << "].";
W
wangguibao 已提交
429 430 431 432 433 434 435 436 437 438 439 440 441

    if (DBReloadableInferEngine<EngineCore>::_reload_vec.empty()) {
      return 0;
    }

    for (uint32_t ti = 0;
         ti < DBReloadableInferEngine<EngineCore>::_reload_vec.size();
         ++ti) {
      if (load_data(DBReloadableInferEngine<EngineCore>::_reload_vec[ti],
                    _pd->cores[_pd->current_idx]) != 0) {
        LOG(ERROR) << "Failed reload engine model: " << ti;
        return -1;
      }
W
wangguibao 已提交
442 443
    }

Z
update  
zhangjun 已提交
444
    LOG(WARNING) << "Succ load clone model, path[" << conf.model_dir() << "]";
W
wangguibao 已提交
445

W
wangguibao 已提交
446 447
    return 0;
  }
W
wangguibao 已提交
448

W
wangguibao 已提交
449 450 451 452 453
  // 加载线程级对象,多个线程级对象共用pd_core的模型数据
  int load_data(ModelData<EngineCore>* td, EngineCore* pd_core) {
    uint32_t next_idx = (td->current_idx + 1) % 2;
    if (td->cores[next_idx]) {
      delete td->cores[next_idx];
W
wangguibao 已提交
454 455
    }

W
wangguibao 已提交
456 457 458 459 460 461
    td->cores[next_idx] = new (std::nothrow) EngineCore;
    if (!td->cores[next_idx] ||
        td->cores[next_idx]->clone(pd_core->get()) != 0) {
      LOG(ERROR) << "Failed clone model from pd_core[ " << pd_core << "], idx["
                 << next_idx << "]";
      return -1;
W
wangguibao 已提交
462
    }
W
wangguibao 已提交
463 464 465 466 467 468
    td->current_idx = next_idx;
    LOG(WARNING) << "td_core[" << td->cores[td->current_idx]
                 << "] clone model from pd_core[" << pd_core
                 << "] succ, cur_idx[" << td->current_idx << "].";
    return 0;
  }
W
wangguibao 已提交
469

W
wangguibao 已提交
470
  virtual int thrd_initialize_impl() {
H
HexToString 已提交
471 472 473 474 475 476
    // memory pool to be inited in non-serving-threads
    if (MempoolWrapper::instance().thread_initialize() != 0) {
      LOG(ERROR) << "Failed thread initialize mempool";
      return -1;
    }

W
wangguibao 已提交
477 478 479 480 481 482 483 484
    ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
    if (!md || load_data(md, _pd->cores[_pd->current_idx]) != 0) {
      LOG(ERROR) << "Failed clone thread data, origin_core["
                 << _pd->cores[_pd->current_idx] << "].";
      return -1;
    }

    THREAD_SETSPECIFIC(DBReloadableInferEngine<EngineCore>::_skey, md);
H
HexToString 已提交
485
    im::bsf::AutoMutex lock(DBReloadableInferEngine<EngineCore>::_mutex);
W
wangguibao 已提交
486 487 488
    DBReloadableInferEngine<EngineCore>::_reload_vec.push_back(md);
    return 0;
  }
W
wangguibao 已提交
489

W
wangguibao 已提交
490 491 492
 protected:
  ModelData<EngineCore>*
      _pd;  // 进程级EngineCore,多个线程级EngineCore共用该对象的模型数据
W
wangguibao 已提交
493 494
};

H
HexToString 已提交
495
template <typename EngineCore>
M
bug fix  
MRXLT 已提交
496
#ifdef WITH_TRT
H
HexToString 已提交
497
class FluidInferEngine : public DBReloadableInferEngine<EngineCore> {
M
bug fix  
MRXLT 已提交
498
#else
H
HexToString 已提交
499
class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
M
bug fix  
MRXLT 已提交
500 501
#endif
 public:  // NOLINT
W
wangguibao 已提交
502 503
  FluidInferEngine() {}
  ~FluidInferEngine() {}
H
HexToString 已提交
504
  typedef std::vector<paddle::PaddleTensor> TensorVector;
H
HexToString 已提交
505
  int infer_impl(const void* in, void* out, uint32_t batch_size = -1) {
H
HexToString 已提交
506 507 508
    // First of all, get the real core acording to the
    // Template parameter <EngineCore>.
    EngineCore* core = DBReloadableInferEngine<EngineCore>::get_core();
W
wangguibao 已提交
509 510 511
    if (!core || !core->get()) {
      LOG(ERROR) << "Failed get fluid core in infer_impl()";
      return -1;
W
wangguibao 已提交
512
    }
H
HexToString 已提交
513 514 515 516 517 518 519
    // We use the for loop to process the input data.
    // Inside each for loop, use the in[i]->name as inputName and call
    // 'core->GetInputHandle(inputName)' to get the pointer of InputData.
    // Set the lod and shape information of InputData first.
    // Then copy data from cpu to the core.
    const TensorVector* tensorVector_in_pointer =
      reinterpret_cast<const TensorVector*>(in);
H
HexToString 已提交
520
    for (int i=0; i < tensorVector_in_pointer->size(); ++i) {
H
HexToString 已提交
521 522
      auto lod_tensor_in =
        core->GetInputHandle((*tensorVector_in_pointer)[i].name);
H
HexToString 已提交
523 524 525
      lod_tensor_in->SetLoD((*tensorVector_in_pointer)[i].lod);
      lod_tensor_in->Reshape((*tensorVector_in_pointer)[i].shape);
      void* origin_data = (*tensorVector_in_pointer)[i].data.data();
H
HexToString 已提交
526 527 528 529
      // Because the core needs to determine the size of memory space
      // according to the data type passed in.
      // The pointer type of data must be one of
      // float *,int64_t*,int32_t* instead void*.
H
HexToString 已提交
530
      if ((*tensorVector_in_pointer)[i].dtype == paddle::PaddleDType::FLOAT32) {
H
HexToString 已提交
531
        float* data = static_cast<float*>(origin_data);
H
HexToString 已提交
532
        lod_tensor_in->CopyFromCpu(data);
H
HexToString 已提交
533 534
      } else if ((*tensorVector_in_pointer)[i].dtype ==
                paddle::PaddleDType::INT64) {
H
HexToString 已提交
535
        int64_t* data = static_cast<int64_t*>(origin_data);
H
HexToString 已提交
536
        lod_tensor_in->CopyFromCpu(data);
H
HexToString 已提交
537 538
      } else if ((*tensorVector_in_pointer)[i].dtype ==
                paddle::PaddleDType::INT32) {
H
HexToString 已提交
539
        int32_t* data = static_cast<int32_t*>(origin_data);
H
HexToString 已提交
540
        lod_tensor_in->CopyFromCpu(data);
H
HexToString 已提交
541
      }
W
wangjiawei04 已提交
542
    }
H
HexToString 已提交
543 544
    // After the input data is passed in,
    // call 'core->Run()' perform the prediction process.
W
wangjiawei04 已提交
545
    if (!core->Run()) {
H
HexToString 已提交
546 547
        LOG(ERROR) << "Failed run fluid family core";
        return -1;
W
wangjiawei04 已提交
548
    }
H
HexToString 已提交
549 550 551 552
    // In order to get the results,
    // first, call the 'core->GetOutputNames()' to get the name of output
    // (which is a dict like {OutputName:pointer of OutputValue}).
    // Then, use for-loop to get OutputValue by calling 'core->GetOutputHandle'.
H
HexToString 已提交
553
    std::vector<std::string> outnames = core->GetOutputNames();
H
HexToString 已提交
554
    std::vector<int> output_shape;
H
HexToString 已提交
555 556
    int out_num = 0;
    int dataType = 0;
H
HexToString 已提交
557 558 559
    void* databuf_data = NULL;
    char* databuf_char = NULL;
    size_t databuf_size = 0;
H
HexToString 已提交
560 561
    TensorVector* tensorVector_out_pointer =
                  reinterpret_cast<TensorVector*>(out);
H
HexToString 已提交
562
    if (!tensorVector_out_pointer) {
H
HexToString 已提交
563
      LOG(ERROR) << "tensorVector_out_pointer is nullptr,error";
W
wangguibao 已提交
564 565
      return -1;
    }
H
HexToString 已提交
566 567 568 569
    // Get the type and shape information of OutputData first.
    // then copy data to cpu from the core.
    // The pointer type of data_out must be one of
    // float *,int64_t*,int32_t* instead void*.
H
HexToString 已提交
570
    for (int i=0; i < outnames.size(); ++i) {
H
HexToString 已提交
571
      auto lod_tensor_out = core->GetOutputHandle(outnames[i]);
H
HexToString 已提交
572
      output_shape = lod_tensor_out->shape();
H
HexToString 已提交
573 574
      out_num = std::accumulate(
          output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
H
HexToString 已提交
575
      dataType = lod_tensor_out->type();
H
HexToString 已提交
576
      if (dataType == paddle::PaddleDType::FLOAT32) {
H
HexToString 已提交
577
        databuf_size = out_num*sizeof(float);
H
HexToString 已提交
578
        databuf_data = MempoolWrapper::instance().malloc(databuf_size);
H
HexToString 已提交
579 580 581 582 583
        if (!databuf_data) {
            LOG(ERROR) << "Malloc failed, size: " << databuf_size;
            return -1;
        }
        float* data_out = reinterpret_cast<float*>(databuf_data);
H
HexToString 已提交
584
        lod_tensor_out->CopyToCpu(data_out);
H
HexToString 已提交
585
        databuf_char = reinterpret_cast<char*>(data_out);
H
HexToString 已提交
586
      } else if (dataType == paddle::PaddleDType::INT64) {
H
HexToString 已提交
587
        databuf_size = out_num*sizeof(int64_t);
H
HexToString 已提交
588
        databuf_data = MempoolWrapper::instance().malloc(databuf_size);
H
HexToString 已提交
589 590 591 592
        if (!databuf_data) {
            LOG(ERROR) << "Malloc failed, size: " << databuf_size;
            return -1;
        }
H
HexToString 已提交
593
        int64_t* data_out = reinterpret_cast<int64_t*>(databuf_data);
H
HexToString 已提交
594
        lod_tensor_out->CopyToCpu(data_out);
H
HexToString 已提交
595
        databuf_char = reinterpret_cast<char*>(data_out);
H
HexToString 已提交
596
      } else if (dataType == paddle::PaddleDType::INT32) {
H
HexToString 已提交
597
        databuf_size = out_num*sizeof(int32_t);
H
HexToString 已提交
598
        databuf_data = MempoolWrapper::instance().malloc(databuf_size);
H
HexToString 已提交
599 600 601 602 603 604 605
        if (!databuf_data) {
            LOG(ERROR) << "Malloc failed, size: " << databuf_size;
            return -1;
        }
        int32_t* data_out = reinterpret_cast<int32_t*>(databuf_data);
        lod_tensor_out->CopyToCpu(data_out);
        databuf_char = reinterpret_cast<char*>(data_out);
H
HexToString 已提交
606
      }
H
HexToString 已提交
607 608 609 610 611
      // Because task scheduling requires OPs to use 'Channel'
      // (which is a data structure) to transfer data between OPs.
      // We need to copy the processed data to the 'Channel' for the next OP.
      // In this function, it means we should copy the 'databuf_char' to
      // 'void* out'.(which is also called ‘tensorVector_out_pointer’)
H
HexToString 已提交
612 613 614 615 616
      paddle::PaddleTensor tensor_out;
      tensor_out.name = outnames[i];
      tensor_out.dtype = paddle::PaddleDType(dataType);
      tensor_out.shape.assign(output_shape.begin(), output_shape.end());
      std::vector<std::vector<size_t>> out_lod = lod_tensor_out->lod();
H
HexToString 已提交
617
      for (int li=0; li < out_lod.size(); ++li) {
H
HexToString 已提交
618 619 620 621
        std::vector<size_t> lod_element;
        lod_element.assign(out_lod[li].begin(), out_lod[li].end());
        tensor_out.lod.push_back(lod_element);
      }
H
HexToString 已提交
622
      paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
H
HexToString 已提交
623 624
      tensor_out.data = paddleBuf;
      tensorVector_out_pointer->push_back(tensor_out);
H
HexToString 已提交
625
    }
W
wangguibao 已提交
626 627
    return 0;
  }
H
HexToString 已提交
628

H
HexToString 已提交
629 630
  int task_infer_impl(const BatchTensor& in, BatchTensor& out) {  // NOLINT
    return infer_impl(&in, &out);
H
HexToString 已提交
631
  }
W
wangguibao 已提交
632 633
};

W
wangguibao 已提交
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
typedef FactoryPool<InferEngine> StaticInferFactory;

class VersionedInferEngine : public InferEngine {
 public:
  VersionedInferEngine() { _versions.clear(); }
  ~VersionedInferEngine() {}

  int proc_initialize(const configure::EngineDesc& conf) {
    if (proc_initialize(conf, false) != 0) {
      LOG(ERROR) << "Failed proc intialize engine: " << conf.name().c_str();
      return -1;
    }

    LOG(WARNING) << "Succ proc initialize engine: " << conf.name().c_str();
    return 0;
  }

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    std::string engine_type = conf.type();
    InferEngine* engine =
        StaticInferFactory::instance().generate_object(engine_type);
    if (!engine) {
      LOG(ERROR) << "Failed generate engine with type:" << engine_type;
      return -1;
    }
659
#ifndef BCLOUD
M
MRXLT 已提交
660
    VLOG(2) << "FLAGS_logtostderr " << FLAGS_logtostderr;
M
MRXLT 已提交
661
    int tmp = FLAGS_logtostderr;
W
wangguibao 已提交
662 663 664 665
    if (engine->proc_initialize(conf, version) != 0) {
      LOG(ERROR) << "Failed initialize engine, type:" << engine_type;
      return -1;
    }
M
bug fix  
MRXLT 已提交
666
    VLOG(2) << "FLAGS_logtostderr " << FLAGS_logtostderr;
M
MRXLT 已提交
667
    FLAGS_logtostderr = tmp;
668 669 670 671 672 673
#else
    if (engine->proc_initialize(conf, version) != 0) {
      LOG(ERROR) << "Failed initialize engine, type:" << engine_type;
      return -1;
    }
#endif
W
wangguibao 已提交
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698
    auto r = _versions.insert(std::make_pair(engine->version(), engine));
    if (!r.second) {
      LOG(ERROR) << "Failed insert item: " << engine->version()
                 << ", type: " << engine_type;
      return -1;
    }
    LOG(WARNING) << "Succ proc initialize version engine: "
                 << engine->version();
    return 0;
  }

  int proc_finalize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->proc_finalize() != 0) {
        LOG(ERROR) << "Failed proc finalize version engine: " << iter->first;
      }
      LOG(WARNING) << "Succ proc finalize version engine: " << iter->first;
    }
    return 0;
  }

  int thrd_initialize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_initialize() != 0) {
        LOG(ERROR) << "Failed thrd initialize version engine: " << iter->first;
W
wangguibao 已提交
699
        return -1;
W
wangguibao 已提交
700 701
      }
      LOG(WARNING) << "Succ thrd initialize version engine: " << iter->first;
W
wangguibao 已提交
702
    }
W
wangguibao 已提交
703 704
    return 0;
  }
W
wangguibao 已提交
705

W
wangguibao 已提交
706 707 708 709
  int thrd_clear() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_clear() != 0) {
        LOG(ERROR) << "Failed thrd clear version engine: " << iter->first;
W
wangguibao 已提交
710
        return -1;
W
wangguibao 已提交
711
      }
W
wangguibao 已提交
712
    }
W
wangguibao 已提交
713 714
    return 0;
  }
W
wangguibao 已提交
715

W
wangguibao 已提交
716 717 718 719 720 721 722
  int thrd_finalize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_finalize() != 0) {
        LOG(ERROR) << "Failed thrd finalize version engine: " << iter->first;
        return -1;
      }
      LOG(WARNING) << "Succ thrd finalize version engine: " << iter->first;
W
wangguibao 已提交
723
    }
W
wangguibao 已提交
724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756
    return 0;
  }

  int reload() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->reload() != 0) {
        LOG(ERROR) << "Failed reload version engine: " << iter->first;
        return -1;
      }
      LOG(WARNING) << "Succ reload version engine: " << iter->first;
    }
    return 0;
  }

  uint64_t version() const {
    InferEngine* engine = default_engine();
    if (engine) {
      return engine->version();
    } else {
      return uint64_t(-1);
    }
  }

  // inference interface
  InferEngine* default_engine() const {
    if (_versions.size() != 1) {
      LOG(ERROR) << "Ambiguous default engine version:" << _versions.size();
      return NULL;
    }

    return _versions.begin()->second;
  }

H
HexToString 已提交
757
  int infer(const void* in, void* out, uint32_t batch_size) {
W
wangguibao 已提交
758 759 760 761 762
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get default engine";
      return -1;
    }
H
HexToString 已提交
763
    return engine->infer(in, out, batch_size);
W
wangguibao 已提交
764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781
  }

  template <typename T>
  T* get_core() {
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get core";
      return NULL;
    }
    auto db_engine = dynamic_cast<DBReloadableInferEngine<T>*>(engine);
    if (db_engine) {
      return db_engine->get_core();
    }
    LOG(WARNING) << "fail to get core";
    return NULL;
  }

  // versioned inference interface
H
HexToString 已提交
782
  int infer(const void* in, void* out, uint32_t batch_size, uint64_t version) {
W
wangguibao 已提交
783 784 785 786 787 788
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
      return -1;
    }

H
HexToString 已提交
789
    return iter->second->infer(in, out, batch_size);
W
wangguibao 已提交
790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815
  }

  template <typename T>
  T* get_core(uint64_t version) {
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
      return NULL;
    }

    auto db_engine = dynamic_cast<DBReloadableInferEngine<T>*>(iter->second);
    if (db_engine) {
      return db_engine->get_core();
    }
    LOG(WARNING) << "fail to get core for " << version;
    return NULL;
  }

  // --
  int proc_initialize_impl(const configure::EngineDesc& conf, bool) {
    return -1;
  }
  int thrd_initialize_impl() { return -1; }
  int thrd_finalize_impl() { return -1; }
  int thrd_clear_impl() { return -1; }
  int proc_finalize_impl() { return -1; }
H
HexToString 已提交
816 817 818
  int infer_impl(const void* in, void* out, uint32_t batch_size = -1) {
    return -1;
  }
H
HexToString 已提交
819
  int task_infer_impl(const BatchTensor& in, BatchTensor& out) {  // NOLINT
H
HexToString 已提交
820 821
    return -1;
  }  // NOLINT
W
wangguibao 已提交
822 823 824

 private:
  boost::unordered_map<uint64_t, InferEngine*> _versions;
W
wangguibao 已提交
825 826
};

W
wangguibao 已提交
827 828 829 830 831 832 833 834 835 836 837 838 839 840 841
class InferManager {
 public:
  static InferManager& instance() {
    static InferManager ins;
    return ins;
  }

  int proc_initialize(const char* path, const char* file) {
    ModelToolkitConf model_toolkit_conf;
    if (configure::read_proto_conf(path, file, &model_toolkit_conf) != 0) {
      LOG(ERROR) << "failed load infer config, path: " << path << "/" << file;
      return -1;
    }
    size_t engine_num = model_toolkit_conf.engines_size();
    for (size_t ei = 0; ei < engine_num; ++ei) {
B
barrierye 已提交
842 843
      LOG(INFO) << "model_toolkit_conf.engines(" << ei
                << ").name: " << model_toolkit_conf.engines(ei).name();
W
wangguibao 已提交
844 845 846 847 848 849 850 851
      std::string engine_name = model_toolkit_conf.engines(ei).name();
      VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine();
      if (!engine) {
        LOG(ERROR) << "Failed generate versioned engine: " << engine_name;
        return -1;
      }
      if (engine->proc_initialize(model_toolkit_conf.engines(ei)) != 0) {
        LOG(ERROR) << "Failed initialize version engine, name:" << engine_name;
W
wangguibao 已提交
852
        return -1;
W
wangguibao 已提交
853 854 855 856 857 858 859
      }
      auto r = _map.insert(std::make_pair(engine_name, engine));
      if (!r.second) {
        LOG(ERROR) << "Failed insert item: " << engine_name;
        return -1;
      }
      LOG(WARNING) << "Succ proc initialize engine: " << engine_name;
W
wangguibao 已提交
860
    }
W
wangguibao 已提交
861 862 863 864 865 866 867
    return 0;
  }

  int thrd_initialize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_initialize() != 0) {
        LOG(ERROR) << "Failed thrd initialize engine, name: " << it->first;
W
wangguibao 已提交
868
        return -1;
W
wangguibao 已提交
869 870
      }
      LOG(WARNING) << "Succ thrd initialize engine, name: " << it->first;
W
wangguibao 已提交
871
    }
W
wangguibao 已提交
872 873
    return 0;
  }
W
wangguibao 已提交
874

W
wangguibao 已提交
875 876 877 878 879 880 881 882 883
  int thrd_clear() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_clear() != 0) {
        LOG(ERROR) << "Failed thrd clear engine, name: " << it->first;
        return -1;
      }
    }
    return 0;
  }
W
wangguibao 已提交
884

W
wangguibao 已提交
885 886 887 888 889 890 891 892 893
  int reload() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->reload() != 0) {
        LOG(ERROR) << "Failed reload engine, name: " << it->first;
        return -1;
      }
    }
    return 0;
  }
W
wangguibao 已提交
894

W
wangguibao 已提交
895 896 897 898 899 900 901 902 903 904
  int thrd_finalize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_finalize() != 0) {
        LOG(ERROR) << "Failed thrd finalize engine, name: " << it->first;
        return -1;
      }
      LOG(WARNING) << "Succ thrd finalize engine, name: " << it->first;
    }
    return 0;
  }
W
wangguibao 已提交
905

W
wangguibao 已提交
906 907 908 909 910 911 912 913
  int proc_finalize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->proc_finalize() != 0) {
        LOG(ERROR) << "Failed proc finalize engine, name: " << it->first;
        return -1;
      }
      LOG(WARNING) << "Succ proc finalize engine, name: " << it->first;
    }
W
wangguibao 已提交
914
    _map.clear();
W
wangguibao 已提交
915 916 917 918
    return 0;
  }

  // Inference interface
H
HexToString 已提交
919 920 921 922
  int infer(const char* model_name,
            const void* in,
            void* out,
            uint32_t batch_size = -1) {
W
wangguibao 已提交
923 924 925 926 927
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return -1;
    }
H
HexToString 已提交
928
    return it->second->infer(in, out, batch_size);
W
wangguibao 已提交
929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947
  }

  template <typename T>
  T* get_core(const char* model_name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return NULL;
    }
    auto infer_engine =
        dynamic_cast<DBReloadableInferEngine<T>*>(it->second->default_engine());
    if (infer_engine) {
      return infer_engine->get_core();
    }
    LOG(WARNING) << "fail to get core for " << model_name;
    return NULL;
  }

  // Versioned inference interface
H
HexToString 已提交
948
  int infer(const char* model_name,
H
HexToString 已提交
949 950 951 952
            const void* in,
            void* out,
            uint32_t batch_size,
            uint64_t version) {
W
wangguibao 已提交
953 954 955 956 957
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return -1;
    }
H
HexToString 已提交
958
    return it->second->infer(in, out, batch_size, version);
W
wangguibao 已提交
959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989
  }

  template <typename T>
  T* get_core(const char* model_name, uint64_t version) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return NULL;
    }
    return it->second->get_core<T>(version);
  }

  int query_version(const std::string& model, uint64_t& version) {  // NOLINT
    auto it = _map.find(model);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model;
      return -1;
    }
    auto infer_engine = it->second->default_engine();
    if (!infer_engine) {
      LOG(WARNING) << "Cannot get default engine for model:" << model;
      return -1;
    }
    version = infer_engine->version();
    LOG(INFO) << "Succ get version: " << version << " for model: " << model;
    return 0;
  }

 private:
  boost::unordered_map<std::string, VersionedInferEngine*> _map;
};
W
wangguibao 已提交
990

W
wangguibao 已提交
991 992 993
}  // namespace predictor
}  // namespace paddle_serving
}  // namespace baidu