未验证 提交 d5f1da66 编写于 作者: T TeslaZhao 提交者: GitHub

Merge branch 'develop' into grpc-fix

...@@ -10,7 +10,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ ...@@ -10,7 +10,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/
### Start the service ### Start the service
``` ```
tar xf faster_rcnn_hrnetv2p_w18_1x.tar tar xf faster_rcnn_hrnetv2p_w18_1x.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
This model support TensorRT, if you want a faster inference, please use `--use_trt`. This model support TensorRT, if you want a faster inference, please use `--use_trt`.
......
...@@ -11,7 +11,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ ...@@ -11,7 +11,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/
### 启动服务 ### 启动服务
``` ```
tar xf faster_rcnn_hrnetv2p_w18_1x.tar tar xf faster_rcnn_hrnetv2p_w18_1x.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。 该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。
......
...@@ -10,7 +10,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ ...@@ -10,7 +10,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/
### Start the service ### Start the service
``` ```
tar xf fcos_dcn_r50_fpn_1x_coco.tar tar xf fcos_dcn_r50_fpn_1x_coco.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
This model support TensorRT, if you want a faster inference, please use `--use_trt`. This model support TensorRT, if you want a faster inference, please use `--use_trt`.
...@@ -18,4 +18,3 @@ This model support TensorRT, if you want a faster inference, please use `--use_t ...@@ -18,4 +18,3 @@ This model support TensorRT, if you want a faster inference, please use `--use_t
``` ```
python test_client.py 000000570688.jpg python test_client.py 000000570688.jpg
``` ```
...@@ -11,7 +11,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ ...@@ -11,7 +11,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/
### 启动服务 ### 启动服务
``` ```
tar xf fcos_dcn_r50_fpn_1x_coco.tar tar xf fcos_dcn_r50_fpn_1x_coco.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。 该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。
...@@ -20,4 +20,3 @@ python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --g ...@@ -20,4 +20,3 @@ python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --g
``` ```
python test_client.py 000000570688.jpg python test_client.py 000000570688.jpg
``` ```
...@@ -10,7 +10,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ ...@@ -10,7 +10,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/
### Start the service ### Start the service
``` ```
tar xf ssd_vgg16_300_240e_voc.tar tar xf ssd_vgg16_300_240e_voc.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
This model support TensorRT, if you want a faster inference, please use `--use_trt`. This model support TensorRT, if you want a faster inference, please use `--use_trt`.
...@@ -18,4 +18,3 @@ This model support TensorRT, if you want a faster inference, please use `--use_t ...@@ -18,4 +18,3 @@ This model support TensorRT, if you want a faster inference, please use `--use_t
``` ```
python test_client.py 000000570688.jpg python test_client.py 000000570688.jpg
``` ```
...@@ -11,7 +11,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/ ...@@ -11,7 +11,7 @@ wget --no-check-certificate https://paddle-serving.bj.bcebos.com/pddet_demo/2.0/
### 启动服务 ### 启动服务
``` ```
tar xf ssd_vgg16_300_240e_voc.tar tar xf ssd_vgg16_300_240e_voc.tar
python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --gpu_ids 0 python -m paddle_serving_server.serve --model serving_server --port 9494 --gpu_ids 0
``` ```
该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。 该模型支持TensorRT,如果想要更快的预测速度,可以开启`--use_trt`选项。
...@@ -20,4 +20,3 @@ python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --g ...@@ -20,4 +20,3 @@ python -m paddle_serving_server_gpu.serve --model serving_server --port 9494 --g
``` ```
python test_client.py 000000570688.jpg python test_client.py 000000570688.jpg
``` ```
...@@ -107,7 +107,7 @@ ocr_service.prepare_server(workdir="workdir", port=9292) ...@@ -107,7 +107,7 @@ ocr_service.prepare_server(workdir="workdir", port=9292)
ocr_service.init_det_debugger(det_model_config="ocr_det_model") ocr_service.init_det_debugger(det_model_config="ocr_det_model")
if sys.argv[1] == 'gpu': if sys.argv[1] == 'gpu':
ocr_service.set_gpus("2") ocr_service.set_gpus("2")
ocr_service.run_debugger_service(gpu = True) ocr_service.run_debugger_service(gpu=True)
elif sys.argv[1] == 'cpu': elif sys.argv[1] == 'cpu':
ocr_service.run_debugger_service() ocr_service.run_debugger_service()
ocr_service.run_web_service() ocr_service.run_web_service()
...@@ -71,7 +71,8 @@ ocr_service.load_model_config("ocr_rec_model") ...@@ -71,7 +71,8 @@ ocr_service.load_model_config("ocr_rec_model")
if sys.argv[1] == 'gpu': if sys.argv[1] == 'gpu':
ocr_service.set_gpus("0") ocr_service.set_gpus("0")
ocr_service.init_rec() ocr_service.init_rec()
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0) ocr_service.prepare_server(
workdir="workdir", port=9292, device="gpu", gpuid=0)
elif sys.argv[1] == 'cpu': elif sys.argv[1] == 'cpu':
ocr_service.init_rec() ocr_service.init_rec()
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu") ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import os import os
import yaml import yaml
import requests import requests
import time import time
import json import json
try: from paddle_serving_server.pipeline import PipelineClient
from paddle_serving_server_gpu.pipeline import PipelineClient
except ImportError:
from paddle_serving_server.pipeline import PipelineClient
import numpy as np import numpy as np
client = PipelineClient() client = PipelineClient()
client.connect(['127.0.0.1:9998']) client.connect(['127.0.0.1:9998'])
batch_size = 101 batch_size = 101
with open("data-c.txt", 'r') as fin: with open("data-c.txt", 'r') as fin:
lines = fin.readlines() lines = fin.readlines()
start_idx = 0 start_idx = 0
while start_idx < len(lines): while start_idx < len(lines):
end_idx = min(len(lines), start_idx + batch_size) end_idx = min(len(lines), start_idx + batch_size)
feed = {} feed = {}
for i in range(start_idx, end_idx): for i in range(start_idx, end_idx):
feed[str(i - start_idx)] = lines[i] feed[str(i - start_idx)] = lines[i]
ret = client.predict(feed_dict=feed, fetch=["res"]) ret = client.predict(feed_dict=feed, fetch=["res"])
print(ret) print(ret)
start_idx += batch_size start_idx += batch_size
...@@ -11,10 +11,7 @@ ...@@ -11,10 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
try: from paddle_serving_server.web_service import WebService, Op
from paddle_serving_server_gpu.web_service import WebService, Op
except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging import logging
import numpy as np import numpy as np
import sys import sys
...@@ -37,7 +34,8 @@ class BertOp(Op): ...@@ -37,7 +34,8 @@ class BertOp(Op):
for i in range(batch_size): for i in range(batch_size):
feed_dict = self.reader.process(input_dict[str(i)].encode("utf-8")) feed_dict = self.reader.process(input_dict[str(i)].encode("utf-8"))
for key in feed_dict.keys(): for key in feed_dict.keys():
feed_dict[key] = np.array(feed_dict[key]).reshape((1, len(feed_dict[key]), 1)) feed_dict[key] = np.array(feed_dict[key]).reshape(
(1, len(feed_dict[key]), 1))
feed_res.append(feed_dict) feed_res.append(feed_dict)
feed_dict = {} feed_dict = {}
for key in feed_res[0].keys(): for key in feed_res[0].keys():
...@@ -57,5 +55,5 @@ class BertService(WebService): ...@@ -57,5 +55,5 @@ class BertService(WebService):
bert_service = BertService(name="bert") bert_service = BertService(name="bert")
bert_service.prepare_pipeline_config("config2.yml") bert_service.prepare_pipeline_config("config.yml")
bert_service.run_service() bert_service.run_service()
...@@ -13,10 +13,7 @@ ...@@ -13,10 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize, Base64ToImage from paddle_serving_app.reader import Sequential, URL2Image, Resize, CenterCrop, RGB2BGR, Transpose, Div, Normalize, Base64ToImage
try: from paddle_serving_server.web_service import WebService, Op
from paddle_serving_server.web_service import WebService, Op
except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging import logging
import numpy as np import numpy as np
import base64, cv2 import base64, cv2
......
...@@ -12,17 +12,14 @@ ...@@ -12,17 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
import numpy as np
from paddle_serving_app.reader.imdb_reader import IMDBDataset
import logging
from paddle_serving_server.web_service import WebService
from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp from paddle_serving_server.pipeline import Op, RequestOp, ResponseOp
from paddle_serving_server.pipeline import PipelineServer from paddle_serving_server.pipeline import PipelineServer
from paddle_serving_server.pipeline.proto import pipeline_service_pb2 from paddle_serving_server.pipeline.proto import pipeline_service_pb2
from paddle_serving_server.pipeline.channel import ChannelDataErrcode from paddle_serving_server.pipeline.channel import ChannelDataErrcode
import numpy as np
from paddle_serving_app.reader.imdb_reader import IMDBDataset
import logging
try:
from paddle_serving_server.web_service import WebService
except ImportError:
from paddle_serving_server.web_service import WebService
_LOGGER = logging.getLogger() _LOGGER = logging.getLogger()
user_handler = logging.StreamHandler() user_handler = logging.StreamHandler()
......
...@@ -11,10 +11,7 @@ ...@@ -11,10 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
try: from paddle_serving_server.web_service import WebService, Op
from paddle_serving_server_gpu.web_service import WebService, Op
except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging import logging
import numpy as np import numpy as np
import cv2 import cv2
...@@ -48,7 +45,7 @@ class DetOp(Op): ...@@ -48,7 +45,7 @@ class DetOp(Op):
imgs = [] imgs = []
for key in input_dict.keys(): for key in input_dict.keys():
data = base64.b64decode(input_dict[key].encode('utf8')) data = base64.b64decode(input_dict[key].encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.frombuffer(data, np.uint8)
self.im = cv2.imdecode(data, cv2.IMREAD_COLOR) self.im = cv2.imdecode(data, cv2.IMREAD_COLOR)
self.ori_h, self.ori_w, _ = self.im.shape self.ori_h, self.ori_w, _ = self.im.shape
det_img = self.det_preprocess(self.im) det_img = self.det_preprocess(self.im)
...@@ -57,7 +54,7 @@ class DetOp(Op): ...@@ -57,7 +54,7 @@ class DetOp(Op):
return {"image": np.concatenate(imgs, axis=0)}, False, None, "" return {"image": np.concatenate(imgs, axis=0)}, False, None, ""
def postprocess(self, input_dicts, fetch_dict, log_id): def postprocess(self, input_dicts, fetch_dict, log_id):
# print(fetch_dict) # print(fetch_dict)
det_out = fetch_dict["concat_1.tmp_0"] det_out = fetch_dict["concat_1.tmp_0"]
ratio_list = [ ratio_list = [
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
...@@ -114,5 +111,5 @@ class OcrService(WebService): ...@@ -114,5 +111,5 @@ class OcrService(WebService):
uci_service = OcrService(name="ocr") uci_service = OcrService(name="ocr")
uci_service.prepare_pipeline_config("config2.yml") uci_service.prepare_pipeline_config("config.yml")
uci_service.run_service() uci_service.run_service()
...@@ -11,10 +11,8 @@ ...@@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
try:
from paddle_serving_server.web_service import WebService, Op from paddle_serving_server.web_service import WebService, Op
except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging import logging
import numpy as np import numpy as np
import sys import sys
...@@ -34,8 +32,11 @@ class UciOp(Op): ...@@ -34,8 +32,11 @@ class UciOp(Op):
x_value = input_dict["x"].split(self.batch_separator) x_value = input_dict["x"].split(self.batch_separator)
x_lst = [] x_lst = []
for x_val in x_value: for x_val in x_value:
x_lst.append(np.array([float(x.strip()) for x in x_val.split(self.separator)]).reshape(1, 13)) x_lst.append(
input_dict["x"] = np.concatenate(x_lst, axis=0) np.array([
float(x.strip()) for x in x_val.split(self.separator)
]).reshape(1, 13))
input_dict["x"] = np.concatenate(x_lst, axis=0)
proc_dict = {} proc_dict = {}
return input_dict, False, None, "" return input_dict, False, None, ""
...@@ -53,5 +54,5 @@ class UciService(WebService): ...@@ -53,5 +54,5 @@ class UciService(WebService):
uci_service = UciService(name="uci") uci_service = UciService(name="uci")
uci_service.prepare_pipeline_config("config2.yml") uci_service.prepare_pipeline_config("config.yml")
uci_service.run_service() uci_service.run_service()
...@@ -11,10 +11,7 @@ ...@@ -11,10 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
try: from paddle_serving_server.web_service import WebService, Op
from paddle_serving_server.web_service import WebService, Op
except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging import logging
import numpy as np import numpy as np
from numpy import array from numpy import array
......
...@@ -13,13 +13,9 @@ ...@@ -13,13 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle_serving_server.web_service import WebService
from paddle_serving_client import Client
from paddle_serving_app.reader import LACReader, SentaReader
import os import os
import sys import sys
import numpy as np import numpy as np
#senta_web_service.py
from paddle_serving_server.web_service import WebService from paddle_serving_server.web_service import WebService
from paddle_serving_client import Client from paddle_serving_client import Client
from paddle_serving_app.reader import LACReader, SentaReader from paddle_serving_app.reader import LACReader, SentaReader
......
...@@ -31,6 +31,7 @@ class UciService(WebService): ...@@ -31,6 +31,7 @@ class UciService(WebService):
uci_service = UciService(name="uci") uci_service = UciService(name="uci")
uci_service.load_model_config("uci_housing_model") uci_service.load_model_config("uci_housing_model")
uci_service.prepare_server(workdir="workdir", port=9393, use_lite=True, use_xpu=True, ir_optim=True) uci_service.prepare_server(
workdir="workdir", port=9393, use_lite=True, use_xpu=True, ir_optim=True)
uci_service.run_rpc_service() uci_service.run_rpc_service()
uci_service.run_web_service() uci_service.run_web_service()
...@@ -19,16 +19,12 @@ import os ...@@ -19,16 +19,12 @@ import os
import google.protobuf.text_format import google.protobuf.text_format
import numpy as np import numpy as np
import argparse import argparse
import paddle.fluid as fluid
import paddle.inference as inference
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
from paddle.fluid.core import PaddleTensor import paddle.inference as paddle_infer
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
import logging import logging
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid") logger = logging.getLogger("LocalPredictor")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
...@@ -62,7 +58,7 @@ class LocalPredictor(object): ...@@ -62,7 +58,7 @@ class LocalPredictor(object):
use_xpu=False, use_xpu=False,
use_feed_fetch_ops=False): use_feed_fetch_ops=False):
""" """
Load model config and set the engine config for the paddle predictor Load model configs and create the paddle predictor by Paddle Inference API.
Args: Args:
model_path: model config path. model_path: model config path.
...@@ -83,14 +79,18 @@ class LocalPredictor(object): ...@@ -83,14 +79,18 @@ class LocalPredictor(object):
model_conf = google.protobuf.text_format.Merge( model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf) str(f.read()), model_conf)
if os.path.exists(os.path.join(model_path, "__params__")): if os.path.exists(os.path.join(model_path, "__params__")):
config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, "__params__")) config = paddle_infer.Config(
os.path.join(model_path, "__model__"),
os.path.join(model_path, "__params__"))
else: else:
config = AnalysisConfig(model_path) config = paddle_infer.Config(model_path)
logger.info("load_model_config params: model_path:{}, use_gpu:{},\
logger.info(
"LocalPredictor load_model_config params: model_path:{}, use_gpu:{},\
gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\ gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\
use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format( use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format(
model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim, model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops)) ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops))
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
...@@ -129,7 +129,7 @@ class LocalPredictor(object): ...@@ -129,7 +129,7 @@ class LocalPredictor(object):
if use_lite: if use_lite:
config.enable_lite_engine( config.enable_lite_engine(
precision_mode=inference.PrecisionType.Float32, precision_mode=paddle_infer.PrecisionType.Float32,
zero_copy=True, zero_copy=True,
passes_filter=[], passes_filter=[],
ops_filter=[]) ops_filter=[])
...@@ -138,11 +138,11 @@ class LocalPredictor(object): ...@@ -138,11 +138,11 @@ class LocalPredictor(object):
# 2MB l3 cache # 2MB l3 cache
config.enable_xpu(8 * 1024 * 1024) config.enable_xpu(8 * 1024 * 1024)
self.predictor = create_paddle_predictor(config) self.predictor = paddle_infer.create_predictor(config)
def predict(self, feed=None, fetch=None, batch=False, log_id=0): def predict(self, feed=None, fetch=None, batch=False, log_id=0):
""" """
Predict locally Run model inference by Paddle Inference API.
Args: Args:
feed: feed var feed: feed var
...@@ -155,14 +155,16 @@ class LocalPredictor(object): ...@@ -155,14 +155,16 @@ class LocalPredictor(object):
fetch_map: dict fetch_map: dict
""" """
if feed is None or fetch is None: if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction") raise ValueError("You should specify feed and fetch for prediction.\
log_id:{}".format(log_id))
fetch_list = [] fetch_list = []
if isinstance(fetch, str): if isinstance(fetch, str):
fetch_list = [fetch] fetch_list = [fetch]
elif isinstance(fetch, list): elif isinstance(fetch, list):
fetch_list = fetch fetch_list = fetch
else: else:
raise ValueError("Fetch only accepts string and list of string") raise ValueError("Fetch only accepts string and list of string.\
log_id:{}".format(log_id))
feed_batch = [] feed_batch = []
if isinstance(feed, dict): if isinstance(feed, dict):
...@@ -170,27 +172,21 @@ class LocalPredictor(object): ...@@ -170,27 +172,21 @@ class LocalPredictor(object):
elif isinstance(feed, list): elif isinstance(feed, list):
feed_batch = feed feed_batch = feed
else: else:
raise ValueError("Feed only accepts dict and list of dict") raise ValueError("Feed only accepts dict and list of dict.\
log_id:{}".format(log_id))
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
float_feed_names = []
int_shape = []
float_shape = []
fetch_names = []
counter = 0
batch_size = len(feed_batch)
fetch_names = []
# Filter invalid fetch names
for key in fetch_list: for key in fetch_list:
if key in self.fetch_names_: if key in self.fetch_names_:
fetch_names.append(key) fetch_names.append(key)
if len(fetch_names) == 0: if len(fetch_names) == 0:
raise ValueError( raise ValueError(
"Fetch names should not be empty or out of saved fetch list.") "Fetch names should not be empty or out of saved fetch list.\
return {} log_id:{}".format(log_id))
# Assemble the input data of paddle predictor
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for name in input_names: for name in input_names:
if isinstance(feed[name], list): if isinstance(feed[name], list):
...@@ -204,27 +200,31 @@ class LocalPredictor(object): ...@@ -204,27 +200,31 @@ class LocalPredictor(object):
feed[name] = feed[name].astype("int32") feed[name] = feed[name].astype("int32")
else: else:
raise ValueError("local predictor receives wrong data type") raise ValueError("local predictor receives wrong data type")
input_tensor = self.predictor.get_input_tensor(name) input_tensor_handle = self.predictor.get_input_handle(name)
if "{}.lod".format(name) in feed: if "{}.lod".format(name) in feed:
input_tensor.set_lod([feed["{}.lod".format(name)]]) input_tensor_handle.set_lod([feed["{}.lod".format(name)]])
if batch == False: if batch == False:
input_tensor.copy_from_cpu(feed[name][np.newaxis, :]) input_tensor_handle.copy_from_cpu(feed[name][np.newaxis, :])
else: else:
input_tensor.copy_from_cpu(feed[name]) input_tensor_handle.copy_from_cpu(feed[name])
output_tensors = [] output_tensor_handles = []
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
for output_name in output_names: for output_name in output_names:
output_tensor = self.predictor.get_output_tensor(output_name) output_tensor_handle = self.predictor.get_output_handle(output_name)
output_tensors.append(output_tensor) output_tensor_handles.append(output_tensor_handle)
# Run inference
self.predictor.run()
# Assemble output data of predict results
outputs = [] outputs = []
self.predictor.zero_copy_run() for output_tensor_handle in output_tensor_handles:
for output_tensor in output_tensors: output = output_tensor_handle.copy_to_cpu()
output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
fetch_map = {} fetch_map = {}
for i, name in enumerate(fetch): for i, name in enumerate(fetch):
fetch_map[name] = outputs[i] fetch_map[name] = outputs[i]
if len(output_tensors[i].lod()) > 0: if len(output_tensor_handles[i].lod()) > 0:
fetch_map[name + ".lod"] = np.array(output_tensors[i].lod()[ fetch_map[name + ".lod"] = np.array(output_tensor_handles[i]
0]).astype('int32') .lod()[0]).astype('int32')
return fetch_map return fetch_map
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册