提交 af6ead49 编写于 作者: B barriery 提交者: GitHub

Merge branch 'develop' into java-sdk

......@@ -184,11 +184,6 @@ Here, `client.predict` function has two arguments. `feed` is a `python dict` wit
<h2 align="center">Community</h2>
### User Group in China
<p align="center"><img width="200" height="300" margin="500" src="./doc/qq.jpeg"/>&#8194;&#8194;&#8194;&#8194;&#8194<img width="200" height="300" src="doc/wechat.jpeg"/></p>
<p align="center">PaddleServing交流QQ群&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;PaddleServing微信群</p>
### Slack
......
......@@ -31,8 +31,9 @@ DEFINE_bool(print_output, false, "print output flag");
DEFINE_int32(thread_num, 1, "thread num");
std::atomic<int> g_concurrency(0);
std::vector<uint64_t> time_list;
std::vector<std::vector<uint64_t>> time_list;
std::vector<uint64_t> request_list;
int turns = 1000000 / FLAGS_batch;
namespace {
inline uint64_t time_diff(const struct timeval& start_time,
......@@ -97,7 +98,7 @@ int run(int argc, char** argv, int thread_id) {
while (g_concurrency.load() >= FLAGS_thread_num) {
}
g_concurrency++;
time_list[thread_id].resize(turns);
while (index < file_size) {
// uint64_t key = strtoul(buffer, NULL, 10);
......@@ -121,47 +122,12 @@ int run(int argc, char** argv, int thread_id) {
}
++seek_counter;
uint64_t seek_cost = time_diff(seek_start, seek_end);
seek_cost_total += seek_cost;
if (seek_cost > seek_cost_max) {
seek_cost_max = seek_cost;
}
if (seek_cost < seek_cost_min) {
seek_cost_min = seek_cost;
}
time_list[thread_id][request - 1] = seek_cost;
keys.clear();
values.clear();
}
}
/*
if (keys.size() > 0) {
int ret = 0;
values.resize(keys.size());
TIME_FLAG(seek_start);
ret = cube->seek(FLAGS_dict, keys, &values);
TIME_FLAG(seek_end);
if (ret != 0) {
LOG(WARNING) << "cube seek failed";
} else if (FLAGS_print_output) {
for (size_t i = 0; i < keys.size(); ++i) {
fprintf(stdout,
"key:%lu value:%s\n",
keys[i],
string_to_hex(values[i].buff).c_str());
}
}
++seek_counter;
uint64_t seek_cost = time_diff(seek_start, seek_end);
seek_cost_total += seek_cost;
if (seek_cost > seek_cost_max) {
seek_cost_max = seek_cost;
}
if (seek_cost < seek_cost_min) {
seek_cost_min = seek_cost;
}
}
*/
g_concurrency--;
// fclose(key_file);
......@@ -171,12 +137,6 @@ int run(int argc, char** argv, int thread_id) {
LOG(WARNING) << "destroy cube api failed err=" << ret;
}
uint64_t seek_cost_avg = seek_cost_total / seek_counter;
LOG(INFO) << "seek cost avg = " << seek_cost_avg;
LOG(INFO) << "seek cost max = " << seek_cost_max;
LOG(INFO) << "seek cost min = " << seek_cost_min;
time_list[thread_id] = seek_cost_avg;
request_list[thread_id] = request;
return 0;
......@@ -188,6 +148,7 @@ int run_m(int argc, char** argv) {
request_list.resize(thread_num);
time_list.resize(thread_num);
std::vector<std::thread*> thread_pool;
TIME_FLAG(main_start);
for (int i = 0; i < thread_num; i++) {
thread_pool.push_back(new std::thread(run, argc, argv, i));
}
......@@ -195,27 +156,33 @@ int run_m(int argc, char** argv) {
thread_pool[i]->join();
delete thread_pool[i];
}
TIME_FLAG(main_end);
uint64_t sum_time = 0;
uint64_t max_time = 0;
uint64_t min_time = 1000000;
uint64_t request_num = 0;
for (int i = 0; i < thread_num; i++) {
sum_time += time_list[i];
if (time_list[i] > max_time) {
max_time = time_list[i];
}
if (time_list[i] < min_time) {
min_time = time_list[i];
for (int j = 0; j < request_list[i]; j++) {
sum_time += time_list[i][j];
if (time_list[i][j] > max_time) {
max_time = time_list[i][j];
}
if (time_list[i][j] < min_time) {
min_time = time_list[i][j];
}
}
request_num += request_list[i];
}
uint64_t mean_time = sum_time / thread_num;
LOG(INFO) << thread_num << " thread seek cost"
<< " avg = " << std::to_string(mean_time)
<< " max = " << std::to_string(max_time)
<< " min = " << std::to_string(min_time);
LOG(INFO) << " total_request = " << std::to_string(request_num) << " speed = "
<< std::to_string(1000000 * thread_num / mean_time) // mean_time us
uint64_t mean_time = sum_time / (thread_num * turns);
uint64_t main_time = time_diff(main_start, main_end);
LOG(INFO) << "\n"
<< thread_num << " thread seek cost"
<< "\navg = " << std::to_string(mean_time)
<< "\nmax = " << std::to_string(max_time)
<< "\nmin = " << std::to_string(min_time);
LOG(INFO) << "\ntotal_request = " << std::to_string(request_num)
<< "\nspeed = " << std::to_string(request_num * 1000000 /
main_time) // mean_time us
<< " query per second";
return 0;
}
......
......@@ -90,6 +90,9 @@ int GeneralDistKVInferOp::inference() {
keys.begin() + key_idx);
key_idx += dataptr_size_pairs[i].second;
}
Timer timeline;
int64_t cube_start = timeline.TimeStampUS();
timeline.Start();
rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance();
std::vector<std::string> table_names = cube->get_table_names();
if (table_names.size() == 0) {
......@@ -97,7 +100,7 @@ int GeneralDistKVInferOp::inference() {
return -1;
}
int ret = cube->seek(table_names[0], keys, &values);
int64_t cube_end = timeline.TimeStampUS();
if (values.size() != keys.size() || values[0].buff.size() == 0) {
LOG(ERROR) << "cube value return null";
}
......@@ -153,9 +156,7 @@ int GeneralDistKVInferOp::inference() {
VLOG(2) << "infer batch size: " << batch_size;
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
if (InferManager::instance().infer(
engine_name().c_str(), &infer_in, out, batch_size)) {
......@@ -165,6 +166,8 @@ int GeneralDistKVInferOp::inference() {
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, cube_start);
AddBlobInfo(output_blob, cube_end);
AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end);
return 0;
......
......@@ -114,70 +114,48 @@ int GeneralResponseOp::inference() {
for (int j = 0; j < in->at(idx).shape.size(); ++j) {
cap *= in->at(idx).shape[j];
}
if (in->at(idx).dtype == paddle::PaddleDType::INT64) {
FetchInst *fetch_p = output->mutable_insts(0);
auto dtype = in->at(idx).dtype;
if (dtype == paddle::PaddleDType::INT64) {
VLOG(2) << "Prepare int64 var [" << model_config->_fetch_name[idx]
<< "].";
int64_t *data_ptr = static_cast<int64_t *>(in->at(idx).data.data());
if (model_config->_is_lod_fetch[idx]) {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(idx).lod[0][j]);
}
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]);
}
} else {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]);
}
}
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
} else if (in->at(idx).dtype == paddle::PaddleDType::FLOAT32) {
// from
// https://stackoverflow.com/questions/15499641/copy-a-stdvector-to-a-repeated-field-from-protobuf-with-memcpy
// `Swap` method is faster than `{}` method.
google::protobuf::RepeatedField<int64_t> tmp_data(data_ptr,
data_ptr + cap);
fetch_p->mutable_tensor_array(var_idx)->mutable_int64_data()->Swap(
&tmp_data);
} else if (dtype == paddle::PaddleDType::FLOAT32) {
VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx]
<< "].";
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
if (model_config->_is_lod_fetch[idx]) {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(idx).lod[0][j]);
}
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]);
}
} else {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]);
}
}
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
} else if (in->at(idx).dtype == paddle::PaddleDType::INT32) {
google::protobuf::RepeatedField<float> tmp_data(data_ptr,
data_ptr + cap);
fetch_p->mutable_tensor_array(var_idx)->mutable_float_data()->Swap(
&tmp_data);
} else if (dtype == paddle::PaddleDType::INT32) {
VLOG(2) << "Prepare int32 var [" << model_config->_fetch_name[idx]
<< "].";
int32_t *data_ptr = static_cast<int32_t *>(in->at(idx).data.data());
if (model_config->_is_lod_fetch[idx]) {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(idx).lod[0][j]);
}
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_int_data(data_ptr[j]);
}
} else {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_int_data(data_ptr[j]);
}
google::protobuf::RepeatedField<int32_t> tmp_data(data_ptr,
data_ptr + cap);
fetch_p->mutable_tensor_array(var_idx)->mutable_int_data()->Swap(
&tmp_data);
}
if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_lod(
in->at(idx).lod[0][j]);
}
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
}
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
}
}
......
# Paddle Serving
([简体中文](./README_CN.md)|English)
Paddle Serving is PaddlePaddle's online estimation service framework, which can help developers easily implement remote prediction services that call deep learning models from mobile and server ends. At present, Paddle Serving is mainly based on models that support PaddlePaddle training. It can be used in conjunction with the Paddle training framework to quickly deploy inference services. Paddle Serving is designed around common industrial-level deep learning model deployment scenarios. Some common functions include multi-model management, model hot loading, [Baidu-rpc](https://github.com/apache/incubator-brpc)-based high-concurrency low-latency response capabilities, and online model A/B tests. The API that cooperates with the Paddle training framework can enable users to seamlessly transition between training and remote deployment, improving the landing efficiency of deep learning models.
------------
## Quick Start
Paddle Serving's current develop version supports lightweight Python API for fast predictions, and training with Paddle can get through. We take the most classic Boston house price prediction as an example to fully explain the process of model training on a single machine and model deployment using Paddle Serving.
#### Install
It is highly recommended that you build Paddle Serving inside Docker, please read [How to run PaddleServing in Docker](RUN_IN_DOCKER.md)
```
pip install paddle-serving-client
pip install paddle-serving-server
```
#### Training Script
``` python
import sys
import paddle
import paddle.fluid as fluid
train_reader = paddle.batch(paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500), batch_size=16)
test_reader = paddle.batch(paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500), batch_size=16)
x = fluid.data(name='x', shape=[None, 13], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_loss)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
import paddle_serving_client.io as serving_io
for pass_id in range(30):
for data_train in train_reader():
avg_loss_value, = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data_train),
fetch_list=[avg_loss])
serving_io.save_model(
"serving_server_model", "serving_client_conf",
{"x": x}, {"y": y_predict}, fluid.default_main_program())
```
#### Server Side Code
``` python
import sys
from paddle_serving.serving_server import OpMaker
from paddle_serving.serving_server import OpSeqMaker
from paddle_serving.serving_server import Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.load_model_config(sys.argv[1])
server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
server.run_server()
```
#### Launch Server End
``` shell
python test_server.py serving_server_model
```
#### Client Prediction
``` python
from paddle_serving_client import Client
import paddle
import sys
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9292"])
test_reader = paddle.batch(paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500), batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["y"])
print("{} {}".format(fetch_map["y"][0], data[0][1][0]))
```
### Document
[Design Doc](DESIGN.md)
[FAQ](./deprecated/FAQ.md)
### Senior Developer Guildlines
[Compile Tutorial](COMPILE.md)
## Contribution
If you want to make contributions to Paddle Serving Please refer to [CONRTIBUTE](CONTRIBUTE.md)
# Paddle Serving
(简体中文|[English](./README.md))
Paddle Serving是PaddlePaddle的在线预估服务框架,能够帮助开发者轻松实现从移动端、服务器端调用深度学习模型的远程预测服务。当前Paddle Serving以支持PaddlePaddle训练的模型为主,可以与Paddle训练框架联合使用,快速部署预估服务。Paddle Serving围绕常见的工业级深度学习模型部署场景进行设计,一些常见的功能包括多模型管理、模型热加载、基于[Baidu-rpc](https://github.com/apache/incubator-brpc)的高并发低延迟响应能力、在线模型A/B实验等。与Paddle训练框架互相配合的API可以使用户在训练与远程部署之间无缝过度,提升深度学习模型的落地效率。
------------
## 快速上手指南
Paddle Serving当前的develop版本支持轻量级Python API进行快速预测,并且与Paddle的训练可以打通。我们以最经典的波士顿房价预测为示例,完整说明在单机进行模型训练以及使用Paddle Serving进行模型部署的过程。
#### 安装
强烈建议您在Docker内构建Paddle Serving,请查看[如何在Docker中运行PaddleServing](RUN_IN_DOCKER_CN.md)
```
pip install paddle-serving-client
pip install paddle-serving-server
```
#### 训练脚本
``` python
import sys
import paddle
import paddle.fluid as fluid
train_reader = paddle.batch(paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500), batch_size=16)
test_reader = paddle.batch(paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500), batch_size=16)
x = fluid.data(name='x', shape=[None, 13], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_loss)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
import paddle_serving_client.io as serving_io
for pass_id in range(30):
for data_train in train_reader():
avg_loss_value, = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data_train),
fetch_list=[avg_loss])
serving_io.save_model(
"serving_server_model", "serving_client_conf",
{"x": x}, {"y": y_predict}, fluid.default_main_program())
```
#### 服务器端代码
``` python
import sys
from paddle_serving.serving_server import OpMaker
from paddle_serving.serving_server import OpSeqMaker
from paddle_serving.serving_server import Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.load_model_config(sys.argv[1])
server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
server.run_server()
```
#### 服务器端启动
``` shell
python test_server.py serving_server_model
```
#### 客户端预测
``` python
from paddle_serving_client import Client
import paddle
import sys
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9292"])
test_reader = paddle.batch(paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500), batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["y"])
print("{} {}".format(fetch_map["y"][0], data[0][1][0]))
```
### 文档
[设计文档](DESIGN_CN.md)
[FAQ](./deprecated/FAQ.md)
### 资深开发者使用指南
[编译指南](COMPILE_CN.md)
## 贡献
如果你想要给Paddle Serving做贡献,请参考[贡献指南](CONTRIBUTE.md)
......@@ -116,8 +116,10 @@ def single_func(idx, resource):
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9292"]
turns = 10
endpoint_list = [
"127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295"
]
turns = 100
start = time.time()
result = multi_thread_runner.run(
single_func, args.thread, {"endpoint": endpoint_list,
......@@ -130,9 +132,9 @@ if __name__ == '__main__':
avg_cost += result[0][i]
avg_cost = avg_cost / args.thread
print("total cost :{} s".format(total_cost))
print("each thread cost :{} s. ".format(avg_cost))
print("qps :{} samples/s".format(args.batch_size * args.thread * turns /
total_cost))
print("total cost: {}s".format(total_cost))
print("each thread cost: {}s. ".format(avg_cost))
print("qps: {}samples/s".format(args.batch_size * args.thread * turns /
total_cost))
if os.getenv("FLAGS_serving_latency"):
show_latency(result[1])
rm profile_log
rm profile_log*
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FLAGS_profile_server=1
export FLAGS_profile_client=1
export FLAGS_serving_latency=1
python3 -m paddle_serving_server_gpu.serve --model $1 --port 9292 --thread 4 --gpu_ids 0,1,2,3 --mem_optim False --ir_optim True 2> elog > stdlog &
gpu_id=0
#save cpu and gpu utilization log
if [ -d utilization ];then
rm -rf utilization
else
mkdir utilization
fi
#start server
$PYTHONROOT/bin/python3 -m paddle_serving_server_gpu.serve --model $1 --port 9292 --thread 4 --gpu_ids 0,1,2,3 --mem_optim --ir_optim > elog 2>&1 &
sleep 5
#warm up
python3 benchmark.py --thread 8 --batch_size 1 --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
for thread_num in 4 8 16
$PYTHONROOT/bin/python3 benchmark.py --thread 4 --batch_size 1 --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
echo -e "import psutil\ncpu_utilization=psutil.cpu_percent(1,False)\nprint('CPU_UTILIZATION:', cpu_utilization)\n" > cpu_utilization.py
for thread_num in 1 4 8 16
do
for batch_size in 1 4 16 64 256
for batch_size in 1 4 16 64
do
python3 benchmark.py --thread $thread_num --batch_size $batch_size --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "model name :" $1
echo "thread num :" $thread_num
echo "batch size :" $batch_size
job_bt=`date '+%Y%m%d%H%M%S'`
nvidia-smi --id=0 --query-compute-apps=used_memory --format=csv -lms 100 > gpu_use.log 2>&1 &
nvidia-smi --id=0 --query-gpu=utilization.gpu --format=csv -lms 100 > gpu_utilization.log 2>&1 &
gpu_memory_pid=$!
$PYTHONROOT/bin/python3 benchmark.py --thread $thread_num --batch_size $batch_size --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
kill ${gpu_memory_pid}
kill `ps -ef|grep used_memory|awk '{print $2}'`
echo "model_name:" $1
echo "thread_num:" $thread_num
echo "batch_size:" $batch_size
echo "=================Done===================="
echo "model name :$1" >> profile_log_$1
echo "batch size :$batch_size" >> profile_log_$1
python3 ../util/show_profile.py profile $thread_num >> profile_log_$1
echo "model_name:$1" >> profile_log_$1
echo "batch_size:$batch_size" >> profile_log_$1
$PYTHONROOT/bin/python3 cpu_utilization.py >> profile_log_$1
job_et=`date '+%Y%m%d%H%M%S'`
awk 'BEGIN {max = 0} {if(NR>1){if ($1 > max) max=$1}} END {print "MAX_GPU_MEMORY:", max}' gpu_use.log >> profile_log_$1
awk 'BEGIN {max = 0} {if(NR>1){if ($1 > max) max=$1}} END {print "GPU_UTILIZATION:", max}' gpu_utilization.log >> profile_log_$1
rm -rf gpu_use.log gpu_utilization.log
$PYTHONROOT/bin/python3 ../util/show_profile.py profile $thread_num >> profile_log_$1
tail -n 8 profile >> profile_log_$1
echo "" >> profile_log_$1
done
done
#Divided log
awk 'BEGIN{RS="\n\n"}{i++}{print > "bert_log_"i}' profile_log_$1
mkdir bert_log && mv bert_log_* bert_log
ps -ef|grep 'serving'|grep -v grep|cut -c 9-15 | xargs kill -9
# Blazeface
## Get Model
```
python -m paddle_serving_app.package --get_model blazeface
tar -xzvf blazeface.tar.gz
```
## RPC Service
### Start Service
```
python -m paddle_serving_server.serve --model serving_server --port 9494
```
### Client Prediction
```
python test_client.py serving_client/serving_client_conf.prototxt test.jpg
```
the result is in `output` folder, including a json file and image file with bounding boxes.
......@@ -13,19 +13,26 @@
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
from paddle_serving_app.reader import *
import sys
import numpy as np
preprocess = Sequential([
File2Image(),
Normalize([104, 117, 123], [127.502231, 127.502231, 127.502231], False)
])
postprocess = BlazeFacePostprocess("label_list.txt", "output")
client = Client()
client.load_client_config("ocr_rec_client/serving_client_conf.prototxt")
client.connect(["127.0.0.1:9292"])
image_file_list = ["./test_rec.jpg"]
img = cv2.imread(image_file_list[0])
ocr_reader = OCRReader()
feed = {"image": ocr_reader.preprocess([img])}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
fetch_map = client.predict(feed=feed, fetch=fetch)
rec_res = ocr_reader.postprocess(fetch_map)
print(image_file_list[0])
print(rec_res[0][0])
client.load_client_config(sys.argv[1])
client.connect(['127.0.0.1:9494'])
im_0 = preprocess(sys.argv[2])
tmp = Transpose((2, 0, 1))
im = tmp(im_0)
fetch_map = client.predict(
feed={"image": im}, fetch=["detection_output_0.tmp_0"])
fetch_map["image"] = sys.argv[2]
fetch_map["im_shape"] = im_0.shape
postprocess(fetch_map)
......@@ -24,11 +24,13 @@ from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
from paddle_serving_client.metric import auc
py_version = sys.version_info[0]
args = benchmark_args()
def single_func(idx, resource):
client = Client()
print([resource["endpoint"][idx % len(resource["endpoint"])]])
client.load_client_config('ctr_client_conf/serving_client_conf.prototxt')
client.connect(['127.0.0.1:9292'])
batch = 1
......@@ -40,27 +42,32 @@ def single_func(idx, resource):
]
reader = dataset.infer_reader(test_filelists[len(test_filelists) - 40:],
batch, buf_size)
args.batch_size = 1
if args.request == "rpc":
fetch = ["prob"]
print("Start Time")
start = time.time()
itr = 1000
for ei in range(itr):
if args.batch_size == 1:
data = reader().next()
feed_dict = {}
feed_dict['dense_input'] = data[0][0]
for i in range(1, 27):
feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][i]
result = client.predict(feed=feed_dict, fetch=fetch)
if args.batch_size > 0:
feed_batch = []
for bi in range(args.batch_size):
if py_version == 2:
data = reader().next()
else:
data = reader().__next__()
feed_dict = {}
feed_dict['dense_input'] = data[0][0]
for i in range(1, 27):
feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][
i]
feed_batch.append(feed_dict)
result = client.predict(feed=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http":
raise ("Not support http service.")
end = time.time()
qps = itr / (end - start)
qps = itr * args.batch_size / (end - start)
return [[end - start, qps]]
......@@ -68,13 +75,17 @@ if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9292"]
#result = single_func(0, {"endpoint": endpoint_list})
start = time.time()
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
end = time.time()
total_cost = end - start
avg_cost = 0
qps = 0
for i in range(args.thread):
avg_cost += result[0][i * 2 + 0]
qps += result[0][i * 2 + 1]
avg_cost = avg_cost / args.thread
print("total cost: {}".format(total_cost))
print("average total cost {} s.".format(avg_cost))
print("qps {} ins/s".format(qps))
rm profile_log
batch_size=1
for thread_num in 1 2 4 8 16
export FLAGS_profile_client=1
export FLAGS_profile_server=1
wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz --no-check-certificate
tar xf ctr_cube_unittest.tar.gz
mv models/ctr_client_conf ./
mv models/ctr_serving_model_kv ./
mv models/data ./cube/
wget https://paddle-serving.bj.bcebos.com/others/cube_app.tar.gz --no-check-certificate
tar xf cube_app.tar.gz
mv cube_app/cube* ./cube/
sh cube_prepare.sh &
python test_server.py ctr_serving_model_kv > serving_log 2>&1 &
for thread_num in 1 4 16
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --model ctr_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
for batch_size in 1 4 16 64
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "batch size : $batch_size"
echo "thread num : $thread_num"
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 2 profile >> profile_log
tail -n 3 profile >> profile_log
done
done
ps -ef|grep 'serving'|grep -v grep|cut -c 9-15 | xargs kill -9
# -*- coding: utf-8 -*-
#
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import Client
import sys
import os
import criteo as criteo
import time
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
from paddle_serving_client.metric import auc
args = benchmark_args()
def single_func(idx, resource):
client = Client()
print([resource["endpoint"][idx % len(resource["endpoint"])]])
client.load_client_config('ctr_client_conf/serving_client_conf.prototxt')
client.connect(['127.0.0.1:9292'])
batch = 1
buf_size = 100
dataset = criteo.CriteoDataset()
dataset.setup(1000001)
test_filelists = [
"./raw_data/part-%d" % x for x in range(len(os.listdir("./raw_data")))
]
reader = dataset.infer_reader(test_filelists[len(test_filelists) - 40:],
batch, buf_size)
if args.request == "rpc":
fetch = ["prob"]
start = time.time()
itr = 1000
for ei in range(itr):
if args.batch_size > 1:
feed_batch = []
for bi in range(args.batch_size):
data = reader().next()
feed_dict = {}
feed_dict['dense_input'] = data[0][0]
for i in range(1, 27):
feed_dict["embedding_{}.tmp_0".format(i - 1)] = data[0][
i]
feed_batch.append(feed_dict)
result = client.predict(feed=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http":
raise ("Not support http service.")
end = time.time()
qps = itr * args.batch_size / (end - start)
return [[end - start, qps]]
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9292"]
#result = single_func(0, {"endpoint": endpoint_list})
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
print(result)
avg_cost = 0
qps = 0
for i in range(args.thread):
avg_cost += result[0][i * 2 + 0]
qps += result[0][i * 2 + 1]
avg_cost = avg_cost / args.thread
print("average total cost {} s.".format(avg_cost))
print("qps {} ins/s".format(qps))
rm profile_log
for thread_num in 1 2 4 8 16
do
for batch_size in 1 2 4 8 16 32 64 128 256 512
do
$PYTHONROOT/bin/python benchmark_batch.py --thread $thread_num --batch_size $batch_size --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 2 profile >> profile_log
done
done
rm profile_log
wget https://paddle-serving.bj.bcebos.com/unittest/ctr_cube_unittest.tar.gz --no-check-certificate
tar xf ctr_cube_unittest.tar.gz
mv models/ctr_client_conf ./
mv models/ctr_serving_model_kv ./
mv models/data ./cube/
wget https://paddle-serving.bj.bcebos.com/others/cube_app.tar.gz --no-check-certificate
tar xf cube_app.tar.gz
mv cube_app/cube* ./cube/
sh cube_prepare.sh &
cp ../../../build_server/core/cube/cube-api/cube-cli .
python gen_key.py
for thread_num in 1 4 16 32
do
for batch_size in 1000
do
./cube-cli -config_file ./cube/conf/cube.conf -keys key -dict test_dict -thread_num $thread_num --batch $batch_size > profile 2>&1
echo "batch size : $batch_size"
echo "thread num : $thread_num"
echo "========================================"
echo "batch size : $batch_size" >> profile_log
echo "thread num : $thread_num" >> profile_log
tail -n 7 profile | head -n 4 >> profile_log
tail -n 2 profile >> profile_log
done
done
ps -ef|grep 'cube'|grep -v grep|cut -c 9-15 | xargs kill -9
......@@ -16,7 +16,5 @@
mkdir -p cube_model
mkdir -p cube/data
./seq_generator ctr_serving_model/SparseFeatFactors ./cube_model/feature
./cube/cube-builder -dict_name=test_dict -job_mode=base -last_version=0 -cur_version=0 -depend_version=0 -input_path=./cube_model -output_path=${PWD}/cube/data -shard_num=1 -only_build=false
mv ./cube/data/0_0/test_dict_part0/* ./cube/data/
cd cube && ./cube
cd cube && ./cube
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
with open("key", "w") as f:
for i in range(1000000):
f.write("{}\n".format(random.randint(0, 999999)))
......@@ -20,6 +20,8 @@ import criteo as criteo
import time
from paddle_serving_client.metric import auc
py_version = sys.version_info[0]
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9292"])
......@@ -34,7 +36,10 @@ label_list = []
prob_list = []
start = time.time()
for ei in range(10000):
data = reader().next()
if py_version == 2:
data = reader().next()
else:
data = reader().__next__()
feed_dict = {}
feed_dict['dense_input'] = data[0][0]
for i in range(1, 27):
......
......@@ -33,5 +33,9 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
server.load_model_config(sys.argv[1])
server.prepare_server(workdir="work_dir1", port=9292, device="cpu")
server.prepare_server(
workdir="work_dir1",
port=9292,
device="cpu",
cube_conf="./cube/conf/cube.conf")
server.run_server()
......@@ -33,5 +33,9 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
server.load_model_config(sys.argv[1])
server.prepare_server(workdir="work_dir1", port=9292, device="cpu")
server.prepare_server(
workdir="work_dir1",
port=9292,
device="cpu",
cube_conf="./cube/conf/cube.conf")
server.run_server()
......@@ -33,5 +33,9 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
server.load_model_config(sys.argv[1], sys.argv[2])
server.prepare_server(workdir="work_dir1", port=9292, device="cpu")
server.prepare_server(
workdir="work_dir1",
port=9292,
device="cpu",
cube_conf="./cube/conf/cube.conf")
server.run_server()
......@@ -33,5 +33,9 @@ server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
server.load_model_config(sys.argv[1], sys.argv[2])
server.prepare_server(workdir="work_dir1", port=9292, device="cpu")
server.prepare_server(
workdir="work_dir1",
port=9292,
device="cpu",
cube_conf="./cube/conf/cube.conf")
server.run_server()
......@@ -24,38 +24,43 @@ import json
import base64
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
from paddle_serving_app.reader import Sequential, URL2Image, Resize
from paddle_serving_client.utils import benchmark_args, show_latency
from paddle_serving_app.reader import Sequential, File2Image, Resize
from paddle_serving_app.reader import CenterCrop, RGB2BGR, Transpose, Div, Normalize
args = benchmark_args()
seq_preprocess = Sequential([
URL2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
File2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True)
])
def single_func(idx, resource):
file_list = []
turns = resource["turns"]
latency_flags = False
if os.getenv("FLAGS_serving_latency"):
latency_flags = True
latency_list = []
for file_name in os.listdir("./image_data/n01440764"):
file_list.append(file_name)
img_list = []
for i in range(1000):
img_list.append(open("./image_data/n01440764/" + file_list[i]).read())
img_list.append("./image_data/n01440764/" + file_list[i])
profile_flags = False
if "FLAGS_profile_client" in os.environ and os.environ[
"FLAGS_profile_client"]:
profile_flags = True
if args.request == "rpc":
reader = ImageReader()
fetch = ["score"]
client = Client()
client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for i in range(1000):
for i in range(turns):
if args.batch_size >= 1:
l_start = time.time()
feed_batch = []
i_start = time.time()
for bi in range(args.batch_size):
......@@ -69,6 +74,9 @@ def single_func(idx, resource):
int(round(i_end * 1000000))))
result = client.predict(feed=feed_batch, fetch=fetch)
l_end = time.time()
if latency_flags:
latency_list.append(l_end * 1000 - l_start * 1000)
else:
print("unsupport batch size {}".format(args.batch_size))
......@@ -77,7 +85,7 @@ def single_func(idx, resource):
server = "http://" + resource["endpoint"][idx % len(resource[
"endpoint"])] + "/image/prediction"
start = time.time()
for i in range(1000):
for i in range(turns):
if py_version == 2:
image = base64.b64encode(
open("./image_data/n01440764/" + file_list[i]).read())
......@@ -88,18 +96,31 @@ def single_func(idx, resource):
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
end = time.time()
if latency_flags:
return [[end - start], latency_list]
return [[end - start]]
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9393"]
#endpoint_list = endpoint_list + endpoint_list + endpoint_list
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
endpoint_list = [
"127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295"
]
turns = 100
start = time.time()
result = multi_thread_runner.run(
single_func, args.thread, {"endpoint": endpoint_list,
"turns": turns})
#result = single_func(0, {"endpoint": endpoint_list})
end = time.time()
total_cost = end - start
avg_cost = 0
for i in range(args.thread):
avg_cost += result[0][i]
avg_cost = avg_cost / args.thread
print("average total cost {} s.".format(avg_cost))
print("total cost: {}s".format(end - start))
print("each thread cost: {}s.".format(avg_cost))
print("qps: {}samples/s".format(args.batch_size * args.thread * turns /
total_cost))
if os.getenv("FLAGS_serving_latency"):
show_latency(result[1])
rm profile_log
rm profile_log*
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FLAGS_profile_server=1
export FLAGS_profile_client=1
python -m paddle_serving_server_gpu.serve --model $1 --port 9292 --thread 4 --gpu_ids 0,1,2,3 2> elog > stdlog &
python -m paddle_serving_server_gpu.serve --model $1 --port 9292 --thread 4 --gpu_ids 0,1,2,3 --mem_optim --ir_optim 2> elog > stdlog &
sleep 5
gpu_id=0
#save cpu and gpu utilization log
if [ -d utilization ];then
rm -rf utilization
else
mkdir utilization
fi
#warm up
$PYTHONROOT/bin/python benchmark.py --thread 8 --batch_size 1 --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
$PYTHONROOT/bin/python3 benchmark.py --thread 4 --batch_size 1 --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
echo -e "import psutil\ncpu_utilization=psutil.cpu_percent(1,False)\nprint('CPU_UTILIZATION:', cpu_utilization)\n" > cpu_utilization.py
for thread_num in 4 8 16
for thread_num in 1 4 8 16
do
for batch_size in 1 4 16 64 256
for batch_size in 1 4 16 64
do
job_bt=`date '+%Y%m%d%H%M%S'`
nvidia-smi --id=0 --query-compute-apps=used_memory --format=csv -lms 100 > gpu_use.log 2>&1 &
nvidia-smi --id=0 --query-gpu=utilization.gpu --format=csv -lms 100 > gpu_utilization.log 2>&1 &
gpu_memory_pid=$!
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
kill ${gpu_memory_pid}
kill `ps -ef|grep used_memory|awk '{print $2}'`
echo "model name :" $1
echo "thread num :" $thread_num
echo "batch size :" $batch_size
echo "=================Done===================="
echo "model name :$1" >> profile_log
echo "batch size :$batch_size" >> profile_log
job_et=`date '+%Y%m%d%H%M%S'`
awk 'BEGIN {max = 0} {if(NR>1){if ($1 > max) max=$1}} END {print "MAX_GPU_MEMORY:", max}' gpu_use.log >> profile_log_$1
awk 'BEGIN {max = 0} {if(NR>1){if ($1 > max) max=$1}} END {print "GPU_UTILIZATION:", max}' gpu_utilization.log >> profile_log_$1
rm -rf gpu_use.log gpu_utilization.log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 8 profile >> profile_log
echo "" >> profile_log_$1
done
done
#Divided log
awk 'BEGIN{RS="\n\n"}{i++}{print > "ResNet_log_"i}' profile_log_$1
mkdir $1_log && mv ResNet_log_* $1_log
ps -ef|grep 'serving'|grep -v grep|cut -c 9-15 | xargs kill -9
......@@ -13,13 +13,14 @@
# limitations under the License.
# pylint: disable=doc-string-missing
import os
import sys
import time
import requests
from paddle_serving_app.reader import IMDBDataset
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
from paddle_serving_client.utils import MultiThreadRunner, benchmark_args, show_latency
args = benchmark_args()
......@@ -31,6 +32,13 @@ def single_func(idx, resource):
with open("./test_data/part-0") as fin:
for line in fin:
dataset.append(line.strip())
profile_flags = False
latency_flags = False
if os.getenv("FLAGS_profile_client"):
profile_flags = True
if os.getenv("FLAGS_serving_latency"):
latency_flags = True
latency_list = []
start = time.time()
if args.request == "rpc":
client = Client()
......@@ -67,9 +75,26 @@ def single_func(idx, resource):
return [[end - start]]
multi_thread_runner = MultiThreadRunner()
result = multi_thread_runner.run(single_func, args.thread, {})
avg_cost = 0
for cost in result[0]:
avg_cost += cost
print("total cost {} s of each thread".format(avg_cost / args.thread))
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = [
"127.0.0.1:9292", "127.0.0.1:9293", "127.0.0.1:9294", "127.0.0.1:9295"
]
turns = 100
start = time.time()
result = multi_thread_runner.run(
single_func, args.thread, {"endpoint": endpoint_list,
"turns": turns})
end = time.time()
total_cost = end - start
avg_cost = 0
for i in range(args.thread):
avg_cost += result[0][i]
avg_cost = avg_cost / args.thread
print("total cost: {}".format(total_cost))
print("each thread cost: {}".format(avg_cost))
print("qps: {}samples/s".format(args.batch_size * args.thread * turns /
total_cost))
if os.getenv("FLAGS_serving_latency"):
show_latency(result[0])
rm profile_log
for thread_num in 1 2 4 8 16
rm profile_log*
export FLAGS_profile_server=1
export FLAGS_profile_client=1
export FLAGS_serving_latency=1
$PYTHONROOT/bin/python3 -m paddle_serving_server.serve --model $1 --port 9292 --thread 4 --mem_optim --ir_optim 2> elog > stdlog &
hostname=`echo $(hostname)|awk -F '.baidu.com' '{print $1}'`
#save cpu and gpu utilization log
if [ -d utilization ];then
rm -rf utilization
else
mkdir utilization
fi
sleep 5
#warm up
$PYTHONROOT/bin/python3 benchmark.py --thread 4 --batch_size 1 --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
echo -e "import psutil\ncpu_utilization=psutil.cpu_percent(1,False)\nprint('CPU_UTILIZATION:', cpu_utilization)\n" > cpu_utilization.py
for thread_num in 1 4 8 16
do
for batch_size in 1 2 4 8 16 32 64 128 256 512
for batch_size in 1 4 16 64
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model imdb_bow_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log
job_bt=`date '+%Y%m%d%H%M%S'`
$PYTHONROOT/bin/python3 benchmark.py --thread $thread_num --batch_size $batch_size --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "model_name:" $1
echo "thread_num:" $thread_num
echo "batch_size:" $batch_size
echo "=================Done===================="
echo "model_name:$1" >> profile_log_$1
echo "batch_size:$batch_size" >> profile_log_$1
job_et=`date '+%Y%m%d%H%M%S'`
$PYTHONROOT/bin/python3 ../util/show_profile.py profile $thread_num >> profile_log_$1
$PYTHONROOT/bin/python3 cpu_utilization.py >> profile_log_$1
tail -n 8 profile >> profile_log_$1
echo "" >> profile_log_$1
done
done
#Divided log
awk 'BEGIN{RS="\n\n"}{i++}{print > "imdb_log_"i}' profile_log_$1
mkdir $1_log && mv imdb_log_* $1_log
ps -ef|grep 'serving'|grep -v grep|cut -c 9-15 | xargs kill -9
......@@ -4,18 +4,42 @@
```
python -m paddle_serving_app.package --get_model ocr_rec
tar -xzvf ocr_rec.tar.gz
python -m paddle_serving_app.package --get_model ocr_det
tar -xzvf ocr_det.tar.gz
```
## RPC Service
### Start Service
For the following two code block, please check your devices and pick one
for GPU device
```
python -m paddle_serving_server_gpu.serve --model ocr_rec_model --port 9292 --gpu_id 0
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0
```
for CPU device
```
python -m paddle_serving_server.serve --model ocr_rec_model --port 9292
python -m paddle_serving_server.serve --model ocr_det_model --port 9293
```
### Client Prediction
```
python test_ocr_rec_client.py
python ocr_rpc_client.py
```
## Web Service
### Start Service
```
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0
python ocr_web_server.py
```
### Client Prediction
```
sh ocr_web_client.sh
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, File2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
import time
import re
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def get_rotate_crop_image(img, points):
#img = cv2.imread(img)
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0], \
[img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def read_det_box_file(filename):
with open(filename, 'r') as f:
line = f.readline()
a, b, c = int(line.split(' ')[0]), int(line.split(' ')[1]), int(
line.split(' ')[2])
dt_boxes = np.zeros((a, b, c)).astype(np.float32)
line = f.readline()
for i in range(a):
for j in range(b):
line = f.readline()
dt_boxes[i, j, 0], dt_boxes[i, j, 1] = float(
line.split(' ')[0]), float(line.split(' ')[1])
line = f.readline()
def resize_norm_img(img, max_wh_ratio):
import math
imgC, imgH, imgW = 3, 32, 320
imgW = int(32 * max_wh_ratio)
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def main():
client1 = Client()
client1.load_client_config("ocr_det_client/serving_client_conf.prototxt")
client1.connect(["127.0.0.1:9293"])
client2 = Client()
client2.load_client_config("ocr_rec_client/serving_client_conf.prototxt")
client2.connect(["127.0.0.1:9292"])
read_image_file = File2Image()
preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
filter_func = FilterBoxes(10, 10)
ocr_reader = OCRReader()
files = [
"./imgs/{}".format(f) for f in os.listdir('./imgs')
if re.match(r'[0-9]+.*\.jpg|[0-9]+.*\.png', f)
]
#files = ["2.jpg"]*30
#files = ["rctw/rctw/train/images/image_{}.jpg".format(i) for i in range(500)]
time_all = 0
time_det_all = 0
time_rec_all = 0
for name in files:
#print(name)
im = read_image_file(name)
ori_h, ori_w, _ = im.shape
time1 = time.time()
img = preprocess(im)
_, new_h, new_w = img.shape
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
#print(new_h, new_w, ori_h, ori_w)
time_before_det = time.time()
outputs = client1.predict(feed={"image": img}, fetch=["concat_1.tmp_0"])
time_after_det = time.time()
time_det_all += (time_after_det - time_before_det)
#print(outputs)
dt_boxes_list = post_func(outputs["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes)
feed_list = []
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = resize_norm_img(img, max_wh_ratio)
#norm_img = norm_img[np.newaxis, :]
feed = {"image": norm_img}
feed_list.append(feed)
#fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
fetch = ["ctc_greedy_decoder_0.tmp_0"]
time_before_rec = time.time()
if len(feed_list) == 0:
continue
fetch_map = client2.predict(feed=feed_list, fetch=fetch)
time_after_rec = time.time()
time_rec_all += (time_after_rec - time_before_rec)
rec_res = ocr_reader.postprocess(fetch_map)
#for res in rec_res:
# print(res[0].encode("utf-8"))
time2 = time.time()
time_all += (time2 - time1)
print("rpc+det time: {}".format(time_all / len(files)))
if __name__ == '__main__':
main()
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/others/1.jpg"}], "fetch": ["res"]}' http://127.0.0.1:9292/ocr/prediction
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
import sys
import numpy as np
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
from paddle_serving_server_gpu.web_service import WebService
import time
import re
class OCRService(WebService):
def init_det_client(self, det_port, det_client_config):
self.det_preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
self.det_client = Client()
self.det_client.load_client_config(det_client_config)
self.det_client.connect(["127.0.0.1:{}".format(det_port)])
def preprocess(self, feed=[], fetch=[]):
img_url = feed[0]["image"]
#print(feed, img_url)
read_from_url = URL2Image()
im = read_from_url(img_url)
ori_h, ori_w, _ = im.shape
det_img = self.det_preprocess(im)
#print("det_img", det_img, det_img.shape)
det_out = self.det_client.predict(
feed={"image": det_img}, fetch=["concat_1.tmp_0"])
#print("det_out", det_out)
def sorted_boxes(dt_boxes):
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def get_rotate_crop_image(img, points):
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
pts_std = np.float32([[0, 0], [img_crop_width, 0], \
[img_crop_width, img_crop_height], [0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img_crop,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def resize_norm_img(img, max_wh_ratio):
import math
imgC, imgH, imgW = 3, 32, 320
imgW = int(32 * max_wh_ratio)
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
_, new_h, new_w = det_img.shape
filter_func = FilterBoxes(10, 10)
post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
dt_boxes_list = post_func(det_out["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
dt_boxes = sorted_boxes(dt_boxes)
feed_list = []
img_list = []
max_wh_ratio = 0
for i, dtbox in enumerate(dt_boxes):
boximg = get_rotate_crop_image(im, dt_boxes[i])
img_list.append(boximg)
h, w = boximg.shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for img in img_list:
norm_img = resize_norm_img(img, max_wh_ratio)
feed = {"image": norm_img}
feed_list.append(feed)
fetch = ["ctc_greedy_decoder_0.tmp_0"]
#print("feed_list", feed_list)
return feed_list, fetch
def postprocess(self, feed={}, fetch=[], fetch_map=None):
#print(fetch_map)
ocr_reader = OCRReader()
rec_res = ocr_reader.postprocess(fetch_map)
res_lst = []
for res in rec_res:
res_lst.append(res[0])
fetch_map["res"] = res_lst
del fetch_map["ctc_greedy_decoder_0.tmp_0"]
del fetch_map["ctc_greedy_decoder_0.tmp_0.lod"]
return fetch_map
ocr_service = OCRService(name="ocr")
ocr_service.load_model_config("ocr_rec_model")
ocr_service.prepare_server(workdir="workdir", port=9292)
ocr_service.init_det_client(
det_port=9293,
det_client_config="ocr_det_client/serving_client_conf.prototxt")
ocr_service.run_rpc_service()
ocr_service.run_web_service()
......@@ -31,7 +31,7 @@ with open(profile_file) as f:
if line[0] == "PROFILE":
prase(line[2])
print("thread num :{}".format(thread_num))
print("thread_num: {}".format(thread_num))
for name in time_dict:
print("{} cost :{} s in each thread ".format(name, time_dict[name] / (
print("{} cost: {}s in each thread ".format(name, time_dict[name] / (
1000000.0 * float(thread_num))))
......@@ -24,14 +24,15 @@ class ServingModels(object):
"SentimentAnalysis"] = ["senta_bilstm", "senta_bow", "senta_cnn"]
self.model_dict["SemanticRepresentation"] = ["ernie"]
self.model_dict["ChineseWordSegmentation"] = ["lac"]
self.model_dict["ObjectDetection"] = ["faster_rcnn", "yolov4"]
self.model_dict[
"ObjectDetection"] = ["faster_rcnn", "yolov4", "blazeface"]
self.model_dict["ImageSegmentation"] = [
"unet", "deeplabv3", "deeplabv3+cityscapes"
]
self.model_dict["ImageClassification"] = [
"resnet_v2_50_imagenet", "mobilenet_v2_imagenet"
]
self.model_dict["TextDetection"] = ["ocr_detection"]
self.model_dict["TextDetection"] = ["ocr_det"]
self.model_dict["OCR"] = ["ocr_rec"]
image_class_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/"
......
......@@ -29,6 +29,7 @@ def normalize(img, mean, std, channel_first):
else:
img_mean = np.array(mean).reshape((1, 1, 3))
img_std = np.array(std).reshape((1, 1, 3))
img = np.array(img).astype("float32")
img -= img_mean
img /= img_std
return img
......
......@@ -440,6 +440,30 @@ class RCNNPostprocess(object):
self.label_file, self.output_dir)
class BlazeFacePostprocess(RCNNPostprocess):
def clip_bbox(self, bbox, im_size=None):
h = 1. if im_size is None else im_size[0]
w = 1. if im_size is None else im_size[1]
xmin = max(min(bbox[0], w), 0.)
ymin = max(min(bbox[1], h), 0.)
xmax = max(min(bbox[2], w), 0.)
ymax = max(min(bbox[3], h), 0.)
return xmin, ymin, xmax, ymax
def _get_bbox_result(self, fetch_map, fetch_name, clsid2catid):
result = {}
is_bbox_normalized = True #for blaze face, set true here
output = fetch_map[fetch_name]
lod = [fetch_map[fetch_name + '.lod']]
lengths = self._offset_to_lengths(lod)
np_data = np.array(output)
result['bbox'] = (np_data, lengths)
result['im_id'] = np.array([[0]])
result["im_shape"] = np.array(fetch_map["im_shape"]).astype(np.int32)
bbox_results = self._bbox2out([result], clsid2catid, is_bbox_normalized)
return bbox_results
class Sequential(object):
"""
Args:
......@@ -653,7 +677,7 @@ class Resize(object):
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
(w, h), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
......
......@@ -182,22 +182,26 @@ class OCRReader(object):
return norm_img_batch[0]
def postprocess(self, outputs):
def postprocess(self, outputs, with_score=False):
rec_res = []
rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = outputs["softmax_0.tmp_0.lod"]
rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"]
if with_score:
predict_lod = outputs["softmax_0.tmp_0.lod"]
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = outputs["softmax_0.tmp_0"][beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
if with_score:
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = outputs["softmax_0.tmp_0"][beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
else:
rec_res.append([preds_text])
return rec_res
......@@ -39,11 +39,11 @@ def benchmark_args():
def show_latency(latency_list):
latency_array = np.array(latency_list)
info = "latency:\n"
info += "mean :{} ms\n".format(np.mean(latency_array))
info += "median :{} ms\n".format(np.median(latency_array))
info += "80 percent :{} ms\n".format(np.percentile(latency_array, 80))
info += "90 percent :{} ms\n".format(np.percentile(latency_array, 90))
info += "99 percent :{} ms\n".format(np.percentile(latency_array, 99))
info += "mean: {}ms\n".format(np.mean(latency_array))
info += "median: {}ms\n".format(np.median(latency_array))
info += "80 percent: {}ms\n".format(np.percentile(latency_array, 80))
info += "90 percent: {}ms\n".format(np.percentile(latency_array, 90))
info += "99 percent: {}ms\n".format(np.percentile(latency_array, 99))
sys.stderr.write(info)
......
......@@ -25,6 +25,7 @@ from contextlib import closing
import collections
import fcntl
import shutil
import numpy as np
import grpc
from .proto import multi_lang_general_model_service_pb2
......@@ -230,7 +231,7 @@ class Server(object):
infer_service.workflows.extend(["workflow1"])
self.infer_service_conf.services.extend([infer_service])
def _prepare_resource(self, workdir):
def _prepare_resource(self, workdir, cube_conf):
self.workdir = workdir
if self.resource_conf == None:
with open("{}/{}".format(workdir, self.general_model_config_fn),
......@@ -242,6 +243,11 @@ class Server(object):
if "dist_kv" in node.name:
self.resource_conf.cube_config_path = workdir
self.resource_conf.cube_config_file = self.cube_config_fn
if cube_conf == None:
raise ValueError(
"Please set the path of cube.conf while use dist_kv op."
)
shutil.copy(cube_conf, workdir)
if "quant" in node.name:
self.resource_conf.cube_quant_bits = 8
self.resource_conf.model_toolkit_path = workdir
......@@ -366,7 +372,11 @@ class Server(object):
os.chdir(self.cur_path)
self.bin_path = self.server_path + "/serving"
def prepare_server(self, workdir=None, port=9292, device="cpu"):
def prepare_server(self,
workdir=None,
port=9292,
device="cpu",
cube_conf=None):
if workdir == None:
workdir = "./tmp"
os.system("mkdir {}".format(workdir))
......@@ -377,7 +387,7 @@ class Server(object):
if not self.port_is_available(port):
raise SystemExit("Port {} is already used".format(port))
self.set_port(port)
self._prepare_resource(workdir)
self._prepare_resource(workdir, cube_conf)
self._prepare_engine(self.model_config_paths, device)
self._prepare_infer_service(port)
self.workdir = workdir
......@@ -645,7 +655,11 @@ class MultiLangServer(object):
server_config_paths)
self.bclient_config_path_ = client_config_path
def prepare_server(self, workdir=None, port=9292, device="cpu"):
def prepare_server(self,
workdir=None,
port=9292,
device="cpu",
cube_conf=None):
if not self._port_is_available(port):
raise SystemExit("Prot {} is already used".format(port))
default_port = 12000
......@@ -656,7 +670,10 @@ class MultiLangServer(object):
self.port_list_.append(default_port + i)
break
self.bserver_.prepare_server(
workdir=workdir, port=self.port_list_[0], device=device)
workdir=workdir,
port=self.port_list_[0],
device=device,
cube_conf=cube_conf)
self.set_port(port)
def _launch_brpc_service(self, bserver):
......
......@@ -26,7 +26,7 @@ from contextlib import closing
import argparse
import collections
import fcntl
import shutil
import numpy as np
import grpc
from .proto import multi_lang_general_model_service_pb2
......@@ -285,7 +285,7 @@ class Server(object):
infer_service.workflows.extend(["workflow1"])
self.infer_service_conf.services.extend([infer_service])
def _prepare_resource(self, workdir):
def _prepare_resource(self, workdir, cube_conf):
self.workdir = workdir
if self.resource_conf == None:
with open("{}/{}".format(workdir, self.general_model_config_fn),
......@@ -297,6 +297,11 @@ class Server(object):
if "dist_kv" in node.name:
self.resource_conf.cube_config_path = workdir
self.resource_conf.cube_config_file = self.cube_config_fn
if cube_conf == None:
raise ValueError(
"Please set the path of cube.conf while use dist_kv op."
)
shutil.copy(cube_conf, workdir)
self.resource_conf.model_toolkit_path = workdir
self.resource_conf.model_toolkit_file = self.model_toolkit_fn
self.resource_conf.general_model_path = workdir
......@@ -406,7 +411,11 @@ class Server(object):
os.chdir(self.cur_path)
self.bin_path = self.server_path + "/serving"
def prepare_server(self, workdir=None, port=9292, device="cpu"):
def prepare_server(self,
workdir=None,
port=9292,
device="cpu",
cube_conf=None):
if workdir == None:
workdir = "./tmp"
os.system("mkdir {}".format(workdir))
......@@ -418,7 +427,7 @@ class Server(object):
raise SystemExit("Port {} is already used".format(port))
self.set_port(port)
self._prepare_resource(workdir)
self._prepare_resource(workdir, cube_conf)
self._prepare_engine(self.model_config_paths, device)
self._prepare_infer_service(port)
self.workdir = workdir
......@@ -690,7 +699,11 @@ class MultiLangServer(object):
server_config_paths)
self.bclient_config_path_ = client_config_path
def prepare_server(self, workdir=None, port=9292, device="cpu"):
def prepare_server(self,
workdir=None,
port=9292,
device="cpu",
cube_conf=None):
if not self._port_is_available(port):
raise SystemExit("Prot {} is already used".format(port))
default_port = 12000
......@@ -701,7 +714,10 @@ class MultiLangServer(object):
self.port_list_.append(default_port + i)
break
self.bserver_.prepare_server(
workdir=workdir, port=self.port_list_[0], device=device)
workdir=workdir,
port=self.port_list_[0],
device=device,
cube_conf=cube_conf)
self.set_port(port)
def _launch_brpc_service(self, bserver):
......
......@@ -27,7 +27,7 @@ import logging
import enum
import copy
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
class ChannelDataEcode(enum.Enum):
......@@ -92,7 +92,16 @@ class ChannelData(object):
def check_dictdata(dictdata):
ecode = ChannelDataEcode.OK.value
error_info = None
if not isinstance(dictdata, dict):
if isinstance(dictdata, list):
# batch data
for sample in dictdata:
if not isinstance(sample, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(sample))
break
elif not isinstance(dictdata, dict):
# batch size = 1
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(dictdata))
......@@ -102,12 +111,32 @@ class ChannelData(object):
def check_npdata(npdata):
ecode = ChannelDataEcode.OK.value
error_info = None
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
break
if isinstance(npdata, list):
# batch data
for sample in npdata:
if not isinstance(sample, dict):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(sample))
break
for _, value in sample.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
return ecode, error_info
elif isinstance(npdata, dict):
# batch_size = 1
for _, value in npdata.items():
if not isinstance(value, np.ndarray):
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be np.ndarray, but get {}.".format(type(value))
break
else:
ecode = ChannelDataEcode.TYPE_ERROR.value
error_info = "the value of data must " \
"be dict, but get {}.".format(type(npdata))
return ecode, error_info
def parse(self):
......
......@@ -19,13 +19,14 @@ from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
import os
from numpy import *
from .proto import pipeline_service_pb2
from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
_op_name_gen = NameGenerator("Op")
......@@ -59,6 +60,10 @@ class Op(object):
self._outputs = []
self._profiler = None
# only for multithread
self._for_init_op_lock = threading.Lock()
self._succ_init_op = False
def init_profiler(self, profiler):
self._profiler = profiler
......@@ -71,18 +76,19 @@ class Op(object):
fetch_names):
if self.with_serving == False:
_LOGGER.debug("{} no client".format(self.name))
return
return None
_LOGGER.debug("{} client_config: {}".format(self.name, client_config))
_LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
if client_type == 'brpc':
self._client = Client()
self._client.load_client_config(client_config)
client = Client()
client.load_client_config(client_config)
elif client_type == 'grpc':
self._client = MultiLangClient()
client = MultiLangClient()
else:
raise ValueError("unknow client type: {}".format(client_type))
self._client.connect(server_endpoints)
client.connect(server_endpoints)
self._fetch_names = fetch_names
return client
def _get_input_channel(self):
return self._input
......@@ -130,19 +136,17 @@ class Op(object):
(_, input_dict), = input_dicts.items()
return input_dict
def process(self, feed_dict):
def process(self, client_predict_handler, feed_dict):
err, err_info = ChannelData.check_npdata(feed_dict)
if err != 0:
raise NotImplementedError(
"{} Please override preprocess func.".format(err_info))
_LOGGER.debug(self._log('feed_dict: {}'.format(feed_dict)))
_LOGGER.debug(self._log('fetch: {}'.format(self._fetch_names)))
call_result = self._client.predict(
call_result = client_predict_handler(
feed=feed_dict, fetch=self._fetch_names)
_LOGGER.debug(self._log("get call_result"))
return call_result
def postprocess(self, fetch_dict):
def postprocess(self, input_dict, fetch_dict):
return fetch_dict
def stop(self):
......@@ -174,7 +178,7 @@ class Op(object):
p = multiprocessing.Process(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type))
self._get_output_channels(), client_type, False))
p.start()
proces.append(p)
return proces
......@@ -185,12 +189,12 @@ class Op(object):
t = threading.Thread(
target=self._run,
args=(concurrency_idx, self._get_input_channel(),
self._get_output_channels(), client_type))
self._get_output_channels(), client_type, True))
t.start()
threads.append(t)
return threads
def load_user_resources(self):
def init_op(self):
pass
def _run_preprocess(self, parsed_data, data_id, log_func):
......@@ -222,13 +226,15 @@ class Op(object):
data_id=data_id)
return preped_data, error_channeldata
def _run_process(self, preped_data, data_id, log_func):
def _run_process(self, client_predict_handler, preped_data, data_id,
log_func):
midped_data, error_channeldata = None, None
if self.with_serving:
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
try:
midped_data = self.process(preped_data)
midped_data = self.process(client_predict_handler,
preped_data)
except Exception as e:
ecode = ChannelDataEcode.UNKNOW.value
error_info = log_func(e)
......@@ -237,7 +243,11 @@ class Op(object):
for i in range(self._retry):
try:
midped_data = func_timeout.func_timeout(
self._timeout, self.process, args=(preped_data, ))
self._timeout,
self.process,
args=(
client_predict_handler,
preped_data, ))
except func_timeout.FunctionTimedOut as e:
if i + 1 >= self._retry:
ecode = ChannelDataEcode.TIMEOUT.value
......@@ -267,10 +277,10 @@ class Op(object):
midped_data = preped_data
return midped_data, error_channeldata
def _run_postprocess(self, midped_data, data_id, log_func):
def _run_postprocess(self, input_dict, midped_data, data_id, log_func):
output_data, error_channeldata = None, None
try:
postped_data = self.postprocess(midped_data)
postped_data = self.postprocess(input_dict, midped_data)
except Exception as e:
error_info = log_func(e)
_LOGGER.error(error_info)
......@@ -303,8 +313,8 @@ class Op(object):
data_id=data_id)
return output_data, error_channeldata
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
use_multithread):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
......@@ -315,12 +325,30 @@ class Op(object):
log = get_log_func(op_info_prefix)
tid = threading.current_thread().ident
client = None
client_predict_handler = None
# create client based on client_type
self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
try:
client = self.init_client(client_type, self._client_config,
self._server_endpoints, self._fetch_names)
if client is not None:
client_predict_handler = client.predict
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
# load user resources
self.load_user_resources()
try:
if use_multithread:
with self._for_init_op_lock:
if not self._succ_init_op:
self.init_op()
self._succ_init_op = True
else:
self.init_op()
except Exception as e:
_LOGGER.error(log(e))
os._exit(-1)
self._is_run = True
while self._is_run:
......@@ -349,8 +377,8 @@ class Op(object):
# process
self._profiler_record("{}-midp#{}_0".format(op_info_prefix, tid))
midped_data, error_channeldata = self._run_process(preped_data,
data_id, log)
midped_data, error_channeldata = self._run_process(
client_predict_handler, preped_data, data_id, log)
self._profiler_record("{}-midp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -359,8 +387,8 @@ class Op(object):
# postprocess
self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
output_data, error_channeldata = self._run_postprocess(midped_data,
data_id, log)
output_data, error_channeldata = self._run_postprocess(
parsed_data, midped_data, data_id, log)
self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
if error_channeldata is not None:
self._push_to_output_channels(error_channeldata,
......@@ -384,7 +412,11 @@ class RequestOp(Op):
super(RequestOp, self).__init__(
name="#G", input_ops=[], concurrency=concurrency)
# load user resources
self.load_user_resources()
try:
self.init_op()
except Exception as e:
_LOGGER.error(e)
os._exit(-1)
def unpack_request_package(self, request):
dictdata = {}
......@@ -405,7 +437,11 @@ class ResponseOp(Op):
super(ResponseOp, self).__init__(
name="#R", input_ops=input_ops, concurrency=concurrency)
# load user resources
self.load_user_resources()
try:
self.init_op()
except Exception as e:
_LOGGER.error(e)
os._exit(-1)
def pack_response_package(self, channeldata):
resp = pipeline_service_pb2.Response()
......@@ -450,17 +486,26 @@ class VirtualOp(Op):
def add_virtual_pred_op(self, op):
self._virtual_pred_ops.append(op)
def _actual_pred_op_names(self, op):
if not isinstance(op, VirtualOp):
return [op.name]
names = []
for x in op._virtual_pred_ops:
names.extend(self._actual_pred_op_names(x))
return names
def add_output_channel(self, channel):
if not isinstance(channel, (ThreadChannel, ProcessChannel)):
raise TypeError(
self._log('output channel must be Channel type, not {}'.format(
type(channel))))
for op in self._virtual_pred_ops:
channel.add_producer(op.name)
for op_name in self._actual_pred_op_names(op):
channel.add_producer(op_name)
self._outputs.append(channel)
def _run(self, concurrency_idx, input_channel, output_channels,
client_type):
def _run(self, concurrency_idx, input_channel, output_channels, client_type,
use_multithread):
def get_log_func(op_info_prefix):
def log_func(info_str):
return "{} {}".format(op_info_prefix, info_str)
......
......@@ -20,7 +20,7 @@ import functools
from .proto import pipeline_service_pb2
from .proto import pipeline_service_pb2_grpc
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
class PipelineClient(object):
......@@ -52,7 +52,7 @@ class PipelineClient(object):
return {"ecode": resp.ecode, "error_info": resp.error_info}
fetch_map = {"ecode": resp.ecode}
for idx, key in enumerate(resp.key):
if key not in fetch:
if fetch is not None and key not in fetch:
continue
data = resp.value[idx]
try:
......@@ -62,16 +62,16 @@ class PipelineClient(object):
fetch_map[key] = data
return fetch_map
def predict(self, feed_dict, fetch, asyn=False):
def predict(self, feed_dict, fetch=None, asyn=False):
if not isinstance(feed_dict, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if not isinstance(fetch, list):
if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict)
if not asyn:
resp = self._stub.inference(req)
return self._unpack_response_package(resp)
return self._unpack_response_package(resp, fetch)
else:
call_future = self._stub.inference.future(req)
return PipelinePredictFuture(
......
......@@ -45,7 +45,7 @@ from .channel import ThreadChannel, ProcessChannel, ChannelData, ChannelDataEcod
from .profiler import TimeProfiler
from .util import NameGenerator
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
_profiler = TimeProfiler()
......@@ -235,6 +235,10 @@ class PipelineServer(object):
return use_ops, succ_ops_of_use_op
use_ops, out_degree_ops = get_use_ops(response_op)
_LOGGER.info("================= use op ==================")
for op in use_ops:
_LOGGER.info(op.name)
_LOGGER.info("===========================================")
if len(use_ops) <= 1:
raise Exception(
"Besides RequestOp and ResponseOp, there should be at least one Op in DAG."
......
......@@ -24,7 +24,7 @@ else:
raise Exception("Error Python version")
import time
_LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger()
class TimeProfiler(object):
......@@ -58,7 +58,7 @@ class TimeProfiler(object):
print_str += "{}_{}:{} ".format(name, tag, timestamp)
else:
tmp[name] = (tag, timestamp)
print_str += "\n"
print_str = "\n{}\n".format(print_str)
sys.stderr.write(print_str)
for name, item in tmp.items():
tag, timestamp = item
......
......@@ -229,10 +229,7 @@ function python_run_criteo_ctr_with_cube() {
check_cmd "mv models/data ./cube/"
check_cmd "mv models/ut_data ./"
cp ../../../build-server-$TYPE/output/bin/cube* ./cube/
mkdir -p $PYTHONROOT/lib/python2.7/site-packages/paddle_serving_server/serving-cpu-avx-openblas-0.1.3/
yes | cp ../../../build-server-$TYPE/output/demo/serving/bin/serving $PYTHONROOT/lib/python2.7/site-packages/paddle_serving_server/serving-cpu-avx-openblas-0.1.3/
sh cube_prepare.sh &
check_cmd "mkdir work_dir1 && cp cube/conf/cube.conf ./work_dir1/"
python test_server.py ctr_serving_model_kv &
sleep 5
check_cmd "python test_client.py ctr_client_conf/serving_client_conf.prototxt ./ut_data >score"
......@@ -257,10 +254,7 @@ function python_run_criteo_ctr_with_cube() {
check_cmd "mv models/data ./cube/"
check_cmd "mv models/ut_data ./"
cp ../../../build-server-$TYPE/output/bin/cube* ./cube/
mkdir -p $PYTHONROOT/lib/python2.7/site-packages/paddle_serving_server_gpu/serving-gpu-0.1.3/
yes | cp ../../../build-server-$TYPE/output/demo/serving/bin/serving $PYTHONROOT/lib/python2.7/site-packages/paddle_serving_server_gpu/serving-gpu-0.1.3/
sh cube_prepare.sh &
check_cmd "mkdir work_dir1 && cp cube/conf/cube.conf ./work_dir1/"
python test_server_gpu.py ctr_serving_model_kv &
sleep 5
# for warm up
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册