提交 991f7efb 编写于 作者: M MRXLT

fix conflict

......@@ -84,6 +84,7 @@ python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --po
| `model` | str | `""` | Path of paddle model directory to be served |
| `mem_optim` | bool | `False` | Enable memory / graphic memory optimization |
| `ir_optim` | bool | `False` | Enable analysis and optimization of calculation graph |
| `use_mkl` (Only for cpu version) | bool | `False` | Run inference with MKL |
Here, we use `curl` to send a HTTP POST request to the service we just started. Users can use any python library to send HTTP POST as well, e.g, [requests](https://requests.readthedocs.io/en/master/).
</center>
......
......@@ -88,6 +88,7 @@ python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --po
| `model` | str | `""` | Path of paddle model directory to be served |
| `mem_optim` | bool | `False` | Enable memory optimization |
| `ir_optim` | bool | `False` | Enable analysis and optimization of calculation graph |
| `use_mkl` (Only for cpu version) | bool | `False` | Run inference with MKL |
我们使用 `curl` 命令来发送HTTP POST请求给刚刚启动的服务。用户也可以调用python库来发送HTTP POST请求,请参考英文文档 [requests](https://requests.readthedocs.io/en/master/)。
</center>
......
......@@ -83,9 +83,6 @@ func JsonReq(method, requrl string, timeout int, kv *map[string]string,
}
func GetHdfsMeta(src string) (master, ugi, path string, err error) {
//src = "hdfs://root:rootpasst@st1-inf-platform0.st01.baidu.com:54310/user/mis_user/news_dnn_ctr_cube_1/1501836820/news_dnn_ctr_cube_1_part54.tar"
//src = "hdfs://st1-inf-platform0.st01.baidu.com:54310/user/mis_user/news_dnn_ctr_cube_1/1501836820/news_dnn_ctr_cube_1_part54.tar"
ugiBegin := strings.Index(src, "//")
ugiPos := strings.LastIndex(src, "@")
if ugiPos != -1 && ugiBegin != -1 {
......
......@@ -69,9 +69,15 @@ class ModelRes {
const std::vector<int64_t>& get_int64_by_name(const std::string& name) {
return _int64_value_map[name];
}
std::vector<int64_t>&& get_int64_by_name_with_rv(const std::string& name) {
return std::move(_int64_value_map[name]);
}
const std::vector<float>& get_float_by_name(const std::string& name) {
return _float_value_map[name];
}
std::vector<float>&& get_float_by_name_with_rv(const std::string& name) {
return std::move(_float_value_map[name]);
}
const std::vector<int>& get_shape(const std::string& name) {
return _shape_map[name];
}
......@@ -121,10 +127,18 @@ class PredictorRes {
const std::string& name) {
return _models[model_idx].get_int64_by_name(name);
}
std::vector<int64_t>&& get_int64_by_name_with_rv(const int model_idx,
const std::string& name) {
return std::move(_models[model_idx].get_int64_by_name_with_rv(name));
}
const std::vector<float>& get_float_by_name(const int model_idx,
const std::string& name) {
return _models[model_idx].get_float_by_name(name);
}
std::vector<float>&& get_float_by_name_with_rv(const int model_idx,
const std::string& name) {
return std::move(_models[model_idx].get_float_by_name_with_rv(name));
}
const std::vector<int>& get_shape(const int model_idx,
const std::string& name) {
return _models[model_idx].get_shape(name);
......
......@@ -258,9 +258,10 @@ int PredictorClient::batch_predict(
ModelRes model;
model.set_engine_name(output.engine_name());
int idx = 0;
for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name];
int idx = 0;
int shape_size = output.insts(0).tensor_array(idx).shape_size();
VLOG(2) << "fetch var " << name << " index " << idx << " shape size "
<< shape_size;
......@@ -279,9 +280,9 @@ int PredictorClient::batch_predict(
idx += 1;
}
idx = 0;
for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name];
int idx = 0;
if (_fetch_name_to_type[name] == 0) {
VLOG(2) << "ferch var " << name << "type int";
model._int64_value_map[name].resize(
......@@ -345,7 +346,7 @@ int PredictorClient::numpy_predict(
PredictorRes &predict_res_batch,
const int &pid) {
int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
VLOG(2) << "batch size: " << batch_size;
predict_res_batch.clear();
Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS();
......@@ -462,7 +463,7 @@ int PredictorClient::numpy_predict(
for (ssize_t j = 0; j < int_array.shape(1); j++) {
for (ssize_t k = 0; k < int_array.shape(2); k++) {
for (ssize_t l = 0; k < int_array.shape(3); l++) {
tensor->add_float_data(int_array(i, j, k, l));
tensor->add_int64_data(int_array(i, j, k, l));
}
}
}
......@@ -474,7 +475,7 @@ int PredictorClient::numpy_predict(
for (ssize_t i = 0; i < int_array.shape(0); i++) {
for (ssize_t j = 0; j < int_array.shape(1); j++) {
for (ssize_t k = 0; k < int_array.shape(2); k++) {
tensor->add_float_data(int_array(i, j, k));
tensor->add_int64_data(int_array(i, j, k));
}
}
}
......@@ -484,7 +485,7 @@ int PredictorClient::numpy_predict(
auto int_array = int_feed[vec_idx].unchecked<2>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
for (ssize_t j = 0; j < int_array.shape(1); j++) {
tensor->add_float_data(int_array(i, j));
tensor->add_int64_data(int_array(i, j));
}
}
break;
......@@ -492,7 +493,7 @@ int PredictorClient::numpy_predict(
case 1: {
auto int_array = int_feed[vec_idx].unchecked<1>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
tensor->add_float_data(int_array(i));
tensor->add_int64_data(int_array(i));
}
break;
}
......@@ -536,9 +537,9 @@ int PredictorClient::numpy_predict(
ModelRes model;
model.set_engine_name(output.engine_name());
int idx = 0;
for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name];
int idx = 0;
int shape_size = output.insts(0).tensor_array(idx).shape_size();
VLOG(2) << "fetch var " << name << " index " << idx << " shape size "
<< shape_size;
......@@ -557,9 +558,10 @@ int PredictorClient::numpy_predict(
idx += 1;
}
idx = 0;
for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name];
int idx = 0;
if (_fetch_name_to_type[name] == 0) {
VLOG(2) << "ferch var " << name << "type int";
model._int64_value_map[name].resize(
......
......@@ -32,14 +32,23 @@ PYBIND11_MODULE(serving_client, m) {
.def(py::init())
.def("get_int64_by_name",
[](PredictorRes &self, int model_idx, std::string &name) {
return self.get_int64_by_name(model_idx, name);
},
py::return_value_policy::reference)
// see more: https://github.com/pybind/pybind11/issues/1042
std::vector<int64_t> *ptr = new std::vector<int64_t>(
std::move(self.get_int64_by_name_with_rv(model_idx, name)));
auto capsule = py::capsule(ptr, [](void *p) {
delete reinterpret_cast<std::vector<int64_t> *>(p);
});
return py::array(ptr->size(), ptr->data(), capsule);
})
.def("get_float_by_name",
[](PredictorRes &self, int model_idx, std::string &name) {
return self.get_float_by_name(model_idx, name);
},
py::return_value_policy::reference)
std::vector<float> *ptr = new std::vector<float>(
std::move(self.get_float_by_name_with_rv(model_idx, name)));
auto capsule = py::capsule(ptr, [](void *p) {
delete reinterpret_cast<std::vector<float> *>(p);
});
return py::array(ptr->size(), ptr->data(), capsule);
})
.def("get_shape",
[](PredictorRes &self, int model_idx, std::string &name) {
return self.get_shape(model_idx, name);
......
......@@ -16,7 +16,11 @@ It is recommended to use Docker for compilation. We have prepared the Paddle Ser
- CPU: `hub.baidubce.com/paddlepaddle/serving:0.2.0-devel`,dockerfile: [Dockerfile.devel](../tools/Dockerfile.devel)
- GPU: `hub.baidubce.com/paddlepaddle/serving:0.2.0-gpu-devel`,dockerfile: [Dockerfile.gpu.devel](../tools/Dockerfile.gpu.devel)
This document will take Python2 as an example to show how to compile Paddle Serving. If you want to compile with Python 3, just adjust the Python options of cmake.
This document will take Python2 as an example to show how to compile Paddle Serving. If you want to compile with Python3, just adjust the Python options of cmake:
- Set `DPYTHON_INCLUDE_DIR` to `$PYTHONROOT/include/python3.6m/`
- Set `DPYTHON_LIBRARIES` to `$PYTHONROOT/lib64/libpython3.6.so`
- Set `DPYTHON_EXECUTABLE` to `$PYTHONROOT/bin/python3`
## Get Code
......@@ -54,7 +58,7 @@ make -j10
execute `make install` to put targets under directory `./output`
**Attention:**After the compilation is successful, the serving binary file will be generated in the ./core/general-server directory. Before starting the server, export SERVING_BIN = $ {path / to / serving / bin} is required to allow the server to use the compiled serving binary file.
**Attention:** After the compilation is successful, you need to set the path of `SERVING_BIN`. See [Note](https://github.com/PaddlePaddle/Serving/blob/develop/doc/COMPILE.md#Note) for details.
## Compile Client
......
......@@ -16,7 +16,11 @@
- CPU: `hub.baidubce.com/paddlepaddle/serving:0.2.0-devel`,dockerfile: [Dockerfile.devel](../tools/Dockerfile.devel)
- GPU: `hub.baidubce.com/paddlepaddle/serving:0.2.0-gpu-devel`,dockerfile: [Dockerfile.gpu.devel](../tools/Dockerfile.gpu.devel)
本文档将以Python2为例介绍如何编译Paddle Serving。如果您想用Python3进行编译,只需要调整cmake的Python相关选项即可。
本文档将以Python2为例介绍如何编译Paddle Serving。如果您想用Python3进行编译,只需要调整cmake的Python相关选项即可:
-`DPYTHON_INCLUDE_DIR`设置为`$PYTHONROOT/include/python3.6m/`
-`DPYTHON_LIBRARIES`设置为`$PYTHONROOT/lib64/libpython3.6.so`
-`DPYTHON_EXECUTABLE`设置为`$PYTHONROOT/bin/python3`
## 获取代码
......@@ -54,7 +58,7 @@ make -j10
执行`make install`可以把目标产出放在`./output`目录下。
**注意:** 编译成功后,在./core/general-server目录下会产出serving二进制文件。启动server前需要export SERVING_BIN=${path/to/serving/bin} 来让server端使用编译出的serving二进制文件
**注意:** 编译成功后,需要设置`SERVING_BIN`路径,详见后面的[注意事项](https://github.com/PaddlePaddle/Serving/blob/develop/doc/COMPILE_CN.md#注意事项)
## 编译Client部分
......
# Cascade RCNN model on Paddle Serving
([简体中文](./README_CN.md)|English)
### Get The Cascade RCNN Model
```
sh get_data.sh
```
If you want to have more detection models, please refer to [Paddle Detection Model Zoo](https://github.com/PaddlePaddle/PaddleDetection/blob/release/0.2/docs/MODEL_ZOO_cn.md)
### Start the service
```
python -m paddle_serving_server_gpu.serve --model serving_server --port 9292 --gpu_id 0
```
### Perform prediction
```
python test_client.py
```
Image with bounding boxes and json result would be saved in `output` folder.
# 使用Paddle Serving部署Cascade RCNN模型
(简体中文|[English](./README.md))
## 获得Cascade RCNN模型
```
sh get_data.sh
```
如果你想要更多的检测模型,请参考[Paddle检测模型库](https://github.com/PaddlePaddle/PaddleDetection/blob/release/0.2/docs/MODEL_ZOO_cn.md)
### 启动服务
```
python -m paddle_serving_server_gpu.serve --model serving_server --port 9292 --gpu_id 0
```
### 执行预测
```
python test_client.py
```
客户端已经为图片做好了后处理,在`output`文件夹下存放各个框的json格式信息还有后处理结果图片。
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/cascade_rcnn_r50_fpx_1x_serving.tar.gz
tar xf cascade_rcnn_r50_fpx_1x_serving.tar.gz
......@@ -2,16 +2,6 @@
([简体中文](./README_CN.md)|English)
### Compile Source Code
in the root directory of this git project
```
mkdir build_server
cd build_server
cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython2.7.so -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python -DSERVER=ON ..
make -j10
make install -j10
```
### Get Sample Dataset
go to directory `python/examples/criteo_ctr_with_cube`
......
## 带稀疏参数索引服务的CTR预测服务
(简体中文|[English](./README.md))
### 编译源代码
在本项目的根目录下,执行
```
mkdir build_server
cd build_server
cmake -DPYTHON_INCLUDE_DIR=$PYTHONROOT/include/python2.7/ -DPYTHON_LIBRARIES=$PYTHONROOT/lib64/libpython2.7.so -DPYTHON_EXECUTABLE=$PYTHONROOT/bin/python -DSERVER=ON ..
make -j10
make install -j10
```
### 获取样例数据
进入目录 `python/examples/criteo_ctr_with_cube`
```
......
......@@ -15,34 +15,35 @@ sh get_model.sh
pip install paddle_serving_app
```
### HTTP Infer
### HTTP Service
launch server side
```
python image_classification_service.py ResNet50_vd_model workdir 9393 #cpu inference service
python resnet50_web_service.py ResNet50_vd_model cpu 9696 #cpu inference service
```
```
python image_classification_service_gpu.py ResNet50_vd_model workdir 9393 #gpu inference service
python resnet50_web_service.py ResNet50_vd_model gpu 9696 #gpu inference service
```
client send inference request
```
python image_http_client.py
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"}], "fetch": ["score"]}' http://127.0.0.1:9696/image/prediction
```
### RPC Infer
### RPC Service
launch server side
```
python -m paddle_serving_server.serve --model ResNet50_vd_model --port 9393 #cpu inference service
python -m paddle_serving_server.serve --model ResNet50_vd_model --port 9696 #cpu inference service
```
```
python -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 9393 --gpu_ids 0 #gpu inference service
python -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 9696 --gpu_ids 0 #gpu inference service
```
client send inference request
```
python image_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt
```
*the port of server side in this example is 9393, the sample data used by client side is in the folder ./data. These parameter can be modified in practice*
*the port of server side in this example is 9696
......@@ -15,34 +15,35 @@ sh get_model.sh
pip install paddle_serving_app
```
### 执行HTTP预测服务
### HTTP服务
启动server端
```
python image_classification_service.py ResNet50_vd_model workdir 9393 #cpu预测服务
python image_classification_service.py ResNet50_vd_model cpu 9696 #cpu预测服务
```
```
python image_classification_service_gpu.py ResNet50_vd_model workdir 9393 #gpu预测服务
python image_classification_service.py ResNet50_vd_model gpu 9696 #gpu预测服务
```
client端进行预测
发送HTTP POST请求
```
python image_http_client.py
curl -H "Content-Type:application/json" -X POST -d '{"feed":[{"image": "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"}], "fetch": ["score"]}' http://127.0.0.1:9696/image/prediction
```
### 执行RPC预测服务
### RPC服务
启动server端
```
python -m paddle_serving_server.serve --model ResNet50_vd_model --port 9393 #cpu预测服务
python -m paddle_serving_server.serve --model ResNet50_vd_model --port 9696 #cpu预测服务
```
```
python -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 9393 --gpu_ids 0 #gpu预测服务
python -m paddle_serving_server_gpu.serve --model ResNet50_vd_model --port 9696 --gpu_ids 0 #gpu预测服务
```
client端进行预测
```
python image_rpc_client.py ResNet50_vd_client_config/serving_client_conf.prototxt
```
*server端示例中服务端口为9393端口,client端示例中数据来自./data文件夹,server端地址为本地9393端口,可根据实际情况更改脚本。*
*server端示例中服务端口为9696端口
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import cv2
import base64
import numpy as np
from paddle_serving_app import ImageReader
from paddle_serving_server_gpu.web_service import WebService
class ImageService(WebService):
def preprocess(self, feed={}, fetch=[]):
reader = ImageReader()
feed_batch = []
for ins in feed:
if "image" not in ins:
raise ("feed data error!")
sample = base64.b64decode(ins["image"])
img = reader.process_image(sample)
feed_batch.append({"image": img})
return feed_batch, fetch
image_service = ImageService(name="image")
image_service.load_model_config(sys.argv[1])
image_service.set_gpus("0,1")
image_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="gpu")
image_service.run_server()
image_service.run_flask()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
import base64
import json
import time
import os
import sys
py_version = sys.version_info[0]
def predict(image_path, server):
if py_version == 2:
image = base64.b64encode(open(image_path).read())
else:
image = base64.b64encode(open(image_path, "rb").read()).decode("utf-8")
req = json.dumps({"feed": [{"image": image}], "fetch": ["score"]})
r = requests.post(
server, data=req, headers={"Content-Type": "application/json"})
try:
print(r.json()["result"]["score"])
except ValueError:
print(r.text)
return r
if __name__ == "__main__":
server = "http://127.0.0.1:9393/image/prediction"
image_list = os.listdir("./image_data/n01440764/")
start = time.time()
for img in image_list:
image_file = "./image_data/n01440764/" + img
res = predict(image_file, server)
end = time.time()
print(end - start)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
class ImageReader():
def __init__(self):
self.image_mean = [0.485, 0.456, 0.406]
self.image_std = [0.229, 0.224, 0.225]
self.image_shape = [3, 224, 224]
self.resize_short_size = 256
self.interpolation = None
def resize_short(self, img, target_size, interpolation=None):
"""resize image
Args:
img: image data
target_size: resize short target size
interpolation: interpolation mode
Returns:
resized image data
"""
percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
if interpolation:
resized = cv2.resize(
img, (resized_width, resized_height),
interpolation=interpolation)
else:
resized = cv2.resize(img, (resized_width, resized_height))
return resized
def crop_image(self, img, target_size, center):
"""crop image
Args:
img: images data
target_size: crop target size
center: crop mode
Returns:
img: cropped image data
"""
height, width = img.shape[:2]
size = target_size
if center == True:
w_start = (width - size) // 2
h_start = (height - size) // 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :]
return img
def process_image(self, sample):
""" process_image """
mean = self.image_mean
std = self.image_std
crop_size = self.image_shape[1]
data = np.fromstring(sample, np.uint8)
img = cv2.imdecode(data, cv2.IMREAD_COLOR)
if img is None:
print("img is None, pass it.")
return None
if crop_size > 0:
target_size = self.resize_short_size
img = self.resize_short(
img, target_size, interpolation=self.interpolation)
img = self.crop_image(img, target_size=crop_size, center=True)
img = img[:, :, ::-1]
img = img.astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
return img
此差异已折叠。
......@@ -14,23 +14,35 @@
import sys
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, File2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize
from paddle_serving_app.reader import Sequential, URL2Image, Resize
from paddle_serving_app.reader import CenterCrop, RGB2BGR, Transpose, Div, Normalize
import time
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"])
client.connect(["127.0.0.1:9696"])
label_dict = {}
label_idx = 0
with open("imagenet.label") as fin:
for line in fin:
label_dict[label_idx] = line.strip()
label_idx += 1
seq = Sequential([
File2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
URL2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose((2, 0, 1)),
Div(255), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], True)
])
print(seq)
start = time.time()
image_file = "daisy.jpg"
for i in range(1000):
image_file = "https://paddle-serving.bj.bcebos.com/imagenet-example/daisy.jpg"
for i in range(10):
img = seq(image_file)
fetch_map = client.predict(feed={"image": img}, fetch=["score"])
prob = max(fetch_map["score"][0])
label = label_dict[fetch_map["score"][0].tolist().index(prob)].strip(
).replace(",", "")
print("prediction: {}, probability: {}".format(label, prob))
end = time.time()
print(end - start)
......@@ -11,29 +11,62 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_server.web_service import WebService
from paddle_serving_app import ImageReader
import sys
import base64
from paddle_serving_client import Client
from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize
if len(sys.argv) != 4:
print("python resnet50_web_service.py model device port")
sys.exit(-1)
device = sys.argv[2]
if device == "cpu":
from paddle_serving_server.web_service import WebService
else:
from paddle_serving_server_gpu.web_service import WebService
class ImageService(WebService):
def preprocess(self, feed={}, fetch=[]):
reader = ImageReader()
def init_imagenet_setting(self):
self.seq = Sequential([
URL2Image(), Resize(256), CenterCrop(224), RGB2BGR(), Transpose(
(2, 0, 1)), Div(255), Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225], True)
])
self.label_dict = {}
label_idx = 0
with open("imagenet.label") as fin:
for line in fin:
self.label_dict[label_idx] = line.strip()
label_idx += 1
def preprocess(self, feed=[], fetch=[]):
feed_batch = []
for ins in feed:
if "image" not in ins:
raise ("feed data error!")
sample = base64.b64decode(ins["image"])
img = reader.process_image(sample)
img = self.seq(ins["image"])
feed_batch.append({"image": img})
return feed_batch, fetch
def postprocess(self, feed=[], fetch=[], fetch_map={}):
score_list = fetch_map["score"]
result = {"label": [], "prob": []}
for score in score_list:
max_score = max(score)
result["label"].append(self.label_dict[score.index(max_score)]
.strip().replace(",", ""))
result["prob"].append(max_score)
return result
image_service = ImageService(name="image")
image_service.load_model_config(sys.argv[1])
image_service.init_imagenet_setting()
if device == "gpu":
image_service.set_gpus("0,1")
image_service.prepare_server(
workdir=sys.argv[2], port=int(sys.argv[3]), device="cpu")
workdir="workdir", port=int(sys.argv[3]), device=device)
image_service.run_server()
image_service.run_flask()
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .reader.chinese_bert_reader import ChineseBertReader
from .reader.image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize
from .reader.image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize, PadStride
from .reader.lac_reader import LACReader
from .reader.senta_reader import SentaReader
from .reader.imdb_reader import IMDBDataset
......
......@@ -71,6 +71,7 @@ class Debugger(object):
if profile:
config.enable_profile()
config.set_cpu_math_library_num_threads(cpu_num)
config.switch_ir_optim(False)
self.predictor = create_paddle_predictor(config)
......
......@@ -465,6 +465,24 @@ class Resize(object):
_cv2_interpolation_to_str[self.interpolation])
class PadStride(object):
def __init__(self, stride):
self.coarsest_stride = stride
def __call__(self, img):
coarsest_stride = self.coarsest_stride
if coarsest_stride == 0:
return img
im_c, im_h, im_w = img.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
im_info = {}
im_info['resize_shape'] = padding_im.shape[1:]
return padding_im
class Transpose(object):
def __init__(self, transpose_target):
self.transpose_target = transpose_target
......
......@@ -329,9 +329,9 @@ class Client(object):
# result map needs to be a numpy array
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
# result_map[name] will be py::array(numpy array)
result_map[name] = result_batch.get_int64_by_name(mi, name)
shape = result_batch.get_shape(mi, name)
result_map[name] = np.array(result_map[name], dtype='int64')
result_map[name].shape = shape
if name in self.lod_tensor_set:
result_map["{}.lod".format(name)] = np.array(
......@@ -339,8 +339,6 @@ class Client(object):
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(mi, name)
shape = result_batch.get_shape(mi, name)
result_map[name] = np.array(
result_map[name], dtype='float32')
result_map[name].shape = shape
if name in self.lod_tensor_set:
result_map["{}.lod".format(name)] = np.array(
......
......@@ -290,8 +290,8 @@ class Server(object):
# check config here
# print config here
def use_mkl(self):
self.mkl_flag = True
def use_mkl(self, flag):
self.mkl_flag = flag
def get_device_version(self):
avx_flag = False
......@@ -306,6 +306,10 @@ class Server(object):
else:
device_version = "serving-cpu-avx-openblas-"
else:
if mkl_flag:
print(
"Your CPU does not support AVX, server will running with noavx-openblas mode."
)
device_version = "serving-cpu-noavx-openblas-"
return device_version
......
......@@ -43,6 +43,7 @@ def parse_args(): # pylint: disable=doc-string-missing
"--mem_optim", type=bool, default=False, help="Memory optimize")
parser.add_argument(
"--ir_optim", type=bool, default=False, help="Graph optimize")
parser.add_argument("--use_mkl", type=bool, default=False, help="Use MKL")
parser.add_argument(
"--max_body_size",
type=int,
......@@ -61,6 +62,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
mem_optim = args.mem_optim
ir_optim = args.ir_optim
max_body_size = args.max_body_size
use_mkl = args.use_mkl
if model == "":
print("You must specify your serving model")
......@@ -82,6 +84,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size)
server.set_port(port)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册