提交 f797d2b1 编写于 作者: B barriery 提交者: GitHub

Merge branch 'develop' into optimize-quantization-tool

......@@ -53,7 +53,7 @@ You may need to use a domestic mirror source (in China, you can use the Tsinghua
If you need install modules compiled with develop branch, please download packages from [latest packages list](./doc/LATEST_PACKAGES.md) and install with `pip install` command.
Client package support Centos 7 and Ubuntu 18, or you can use HTTP service without install client.
Packages of Paddle Serving support Centos 6/7 and Ubuntu 16/18, or you can use HTTP service without install client.
<h2 align="center"> Pre-built services with Paddle Serving</h2>
......
......@@ -55,7 +55,7 @@ pip install paddle-serving-server-gpu # GPU
如果需要使用develop分支编译的安装包,请从[最新安装包列表](./doc/LATEST_PACKAGES.md)中获取下载地址进行下载,使用`pip install`命令进行安装。
客户端安装包支持Centos 7和Ubuntu 18,或者您可以使用HTTP服务,这种情况下不需要安装客户端。
Paddle Serving安装包支持Centos 6/7和Ubuntu 16/18,或者您可以使用HTTP服务,这种情况下不需要安装客户端。
<h2 align="center"> Paddle Serving预装的服务 </h2>
......
......@@ -86,6 +86,64 @@ function(protobuf_generate_python SRCS)
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
function(grpc_protobuf_generate_python SRCS)
# shameless copy from https://github.com/Kitware/CMake/blob/master/Modules/FindProtobuf.cmake
if(NOT ARGN)
message(SEND_ERROR "Error: GRPC_PROTOBUF_GENERATE_PYTHON() called without any proto files")
return()
endif()
if(PROTOBUF_GENERATE_CPP_APPEND_PATH)
# Create an include path for each file specified
foreach(FIL ${ARGN})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(ABS_PATH ${ABS_FIL} PATH)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
else()
set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR})
endif()
if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS)
set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}")
endif()
if(DEFINED Protobuf_IMPORT_DIRS)
foreach(DIR ${Protobuf_IMPORT_DIRS})
get_filename_component(ABS_PATH ${DIR} ABSOLUTE)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
endif()
set(${SRCS})
foreach(FIL ${ARGN})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(FIL_WE ${FIL} NAME_WE)
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH)
get_filename_component(FIL_DIR ${FIL} DIRECTORY)
if(FIL_DIR)
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
endif()
endif()
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2_grpc.py")
add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2_grpc.py"
COMMAND ${PYTHON_EXECUTABLE} -m grpc_tools.protoc --python_out ${CMAKE_CURRENT_BINARY_DIR} --grpc_python_out ${CMAKE_CURRENT_BINARY_DIR} ${_protobuf_include_path} ${ABS_FIL}
DEPENDS ${ABS_FIL}
COMMENT "Running Python grpc protocol buffer compiler on ${FIL}"
VERBATIM )
endforeach()
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
# Print and set the protobuf library information,
# finish this cmake process and exit from this file.
macro(PROMPT_PROTOBUF_LIB)
......
......@@ -704,6 +704,15 @@ function(py_proto_compile TARGET_NAME)
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs})
endfunction()
function(py_grpc_proto_compile TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS)
cmake_parse_arguments(py_grpc_proto_compile "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(py_srcs)
grpc_protobuf_generate_python(py_srcs ${py_grpc_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs})
endfunction()
function(py_test TARGET_NAME)
if(WITH_TESTING)
set(options "")
......
......@@ -35,6 +35,10 @@ py_proto_compile(general_model_config_py_proto SRCS proto/general_model_config.p
add_custom_target(general_model_config_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(general_model_config_py_proto general_model_config_py_proto_init)
py_grpc_proto_compile(multi_lang_general_model_service_py_proto SRCS proto/multi_lang_general_model_service.proto)
add_custom_target(multi_lang_general_model_service_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(multi_lang_general_model_service_py_proto multi_lang_general_model_service_py_proto_init)
if (CLIENT)
py_proto_compile(sdk_configure_py_proto SRCS proto/sdk_configure.proto)
add_custom_target(sdk_configure_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
......@@ -51,6 +55,11 @@ add_custom_command(TARGET general_model_config_py_proto POST_BUILD
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_client/proto
COMMENT "Copy generated multi_lang_general_model_service proto file into directory paddle_serving_client/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
if (APP)
......@@ -77,6 +86,11 @@ add_custom_command(TARGET general_model_config_py_proto POST_BUILD
COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server/proto
COMMENT "Copy generated multi_lang_general_model_service proto file into directory paddle_serving_server/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else()
add_custom_command(TARGET server_config_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory
......@@ -95,5 +109,11 @@ add_custom_command(TARGET general_model_config_py_proto POST_BUILD
COMMENT "Copy generated general_model_config proto file into directory
paddle_serving_server_gpu/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET multi_lang_general_model_service_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto
COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_server_gpu/proto
COMMENT "Copy generated multi_lang_general_model_service proto file into directory paddle_serving_server_gpu/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
message Tensor {
optional bytes data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type = 5;
repeated int32 shape = 6;
repeated int32 lod = 7; // only for fetch tensor currently
};
message FeedInst { repeated Tensor tensor_array = 1; };
message FetchInst { repeated Tensor tensor_array = 1; };
message Request {
repeated FeedInst insts = 1;
repeated string feed_var_names = 2;
repeated string fetch_var_names = 3;
required bool is_python = 4 [ default = false ];
};
message Response {
repeated ModelOutput outputs = 1;
optional string tag = 2;
};
message ModelOutput {
repeated FetchInst insts = 1;
optional string engine_name = 2;
}
service MultiLangGeneralModelService {
rpc inference(Request) returns (Response) {}
};
......@@ -19,13 +19,11 @@ from __future__ import unicode_literals, absolute_import
import os
import sys
import time
import json
import requests
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
from batching import pad_batch_data
import tokenization
import requests
import json
from paddle_serving_client.utils import benchmark_args, show_latency
from paddle_serving_app.reader import ChineseBertReader
args = benchmark_args()
......@@ -36,42 +34,105 @@ def single_func(idx, resource):
dataset = []
for line in fin:
dataset.append(line.strip())
profile_flags = False
latency_flags = False
if os.getenv("FLAGS_profile_client"):
profile_flags = True
if os.getenv("FLAGS_serving_latency"):
latency_flags = True
latency_list = []
if args.request == "rpc":
reader = ChineseBertReader(vocab_file="vocab.txt", max_seq_len=20)
reader = ChineseBertReader({"max_seq_len": 128})
fetch = ["pooled_output"]
client = Client()
client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for i in range(1000):
if args.batch_size == 1:
feed_dict = reader.process(dataset[i])
result = client.predict(feed=feed_dict, fetch=fetch)
for i in range(turns):
if args.batch_size >= 1:
l_start = time.time()
feed_batch = []
b_start = time.time()
for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[bi]))
b_end = time.time()
if profile_flags:
sys.stderr.write(
"PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}\n".format(
os.getpid(),
int(round(b_start * 1000000)),
int(round(b_end * 1000000))))
result = client.predict(feed=feed_batch, fetch=fetch)
l_end = time.time()
if latency_flags:
latency_list.append(l_end * 1000 - l_start * 1000)
else:
print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http":
reader = ChineseBertReader({"max_seq_len": 128})
fetch = ["pooled_output"]
server = "http://" + resource["endpoint"][idx % len(resource[
"endpoint"])] + "/bert/prediction"
start = time.time()
header = {"Content-Type": "application/json"}
for i in range(1000):
dict_data = {"words": dataset[i], "fetch": ["pooled_output"]}
r = requests.post(
'http://{}/bert/prediction'.format(resource["endpoint"][
idx % len(resource["endpoint"])]),
data=json.dumps(dict_data),
headers=header)
for i in range(turns):
if args.batch_size >= 1:
l_start = time.time()
feed_batch = []
b_start = time.time()
for bi in range(args.batch_size):
feed_batch.append({"words": dataset[bi]})
req = json.dumps({"feed": feed_batch, "fetch": fetch})
b_end = time.time()
if profile_flags:
sys.stderr.write(
"PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}\n".format(
os.getpid(),
int(round(b_start * 1000000)),
int(round(b_end * 1000000))))
result = requests.post(
server,
data=req,
headers={"Content-Type": "application/json"})
l_end = time.time()
if latency_flags:
latency_list.append(l_end * 1000 - l_start * 1000)
else:
print("unsupport batch size {}".format(args.batch_size))
else:
raise ValueError("not implemented {} request".format(args.request))
end = time.time()
return [[end - start]]
if latency_flags:
return [[end - start], latency_list]
else:
return [[end - start]]
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9292"]
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
turns = 10
start = time.time()
result = multi_thread_runner.run(
single_func, args.thread, {"endpoint": endpoint_list,
"turns": turns})
end = time.time()
total_cost = end - start
avg_cost = 0
for i in range(args.thread):
avg_cost += result[0][i]
avg_cost = avg_cost / args.thread
print("average total cost {} s.".format(avg_cost))
print("total cost :{} s".format(total_cost))
print("each thread cost :{} s. ".format(avg_cost))
print("qps :{} samples/s".format(args.batch_size * args.thread * turns /
total_cost))
if os.getenv("FLAGS_serving_latency"):
show_latency(result[1])
rm profile_log
for thread_num in 1 2 4 8 16
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FLAGS_profile_server=1
export FLAGS_profile_client=1
export FLAGS_serving_latency=1
python3 -m paddle_serving_server_gpu.serve --model $1 --port 9292 --thread 4 --gpu_ids 0,1,2,3 --mem_optim False --ir_optim True 2> elog > stdlog &
sleep 5
#warm up
python3 benchmark.py --thread 8 --batch_size 1 --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
for thread_num in 4 8 16
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log
for batch_size in 1 4 16 64 256
do
python3 benchmark.py --thread $thread_num --batch_size $batch_size --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "model name :" $1
echo "thread num :" $thread_num
echo "batch size :" $batch_size
echo "=================Done===================="
echo "model name :$1" >> profile_log_$1
echo "batch size :$batch_size" >> profile_log_$1
python3 ../util/show_profile.py profile $thread_num >> profile_log_$1
tail -n 8 profile >> profile_log_$1
echo "" >> profile_log_$1
done
done
ps -ef|grep 'serving'|grep -v grep|cut -c 9-15 | xargs kill -9
# -*- coding: utf-8 -*-
#
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from __future__ import unicode_literals, absolute_import
import os
import sys
import time
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
from batching import pad_batch_data
import tokenization
import requests
import json
from bert_reader import BertReader
args = benchmark_args()
def single_func(idx, resource):
fin = open("data-c.txt")
dataset = []
for line in fin:
dataset.append(line.strip())
profile_flags = False
if os.environ["FLAGS_profile_client"]:
profile_flags = True
if args.request == "rpc":
reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
fetch = ["pooled_output"]
client = Client()
client.load_client_config(args.model)
client.connect([resource["endpoint"][idx % len(resource["endpoint"])]])
start = time.time()
for i in range(1000):
if args.batch_size >= 1:
feed_batch = []
b_start = time.time()
for bi in range(args.batch_size):
feed_batch.append(reader.process(dataset[bi]))
b_end = time.time()
if profile_flags:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
os.getpid(),
int(round(b_start * 1000000)),
int(round(b_end * 1000000))))
result = client.predict(feed=feed_batch, fetch=fetch)
else:
print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http":
raise ("no batch predict for http")
end = time.time()
return [[end - start]]
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9292"]
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
avg_cost = 0
for i in range(args.thread):
avg_cost += result[0][i]
avg_cost = avg_cost / args.thread
print("average total cost {} s.".format(avg_cost))
rm profile_log
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle_serving_server_gpu.serve --model bert_seq20_model/ --port 9295 --thread 4 --gpu_ids 0,1,2,3 2> elog > stdlog &
sleep 5
for thread_num in 1 2 4 8 16
do
for batch_size in 1 2 4 8 16 32 64 128 256 512
do
$PYTHONROOT/bin/python benchmark_batch.py --thread $thread_num --batch_size $batch_size --model serving_client_conf/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "thread num: ", $thread_num
echo "batch size: ", $batch_size
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log
done
done
......@@ -14,15 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import paddlehub as hub
import ujson
import random
import time
from paddlehub.common.logger import logger
import socket
from paddle_serving_client import Client
from paddle_serving_client.utils import benchmark_args
from paddle_serving_app.reader import ChineseBertReader
......
......@@ -17,6 +17,6 @@
mkdir -p cube_model
mkdir -p cube/data
./seq_generator ctr_serving_model/SparseFeatFactors ./cube_model/feature
./cube/cube-builder -dict_name=test_dict -job_mode=base -last_version=0 -cur_version=0 -depend_version=0 -input_path=./cube_model -output_path=./cube/data -shard_num=1 -only_build=false
./cube/cube-builder -dict_name=test_dict -job_mode=base -last_version=0 -cur_version=0 -depend_version=0 -input_path=./cube_model -output_path=${PWD}/cube/data -shard_num=1 -only_build=false
mv ./cube/data/0_0/test_dict_part0/* ./cube/data/
cd cube && ./cube
......@@ -17,6 +17,6 @@
mkdir -p cube_model
mkdir -p cube/data
./seq_generator ctr_serving_model/SparseFeatFactors ./cube_model/feature 8
./cube/cube-builder -dict_name=test_dict -job_mode=base -last_version=0 -cur_version=0 -depend_version=0 -input_path=./cube_model -output_path=./cube/data -shard_num=1 -only_build=false
./cube/cube-builder -dict_name=test_dict -job_mode=base -last_version=0 -cur_version=0 -depend_version=0 -input_path=./cube_model -output_path=${PWD}/cube/data -shard_num=1 -only_build=false
mv ./cube/data/0_0/test_dict_part0/* ./cube/data/
cd cube && ./cube
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import MultiLangClient
import sys
client = MultiLangClient()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"])
import paddle
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1)
for data in test_reader():
future = client.predict(feed={"x": data[0][0]}, fetch=["price"], asyn=True)
fetch_map = future.result()
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import os
import sys
from paddle_serving_server import OpMaker
from paddle_serving_server import OpSeqMaker
from paddle_serving_server import MultiLangServer
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(response_op)
server = MultiLangServer()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.load_model_config(sys.argv[1])
server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
server.run_server()
......@@ -73,7 +73,7 @@ def single_func(idx, resource):
print("unsupport batch size {}".format(args.batch_size))
elif args.request == "http":
py_version = 2
py_version = sys.version_info[0]
server = "http://" + resource["endpoint"][idx % len(resource[
"endpoint"])] + "/image/prediction"
start = time.time()
......@@ -93,7 +93,7 @@ def single_func(idx, resource):
if __name__ == '__main__':
multi_thread_runner = MultiThreadRunner()
endpoint_list = ["127.0.0.1:9696"]
endpoint_list = ["127.0.0.1:9393"]
#endpoint_list = endpoint_list + endpoint_list + endpoint_list
result = multi_thread_runner.run(single_func, args.thread,
{"endpoint": endpoint_list})
......
rm profile_log
for thread_num in 1 2 4 8
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FLAGS_profile_server=1
export FLAGS_profile_client=1
python -m paddle_serving_server_gpu.serve --model $1 --port 9292 --thread 4 --gpu_ids 0,1,2,3 2> elog > stdlog &
sleep 5
#warm up
$PYTHONROOT/bin/python benchmark.py --thread 8 --batch_size 1 --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
for thread_num in 4 8 16
do
for batch_size in 1 2 4 8 16 32 64 128
for batch_size in 1 4 16 64 256
do
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model ResNet50_vd_client_config/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "========================================"
echo "batch size : $batch_size" >> profile_log
$PYTHONROOT/bin/python benchmark.py --thread $thread_num --batch_size $batch_size --model $2/serving_client_conf.prototxt --request rpc > profile 2>&1
echo "model name :" $1
echo "thread num :" $thread_num
echo "batch size :" $batch_size
echo "=================Done===================="
echo "model name :$1" >> profile_log
echo "batch size :$batch_size" >> profile_log
$PYTHONROOT/bin/python ../util/show_profile.py profile $thread_num >> profile_log
tail -n 1 profile >> profile_log
tail -n 8 profile >> profile_log
done
done
ps -ef|grep 'serving'|grep -v grep|cut -c 9-15 | xargs kill -9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, File2Image, ResizeByFactor
from paddle_serving_app.reader import Div, Normalize, Transpose
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
client = Client()
client.load_client_config("ocr_det_client/serving_client_conf.prototxt")
client.connect(["127.0.0.1:9494"])
read_image_file = File2Image()
preprocess = Sequential([
ResizeByFactor(32, 960), Div(255),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
(2, 0, 1))
])
post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
})
filter_func = FilterBoxes(10, 10)
img = read_image_file(name)
ori_h, ori_w, _ = img.shape
img = preprocess(img)
new_h, new_w, _ = img.shape
ratio_list = [float(new_h) / ori_h, float(new_w) / ori_w]
outputs = client.predict(feed={"image": img}, fetch=["concat_1.tmp_0"])
dt_boxes_list = post_func(outputs["concat_1.tmp_0"], [ratio_list])
dt_boxes = filter_func(dt_boxes_list[0], [ori_h, ori_w])
......@@ -31,7 +31,7 @@ with open(profile_file) as f:
if line[0] == "PROFILE":
prase(line[2])
print("thread num {}".format(thread_num))
print("thread num :{}".format(thread_num))
for name in time_dict:
print("{} cost {} s in each thread ".format(name, time_dict[name] / (
print("{} cost :{} s in each thread ".format(name, time_dict[name] / (
1000000.0 * float(thread_num))))
......@@ -31,6 +31,7 @@ class ServingModels(object):
self.model_dict["ImageClassification"] = [
"resnet_v2_50_imagenet", "mobilenet_v2_imagenet"
]
self.model_dict["TextDetection"] = ["ocr_detection"]
self.model_dict["OCR"] = ["ocr_rec"]
image_class_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/"
......@@ -40,6 +41,7 @@ class ServingModels(object):
senta_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/"
semantic_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SemanticModel/"
wordseg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/LexicalAnalysis/"
ocr_det_url = "https://paddle-serving.bj.bcebos.com/ocr/"
self.url_dict = {}
......@@ -55,6 +57,7 @@ class ServingModels(object):
pack_url(self.model_dict, "ImageSegmentation", image_seg_url)
pack_url(self.model_dict, "ImageClassification", image_class_url)
pack_url(self.model_dict, "OCR", ocr_url)
pack_url(self.model_dict, "TextDetection", ocr_det_url)
def get_model_list(self):
return self.model_dict
......
......@@ -13,8 +13,9 @@
# limitations under the License.
from .chinese_bert_reader import ChineseBertReader
from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize
from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB
from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor
from .image_reader import RCNNPostprocess, SegPostprocess, PadStride
from .image_reader import DBPostProcess, FilterBoxes
from .lac_reader import LACReader
from .senta_reader import SentaReader
from .imdb_reader import IMDBDataset
......
......@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import os
import numpy as np
......@@ -18,6 +21,8 @@ import base64
import sys
from . import functional as F
from PIL import Image, ImageDraw
from shapely.geometry import Polygon
import pyclipper
import json
_cv2_interpolation_to_str = {cv2.INTER_LINEAR: "cv2.INTER_LINEAR", None: "None"}
......@@ -43,6 +48,196 @@ def generate_colormap(num_classes):
return color_map
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self, params):
self.thresh = params['thresh']
self.box_thresh = params['box_thresh']
self.max_candidates = params['max_candidates']
self.unclip_ratio = params['unclip_ratio']
self.min_size = 3
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
scores = np.zeros((num_contours, ), dtype=np.float32)
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
if not isinstance(dest_width, int):
dest_width = dest_width.item()
dest_height = dest_height.item()
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes[index, :, :] = box.astype(np.int16)
scores[index] = score
return boxes, scores
def unclip(self, box):
unclip_ratio = self.unclip_ratio
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, pred, ratio_list):
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(
pred[batch_index], segmentation[batch_index], width, height)
boxes = []
for k in range(len(tmp_boxes)):
if tmp_scores[k] > self.box_thresh:
boxes.append(tmp_boxes[k])
if len(boxes) > 0:
boxes = np.array(boxes)
ratio_h, ratio_w = ratio_list[batch_index]
boxes[:, :, 0] = boxes[:, :, 0] / ratio_w
boxes[:, :, 1] = boxes[:, :, 1] / ratio_h
boxes_batch.append(boxes)
return boxes_batch
def __repr__(self):
return self.__class__.__name__ + \
" thresh: {1}, box_thresh: {2}, max_candidates: {3}, unclip_ratio: {4}, min_size: {5}".format(
self.thresh, self.box_thresh, self.max_candidates, self.unclip_ratio, self.min_size)
class FilterBoxes(object):
def __init__(self, width, height):
self.filter_width = width
self.filter_height = height
def order_points_clockwise(self, pts):
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
# sort the points based on their x-coordinates
"""
xSorted = pts[np.argsort(pts[:, 0]), :]
# grab the left-most and right-most points from the sorted
# x-roodinate points
leftMost = xSorted[:2, :]
rightMost = xSorted[2:, :]
# now, sort the left-most coordinates according to their
# y-coordinates so we can grab the top-left and bottom-left
# points, respectively
leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
(tl, bl) = leftMost
rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
(tr, br) = rightMost
rect = np.array([tl, tr, br, bl], dtype="float32")
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(4):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def __call__(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= self.filter_width or \
rect_height <= self.filter_height:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def __repr__(self):
return self.__class__.__name__ + " filter_width: {1}, filter_height: {2}".format(
self.filter_width, self.filter_height)
class SegPostprocess(object):
def __init__(self, class_num):
self.class_num = class_num
......@@ -473,6 +668,57 @@ class Resize(object):
_cv2_interpolation_to_str[self.interpolation])
class ResizeByFactor(object):
"""Resize the input numpy array Image to a size multiple of factor which is usually required by a network
Args:
factor (int): Resize factor. make width and height multiple factor of the value of factor. Default is 32
max_side_len (int): max size of width and height. if width or height is larger than max_side_len, just resize the width or the height. Default is 2400
"""
def __init__(self, factor=32, max_side_len=2400):
self.factor = factor
self.max_side_len = max_side_len
def __call__(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if max(resize_h, resize_w) > self.max_side_len:
if resize_h > resize_w:
ratio = float(self.max_side_len) / resize_h
else:
ratio = float(self.max_side_len) / resize_w
else:
ratio = 1.
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
if resize_h % self.factor == 0:
resize_h = resize_h
elif resize_h // self.factor <= 1:
resize_h = self.factor
else:
resize_h = (resize_h // 32 - 1) * 32
if resize_w % self.factor == 0:
resize_w = resize_w
elif resize_w // self.factor <= 1:
resize_w = self.factor
else:
resize_w = (resize_w // self.factor - 1) * self.factor
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
im = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(resize_w, resize_h)
sys.exit(0)
return im
def __repr__(self):
return self.__class__.__name__ + '(factor={0}, max_side_len={1})'.format(
self.factor, self.max_side_len)
class PadStride(object):
def __init__(self, stride):
self.coarsest_stride = stride
......
......@@ -21,7 +21,12 @@ import google.protobuf.text_format
import numpy as np
import time
import sys
from .serving_client import PredictorRes
import grpc
from .proto import multi_lang_general_model_service_pb2
sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import multi_lang_general_model_service_pb2_grpc
int_type = 0
float_type = 1
......@@ -125,6 +130,8 @@ class Client(object):
self.all_numpy_input = True
self.has_numpy_input = False
self.rpc_timeout_ms = 20000
from .serving_client import PredictorRes
self.predictorres_constructor = PredictorRes
def load_client_config(self, path):
from .serving_client import PredictorClient
......@@ -304,7 +311,7 @@ class Client(object):
self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0')
result_batch_handle = PredictorRes()
result_batch_handle = self.predictorres_constructor()
if self.all_numpy_input:
res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch,
......@@ -372,3 +379,172 @@ class Client(object):
def release(self):
self.client_handle_.destroy_predictor()
self.client_handle_ = None
class MultiLangClient(object):
def __init__(self):
self.channel_ = None
def load_client_config(self, path):
if not isinstance(path, str):
raise Exception("GClient only supports multi-model temporarily")
self._parse_model_config(path)
def connect(self, endpoint):
self.channel_ = grpc.insecure_channel(endpoint[0]) #TODO
self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
self.channel_)
def _flatten_list(self, nested_list):
for item in nested_list:
if isinstance(item, (list, tuple)):
for sub_item in self._flatten_list(item):
yield sub_item
else:
yield item
def _parse_model_config(self, model_config_path):
model_conf = m_config.GeneralModelConfig()
f = open(model_config_path, 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _pack_feed_data(self, feed, fetch, is_python):
req = multi_lang_general_model_service_pb2.Request()
req.fetch_var_names.extend(fetch)
req.feed_var_names.extend(feed.keys())
req.is_python = is_python
feed_batch = None
if isinstance(feed, dict):
feed_batch = [feed]
elif isinstance(feed, list):
feed_batch = feed
else:
raise Exception("{} not support".format(type(feed)))
init_feed_names = False
for feed_data in feed_batch:
inst = multi_lang_general_model_service_pb2.FeedInst()
for name in req.feed_var_names:
tensor = multi_lang_general_model_service_pb2.Tensor()
var = feed_data[name]
v_type = self.feed_types_[name]
if is_python:
data = None
if isinstance(var, list):
if v_type == 0: # int64
data = np.array(var, dtype="int64")
elif v_type == 1: # float32
data = np.array(var, dtype="float32")
else:
raise Exception("error type.")
else:
data = var
if var.dtype == "float64":
data = data.astype("float32")
tensor.data = data.tobytes()
else:
if v_type == 0: # int64
if isinstance(var, np.ndarray):
tensor.int64_data.extend(var.reshape(-1).tolist())
else:
tensor.int64_data.extend(self._flatten_list(var))
elif v_type == 1: # float32
if isinstance(var, np.ndarray):
tensor.float_data.extend(var.reshape(-1).tolist())
else:
tensor.float_data.extend(self._flatten_list(var))
else:
raise Exception("error type.")
if isinstance(var, np.ndarray):
tensor.shape.extend(list(var.shape))
else:
tensor.shape.extend(self.feed_shapes_[name])
inst.tensor_array.append(tensor)
req.insts.append(inst)
return req
def _unpack_resp(self, resp, fetch, is_python, need_variant_tag):
result_map = {}
inst = resp.outputs[0].insts[0]
tag = resp.tag
for i, name in enumerate(fetch):
var = inst.tensor_array[i]
v_type = self.fetch_types_[name]
if is_python:
if v_type == 0: # int64
result_map[name] = np.frombuffer(var.data, dtype="int64")
elif v_type == 1: # float32
result_map[name] = np.frombuffer(var.data, dtype="float32")
else:
raise Exception("error type.")
else:
if v_type == 0: # int64
result_map[name] = np.array(
list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
result_map[name] = np.array(
list(var.float_data), dtype="float32")
else:
raise Exception("error type.")
result_map[name].shape = list(var.shape)
if name in self.lod_tensor_set_:
result_map["{}.lod".format(name)] = np.array(list(var.lod))
return result_map if not need_variant_tag else [result_map, tag]
def _done_callback_func(self, fetch, is_python, need_variant_tag):
def unpack_resp(resp):
return self._unpack_resp(resp, fetch, is_python, need_variant_tag)
return unpack_resp
def predict(self,
feed,
fetch,
need_variant_tag=False,
asyn=False,
is_python=True):
req = self._pack_feed_data(feed, fetch, is_python=is_python)
if not asyn:
resp = self.stub_.inference(req)
return self._unpack_resp(
resp,
fetch,
is_python=is_python,
need_variant_tag=need_variant_tag)
else:
call_future = self.stub_.inference.future(req)
return MultiLangPredictFuture(
call_future,
self._done_callback_func(
fetch,
is_python=is_python,
need_variant_tag=need_variant_tag))
class MultiLangPredictFuture(object):
def __init__(self, call_future, callback_func):
self.call_future_ = call_future
self.callback_func_ = callback_func
def result(self):
resp = self.call_future_.result()
return self.callback_func_(resp)
......@@ -17,6 +17,7 @@ import sys
import subprocess
import argparse
from multiprocessing import Pool
import numpy as np
def benchmark_args():
......@@ -35,6 +36,17 @@ def benchmark_args():
return parser.parse_args()
def show_latency(latency_list):
latency_array = np.array(latency_list)
info = "latency:\n"
info += "mean :{} ms\n".format(np.mean(latency_array))
info += "median :{} ms\n".format(np.median(latency_array))
info += "80 percent :{} ms\n".format(np.percentile(latency_array, 80))
info += "90 percent :{} ms\n".format(np.percentile(latency_array, 90))
info += "99 percent :{} ms\n".format(np.percentile(latency_array, 99))
sys.stderr.write(info)
class MultiThreadRunner(object):
def __init__(self):
pass
......
......@@ -25,6 +25,16 @@ from contextlib import closing
import collections
import fcntl
import numpy as np
import grpc
from .proto import multi_lang_general_model_service_pb2
import sys
sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import multi_lang_general_model_service_pb2_grpc
from multiprocessing import Pool, Process
from concurrent import futures
class OpMaker(object):
def __init__(self):
......@@ -428,3 +438,158 @@ class Server(object):
print("Going to Run Command")
print(command)
os.system(command)
class MultiLangServerService(
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService):
def __init__(self, model_config_path, endpoints):
from paddle_serving_client import Client
self._parse_model_config(model_config_path)
self.bclient_ = Client()
self.bclient_.load_client_config(
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path):
model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(model_config_path),
'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _flatten_list(self, nested_list):
for item in nested_list:
if isinstance(item, (list, tuple)):
for sub_item in self._flatten_list(item):
yield sub_item
else:
yield item
def _unpack_request(self, request):
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
feed_batch = []
for feed_inst in request.insts:
feed_dict = {}
for idx, name in enumerate(feed_names):
var = feed_inst.tensor_array[idx]
v_type = self.feed_types_[name]
data = None
if is_python:
if v_type == 0:
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:
data = np.frombuffer(var.data, dtype="float32")
else:
raise Exception("error type.")
else:
if v_type == 0: # int64
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data
feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python
def _pack_resp_package(self, result, fetch_names, is_python, tag):
resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if is_python:
tensor.data = result[name].tobytes()
else:
if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(result[name].shape))
if name in self.lod_tensor_set_:
tensor.lod.extend(result["{}.lod".format(name)].tolist())
inst.tensor_array.append(tensor)
model_output.insts.append(inst)
resp.outputs.append(model_output)
resp.tag = tag
return resp
def inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_request(request)
data, tag = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, is_python, tag)
class MultiLangServer(object):
def __init__(self, worker_num=2):
self.bserver_ = Server()
self.worker_num_ = worker_num
def set_op_sequence(self, op_seq):
self.bserver_.set_op_sequence(op_seq)
def load_model_config(self, model_config_path):
if not isinstance(model_config_path, str):
raise Exception(
"MultiLangServer only supports multi-model temporarily")
self.bserver_.load_model_config(model_config_path)
self.model_config_path_ = model_config_path
def prepare_server(self, workdir=None, port=9292, device="cpu"):
default_port = 12000
self.port_list_ = []
for i in range(1000):
if default_port + i != port and self._port_is_available(default_port
+ i):
self.port_list_.append(default_port + i)
break
self.bserver_.prepare_server(
workdir=workdir, port=self.port_list_[0], device=device)
self.gport_ = port
def _launch_brpc_service(self, bserver):
bserver.run_server()
def _port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
return result != 0
def run_server(self):
p_bserver = Process(
target=self._launch_brpc_service, args=(self.bserver_, ))
p_bserver.start()
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_))
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerService(self.model_config_path_,
["0.0.0.0:{}".format(self.port_list_[0])]),
server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
p_bserver.join()
server.wait_for_termination()
......@@ -27,6 +27,16 @@ import argparse
import collections
import fcntl
import numpy as np
import grpc
from .proto import multi_lang_general_model_service_pb2
import sys
sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import multi_lang_general_model_service_pb2_grpc
from multiprocessing import Pool, Process
from concurrent import futures
def serve_args():
parser = argparse.ArgumentParser("serve")
......@@ -469,3 +479,158 @@ class Server(object):
print(command)
os.system(command)
class MultiLangServerService(
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService):
def __init__(self, model_config_path, endpoints):
from paddle_serving_client import Client
self._parse_model_config(model_config_path)
self.bclient_ = Client()
self.bclient_.load_client_config(
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path):
model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(model_config_path),
'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _flatten_list(self, nested_list):
for item in nested_list:
if isinstance(item, (list, tuple)):
for sub_item in self._flatten_list(item):
yield sub_item
else:
yield item
def _unpack_request(self, request):
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
feed_batch = []
for feed_inst in request.insts:
feed_dict = {}
for idx, name in enumerate(feed_names):
var = feed_inst.tensor_array[idx]
v_type = self.feed_types_[name]
data = None
if is_python:
if v_type == 0:
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:
data = np.frombuffer(var.data, dtype="float32")
else:
raise Exception("error type.")
else:
if v_type == 0: # int64
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data
feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python
def _pack_resp_package(self, result, fetch_names, is_python, tag):
resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if is_python:
tensor.data = result[name].tobytes()
else:
if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(result[name].shape))
if name in self.lod_tensor_set_:
tensor.lod.extend(result["{}.lod".format(name)].tolist())
inst.tensor_array.append(tensor)
model_output.insts.append(inst)
resp.outputs.append(model_output)
resp.tag = tag
return resp
def inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_request(request)
data, tag = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, is_python, tag)
class MultiLangServer(object):
def __init__(self, worker_num=2):
self.bserver_ = Server()
self.worker_num_ = worker_num
def set_op_sequence(self, op_seq):
self.bserver_.set_op_sequence(op_seq)
def load_model_config(self, model_config_path):
if not isinstance(model_config_path, str):
raise Exception(
"MultiLangServer only supports multi-model temporarily")
self.bserver_.load_model_config(model_config_path)
self.model_config_path_ = model_config_path
def prepare_server(self, workdir=None, port=9292, device="cpu"):
default_port = 12000
self.port_list_ = []
for i in range(1000):
if default_port + i != port and self._port_is_available(default_port
+ i):
self.port_list_.append(default_port + i)
break
self.bserver_.prepare_server(
workdir=workdir, port=self.port_list_[0], device=device)
self.gport_ = port
def _launch_brpc_service(self, bserver):
bserver.run_server()
def _port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
return result != 0
def run_server(self):
p_bserver = Process(
target=self._launch_brpc_service, args=(self.bserver_, ))
p_bserver.start()
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_))
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerService(self.model_config_path_,
["0.0.0.0:{}".format(self.port_list_[0])]),
server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
p_bserver.join()
server.wait_for_termination()
numpy>=1.12, <=1.16.4 ; python_version<"3.5"
grpcio-tools>=1.28.1
grpcio>=1.28.1
......@@ -42,7 +42,8 @@ if '${PACK}' == 'ON':
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'sentencepiece', 'opencv-python', 'pillow'
'six >= 1.10.0', 'sentencepiece', 'opencv-python', 'pillow',
'shapely', 'pyclipper'
]
packages=['paddle_serving_app',
......
......@@ -58,7 +58,8 @@ if '${PACK}' == 'ON':
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0', 'numpy >= 1.12'
'six >= 1.10.0', 'protobuf >= 3.1.0', 'numpy >= 1.12', 'grpcio >= 1.28.1',
'grpcio-tools >= 1.28.1'
]
if not find_package("paddlepaddle") and not find_package("paddlepaddle-gpu"):
......
......@@ -37,7 +37,7 @@ def python_version():
max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0',
'six >= 1.10.0', 'protobuf >= 3.1.0', 'grpcio >= 1.28.1', 'grpcio-tools >= 1.28.1',
'paddle_serving_client', 'flask >= 1.1.1', 'paddle_serving_app'
]
......
......@@ -37,7 +37,7 @@ def python_version():
max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0',
'six >= 1.10.0', 'protobuf >= 3.1.0', 'grpcio >= 1.28.1', 'grpcio-tools >= 1.28.1',
'paddle_serving_client', 'flask >= 1.1.1', 'paddle_serving_app'
]
......
......@@ -9,4 +9,6 @@ RUN yum -y install wget && \
yum -y install python3 python3-devel && \
yum clean all && \
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \
python get-pip.py && rm get-pip.py
python get-pip.py && rm get-pip.py && \
localedef -c -i en_US -f UTF-8 en_US.UTF-8 && \
echo "export LANG=en_US.utf8" >> /root/.bashrc
......@@ -44,4 +44,6 @@ RUN yum -y install wget && \
cd .. && rm -rf Python-3.6.8* && \
pip3 install google protobuf setuptools wheel flask numpy==1.16.4 && \
yum -y install epel-release && yum -y install patchelf libXext libSM libXrender && \
yum clean all
yum clean all && \
localedef -c -i en_US -f UTF-8 en_US.UTF-8 && \
echo "export LANG=en_US.utf8" >> /root/.bashrc
......@@ -44,4 +44,5 @@ RUN yum -y install wget && \
cd .. && rm -rf Python-3.6.8* && \
pip3 install google protobuf setuptools wheel flask numpy==1.16.4 && \
yum -y install epel-release && yum -y install patchelf libXext libSM libXrender && \
yum clean all
yum clean all && \
echo "export LANG=en_US.utf8" >> /root/.bashrc
......@@ -21,4 +21,6 @@ RUN yum -y install wget >/dev/null \
&& yum install -y python3 python3-devel \
&& pip3 install google protobuf setuptools wheel flask \
&& yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\
&& yum clean all
&& yum clean all \
&& localedef -c -i en_US -f UTF-8 en_US.UTF-8 \
&& echo "export LANG=en_US.utf8" >> /root/.bashrc
......@@ -15,6 +15,7 @@ RUN yum -y install wget && \
echo 'export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH' >> /root/.bashrc && \
ln -s /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7 /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so && \
echo 'export LD_LIBRARY_PATH=/usr/local/cuda-9.0/targets/x86_64-linux/lib:$LD_LIBRARY_PATH' >> /root/.bashrc && \
echo "export LANG=en_US.utf8" >> /root/.bashrc && \
mkdir -p /usr/local/cuda/extras
COPY --from=builder /usr/local/cuda/extras/CUPTI /usr/local/cuda/extras/CUPTI
......@@ -22,4 +22,5 @@ RUN yum -y install wget >/dev/null \
&& yum install -y python3 python3-devel \
&& pip3 install google protobuf setuptools wheel flask \
&& yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\
&& yum clean all
&& yum clean all \
&& echo "export LANG=en_US.utf8" >> /root/.bashrc
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册