未验证 提交 60a7b4fd 编写于 作者: T TeslaZhao 提交者: GitHub

Merge branch 'PaddlePaddle:develop' into develop

......@@ -22,11 +22,8 @@ message EngineDesc {
required string reloadable_type = 4;
required string model_dir = 5;
repeated int32 gpu_ids = 6;
required int32 runtime_thread_num = 7;
required int32 batch_infer_size = 8;
required int32 enable_batch_align = 9;
optional string version_file = 10;
optional string version_type = 11;
optional string version_file = 7;
optional string version_type = 8;
/*
* Sparse Parameter Service type. Valid types are:
......@@ -39,17 +36,34 @@ message EngineDesc {
LOCAL = 1;
REMOTE = 2;
}
optional SparseParamServiceType sparse_param_service_type = 12;
optional string sparse_param_service_table_name = 13;
optional bool enable_memory_optimization = 14;
optional bool enable_ir_optimization = 15;
optional bool use_trt = 16;
optional bool use_lite = 17;
optional bool use_xpu = 18;
optional bool use_gpu = 19;
optional bool combined_model = 20;
optional bool encrypted_model = 21;
optional bool gpu_multi_stream = 22;
optional SparseParamServiceType sparse_param_service_type = 10;
optional string sparse_param_service_table_name = 11;
optional bool enable_memory_optimization = 12;
optional bool enable_ir_optimization = 13;
optional bool use_trt = 14;
optional bool use_lite = 15;
optional bool use_xpu = 16;
optional bool use_gpu = 17;
optional bool combined_model = 18;
optional bool encrypted_model = 19;
optional bool gpu_multi_stream = 20;
/*
* "runtime_thread_num": n == 0 means don`t use Asynchronous task scheduling
* mode.
* n > 0 means how many Predictor for this engine in Asynchronous task
* scheduling mode.
* "batch_infer_size": the max batch for this engine in Asynchronous task
* scheduling mode.
* "enable_overrun": always put a whole task into the TaskQueue even if the
* total batch is bigger than "batch_infer_size".
* "allow_split_request": allow to split task(which is corresponding to
* request).
*/
optional int32 runtime_thread_num = 30 [ default = 0 ];
optional int32 batch_infer_size = 31 [ default = 32 ];
optional bool enable_overrun = 32 [ default = false ];
optional bool allow_split_request = 33 [ default = true ];
};
// model_toolkit conf
......
......@@ -26,9 +26,90 @@
#include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/memory.h"
// this file is included by bsf.h
namespace im {
namespace bsf {
template <typename InItemT, typename OutItemT>
bool Task<InItemT, OutItemT>::task_fetch_init(BatchTasks<TaskT>& batchTask) {
// 双检锁,减少加锁的粒度
if (!fetch_init) {
if (taskmeta_num > 1) {
// 对于task被拆分为多个taskmeta,需要加锁。
AutoMutex lock(task_mut);
task_fetch_create(batchTask);
} else {
// 对于task只有1个taskmeta,不需要加锁。
task_fetch_create(batchTask);
}
}
return true;
}
template <typename InItemT, typename OutItemT>
bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
if (!fetch_init) {
vector_fetch_lod_index = batchTask.vector_fetch_lod_index;
set_fetch_nobatch_index = batchTask.set_fetch_nobatch_index;
OutVectorT taskMetaOutLodTensor;
size_t fetchvar_num = batchTask._batch_out.size();
for (size_t fetchvar_index = 0; fetchvar_index < fetchvar_num;
++fetchvar_index) {
size_t fetchvar_bytesize_index =
batchTask.fetchvar_bytesize(fetchvar_index);
size_t fetchvar_batch = 0;
// 1. nobatch fetchvar情况
if (set_fetch_nobatch_index.size() > 0 &&
set_fetch_nobatch_index.find(fetchvar_index) !=
set_fetch_nobatch_index.end()) {
fetchvar_batch = 1;
} else if (vector_fetch_lod_index.size() > 0 &&
std::find(vector_fetch_lod_index.begin(),
vector_fetch_lod_index.end(),
fetchvar_index) != vector_fetch_lod_index.end()) {
// lod fetchvar情况,此时无法确定总的shape[0]
// 根据task中的task_num总数开辟task_num个临时空间
// 每个lod型的fetchvar拷贝到对应的临时空间中
// 最后再计算临时空间的总量,合并fetchvar和lod
fetchvar_batch = 0;
} else {
// 普通fetchvar情况,此时该Task总的fetchvar_batch =
// 输入的总的batch_size()
fetchvar_batch = batch_size();
}
paddle::PaddleTensor tensor_out;
tensor_out.name = batchTask._batch_out[fetchvar_index].name;
tensor_out.dtype =
paddle::PaddleDType(batchTask._batch_out[fetchvar_index].dtype);
tensor_out.shape = batchTask._batch_out[fetchvar_index].shape;
tensor_out.shape[0] = fetchvar_batch;
if (fetchvar_batch != 0) {
// 此时 lod 为空。
tensor_out.lod = batchTask._batch_out[fetchvar_index].lod;
// resize all batch memory at one time
size_t databuf_size = fetchvar_batch * fetchvar_bytesize_index;
tensor_out.data.Resize(databuf_size);
} else {
// 当taskmeta_num = 1时,由于同时只有一个taskMeta操作task
// 不涉及线程安全问题,所以此时可以直接由taskMeta->task->resize->copy
// 当task被分为多个taskMeta时,需要临时对象记录
// 收齐后再一起合并
if (taskmeta_num > 1) {
taskMetaOutLodTensor.push_back(tensor_out);
}
}
outVectorT_ptr->push_back(tensor_out);
}
// outLodTensorVector实际是一个双层vector
// shape为taskmeta_num * vector_fetch_lod_index.size();
outLodTensorVector.resize(taskmeta_num, taskMetaOutLodTensor);
fetch_init = true;
}
return true;
}
template <typename TaskT>
void* TaskExecutor<TaskT>::thread_entry(void* args) {
ThreadContext<TaskT>* context = static_cast<ThreadContext<TaskT>*>(args);
......@@ -136,7 +217,7 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
}
/*
if (!BatchTasks<TaskT>::check_valid(in, out, _batch_align)) {
if (!BatchTasks<TaskT>::check_valid(in, out, _overrun)) {
LOG(ERROR) << "Invalid input & output";
return TaskHandler<TaskT>::valid_handle();
}
......@@ -156,9 +237,11 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr;
task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr;
if (!task->task_init()) {
LOG(ERROR) << "task->init() failed";
}
task->rem = task->batch_size();
task->index.store(0, butil::memory_order_relaxed);
AutoMutex lock(_mut);
_task_queue.push_back(task);
THREAD_COND_SIGNAL(&_cond);
......@@ -168,11 +251,12 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
// this function is accessed by multi thread.
// so AutoMutex at first.
// so batch.append_task is thread safe.
// so batchTask.append_task is thread safe.
// you dont need to add extra lock in append_task()
// task is already init.
template <typename TaskT>
bool TaskExecutor<TaskT>::move_task_to_batch(
BatchTasks<TaskT>& batch) { // NOLINT
BatchTasks<TaskT>& batchTask) { // NOLINT
AutoMutex lock(_mut);
while (_task_queue.empty()) {
THREAD_COND_WAIT(&_cond, &_mut);
......@@ -183,15 +267,65 @@ bool TaskExecutor<TaskT>::move_task_to_batch(
return false;
}
TaskT* previous_task = nullptr;
while (!_task_queue.empty()) {
TaskT* task = _task_queue.front();
size_t rem = batch.append_task(task);
// 由于无法确定fetchVar是否为lod(即使输入是非lod,输出也可能是lod)
// 简单的处理方法是:task不能被拆分,即用户的请求可以合并一起预测,但不能拆分两个小部分去预测。
// 只需要设置engine的属性allow_split_request = false即可。
// 复杂的处理方法是允许拆分Task,无论是否包含lod.
// 难点:预测前,能够知道被拆成了几个taskmeta,但只有预测后,才知道有多少个fetchvar,多少个lod的fetchvar
// 所以,task中先要创建taskmeta_num* fetchvar
// num(lod类型的)个临时PaddleTensor(存储data及Lod)
// 由于多线程调度的单位是taskmeta,故只能在notify_task中,用taskmeta->task去创建
// 此时由于多个taskmeta对应一个task,存在多线程竞争,所以需要在task中加锁。
// 原子操作不可行,因为多个线程必须等待创建好上述的PaddleTensor后才能继续。
// 对于普通的fetch,也需要加锁去创建PaddleTensor,后续才能往里拷贝。
// _overrun表示,异步BatchTasks是否允许单次临时超过限制。
// _overrun为true时,即使BatchTasks剩下1-batch,也会全放入一个完整的Task,允许临时超限。
// _overrun为false时,不允许。
// 对于模型本身有最大Batch限制的情况,应将该值设为false,默认为false。
// 对于模型本身无最大Batch限制,但自己设置了BatchTasks的最大Batch,可以考虑设置为True。
// _allow_split_request ==
// true,则允许拆分task.BatchTasks剩下1-batch,则会从下一个Task中拆出1-Batch
// _allow_split_request ==
// false,则每个task不会被拆分。BatchTasks剩下1-batch会被浪费
// 默认为true,允许拆分task从而使得空间利用率最大。
if (!batchTask.get_allow_split_request()) {
if (task->batch_size() > batchTask.get_rem_size() &&
!batchTask.get_overrun()) {
break;
}
}
// combine_task_valid负责判断是否能够合并
// 除最外层的shape外,内层shape应一致才能合并。
// 否则跳出循环,放入下一个batchTask中。
// 以此保证batch.append_task(task)中的task的内层shape相同。
// 对于Shape[0] = 1 而!=batch的情况,因为合并时,取其中一个的值
// 所以要求该feedvar必须相等,才能合并。
// 否则跳出循环,放入下一个batchTask中。
// 目前没有PaddleTensor和PaddleBuff没有重载==,所以只能比较内存.
// TODO(HexToString): 可以考虑后期支持AutoPadding.
if (previous_task != nullptr) {
if (!task->combine_task_valid(previous_task)) {
break;
}
}
size_t rem = batchTask.append_task(task);
previous_task = task;
if (task->rem <= 0) {
_task_queue.pop_front();
}
if (rem <= 0) break;
}
LOG(INFO) << "Number of tasks remaining in _task_queue is"
<< _task_queue.size();
return true;
}
......@@ -201,11 +335,12 @@ bool TaskExecutor<TaskT>::move_task_to_batch(
// TaskT is from the SingleTon TaskExecutor`s _task_queue
// although TaskMeta is a local variable, but several TaskMeta may points to
// the same TaskT which is get from the SingleTon TaskExecutor`s _task_queue.
// put TaskMeta to the local variable BatchTasks<TaskT> batch.
// put TaskMeta to the local variable BatchTasks<TaskT> batchTask.
// batch.merge_tasks() and batch.notify_tasks() has no lock.
// BatchTasks<TaskT> batch itself is a local variable, it`s thread safe.
// If batch.merge_tasks() and batch.notify_tasks() do something to TaskMeta
// batchTask.merge_tasks() and batchTask.notify_tasks() has no lock.
// BatchTasks<TaskT> batchTask itself is a local variable, it`s thread safe.
// If batchTask.merge_tasks() and batchTask.notify_tasks() do something to
// TaskMeta
// you need to pay attention to that.
// Multi-Thread deal with different TaskMeta(cause it`s created as local
// variable)
......@@ -242,11 +377,23 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
return -1;
}
BatchTasks<TaskT> batch(_batch_size, _batch_align);
if (move_task_to_batch(batch)) {
batch.merge_tasks();
_fn(&batch.in(), &batch.out());
batch.notify_tasks();
// move_task_to_batch() take the original task from the `_task_queue`
// put the original task into its own Vector<taskmeta>
// the capacity of its own Vector<taskmeta> is decided by `_batch_size` or
// `_overrun`
// merge_tasks() move the imput-data into `_batch_in` from its own
// Vector<taskmeta>.
// because the predictor`s input is the `_batch_in`
// notify_tasks() move the output-data into every single taskmeta from
// `_batch_out`.
// because the predictor`s output is the `_batch_out`
BatchTasks<TaskT> batchTask(_batch_size, _overrun, _allow_split_request);
if (move_task_to_batch(batchTask)) {
batchTask.merge_tasks();
_fn(&batchTask.in(), &batchTask.out());
batchTask.notify_tasks();
}
}
......
此差异已折叠。
......@@ -25,7 +25,8 @@ int ReloadableInferEngine::proc_initialize_impl(
_model_dir = conf.model_dir();
_infer_thread_num = conf.runtime_thread_num();
_infer_batch_size = conf.batch_infer_size();
_infer_batch_align = conf.enable_batch_align();
_infer_overrun = conf.enable_overrun();
_allow_split_request = conf.allow_split_request();
_conf = conf;
......@@ -56,9 +57,6 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf,
}
// init bsf framework
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
.set_thread_init_fn(
boost::bind(&InferEngine::thrd_initialize_impl, this));
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
.set_thread_init_fn(
boost::bind(&InferEngine::thrd_initialize_impl, this));
......@@ -69,8 +67,10 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf,
boost::bind(&InferEngine::task_infer_impl, this, _1, _2));
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].set_batch_size(
_infer_batch_size);
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].set_batch_align(
_infer_batch_align);
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].set_overrun(
_infer_overrun);
im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index]
.set_allow_split_request(_allow_split_request);
if (im::bsf::TaskExecutorVector<TaskT>::instance()[_model_index].start(
_infer_thread_num) != 0) {
LOG(ERROR) << "Failed start bsf executor, threads:" << _infer_thread_num;
......@@ -79,7 +79,8 @@ int ReloadableInferEngine::proc_initialize(const configure::EngineDesc& conf,
LOG(WARNING) << "Enable batch schedule framework, thread_num:"
<< _infer_thread_num << ", batch_size:" << _infer_batch_size
<< ", enable_batch_align:" << _infer_batch_align;
<< ", enable_overrun:" << _infer_overrun
<< ", allow_split_request:" << _allow_split_request;
return 0;
}
......@@ -391,6 +392,11 @@ int VersionedInferEngine::task_infer_impl(const void* in,
return -1;
}
int InferManager::set_taskexecutor_num(size_t total_engine_num) {
im::bsf::TaskExecutorVector<TaskT>::instance().resize(total_engine_num);
return 0;
}
int InferManager::proc_initialize(const char* path,
const char* file,
std::shared_ptr<int> engine_index_ptr) {
......@@ -400,8 +406,6 @@ int InferManager::proc_initialize(const char* path,
return -1;
}
uint32_t engine_num = model_toolkit_conf.engines_size();
im::bsf::TaskExecutorVector<TaskT>::instance().resize(*engine_index_ptr +
engine_num);
for (uint32_t ei = 0; ei < engine_num; ++ei) {
LOG(INFO) << "model_toolkit_conf.engines(" << ei
<< ").name: " << model_toolkit_conf.engines(ei).name();
......
......@@ -169,8 +169,10 @@ class ReloadableInferEngine : public InferEngine {
uint32_t _infer_batch_size;
// Need to align batch_size in inferring
bool _infer_batch_align;
bool _infer_overrun;
// allow to split request in inferring
bool _allow_split_request;
// model version
uint64_t _version;
};
......@@ -645,6 +647,8 @@ class InferManager {
const char* file,
std::shared_ptr<int> engine_index_ptr);
int set_taskexecutor_num(size_t total_engine_num);
int thrd_initialize();
int thrd_clear();
......
......@@ -135,6 +135,17 @@ int Resource::initialize(const std::string& path, const std::string& file) {
if (FLAGS_enable_model_toolkit) {
size_t model_toolkit_num = resource_conf.model_toolkit_path_size();
// 此处暂时认为,每个model_toolkit仅包含一个engine
// 故认为 model_toolkit_num == engine总数
// 若以后出现model_toolkit仅包含多个engine
// 则应先for循环统计engine总数,再set_taskexecutor_num
// 切不可动态im::bsf::TaskExecutorVector<TaskT>::instance().resize
// TaskExecutor是线程池,内含锁,在engine进程初始化时已开始work加锁循环运行了
// 之后再resize内存搬运,会导致work使用原锁,而搬运后的TaskExecutor的锁内存已改变
if (InferManager::instance().set_taskexecutor_num(model_toolkit_num) != 0) {
LOG(ERROR) << "failed set_taskexecutor_num";
return -1;
}
std::shared_ptr<int> engine_index_ptr(new int(0));
for (size_t mi = 0; mi < model_toolkit_num; ++mi) {
std::string model_toolkit_path = resource_conf.model_toolkit_path(mi);
......
......@@ -52,7 +52,9 @@ Java的HttpClient使用示例见[`java/examples/src/main/java/PaddleServingClien
如果不能满足您的需求,您也可以在此基础上添加一些功能。
如需支持https或者自定义Response的Status Code等,则需要对C++端brpc-Server进行一定的二次开发,请参考https://github.com/apache/incubator-brpc/blob/master/docs/cn/http_service.md,后续如果需求很大,我们也会将这部分功能加入到Server中,尽情期待。
如需支持https或者自定义Response的Status Code等,则需要对C++端brpc-Server进行一定的二次开发,请参考https://github.com/apache/incubator-brpc/blob/master/docs/cn/http_service.md
后续如果需求很大,我们也会将这部分功能加入到Server中,尽情期待。
### curl方式发送Http请求(基本原理)
......
......@@ -23,11 +23,9 @@ args = benchmark_args()
reader = ChineseBertReader({"max_seq_len": 128})
fetch = ["pooled_output"]
client = HttpClient(ip='127.0.0.1', port='9292')
endpoint_list = ['127.0.0.1:9292']
client = HttpClient()
client.load_client_config(args.model)
#client.set_ip('127.0.0.1')
#client.set_port('9292')
'''
if you want use GRPC-client, set_use_grpc_client(True)
or you can directly use client.grpc_client_predict(...)
......@@ -49,6 +47,7 @@ we recommend use Proto data format in HTTP-body, set True(which is default)
if you want use JSON data format in HTTP-body, set False
'''
#client.set_http_proto(True)
client.connect(endpoint_list)
for line in sys.stdin:
feed_dict = reader.process(line)
......
......@@ -20,8 +20,6 @@ import time
client = HttpClient()
client.load_client_config(sys.argv[1])
#client.set_ip('127.0.0.1')
#client.set_port('9393')
'''
if you want use GRPC-client, set_use_grpc_client(True)
or you can directly use client.grpc_client_predict(...)
......@@ -43,13 +41,14 @@ we recommend use Proto data format in HTTP-body, set True(which is default)
if you want use JSON data format in HTTP-body, set False
'''
#client.set_http_proto(True)
client.connect(["127.0.0.1:9393"])
fetch_list = client.get_fetch_names()
import paddle
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1)
fetch_list = client.get_fetch_names()
for data in test_reader():
new_data = np.zeros((1, 13)).astype("float32")
new_data[0] = data[0][0]
......
......@@ -18,10 +18,8 @@ from paddle_serving_app.reader import Sequential, URL2Image, Resize
from paddle_serving_app.reader import CenterCrop, RGB2BGR, Transpose, Div, Normalize
import time
client = HttpClient(ip='127.0.0.1', port='9696')
client = HttpClient()
client.load_client_config(sys.argv[1])
#client.set_ip('127.0.0.1')
#client.set_port('9292')
'''
if you want use GRPC-client, set_use_grpc_client(True)
or you can directly use client.grpc_client_predict(...)
......@@ -43,6 +41,7 @@ we recommend use Proto data format in HTTP-body, set True(which is default)
if you want use JSON data format in HTTP-body, set False
'''
#client.set_http_proto(True)
client.connect(["127.0.0.1:9696"])
label_dict = {}
label_idx = 0
......
......@@ -17,10 +17,8 @@ from paddle_serving_app.reader.imdb_reader import IMDBDataset
import sys
import numpy as np
client = HttpClient(ip='127.0.0.1', port='9292')
client = HttpClient()
client.load_client_config(sys.argv[1])
#client.set_ip('127.0.0.1')
#client.set_port('9292')
'''
if you want use GRPC-client, set_use_grpc_client(True)
or you can directly use client.grpc_client_predict(...)
......@@ -42,6 +40,7 @@ we recommend use Proto data format in HTTP-body, set True(which is default)
if you want use JSON data format in HTTP-body, set False
'''
#client.set_http_proto(True)
client.connect(["127.0.0.1:9292"])
# you can define any english sentence or dataset here
# This example reuses imdb reader in training, you
......
......@@ -21,10 +21,8 @@ import os
import io
import numpy as np
client = HttpClient(ip='127.0.0.1', port='9292')
client = HttpClient()
client.load_client_config(sys.argv[1])
#client.set_ip('127.0.0.1')
#client.set_port('9292')
'''
if you want use GRPC-client, set_use_grpc_client(True)
or you can directly use client.grpc_client_predict(...)
......@@ -46,6 +44,7 @@ we recommend use Proto data format in HTTP-body, set True(which is default)
if you want use JSON data format in HTTP-body, set False
'''
#client.set_http_proto(True)
client.connect(["127.0.0.1:9292"])
reader = LACReader()
for line in sys.stdin:
......
# coding=utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server.web_service import WebService
from paddle_serving_app.reader import ChineseBertReader
import sys
import os
import numpy as np
class BertService(WebService):
def load(self):
self.reader = ChineseBertReader({
"vocab_file": "vocab.txt",
"max_seq_len": 128
})
def preprocess(self, feed=[], fetch=[]):
feed_res = []
is_batch = False
for ins in feed:
feed_dict = self.reader.process(ins["words"].encode("utf-8"))
for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape(
(len(feed_dict[key]), 1))
feed_res.append(feed_dict)
return feed_res, fetch, is_batch
bert_service = BertService(name="bert")
bert_service.load()
bert_service.load_model_config(sys.argv[1])
bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), use_lite=True, use_xpu=True, ir_optim=True)
bert_service.run_rpc_service()
bert_service.run_web_service()
# coding=utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server.web_service import WebService
from paddle_serving_app.reader import ChineseBertReader
import sys
import os
import numpy as np
class BertService(WebService):
def load(self):
self.reader = ChineseBertReader({
"vocab_file": "vocab.txt",
"max_seq_len": 128
})
def preprocess(self, feed=[], fetch=[]):
feed_res = []
is_batch = False
for ins in feed:
feed_dict = self.reader.process(ins["words"].encode("utf-8"))
for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape(
(len(feed_dict[key]), 1))
feed_res.append(feed_dict)
return feed_res, fetch, is_batch
bert_service = BertService(name="bert")
bert_service.load()
bert_service.load_model_config(sys.argv[1])
bert_service.prepare_server(
workdir="workdir", port=int(sys.argv[2]), use_lite=True, use_xpu=True, ir_optim=True)
bert_service.run_rpc_service()
bert_service.run_web_service()
......@@ -23,18 +23,3 @@ The `paddlepaddle` package is used in `test_client.py`, and you may need to down
``` shell
python3 test_client.py uci_housing_client/serving_client_conf.prototxt
```
## HTTP service
### Start server
Start a web service with default web service hosting modules:
``` shell
python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9393 --use_lite --use_xpu --ir_optim --name uci
```
### Client prediction
``` shell
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' http://127.0.0.1:9393/uci/prediction
```
......@@ -31,19 +31,3 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
``` shell
python3 test_client.py uci_housing_client/serving_client_conf.prototxt
```
## HTTP服务
### 开启服务端
通过下面的一行代码开启默认web服务:
``` shell
python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9393 --use_lite --use_xpu --ir_optim --name uci
```
### 客户端预测
``` shell
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"x": [0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584, 0.6283, 0.4919, 0.1856, 0.0795, -0.0332]}], "fetch":["price"]}' http://127.0.0.1:9393/uci/prediction
```
......@@ -289,6 +289,7 @@ class Client(object):
log_id=0):
self.profile_.record('py_prepro_0')
# fetch 可以为空,此时会取所有的输出结果
if feed is None:
raise ValueError("You should specify feed for prediction")
......@@ -297,6 +298,7 @@ class Client(object):
fetch_list = [fetch]
elif isinstance(fetch, list):
fetch_list = fetch
# fetch 可以为空,此时会取所有的输出结果
elif fetch == None:
pass
else:
......@@ -441,6 +443,7 @@ class Client(object):
model_engine_names = result_batch_handle.get_engine_names()
for mi, engine_name in enumerate(model_engine_names):
result_map = {}
# fetch 为空,则会取所有的输出结果
if len(fetch_names) == 0:
fetch_names = result_batch_handle.get_tensor_alias_names(mi)
# result map needs to be a numpy array
......
......@@ -22,6 +22,7 @@ import gzip
from collections import Iterable
import base64
import sys
import re
import grpc
from .proto import general_model_service_pb2
......@@ -98,7 +99,7 @@ class HttpClient(object):
self.headers["Content-Type"] = "application/proto"
self.max_body_size = 512 * 1024 * 1024
self.use_grpc_client = False
self.url = None
self.http_s = "http://"
# 使用连接池能够不用反复建立连接
self.requests_session = requests.session()
......@@ -170,7 +171,6 @@ class HttpClient(object):
def set_max_body_size(self, max_body_size):
self.max_body_size = max_body_size
self.init_grpc_stub()
def set_timeout_ms(self, timeout_ms):
if not isinstance(timeout_ms, int):
......@@ -183,25 +183,46 @@ class HttpClient(object):
raise ValueError("retry_times must be int type.")
else:
self.requests_session.mount(
'http://', HTTPAdapter(max_retries=retry_times))
def set_ip(self, ip):
self.ip = ip
self.init_grpc_stub()
self.http_s, HTTPAdapter(max_retries=retry_times))
def set_service_name(self, service_name):
self.service_name = service_name
def set_port(self, port):
self.port = port
self.server_port = port
self.init_grpc_stub()
def set_url(self, url):
def connect(self, url=None, encryption=False):
if isinstance(url, (list, tuple)):
if len(url) > 1:
raise ValueError("HttpClient only support 1 endpoint")
else:
url = url[0]
if isinstance(url, str):
self.url = url
if url.startswith("https://"):
url = url[8:]
self.http_s = "https://"
if url.startswith("http://"):
url = url[7:]
self.http_s = "http://"
url_parts = url.split(':')
if len(url_parts) != 2 or self.check_ip(url_parts[0]) == False:
raise ValueError(
"url not right, it should be like 127.0.0.1:9393 or http://127.0.0.1:9393"
)
else:
self.ip = url_parts[0]
self.port = url_parts[1]
self.server_port = url_parts[1]
if encryption:
self.get_serving_port()
if self.use_grpc_client:
self.init_grpc_stub()
def check_ip(self, ipAddr):
compile_ip = re.compile(
'^(1\d{2}|2[0-4]\d|25[0-5]|[1-9]\d|[1-9])\.(1\d{2}|2[0-4]\d|25[0-5]|[1-9]\d|\d)\.(1\d{2}|2[0-4]\d|25[0-5]|[1-9]\d|\d)\.(1\d{2}|2[0-4]\d|25[0-5]|[1-9]\d|\d)$'
)
if compile_ip.match(ipAddr):
return True
else:
print("url must be str")
return False
def add_http_headers(self, headers):
if isinstance(headers, dict):
......@@ -229,10 +250,9 @@ class HttpClient(object):
def use_key(self, key_filename):
with open(key_filename, "rb") as f:
self.key = f.read()
self.get_serving_port()
def get_serving_port(self):
encrypt_url = "http://" + str(self.ip) + ":" + str(self.port)
encrypt_url = self.http_s + str(self.ip) + ":" + str(self.port)
if self.key is not None:
req = json.dumps({"key": base64.b64encode(self.key).decode()})
else:
......@@ -481,13 +501,7 @@ class HttpClient(object):
postData = self.process_json_data(feed_dict, fetch_list, batch,
log_id)
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
if self.url != None:
if "http" not in self.url:
self.url = "http://" + self.url
if "self.service_name" not in self.url:
self.url = self.url + self.service_name
web_url = self.url
web_url = self.http_s + self.ip + ":" + self.server_port + self.service_name
# 当数据区长度大于512字节时才压缩.
self.headers.pop("Content-Encoding", "nokey")
try:
......
......@@ -228,7 +228,8 @@ class Server(object):
engine.batch_infer_size = self.op_max_batch[index %
len(self.op_max_batch)]
engine.enable_batch_align = 1
engine.enable_overrun = False
engine.allow_split_request = True
engine.model_dir = model_config_path
engine.enable_memory_optimization = self.memory_optimization
engine.enable_ir_optimization = self.ir_optimization
......
......@@ -40,9 +40,9 @@ go env -w GO111MODULE=auto
build_whl_list=(build_cpu_server build_gpu_server build_client build_app)
rpc_model_list=(grpc_fit_a_line grpc_yolov4 pipeline_imagenet bert_rpc_gpu bert_rpc_cpu ResNet50_rpc \
lac_rpc cnn_rpc bow_rpc lstm_rpc fit_a_line_rpc deeplabv3_rpc mobilenet_rpc unet_rpc resnetv2_rpc \
lac_rpc_asyn cnn_rpc_asyn bow_rpc lstm_rpc fit_a_line_rpc deeplabv3_rpc mobilenet_rpc unet_rpc resnetv2_rpc \
criteo_ctr_rpc_cpu criteo_ctr_rpc_gpu ocr_rpc yolov4_rpc_gpu faster_rcnn_hrnetv2p_w18_1x_encrypt \
faster_rcnn_model_rpc low_precision_resnet50_int8 ocr_c++_service)
faster_rcnn_model_rpc low_precision_resnet50_int8 ocr_c++_service ocr_c++_service_asyn)
http_model_list=(fit_a_line_http lac_http imdb_http_proto imdb_http_json imdb_grpc ResNet50_http bert_http \
pipeline_ocr_cpu_http)
......@@ -492,7 +492,7 @@ function ResNet101_rpc() {
kill_server_process
}
function cnn_rpc() {
function cnn_rpc_asyn() {
dir=${log_dir}rpc_model/cnn_rpc/
check_dir ${dir}
unsetproxy
......@@ -500,8 +500,9 @@ function cnn_rpc() {
data_dir=${data}imdb/
link_data ${data_dir}
sed -i 's/9292/8865/g' test_client.py
${py_version} -m paddle_serving_server.serve --model imdb_cnn_model/ --port 8865 > ${dir}server_log.txt 2>&1 &
check_result server 5
${py_version} -m paddle_serving_server.serve --model imdb_cnn_model/ --port 8865 --op_num 4 --thread 10 --gpu_ids 0 > ${dir}server_log.txt 2>&1 &
check_result server 8
check_gpu_memory 0
head test_data/part-0 | ${py_version} test_client.py imdb_cnn_client_conf/serving_client_conf.prototxt imdb.vocab > ${dir}client_log.txt 2>&1
check_result client "cnn_CPU_RPC server test completed"
kill_server_process
......@@ -537,7 +538,7 @@ function lstm_rpc() {
kill_server_process
}
function lac_rpc() {
function lac_rpc_asyn() {
dir=${log_dir}rpc_model/lac_rpc/
check_dir ${dir}
unsetproxy
......@@ -545,8 +546,9 @@ function lac_rpc() {
data_dir=${data}lac/
link_data ${data_dir}
sed -i 's/9292/8868/g' lac_client.py
${py_version} -m paddle_serving_server.serve --model lac_model/ --port 8868 > ${dir}server_log.txt 2>&1 &
check_result server 5
${py_version} -m paddle_serving_server.serve --model lac_model/ --port 8868 --gpu_ids 0 --op_num 2 > ${dir}server_log.txt 2>&1 &
check_result server 8
check_gpu_memory 0
echo "我爱北京天安门" | ${py_version} lac_client.py lac_client/serving_client_conf.prototxt lac_dict/ > ${dir}client_log.txt 2>&1
check_result client "lac_CPU_RPC server test completed"
kill_server_process
......@@ -923,6 +925,23 @@ function ocr_c++_service() {
kill_server_process
}
function ocr_c++_service_asyn() {
dir=${log_dir}rpc_model/ocr_c++_serving/
cd ${build_path}/python/examples/ocr
check_dir ${dir}
echo -e "${GREEN_COLOR}OCR_C++_Service_GPU_RPC asyn_server started${RES}"
$py_version -m paddle_serving_server.serve --model ocr_det_model ocr_rec_model --port 9293 --gpu_id 0 --op_num 4 > ${dir}server_log.txt 2>&1 &
check_result server 8
check_gpu_memory 0
echo -e "${GREEN_COLOR}OCR_C++_Service_GPU_RPC client started${RES}"
echo "------------------first:"
$py_version ocr_cpp_client.py ocr_det_client ocr_rec_client
echo "------------------second:"
$py_version ocr_cpp_client.py ocr_det_client ocr_rec_client > ${dir}client_log.txt 2>&1
check_result client "OCR_C++_Service_GPU_RPC server test completed"
kill_server_process
}
function build_all_whl() {
for whl in ${build_whl_list[@]}
do
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册