未验证 提交 0849586d 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #882 from wangjiawei04/trt

Fix Local Predictor bug
...@@ -8,8 +8,8 @@ sh get_data.sh ...@@ -8,8 +8,8 @@ sh get_data.sh
## 启动服务 ## 启动服务
``` ```
python -m paddle_serving_server_gpu.serve --model imdb_cnn_model --port 9292 &> cnn.log & python -m paddle_serving_server.serve --model imdb_cnn_model --port 9292 &> cnn.log &
python -m paddle_serving_server_gpu.serve --model imdb_bow_model --port 9393 &> bow.log & python -m paddle_serving_server.serve --model imdb_bow_model --port 9393 &> bow.log &
python test_pipeline_server.py &>pipeline.log & python test_pipeline_server.py &>pipeline.log &
``` ```
...@@ -17,8 +17,3 @@ python test_pipeline_server.py &>pipeline.log & ...@@ -17,8 +17,3 @@ python test_pipeline_server.py &>pipeline.log &
``` ```
python test_pipeline_client.py python test_pipeline_client.py
``` ```
## HTTP 测试
```
curl -X POST -k http://localhost:9999/prediction -d '{"key": ["words"], "value": ["i am very sad | 0"]}'
```
...@@ -41,7 +41,9 @@ class ImdbRequestOp(RequestOp): ...@@ -41,7 +41,9 @@ class ImdbRequestOp(RequestOp):
continue continue
words = request.value[idx] words = request.value[idx]
word_ids, _ = self.imdb_dataset.get_words_and_label(words) word_ids, _ = self.imdb_dataset.get_words_and_label(words)
dictdata[key] = np.array(word_ids) word_len = len(word_ids)
dictdata[key] = np.array(word_ids).reshape(word_len, 1)
dictdata["{}.lod".format(key)] = [0, word_len]
return dictdata return dictdata
...@@ -77,16 +79,18 @@ bow_op = Op(name="bow", ...@@ -77,16 +79,18 @@ bow_op = Op(name="bow",
server_endpoints=["127.0.0.1:9393"], server_endpoints=["127.0.0.1:9393"],
fetch_list=["prediction"], fetch_list=["prediction"],
client_config="imdb_bow_client_conf/serving_client_conf.prototxt", client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
client_type='brpc',
concurrency=1, concurrency=1,
timeout=-1, timeout=-1,
retry=1, retry=1,
batch_size=3, batch_size=1,
auto_batching_timeout=1000) auto_batching_timeout=None)
cnn_op = Op(name="cnn", cnn_op = Op(name="cnn",
input_ops=[read_op], input_ops=[read_op],
server_endpoints=["127.0.0.1:9292"], server_endpoints=["127.0.0.1:9292"],
fetch_list=["prediction"], fetch_list=["prediction"],
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt", client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
client_type='brpc',
concurrency=1, concurrency=1,
timeout=-1, timeout=-1,
retry=1, retry=1,
......
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_server_gpu.pipeline import Op, RequestOp, ResponseOp from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server_gpu.pipeline import PipelineServer from paddle_serving_server.pipeline import PipelineServer
from paddle_serving_server_gpu.pipeline.proto import pipeline_service_pb2 from paddle_serving_server.pipeline.proto import pipeline_service_pb2
from paddle_serving_server_gpu.pipeline.channel import ChannelDataEcode from paddle_serving_server.pipeline.channel import ChannelDataEcode
from paddle_serving_server_gpu.pipeline import LocalRpcServiceHandler from paddle_serving_server.pipeline import LocalServiceHandler
import numpy as np import numpy as np
import cv2 import cv2
import time import time
...@@ -56,9 +56,11 @@ class DetOp(Op): ...@@ -56,9 +56,11 @@ class DetOp(Op):
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
# Note: class variables(self.var) can only be used in process op mode # Note: class variables(self.var) can only be used in process op mode
self.im = cv2.imdecode(data, cv2.IMREAD_COLOR) self.im = cv2.imdecode(data, cv2.IMREAD_COLOR)
print(self.im)
self.ori_h, self.ori_w, _ = self.im.shape self.ori_h, self.ori_w, _ = self.im.shape
det_img = self.det_preprocess(self.im) det_img = self.det_preprocess(self.im)
_, self.new_h, self.new_w = det_img.shape _, self.new_h, self.new_w = det_img.shape
print("image", det_img)
return {"image": det_img} return {"image": det_img}
def postprocess(self, input_dicts, fetch_dict): def postprocess(self, input_dicts, fetch_dict):
...@@ -111,11 +113,11 @@ read_op = RequestOp() ...@@ -111,11 +113,11 @@ read_op = RequestOp()
det_op = DetOp( det_op = DetOp(
name="det", name="det",
input_ops=[read_op], input_ops=[read_op],
local_rpc_service_handler=LocalRpcServiceHandler( client_type="local_predictor",
local_service_handler=LocalServiceHandler(
model_config="ocr_det_model", model_config="ocr_det_model",
workdir="det_workdir", # defalut: "workdir" workdir="det_workdir", # defalut: "workdir"
thread_num=2, # defalut: 2 thread_num=2, # defalut: 2
devices="0", # gpu0. defalut: "" (cpu)
mem_optim=True, # defalut: True mem_optim=True, # defalut: True
ir_optim=False, # defalut: False ir_optim=False, # defalut: False
available_port_generator=None), # defalut: None available_port_generator=None), # defalut: None
...@@ -123,8 +125,8 @@ det_op = DetOp( ...@@ -123,8 +125,8 @@ det_op = DetOp(
rec_op = RecOp( rec_op = RecOp(
name="rec", name="rec",
input_ops=[det_op], input_ops=[det_op],
local_rpc_service_handler=LocalRpcServiceHandler( client_type="local_predictor",
model_config="ocr_rec_model"), local_service_handler=LocalServiceHandler(model_config="ocr_rec_model"),
concurrency=1) concurrency=1)
response_op = ResponseOp(input_ops=[rec_op]) response_op = ResponseOp(input_ops=[rec_op])
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle_serving_server_gpu.pipeline import PipelineClient from paddle_serving_server.pipeline import PipelineClient
import numpy as np import numpy as np
import requests import requests
import json import json
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle_serving_server_gpu.pipeline import PipelineClient from paddle_serving_server.pipeline import PipelineClient
import numpy as np import numpy as np
import requests import requests
import json import json
...@@ -33,6 +33,6 @@ for img_file in os.listdir(test_img_dir): ...@@ -33,6 +33,6 @@ for img_file in os.listdir(test_img_dir):
image_data = file.read() image_data = file.read()
image = cv2_to_base64(image_data) image = cv2_to_base64(image_data)
for i in range(4): for i in range(1):
ret = client.predict(feed_dict={"image": image}, fetch=["res"]) ret = client.predict(feed_dict={"image": image}, fetch=["res"])
print(ret) print(ret)
...@@ -7,3 +7,4 @@ op: ...@@ -7,3 +7,4 @@ op:
local_service_conf: local_service_conf:
model_config: uci_housing_model model_config: uci_housing_model
devices: "" # "0,1" devices: "" # "0,1"
client_type: brpc
...@@ -22,6 +22,7 @@ except ImportError: ...@@ -22,6 +22,7 @@ except ImportError:
from paddle_serving_server import OpMaker, OpSeqMaker, Server from paddle_serving_server import OpMaker, OpSeqMaker, Server
PACKAGE_VERSION = "CPU" PACKAGE_VERSION = "CPU"
from . import util from . import util
from paddle_serving_app.local_predict import LocalPredictor
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_workdir_name_gen = util.NameGenerator("workdir_") _workdir_name_gen = util.NameGenerator("workdir_")
...@@ -30,6 +31,7 @@ _workdir_name_gen = util.NameGenerator("workdir_") ...@@ -30,6 +31,7 @@ _workdir_name_gen = util.NameGenerator("workdir_")
class LocalServiceHandler(object): class LocalServiceHandler(object):
def __init__(self, def __init__(self,
model_config, model_config,
client_type='local_predictor',
workdir="", workdir="",
thread_num=2, thread_num=2,
devices="", devices="",
...@@ -58,12 +60,13 @@ class LocalServiceHandler(object): ...@@ -58,12 +60,13 @@ class LocalServiceHandler(object):
self._port_list.append(available_port_generator.next()) self._port_list.append(available_port_generator.next())
_LOGGER.info("Model({}) will be launch in gpu device: {}. Port({})" _LOGGER.info("Model({}) will be launch in gpu device: {}. Port({})"
.format(model_config, devices, self._port_list)) .format(model_config, devices, self._port_list))
self.client_type = client_type
self._workdir = workdir self._workdir = workdir
self._devices = devices self._devices = devices
self._thread_num = thread_num self._thread_num = thread_num
self._mem_optim = mem_optim self._mem_optim = mem_optim
self._ir_optim = ir_optim self._ir_optim = ir_optim
self.local_predictor_client = None
self._rpc_service_list = [] self._rpc_service_list = []
self._server_pros = [] self._server_pros = []
self._fetch_vars = None self._fetch_vars = None
...@@ -74,6 +77,13 @@ class LocalServiceHandler(object): ...@@ -74,6 +77,13 @@ class LocalServiceHandler(object):
def get_port_list(self): def get_port_list(self):
return self._port_list return self._port_list
def get_client(self): # for local_predictor_only
if self.local_predictor_client is None:
self.local_predictor_client = LocalPredictor()
self.local_predictor_client.load_model_config(
"{}".format(self._model_config), gpu=False, profile=False)
return self.local_predictor_client
def get_client_config(self): def get_client_config(self):
return os.path.join(self._model_config, "serving_server_conf.prototxt") return os.path.join(self._model_config, "serving_server_conf.prototxt")
......
...@@ -51,6 +51,7 @@ class Op(object): ...@@ -51,6 +51,7 @@ class Op(object):
server_endpoints=None, server_endpoints=None,
fetch_list=None, fetch_list=None,
client_config=None, client_config=None,
client_type=None,
concurrency=None, concurrency=None,
timeout=None, timeout=None,
retry=None, retry=None,
...@@ -68,6 +69,7 @@ class Op(object): ...@@ -68,6 +69,7 @@ class Op(object):
self._server_endpoints = server_endpoints self._server_endpoints = server_endpoints
self._fetch_names = fetch_list self._fetch_names = fetch_list
self._client_config = client_config self._client_config = client_config
self.client_type = client_type
self._timeout = timeout self._timeout = timeout
self._retry = max(1, retry) self._retry = max(1, retry)
self._batch_size = batch_size self._batch_size = batch_size
...@@ -138,6 +140,7 @@ class Op(object): ...@@ -138,6 +140,7 @@ class Op(object):
if self.client_type == "brpc" or self.client_type == "grpc": if self.client_type == "brpc" or self.client_type == "grpc":
service_handler = local_service_handler.LocalServiceHandler( service_handler = local_service_handler.LocalServiceHandler(
model_config=model_config, model_config=model_config,
client_type=self.client_type,
workdir=local_service_conf["workdir"], workdir=local_service_conf["workdir"],
thread_num=local_service_conf["thread_num"], thread_num=local_service_conf["thread_num"],
devices=local_service_conf["devices"], devices=local_service_conf["devices"],
...@@ -155,12 +158,13 @@ class Op(object): ...@@ -155,12 +158,13 @@ class Op(object):
self._fetch_names = service_handler.get_fetch_list( self._fetch_names = service_handler.get_fetch_list(
) )
elif self.client_type == "local_predictor": elif self.client_type == "local_predictor":
service_handler = local_service_handler.LocalPredictorServiceHandler( service_handler = local_service_handler.LocalServiceHandler(
model_config=model_config, model_config=model_config,
client_type=self.client_type,
workdir=local_service_conf["workdir"], workdir=local_service_conf["workdir"],
thread_num=local_service_conf["thread_num"], thread_num=local_service_conf["thread_num"],
devices=local_service_conf["devices"]) devices=local_service_conf["devices"])
service_handler.prepare_server() # get fetch_list #service_handler.prepare_server() # get fetch_list
self.local_predictor = service_handler.get_client() self.local_predictor = service_handler.get_client()
if self._client_config is None: if self._client_config is None:
self._client_config = service_handler.get_client_config( self._client_config = service_handler.get_client_config(
...@@ -210,6 +214,9 @@ class Op(object): ...@@ -210,6 +214,9 @@ class Op(object):
" service: local_service_handler is None.")) " service: local_service_handler is None."))
return return
port = self._local_service_handler.get_port_list() port = self._local_service_handler.get_port_list()
#if self._local_service_handler.client_type == "local_predictor":
# _LOGGER.info("Op({}) use local predictor.")
# return
self._local_service_handler.start_server() self._local_service_handler.start_server()
_LOGGER.info("Op({}) use local rpc service at port: {}" _LOGGER.info("Op({}) use local rpc service at port: {}"
.format(self.name, port)) .format(self.name, port))
...@@ -248,6 +255,9 @@ class Op(object): ...@@ -248,6 +255,9 @@ class Op(object):
else: else:
raise ValueError("Failed to init client: unknow client " raise ValueError("Failed to init client: unknow client "
"type {}".format(self.client_type)) "type {}".format(self.client_type))
if self._fetch_names is None:
self._fetch_names = client.fetch_names_
_LOGGER.info("Op({}) has no fetch name set. So fetch all vars")
if self.client_type != "local_predictor": if self.client_type != "local_predictor":
client.connect(server_endpoints) client.connect(server_endpoints)
return client return client
...@@ -310,7 +320,7 @@ class Op(object): ...@@ -310,7 +320,7 @@ class Op(object):
(_, input_dict), = input_dicts.items() (_, input_dict), = input_dicts.items()
return input_dict return input_dict
def process(self, feed_batch, fetch_names, typical_logid): def process(self, feed_batch, typical_logid):
err, err_info = ChannelData.check_batch_npdata(feed_batch) err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0: if err != 0:
_LOGGER.critical( _LOGGER.critical(
...@@ -320,13 +330,13 @@ class Op(object): ...@@ -320,13 +330,13 @@ class Op(object):
if self.client_type == "local_predictor": if self.client_type == "local_predictor":
call_result = self.client.predict( call_result = self.client.predict(
feed=feed_batch[0], feed=feed_batch[0],
fetch=fetch_names, fetch=self._fetch_names,
batch=True, batch=True,
log_id=typical_logid) log_id=typical_logid)
else: else:
call_result = self.client.predict( call_result = self.client.predict(
feed=feed_batch, feed=feed_batch,
fetch=fetch_names, fetch=self._fetch_names,
batch=True, batch=True,
log_id=typical_logid) log_id=typical_logid)
if isinstance(self.client, MultiLangClient): if isinstance(self.client, MultiLangClient):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册