未验证 提交 d2d95908 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #1358 from TeslaZhao/develop

Modify model conversion interfaces and model load methods
......@@ -22,6 +22,7 @@ import argparse
from .proto import general_model_config_pb2 as m_config
import paddle.inference as paddle_infer
import logging
import glob
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("LocalPredictor")
......@@ -51,6 +52,23 @@ class LocalPredictor(object):
self.fetch_names_to_idx_ = {}
self.fetch_names_to_type_ = {}
def search_suffix_files(self, model_path, target_suffix):
"""
Find all files with the suffix xxx in the specified directory.
Args:
model_path: model directory, not None.
target_suffix: filenames with target suffix, not None. e.g: *.pdmodel
Returns:
file_list, None, [] or [path, ] .
"""
if model_path is None or target_suffix is None:
return None
file_list = glob.glob(os.path.join(model_path, target_suffix))
return file_list
def load_model_config(self,
model_path,
use_gpu=False,
......@@ -97,11 +115,30 @@ class LocalPredictor(object):
f = open(client_config, 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
# Init paddle_infer config
# Paddle's model files and parameter files have multiple naming rules:
# 1) __model__, __params__
# 2) *.pdmodel, *.pdiparams
# 3) __model__, conv2d_1.w_0, conv2d_2.w_0, fc_1.w_0, conv2d_1.b_0, ...
pdmodel_file_list = self.search_suffix_files(model_path, "*.pdmodel")
pdiparams_file_list = self.search_suffix_files(model_path,
"*.pdiparams")
if os.path.exists(os.path.join(model_path, "__params__")):
# case 1) initializing
config = paddle_infer.Config(
os.path.join(model_path, "__model__"),
os.path.join(model_path, "__params__"))
elif pdmodel_file_list and len(
pdmodel_file_list) > 0 and pdiparams_file_list and len(
pdiparams_file_list) > 0:
# case 2) initializing
logger.info("pdmodel_file_list:{}, pdiparams_file_list:{}".format(
pdmodel_file_list, pdiparams_file_list))
config = paddle_infer.Config(pdmodel_file_list[0],
pdiparams_file_list[0])
else:
# case 3) initializing.
config = paddle_infer.Config(model_path)
logger.info(
......@@ -201,8 +238,9 @@ class LocalPredictor(object):
Run model inference by Paddle Inference API.
Args:
feed: feed var
fetch: fetch var
feed: feed var list, None is not allowed.
fetch: fetch var list, None allowed. when it is None, all fetch
vars are returned. Otherwise, return fetch specified result.
batch: batch data or not, False default.If batch is False, a new
dimension is added to header of the shape[np.newaxis].
log_id: for logging
......@@ -210,16 +248,8 @@ class LocalPredictor(object):
Returns:
fetch_map: dict
"""
if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction.\
log_id:{}".format(log_id))
fetch_list = []
if isinstance(fetch, str):
fetch_list = [fetch]
elif isinstance(fetch, list):
fetch_list = fetch
else:
raise ValueError("Fetch only accepts string and list of string.\
if feed is None:
raise ValueError("You should specify feed vars for prediction.\
log_id:{}".format(log_id))
feed_batch = []
......@@ -231,18 +261,20 @@ class LocalPredictor(object):
raise ValueError("Feed only accepts dict and list of dict.\
log_id:{}".format(log_id))
fetch_names = []
fetch_list = []
if fetch is not None:
if isinstance(fetch, str):
fetch_list = [fetch]
elif isinstance(fetch, list):
fetch_list = fetch
# Filter invalid fetch names
fetch_names = []
for key in fetch_list:
if key in self.fetch_names_:
fetch_names.append(key)
if len(fetch_names) == 0:
raise ValueError(
"Fetch names should not be empty or out of saved fetch list.\
log_id:{}".format(log_id))
# Assemble the input data of paddle predictor
# Assemble the input data of paddle predictor, and filter invalid inputs.
input_names = self.predictor.get_input_names()
for name in input_names:
if isinstance(feed[name], list):
......@@ -282,11 +314,15 @@ class LocalPredictor(object):
input_tensor_handle.copy_from_cpu(feed[name][np.newaxis, :])
else:
input_tensor_handle.copy_from_cpu(feed[name])
# set output tensor handlers
output_tensor_handles = []
output_name_to_index_dict = {}
output_names = self.predictor.get_output_names()
for output_name in output_names:
for i, output_name in enumerate(output_names):
output_tensor_handle = self.predictor.get_output_handle(output_name)
output_tensor_handles.append(output_tensor_handle)
output_name_to_index_dict[output_name] = i
# Run inference
self.predictor.run()
......@@ -296,10 +332,43 @@ class LocalPredictor(object):
for output_tensor_handle in output_tensor_handles:
output = output_tensor_handle.copy_to_cpu()
outputs.append(output)
outputs_len = len(outputs)
# Copy fetch vars. If fetch is None, it will copy all results from output_tensor_handles.
# Otherwise, it will copy the fields specified from output_tensor_handles.
fetch_map = {}
for i, name in enumerate(fetch):
fetch_map[name] = outputs[i]
if len(output_tensor_handles[i].lod()) > 0:
fetch_map[name + ".lod"] = np.array(output_tensor_handles[i]
.lod()[0]).astype('int32')
if fetch is None:
for i, name in enumerate(output_names):
fetch_map[name] = outputs[i]
if len(output_tensor_handles[i].lod()) > 0:
fetch_map[name + ".lod"] = np.array(output_tensor_handles[
i].lod()[0]).astype('int32')
else:
# Because the save_inference_model interface will increase the scale op
# in the network, the name of fetch_var is different from that in prototxt.
# Therefore, it is compatible with v0.6.x and the previous model save format,
# and here is compatible with the results that do not match.
fetch_match_num = 0
for i, name in enumerate(fetch):
output_index = output_name_to_index_dict.get(name)
if output_index is None:
continue
fetch_map[name] = outputs[output_index]
fetch_match_num += 1
if len(output_tensor_handles[output_index].lod()) > 0:
fetch_map[name + ".lod"] = np.array(output_tensor_handles[
output_index].lod()[0]).astype('int32')
# Compatible with v0.6.x and lower versions model saving formats.
if fetch_match_num == 0:
logger.debug("fetch match num is 0. Retrain the model please!")
for i, name in enumerate(fetch):
if i >= outputs_len:
break
fetch_map[name] = outputs[i]
if len(output_tensor_handles[i].lod()) > 0:
fetch_map[name + ".lod"] = np.array(
output_tensor_handles[i].lod()[0]).astype('int32')
return fetch_map
......@@ -67,7 +67,6 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
}
config = model_conf.GeneralModelConfig()
#int64 = 0; float32 = 1; int32 = 2;
for key in feed_var_dict:
feed_var = model_conf.FeedVar()
feed_var.alias_name = key
......@@ -127,7 +126,6 @@ def save_dygraph_model(serving_model_folder, client_config_folder, model):
def var_type_conversion(dtype):
"""
Variable type conversion
Args:
dtype: type of core.VarDesc.VarType.xxxxx
(https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/framework/dtype.py)
......@@ -184,7 +182,9 @@ def save_model(server_model_folder,
main_program=None,
encryption=False,
key_len=128,
encrypt_conf=None):
encrypt_conf=None,
model_filename=None,
params_filename=None):
executor = Executor(place=CPUPlace())
feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
......@@ -194,15 +194,27 @@ def save_model(server_model_folder,
target_vars.append(fetch_var_dict[key])
target_var_names.append(key)
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
if not encryption:
save_inference_model(
server_model_folder,
feed_var_names,
target_vars,
executor,
model_filename="__model__",
params_filename="__params__",
main_program=main_program)
if not model_filename:
model_filename = "model.pdmodel"
if not params_filename:
params_filename = "params.pdiparams"
new_model_path = os.path.join(server_model_folder, model_filename)
new_params_path = os.path.join(server_model_folder, params_filename)
with open(new_model_path, "wb") as new_model_file:
new_model_file.write(main_program.desc.serialize_to_string())
paddle.static.save_vars(
executor=executor,
dirname=server_model_folder,
main_program=main_program,
vars=None,
predicate=paddle.static.io.is_persistable,
filename=params_filename)
else:
if encrypt_conf == None:
aes_cipher = CipherFactory.create_cipher()
......@@ -296,7 +308,8 @@ def inference_model_to_serving(dirname,
}
fetch_dict = {x.name: x for x in fetch_targets}
save_model(serving_server, serving_client, feed_dict, fetch_dict,
inference_program, encryption, key_len, encrypt_conf)
inference_program, encryption, key_len, encrypt_conf,
model_filename, params_filename)
feed_names = feed_dict.keys()
fetch_names = fetch_dict.keys()
return feed_names, fetch_names
......@@ -40,6 +40,7 @@ from .channel import (ThreadChannel, ProcessChannel, ChannelDataErrcode,
from .util import NameGenerator
from .profiler import UnsafeTimeProfiler as TimeProfiler
from . import local_service_handler
from .pipeline_client import PipelineClient as PPClient
_LOGGER = logging.getLogger(__name__)
_op_name_gen = NameGenerator("Op")
......@@ -330,9 +331,8 @@ class Op(object):
if self.client_type == 'brpc':
client = Client()
client.load_client_config(client_config)
# 待测试完成后,使用brpc-http替代。
# elif self.client_type == 'grpc':
# client = MultiLangClient()
elif self.client_type == 'pipeline_grpc':
client = PPClient()
elif self.client_type == 'local_predictor':
if self.local_predictor is None:
raise ValueError("local predictor not yet created")
......@@ -531,32 +531,72 @@ class Op(object):
Returns:
call_result: predict result
"""
err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0:
_LOGGER.critical(
self._log("Failed to run process: {}. Please override "
"preprocess func.".format(err_info)))
os._exit(-1)
call_result = None
err_code = ChannelDataErrcode.OK.value
err_info = ""
if self.client_type == "local_predictor":
err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0:
_LOGGER.error(
self._log("Failed to run process: {}. feed_batch must be \
npdata in process for local_predictor mode."
.format(err_info)))
return call_result, ChannelDataErrcode.TYPE_ERROR.value, "feed_batch must be npdata"
call_result = self.client.predict(
feed=feed_batch[0],
fetch=self._fetch_names,
batch=True,
log_id=typical_logid)
else:
elif self.client_type == "brpc":
err, err_info = ChannelData.check_batch_npdata(feed_batch)
if err != 0:
_LOGGER.error(
self._log("Failed to run process: {}. feed_batch must be \
npdata in process for brpc mode.".format(err_info)))
return call_result, ChannelDataErrcode.TYPE_ERROR.value, "feed_batch must be npdata"
call_result = self.client.predict(
feed=feed_batch,
feed=feed_batch[0],
fetch=self._fetch_names,
batch=True,
log_id=typical_logid)
# 后续用HttpClient替代
'''
if isinstance(self.client, MultiLangClient):
if call_result is None or call_result["serving_status_code"] != 0:
return None
call_result.pop("serving_status_code")
'''
return call_result
elif self.client_type == "pipeline_grpc":
err, err_info = ChannelData.check_dictdata(feed_batch)
if err != 0:
_LOGGER.error(
self._log("Failed to run process: {}. feed_batch must be \
npdata in process for pipeline_grpc mode."
.format(err_info)))
return call_result, ChannelDataErrcode.TYPE_ERROR.value, "feed_batch must be dict"
call_result = self.client.predict(
feed_dict=feed_batch[0],
fetch=self._fetch_names,
asyn=False,
profile=False)
if call_result is None:
_LOGGER.error(
self._log("Failed in pipeline_grpc. call_result is None."))
return call_result, ChannelDataErrcode.UNKNOW.value, "pipeline_grpc error"
if call_result.err_no != 0:
_LOGGER.error(
self._log("Failed in pipeline_grpc. err_no:{}, err_info:{}".
format(call_result.err_no, call_result.err_msg)))
return call_result, ChannelDataErrcode(
call_result.err_no).value, call_result.err_msg
new_dict = {}
err_code = ChannelDataErrcode(call_result.err_no).value
err_info = call_result.err_msg
for idx, key in enumerate(call_result.key):
new_dict[key] = [call_result.value[idx]]
call_result = new_dict
return call_result, err_code, err_info
def postprocess(self, input_data, fetch_data, data_id=0, log_id=0):
"""
......@@ -891,16 +931,20 @@ class Op(object):
midped_batch = None
error_code = ChannelDataErrcode.OK.value
error_info = ""
if self._timeout <= 0:
# No retry
try:
if batch_input is False:
midped_batch = self.process(feed_batch, typical_logid)
midped_batch, error_code, error_info = self.process(
feed_batch, typical_logid)
else:
midped_batch = []
for idx in range(len(feed_batch)):
predict_res = self.process([feed_batch[idx]],
typical_logid)
predict_res, error_code, error_info = self.process(
[feed_batch[idx]], typical_logid)
if error_code != ChannelDataErrcode.OK.value:
break
midped_batch.append(predict_res)
except Exception as e:
error_code = ChannelDataErrcode.UNKNOW.value
......@@ -913,14 +957,14 @@ class Op(object):
try:
# time out for each process
if batch_input is False:
midped_batch = func_timeout.func_timeout(
midped_batch, error_code, error_info = func_timeout.func_timeout(
self._timeout,
self.process,
args=(feed_batch, typical_logid))
else:
midped_batch = []
for idx in range(len(feed_batch)):
predict_res = func_timeout.func_timeout(
predict_res, error_code, error_info = func_timeout.func_timeout(
self._timeout,
self.process,
args=([feed_batch[idx]], typical_logid))
......
......@@ -93,13 +93,19 @@ class PipelineClient(object):
def _unpack_response_package(self, resp, fetch):
return resp
def predict(self, feed_dict, fetch=None, asyn=False, profile=False):
def predict(self,
feed_dict,
fetch=None,
asyn=False,
profile=False,
log_id=0):
if not isinstance(feed_dict, dict):
raise TypeError(
"feed must be dict type with format: {name: value}.")
if fetch is not None and not isinstance(fetch, list):
raise TypeError("fetch must be list type with format: [name].")
req = self._pack_request_package(feed_dict, profile)
req.logid = log_id
if not asyn:
resp = self._stub.inference(req)
return self._unpack_response_package(resp, fetch)
......
......@@ -39,7 +39,7 @@ class AvailablePortGenerator(object):
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('127.0.0.1', port))
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册