infer.h 30.4 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
W
wangguibao 已提交
16
#include <sys/stat.h>
W
wangguibao 已提交
17
#include <sys/types.h>
W
wangguibao 已提交
18
#include <unistd.h>
W
wangguibao 已提交
19
#include <string>
M
MRXLT 已提交
20
#include <utility>
W
wangguibao 已提交
21
#include <vector>
G
guru4elephant 已提交
22 23 24
#include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/factory.h"
#include "core/predictor/framework/infer_data.h"
W
wangjiawei04 已提交
25
#include "paddle_inference_api.h"  // NOLINT
W
wangguibao 已提交
26 27 28 29
namespace baidu {
namespace paddle_serving {
namespace predictor {

W
wangguibao 已提交
30 31
using configure::ModelToolkitConf;

32 33 34 35 36
class InferEngineCreationParams {
 public:
  InferEngineCreationParams() {
    _path = "";
    _enable_memory_optimization = false;
M
MRXLT 已提交
37
    _enable_ir_optimization = false;
38 39
    _static_optimization = false;
    _force_update_static_cache = false;
M
MRXLT 已提交
40
    _use_trt = false;
41 42 43 44 45 46 47 48
  }

  void set_path(const std::string& path) { _path = path; }

  void set_enable_memory_optimization(bool enable_memory_optimization) {
    _enable_memory_optimization = enable_memory_optimization;
  }

M
MRXLT 已提交
49 50 51 52
  void set_enable_ir_optimization(bool enable_ir_optimization) {
    _enable_ir_optimization = enable_ir_optimization;
  }

M
MRXLT 已提交
53 54
  void set_use_trt(bool use_trt) { _use_trt = use_trt; }

55 56 57 58
  bool enable_memory_optimization() const {
    return _enable_memory_optimization;
  }

M
MRXLT 已提交
59 60
  bool enable_ir_optimization() const { return _enable_ir_optimization; }

M
MRXLT 已提交
61 62
  bool use_trt() const { return _use_trt; }

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
  void set_static_optimization(bool static_optimization = false) {
    _static_optimization = static_optimization;
  }

  void set_force_update_static_cache(bool force_update_static_cache = false) {
    _force_update_static_cache = force_update_static_cache;
  }

  bool static_optimization() const { return _static_optimization; }

  bool force_update_static_cache() const { return _force_update_static_cache; }

  std::string get_path() const { return _path; }

  void dump() const {
    LOG(INFO) << "InferEngineCreationParams: "
              << "model_path = " << _path << ", "
              << "enable_memory_optimization = " << _enable_memory_optimization
              << ", "
M
MRXLT 已提交
82
              << "enable_ir_optimization = " << _enable_ir_optimization << ", "
83 84 85 86 87 88 89
              << "static_optimization = " << _static_optimization << ", "
              << "force_update_static_cache = " << _force_update_static_cache;
  }

 private:
  std::string _path;
  bool _enable_memory_optimization;
M
MRXLT 已提交
90
  bool _enable_ir_optimization;
91 92
  bool _static_optimization;
  bool _force_update_static_cache;
M
MRXLT 已提交
93
  bool _use_trt;
94 95
};

W
wangguibao 已提交
96
class InferEngine {
W
wangguibao 已提交
97 98 99 100 101 102 103 104 105 106
 public:
  virtual ~InferEngine() {}

  virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
    return proc_initialize_impl(conf, version);
  }
  virtual int proc_finalize() { return proc_finalize_impl(); }
  virtual int thrd_initialize() { return thrd_initialize_impl(); }
  virtual int thrd_clear() { return thrd_clear_impl(); }
  virtual int thrd_finalize() { return thrd_finalize_impl(); }
W
wangjiawei04 已提交
107 108
  virtual int infer() {
    return infer_impl();
W
wangguibao 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121
  }

  virtual int reload() = 0;

  virtual uint64_t version() const = 0;

  // begin: framework inner call
  virtual int proc_initialize_impl(const configure::EngineDesc& conf,
                                   bool version) = 0;
  virtual int thrd_initialize_impl() = 0;
  virtual int thrd_finalize_impl() = 0;
  virtual int thrd_clear_impl() = 0;
  virtual int proc_finalize_impl() = 0;
W
wangjiawei04 已提交
122 123
  virtual std::vector<std::string> GetInputNames() = 0;
  virtual std::vector<std::string> GetOutputNames() = 0;
W
wangjiawei04 已提交
124 125 126
  virtual std::unique_ptr<paddle_infer::Tensor> GetInputHandle(const std::string& name) = 0;
  virtual std::unique_ptr<paddle_infer::Tensor> GetOutputHandle(const std::string& name) = 0;
  virtual int infer_impl() = 0;
W
wangguibao 已提交
127 128 129 130 131 132
  // end: framework inner call
};

class ReloadableInferEngine : public InferEngine {
 public:
  virtual ~ReloadableInferEngine() {}
W
wangguibao 已提交
133

W
wangguibao 已提交
134 135 136 137 138
  union last_check_status {
    time_t last_timestamp;
    uint64_t last_md5sum;
    uint64_t last_revision;
  };
W
wangguibao 已提交
139

140
  virtual int load(const InferEngineCreationParams& params) = 0;
W
wangguibao 已提交
141 142 143 144 145 146 147 148

  int proc_initialize_impl(const configure::EngineDesc& conf, bool version) {
    _reload_tag_file = conf.reloadable_meta();
    _reload_mode_tag = conf.reloadable_type();
    _model_data_path = conf.model_data_path();
    _infer_thread_num = conf.runtime_thread_num();
    _infer_batch_size = conf.batch_infer_size();
    _infer_batch_align = conf.enable_batch_align();
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164

    bool enable_memory_optimization = false;
    if (conf.has_enable_memory_optimization()) {
      enable_memory_optimization = conf.enable_memory_optimization();
    }

    bool static_optimization = false;
    if (conf.has_static_optimization()) {
      static_optimization = conf.static_optimization();
    }

    bool force_update_static_cache = false;
    if (conf.has_force_update_static_cache()) {
      force_update_static_cache = conf.force_update_static_cache();
    }

M
MRXLT 已提交
165 166 167 168 169
    if (conf.has_enable_ir_optimization()) {
      _infer_engine_params.set_enable_ir_optimization(
          conf.enable_ir_optimization());
    }

170 171 172 173 174 175 176 177
    _infer_engine_params.set_path(_model_data_path);
    if (enable_memory_optimization) {
      _infer_engine_params.set_enable_memory_optimization(true);
      _infer_engine_params.set_static_optimization(static_optimization);
      _infer_engine_params.set_force_update_static_cache(
          force_update_static_cache);
    }

M
MRXLT 已提交
178 179 180 181
    if (conf.has_use_trt()) {
      _infer_engine_params.set_use_trt(conf.use_trt());
    }

182
    if (!check_need_reload() || load(_infer_engine_params) != 0) {
W
wangguibao 已提交
183 184
      LOG(ERROR) << "Failed load model_data_path" << _model_data_path;
      return -1;
W
wangguibao 已提交
185
    }
W
wangguibao 已提交
186 187 188 189

    if (parse_version_info(conf, version) != 0) {
      LOG(ERROR) << "Failed parse version info";
      return -1;
W
wangguibao 已提交
190
    }
W
wangguibao 已提交
191 192 193 194 195 196 197 198 199

    LOG(WARNING) << "Succ load model_data_path" << _model_data_path;
    return 0;
  }

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    if (proc_initialize_impl(conf, version) != 0) {
      LOG(ERROR) << "Failed proc initialize impl";
      return -1;
W
wangguibao 已提交
200
    }
W
wangguibao 已提交
201 202
    return 0;
  }
W
wangguibao 已提交
203

W
wangjiawei04 已提交
204 205
  int infer() {
      return infer_impl();
W
wangguibao 已提交
206
  }
W
wangguibao 已提交
207

W
wangguibao 已提交
208 209 210 211
  int thrd_initialize() {
    if (_infer_thread_num > 0) {
      return 0;
    }
W
wangguibao 已提交
212

W
wangguibao 已提交
213 214
    return thrd_initialize_impl();
  }
W
wangguibao 已提交
215

W
wangguibao 已提交
216 217 218 219
  int thrd_clear() {
    if (_infer_thread_num > 0) {
      return 0;
    }
W
wangguibao 已提交
220

W
wangguibao 已提交
221 222
    return thrd_clear_impl();
  }
W
wangguibao 已提交
223

W
wangguibao 已提交
224 225 226 227 228
  int proc_finalize() {
    if (proc_finalize_impl() != 0) {
      LOG(ERROR) << "Failed proc finalize impl";
      return -1;
    }
W
wangguibao 已提交
229

W
wangguibao 已提交
230 231
    return 0;
  }
W
wangguibao 已提交
232

W
wangguibao 已提交
233 234 235
  int reload() {
    if (check_need_reload()) {
      LOG(WARNING) << "begin reload model[" << _model_data_path << "].";
236
      return load(_infer_engine_params);
W
wangguibao 已提交
237 238 239 240 241 242 243
    }
    return 0;
  }

  uint64_t version() const { return _version; }

  uint32_t thread_num() const { return _infer_thread_num; }
W
wangguibao 已提交
244

W
wangguibao 已提交
245 246 247 248 249
 private:
  int parse_version_info(const configure::EngineDesc& config, bool version) {
    _version = uint64_t(-1);
    return 0;
  }
W
wangguibao 已提交
250

W
wangguibao 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
  bool check_need_reload() {
    if (_reload_mode_tag == "timestamp_ne") {
      return check_timestamp_ne();
    } else if (_reload_mode_tag == "timestamp_gt") {
      return check_timestamp_gt();
    } else if (_reload_mode_tag == "md5sum") {
      return check_md5sum();
    } else if (_reload_mode_tag == "revision") {
      return check_revision();
    } else if (_reload_mode_tag == "none") {
      return false;
    } else {
      LOG(ERROR) << "Not support check type: " << _reload_mode_tag;
      return false;
    }
  }

  bool check_timestamp_ne() {
    struct stat st;
    if (stat(_reload_tag_file.c_str(), &st) != 0) {
      LOG(ERROR) << "Failed stat config file:" << _reload_tag_file;
      return false;
    }
W
wangguibao 已提交
274

W
wangguibao 已提交
275 276 277
    if ((st.st_mode & S_IFREG) && st.st_mtime != _last_status.last_timestamp) {
      _last_status.last_timestamp = st.st_mtime;
      return true;
W
wangguibao 已提交
278 279
    }

W
wangguibao 已提交
280 281
    return false;
  }
W
wangguibao 已提交
282

W
wangguibao 已提交
283 284 285 286 287 288
  bool check_timestamp_gt() {
    struct stat st;
    if (stat(_reload_tag_file.c_str(), &st) != 0) {
      LOG(ERROR) << "Failed stat config file:" << _reload_tag_file;
      return false;
    }
W
wangguibao 已提交
289

W
wangguibao 已提交
290 291 292
    if ((st.st_mode & S_IFREG) && st.st_mtime > _last_status.last_timestamp) {
      _last_status.last_timestamp = st.st_mtime;
      return true;
W
wangguibao 已提交
293 294
    }

W
wangguibao 已提交
295 296 297 298 299 300 301 302 303
    return false;
  }

  bool check_md5sum() { return false; }

  bool check_revision() { return false; }

 protected:
  std::string _model_data_path;
304
  InferEngineCreationParams _infer_engine_params;
W
wangguibao 已提交
305 306 307 308 309 310 311 312 313 314

 private:
  std::string _reload_tag_file;
  std::string _reload_mode_tag;
  last_check_status _last_status;
  uint32_t _infer_thread_num;
  uint32_t _infer_batch_size;
  bool _infer_batch_align;
  uint64_t _version;
};
W
wangguibao 已提交
315

W
wangguibao 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
template <typename EngineCore>
struct ModelData {
  ModelData() : current_idx(1) {
    cores[0] = NULL;
    cores[1] = NULL;
  }

  ~ModelData() {
    delete cores[0];
    delete cores[1];
  }

  EngineCore* cores[2];
  uint32_t current_idx;
};

template <typename EngineCore>
class DBReloadableInferEngine : public ReloadableInferEngine {
 public:
  virtual ~DBReloadableInferEngine() {}

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    THREAD_KEY_CREATE(&_skey, NULL);
    THREAD_MUTEX_INIT(&_mutex, NULL);
    return ReloadableInferEngine::proc_initialize(conf, version);
  }

343
  virtual int load(const InferEngineCreationParams& params) {
W
wangguibao 已提交
344 345
    if (_reload_vec.empty()) {
      return 0;
W
wangguibao 已提交
346 347
    }

W
wangguibao 已提交
348
    for (uint32_t ti = 0; ti < _reload_vec.size(); ++ti) {
349
      if (load_data(_reload_vec[ti], params) != 0) {
W
wangguibao 已提交
350 351 352 353 354
        LOG(ERROR) << "Failed reload engine model: " << ti;
        return -1;
      }
    }

355
    LOG(WARNING) << "Succ load engine, path: " << params.get_path();
W
wangguibao 已提交
356

W
wangguibao 已提交
357 358
    return 0;
  }
W
wangguibao 已提交
359

360 361
  int load_data(ModelData<EngineCore>* md,
                const InferEngineCreationParams& params) {
W
wangguibao 已提交
362 363 364
    uint32_t next_idx = (md->current_idx + 1) % 2;
    if (md->cores[next_idx]) {
      delete md->cores[next_idx];
W
wangguibao 已提交
365 366
    }

W
wangguibao 已提交
367
    md->cores[next_idx] = new (std::nothrow) EngineCore;
368 369 370 371

    params.dump();
    if (!md->cores[next_idx] || md->cores[next_idx]->create(params) != 0) {
      LOG(ERROR) << "Failed create model, path: " << params.get_path();
W
wangguibao 已提交
372
      return -1;
W
wangguibao 已提交
373
    }
W
wangguibao 已提交
374 375 376
    md->current_idx = next_idx;
    return 0;
  }
W
wangguibao 已提交
377

W
wangguibao 已提交
378 379
  virtual int thrd_initialize_impl() {
    // memory pool to be inited in non-serving-threads
W
wangguibao 已提交
380

W
wangguibao 已提交
381
    ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
382 383 384
    if (!md || load_data(md, _infer_engine_params) != 0) {
      LOG(ERROR) << "Failed create thread data from "
                 << _infer_engine_params.get_path();
W
wangguibao 已提交
385
      return -1;
W
wangguibao 已提交
386 387
    }

W
wangguibao 已提交
388 389 390 391 392 393 394 395 396 397 398
    THREAD_SETSPECIFIC(_skey, md);
    _reload_vec.push_back(md);
    return 0;
  }

  int thrd_clear_impl() {
    // for non-serving-threads
    return 0;
  }

  int thrd_finalize_impl() { return 0; }
W
wangguibao 已提交
399

W
wangguibao 已提交
400 401 402 403 404
  int proc_finalize_impl() {
    THREAD_KEY_DELETE(_skey);
    THREAD_MUTEX_DESTROY(&_mutex);
    return 0;
  }
W
wangguibao 已提交
405

W
wangguibao 已提交
406 407 408 409 410 411
  EngineCore* get_core() {
    ModelData<EngineCore>* md =
        (ModelData<EngineCore>*)THREAD_GETSPECIFIC(_skey);
    if (!md) {
      LOG(ERROR) << "Failed get thread specific data";
      return NULL;
W
wangguibao 已提交
412
    }
W
wangguibao 已提交
413 414
    return md->cores[md->current_idx];
  }
W
wangguibao 已提交
415

W
wangguibao 已提交
416 417 418 419
 protected:
  THREAD_KEY_T _skey;
  THREAD_MUTEX_T _mutex;
  std::vector<ModelData<EngineCore>*> _reload_vec;
W
wangguibao 已提交
420

W
wangguibao 已提交
421 422
 private:
};
W
wangguibao 已提交
423

W
wangguibao 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
// 多个EngineCore共用同一份模型数据
template <typename EngineCore>
class CloneDBReloadableInferEngine
    : public DBReloadableInferEngine<EngineCore> {
 public:
  virtual ~CloneDBReloadableInferEngine() {}

  virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
    _pd = new (std::nothrow) ModelData<EngineCore>;
    if (!_pd) {
      LOG(ERROR) << "Failed to allocate for ProcData";
      return -1;
    }
    return DBReloadableInferEngine<EngineCore>::proc_initialize(conf, version);
  }

440
  virtual int load(const InferEngineCreationParams& params) {
W
wangguibao 已提交
441 442
    // 加载进程级模型数据
    if (!_pd ||
443 444
        DBReloadableInferEngine<EngineCore>::load_data(_pd, params) != 0) {
      LOG(ERROR) << "Failed to create common model from [" << params.get_path()
W
wangguibao 已提交
445 446 447 448
                 << "].";
      return -1;
    }
    LOG(WARNING) << "Succ load common model[" << _pd->cores[_pd->current_idx]
449
                 << "], path[" << params.get_path() << "].";
W
wangguibao 已提交
450 451 452 453 454 455 456 457 458 459 460 461 462

    if (DBReloadableInferEngine<EngineCore>::_reload_vec.empty()) {
      return 0;
    }

    for (uint32_t ti = 0;
         ti < DBReloadableInferEngine<EngineCore>::_reload_vec.size();
         ++ti) {
      if (load_data(DBReloadableInferEngine<EngineCore>::_reload_vec[ti],
                    _pd->cores[_pd->current_idx]) != 0) {
        LOG(ERROR) << "Failed reload engine model: " << ti;
        return -1;
      }
W
wangguibao 已提交
463 464
    }

465
    LOG(WARNING) << "Succ load clone model, path[" << params.get_path() << "]";
W
wangguibao 已提交
466

W
wangguibao 已提交
467 468
    return 0;
  }
W
wangguibao 已提交
469

W
wangguibao 已提交
470 471 472 473 474
  // 加载线程级对象,多个线程级对象共用pd_core的模型数据
  int load_data(ModelData<EngineCore>* td, EngineCore* pd_core) {
    uint32_t next_idx = (td->current_idx + 1) % 2;
    if (td->cores[next_idx]) {
      delete td->cores[next_idx];
W
wangguibao 已提交
475 476
    }

W
wangguibao 已提交
477 478 479 480 481 482
    td->cores[next_idx] = new (std::nothrow) EngineCore;
    if (!td->cores[next_idx] ||
        td->cores[next_idx]->clone(pd_core->get()) != 0) {
      LOG(ERROR) << "Failed clone model from pd_core[ " << pd_core << "], idx["
                 << next_idx << "]";
      return -1;
W
wangguibao 已提交
483
    }
W
wangguibao 已提交
484 485 486 487 488 489
    td->current_idx = next_idx;
    LOG(WARNING) << "td_core[" << td->cores[td->current_idx]
                 << "] clone model from pd_core[" << pd_core
                 << "] succ, cur_idx[" << td->current_idx << "].";
    return 0;
  }
W
wangguibao 已提交
490

W
wangguibao 已提交
491
  virtual int thrd_initialize_impl() {
W
wangguibao 已提交
492

W
wangguibao 已提交
493 494 495 496 497 498 499 500 501 502 503
    ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
    if (!md || load_data(md, _pd->cores[_pd->current_idx]) != 0) {
      LOG(ERROR) << "Failed clone thread data, origin_core["
                 << _pd->cores[_pd->current_idx] << "].";
      return -1;
    }

    THREAD_SETSPECIFIC(DBReloadableInferEngine<EngineCore>::_skey, md);
    DBReloadableInferEngine<EngineCore>::_reload_vec.push_back(md);
    return 0;
  }
W
wangguibao 已提交
504

W
wangguibao 已提交
505 506 507
 protected:
  ModelData<EngineCore>*
      _pd;  // 进程级EngineCore,多个线程级EngineCore共用该对象的模型数据
W
wangguibao 已提交
508 509
};

W
wangguibao 已提交
510
template <typename FluidFamilyCore>
M
bug fix  
MRXLT 已提交
511
#ifdef WITH_TRT
M
MRXLT 已提交
512
class FluidInferEngine : public DBReloadableInferEngine<FluidFamilyCore> {
M
bug fix  
MRXLT 已提交
513 514 515 516
#else
class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> {
#endif
 public:  // NOLINT
W
wangguibao 已提交
517 518
  FluidInferEngine() {}
  ~FluidInferEngine() {}
W
wangjiawei04 已提交
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
  std::vector<std::string> GetInputNames() {
    FluidFamilyCore* core = DBReloadableInferEngine<FluidFamilyCore>::get_core();
    if (!core || !core->get()) {
      LOG(ERROR) << "Failed get fluid core in GetInputHandle()";
    }
    return core->GetInputNames();
  }

  std::vector<std::string> GetOutputNames() {
    FluidFamilyCore* core = DBReloadableInferEngine<FluidFamilyCore>::get_core();
    if (!core || !core->get()) {
      LOG(ERROR) << "Failed get fluid core in GetInputHandle()";
    }
    return core->GetOutputNames();
  }

W
wangjiawei04 已提交
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549
  std::unique_ptr<paddle_infer::Tensor> GetInputHandle(const std::string& name) {
    FluidFamilyCore* core = DBReloadableInferEngine<FluidFamilyCore>::get_core();
    if (!core || !core->get()) {
      LOG(ERROR) << "Failed get fluid core in GetInputHandle()";
    }
    return core->GetInputHandle(name);
  }

  std::unique_ptr<paddle_infer::Tensor> GetOutputHandle(const std::string& name) {
    FluidFamilyCore* core = DBReloadableInferEngine<FluidFamilyCore>::get_core();
    if (!core || !core->get()) {
      LOG(ERROR) << "Failed get fluid core in GetOutputHandle()";
    }
    return core->GetOutputHandle(name);
  }
W
wangguibao 已提交
550

W
wangjiawei04 已提交
551 552
  int infer_impl() {
    FluidFamilyCore* core = DBReloadableInferEngine<FluidFamilyCore>::get_core();
W
wangguibao 已提交
553 554 555
    if (!core || !core->get()) {
      LOG(ERROR) << "Failed get fluid core in infer_impl()";
      return -1;
W
wangguibao 已提交
556 557
    }

W
wangjiawei04 已提交
558
    if (!core->Run()) {
W
wangguibao 已提交
559 560 561 562 563
      LOG(ERROR) << "Failed run fluid family core";
      return -1;
    }
    return 0;
  }
W
wangguibao 已提交
564 565
};

W
wangguibao 已提交
566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
typedef FactoryPool<InferEngine> StaticInferFactory;

class VersionedInferEngine : public InferEngine {
 public:
  VersionedInferEngine() { _versions.clear(); }
  ~VersionedInferEngine() {}

  int proc_initialize(const configure::EngineDesc& conf) {
    if (proc_initialize(conf, false) != 0) {
      LOG(ERROR) << "Failed proc intialize engine: " << conf.name().c_str();
      return -1;
    }

    LOG(WARNING) << "Succ proc initialize engine: " << conf.name().c_str();
    return 0;
  }

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    std::string engine_type = conf.type();
    InferEngine* engine =
        StaticInferFactory::instance().generate_object(engine_type);
    if (!engine) {
      LOG(ERROR) << "Failed generate engine with type:" << engine_type;
      return -1;
    }
591
#ifndef BCLOUD
M
MRXLT 已提交
592
    VLOG(2) << "FLAGS_logtostderr " << FLAGS_logtostderr;
M
MRXLT 已提交
593
    int tmp = FLAGS_logtostderr;
W
wangguibao 已提交
594 595 596 597
    if (engine->proc_initialize(conf, version) != 0) {
      LOG(ERROR) << "Failed initialize engine, type:" << engine_type;
      return -1;
    }
M
bug fix  
MRXLT 已提交
598
    VLOG(2) << "FLAGS_logtostderr " << FLAGS_logtostderr;
M
MRXLT 已提交
599
    FLAGS_logtostderr = tmp;
600 601 602 603 604 605
#else
    if (engine->proc_initialize(conf, version) != 0) {
      LOG(ERROR) << "Failed initialize engine, type:" << engine_type;
      return -1;
    }
#endif
W
wangguibao 已提交
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
    auto r = _versions.insert(std::make_pair(engine->version(), engine));
    if (!r.second) {
      LOG(ERROR) << "Failed insert item: " << engine->version()
                 << ", type: " << engine_type;
      return -1;
    }
    LOG(WARNING) << "Succ proc initialize version engine: "
                 << engine->version();
    return 0;
  }

  int proc_finalize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->proc_finalize() != 0) {
        LOG(ERROR) << "Failed proc finalize version engine: " << iter->first;
      }
      LOG(WARNING) << "Succ proc finalize version engine: " << iter->first;
    }
    return 0;
  }

  int thrd_initialize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_initialize() != 0) {
        LOG(ERROR) << "Failed thrd initialize version engine: " << iter->first;
W
wangguibao 已提交
631
        return -1;
W
wangguibao 已提交
632 633
      }
      LOG(WARNING) << "Succ thrd initialize version engine: " << iter->first;
W
wangguibao 已提交
634
    }
W
wangguibao 已提交
635 636
    return 0;
  }
W
wangguibao 已提交
637

W
wangguibao 已提交
638 639 640 641
  int thrd_clear() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_clear() != 0) {
        LOG(ERROR) << "Failed thrd clear version engine: " << iter->first;
W
wangguibao 已提交
642
        return -1;
W
wangguibao 已提交
643
      }
W
wangguibao 已提交
644
    }
W
wangguibao 已提交
645 646
    return 0;
  }
W
wangguibao 已提交
647

W
wangguibao 已提交
648 649 650 651 652 653 654
  int thrd_finalize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_finalize() != 0) {
        LOG(ERROR) << "Failed thrd finalize version engine: " << iter->first;
        return -1;
      }
      LOG(WARNING) << "Succ thrd finalize version engine: " << iter->first;
W
wangguibao 已提交
655
    }
W
wangguibao 已提交
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688
    return 0;
  }

  int reload() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->reload() != 0) {
        LOG(ERROR) << "Failed reload version engine: " << iter->first;
        return -1;
      }
      LOG(WARNING) << "Succ reload version engine: " << iter->first;
    }
    return 0;
  }

  uint64_t version() const {
    InferEngine* engine = default_engine();
    if (engine) {
      return engine->version();
    } else {
      return uint64_t(-1);
    }
  }

  // inference interface
  InferEngine* default_engine() const {
    if (_versions.size() != 1) {
      LOG(ERROR) << "Ambiguous default engine version:" << _versions.size();
      return NULL;
    }

    return _versions.begin()->second;
  }

W
wangjiawei04 已提交
689
  int infer() {
W
wangguibao 已提交
690 691 692 693 694
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get default engine";
      return -1;
    }
W
wangjiawei04 已提交
695 696 697
    return engine->infer();
  }

W
wangjiawei04 已提交
698 699 700 701 702 703 704 705 706 707 708 709 710 711
  std::vector<std::string> GetInputNames() {
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get default engine";
    }
    return engine->GetInputNames();
  }
  std::vector<std::string> GetOutputNames() {
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get default engine";
    }
    return engine->GetOutputNames();
  }
W
wangjiawei04 已提交
712 713 714 715 716 717 718 719 720 721 722 723 724 725
  std::unique_ptr<paddle_infer::Tensor> GetInputHandle(const std::string& name) {
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get default engine";
    }
    return engine->GetInputHandle(name);
  }

  std::unique_ptr<paddle_infer::Tensor> GetOutputHandle(const std::string& name) {
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get default engine";
    }
    return engine->GetOutputHandle(name);
W
wangguibao 已提交
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743
  }

  template <typename T>
  T* get_core() {
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get core";
      return NULL;
    }
    auto db_engine = dynamic_cast<DBReloadableInferEngine<T>*>(engine);
    if (db_engine) {
      return db_engine->get_core();
    }
    LOG(WARNING) << "fail to get core";
    return NULL;
  }

  // versioned inference interface
W
wangjiawei04 已提交
744
  int infer(uint64_t version) {
W
wangguibao 已提交
745 746 747 748 749 750
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
      return -1;
    }

W
wangjiawei04 已提交
751 752
    return iter->second->infer();
  }
W
wangjiawei04 已提交
753 754 755 756 757 758 759 760 761 762 763 764 765 766 767
  std::vector<std::string> GetInputNames(uint64_t version) {
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
    } 
    return iter->second->GetInputNames();
  }

  std::vector<std::string> GetOutputNames(uint64_t version) {
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
    }
    return iter->second->GetOutputNames();
  }
W
wangjiawei04 已提交
768 769 770 771 772 773 774 775 776 777 778 779 780 781 782

  std::unique_ptr<paddle_infer::Tensor> GetInputHandle(uint64_t version, const std::string& name) {
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
    }
    return iter->second->GetInputHandle(name);
  }

  std::unique_ptr<paddle_infer::Tensor> GetOutputHandle(uint64_t version, const std::string& name) {
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
    }
    return iter->second->GetOutputHandle(name);
W
wangguibao 已提交
783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808
  }

  template <typename T>
  T* get_core(uint64_t version) {
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
      return NULL;
    }

    auto db_engine = dynamic_cast<DBReloadableInferEngine<T>*>(iter->second);
    if (db_engine) {
      return db_engine->get_core();
    }
    LOG(WARNING) << "fail to get core for " << version;
    return NULL;
  }

  // --
  int proc_initialize_impl(const configure::EngineDesc& conf, bool) {
    return -1;
  }
  int thrd_initialize_impl() { return -1; }
  int thrd_finalize_impl() { return -1; }
  int thrd_clear_impl() { return -1; }
  int proc_finalize_impl() { return -1; }
W
wangjiawei04 已提交
809
  int infer_impl() {
W
wangguibao 已提交
810 811 812 813 814
    return -1;
  }

 private:
  boost::unordered_map<uint64_t, InferEngine*> _versions;
W
wangguibao 已提交
815 816
};

W
wangguibao 已提交
817 818 819 820 821 822 823 824 825 826 827 828 829 830 831
class InferManager {
 public:
  static InferManager& instance() {
    static InferManager ins;
    return ins;
  }

  int proc_initialize(const char* path, const char* file) {
    ModelToolkitConf model_toolkit_conf;
    if (configure::read_proto_conf(path, file, &model_toolkit_conf) != 0) {
      LOG(ERROR) << "failed load infer config, path: " << path << "/" << file;
      return -1;
    }
    size_t engine_num = model_toolkit_conf.engines_size();
    for (size_t ei = 0; ei < engine_num; ++ei) {
B
barrierye 已提交
832 833
      LOG(INFO) << "model_toolkit_conf.engines(" << ei
                << ").name: " << model_toolkit_conf.engines(ei).name();
W
wangguibao 已提交
834 835 836 837 838 839 840 841
      std::string engine_name = model_toolkit_conf.engines(ei).name();
      VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine();
      if (!engine) {
        LOG(ERROR) << "Failed generate versioned engine: " << engine_name;
        return -1;
      }
      if (engine->proc_initialize(model_toolkit_conf.engines(ei)) != 0) {
        LOG(ERROR) << "Failed initialize version engine, name:" << engine_name;
W
wangguibao 已提交
842
        return -1;
W
wangguibao 已提交
843 844 845 846 847 848 849
      }
      auto r = _map.insert(std::make_pair(engine_name, engine));
      if (!r.second) {
        LOG(ERROR) << "Failed insert item: " << engine_name;
        return -1;
      }
      LOG(WARNING) << "Succ proc initialize engine: " << engine_name;
W
wangguibao 已提交
850
    }
W
wangguibao 已提交
851 852 853 854 855 856 857
    return 0;
  }

  int thrd_initialize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_initialize() != 0) {
        LOG(ERROR) << "Failed thrd initialize engine, name: " << it->first;
W
wangguibao 已提交
858
        return -1;
W
wangguibao 已提交
859 860
      }
      LOG(WARNING) << "Succ thrd initialize engine, name: " << it->first;
W
wangguibao 已提交
861
    }
W
wangguibao 已提交
862 863
    return 0;
  }
W
wangguibao 已提交
864

W
wangguibao 已提交
865 866 867 868 869 870 871 872 873
  int thrd_clear() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_clear() != 0) {
        LOG(ERROR) << "Failed thrd clear engine, name: " << it->first;
        return -1;
      }
    }
    return 0;
  }
W
wangguibao 已提交
874

W
wangguibao 已提交
875 876 877 878 879 880 881 882 883
  int reload() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->reload() != 0) {
        LOG(ERROR) << "Failed reload engine, name: " << it->first;
        return -1;
      }
    }
    return 0;
  }
W
wangguibao 已提交
884

W
wangguibao 已提交
885 886 887 888 889 890 891 892 893 894
  int thrd_finalize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_finalize() != 0) {
        LOG(ERROR) << "Failed thrd finalize engine, name: " << it->first;
        return -1;
      }
      LOG(WARNING) << "Succ thrd finalize engine, name: " << it->first;
    }
    return 0;
  }
W
wangguibao 已提交
895

W
wangguibao 已提交
896 897 898 899 900 901 902 903
  int proc_finalize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->proc_finalize() != 0) {
        LOG(ERROR) << "Failed proc finalize engine, name: " << it->first;
        return -1;
      }
      LOG(WARNING) << "Succ proc finalize engine, name: " << it->first;
    }
W
wangguibao 已提交
904
    _map.clear();
W
wangguibao 已提交
905 906 907 908
    return 0;
  }

  // Inference interface
W
wangjiawei04 已提交
909
  int infer(const char* model_name) {
W
wangguibao 已提交
910 911 912 913 914
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return -1;
    }
W
wangjiawei04 已提交
915 916
    return it->second->infer();
  }
W
wangjiawei04 已提交
917 918 919 920 921 922 923 924 925 926 927 928 929 930 931

  std::vector<std::string> GetInputNames(const char* model_name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
    }
    return it->second->GetInputNames();
  }
  std::vector<std::string> GetOutputNames(const char* model_name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
    }
    return it->second->GetOutputNames();
  }
W
wangjiawei04 已提交
932 933 934 935 936 937 938 939 940 941 942 943 944
  std::unique_ptr<paddle_infer::Tensor> GetInputHandle(const char* model_name, const std::string& name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
    }
    return it->second->GetInputHandle(name);
  }
  std::unique_ptr<paddle_infer::Tensor> GetOutputHandle(const char* model_name, const std::string& name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
    }
    return it->second->GetOutputHandle(name);
W
wangguibao 已提交
945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970
  }

  template <typename T>
  T* get_core(const char* model_name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return NULL;
    }
    auto infer_engine =
        dynamic_cast<DBReloadableInferEngine<T>*>(it->second->default_engine());
    if (infer_engine) {
      return infer_engine->get_core();
    }
    LOG(WARNING) << "fail to get core for " << model_name;
    return NULL;
  }

  // Versioned inference interface
  int infer(const char* model_name,
            uint64_t version) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return -1;
    }
W
wangjiawei04 已提交
971 972
    return it->second->infer(version);
  }
W
wangjiawei04 已提交
973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988
  std::vector<std::string> GetInputNames(const char* model_name, uint64_t version) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
    }
    return it->second->GetInputNames(version);
  }

  std::vector<std::string> GetOutputNames(const char* model_name, uint64_t version) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
    }
    return it->second->GetOutputNames(version);
  }

W
wangjiawei04 已提交
989 990 991 992 993 994 995 996 997 998 999 1000 1001
  std::unique_ptr<paddle_infer::Tensor> GetInputHandle(const char* model_name, uint64_t version, const std::string& name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
    }
    return it->second->GetInputHandle(version, name);
  }
  std::unique_ptr<paddle_infer::Tensor> GetOutputHandle(const char* model_name, uint64_t version, const std::string& name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
    }
    return it->second->GetOutputHandle(version, name);
W
wangguibao 已提交
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031
  }
  template <typename T>
  T* get_core(const char* model_name, uint64_t version) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return NULL;
    }
    return it->second->get_core<T>(version);
  }

  int query_version(const std::string& model, uint64_t& version) {  // NOLINT
    auto it = _map.find(model);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model;
      return -1;
    }
    auto infer_engine = it->second->default_engine();
    if (!infer_engine) {
      LOG(WARNING) << "Cannot get default engine for model:" << model;
      return -1;
    }
    version = infer_engine->version();
    LOG(INFO) << "Succ get version: " << version << " for model: " << model;
    return 0;
  }

 private:
  boost::unordered_map<std::string, VersionedInferEngine*> _map;
};
W
wangguibao 已提交
1032

W
wangguibao 已提交
1033 1034 1035
}  // namespace predictor
}  // namespace paddle_serving
}  // namespace baidu