未验证 提交 fe0ab9c8 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1289 from HexToString/develop-p

C++ Serving更新
...@@ -30,7 +30,7 @@ find_package(Threads REQUIRED) ...@@ -30,7 +30,7 @@ find_package(Threads REQUIRED)
find_package(CUDA QUIET) find_package(CUDA QUIET)
include(simd) include(simd)
# SET(CMAKE_BUILD_TYPE "Debug")
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING
......
...@@ -175,9 +175,12 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p ...@@ -175,9 +175,12 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
| Argument | Type | Default | Description | | Argument | Type | Default | Description |
| ---------------------------------------------- | ---- | ------- | ----------------------------------------------------- | | ---------------------------------------------- | ---- | ------- | ----------------------------------------------------- |
| `thread` | int | `4` | Concurrency of current service | | `thread` | int | `2` | Number of brpc service thread |
| `op_num` | int[]| `0` | Thread Number for each model in asynchronous mode |
| `op_max_batch` | int[]| `0` | Batch Number for each model in asynchronous mode |
| `gpu_ids` | str[]| `"-1"` | Gpu card id for each model |
| `port` | int | `9292` | Exposed port of current service to users | | `port` | int | `9292` | Exposed port of current service to users |
| `model` | str | `""` | Path of paddle model directory to be served | | `model` | str[]| `""` | Path of paddle model directory to be served |
| `mem_optim_off` | - | - | Disable memory / graphic memory optimization | | `mem_optim_off` | - | - | Disable memory / graphic memory optimization |
| `ir_optim` | bool | False | Enable analysis and optimization of calculation graph | | `ir_optim` | bool | False | Enable analysis and optimization of calculation graph |
| `use_mkl` (Only for cpu version) | - | - | Run inference with MKL | | `use_mkl` (Only for cpu version) | - | - | Run inference with MKL |
...@@ -186,7 +189,24 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p ...@@ -186,7 +189,24 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU | | `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 | | `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
| `use_calib` | bool | False | Only for deployment with TensorRT | | `use_calib` | bool | False | Only for deployment with TensorRT |
| `gpu_multi_stream` | bool | False | EnableGpuMultiStream to get larger QPS |
#### Description of asynchronous model
Asynchronous mode is suitable for 1. When the number of requests is very large, 2. When multiple models are concatenated and you want to specify the concurrency number of each model.
Asynchronous mode helps to improve the throughput (QPS) of service, but for a single request, the delay will increase slightly.
In asynchronous mode, each model will start n threads of the number you specify, and each thread contains a model instance. In other words, each model is equivalent to a thread pool containing N threads, and the task is taken from the task queue of the thread pool to execute.
In asynchronous mode, each RPC server thread is only responsible for putting the request into the task queue of the model thread pool. After the task is executed, the completed task is removed from the task queue.
In the above table, the number of RPC server threads is specified by --thread, and the default value is 2.
--op_num specifies the number of threads in the thread pool of each model. The default value is 0, indicating that asynchronous mode is not used.
--op_max_batch specifies the number of batches for each model. The default value is 32. It takes effect when --op_num is not 0.
#### When you want a model to use multiple GPU cards.
python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9292 --gpu_ids 0,1,2
#### When you want 2 models.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292
#### When you want 2 models, and want each of them use multiple GPU cards.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2
#### When a service contains two models, and each model needs to specify multiple GPU cards, and needs asynchronous mode, each model specifies different concurrency number.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2 --op_num 4 8
</center> </center>
```python ```python
......
...@@ -172,19 +172,40 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p ...@@ -172,19 +172,40 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
``` ```
<center> <center>
| Argument | Type | Default | Description | | Argument | Type | Default | Description |
| ---------------------------------------------- | ---- | ------- | ------------------------------------------------------ | | ---------------------------------------------- | ---- | ------- | ----------------------------------------------------- |
| `thread` | int | `4` | Concurrency of current service | | `thread` | int | `2` | Number of brpc service thread |
| `port` | int | `9292` | Exposed port of current service to users | | `op_num` | int[]| `0` | Thread Number for each model in asynchronous mode |
| `name` | str | `""` | Service name, can be used to generate HTTP request url | | `op_max_batch` | int[]| `32` | Batch Number for each model in asynchronous mode |
| `model` | str | `""` | Path of paddle model directory to be served | | `gpu_ids` | str[]| `"-1"` | Gpu card id for each model |
| `mem_optim_off` | - | - | Disable memory optimization | | `port` | int | `9292` | Exposed port of current service to users |
| `ir_optim` | bool | False | Enable analysis and optimization of calculation graph | | `model` | str[]| `""` | Path of paddle model directory to be served |
| `use_mkl` (Only for cpu version) | - | - | Run inference with MKL | | `mem_optim_off` | - | - | Disable memory / graphic memory optimization |
| `use_trt` (Only for Cuda>=10.1 version) | - | - | Run inference with TensorRT | | `ir_optim` | bool | False | Enable analysis and optimization of calculation graph |
| `use_lite` (Only for Intel x86 CPU or ARM CPU) | - | - | Run PaddleLite inference | | `use_mkl` (Only for cpu version) | - | - | Run inference with MKL |
| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU | | `use_trt` (Only for trt version) | - | - | Run inference with TensorRT |
| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 | | `use_lite` (Only for Intel x86 CPU or ARM CPU) | - | - | Run PaddleLite inference |
| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
| `use_calib` | bool | False | Only for deployment with TensorRT |
| `gpu_multi_stream` | bool | False | EnableGpuMultiStream to get larger QPS |
#### 异步模型的说明
异步模式适用于1、请求数量非常大的情况,2、多模型串联,想要分别指定每个模型的并发数的情况。
异步模式有助于提高Service服务的吞吐(QPS),但对于单次请求而言,时延会有少量增加。
异步模式中,每个模型会启动您指定个数的N个线程,每个线程中包含一个模型实例,换句话说每个模型相当于包含N个线程的线程池,从线程池的任务队列中取任务来执行。
异步模式中,各个RPC Server的线程只负责将Request请求放入模型线程池的任务队列中,等任务被执行完毕后,再从任务队列中取出已完成的任务。
上表中通过 --thread 10 指定的是RPC Server的线程数量,默认值为2,--op_num 指定的是各个模型的线程池中线程数N,默认值为0,表示不使用异步模式。
--op_max_batch 指定的各个模型的batch数量,默认值为32,该参数只有当--op_num不为0时才生效。
#### 当您的某个模型想使用多张GPU卡部署时.
python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9292 --gpu_ids 0,1,2
#### 当您的一个服务包含两个模型部署时.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292
#### 当您的一个服务包含两个模型,且每个模型都需要指定多张GPU卡部署时.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2
#### 当您的一个服务包含两个模型,且每个模型都需要指定多张GPU卡,且需要异步模式每个模型指定不同的并发数时.
python3 -m paddle_serving_server.serve --model uci_housing_model_1 uci_housing_model_2 --thread 10 --port 9292 --gpu_ids 0,1 1,2 --op_num 4 8
</center> </center>
......
...@@ -21,11 +21,12 @@ message EngineDesc { ...@@ -21,11 +21,12 @@ message EngineDesc {
required string reloadable_meta = 3; required string reloadable_meta = 3;
required string reloadable_type = 4; required string reloadable_type = 4;
required string model_dir = 5; required string model_dir = 5;
required int32 runtime_thread_num = 6; repeated int32 gpu_ids = 6;
required int32 batch_infer_size = 7; required int32 runtime_thread_num = 7;
required int32 enable_batch_align = 8; required int32 batch_infer_size = 8;
optional string version_file = 9; required int32 enable_batch_align = 9;
optional string version_type = 10; optional string version_file = 10;
optional string version_type = 11;
/* /*
* Sparse Parameter Service type. Valid types are: * Sparse Parameter Service type. Valid types are:
...@@ -38,16 +39,17 @@ message EngineDesc { ...@@ -38,16 +39,17 @@ message EngineDesc {
LOCAL = 1; LOCAL = 1;
REMOTE = 2; REMOTE = 2;
} }
optional SparseParamServiceType sparse_param_service_type = 11; optional SparseParamServiceType sparse_param_service_type = 12;
optional string sparse_param_service_table_name = 12; optional string sparse_param_service_table_name = 13;
optional bool enable_memory_optimization = 13; optional bool enable_memory_optimization = 14;
optional bool enable_ir_optimization = 14; optional bool enable_ir_optimization = 15;
optional bool use_trt = 15; optional bool use_trt = 16;
optional bool use_lite = 16; optional bool use_lite = 17;
optional bool use_xpu = 17; optional bool use_xpu = 18;
optional bool use_gpu = 18; optional bool use_gpu = 19;
optional bool combined_model = 19; optional bool combined_model = 20;
optional bool encrypted_model = 20; optional bool encrypted_model = 21;
optional bool gpu_multi_stream = 22;
}; };
// model_toolkit conf // model_toolkit conf
......
...@@ -166,6 +166,8 @@ int PredictorClient::numpy_predict( ...@@ -166,6 +166,8 @@ int PredictorClient::numpy_predict(
batch_size = batch_size > string_feed_batch.size() ? batch_size batch_size = batch_size > string_feed_batch.size() ? batch_size
: string_feed_batch.size(); : string_feed_batch.size();
VLOG(2) << "batch size: " << batch_size; VLOG(2) << "batch size: " << batch_size;
// batch_size must be 1, cause batch is already in Tensor.
// I suggest to remove the outside vector<>.
predict_res_batch.clear(); predict_res_batch.clear();
Timer timeline; Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS(); int64_t preprocess_start = timeline.TimeStampUS();
...@@ -188,6 +190,8 @@ int PredictorClient::numpy_predict( ...@@ -188,6 +190,8 @@ int PredictorClient::numpy_predict(
} }
int vec_idx = 0; int vec_idx = 0;
// batch_size can only be 1, cause batch is already in Tensor.
// if batch_size is not 1, error will occur in C++ part.
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
VLOG(2) << "prepare batch " << bi; VLOG(2) << "prepare batch " << bi;
std::vector<Tensor *> tensor_vec; std::vector<Tensor *> tensor_vec;
......
...@@ -93,6 +93,9 @@ int GeneralReaderOp::inference() { ...@@ -93,6 +93,9 @@ int GeneralReaderOp::inference() {
res->SetLogId(log_id); res->SetLogId(log_id);
Timer timeline; Timer timeline;
int64_t start = timeline.TimeStampUS(); int64_t start = timeline.TimeStampUS();
// only get insts(0), cause batch is already in Tensor.
// req can only include 1 inst.
// var_num means the number of feed_var.
int var_num = req->insts(0).tensor_array_size(); int var_num = req->insts(0).tensor_array_size();
VLOG(2) << "(logid=" << log_id << ") var num: " << var_num VLOG(2) << "(logid=" << log_id << ") var num: " << var_num
...@@ -178,7 +181,10 @@ int GeneralReaderOp::inference() { ...@@ -178,7 +181,10 @@ int GeneralReaderOp::inference() {
VLOG(2) << "(logid=" << log_id << ") tensor size for var[" << i VLOG(2) << "(logid=" << log_id << ") tensor size for var[" << i
<< "]: " << data_len; << "]: " << data_len;
databuf_size = data_len * elem_size; databuf_size = data_len * elem_size;
out->at(i).data.Resize(databuf_size); void *databuf_char = MempoolWrapper::instance().malloc(databuf_size);
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
out->at(i).data = paddleBuf;
// out->at(i).data.Resize(databuf_size);
if (out->at(i).lod.size() > 0) { if (out->at(i).lod.size() > 0) {
VLOG(2) << "(logid=" << log_id << ") var[" << i VLOG(2) << "(logid=" << log_id << ") var[" << i
<< "] has lod_tensor and len=" << out->at(i).lod[0].back(); << "] has lod_tensor and len=" << out->at(i).lod[0].back();
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef BCLOUD
#include <base/atomicops.h>
#else
#include <butil/atomicops.h>
#endif
#include <errno.h>
#include <algorithm>
#include <deque>
#include <vector>
#include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/infer_data.h"
#include "core/predictor/framework/memory.h"
#include <boost/function.hpp>
namespace im {
namespace bsf {
template <>
struct Task<baidu::paddle_serving::predictor::Tensor,
baidu::paddle_serving::predictor::Tensor> {
typedef Task<baidu::paddle_serving::predictor::Tensor,
baidu::paddle_serving::predictor::Tensor>
TaskT;
typedef baidu::paddle_serving::predictor::Tensor Tensor;
typedef baidu::paddle_serving::predictor::Tensor InType;
typedef baidu::paddle_serving::predictor::Tensor OutType;
typedef baidu::paddle_serving::predictor::BatchTensor BatchTensor;
typedef baidu::paddle_serving::predictor::BatchTensor InArrayT;
typedef baidu::paddle_serving::predictor::BatchTensor OutArrayT;
struct Segment {
Segment(void* p, size_t b, size_t s) : ptr(p), begin(b), size(s) {}
void* ptr;
size_t begin;
size_t size;
};
int read_fd;
int write_fd;
pid_t owner_tid;
const InArrayT* in;
OutArrayT* out;
size_t rem;
size_t size;
butil::atomic<size_t> index;
const BatchTensor* get(bool is_in) const {
if (is_in) {
return in;
} else {
return out;
}
}
BatchTensor* get(bool is_in) {
if (is_in) {
return const_cast<BatchTensor*>(in);
} else {
return out;
}
}
Task() {
read_fd = -1;
write_fd = -1;
owner_tid = -1;
in = NULL;
out = NULL;
rem = -1;
size = -1;
index.store(0, butil::memory_order_relaxed);
}
};
template <>
class BatchTasks<Task<baidu::paddle_serving::predictor::Tensor,
baidu::paddle_serving::predictor::Tensor>> {
public:
typedef baidu::paddle_serving::predictor::Tensor Tensor;
typedef baidu::paddle_serving::predictor::Tensor InType;
typedef baidu::paddle_serving::predictor::Tensor OutType;
typedef baidu::paddle_serving::predictor::DataBuf DataBuf;
typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper;
typedef Task<baidu::paddle_serving::predictor::Tensor,
baidu::paddle_serving::predictor::Tensor>
TaskT;
typedef TaskMeta<TaskT> TaskMetaT;
typedef TaskT::InArrayT InArrayT;
typedef TaskT::OutArrayT OutArrayT;
explicit BatchTasks(size_t batch_size, bool batch_align = false)
: _batch_size(batch_size),
_rem_size(batch_size),
_batch_align(batch_align) {
_batch_in.clear();
_batch_out.clear();
_tasks.clear();
}
~BatchTasks() {
_batch_in.clear();
_batch_out.clear();
_tasks.clear();
}
static bool check_valid(const InArrayT& in,
OutArrayT& out, // NOLINT
bool align) { // NOLINT
if (align) {
if (out.count() <= 0 || out.size() <= 0) {
LOG(ERROR) << "Out tensor is empty, when aligned";
return false;
}
if (out.size() != in.size()) {
LOG(ERROR) << "In/Out tensor size not eq: " << out.size()
<< "!=" << in.size();
return false;
}
for (size_t fi = 0, shape0 = 0; fi < out.count(); ++fi) {
if (!out[fi].valid()) {
LOG(ERROR) << "Out[" << fi << "] tensor not valid";
return false;
}
if (out.size() != out[fi].shape0()) {
LOG(ERROR) << "Shape0 not consistency, " << out.size()
<< "!=" << out[fi].shape0() << ", " << fi;
return false;
}
}
}
return true;
}
size_t append_task(TaskT* task) {
size_t add = std::min(task->rem, _rem_size);
if (!_batch_align) {
add = task->rem;
}
TaskMetaT tm(task, task->in->size() - task->rem, add);
_tasks.push_back(tm);
task->rem -= add;
_rem_size -= add;
return _rem_size;
}
void merge_tasks() {
merge_input();
merge_output();
}
void merge_input() {
if (_tasks.size() <= 0 || _tasks[0].task->in->count() <= 0) {
return;
}
if (_tasks.size() == 1 && !_batch_align) {
TaskMetaT& tm = _tasks[0];
_batch_in = *(tm.task->in);
return;
}
merge_tensor(true);
}
void merge_output() {
if (_batch_align) {
if (_tasks.size() <= 0 || _tasks[0].task->out->count() <= 0) {
return;
}
}
if (_tasks.size() <= 0 || _tasks[0].task->out->count() <= 0) {
return;
}
TaskMetaT& tm = _tasks[0];
if (_tasks.size() == 1 && !_batch_align) {
_batch_out = *(tm.task->out);
return;
}
if (tm.task->out->size() <= 0) {
// shape is empty
_batch_out = *(tm.task->out);
return;
}
if ((*tm.task->out)[0].data.data() == 0 ||
(*tm.task->out)[0].data.size() == 0) {
_batch_out = *(tm.task->out);
return;
}
merge_tensor(false);
}
void merge_tensor(bool is_in) {
// accumulate batch size from fetched tasks
size_t batch_size = 0;
for (size_t ti = 0; ti < _tasks.size(); ++ti) {
TaskMetaT& tm = _tasks[ti];
size_t add = tm.end - tm.begin;
batch_size += add;
}
// merge all instanses in each tensor data
size_t tensor_count = _tasks[0].task->get(is_in)->count();
for (size_t fi = 0; fi < tensor_count; ++fi) {
const Tensor& head = (*(_tasks[0].task->get(is_in)))[fi];
Tensor batch_tensor;
batch_tensor.name = head.name;
batch_tensor.type = head.type;
batch_tensor.shape.push_back(batch_size);
size_t ins_ele_count = 1;
for (size_t si = 1; si < head.shape.size(); ++si) {
batch_tensor.shape.push_back(head.shape[si]);
ins_ele_count *= head.shape[si];
}
size_t tensor_ele_count = ins_ele_count * batch_size;
size_t ins_byte = ins_ele_count * head.ele_byte();
size_t tensor_byte = tensor_ele_count * head.ele_byte();
void* data_buf = MempoolWrapper::instance().malloc(tensor_byte);
if (!data_buf) {
LOG(ERROR) << "Malloc failed, size: " << tensor_byte;
return;
}
size_t data_byte = 0;
for (size_t ti = 0; ti < _tasks.size(); ++ti) {
TaskMetaT& tm = _tasks[ti];
size_t acc_byte = ins_byte * (tm.end - tm.begin);
if (data_byte + acc_byte > tensor_byte) {
LOG(ERROR) << "Invalid bytes: " << data_byte << " + " << acc_byte
<< " >= " << tensor_byte;
return;
}
const Tensor& tensor = (*(tm.task->get(is_in)))[fi];
memcpy(
reinterpret_cast<char*>(data_buf) + data_byte,
reinterpret_cast<char*>(tensor.data.data()) + tm.begin * ins_byte,
acc_byte);
data_byte += acc_byte;
}
if (data_byte != tensor_byte) {
LOG(ERROR) << "Invalid tensor byte: " << data_byte
<< " != " << tensor_byte;
return;
}
batch_tensor.data =
DataBuf(reinterpret_cast<char*>(data_buf), tensor_byte);
if (is_in) {
_batch_in.push_back(batch_tensor);
} else {
_batch_out.push_back(batch_tensor);
}
}
LOG(INFO) << "merge input(" << is_in << ") samples: " << batch_size
<< " from " << _tasks.size() << " pvs";
}
void notify_tasks() {
if (_batch_out.size() != _batch_in.size()) {
LOG(ERROR) << "batch size not consistency: " << _batch_out.size()
<< " != " << _batch_in.size();
return;
}
size_t tensor_count = _batch_out.count();
size_t batch_size = _batch_out.size();
for (size_t fi = 0; fi < tensor_count; ++fi) {
const Tensor& tensor = _batch_out[fi];
size_t ins_byte = tensor.ele_byte();
for (size_t si = 1; si < tensor.shape.size(); ++si) {
ins_byte *= tensor.shape[si];
}
for (size_t ti = 0, bi = 0, add = 0; ti < _tasks.size();
++ti, bi += add) {
OutArrayT* dst = _tasks[ti].task->out;
add = _tasks[ti].end - _tasks[ti].begin;
size_t offset_src = ins_byte * bi;
size_t add_byte = add * ins_byte;
if (_batch_align) { // merge all batchs
size_t offset_dst = ins_byte * _tasks[ti].begin;
void* ptr = const_cast<void*>((*dst)[fi].data.data());
memcpy(
reinterpret_cast<char*>(ptr) + offset_dst,
reinterpret_cast<char*>(_batch_out[fi].data.data()) + offset_src,
add_byte);
} else { // overwrite
if (dst->count() <= 0) {
dst->push_back(_batch_out[fi]);
} else {
(*dst)[fi] = _batch_out[fi];
}
(*dst)[fi].shape[0] = add;
(*dst)[fi].data = DataBuf(
reinterpret_cast<char*>(_batch_out[fi].data.data()) + offset_src,
add_byte);
}
}
}
for (size_t ti = 0; ti < _tasks.size(); ++ti) {
TaskT* task = _tasks[ti].task;
size_t begin = _tasks[ti].begin;
size_t end = _tasks[ti].end;
size_t add = end - begin;
size_t index = task->index.fetch_add(add);
if ((index + add) >= task->in->size()) {
char c = 0;
while (write(task->write_fd, &c, 1) != 1 && errno == EINTR) {
}
butil::return_object(task);
}
}
}
const typename TaskT::InArrayT& in() const { return _batch_in; }
typename TaskT::OutArrayT& out() { return _batch_out; }
size_t task_size() { return _tasks.size(); }
private:
std::vector<TaskMetaT> _tasks;
InArrayT _batch_in;
OutArrayT _batch_out;
size_t _batch_size;
size_t _rem_size;
bool _batch_align;
};
} // namespace bsf
} // namespace im
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <boost/bind.hpp> #include <boost/bind.hpp>
#include "core/predictor/common/inner_common.h" #include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/memory.h"
namespace im { namespace im {
namespace bsf { namespace bsf {
...@@ -35,7 +36,7 @@ void* TaskExecutor<TaskT>::thread_entry(void* args) { ...@@ -35,7 +36,7 @@ void* TaskExecutor<TaskT>::thread_entry(void* args) {
static_cast<TaskExecutor<TaskT>*>(context->executor); static_cast<TaskExecutor<TaskT>*>(context->executor);
executor->work(context); executor->work(context);
return NULL; return nullptr;
} }
template <typename TaskT> template <typename TaskT>
...@@ -125,18 +126,21 @@ void TaskExecutor<TaskT>::stop() { ...@@ -125,18 +126,21 @@ void TaskExecutor<TaskT>::stop() {
} }
template <typename TaskT> template <typename TaskT>
TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(const InArrayT& in, TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
OutArrayT& out) { // NOLINT const void* inVectorT_ptr,
void* outVectorT_ptr) { // NOLINT
TaskT* task = butil::get_object<TaskT>(); TaskT* task = butil::get_object<TaskT>();
if (!task) { if (!task) {
LOG(ERROR) << "Failed get TaskT from object pool"; LOG(ERROR) << "Failed get TaskT from object pool";
return TaskHandler<TaskT>::valid_handle(); return TaskHandler<TaskT>::valid_handle();
} }
/*
if (!BatchTasks<TaskT>::check_valid(in, out, _batch_align)) { if (!BatchTasks<TaskT>::check_valid(in, out, _batch_align)) {
LOG(ERROR) << "Invalid input & output"; LOG(ERROR) << "Invalid input & output";
return TaskHandler<TaskT>::valid_handle(); return TaskHandler<TaskT>::valid_handle();
} }
*/
int fds[2]; int fds[2];
int rc = pipe(fds); int rc = pipe(fds);
...@@ -150,10 +154,9 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(const InArrayT& in, ...@@ -150,10 +154,9 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(const InArrayT& in,
task->write_fd = fds[1]; task->write_fd = fds[1];
task->owner_tid = ::syscall(SYS_gettid); task->owner_tid = ::syscall(SYS_gettid);
task->in = &in; task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr;
task->out = &out; task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr;
task->rem = in.size(); task->rem = task->batch_size();
task->size = in.size();
task->index.store(0, butil::memory_order_relaxed); task->index.store(0, butil::memory_order_relaxed);
AutoMutex lock(_mut); AutoMutex lock(_mut);
...@@ -163,8 +166,13 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(const InArrayT& in, ...@@ -163,8 +166,13 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(const InArrayT& in,
return TaskHandler<TaskT>(*task); return TaskHandler<TaskT>(*task);
} }
// this function is accessed by multi thread.
// so AutoMutex at first.
// so batch.append_task is thread safe.
// you dont need to add extra lock in append_task()
template <typename TaskT> template <typename TaskT>
bool TaskExecutor<TaskT>::fetch_batch(BatchTasks<TaskT>& batch) { // NOLINT bool TaskExecutor<TaskT>::move_task_to_batch(
BatchTasks<TaskT>& batch) { // NOLINT
AutoMutex lock(_mut); AutoMutex lock(_mut);
while (_task_queue.empty()) { while (_task_queue.empty()) {
THREAD_COND_WAIT(&_cond, &_mut); THREAD_COND_WAIT(&_cond, &_mut);
...@@ -187,8 +195,30 @@ bool TaskExecutor<TaskT>::fetch_batch(BatchTasks<TaskT>& batch) { // NOLINT ...@@ -187,8 +195,30 @@ bool TaskExecutor<TaskT>::fetch_batch(BatchTasks<TaskT>& batch) { // NOLINT
return true; return true;
} }
// this function is accessed by multi thread.
// move_task_to_batch have add lock inside the function.
// Packaging 1 TaskT as 1 or Several TaskMeta.
// TaskT is from the SingleTon TaskExecutor`s _task_queue
// although TaskMeta is a local variable, but several TaskMeta may points to
// the same TaskT which is get from the SingleTon TaskExecutor`s _task_queue.
// put TaskMeta to the local variable BatchTasks<TaskT> batch.
// batch.merge_tasks() and batch.notify_tasks() has no lock.
// BatchTasks<TaskT> batch itself is a local variable, it`s thread safe.
// If batch.merge_tasks() and batch.notify_tasks() do something to TaskMeta
// you need to pay attention to that.
// Multi-Thread deal with different TaskMeta(cause it`s created as local
// variable)
// But different TaskMeta may points to the same TaskT
// which is get from the SingleTon TaskExecutor`s _task_queue.
template <typename TaskT> template <typename TaskT>
int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) { int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
if (MempoolWrapper::instance().thread_initialize() != 0) {
LOG(ERROR) << "Failed thread initialize mempool";
return -1;
}
if (_thread_init_fn != NULL) { if (_thread_init_fn != NULL) {
if (_thread_init_fn(context->user_thread_context) != 0) { if (_thread_init_fn(context->user_thread_context) != 0) {
LOG(ERROR) << "execute thread init thunk failed, BSF thread will exit"; LOG(ERROR) << "execute thread init thunk failed, BSF thread will exit";
...@@ -207,10 +237,15 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) { ...@@ -207,10 +237,15 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
} }
} }
if (MempoolWrapper::instance().thread_clear() != 0) {
LOG(ERROR) << "Failed thread clear mempool";
return -1;
}
BatchTasks<TaskT> batch(_batch_size, _batch_align); BatchTasks<TaskT> batch(_batch_size, _batch_align);
if (fetch_batch(batch)) { if (move_task_to_batch(batch)) {
batch.merge_tasks(); batch.merge_tasks();
_fn(batch.in(), batch.out()); _fn(&batch.in(), &batch.out());
batch.notify_tasks(); batch.notify_tasks();
} }
} }
...@@ -219,9 +254,10 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) { ...@@ -219,9 +254,10 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
} }
template <typename InItemT, typename OutItemT> template <typename InItemT, typename OutItemT>
bool TaskManager<InItemT, OutItemT>::schedule(const InArrayT& in, bool TaskManager<InItemT, OutItemT>::schedule(const void* in,
OutArrayT& out) { // NOLINT void* out) { // NOLINT
TaskHandler<TaskT> handler = _executor.schedule(in, out); TaskHandler<TaskT> handler =
TaskExecutorVector<TaskT>::instance()[_model_index].schedule(in, out);
if (handler.valid()) { if (handler.valid()) {
_task_owned = handler; _task_owned = handler;
......
此差异已折叠。
...@@ -56,15 +56,23 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf, ...@@ -56,15 +56,23 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf,
} }
// init bsf framework // init bsf framework
im::bsf::TaskExecutor<TaskT>::instance()->set_thread_init_fn( im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
boost::bind(&InferEngine::thrd_initialize_impl, this)); .set_thread_init_fn(
im::bsf::TaskExecutor<TaskT>::instance()->set_thread_reset_fn( boost::bind(&InferEngine::thrd_initialize_impl, this));
boost::bind(&InferEngine::thrd_clear_impl, this)); im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
im::bsf::TaskExecutor<TaskT>::instance()->set_thread_callback_fn( .set_thread_init_fn(
boost::bind(&InferEngine::task_infer_impl, this, _1, _2)); boost::bind(&InferEngine::thrd_initialize_impl, this));
im::bsf::TaskExecutor<TaskT>::instance()->set_batch_size(_infer_batch_size); im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
im::bsf::TaskExecutor<TaskT>::instance()->set_batch_align(_infer_batch_align); .set_thread_reset_fn(boost::bind(&InferEngine::thrd_clear_impl, this));
if (im::bsf::TaskExecutor<TaskT>::instance()->start(_infer_thread_num) != 0) { im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
.set_thread_callback_fn(
boost::bind(&InferEngine::task_infer_impl, this, _1, _2));
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].set_batch_size(
_infer_batch_size);
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].set_batch_align(
_infer_batch_align);
if (im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].start(
_infer_thread_num) != 0) {
LOG(ERROR) << "Failed start bsf executor, threads:" << _infer_thread_num; LOG(ERROR) << "Failed start bsf executor, threads:" << _infer_thread_num;
return -1; return -1;
} }
...@@ -75,6 +83,11 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf, ...@@ -75,6 +83,11 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf,
return 0; return 0;
} }
// Multiple threads will enter this method of the same object
// One Model corresponds to One ReloadableInferEngine object.
// ReloadableInferEngine object is Process object.
// One ReloadableInferEngine object can have several ModelData<EngineCore>
// ModelData<EngineCore> is Thread object.
int ReloadableInferEngine::infer(const void* in, int ReloadableInferEngine::infer(const void* in,
void* out, void* out,
uint32_t batch_size) { uint32_t batch_size) {
...@@ -82,9 +95,10 @@ int ReloadableInferEngine::infer(const void* in, ...@@ -82,9 +95,10 @@ int ReloadableInferEngine::infer(const void* in,
return infer_impl(in, out, batch_size); return infer_impl(in, out, batch_size);
} }
im::bsf::TaskManager<Tensor, Tensor> task_manager; im::bsf::TaskManager<paddle::PaddleTensor, paddle::PaddleTensor> task_manager(
task_manager.schedule(*(reinterpret_cast<const BatchTensor*>(in)), _model_index);
*(reinterpret_cast<BatchTensor*>(out)));
task_manager.schedule(in, out);
task_manager.wait(); task_manager.wait();
return 0; return 0;
} }
...@@ -110,7 +124,7 @@ int ReloadableInferEngine::proc_finalize() { ...@@ -110,7 +124,7 @@ int ReloadableInferEngine::proc_finalize() {
} }
if (_infer_thread_num > 0) { if (_infer_thread_num > 0) {
im::bsf::TaskExecutor<TaskT>::instance()->stop(); im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].stop();
} }
return 0; return 0;
} }
...@@ -191,6 +205,7 @@ int VersionedInferEngine::proc_initialize(const configure::EngineDesc& conf, ...@@ -191,6 +205,7 @@ int VersionedInferEngine::proc_initialize(const configure::EngineDesc& conf,
std::string engine_type = conf.type(); std::string engine_type = conf.type();
InferEngine* engine = InferEngine* engine =
StaticInferFactory::instance().generate_object(engine_type); StaticInferFactory::instance().generate_object(engine_type);
engine->set_model_index(_model_index);
if (!engine) { if (!engine) {
LOG(ERROR) << "Failed generate engine with type:" << engine_type; LOG(ERROR) << "Failed generate engine with type:" << engine_type;
return -1; return -1;
...@@ -362,8 +377,8 @@ int VersionedInferEngine::infer_impl(const void* in, ...@@ -362,8 +377,8 @@ int VersionedInferEngine::infer_impl(const void* in,
uint32_t batch_size) { uint32_t batch_size) {
return -1; return -1;
} }
int VersionedInferEngine::task_infer_impl(const BatchTensor& in, int VersionedInferEngine::task_infer_impl(const void* in,
BatchTensor& out) { // NOLINT void* out) { // NOLINT
return -1; return -1;
} }
...@@ -373,12 +388,14 @@ int InferManager::proc_initialize(const char* path, const char* file) { ...@@ -373,12 +388,14 @@ int InferManager::proc_initialize(const char* path, const char* file) {
LOG(ERROR) << "failed load infer config, path: " << path << "/" << file; LOG(ERROR) << "failed load infer config, path: " << path << "/" << file;
return -1; return -1;
} }
size_t engine_num = model_toolkit_conf.engines_size(); uint32_t engine_num = model_toolkit_conf.engines_size();
for (size_t ei = 0; ei < engine_num; ++ei) { im::bsf::TaskExecutorVector<TaskT>::instance().resize(engine_num);
for (uint32_t ei = 0; ei < engine_num; ++ei) {
LOG(INFO) << "model_toolkit_conf.engines(" << ei LOG(INFO) << "model_toolkit_conf.engines(" << ei
<< ").name: " << model_toolkit_conf.engines(ei).name(); << ").name: " << model_toolkit_conf.engines(ei).name();
std::string engine_name = model_toolkit_conf.engines(ei).name(); std::string engine_name = model_toolkit_conf.engines(ei).name();
VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine(); VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine();
engine->set_model_index(ei);
if (!engine) { if (!engine) {
LOG(ERROR) << "Failed generate versioned engine: " << engine_name; LOG(ERROR) << "Failed generate versioned engine: " << engine_name;
return -1; return -1;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/types.h> #include <sys/types.h>
#include <unistd.h> #include <unistd.h>
#include <functional>
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -25,6 +26,7 @@ ...@@ -25,6 +26,7 @@
#include "core/predictor/framework/bsf.h" #include "core/predictor/framework/bsf.h"
#include "core/predictor/framework/factory.h" #include "core/predictor/framework/factory.h"
#include "core/predictor/framework/infer_data.h" #include "core/predictor/framework/infer_data.h"
#include "core/predictor/framework/memory.h"
#include "paddle_inference_api.h" // NOLINT #include "paddle_inference_api.h" // NOLINT
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
...@@ -71,7 +73,7 @@ class InferEngine { ...@@ -71,7 +73,7 @@ class InferEngine {
virtual int infer(const void* in, void* out, uint32_t batch_size = -1) { virtual int infer(const void* in, void* out, uint32_t batch_size = -1) {
return infer_impl(in, out, batch_size); return infer_impl(in, out, batch_size);
} }
virtual void set_model_index(uint32_t index) { _model_index = index; }
virtual int reload() = 0; virtual int reload() = 0;
virtual uint64_t version() const = 0; virtual uint64_t version() const = 0;
...@@ -86,12 +88,13 @@ class InferEngine { ...@@ -86,12 +88,13 @@ class InferEngine {
virtual int infer_impl(const void* in, virtual int infer_impl(const void* in,
void* out, void* out,
uint32_t batch_size = -1) = 0; uint32_t batch_size = -1) = 0;
virtual int task_infer_impl(const BatchTensor& in, virtual int task_infer_impl(const void* in, void* out) = 0; // NOLINT
BatchTensor& out) = 0; // NOLINT
protected:
uint32_t _model_index;
// end: framework inner call // end: framework inner call
}; };
typedef im::bsf::Task<paddle::PaddleTensor, paddle::PaddleTensor> TaskT;
class ReloadableInferEngine : public InferEngine { class ReloadableInferEngine : public InferEngine {
public: public:
virtual ~ReloadableInferEngine() {} virtual ~ReloadableInferEngine() {}
...@@ -104,7 +107,6 @@ class ReloadableInferEngine : public InferEngine { ...@@ -104,7 +107,6 @@ class ReloadableInferEngine : public InferEngine {
}; };
virtual int load(const configure::EngineDesc& conf) = 0; virtual int load(const configure::EngineDesc& conf) = 0;
typedef im::bsf::Task<Tensor, Tensor> TaskT;
int proc_initialize_impl(const configure::EngineDesc& conf, bool version); int proc_initialize_impl(const configure::EngineDesc& conf, bool version);
...@@ -179,6 +181,8 @@ struct ModelData { ...@@ -179,6 +181,8 @@ struct ModelData {
delete cores[1]; delete cores[1];
} }
void* get() { return cores[current_idx]->get(); }
EngineCore* cores[2]; EngineCore* cores[2];
uint32_t current_idx; uint32_t current_idx;
}; };
...@@ -191,14 +195,20 @@ class DBReloadableInferEngine : public ReloadableInferEngine { ...@@ -191,14 +195,20 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
int proc_initialize(const configure::EngineDesc& conf, bool version) { int proc_initialize(const configure::EngineDesc& conf, bool version) {
THREAD_KEY_CREATE(&_skey, NULL); THREAD_KEY_CREATE(&_skey, NULL);
THREAD_MUTEX_INIT(&_mutex, NULL); THREAD_MUTEX_INIT(&_mutex, NULL);
gpu_index = 0;
return ReloadableInferEngine::proc_initialize(conf, version); return ReloadableInferEngine::proc_initialize(conf, version);
} }
// 进程初始化会调用load,但由于未执行线程初始化,所以_reload_vec为空,不再继续执行。
// 热加载的话会调用load,由于线程已经初始化,_reload_vec不为空,所以继续执行load_data操作加载数据。
// 线程初始化会执行load_data操作加载数据,然后将engine加入_reload_vec中。
// 每个模型只有一个CloneDBReloadableInferEngine对象。
// 但一个CloneDBReloadableInferEngine对象,可以包含N个EngineCore。
virtual int load(const configure::EngineDesc& conf) { virtual int load(const configure::EngineDesc& conf) {
if (_reload_vec.empty()) { if (_reload_vec.empty()) {
return 0; return 0;
} }
gpu_index = 0;
for (uint32_t ti = 0; ti < _reload_vec.size(); ++ti) { for (uint32_t ti = 0; ti < _reload_vec.size(); ++ti) {
if (load_data(_reload_vec[ti], conf) != 0) { if (load_data(_reload_vec[ti], conf) != 0) {
LOG(ERROR) << "Failed reload engine model: " << ti; LOG(ERROR) << "Failed reload engine model: " << ti;
...@@ -210,7 +220,8 @@ class DBReloadableInferEngine : public ReloadableInferEngine { ...@@ -210,7 +220,8 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
return 0; return 0;
} }
int load_data(ModelData<EngineCore>* md, const configure::EngineDesc& conf) { virtual int load_data(ModelData<EngineCore>* md,
const configure::EngineDesc& conf) {
uint32_t next_idx = (md->current_idx + 1) % 2; uint32_t next_idx = (md->current_idx + 1) % 2;
if (md->cores[next_idx]) { if (md->cores[next_idx]) {
delete md->cores[next_idx]; delete md->cores[next_idx];
...@@ -219,28 +230,29 @@ class DBReloadableInferEngine : public ReloadableInferEngine { ...@@ -219,28 +230,29 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
md->cores[next_idx] = new (std::nothrow) EngineCore; md->cores[next_idx] = new (std::nothrow) EngineCore;
// params.dump(); // params.dump();
if (!md->cores[next_idx] || md->cores[next_idx]->create(conf) != 0) { size_t gpu_ids_num = conf.gpu_ids_size();
im::bsf::AutoMutex lock(_mutex);
int gpu_id = -1;
if (gpu_ids_num > 0) {
gpu_id = conf.gpu_ids(gpu_index % gpu_ids_num);
}
if (!md->cores[next_idx] ||
md->cores[next_idx]->create(conf, gpu_id) != 0) {
LOG(ERROR) << "Failed create model, path: " << conf.model_dir(); LOG(ERROR) << "Failed create model, path: " << conf.model_dir();
return -1; return -1;
} }
gpu_index++;
md->current_idx = next_idx; md->current_idx = next_idx;
return 0; return 0;
} }
virtual int thrd_initialize_impl() { virtual int thrd_initialize_impl() {
// memory pool to be inited in non-serving-threads
if (MempoolWrapper::instance().thread_initialize() != 0) {
LOG(ERROR) << "Failed thread initialize mempool";
return -1;
}
ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>; ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
if (!md || load_data(md, _conf) != 0) { if (!md || load_data(md, _conf) != 0) {
LOG(ERROR) << "Failed create thread data from " << _conf.model_dir(); LOG(ERROR) << "Failed create thread data from " << _conf.model_dir();
return -1; return -1;
} }
LOG(ERROR) << "THREAD_SETSPECIFIC _skey = md";
THREAD_SETSPECIFIC(_skey, md); THREAD_SETSPECIFIC(_skey, md);
im::bsf::AutoMutex lock(_mutex); im::bsf::AutoMutex lock(_mutex);
_reload_vec.push_back(md); _reload_vec.push_back(md);
...@@ -248,11 +260,33 @@ class DBReloadableInferEngine : public ReloadableInferEngine { ...@@ -248,11 +260,33 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
} }
int thrd_clear_impl() { int thrd_clear_impl() {
// for non-serving-threads // actually, there are 2 kinds of multi-thread.
if (MempoolWrapper::instance().thread_clear() != 0) { // 1. brpc thread 2. bsf Task thread
LOG(ERROR) << "Failed thread clear mempool"; // each request is in 1-single brpc thread.
return -1; // IF (bsf Task thread is not used)
} // every single brpc thread corresponds to all the DBReloadableInferEngines.
// each request runs all models in 1-single brpc thread.
// every single brpc thread will create or clone N predictor.
// N = the number of Model.
// so if there are 2 models, and --thread 10.
// each brpc thread will create predictor of Model-1 and Model-2.
// there are totally 10 predictors of Model-1 and 10 predictors of Model-2
// cause there are 10 brpc threads.
// IF bsf Task thread is used。
// there will be a ThreadPool called bsf TaskExecutor.
// TaskExecutorVector is the vector of TaskExecutor.
// the number of TaskExecutor equals to the number of Model.
// 1 TaskExecutor corresponding to 1 Model.
// 1 TaskExecutor have N bsf threads.
// 1 bsf thread corresponds to 1 predictor of
// the Model corresponding to the TaskExecutor.
// brpc thread only put the data into the task_queue(which is in
// TaskExecutor)
// EngineCore->infer() is running in bsf Task thread.
// MempoolWrapper::instance() is actually a Thread-Local Mempool.
// so it belongs to a single Thread.
return 0; return 0;
} }
...@@ -278,6 +312,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine { ...@@ -278,6 +312,7 @@ class DBReloadableInferEngine : public ReloadableInferEngine {
THREAD_KEY_T _skey; THREAD_KEY_T _skey;
THREAD_MUTEX_T _mutex; THREAD_MUTEX_T _mutex;
std::vector<ModelData<EngineCore>*> _reload_vec; std::vector<ModelData<EngineCore>*> _reload_vec;
int gpu_index = 0;
}; };
// 多个EngineCore共用同一份模型数据 // 多个EngineCore共用同一份模型数据
...@@ -287,88 +322,72 @@ class CloneDBReloadableInferEngine ...@@ -287,88 +322,72 @@ class CloneDBReloadableInferEngine
public: public:
virtual ~CloneDBReloadableInferEngine() {} virtual ~CloneDBReloadableInferEngine() {}
virtual int proc_initialize(const configure::EngineDesc& conf, bool version) { // 进程初始化会调用load,但由于未执行线程初始化,所以_reload_vec为空,不再继续执行。
_pd = new (std::nothrow) ModelData<EngineCore>; // 热加载的话会调用load,由于线程已经初始化,_reload_vec不为空,所以继续执行load_data操作加载数据。
if (!_pd) { // 线程初始化会执行load_data操作加载数据,然后将engine加入_reload_vec中。
LOG(ERROR) << "Failed to allocate for ProcData"; // 每个模型只有一个CloneDBReloadableInferEngine对象。
return -1; // 但一个CloneDBReloadableInferEngine对象,可以包含N个EngineCore。
}
return DBReloadableInferEngine<EngineCore>::proc_initialize(conf, version);
}
virtual int load(const configure::EngineDesc& conf) { virtual int load_data(ModelData<EngineCore>* md,
// 加载进程级模型数据 const configure::EngineDesc& conf) {
if (!_pd || uint32_t next_idx = (md->current_idx + 1) % 2;
DBReloadableInferEngine<EngineCore>::load_data(_pd, conf) != 0) { if (md->cores[next_idx]) {
LOG(ERROR) << "Failed to create common model from [" << conf.model_dir() delete md->cores[next_idx];
<< "].";
return -1;
} }
LOG(WARNING) << "Succ load common model[" << _pd->cores[_pd->current_idx] md->cores[next_idx] = new (std::nothrow) EngineCore;
<< "], path[" << conf.model_dir() << "].";
if (DBReloadableInferEngine<EngineCore>::_reload_vec.empty()) { // params.dump();
return 0; size_t gpu_ids_num = conf.gpu_ids_size();
im::bsf::AutoMutex lock(DBReloadableInferEngine<EngineCore>::_mutex);
int gpu_id = -1;
if (gpu_ids_num > 0) {
gpu_id = conf.gpu_ids(DBReloadableInferEngine<EngineCore>::gpu_index %
gpu_ids_num);
} }
// gpu_index will be set to be 0, when load() or proc_initial() is called.
for (uint32_t ti = 0; // gpu_index < gpu_ids_num, means there are predictors still not create
ti < DBReloadableInferEngine<EngineCore>::_reload_vec.size(); // on some GPU card.
++ti) { // so we need to create the predictor.
if (load_data(DBReloadableInferEngine<EngineCore>::_reload_vec[ti], // gpu_index >= gpu_ids_num, means each GPU card has already create one.
_pd->cores[_pd->current_idx]) != 0) { // so we need to clone the predictor.
LOG(ERROR) << "Failed reload engine model: " << ti; if (DBReloadableInferEngine<EngineCore>::gpu_index < gpu_ids_num) {
if (!md->cores[next_idx] ||
md->cores[next_idx]->create(conf, gpu_id) != 0) {
LOG(ERROR) << "Failed create model, path: " << conf.model_dir();
return -1; return -1;
} }
DBReloadableInferEngine<EngineCore>::gpu_index++;
md->current_idx = next_idx;
if (_cloneTemplate.size() <
DBReloadableInferEngine<EngineCore>::gpu_index) {
_cloneTemplate.push_back(md);
} else {
_cloneTemplate[DBReloadableInferEngine<EngineCore>::gpu_index - 1] = md;
}
} else {
// when gpu_id = -1, means we use cpu, but the index should be 0.
// _cloneTemplate[-1] will occur error.
// actually, when gpu_id = -1, there is only 1 predictor in
// _cloneTemplate.
// so the index should always be 0 when gpu_id = -1.
if (gpu_id == -1) gpu_id = 0;
if (!md->cores[next_idx] ||
md->cores[next_idx]->clone(_cloneTemplate[gpu_id]->get()) != 0) {
LOG(ERROR) << "Failed clone model from core";
return -1;
}
DBReloadableInferEngine<EngineCore>::gpu_index++;
md->current_idx = next_idx;
LOG(WARNING) << "core clone model succ, cur_idx[" << md->current_idx
<< "].";
} }
LOG(WARNING) << "Succ load clone model, path[" << conf.model_dir() << "]";
return 0;
}
// 加载线程级对象,多个线程级对象共用pd_core的模型数据
int load_data(ModelData<EngineCore>* td, EngineCore* pd_core) {
uint32_t next_idx = (td->current_idx + 1) % 2;
if (td->cores[next_idx]) {
delete td->cores[next_idx];
}
td->cores[next_idx] = new (std::nothrow) EngineCore;
if (!td->cores[next_idx] ||
td->cores[next_idx]->clone(pd_core->get()) != 0) {
LOG(ERROR) << "Failed clone model from pd_core[ " << pd_core << "], idx["
<< next_idx << "]";
return -1;
}
td->current_idx = next_idx;
LOG(WARNING) << "td_core[" << td->cores[td->current_idx]
<< "] clone model from pd_core[" << pd_core
<< "] succ, cur_idx[" << td->current_idx << "].";
return 0;
}
virtual int thrd_initialize_impl() {
// memory pool to be inited in non-serving-threads
if (MempoolWrapper::instance().thread_initialize() != 0) {
LOG(ERROR) << "Failed thread initialize mempool";
return -1;
}
ModelData<EngineCore>* md = new (std::nothrow) ModelData<EngineCore>;
if (!md || load_data(md, _pd->cores[_pd->current_idx]) != 0) {
LOG(ERROR) << "Failed clone thread data, origin_core["
<< _pd->cores[_pd->current_idx] << "].";
return -1;
}
THREAD_SETSPECIFIC(DBReloadableInferEngine<EngineCore>::_skey, md);
im::bsf::AutoMutex lock(DBReloadableInferEngine<EngineCore>::_mutex);
DBReloadableInferEngine<EngineCore>::_reload_vec.push_back(md);
return 0; return 0;
} }
protected: protected:
ModelData<EngineCore>* // 模板EngineCore,如果已创建,则多个线程级EngineCore共用该对象的模型数据
_pd; // 进程级EngineCore,多个线程级EngineCore共用该对象的模型数据 std::vector<ModelData<EngineCore>*> _cloneTemplate;
}; };
template <typename EngineCore> template <typename EngineCore>
...@@ -505,8 +524,8 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> { ...@@ -505,8 +524,8 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<EngineCore> {
return 0; return 0;
} }
int task_infer_impl(const BatchTensor& in, BatchTensor& out) { // NOLINT int task_infer_impl(const void* in, void* out) { // NOLINT
return infer_impl(&in, &out); return infer_impl(in, out);
} }
}; };
...@@ -559,7 +578,7 @@ class VersionedInferEngine : public InferEngine { ...@@ -559,7 +578,7 @@ class VersionedInferEngine : public InferEngine {
int infer_impl(const void* in, void* out, uint32_t batch_size = -1); int infer_impl(const void* in, void* out, uint32_t batch_size = -1);
int task_infer_impl(const BatchTensor& in, BatchTensor& out); int task_infer_impl(const void* in, void* out);
private: private:
boost::unordered_map<uint64_t, InferEngine*> _versions; boost::unordered_map<uint64_t, InferEngine*> _versions;
......
...@@ -91,6 +91,7 @@ int ServerManager::start_and_wait() { ...@@ -91,6 +91,7 @@ int ServerManager::start_and_wait() {
} }
} }
// rpc multi-thread start from here.
if (_server.Start(FLAGS_port, &_options) != 0) { if (_server.Start(FLAGS_port, &_options) != 0) {
LOG(ERROR) << "Failed to start Paddle Inference Server"; LOG(ERROR) << "Failed to start Paddle Inference Server";
return -1; return -1;
......
...@@ -24,7 +24,7 @@ namespace fugue { ...@@ -24,7 +24,7 @@ namespace fugue {
namespace memory { namespace memory {
void Region::init() { void Region::init() {
_big_mem_capacity = 64 * 1024 * 1024; // 64MB _big_mem_capacity = 128 * 1024 * 1024; // 128MB
_big_mem_start = new char[_big_mem_capacity]; _big_mem_start = new char[_big_mem_capacity];
} }
......
...@@ -129,7 +129,7 @@ class FreeList { ...@@ -129,7 +129,7 @@ class FreeList {
to get the class Pointer to get the class Pointer
for example for example
T is the member of class Node, T data, 'data' is the name. T is the member of class Node, T data, 'data' is the name.
T* value is the member(pointer type) class Node T* value is the member(pointer type) of class Node
so we can get the Node* by calling container_of(value, Node, data) so we can get the Node* by calling container_of(value, Node, data)
*/ */
Node* node = container_of(value, Node, data); Node* node = container_of(value, Node, data);
...@@ -261,7 +261,11 @@ struct BlockReference { ...@@ -261,7 +261,11 @@ struct BlockReference {
// because BlockFreeList is a threal-safe Singleton. // because BlockFreeList is a threal-safe Singleton.
// so we don`t release Block, it is global memory. // so we don`t release Block, it is global memory.
// total number is 32*1024 // total number is 256*1024.
// the MAX_BLOCK_COUNT of Region(one thread one Region) is 1024.
// so BlockFreeList allow 256 Region(means 256 thread).
// the memory used by BlockFreeListType is sizeof(void*)*256*1024.
// Block(2MB) memory is created only when get() is called.
class BlockFreeList { class BlockFreeList {
public: public:
static const int MAX_BLOCK_COUNT = 256 * 1024; static const int MAX_BLOCK_COUNT = 256 * 1024;
...@@ -341,9 +345,10 @@ class Region { ...@@ -341,9 +345,10 @@ class Region {
2 * 1024 * 2 * 1024 *
1024; // 2MB,means when you need less than 2M, get memory from Block. 1024; // 2MB,means when you need less than 2M, get memory from Block.
// 64MB,means when you need less than 64MB, get memory from BigMemory instead // 128MB,means when you need less than 128MB, get memory from BigMemory
// instead
// of BigNode // of BigNode
static const int BIGNODE_MEM_THRESHOLD = (64 * 1024 * 1024 + 1); static const int BIGNODE_MEM_THRESHOLD = (128 * 1024 * 1024 + 1);
static const int COUNTER_SIZE = static const int COUNTER_SIZE =
BIGNODE_MEM_THRESHOLD / BIG_MEM_THRESHOLD + 1; // this is not used BIGNODE_MEM_THRESHOLD / BIG_MEM_THRESHOLD + 1; // this is not used
...@@ -374,7 +379,8 @@ class Mempool { ...@@ -374,7 +379,8 @@ class Mempool {
void* malloc(size_t size) { void* malloc(size_t size) {
size = _align(size); size = _align(size);
// It does not enter the if statement the first time. // It does not enter the if statement the first time.
// Because the block has not been used up, it will enter. // The if statement may enter after the block is created.
// If the block has not been used up, it will enter.
if (size <= _free_size) { if (size <= _free_size) {
void* p = _free_cursor; void* p = _free_cursor;
_free_size -= size; _free_size -= size;
...@@ -392,7 +398,7 @@ class Mempool { ...@@ -392,7 +398,7 @@ class Mempool {
return; return;
} }
// memory in Block,update the pointer. // memory in _block,update the pointer.
if (_free_cursor - size == static_cast<char*>(p)) { if (_free_cursor - size == static_cast<char*>(p)) {
// for example, you need to release -(8+1)bytes // for example, you need to release -(8+1)bytes
// you can only release -8bytes,cause -(8+2)byte is used by other. // you can only release -8bytes,cause -(8+2)byte is used by other.
...@@ -424,9 +430,8 @@ class Mempool { ...@@ -424,9 +430,8 @@ class Mempool {
} }
// 可能返回的是单独Region中malloc的内存。 // 可能返回的是单独Region中malloc的内存。
// 也可能是Block,例如new_size=1M, old_data原本的指针头就在1.2M处,old_size // 也可能是Block,例如new_size=1M, old_data原本的指针头就在1.2M处
// = // old_size = 0.5M
// 0.5M
// 此时,_free_size = 0.3M,new_size<2M,但是required = 1-0.5 >0.3 // 此时,_free_size = 0.3M,new_size<2M,但是required = 1-0.5 >0.3
// 分配出来的就是Block,但是该Block没有并很完美的利用完全。 // 分配出来的就是Block,但是该Block没有并很完美的利用完全。
void* p = this->malloc_from_region(new_size); void* p = this->malloc_from_region(new_size);
......
...@@ -68,13 +68,14 @@ static bvar::PassiveStatus<std::string> s_predictor_revision( ...@@ -68,13 +68,14 @@ static bvar::PassiveStatus<std::string> s_predictor_revision(
DEFINE_bool(V, false, "print version, bool"); DEFINE_bool(V, false, "print version, bool");
DEFINE_bool(g, false, "user defined gflag path"); DEFINE_bool(g, false, "user defined gflag path");
DECLARE_string(flagfile); DECLARE_string(flagfile);
/*
namespace bthread { namespace bthread {
extern pthread_mutex_t g_task_control_mutex; extern pthread_mutex_t g_task_control_mutex;
} }
pthread_mutex_t g_worker_start_fn_mutex = PTHREAD_MUTEX_INITIALIZER; pthread_mutex_t g_worker_start_fn_mutex = PTHREAD_MUTEX_INITIALIZER;
*/
void pthread_worker_start_fn() { void pthread_worker_start_fn() {
/*
while (pthread_mutex_lock(&g_worker_start_fn_mutex) != 0) { while (pthread_mutex_lock(&g_worker_start_fn_mutex) != 0) {
} }
...@@ -83,15 +84,18 @@ void pthread_worker_start_fn() { ...@@ -83,15 +84,18 @@ void pthread_worker_start_fn() {
if (lock_status == EBUSY || lock_status == EAGAIN) { if (lock_status == EBUSY || lock_status == EAGAIN) {
pthread_mutex_unlock(&bthread::g_task_control_mutex); pthread_mutex_unlock(&bthread::g_task_control_mutex);
} }
*/
Resource::instance().thread_initialize(); Resource::instance().thread_initialize();
// Try to avoid deadlock in bthread // Try to avoid deadlock in bthread
/*
if (lock_status == EBUSY || lock_status == EAGAIN) { if (lock_status == EBUSY || lock_status == EAGAIN) {
while (pthread_mutex_lock(&bthread::g_task_control_mutex) != 0) { while (pthread_mutex_lock(&bthread::g_task_control_mutex) != 0) {
} }
} }
pthread_mutex_unlock(&g_worker_start_fn_mutex); pthread_mutex_unlock(&g_worker_start_fn_mutex);
*/
} }
static void g_change_server_port() { static void g_change_server_port() {
...@@ -126,7 +130,7 @@ int main(int argc, char** argv) { ...@@ -126,7 +130,7 @@ int main(int argc, char** argv) {
return 0; return 0;
} }
//google::ParseCommandLineFlags(&argc, &argv, true); // google::ParseCommandLineFlags(&argc, &argv, true);
g_change_server_port(); g_change_server_port();
...@@ -202,7 +206,7 @@ int main(int argc, char** argv) { ...@@ -202,7 +206,7 @@ int main(int argc, char** argv) {
} }
VLOG(2) << "Succ call pthread worker start function"; VLOG(2) << "Succ call pthread worker start function";
//this is not used by any code segment,which can be cancelled. // this is not used by any code segment,which can be cancelled.
if (Resource::instance().general_model_initialize(FLAGS_resource_path, if (Resource::instance().general_model_initialize(FLAGS_resource_path,
FLAGS_resource_file) != 0) { FLAGS_resource_file) != 0) {
LOG(ERROR) << "Failed to initialize general model conf: " LOG(ERROR) << "Failed to initialize general model conf: "
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "core/configure/include/configure_parser.h" #include "core/configure/include/configure_parser.h"
#include "core/configure/inferencer_configure.pb.h" #include "core/configure/inferencer_configure.pb.h"
...@@ -96,7 +97,7 @@ class EngineCore { ...@@ -96,7 +97,7 @@ class EngineCore {
return true; return true;
} }
virtual int create(const configure::EngineDesc& conf) = 0; virtual int create(const configure::EngineDesc& conf, int gpu_id) = 0;
virtual int clone(void* predictor) { virtual int clone(void* predictor) {
if (predictor == NULL) { if (predictor == NULL) {
...@@ -121,7 +122,7 @@ class EngineCore { ...@@ -121,7 +122,7 @@ class EngineCore {
// Paddle Inference Engine // Paddle Inference Engine
class PaddleInferenceEngine : public EngineCore { class PaddleInferenceEngine : public EngineCore {
public: public:
int create(const configure::EngineDesc& engine_conf) { int create(const configure::EngineDesc& engine_conf, int gpu_id) {
std::string model_path = engine_conf.model_dir(); std::string model_path = engine_conf.model_dir();
if (access(model_path.c_str(), F_OK) == -1) { if (access(model_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path not exits: " LOG(ERROR) << "create paddle predictor failed, path not exits: "
...@@ -162,7 +163,11 @@ class PaddleInferenceEngine : public EngineCore { ...@@ -162,7 +163,11 @@ class PaddleInferenceEngine : public EngineCore {
config.SetCpuMathLibraryNumThreads(1); config.SetCpuMathLibraryNumThreads(1);
if (engine_conf.has_use_gpu() && engine_conf.use_gpu()) { if (engine_conf.has_use_gpu() && engine_conf.use_gpu()) {
// 2000MB GPU memory // 2000MB GPU memory
config.EnableUseGpu(2000, FLAGS_gpuid); config.EnableUseGpu(50, gpu_id);
if (engine_conf.has_gpu_multi_stream() &&
engine_conf.gpu_multi_stream()) {
config.EnableGpuMultiStream();
}
} }
precision_type = GetPrecision(FLAGS_precision); precision_type = GetPrecision(FLAGS_precision);
...@@ -174,8 +179,13 @@ class PaddleInferenceEngine : public EngineCore { ...@@ -174,8 +179,13 @@ class PaddleInferenceEngine : public EngineCore {
} }
if (engine_conf.has_use_trt() && engine_conf.use_trt()) { if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
config.SwitchIrOptim(true);
if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) { if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) {
config.EnableUseGpu(2000, FLAGS_gpuid); config.EnableUseGpu(50, gpu_id);
if (engine_conf.has_gpu_multi_stream() &&
engine_conf.gpu_multi_stream()) {
config.EnableGpuMultiStream();
}
} }
config.EnableTensorRtEngine(1 << 20, config.EnableTensorRtEngine(1 << 20,
max_batch, max_batch,
...@@ -203,7 +213,7 @@ class PaddleInferenceEngine : public EngineCore { ...@@ -203,7 +213,7 @@ class PaddleInferenceEngine : public EngineCore {
if (precision_type == PrecisionType::kInt8) { if (precision_type == PrecisionType::kInt8) {
config.EnableMkldnnQuantizer(); config.EnableMkldnnQuantizer();
auto quantizer_config = config.mkldnn_quantizer_config(); auto quantizer_config = config.mkldnn_quantizer_config();
// TODO: warmup data // TODO(somebody): warmup data
// quantizer_config -> SetWarmupData(); // quantizer_config -> SetWarmupData();
// quantizer_config -> SetWarmupBatchSize(); // quantizer_config -> SetWarmupBatchSize();
// quantizer_config -> SetEnabledOpTypes(4); // quantizer_config -> SetEnabledOpTypes(4);
......
...@@ -13,7 +13,8 @@ tar xf faster_rcnn_hrnetv2p_w18_1x.tar ...@@ -13,7 +13,8 @@ tar xf faster_rcnn_hrnetv2p_w18_1x.tar
python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
This model support TensorRT, if you want a faster inference, please use `--use_trt`. This model support TensorRT, if you want a faster inference, please use `--use_trt`. But you need to do some extra work.
Please reference to https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/c%2B%2B/paddle-trt/trt_dynamic_shape_test.cc#L40
### Prediction ### Prediction
......
...@@ -13,7 +13,8 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ ...@@ -13,7 +13,8 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/
tar xf faster_rcnn_hrnetv2p_w18_1x.tar tar xf faster_rcnn_hrnetv2p_w18_1x.tar
python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。 该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项,但此时需要额外设置子图的TRT变长最大最小最优shape.
请参考https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/c%2B%2B/paddle-trt/trt_dynamic_shape_test.cc#L40
### 执行预测 ### 执行预测
``` ```
......
...@@ -13,7 +13,8 @@ tar xf faster_rcnn_r50_fpn_1x_coco.tar ...@@ -13,7 +13,8 @@ tar xf faster_rcnn_r50_fpn_1x_coco.tar
python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
This model support TensorRT, if you want a faster inference, please use `--use_trt`. This model support TensorRT, if you want a faster inference, please use `--use_trt`. But you need to do some extra work.
Please reference to https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/c%2B%2B/paddle-trt/trt_dynamic_shape_test.cc#L40
### Perform prediction ### Perform prediction
......
...@@ -13,7 +13,8 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ ...@@ -13,7 +13,8 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/
tar xf faster_rcnn_r50_fpn_1x_coco.tar tar xf faster_rcnn_r50_fpn_1x_coco.tar
python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。 该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项,但此时需要额外设置子图的TRT变长最大最小最优shape.
请参考https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/c%2B%2B/paddle-trt/trt_dynamic_shape_test.cc#L40
### 执行预测 ### 执行预测
``` ```
......
...@@ -27,7 +27,7 @@ preprocess = Sequential([ ...@@ -27,7 +27,7 @@ preprocess = Sequential([
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608]) postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client() client = Client()
client.connect(['127.0.0.1:9393']) client.connect(['127.0.0.1:9393'])
client.set_rpc_timeout_ms(15000) client.set_rpc_timeout_ms(100000)
im = preprocess(sys.argv[1]) im = preprocess(sys.argv[1])
fetch_map = client.predict( fetch_map = client.predict(
......
...@@ -26,7 +26,7 @@ tar xf test_imgs.tar ...@@ -26,7 +26,7 @@ tar xf test_imgs.tar
python -m paddle_serving_server.serve --model ocr_det_model --port 9293 python -m paddle_serving_server.serve --model ocr_det_model --port 9293
python ocr_web_server.py cpu python ocr_web_server.py cpu
#for gpu user #for gpu user
python -m paddle_serving_server.serve --model ocr_det_model --port 9293 --gpu_id 0 python -m paddle_serving_server.serve --model ocr_det_model --port 9293 --gpu_ids 0
python ocr_web_server.py gpu python ocr_web_server.py gpu
``` ```
...@@ -111,7 +111,7 @@ After the -- model parameter, the folder path of multiple model files is passed ...@@ -111,7 +111,7 @@ After the -- model parameter, the folder path of multiple model files is passed
#for cpu user #for cpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
#for gpu user #for gpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0 python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_ids 0
``` ```
### Client Prediction ### Client Prediction
......
...@@ -25,7 +25,7 @@ tar xf test_imgs.tar ...@@ -25,7 +25,7 @@ tar xf test_imgs.tar
python -m paddle_serving_server.serve --model ocr_det_model --port 9293 python -m paddle_serving_server.serve --model ocr_det_model --port 9293
python ocr_web_server.py cpu python ocr_web_server.py cpu
#for gpu user #for gpu user
python -m paddle_serving_server.serve --model ocr_det_model --port 9293 --gpu_id 0 python -m paddle_serving_server.serve --model ocr_det_model --port 9293 --gpu_ids 0
python ocr_web_server.py gpu python ocr_web_server.py gpu
``` ```
...@@ -110,7 +110,7 @@ python rec_web_client.py ...@@ -110,7 +110,7 @@ python rec_web_client.py
#for cpu user #for cpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293
#for gpu user #for gpu user
python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0 python -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_ids 0
``` ```
### 启动客户端 ### 启动客户端
......
...@@ -307,6 +307,9 @@ class Client(object): ...@@ -307,6 +307,9 @@ class Client(object):
if isinstance(feed, dict): if isinstance(feed, dict):
feed_batch.append(feed) feed_batch.append(feed)
elif isinstance(feed, list): elif isinstance(feed, list):
# batch_size must be 1, cause batch is already in Tensor.
if len(feed) != 1:
raise ValueError("Feed only list = [dict]")
feed_batch = feed feed_batch = feed
else: else:
raise ValueError("Feed only accepts dict and list of dict") raise ValueError("Feed only accepts dict and list of dict")
...@@ -326,6 +329,7 @@ class Client(object): ...@@ -326,6 +329,7 @@ class Client(object):
fetch_names = [] fetch_names = []
counter = 0 counter = 0
# batch_size must be 1, cause batch is already in Tensor.
batch_size = len(feed_batch) batch_size = len(feed_batch)
for key in fetch_list: for key in fetch_list:
......
...@@ -39,7 +39,16 @@ def serve_args(): ...@@ -39,7 +39,16 @@ def serve_args():
"--port", type=int, default=9292, help="Port of the starting gpu") "--port", type=int, default=9292, help="Port of the starting gpu")
parser.add_argument( parser.add_argument(
"--device", type=str, default="gpu", help="Type of device") "--device", type=str, default="gpu", help="Type of device")
parser.add_argument("--gpu_ids", type=str, default="", help="gpu ids") parser.add_argument(
"--gpu_ids", type=str, default="", nargs="+", help="gpu ids")
parser.add_argument(
"--op_num", type=int, default=0, nargs="+", help="Number of each op")
parser.add_argument(
"--op_max_batch",
type=int,
default=32,
nargs="+",
help="Max batch of each op")
parser.add_argument( parser.add_argument(
"--model", type=str, default="", nargs="+", help="Model for serving") "--model", type=str, default="", nargs="+", help="Model for serving")
parser.add_argument( parser.add_argument(
...@@ -99,85 +108,20 @@ def serve_args(): ...@@ -99,85 +108,20 @@ def serve_args():
type=str, type=str,
default=None, default=None,
help="container_id for authentication") help="container_id for authentication")
parser.add_argument(
"--gpu_multi_stream",
default=False,
action="store_true",
help="Use gpu_multi_stream")
return parser.parse_args() return parser.parse_args()
def start_standard_model(serving_port): # pylint: disable=doc-string-missing def start_gpu_card_model(gpu_mode, port, args): # pylint: disable=doc-string-missing
args = serve_args()
thread_num = args.thread
model = args.model
port = serving_port
workdir = args.workdir
device = args.device
mem_optim = args.mem_optim_off is False
ir_optim = args.ir_optim
max_body_size = args.max_body_size
use_mkl = args.use_mkl
use_encryption_model = args.use_encryption_model
use_multilang = args.use_multilang
if model == "":
print("You must specify your serving model")
exit(-1)
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
import paddle_serving_server as serving
op_maker = serving.OpMaker()
op_seq_maker = serving.OpSeqMaker()
read_op = op_maker.create('general_reader')
op_seq_maker.add_op(read_op)
for idx, single_model in enumerate(model):
infer_op_name = "general_infer"
#Temporary support for OCR model,it will be completely revised later
#If you want to use this, C++ server must compile with WITH_OPENCV option.
if len(model) == 2 and idx == 0 and model[0] == 'ocr_det_model':
infer_op_name = "general_detection"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
server = None
if use_multilang:
server = serving.MultiLangServer()
else:
server = serving.Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size)
server.set_port(port)
server.set_precision(args.precision)
server.set_use_calib(args.use_calib)
server.use_encryption_model(use_encryption_model)
if args.product_name != None:
server.set_product_name(args.product_name)
if args.container_id != None:
server.set_container_id(args.container_id)
server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device)
server.run_server()
def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-string-missing
workdir = args.workdir
gpuid = int(gpuid)
device = "gpu" device = "gpu"
if gpuid == -1: if gpu_mode == False:
device = "cpu" device = "cpu"
elif gpuid >= 0:
port = port + index
thread_num = args.thread thread_num = args.thread
model = args.model model = args.model
mem_optim = args.mem_optim_off is False mem_optim = args.mem_optim_off is False
...@@ -185,8 +129,7 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin ...@@ -185,8 +129,7 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
use_mkl = args.use_mkl use_mkl = args.use_mkl
max_body_size = args.max_body_size max_body_size = args.max_body_size
use_multilang = args.use_multilang use_multilang = args.use_multilang
if gpuid >= 0: workdir = "{}_{}".format(args.workdir, port)
workdir = "{}_{}".format(args.workdir, gpuid)
if model == "": if model == "":
print("You must specify your serving model") print("You must specify your serving model")
...@@ -204,7 +147,11 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin ...@@ -204,7 +147,11 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
op_seq_maker.add_op(read_op) op_seq_maker.add_op(read_op)
for idx, single_model in enumerate(model): for idx, single_model in enumerate(model):
infer_op_name = "general_infer" infer_op_name = "general_infer"
if len(model) == 2 and idx == 0: # 目前由于ocr的节点Det模型依赖于opencv的第三方库
# 只有使用ocr的时候,才会加入opencv的第三方库并编译GeneralDetectionOp
# 故此处做特殊处理,当不满足下述情况时,所添加的op默认为GeneralInferOp
# 以后可能考虑不用python脚本来生成配置
if len(model) == 2 and idx == 0 and single_model == "ocr_det_model":
infer_op_name = "general_detection" infer_op_name = "general_detection"
else: else:
infer_op_name = "general_infer" infer_op_name = "general_infer"
...@@ -226,8 +173,19 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin ...@@ -226,8 +173,19 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
server.set_memory_optimize(mem_optim) server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim) server.set_ir_optimize(ir_optim)
server.set_max_body_size(max_body_size) server.set_max_body_size(max_body_size)
if args.use_trt:
if args.use_trt and device == "gpu":
server.set_trt() server.set_trt()
server.set_ir_optimize(True)
if args.gpu_multi_stream and device == "gpu":
server.set_gpu_multi_stream()
if args.op_num:
server.set_op_num(args.op_num)
if args.op_max_batch:
server.set_op_max_batch(args.op_max_batch)
if args.use_lite: if args.use_lite:
server.set_lite() server.set_lite()
...@@ -247,48 +205,40 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin ...@@ -247,48 +205,40 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
port=port, port=port,
device=device, device=device,
use_encryption_model=args.use_encryption_model) use_encryption_model=args.use_encryption_model)
if gpuid >= 0: if gpu_mode == True:
server.set_gpuid(gpuid) server.set_gpuid(args.gpu_ids)
server.run_server() server.run_server()
def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing
gpus = "" gpus = []
if serving_port == None: if serving_port == None:
serving_port = args.port serving_port = args.port
if args.gpu_ids == "": if args.gpu_ids == "":
gpus = [] gpus = []
else: else:
gpus = args.gpu_ids.split(",") #check the gpu_id is valid or not.
gpus = args.gpu_ids
if isinstance(gpus, str):
gpus = [gpus]
if "CUDA_VISIBLE_DEVICES" in os.environ: if "CUDA_VISIBLE_DEVICES" in os.environ:
env_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",") env_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
for ids in gpus: for op_gpus_str in gpus:
if ids not in env_gpus: op_gpu_list = op_gpus_str.split(",")
print("gpu_ids is not in CUDA_VISIBLE_DEVICES.") for ids in op_gpu_list:
exit(-1) if ids not in env_gpus:
else: print("gpu_ids is not in CUDA_VISIBLE_DEVICES.")
env_gpus = [] exit(-1)
if args.use_lite: if args.use_lite:
print("run using paddle-lite.") print("run using paddle-lite.")
start_gpu_card_model(-1, -1, serving_port, args) start_gpu_card_model(False, serving_port, args)
elif len(gpus) <= 0: elif len(gpus) <= 0:
print("gpu_ids not set, going to run cpu service.") print("gpu_ids not set, going to run cpu service.")
start_gpu_card_model(-1, -1, serving_port, args) start_gpu_card_model(False, serving_port, args)
else: else:
gpu_processes = [] start_gpu_card_model(True, serving_port, args)
for i, gpu_id in enumerate(gpus):
p = Process(
target=start_gpu_card_model,
args=(
i,
gpu_id,
serving_port,
args, ))
gpu_processes.append(p)
for p in gpu_processes:
p.start()
for p in gpu_processes:
p.join()
class MainService(BaseHTTPRequestHandler): class MainService(BaseHTTPRequestHandler):
...@@ -395,14 +345,28 @@ if __name__ == "__main__": ...@@ -395,14 +345,28 @@ if __name__ == "__main__":
from .web_service import WebService from .web_service import WebService
web_service = WebService(name=args.name) web_service = WebService(name=args.name)
web_service.load_model_config(args.model) web_service.load_model_config(args.model)
gpu_ids = args.gpu_ids
if gpu_ids == "": if args.gpu_ids == "":
gpus = []
else:
#check the gpu_id is valid or not.
gpus = args.gpu_ids
if isinstance(gpus, str):
gpus = [gpus]
if "CUDA_VISIBLE_DEVICES" in os.environ: if "CUDA_VISIBLE_DEVICES" in os.environ:
gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"] env_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
if len(gpu_ids) > 0: for op_gpus_str in gpus:
web_service.set_gpus(gpu_ids) op_gpu_list = op_gpus_str.split(",")
for ids in op_gpu_list:
if ids not in env_gpus:
print("gpu_ids is not in CUDA_VISIBLE_DEVICES.")
exit(-1)
if len(gpus) > 0:
web_service.set_gpus(gpus)
workdir = "{}_{}".format(args.workdir, args.port)
web_service.prepare_server( web_service.prepare_server(
workdir=args.workdir, workdir=workdir,
port=args.port, port=args.port,
device=args.device, device=args.device,
use_lite=args.use_lite, use_lite=args.use_lite,
...@@ -410,7 +374,11 @@ if __name__ == "__main__": ...@@ -410,7 +374,11 @@ if __name__ == "__main__":
ir_optim=args.ir_optim, ir_optim=args.ir_optim,
thread_num=args.thread, thread_num=args.thread,
precision=args.precision, precision=args.precision,
use_calib=args.use_calib) use_calib=args.use_calib,
use_trt=args.use_trt,
gpu_multi_stream=args.gpu_multi_stream,
op_num=args.op_num,
op_max_batch=args.op_max_batch)
web_service.run_rpc_service() web_service.run_rpc_service()
app_instance = Flask(__name__) app_instance = Flask(__name__)
......
...@@ -81,8 +81,11 @@ class Server(object): ...@@ -81,8 +81,11 @@ class Server(object):
self.use_local_bin = False self.use_local_bin = False
self.mkl_flag = False self.mkl_flag = False
self.device = "cpu" self.device = "cpu"
self.gpuid = 0 self.gpuid = []
self.op_num = [0]
self.op_max_batch = [32]
self.use_trt = False self.use_trt = False
self.gpu_multi_stream = False
self.use_lite = False self.use_lite = False
self.use_xpu = False self.use_xpu = False
self.model_config_paths = collections.OrderedDict() self.model_config_paths = collections.OrderedDict()
...@@ -137,11 +140,13 @@ class Server(object): ...@@ -137,11 +140,13 @@ class Server(object):
def set_ir_optimize(self, flag=False): def set_ir_optimize(self, flag=False):
self.ir_optimization = flag self.ir_optimization = flag
# Multi-Server does not have this Function.
def set_product_name(self, product_name=None): def set_product_name(self, product_name=None):
if product_name == None: if product_name == None:
raise ValueError("product_name can't be None.") raise ValueError("product_name can't be None.")
self.product_name = product_name self.product_name = product_name
# Multi-Server does not have this Function.
def set_container_id(self, container_id): def set_container_id(self, container_id):
if container_id == None: if container_id == None:
raise ValueError("container_id can't be None.") raise ValueError("container_id can't be None.")
...@@ -163,12 +168,21 @@ class Server(object): ...@@ -163,12 +168,21 @@ class Server(object):
def set_device(self, device="cpu"): def set_device(self, device="cpu"):
self.device = device self.device = device
def set_gpuid(self, gpuid=0): def set_gpuid(self, gpuid):
self.gpuid = gpuid self.gpuid = gpuid
def set_op_num(self, op_num):
self.op_num = op_num
def set_op_max_batch(self, op_max_batch):
self.op_max_batch = op_max_batch
def set_trt(self): def set_trt(self):
self.use_trt = True self.use_trt = True
def set_gpu_multi_stream(self):
self.gpu_multi_stream = True
def set_lite(self): def set_lite(self):
self.use_lite = True self.use_lite = True
...@@ -178,6 +192,27 @@ class Server(object): ...@@ -178,6 +192,27 @@ class Server(object):
def _prepare_engine(self, model_config_paths, device, use_encryption_model): def _prepare_engine(self, model_config_paths, device, use_encryption_model):
if self.model_toolkit_conf == None: if self.model_toolkit_conf == None:
self.model_toolkit_conf = [] self.model_toolkit_conf = []
self.device = device
if isinstance(self.gpuid, str):
self.gpuid = [self.gpuid]
if len(self.gpuid) == 0:
if self.device == "gpu" or self.use_trt:
self.gpuid.append("0")
else:
self.gpuid.append("-1")
if isinstance(self.op_num, int):
self.op_num = [self.op_num]
if len(self.op_num) == 0:
self.op_num.append(0)
if isinstance(self.op_max_batch, int):
self.op_max_batch = [self.op_max_batch]
if len(self.op_max_batch) == 0:
self.op_max_batch.append(32)
index = 0
for engine_name, model_config_path in model_config_paths.items(): for engine_name, model_config_path in model_config_paths.items():
engine = server_sdk.EngineDesc() engine = server_sdk.EngineDesc()
...@@ -186,19 +221,28 @@ class Server(object): ...@@ -186,19 +221,28 @@ class Server(object):
engine.reloadable_meta = model_config_path + "/fluid_time_file" engine.reloadable_meta = model_config_path + "/fluid_time_file"
os.system("touch {}".format(engine.reloadable_meta)) os.system("touch {}".format(engine.reloadable_meta))
engine.reloadable_type = "timestamp_ne" engine.reloadable_type = "timestamp_ne"
engine.runtime_thread_num = 0 engine.runtime_thread_num = self.op_num[index % len(self.op_num)]
engine.batch_infer_size = 0 engine.batch_infer_size = self.op_max_batch[index %
engine.enable_batch_align = 0 len(self.op_max_batch)]
engine.enable_batch_align = 1
engine.model_dir = model_config_path engine.model_dir = model_config_path
engine.enable_memory_optimization = self.memory_optimization engine.enable_memory_optimization = self.memory_optimization
engine.enable_ir_optimization = self.ir_optimization engine.enable_ir_optimization = self.ir_optimization
engine.use_trt = self.use_trt engine.use_trt = self.use_trt
engine.gpu_multi_stream = self.gpu_multi_stream
engine.use_lite = self.use_lite engine.use_lite = self.use_lite
engine.use_xpu = self.use_xpu engine.use_xpu = self.use_xpu
engine.use_gpu = False engine.use_gpu = False
if self.device == "gpu": if self.device == "gpu" or self.use_trt:
engine.use_gpu = True engine.use_gpu = True
if len(self.gpuid) == 0:
raise ValueError("CPU: self.gpuid = -1, GPU: must set it ")
op_gpu_list = self.gpuid[index % len(self.gpuid)].split(",")
for ids in op_gpu_list:
engine.gpu_ids.extend([int(ids)])
if os.path.exists('{}/__params__'.format(model_config_path)): if os.path.exists('{}/__params__'.format(model_config_path)):
engine.combined_model = True engine.combined_model = True
else: else:
...@@ -208,6 +252,7 @@ class Server(object): ...@@ -208,6 +252,7 @@ class Server(object):
engine.type = "PADDLE_INFER" engine.type = "PADDLE_INFER"
self.model_toolkit_conf.append(server_sdk.ModelToolkitConf()) self.model_toolkit_conf.append(server_sdk.ModelToolkitConf())
self.model_toolkit_conf[-1].engines.extend([engine]) self.model_toolkit_conf[-1].engines.extend([engine])
index = index + 1
def _prepare_infer_service(self, port): def _prepare_infer_service(self, port):
if self.infer_service_conf == None: if self.infer_service_conf == None:
...@@ -332,7 +377,11 @@ class Server(object): ...@@ -332,7 +377,11 @@ class Server(object):
self.mkl_flag = flag self.mkl_flag = flag
def check_avx(self): def check_avx(self):
p = subprocess.Popen(['cat /proc/cpuinfo | grep avx 2>/dev/null'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) p = subprocess.Popen(
['cat /proc/cpuinfo | grep avx 2>/dev/null'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
out, err = p.communicate() out, err = p.communicate()
if err == b'' and len(out) > 0: if err == b'' and len(out) > 0:
return True return True
...@@ -431,6 +480,7 @@ class Server(object): ...@@ -431,6 +480,7 @@ class Server(object):
device="cpu", device="cpu",
use_encryption_model=False, use_encryption_model=False,
cube_conf=None): cube_conf=None):
self.device = device
if workdir == None: if workdir == None:
workdir = "./tmp" workdir = "./tmp"
os.system("mkdir -p {}".format(workdir)) os.system("mkdir -p {}".format(workdir))
...@@ -533,7 +583,6 @@ class Server(object): ...@@ -533,7 +583,6 @@ class Server(object):
"-workflow_path {} " \ "-workflow_path {} " \
"-workflow_file {} " \ "-workflow_file {} " \
"-bthread_concurrency {} " \ "-bthread_concurrency {} " \
"-gpuid {} " \
"-max_body_size {} ".format( "-max_body_size {} ".format(
self.bin_path, self.bin_path,
self.workdir, self.workdir,
...@@ -549,7 +598,6 @@ class Server(object): ...@@ -549,7 +598,6 @@ class Server(object):
self.workdir, self.workdir,
self.workflow_fn, self.workflow_fn,
self.num_threads, self.num_threads,
self.gpuid,
self.max_body_size) self.max_body_size)
print("Going to Run Comand") print("Going to Run Comand")
print(command) print(command)
...@@ -615,9 +663,27 @@ class MultiLangServer(object): ...@@ -615,9 +663,27 @@ class MultiLangServer(object):
def set_ir_optimize(self, flag=False): def set_ir_optimize(self, flag=False):
self.bserver_.set_ir_optimize(flag) self.bserver_.set_ir_optimize(flag)
def set_gpuid(self, gpuid=0): def set_gpuid(self, gpuid):
self.bserver_.set_gpuid(gpuid) self.bserver_.set_gpuid(gpuid)
def set_op_num(self, op_num):
self.bserver_.set_op_num(op_num)
def set_op_max_batch(self, op_max_batch):
self.bserver_.set_op_max_batch(op_max_batch)
def set_trt(self):
self.bserver_.set_trt()
def set_gpu_multi_stream(self):
self.bserver_.set_gpu_multi_stream()
def set_lite(self):
self.bserver_.set_lite()
def set_xpu(self):
self.bserver_.set_xpu()
def load_model_config(self, def load_model_config(self,
server_config_dir_paths, server_config_dir_paths,
client_config_path=None): client_config_path=None):
...@@ -674,6 +740,7 @@ class MultiLangServer(object): ...@@ -674,6 +740,7 @@ class MultiLangServer(object):
device="cpu", device="cpu",
use_encryption_model=False, use_encryption_model=False,
cube_conf=None): cube_conf=None):
self.device = device
if not self._port_is_available(port): if not self._port_is_available(port):
raise SystemExit("Port {} is already used".format(port)) raise SystemExit("Port {} is already used".format(port))
default_port = 12000 default_port = 12000
......
...@@ -105,25 +105,33 @@ class WebService(object): ...@@ -105,25 +105,33 @@ class WebService(object):
def set_gpus(self, gpus): def set_gpus(self, gpus):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
self.gpus = [int(x) for x in gpus.split(",")] self.gpus = gpus
def default_rpc_service(self, def default_rpc_service(self,
workdir="conf", workdir,
port=9292, port=9292,
gpuid=0, gpus=-1,
thread_num=2, thread_num=2,
mem_optim=True, mem_optim=True,
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
ir_optim=False, ir_optim=False,
precision="fp32", precision="fp32",
use_calib=False): use_calib=False,
use_trt=False,
gpu_multi_stream=False,
op_num=None,
op_max_batch=None):
device = "gpu" device = "gpu"
if gpuid == -1: server = Server()
if gpus == -1:
if use_lite: if use_lite:
device = "arm" device = "arm"
else: else:
device = "cpu" device = "cpu"
else:
server.set_gpuid(gpus)
op_maker = OpMaker() op_maker = OpMaker()
op_seq_maker = OpSeqMaker() op_seq_maker = OpSeqMaker()
...@@ -142,7 +150,6 @@ class WebService(object): ...@@ -142,7 +150,6 @@ class WebService(object):
general_response_op = op_maker.create('general_response') general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op) op_seq_maker.add_op(general_response_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num) server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim) server.set_memory_optimize(mem_optim)
...@@ -151,6 +158,19 @@ class WebService(object): ...@@ -151,6 +158,19 @@ class WebService(object):
server.set_precision(precision) server.set_precision(precision)
server.set_use_calib(use_calib) server.set_use_calib(use_calib)
if use_trt and device == "gpu":
server.set_trt()
server.set_ir_optimize(True)
if gpu_multi_stream and device == "gpu":
server.set_gpu_multi_stream()
if op_num:
server.set_op_num(op_num)
if op_max_batch:
server.set_op_max_batch(op_max_batch)
if use_lite: if use_lite:
server.set_lite() server.set_lite()
if use_xpu: if use_xpu:
...@@ -158,8 +178,7 @@ class WebService(object): ...@@ -158,8 +178,7 @@ class WebService(object):
server.load_model_config(self.server_config_dir_paths server.load_model_config(self.server_config_dir_paths
) #brpc Server support server_config_dir_paths ) #brpc Server support server_config_dir_paths
if gpuid >= 0:
server.set_gpuid(gpuid)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
return server return server
...@@ -180,24 +199,29 @@ class WebService(object): ...@@ -180,24 +199,29 @@ class WebService(object):
use_xpu=self.use_xpu, use_xpu=self.use_xpu,
ir_optim=self.ir_optim, ir_optim=self.ir_optim,
precision=self.precision, precision=self.precision,
use_calib=self.use_calib)) use_calib=self.use_calib,
op_num=self.op_num,
op_max_batch=self.op_max_batch))
else: else:
for i, gpuid in enumerate(self.gpus): self.rpc_service_list.append(
self.rpc_service_list.append( self.default_rpc_service(
self.default_rpc_service( self.workdir,
"{}_{}".format(self.workdir, i), self.port_list[0],
self.port_list[i], self.gpus,
gpuid, thread_num=self.thread_num,
thread_num=self.thread_num, mem_optim=self.mem_optim,
mem_optim=self.mem_optim, use_lite=self.use_lite,
use_lite=self.use_lite, use_xpu=self.use_xpu,
use_xpu=self.use_xpu, ir_optim=self.ir_optim,
ir_optim=self.ir_optim, precision=self.precision,
precision=self.precision, use_calib=self.use_calib,
use_calib=self.use_calib)) use_trt=self.use_trt,
gpu_multi_stream=self.gpu_multi_stream,
op_num=self.op_num,
op_max_batch=self.op_max_batch))
def prepare_server(self, def prepare_server(self,
workdir="", workdir,
port=9393, port=9393,
device="gpu", device="gpu",
precision="fp32", precision="fp32",
...@@ -205,9 +229,12 @@ class WebService(object): ...@@ -205,9 +229,12 @@ class WebService(object):
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
ir_optim=False, ir_optim=False,
gpuid=0,
thread_num=2, thread_num=2,
mem_optim=True): mem_optim=True,
use_trt=False,
gpu_multi_stream=False,
op_num=None,
op_max_batch=None):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
self.workdir = workdir self.workdir = workdir
self.port = port self.port = port
...@@ -219,25 +246,23 @@ class WebService(object): ...@@ -219,25 +246,23 @@ class WebService(object):
self.use_xpu = use_xpu self.use_xpu = use_xpu
self.ir_optim = ir_optim self.ir_optim = ir_optim
self.mem_optim = mem_optim self.mem_optim = mem_optim
self.gpuid = gpuid
self.port_list = [] self.port_list = []
self.use_trt = use_trt
self.gpu_multi_stream = gpu_multi_stream
self.op_num = op_num
self.op_max_batch = op_max_batch
default_port = 12000 default_port = 12000
for i in range(1000): for i in range(1000):
if port_is_available(default_port + i): if port_is_available(default_port + i):
self.port_list.append(default_port + i) self.port_list.append(default_port + i)
if len(self.port_list) > len(self.gpus):
break break
def _launch_web_service(self): def _launch_web_service(self):
gpu_num = len(self.gpus)
self.client = Client() self.client = Client()
self.client.load_client_config(self.client_config_path) self.client.load_client_config(self.client_config_path)
endpoints = "" endpoints = ""
if gpu_num > 0: endpoints = "127.0.0.1:{}".format(self.port_list[0])
for i in range(gpu_num):
endpoints += "127.0.0.1:{},".format(self.port_list[i])
else:
endpoints = "127.0.0.1:{}".format(self.port_list[0])
self.client.connect([endpoints]) self.client.connect([endpoints])
def get_prediction(self, request): def get_prediction(self, request):
...@@ -324,10 +349,10 @@ class WebService(object): ...@@ -324,10 +349,10 @@ class WebService(object):
# default self.gpus = [0]. # default self.gpus = [0].
if len(self.gpus) == 0: if len(self.gpus) == 0:
self.gpus.append(0) self.gpus.append(0)
gpu_id = (self.gpus[0].split(","))[0]
self.client.load_model_config( self.client.load_model_config(
self.server_config_dir_paths[0], self.server_config_dir_paths[0], use_gpu=True, gpu_id=gpu_id)
use_gpu=True,
gpu_id=self.gpus[0])
else: else:
self.client.load_model_config( self.client.load_model_config(
self.server_config_dir_paths[0], use_gpu=False) self.server_config_dir_paths[0], use_gpu=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册