提交 6de663ee 编写于 作者: M MRXLT

Merge remote-tracking branch 'upstream/develop' into 0.2.2-doc-fix

......@@ -85,6 +85,17 @@ include(generic)
include(flags)
endif()
if (APP)
include(external/zlib)
include(external/boost)
include(external/protobuf)
include(external/gflags)
include(external/glog)
include(external/pybind11)
include(external/python)
include(generic)
endif()
if (SERVER)
include(external/cudnn)
include(paddlepaddle)
......
......@@ -31,7 +31,7 @@ message( "WITH_GPU = ${WITH_GPU}")
# Paddle Version should be one of:
# latest: latest develop build
# version number like 1.5.2
SET(PADDLE_VERSION "1.7.1")
SET(PADDLE_VERSION "1.7.2")
if (WITH_GPU)
SET(PADDLE_LIB_VERSION "${PADDLE_VERSION}-gpu-cuda${CUDA_VERSION_MAJOR}-cudnn7-avx-mkl")
......
......@@ -23,6 +23,11 @@ add_subdirectory(pdcodegen)
add_subdirectory(sdk-cpp)
endif()
if (APP)
add_subdirectory(configure)
endif()
if(CLIENT)
add_subdirectory(general-client)
endif()
......
if (SERVER OR CLIENT)
LIST(APPEND protofiles
${CMAKE_CURRENT_LIST_DIR}/proto/server_configure.proto
${CMAKE_CURRENT_LIST_DIR}/proto/sdk_configure.proto
......@@ -28,6 +29,7 @@ FILE(GLOB inc ${CMAKE_CURRENT_BINARY_DIR}/*.pb.h)
install(FILES ${inc}
DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/include/configure)
endif()
py_proto_compile(general_model_config_py_proto SRCS proto/general_model_config.proto)
add_custom_target(general_model_config_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
......@@ -51,6 +53,14 @@ add_custom_command(TARGET general_model_config_py_proto POST_BUILD
endif()
if (APP)
add_custom_command(TARGET general_model_config_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/proto
COMMAND cp *.py ${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/proto
COMMENT "Copy generated general_model_config proto file into directory paddle_serving_app/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
if (SERVER)
py_proto_compile(server_config_py_proto SRCS proto/server_configure.proto)
add_custom_target(server_config_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
......
if(CLIENT)
add_subdirectory(pybind11)
pybind11_add_module(serving_client src/general_model.cpp src/pybind_general_model.cpp)
target_link_libraries(serving_client PRIVATE -Wl,--whole-archive utils sdk-cpp pybind python -Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl -lz)
target_link_libraries(serving_client PRIVATE -Wl,--whole-archive utils sdk-cpp pybind python -Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl -lz -Wl,-rpath,'$ORIGIN'/lib)
endif()
......@@ -13,10 +13,7 @@
# limitations under the License.
from paddle_serving_client import Client
import sys
from paddle_serving_app.reader.pddet import Detection
from paddle_serving_app.reader import File2Image, Sequential, Normalize, Resize, Transpose, Div, BGR2RGB, RCNNPostprocess
import numpy as np
from paddle_serving_app.reader import *
preprocess = Sequential([
File2Image(), BGR2RGB(), Div(255.0),
......@@ -25,19 +22,19 @@ preprocess = Sequential([
])
postprocess = RCNNPostprocess("label_list.txt", "output")
client = Client()
client.load_client_config(sys.argv[1])
client.load_client_config(
"faster_rcnn_client_conf/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9393'])
for i in range(100):
im = preprocess(sys.argv[2])
fetch_map = client.predict(
feed={
"image": im,
"im_info": np.array(list(im.shape[1:]) + [1.0]),
"im_shape": np.array(list(im.shape[1:]) + [1.0])
},
fetch=["multiclass_nms"])
fetch_map["image"] = sys.argv[2]
postprocess(fetch_map)
im = preprocess(sys.argv[2])
fetch_map = client.predict(
feed={
"image": im,
"im_info": np.array(list(im.shape[1:]) + [1.0]),
"im_shape": np.array(list(im.shape[1:]) + [1.0])
},
fetch=["multiclass_nms"])
fetch_map["image"] = sys.argv[1]
postprocess(fetch_map)
......@@ -8,6 +8,13 @@ The example uses the ResNet50_vd model to perform the imagenet 1000 classificati
```
sh get_model.sh
```
### Install preprocess module
```
pip install paddle_serving_app
```
### HTTP Infer
launch server side
......
......@@ -8,6 +8,13 @@
```
sh get_model.sh
```
### 安装数据预处理模块
```
pip install paddle_serving_app
```
### 执行HTTP预测服务
启动server端
......
......@@ -14,8 +14,10 @@
from paddle_serving_client import Client
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
py_version = sys.version_info[0]
if py_version == 2:
reload(sys)
sys.setdefaultencoding('utf-8')
import os
import io
......
......@@ -12,22 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, File2Image, Resize, CenterCrop
from paddle_serving_app.reader import RGB2BGR, Transpose, Div, Normalize
from paddle_serving_app import Debugger
import sys
import os
import time
from paddle_serving_app.reader.pddet import Detection
import numpy as np
py_version = sys.version_info[0]
debugger = Debugger()
debugger.load_model_config(sys.argv[1], gpu=True)
feed_var_names = ['image', 'im_shape', 'im_info']
fetch_var_names = ['multiclass_nms']
pddet = Detection(config_path=sys.argv[2], output_dir="./output")
feed_dict = pddet.preprocess(feed_var_names, sys.argv[3])
client = Client()
client.load_client_config(sys.argv[1])
client.connect(['127.0.0.1:9494'])
fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names)
outs = fetch_map.values()
pddet.postprocess(fetch_map, fetch_var_names)
seq = Sequential([
File2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True)
])
image_file = "daisy.jpg"
img = seq(image_file)
fetch_map = debugger.predict(feed={"image": img}, fetch=["feature_map"])
print(fetch_map["feature_map"].reshape(-1))
......@@ -4,6 +4,12 @@
```
sh get_data.sh
```
## Install preprocess module
```
pip install paddle_serving_app
```
## Start http service
```
python senta_web_service.py senta_bilstm_model/ workdir 9292
......
......@@ -4,6 +4,11 @@
```
sh get_data.sh
```
## 安装数据预处理模块
```
pip install paddle_serving_app
```
## 启动HTTP服务
```
python senta_web_service.py senta_bilstm_model/ workdir 9292
......
......@@ -16,3 +16,4 @@ from .reader.image_reader import ImageReader, File2Image, URL2Image, Sequential,
from .reader.lac_reader import LACReader
from .reader.senta_reader import SentaReader
from .models import ServingModels
from .local_predict import Debugger
# -*- coding: utf-8 -*-
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import google.protobuf.text_format
import numpy as np
import argparse
import paddle.fluid as fluid
from .proto import general_model_config_pb2 as m_config
from paddle.fluid.core import PaddleTensor
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
import logging
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
class Debugger(object):
def __init__(self):
self.feed_names_ = []
self.fetch_names_ = []
self.feed_types_ = {}
self.fetch_types_ = {}
self.feed_shapes_ = {}
self.feed_names_to_idx_ = {}
self.fetch_names_to_idx_ = {}
self.fetch_names_to_type_ = {}
def load_model_config(self, model_path, gpu=False, profile=True, cpu_num=1):
client_config = "{}/serving_server_conf.prototxt".format(model_path)
model_conf = m_config.GeneralModelConfig()
f = open(client_config, 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
config = AnalysisConfig(model_path)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_names_to_idx_ = {}
self.fetch_names_to_idx_ = {}
for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type
if not gpu:
config.disable_gpu()
else:
config.enable_use_gpu(100, 0)
if profile:
config.enable_profile()
config.set_cpu_math_library_num_threads(cpu_num)
self.predictor = create_paddle_predictor(config)
def predict(self, feed=None, fetch=None):
if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction")
fetch_list = []
if isinstance(fetch, str):
fetch_list = [fetch]
elif isinstance(fetch, list):
fetch_list = fetch
else:
raise ValueError("Fetch only accepts string and list of string")
feed_batch = []
if isinstance(feed, dict):
feed_batch.append(feed)
elif isinstance(feed, list):
feed_batch = feed
else:
raise ValueError("Feed only accepts dict and list of dict")
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
float_feed_names = []
int_shape = []
float_shape = []
fetch_names = []
counter = 0
batch_size = len(feed_batch)
for key in fetch_list:
if key in self.fetch_names_:
fetch_names.append(key)
if len(fetch_names) == 0:
raise ValueError(
"Fetch names should not be empty or out of saved fetch list.")
return {}
inputs = []
for name in self.feed_names_:
inputs.append(PaddleTensor(feed[name][np.newaxis, :]))
outputs = self.predictor.run(inputs)
fetch_map = {}
for name in fetch:
fetch_map[name] = outputs[self.fetch_names_to_idx_[
name]].as_ndarray()
return fetch_map
......@@ -16,7 +16,7 @@ import os
import urllib
import numpy as np
import base64
import functional as F
from . import functional as F
from PIL import Image, ImageDraw
import json
......
......@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
py_version = sys.version_info[0]
if py_version == 2:
reload(sys)
sys.setdefaultencoding('utf-8')
import os
import io
......
......@@ -112,7 +112,6 @@ class Client(object):
self.feed_shapes_ = {}
self.feed_types_ = {}
self.feed_names_to_idx_ = {}
self.rpath()
self.pid = os.getpid()
self.predictor_sdk_ = None
self.producers = []
......@@ -121,16 +120,6 @@ class Client(object):
self.all_numpy_input = True
self.has_numpy_input = False
def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__)
client_path = os.path.join(lib_path, 'serving_client.so')
lib_path = os.path.join(lib_path, 'lib')
ld_path = os.getenv('LD_LIBRARY_PATH')
if ld_path == None:
os.environ['LD_LIBRARY_PATH'] = lib_path
elif ld_path not in lib_path:
os.environ['LD_LIBRARY_PATH'] = ld_path + ':' + lib_path
def load_client_config(self, path):
from .serving_client import PredictorClient
from .serving_client import PredictorRes
......
......@@ -42,10 +42,11 @@ if '${PACK}' == 'ON':
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'sentencepiece'
'six >= 1.10.0', 'sentencepiece', 'opencv-python', 'pillow'
]
packages=['paddle_serving_app',
'paddle_serving_app.proto',
'paddle_serving_app.reader',
'paddle_serving_app.utils',
'paddle_serving_app.models',
......@@ -54,6 +55,8 @@ packages=['paddle_serving_app',
package_data={}
package_dir={'paddle_serving_app':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app',
'paddle_serving_app.proto':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/proto',
'paddle_serving_app.reader':
'${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader',
'paddle_serving_app.utils':
......
......@@ -26,7 +26,7 @@ from setuptools import setup
from paddle_serving_client.version import serving_client_version
from pkg_resources import DistributionNotFound, get_distribution
py_version = sys.version_info[0]
py_version = sys.version_info
def python_version():
return [int(v) for v in platform.python_version().split(".")]
......@@ -39,7 +39,12 @@ def find_package(pkgname):
return False
def copy_lib():
lib_list = ['libpython2.7.so.1.0', 'libssl.so.10', 'libcrypto.so.10'] if py_version == 2 else ['libpython3.6m.so.1.0', 'libssl.so.10', 'libcrypto.so.10']
if py_version[0] == 2:
lib_list = ['libpython2.7.so.1.0', 'libssl.so.10', 'libcrypto.so.10']
elif py_version[1] == 6:
lib_list = ['libpython3.6m.so.1.0', 'libssl.so.10', 'libcrypto.so.10']
elif py_version[1] == 7:
lib_list = ['libpython3.7m.so.1.0', 'libssl.so.10', 'libcrypto.so.10']
os.popen('mkdir -p paddle_serving_client/lib')
for lib in lib_list:
r = os.popen('whereis {}'.format(lib))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册