未验证 提交 8ccaaf40 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge branch 'develop' into develop

......@@ -188,7 +188,7 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
| `use_lite` (Only for Intel x86 CPU or ARM CPU) | - | - | Run PaddleLite inference |
| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
| `use_calib` | bool | False | Only for deployment with TensorRT |
| `use_calib` | bool | False | Use TRT int8 calibration |
| `gpu_multi_stream` | bool | False | EnableGpuMultiStream to get larger QPS |
#### Description of asynchronous model
......
......@@ -187,7 +187,7 @@ python3 -m paddle_serving_server.serve --model uci_housing_model --thread 10 --p
| `use_lite` (Only for Intel x86 CPU or ARM CPU) | - | - | Run PaddleLite inference |
| `use_xpu` | - | - | Run PaddleLite inference with Baidu Kunlun XPU |
| `precision` | str | FP32 | Precision Mode, support FP32, FP16, INT8 |
| `use_calib` | bool | False | Only for deployment with TensorRT |
| `use_calib` | bool | False | Use TRT int8 calibration |
| `gpu_multi_stream` | bool | False | EnableGpuMultiStream to get larger QPS |
#### 异步模型的说明
......
......@@ -266,10 +266,14 @@ class PredictorClient {
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int>>& float_shape,
const std::vector<std::vector<int>>& float_lod_slot_batch,
const std::vector<py::array_t<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::vector<int>>& int_shape,
const std::vector<std::vector<int>>& int_lod_slot_batch,
const std::vector<py::array_t<int32_t>> &int32_feed,
const std::vector<std::string> &int32_feed_name,
const std::vector<std::vector<int>> &int32_shape,
const std::vector<std::vector<int>> &int32_lod_slot_batch,
const std::vector<py::array_t<int64_t>> &int64_feed,
const std::vector<std::string> &int64_feed_name,
const std::vector<std::vector<int>> &int64_shape,
const std::vector<std::vector<int>> &int64_lod_slot_batch,
const std::vector<std::string>& string_feed,
const std::vector<std::string>& string_feed_name,
const std::vector<std::vector<int>>& string_shape,
......
......@@ -168,10 +168,14 @@ int PredictorClient::numpy_predict(
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<int>> &float_lod_slot_batch,
const std::vector<py::array_t<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::vector<int>> &int_lod_slot_batch,
const std::vector<py::array_t<int32_t>> &int32_feed,
const std::vector<std::string> &int32_feed_name,
const std::vector<std::vector<int>> &int32_shape,
const std::vector<std::vector<int>> &int32_lod_slot_batch,
const std::vector<py::array_t<int64_t>> &int64_feed,
const std::vector<std::string> &int64_feed_name,
const std::vector<std::vector<int>> &int64_shape,
const std::vector<std::vector<int>> &int64_lod_slot_batch,
const std::vector<std::string> &string_feed,
const std::vector<std::string> &string_feed_name,
const std::vector<std::vector<int>> &string_shape,
......@@ -190,7 +194,8 @@ int PredictorClient::numpy_predict(
predict_res_batch.set_variant_tag(variant_tag);
VLOG(2) << "fetch general model predictor done.";
VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size();
VLOG(2) << "int feed name size: " << int32_feed_name.size();
VLOG(2) << "int feed name size: " << int64_feed_name.size();
VLOG(2) << "string feed name size: " << string_feed_name.size();
VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size;
Request req;
......@@ -207,7 +212,11 @@ int PredictorClient::numpy_predict(
tensor_vec.push_back(req.add_tensor());
}
for (auto &name : int_feed_name) {
for (auto &name : int32_feed_name) {
tensor_vec.push_back(req.add_tensor());
}
for (auto &name : int64_feed_name) {
tensor_vec.push_back(req.add_tensor());
}
......@@ -247,34 +256,58 @@ int PredictorClient::numpy_predict(
}
vec_idx = 0;
for (auto &name : int_feed_name) {
for (auto &name : int32_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
Tensor *tensor = tensor_vec[idx];
int nbytes = int_feed[vec_idx].nbytes();
void *rawdata_ptr = (void *)(int_feed[vec_idx].data(0));
int total_number = int_feed[vec_idx].size();
int nbytes = int32_feed[vec_idx].nbytes();
void *rawdata_ptr = (void *)(int32_feed[vec_idx].data(0));
int total_number = int32_feed[vec_idx].size();
for (uint32_t j = 0; j < int_shape[vec_idx].size(); ++j) {
tensor->add_shape(int_shape[vec_idx][j]);
for (uint32_t j = 0; j < int32_shape[vec_idx].size(); ++j) {
tensor->add_shape(int32_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < int_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(int_lod_slot_batch[vec_idx][j]);
for (uint32_t j = 0; j < int32_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(int32_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(_type[idx]);
tensor->set_name(_feed_name[idx]);
tensor->set_alias_name(name);
if (_type[idx] == P_INT64) {
tensor->mutable_int64_data()->Resize(total_number, 0);
memcpy(tensor->mutable_int64_data()->mutable_data(), rawdata_ptr, nbytes);
} else {
tensor->mutable_int_data()->Resize(total_number, 0);
memcpy(tensor->mutable_int_data()->mutable_data(), rawdata_ptr, nbytes);
tensor->mutable_int_data()->Resize(total_number, 0);
memcpy(tensor->mutable_int_data()->mutable_data(), rawdata_ptr, nbytes);
vec_idx++;
}
// Individual INT_64 feed data of int_input to tensor_content
vec_idx = 0;
for (auto &name : int64_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
Tensor *tensor = tensor_vec[idx];
int nbytes = int64_feed[vec_idx].nbytes();
void *rawdata_ptr = (void *)(int64_feed[vec_idx].data(0));
int total_number = int64_feed[vec_idx].size();
for (uint32_t j = 0; j < int64_shape[vec_idx].size(); ++j) {
tensor->add_shape(int64_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < int64_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(int64_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(_type[idx]);
tensor->set_name(_feed_name[idx]);
tensor->set_alias_name(name);
tensor->mutable_int64_data()->Resize(total_number, 0);
memcpy(tensor->mutable_int64_data()->mutable_data(), rawdata_ptr, nbytes);
vec_idx++;
}
......
......@@ -117,10 +117,14 @@ PYBIND11_MODULE(serving_client, m) {
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<int>> &float_lod_slot_batch,
const std::vector<py::array_t<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::vector<int>> &int_lod_slot_batch,
const std::vector<py::array_t<int32_t>> &int32_feed,
const std::vector<std::string> &int32_feed_name,
const std::vector<std::vector<int>> &int32_shape,
const std::vector<std::vector<int>> &int32_lod_slot_batch,
const std::vector<py::array_t<int64_t>> &int64_feed,
const std::vector<std::string> &int64_feed_name,
const std::vector<std::vector<int>> &int64_shape,
const std::vector<std::vector<int>> &int64_lod_slot_batch,
const std::vector<std::string> &string_feed,
const std::vector<std::string> &string_feed_name,
const std::vector<std::vector<int>> &string_shape,
......@@ -133,10 +137,14 @@ PYBIND11_MODULE(serving_client, m) {
float_feed_name,
float_shape,
float_lod_slot_batch,
int_feed,
int_feed_name,
int_shape,
int_lod_slot_batch,
int32_feed,
int32_feed_name,
int32_shape,
int32_lod_slot_batch,
int64_feed,
int64_feed_name,
int64_shape,
int64_lod_slot_batch,
string_feed,
string_feed_name,
string_shape,
......
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/cascade_mask_rcnn_r50_vd_fpn_ssld_2x_coco_serving.tar.gz
tar xf cascade_mask_rcnn_r50_vd_fpn_ssld_2x_coco_serving.tar.gz
......@@ -12,29 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import numpy as np
import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Div(255.0),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
Resize(800, 1333), Transpose((2, 0, 1)), PadStride(32)
File2Image(), BGR2RGB(), Resize(
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose(
(2, 0, 1))
])
postprocess = RCNNPostprocess("label_list.txt", "output")
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt")
client.connect(['127.0.0.1:9292'])
im = preprocess('000000570688.jpg')
fetch_map = client.predict(
feed={
"image": im,
"im_info": np.array(list(im.shape[1:]) + [1.0]),
"im_shape": np.array(list(im.shape[1:]) + [1.0])
"im_shape": np.array(list(im.shape[1:])).reshape(-1),
"scale_factor": np.array([1.0, 1.0]).reshape(-1),
},
fetch=["multiclass_nms_0.tmp_0"],
fetch=["save_infer_model/scale_0.tmp_1"],
batch=False)
fetch_map["image"] = '000000570688.jpg'
print(fetch_map)
fetch_map["image"] = '000000570688.jpg'
postprocess(fetch_map)
print(fetch_map)
......@@ -4,7 +4,7 @@
### Get The Faster RCNN HRNet Model
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/faster_rcnn_hrnetv2p_w18_1x.tar
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/faster_rcnn_hrnetv2p_w18_1x.tar.gz
```
### Start the service
......
......@@ -4,7 +4,7 @@
## 获得Faster RCNN HRNet模型
```
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/faster_rcnn_hrnetv2p_w18_1x.tar
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/faster_rcnn_hrnetv2p_w18_1x.tar.gz
```
......
from paddle_serving_client import Client
from paddle_serving_app.reader import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Div(255.0),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
Resize(640, 640), Transpose((2, 0, 1))
File2Image(), BGR2RGB(), Resize(
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose(
(2, 0, 1))
])
postprocess = RCNNPostprocess("label_list.txt", "output")
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt")
......@@ -19,9 +33,11 @@ im = preprocess(sys.argv[1])
fetch_map = client.predict(
feed={
"image": im,
"im_info": np.array(list(im.shape[1:]) + [1.0]),
"im_shape": np.array(list(im.shape[1:]) + [1.0])
"im_shape": np.array(list(im.shape[1:])).reshape(-1),
"scale_factor": np.array([1.0, 1.0]).reshape(-1),
},
fetch=["multiclass_nms_0.tmp_0"],
fetch=["save_infer_model/scale_0.tmp_1"],
batch=False)
print(fetch_map)
fetch_map["image"] = sys.argv[1]
postprocess(fetch_map)
......@@ -12,18 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import sys
import numpy as np
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Div(255.0),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
Resize((608, 608)), Transpose((2, 0, 1))
File2Image(), BGR2RGB(), Resize(
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose(
(2, 0, 1))
])
postprocess = RCNNPostprocess("label_list.txt", "output")
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt")
......
......@@ -12,18 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import sys
import numpy as np
from paddle_serving_client import Client
from paddle_serving_app.reader import *
import cv2
preprocess = Sequential([
File2Image(), BGR2RGB(), Div(255.0),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], False),
Resize((608, 608)), Transpose((2, 0, 1))
File2Image(), BGR2RGB(), Resize(
(608, 608), interpolation=cv2.INTER_LINEAR), Div(255.0), Transpose(
(2, 0, 1))
])
postprocess = RCNNPostprocess("label_list.txt", "output")
postprocess = RCNNPostprocess("label_list.txt", "output", [608, 608])
client = Client()
client.load_client_config("serving_client/serving_client_conf.prototxt")
......
......@@ -335,10 +335,15 @@ class Client(object):
if len(feed_batch) != 1:
raise ValueError("len of feed_batch can only be 1.")
int_slot = []
int_feed_names = []
int_shape = []
int_lod_slot_batch = []
int32_slot = []
int32_feed_names = []
int32_shape = []
int32_lod_slot_batch = []
int64_slot = []
int64_feed_names = []
int64_shape = []
int64_lod_slot_batch = []
float_slot = []
float_feed_names = []
......@@ -364,27 +369,39 @@ class Client(object):
self.shape_check(feed_dict, key)
if self.feed_types_[key] in int_type:
int_feed_names.append(key)
shape_lst = []
if batch == False:
feed_dict[key] = np.expand_dims(feed_dict[key], 0).repeat(
1, axis=0)
if isinstance(feed_dict[key], np.ndarray):
shape_lst.extend(list(feed_dict[key].shape))
int_shape.append(shape_lst)
else:
int_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_dict:
int_lod_slot_batch.append(feed_dict["{}.lod".format(key)])
# verify different input int_type
if(self.feed_types_[key] == int64_type):
int64_feed_names.append(key)
if isinstance(feed_dict[key], np.ndarray):
shape_lst.extend(list(feed_dict[key].shape))
int64_shape.append(shape_lst)
self.has_numpy_input = True
else:
int64_shape.append(self.feed_shapes_[key])
self.all_numpy_input = False
if "{}.lod".format(key) in feed_dict:
int64_lod_slot_batch.append(feed_dict["{}.lod".format(key)])
else:
int64_lod_slot_batch.append([])
int64_slot.append(np.ascontiguousarray(feed_dict[key]))
else:
int_lod_slot_batch.append([])
if isinstance(feed_dict[key], np.ndarray):
int_slot.append(np.ascontiguousarray(feed_dict[key]))
self.has_numpy_input = True
else:
int_slot.append(np.ascontiguousarray(feed_dict[key]))
self.all_numpy_input = False
int32_feed_names.append(key)
if isinstance(feed_dict[key], np.ndarray):
shape_lst.extend(list(feed_dict[key].shape))
int32_shape.append(shape_lst)
self.has_numpy_input = True
else:
int32_shape.append(self.feed_shapes_[key])
self.all_numpy_input = False
if "{}.lod".format(key) in feed_dict:
int32_lod_slot_batch.append(feed_dict["{}.lod".format(key)])
else:
int32_lod_slot_batch.append([])
int32_slot.append(np.ascontiguousarray(feed_dict[key]))
elif self.feed_types_[key] in float_type:
float_feed_names.append(key)
......@@ -430,7 +447,8 @@ class Client(object):
if self.all_numpy_input:
res = self.client_handle_.numpy_predict(
float_slot, float_feed_names, float_shape, float_lod_slot_batch,
int_slot, int_feed_names, int_shape, int_lod_slot_batch,
int32_slot, int32_feed_names, int32_shape, int32_lod_slot_batch,
int64_slot, int64_feed_names, int64_shape, int64_lod_slot_batch,
string_slot, string_feed_names, string_shape,
string_lod_slot_batch, fetch_names, result_batch_handle,
self.pid, log_id)
......
......@@ -564,7 +564,7 @@ class Server(object):
"-num_threads {} " \
"-port {} " \
"-precision {} " \
"-use_calib {} " \
"-use_calib={} " \
"-reload_interval_s {} " \
"-resource_path {} " \
"-resource_file {} " \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册