infer.h 24.0 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
W
wangguibao 已提交
16
#include <sys/stat.h>
W
wangguibao 已提交
17
#include <sys/types.h>
W
wangguibao 已提交
18
#include <unistd.h>
W
wangguibao 已提交
19 20
#include <string>
#include <vector>
W
wangguibao 已提交
21 22 23 24
#include "predictor/common/inner_common.h"
#include "predictor/framework/bsf.h"
#include "predictor/framework/factory.h"
#include "predictor/framework/infer_data.h"
W
wangguibao 已提交
25 26 27 28 29

namespace baidu {
namespace paddle_serving {
namespace predictor {

W
wangguibao 已提交
30 31
using configure::ModelToolkitConf;

W
wangguibao 已提交
32
class InferEngine {
W
wangguibao 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
 public:
  virtual ~InferEngine() {}

  virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
    return proc_initialize_impl(conf, version);
  }
  virtual int proc_finalize() { return proc_finalize_impl(); }
  virtual int thrd_initialize() { return thrd_initialize_impl(); }
  virtual int thrd_clear() { return thrd_clear_impl(); }
  virtual int thrd_finalize() { return thrd_finalize_impl(); }
  virtual int infer(const void* in, void* out, uint32_t batch_size = -1) {
    return infer_impl1(in, out, batch_size);
  }

  virtual int reload() = 0;

  virtual uint64_t version() const = 0;

  // begin: framework inner call
  virtual int proc_initialize_impl(const configure::EngineDesc& conf,
                                   bool version) = 0;
  virtual int thrd_initialize_impl() = 0;
  virtual int thrd_finalize_impl() = 0;
  virtual int thrd_clear_impl() = 0;
  virtual int proc_finalize_impl() = 0;
  virtual int infer_impl1(const void* in,
                          void* out,
                          uint32_t batch_size = -1) = 0;
  virtual int infer_impl2(const BatchTensor& in,
                          BatchTensor& out) = 0;  // NOLINT
  // end: framework inner call
};

class ReloadableInferEngine : public InferEngine {
 public:
  virtual ~ReloadableInferEngine() {}
W
wangguibao 已提交
69

W
wangguibao 已提交
70 71 72 73 74
  union last_check_status {
    time_t last_timestamp;
    uint64_t last_md5sum;
    uint64_t last_revision;
  };
W
wangguibao 已提交
75

W
wangguibao 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89
  typedef im::bsf::Task<Tensor, Tensor> TaskT;

  virtual int load(const std::string& data_path) = 0;

  int proc_initialize_impl(const configure::EngineDesc& conf, bool version) {
    _reload_tag_file = conf.reloadable_meta();
    _reload_mode_tag = conf.reloadable_type();
    _model_data_path = conf.model_data_path();
    _infer_thread_num = conf.runtime_thread_num();
    _infer_batch_size = conf.batch_infer_size();
    _infer_batch_align = conf.enable_batch_align();
    if (!check_need_reload() || load(_model_data_path) != 0) {
      LOG(ERROR) << "Failed load model_data_path" << _model_data_path;
      return -1;
W
wangguibao 已提交
90
    }
W
wangguibao 已提交
91 92 93 94

    if (parse_version_info(conf, version) != 0) {
      LOG(ERROR) << "Failed parse version info";
      return -1;
W
wangguibao 已提交
95
    }
W
wangguibao 已提交
96 97 98 99 100 101 102 103 104

    LOG(WARNING) << "Succ load model_data_path" << _model_data_path;
    return 0;
  }

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    if (proc_initialize_impl(conf, version) != 0) {
      LOG(ERROR) << "Failed proc initialize impl";
      return -1;
W
wangguibao 已提交
105
    }
W
wangguibao 已提交
106 107 108 109

    // init bsf framework
    if (_infer_thread_num <= 0) {
      return 0;
W
wangguibao 已提交
110
    }
W
wangguibao 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124

    im::bsf::TaskExecutor<TaskT>::instance()->set_thread_init_fn(
        boost::bind(&InferEngine::thrd_initialize_impl, this));
    im::bsf::TaskExecutor<TaskT>::instance()->set_thread_reset_fn(
        boost::bind(&InferEngine::thrd_clear_impl, this));
    im::bsf::TaskExecutor<TaskT>::instance()->set_thread_callback_fn(
        boost::bind(&InferEngine::infer_impl2, this, _1, _2));
    im::bsf::TaskExecutor<TaskT>::instance()->set_batch_size(_infer_batch_size);
    im::bsf::TaskExecutor<TaskT>::instance()->set_batch_align(
        _infer_batch_align);
    if (im::bsf::TaskExecutor<TaskT>::instance()->start(_infer_thread_num) !=
        0) {
      LOG(ERROR) << "Failed start bsf executor, threads:" << _infer_thread_num;
      return -1;
W
wangguibao 已提交
125 126
    }

W
wangguibao 已提交
127 128 129
    LOG(WARNING) << "Enable batch schedule framework, thread_num:"
                 << _infer_thread_num << ", batch_size:" << _infer_batch_size
                 << ", enable_batch_align:" << _infer_batch_align;
W
wangguibao 已提交
130

W
wangguibao 已提交
131 132
    return 0;
  }
W
wangguibao 已提交
133

W
wangguibao 已提交
134 135 136 137
  int infer(const void* in, void* out, uint32_t batch_size = -1) {
    if (_infer_thread_num <= 0) {
      return infer_impl1(in, out, batch_size);
    }
W
wangguibao 已提交
138

W
wangguibao 已提交
139 140 141 142 143 144
    im::bsf::TaskManager<Tensor, Tensor> task_manager;
    task_manager.schedule(*(reinterpret_cast<const BatchTensor*>(in)),
                          *(reinterpret_cast<BatchTensor*>(out)));
    task_manager.wait();
    return 0;
  }
W
wangguibao 已提交
145

W
wangguibao 已提交
146 147 148 149
  int thrd_initialize() {
    if (_infer_thread_num > 0) {
      return 0;
    }
W
wangguibao 已提交
150

W
wangguibao 已提交
151 152
    return thrd_initialize_impl();
  }
W
wangguibao 已提交
153

W
wangguibao 已提交
154 155 156 157
  int thrd_clear() {
    if (_infer_thread_num > 0) {
      return 0;
    }
W
wangguibao 已提交
158

W
wangguibao 已提交
159 160
    return thrd_clear_impl();
  }
W
wangguibao 已提交
161

W
wangguibao 已提交
162 163 164 165 166
  int proc_finalize() {
    if (proc_finalize_impl() != 0) {
      LOG(ERROR) << "Failed proc finalize impl";
      return -1;
    }
W
wangguibao 已提交
167

W
wangguibao 已提交
168 169
    if (_infer_thread_num > 0) {
      im::bsf::TaskExecutor<TaskT>::instance()->stop();
W
wangguibao 已提交
170 171
    }

W
wangguibao 已提交
172 173
    return 0;
  }
W
wangguibao 已提交
174

W
wangguibao 已提交
175 176 177 178 179 180 181 182 183 184 185
  int reload() {
    if (check_need_reload()) {
      LOG(WARNING) << "begin reload model[" << _model_data_path << "].";
      return load(_model_data_path);
    }
    return 0;
  }

  uint64_t version() const { return _version; }

  uint32_t thread_num() const { return _infer_thread_num; }
W
wangguibao 已提交
186

W
wangguibao 已提交
187 188 189 190 191
 private:
  int parse_version_info(const configure::EngineDesc& config, bool version) {
    _version = uint64_t(-1);
    return 0;
  }
W
wangguibao 已提交
192

W
wangguibao 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
  bool check_need_reload() {
    if (_reload_mode_tag == "timestamp_ne") {
      return check_timestamp_ne();
    } else if (_reload_mode_tag == "timestamp_gt") {
      return check_timestamp_gt();
    } else if (_reload_mode_tag == "md5sum") {
      return check_md5sum();
    } else if (_reload_mode_tag == "revision") {
      return check_revision();
    } else if (_reload_mode_tag == "none") {
      return false;
    } else {
      LOG(ERROR) << "Not support check type: " << _reload_mode_tag;
      return false;
    }
  }

  bool check_timestamp_ne() {
    struct stat st;
    if (stat(_reload_tag_file.c_str(), &st) != 0) {
      LOG(ERROR) << "Failed stat config file:" << _reload_tag_file;
      return false;
    }
W
wangguibao 已提交
216

W
wangguibao 已提交
217 218 219
    if ((st.st_mode & S_IFREG) && st.st_mtime != _last_status.last_timestamp) {
      _last_status.last_timestamp = st.st_mtime;
      return true;
W
wangguibao 已提交
220 221
    }

W
wangguibao 已提交
222 223
    return false;
  }
W
wangguibao 已提交
224

W
wangguibao 已提交
225 226 227 228 229 230
  bool check_timestamp_gt() {
    struct stat st;
    if (stat(_reload_tag_file.c_str(), &st) != 0) {
      LOG(ERROR) << "Failed stat config file:" << _reload_tag_file;
      return false;
    }
W
wangguibao 已提交
231

W
wangguibao 已提交
232 233 234
    if ((st.st_mode & S_IFREG) && st.st_mtime > _last_status.last_timestamp) {
      _last_status.last_timestamp = st.st_mtime;
      return true;
W
wangguibao 已提交
235 236
    }

W
wangguibao 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    return false;
  }

  bool check_md5sum() { return false; }

  bool check_revision() { return false; }

 protected:
  std::string _model_data_path;

 private:
  std::string _reload_tag_file;
  std::string _reload_mode_tag;
  last_check_status _last_status;
  uint32_t _infer_thread_num;
  uint32_t _infer_batch_size;
  bool _infer_batch_align;
  uint64_t _version;
};
W
wangguibao 已提交
256

W
wangguibao 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
template <typename EngineCore>
struct ModelData {
  ModelData() : current_idx(1) {
    cores[0] = NULL;
    cores[1] = NULL;
  }

  ~ModelData() {
    delete cores[0];
    delete cores[1];
  }

  EngineCore* cores[2];
  uint32_t current_idx;
};

template <typename EngineCore>
class DBReloadableInferEngine : public ReloadableInferEngine {
 public:
  virtual ~DBReloadableInferEngine() {}

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    THREAD_KEY_CREATE(&_skey, NULL);
    THREAD_MUTEX_INIT(&_mutex, NULL);
    return ReloadableInferEngine::proc_initialize(conf, version);
  }

  virtual int load(const std::string& model_data_dir) {
    if (_reload_vec.empty()) {
      return 0;
W
wangguibao 已提交
287 288
    }

W
wangguibao 已提交
289 290 291 292 293 294 295 296
    for (uint32_t ti = 0; ti < _reload_vec.size(); ++ti) {
      if (load_data(_reload_vec[ti], model_data_dir) != 0) {
        LOG(ERROR) << "Failed reload engine model: " << ti;
        return -1;
      }
    }

    LOG(WARNING) << "Succ load engine, path: " << model_data_dir;
W
wangguibao 已提交
297

W
wangguibao 已提交
298 299
    return 0;
  }
W
wangguibao 已提交
300

W
wangguibao 已提交
301 302 303 304
  int load_data(ModelData<EngineCore>* md, const std::string& data_path) {
    uint32_t next_idx = (md->current_idx + 1) % 2;
    if (md->cores[next_idx]) {
      delete md->cores[next_idx];
W
wangguibao 已提交
305 306
    }

W
wangguibao 已提交
307 308 309 310
    md->cores[next_idx] = new (std::nothrow) EngineCore;
    if (!md->cores[next_idx] || md->cores[next_idx]->create(data_path) != 0) {
      LOG(ERROR) << "Failed create model, path: " << data_path;
      return -1;
W
wangguibao 已提交
311
    }
W
wangguibao 已提交
312 313 314
    md->current_idx = next_idx;
    return 0;
  }
W
wangguibao 已提交
315

W
wangguibao 已提交
316 317 318 319 320
  virtual int thrd_initialize_impl() {
    // memory pool to be inited in non-serving-threads
    if (MempoolWrapper::instance().thread_initialize() != 0) {
      LOG(ERROR) << "Failed thread initialize mempool";
      return -1;
W
wangguibao 已提交
321 322
    }

W
wangguibao 已提交
323 324 325 326
    ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
    if (!md || load_data(md, _model_data_path) != 0) {
      LOG(ERROR) << "Failed create thread data from " << _model_data_path;
      return -1;
W
wangguibao 已提交
327 328
    }

W
wangguibao 已提交
329 330 331 332 333 334 335 336 337 338 339
    THREAD_SETSPECIFIC(_skey, md);
    im::bsf::AutoMutex lock(_mutex);
    _reload_vec.push_back(md);
    return 0;
  }

  int thrd_clear_impl() {
    // for non-serving-threads
    if (MempoolWrapper::instance().thread_clear() != 0) {
      LOG(ERROR) << "Failed thread clear mempool";
      return -1;
W
wangguibao 已提交
340
    }
W
wangguibao 已提交
341 342 343 344
    return 0;
  }

  int thrd_finalize_impl() { return 0; }
W
wangguibao 已提交
345

W
wangguibao 已提交
346 347 348 349 350
  int proc_finalize_impl() {
    THREAD_KEY_DELETE(_skey);
    THREAD_MUTEX_DESTROY(&_mutex);
    return 0;
  }
W
wangguibao 已提交
351

W
wangguibao 已提交
352 353 354 355 356 357
  EngineCore* get_core() {
    ModelData<EngineCore>* md =
        (ModelData<EngineCore>*)THREAD_GETSPECIFIC(_skey);
    if (!md) {
      LOG(ERROR) << "Failed get thread specific data";
      return NULL;
W
wangguibao 已提交
358
    }
W
wangguibao 已提交
359 360
    return md->cores[md->current_idx];
  }
W
wangguibao 已提交
361

W
wangguibao 已提交
362 363 364 365
 protected:
  THREAD_KEY_T _skey;
  THREAD_MUTEX_T _mutex;
  std::vector<ModelData<EngineCore>*> _reload_vec;
W
wangguibao 已提交
366

W
wangguibao 已提交
367 368
 private:
};
W
wangguibao 已提交
369

W
wangguibao 已提交
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
// 多个EngineCore共用同一份模型数据
template <typename EngineCore>
class CloneDBReloadableInferEngine
    : public DBReloadableInferEngine<EngineCore> {
 public:
  virtual ~CloneDBReloadableInferEngine() {}

  virtual int proc_initialize(const configure::EngineDesc& conf, bool version) {
    _pd = new (std::nothrow) ModelData<EngineCore>;
    if (!_pd) {
      LOG(ERROR) << "Failed to allocate for ProcData";
      return -1;
    }
    return DBReloadableInferEngine<EngineCore>::proc_initialize(conf, version);
  }

  virtual int load(const std::string& model_data_dir) {
    // 加载进程级模型数据
    if (!_pd ||
        DBReloadableInferEngine<EngineCore>::load_data(_pd, model_data_dir) !=
            0) {
      LOG(ERROR) << "Failed to create common model from [" << model_data_dir
                 << "].";
      return -1;
    }
    LOG(WARNING) << "Succ load common model[" << _pd->cores[_pd->current_idx]
                 << "], path[" << model_data_dir << "].";

    if (DBReloadableInferEngine<EngineCore>::_reload_vec.empty()) {
      return 0;
    }

    for (uint32_t ti = 0;
         ti < DBReloadableInferEngine<EngineCore>::_reload_vec.size();
         ++ti) {
      if (load_data(DBReloadableInferEngine<EngineCore>::_reload_vec[ti],
                    _pd->cores[_pd->current_idx]) != 0) {
        LOG(ERROR) << "Failed reload engine model: " << ti;
        return -1;
      }
W
wangguibao 已提交
410 411
    }

W
wangguibao 已提交
412
    LOG(WARNING) << "Succ load clone model, path[" << model_data_dir << "]";
W
wangguibao 已提交
413

W
wangguibao 已提交
414 415
    return 0;
  }
W
wangguibao 已提交
416

W
wangguibao 已提交
417 418 419 420 421
  // 加载线程级对象,多个线程级对象共用pd_core的模型数据
  int load_data(ModelData<EngineCore>* td, EngineCore* pd_core) {
    uint32_t next_idx = (td->current_idx + 1) % 2;
    if (td->cores[next_idx]) {
      delete td->cores[next_idx];
W
wangguibao 已提交
422 423
    }

W
wangguibao 已提交
424 425 426 427 428 429
    td->cores[next_idx] = new (std::nothrow) EngineCore;
    if (!td->cores[next_idx] ||
        td->cores[next_idx]->clone(pd_core->get()) != 0) {
      LOG(ERROR) << "Failed clone model from pd_core[ " << pd_core << "], idx["
                 << next_idx << "]";
      return -1;
W
wangguibao 已提交
430
    }
W
wangguibao 已提交
431 432 433 434 435 436
    td->current_idx = next_idx;
    LOG(WARNING) << "td_core[" << td->cores[td->current_idx]
                 << "] clone model from pd_core[" << pd_core
                 << "] succ, cur_idx[" << td->current_idx << "].";
    return 0;
  }
W
wangguibao 已提交
437

W
wangguibao 已提交
438 439 440 441 442
  virtual int thrd_initialize_impl() {
    // memory pool to be inited in non-serving-threads
    if (MempoolWrapper::instance().thread_initialize() != 0) {
      LOG(ERROR) << "Failed thread initialize mempool";
      return -1;
W
wangguibao 已提交
443 444
    }

W
wangguibao 已提交
445 446 447 448 449 450 451 452 453 454 455 456
    ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
    if (!md || load_data(md, _pd->cores[_pd->current_idx]) != 0) {
      LOG(ERROR) << "Failed clone thread data, origin_core["
                 << _pd->cores[_pd->current_idx] << "].";
      return -1;
    }

    THREAD_SETSPECIFIC(DBReloadableInferEngine<EngineCore>::_skey, md);
    im::bsf::AutoMutex lock(DBReloadableInferEngine<EngineCore>::_mutex);
    DBReloadableInferEngine<EngineCore>::_reload_vec.push_back(md);
    return 0;
  }
W
wangguibao 已提交
457

W
wangguibao 已提交
458 459 460
 protected:
  ModelData<EngineCore>*
      _pd;  // 进程级EngineCore,多个线程级EngineCore共用该对象的模型数据
W
wangguibao 已提交
461 462
};

W
wangguibao 已提交
463
template <typename FluidFamilyCore>
W
Wang Guibao 已提交
464
class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> {
W
wangguibao 已提交
465 466 467
 public:
  FluidInferEngine() {}
  ~FluidInferEngine() {}
W
wangguibao 已提交
468

W
wangguibao 已提交
469 470 471 472 473 474
  int infer_impl1(const void* in, void* out, uint32_t batch_size = -1) {
    FluidFamilyCore* core =
        DBReloadableInferEngine<FluidFamilyCore>::get_core();
    if (!core || !core->get()) {
      LOG(ERROR) << "Failed get fluid core in infer_impl()";
      return -1;
W
wangguibao 已提交
475 476
    }

W
wangguibao 已提交
477 478 479 480 481 482
    if (!core->Run(in, out)) {
      LOG(ERROR) << "Failed run fluid family core";
      return -1;
    }
    return 0;
  }
W
wangguibao 已提交
483

W
wangguibao 已提交
484 485 486
  int infer_impl2(const BatchTensor& in, BatchTensor& out) {  // NOLINT
    return infer_impl1(&in, &out);
  }
W
wangguibao 已提交
487 488
};

W
wangguibao 已提交
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
typedef FactoryPool<InferEngine> StaticInferFactory;

class VersionedInferEngine : public InferEngine {
 public:
  VersionedInferEngine() { _versions.clear(); }
  ~VersionedInferEngine() {}

  int proc_initialize(const configure::EngineDesc& conf) {
    if (proc_initialize(conf, false) != 0) {
      LOG(ERROR) << "Failed proc intialize engine: " << conf.name().c_str();
      return -1;
    }

    LOG(WARNING) << "Succ proc initialize engine: " << conf.name().c_str();
    return 0;
  }

  int proc_initialize(const configure::EngineDesc& conf, bool version) {
    std::string engine_type = conf.type();
    InferEngine* engine =
        StaticInferFactory::instance().generate_object(engine_type);
    if (!engine) {
      LOG(ERROR) << "Failed generate engine with type:" << engine_type;
      return -1;
    }

    if (engine->proc_initialize(conf, version) != 0) {
      LOG(ERROR) << "Failed initialize engine, type:" << engine_type;
      return -1;
    }

    auto r = _versions.insert(std::make_pair(engine->version(), engine));
    if (!r.second) {
      LOG(ERROR) << "Failed insert item: " << engine->version()
                 << ", type: " << engine_type;
      return -1;
    }
    LOG(WARNING) << "Succ proc initialize version engine: "
                 << engine->version();
    return 0;
  }

  int proc_finalize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->proc_finalize() != 0) {
        LOG(ERROR) << "Failed proc finalize version engine: " << iter->first;
      }
      LOG(WARNING) << "Succ proc finalize version engine: " << iter->first;
    }
    return 0;
  }

  int thrd_initialize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_initialize() != 0) {
        LOG(ERROR) << "Failed thrd initialize version engine: " << iter->first;
W
wangguibao 已提交
545
        return -1;
W
wangguibao 已提交
546 547
      }
      LOG(WARNING) << "Succ thrd initialize version engine: " << iter->first;
W
wangguibao 已提交
548
    }
W
wangguibao 已提交
549 550
    return 0;
  }
W
wangguibao 已提交
551

W
wangguibao 已提交
552 553 554 555
  int thrd_clear() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_clear() != 0) {
        LOG(ERROR) << "Failed thrd clear version engine: " << iter->first;
W
wangguibao 已提交
556
        return -1;
W
wangguibao 已提交
557 558
      }
      LOG(INFO) << "Succ thrd clear version engine: " << iter->first;
W
wangguibao 已提交
559
    }
W
wangguibao 已提交
560 561
    return 0;
  }
W
wangguibao 已提交
562

W
wangguibao 已提交
563 564 565 566 567 568 569
  int thrd_finalize() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->thrd_finalize() != 0) {
        LOG(ERROR) << "Failed thrd finalize version engine: " << iter->first;
        return -1;
      }
      LOG(WARNING) << "Succ thrd finalize version engine: " << iter->first;
W
wangguibao 已提交
570
    }
W
wangguibao 已提交
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671
    return 0;
  }

  int reload() {
    for (auto iter = _versions.begin(); iter != _versions.end(); ++iter) {
      if (iter->second->reload() != 0) {
        LOG(ERROR) << "Failed reload version engine: " << iter->first;
        return -1;
      }
      LOG(WARNING) << "Succ reload version engine: " << iter->first;
    }
    return 0;
  }

  uint64_t version() const {
    InferEngine* engine = default_engine();
    if (engine) {
      return engine->version();
    } else {
      return uint64_t(-1);
    }
  }

  // inference interface
  InferEngine* default_engine() const {
    if (_versions.size() != 1) {
      LOG(ERROR) << "Ambiguous default engine version:" << _versions.size();
      return NULL;
    }

    return _versions.begin()->second;
  }

  int infer(const void* in, void* out, uint32_t batch_size) {
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get default engine";
      return -1;
    }
    return engine->infer(in, out, batch_size);
  }

  template <typename T>
  T* get_core() {
    InferEngine* engine = default_engine();
    if (!engine) {
      LOG(WARNING) << "fail to get core";
      return NULL;
    }
    auto db_engine = dynamic_cast<DBReloadableInferEngine<T>*>(engine);
    if (db_engine) {
      return db_engine->get_core();
    }
    LOG(WARNING) << "fail to get core";
    return NULL;
  }

  // versioned inference interface
  int infer(const void* in, void* out, uint32_t batch_size, uint64_t version) {
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
      return -1;
    }

    return iter->second->infer(in, out, batch_size);
  }

  template <typename T>
  T* get_core(uint64_t version) {
    auto iter = _versions.find(version);
    if (iter == _versions.end()) {
      LOG(ERROR) << "Not found version engine: " << version;
      return NULL;
    }

    auto db_engine = dynamic_cast<DBReloadableInferEngine<T>*>(iter->second);
    if (db_engine) {
      return db_engine->get_core();
    }
    LOG(WARNING) << "fail to get core for " << version;
    return NULL;
  }

  // --
  int proc_initialize_impl(const configure::EngineDesc& conf, bool) {
    return -1;
  }
  int thrd_initialize_impl() { return -1; }
  int thrd_finalize_impl() { return -1; }
  int thrd_clear_impl() { return -1; }
  int proc_finalize_impl() { return -1; }
  int infer_impl1(const void* in, void* out, uint32_t batch_size = -1) {
    return -1;
  }
  int infer_impl2(const BatchTensor& in, BatchTensor& out) {  // NOLINT
    return -1;
  }  // NOLINT

 private:
  boost::unordered_map<uint64_t, InferEngine*> _versions;
W
wangguibao 已提交
672 673
};

W
wangguibao 已提交
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698
class InferManager {
 public:
  static InferManager& instance() {
    static InferManager ins;
    return ins;
  }

  int proc_initialize(const char* path, const char* file) {
    ModelToolkitConf model_toolkit_conf;
    if (configure::read_proto_conf(path, file, &model_toolkit_conf) != 0) {
      LOG(ERROR) << "failed load infer config, path: " << path << "/" << file;
      return -1;
    }

    size_t engine_num = model_toolkit_conf.engines_size();
    for (size_t ei = 0; ei < engine_num; ++ei) {
      std::string engine_name = model_toolkit_conf.engines(ei).name();
      VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine();
      if (!engine) {
        LOG(ERROR) << "Failed generate versioned engine: " << engine_name;
        return -1;
      }

      if (engine->proc_initialize(model_toolkit_conf.engines(ei)) != 0) {
        LOG(ERROR) << "Failed initialize version engine, name:" << engine_name;
W
wangguibao 已提交
699
        return -1;
W
wangguibao 已提交
700 701 702 703 704 705 706 707
      }

      auto r = _map.insert(std::make_pair(engine_name, engine));
      if (!r.second) {
        LOG(ERROR) << "Failed insert item: " << engine_name;
        return -1;
      }
      LOG(WARNING) << "Succ proc initialize engine: " << engine_name;
W
wangguibao 已提交
708 709
    }

W
wangguibao 已提交
710 711 712 713 714 715 716
    return 0;
  }

  int thrd_initialize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_initialize() != 0) {
        LOG(ERROR) << "Failed thrd initialize engine, name: " << it->first;
W
wangguibao 已提交
717
        return -1;
W
wangguibao 已提交
718 719
      }
      LOG(WARNING) << "Succ thrd initialize engine, name: " << it->first;
W
wangguibao 已提交
720
    }
W
wangguibao 已提交
721 722
    return 0;
  }
W
wangguibao 已提交
723

W
wangguibao 已提交
724 725 726 727 728 729 730 731 732
  int thrd_clear() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_clear() != 0) {
        LOG(ERROR) << "Failed thrd clear engine, name: " << it->first;
        return -1;
      }
    }
    return 0;
  }
W
wangguibao 已提交
733

W
wangguibao 已提交
734 735 736 737 738 739 740 741 742
  int reload() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->reload() != 0) {
        LOG(ERROR) << "Failed reload engine, name: " << it->first;
        return -1;
      }
    }
    return 0;
  }
W
wangguibao 已提交
743

W
wangguibao 已提交
744 745 746 747 748 749 750 751 752 753
  int thrd_finalize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->thrd_finalize() != 0) {
        LOG(ERROR) << "Failed thrd finalize engine, name: " << it->first;
        return -1;
      }
      LOG(WARNING) << "Succ thrd finalize engine, name: " << it->first;
    }
    return 0;
  }
W
wangguibao 已提交
754

W
wangguibao 已提交
755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837
  int proc_finalize() {
    for (auto it = _map.begin(); it != _map.end(); ++it) {
      if (it->second->proc_finalize() != 0) {
        LOG(ERROR) << "Failed proc finalize engine, name: " << it->first;
        return -1;
      }
      LOG(WARNING) << "Succ proc finalize engine, name: " << it->first;
    }
    return 0;
  }

  // Inference interface
  int infer(const char* model_name,
            const void* in,
            void* out,
            uint32_t batch_size = -1) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return -1;
    }
    return it->second->infer(in, out, batch_size);
  }

  template <typename T>
  T* get_core(const char* model_name) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return NULL;
    }
    auto infer_engine =
        dynamic_cast<DBReloadableInferEngine<T>*>(it->second->default_engine());
    if (infer_engine) {
      return infer_engine->get_core();
    }
    LOG(WARNING) << "fail to get core for " << model_name;
    return NULL;
  }

  // Versioned inference interface
  int infer(const char* model_name,
            const void* in,
            void* out,
            uint32_t batch_size,
            uint64_t version) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return -1;
    }
    return it->second->infer(in, out, batch_size, version);
  }

  template <typename T>
  T* get_core(const char* model_name, uint64_t version) {
    auto it = _map.find(model_name);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
      return NULL;
    }
    return it->second->get_core<T>(version);
  }

  int query_version(const std::string& model, uint64_t& version) {  // NOLINT
    auto it = _map.find(model);
    if (it == _map.end()) {
      LOG(WARNING) << "Cannot find engine in map, model name:" << model;
      return -1;
    }
    auto infer_engine = it->second->default_engine();
    if (!infer_engine) {
      LOG(WARNING) << "Cannot get default engine for model:" << model;
      return -1;
    }
    version = infer_engine->version();
    LOG(INFO) << "Succ get version: " << version << " for model: " << model;
    return 0;
  }

 private:
  boost::unordered_map<std::string, VersionedInferEngine*> _map;
};
W
wangguibao 已提交
838

W
wangguibao 已提交
839 840 841
}  // namespace predictor
}  // namespace paddle_serving
}  // namespace baidu